diff --git a/py/plugins/ollama/src/genkit/plugins/ollama/embedders.py b/py/plugins/ollama/src/genkit/plugins/ollama/embedders.py new file mode 100644 index 000000000..e0e9ebc7a --- /dev/null +++ b/py/plugins/ollama/src/genkit/plugins/ollama/embedders.py @@ -0,0 +1,39 @@ +# Copyright 2025 Google LLC +# SPDX-License-Identifier: Apache-2.0 + +from genkit.ai.embedding import EmbedRequest, EmbedResponse +from genkit.plugins.ollama.models import EmbeddingModelDefinition + +import ollama as ollama_api + + +class OllamaEmbedder: + def __init__( + self, + client: ollama_api.Client, + embedding_definition: EmbeddingModelDefinition, + ): + self.client = client + self.embedding_definition = embedding_definition + + def embed(self, request: EmbedRequest) -> EmbedResponse: + return self.client.embed( + model=self.embedding_definition.name, + input=request.documents, + ) + + +class AsyncOllamaEmbedder: + def __init__( + self, + client: ollama_api.AsyncClient, + embedding_definition: EmbeddingModelDefinition, + ): + self.client = client + self.embedding_definition = embedding_definition + + async def embed(self, request: EmbedRequest) -> EmbedResponse: + return await self.client.embed( + model=self.embedding_definition.name, + input=request.documents, + ) diff --git a/py/plugins/ollama/src/genkit/plugins/ollama/plugin_api.py b/py/plugins/ollama/src/genkit/plugins/ollama/plugin_api.py index 11f68fe38..e541987a6 100644 --- a/py/plugins/ollama/src/genkit/plugins/ollama/plugin_api.py +++ b/py/plugins/ollama/src/genkit/plugins/ollama/plugin_api.py @@ -9,8 +9,7 @@ from functools import cached_property from typing import Type -from genkit.core.action import ActionKind -from genkit.core.registry import Registry +from genkit.plugins.ollama.embedders import AsyncOllamaEmbedder, OllamaEmbedder from genkit.plugins.ollama.models import ( AsyncOllamaModel, OllamaAPITypes, @@ -60,7 +59,21 @@ def ollama_model_class(self) -> Type[AsyncOllamaModel | OllamaModel]: else OllamaModel ) + @cached_property + def ollama_embedder_class( + self, + ) -> Type[AsyncOllamaEmbedder | OllamaEmbedder]: + return ( + AsyncOllamaEmbedder + if self.plugin_params.use_async_api + else OllamaEmbedder + ) + def initialize(self, ai: GenkitRegistry) -> None: + self._initialize_models(ai=ai) + self._initialize_embedders(ai=ai) + + def _initialize_models(self, ai: GenkitRegistry): for model_definition in self.plugin_params.models: model = self.ollama_model_class( client=self.client, @@ -75,5 +88,20 @@ def initialize(self, ai: GenkitRegistry) -> None: 'system_role': True, }, ) - # TODO: introduce embedders here - # for embedder in self.plugin_params.embedders: + + def _initialize_embedders(self, ai: GenkitRegistry): + for embedding_definition in self.plugin_params.embedders: + embedder = self.ollama_embedder_class( + client=self.client, embedding_definition=embedding_definition + ) + ai.define_embedder( + name=ollama_name(embedding_definition.name), + fn=embedder.embed, + metadata={ + 'label': f'Ollama Embedding - {embedding_definition.name}', + 'dimensions': embedding_definition.dimensions, + 'supports': { + 'input': ['text'], + }, + }, + ) diff --git a/py/samples/ollama/embed.py b/py/samples/ollama/embed.py new file mode 100644 index 000000000..22b3412a9 --- /dev/null +++ b/py/samples/ollama/embed.py @@ -0,0 +1,60 @@ +# Copyright 2025 Google LLC +# SPDX-License-Identifier: Apache-2.0 +import asyncio + +from genkit.plugins.ollama import Ollama, ollama_name +from genkit.plugins.ollama.models import ( + EmbeddingModelDefinition, + OllamaPluginParams, +) +from genkit.veneer import Genkit + +# model can be pulled with `ollama pull *LLM_VERSION*` +EMBEDDER_VERSION = 'mxbai-embed-large' + +plugin_params = OllamaPluginParams( + embedders=[ + EmbeddingModelDefinition( + name=EMBEDDER_VERSION, + dimensions=512, + ) + ], + use_async_api=True, +) + +ai = Genkit( + plugins=[ + Ollama( + plugin_params=plugin_params, + ) + ], +) + + +async def sample_embed(documents: list[str]): + """Generate a request to greet a user. + + Args: + hi_input: Input data containing user information. + + Returns: + A GenerateRequest object with the greeting message. + """ + return await ai.embed( + model=ollama_name(EMBEDDER_VERSION), + documents=documents, + ) + + +async def main() -> None: + response = await sample_embed( + documents=[ + 'test document 1', + 'test document 2', + ] + ) + print(response) + + +if __name__ == '__main__': + asyncio.run(main())