Skip to content

Commit

Permalink
feat(py): Implementation of embedders api in Ollama Plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
kirgrim committed Mar 3, 2025
1 parent 54c35b2 commit 70d110b
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 4 deletions.
39 changes: 39 additions & 0 deletions py/plugins/ollama/src/genkit/plugins/ollama/embedders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright 2025 Google LLC
# SPDX-License-Identifier: Apache-2.0

from genkit.ai.embedding import EmbedRequest, EmbedResponse
from genkit.plugins.ollama.models import EmbeddingModelDefinition

import ollama as ollama_api


class OllamaEmbedder:
def __init__(
self,
client: ollama_api.Client,
embedding_definition: EmbeddingModelDefinition,
):
self.client = client
self.embedding_definition = embedding_definition

def embed(self, request: EmbedRequest) -> EmbedResponse:
return self.client.embed(
model=self.embedding_definition.name,
input=request.documents,
)


class AsyncOllamaEmbedder:
def __init__(
self,
client: ollama_api.AsyncClient,
embedding_definition: EmbeddingModelDefinition,
):
self.client = client
self.embedding_definition = embedding_definition

async def embed(self, request: EmbedRequest) -> EmbedResponse:
return await self.client.embed(
model=self.embedding_definition.name,
input=request.documents,
)
36 changes: 32 additions & 4 deletions py/plugins/ollama/src/genkit/plugins/ollama/plugin_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
from functools import cached_property
from typing import Type

from genkit.core.action import ActionKind
from genkit.core.registry import Registry
from genkit.plugins.ollama.embedders import AsyncOllamaEmbedder, OllamaEmbedder
from genkit.plugins.ollama.models import (
AsyncOllamaModel,
OllamaAPITypes,
Expand Down Expand Up @@ -60,7 +59,21 @@ def ollama_model_class(self) -> Type[AsyncOllamaModel | OllamaModel]:
else OllamaModel
)

@cached_property
def ollama_embedder_class(
self,
) -> Type[AsyncOllamaEmbedder | OllamaEmbedder]:
return (
AsyncOllamaEmbedder
if self.plugin_params.use_async_api
else OllamaEmbedder
)

def initialize(self, ai: GenkitRegistry) -> None:
self._initialize_models(ai=ai)
self._initialize_embedders(ai=ai)

def _initialize_models(self, ai: GenkitRegistry):
for model_definition in self.plugin_params.models:
model = self.ollama_model_class(
client=self.client,
Expand All @@ -75,5 +88,20 @@ def initialize(self, ai: GenkitRegistry) -> None:
'system_role': True,
},
)
# TODO: introduce embedders here
# for embedder in self.plugin_params.embedders:

def _initialize_embedders(self, ai: GenkitRegistry):
for embedding_definition in self.plugin_params.embedders:
embedder = self.ollama_embedder_class(
client=self.client, embedding_definition=embedding_definition
)
ai.define_embedder(
name=ollama_name(embedding_definition.name),
fn=embedder.embed,
metadata={
'label': f'Ollama Embedding - {embedding_definition.name}',
'dimensions': embedding_definition.dimensions,
'supports': {
'input': ['text'],
},
},
)
60 changes: 60 additions & 0 deletions py/samples/ollama/embed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright 2025 Google LLC
# SPDX-License-Identifier: Apache-2.0
import asyncio

from genkit.plugins.ollama import Ollama, ollama_name
from genkit.plugins.ollama.models import (
EmbeddingModelDefinition,
OllamaPluginParams,
)
from genkit.veneer import Genkit

# model can be pulled with `ollama pull *LLM_VERSION*`
EMBEDDER_VERSION = 'mxbai-embed-large'

plugin_params = OllamaPluginParams(
embedders=[
EmbeddingModelDefinition(
name=EMBEDDER_VERSION,
dimensions=512,
)
],
use_async_api=True,
)

ai = Genkit(
plugins=[
Ollama(
plugin_params=plugin_params,
)
],
)


async def sample_embed(documents: list[str]):
"""Generate a request to greet a user.
Args:
hi_input: Input data containing user information.
Returns:
A GenerateRequest object with the greeting message.
"""
return await ai.embed(
model=ollama_name(EMBEDDER_VERSION),
documents=documents,
)


async def main() -> None:
response = await sample_embed(
documents=[
'test document 1',
'test document 2',
]
)
print(response)


if __name__ == '__main__':
asyncio.run(main())

0 comments on commit 70d110b

Please sign in to comment.