-
Notifications
You must be signed in to change notification settings - Fork 167
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(py): Implementation of embedders api in Ollama Plugin
- Loading branch information
Showing
3 changed files
with
131 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# Copyright 2025 Google LLC | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from genkit.ai.embedding import EmbedRequest, EmbedResponse | ||
from genkit.plugins.ollama.models import EmbeddingModelDefinition | ||
|
||
import ollama as ollama_api | ||
|
||
|
||
class OllamaEmbedder: | ||
def __init__( | ||
self, | ||
client: ollama_api.Client, | ||
embedding_definition: EmbeddingModelDefinition, | ||
): | ||
self.client = client | ||
self.embedding_definition = embedding_definition | ||
|
||
def embed(self, request: EmbedRequest) -> EmbedResponse: | ||
return self.client.embed( | ||
model=self.embedding_definition.name, | ||
input=request.documents, | ||
) | ||
|
||
|
||
class AsyncOllamaEmbedder: | ||
def __init__( | ||
self, | ||
client: ollama_api.AsyncClient, | ||
embedding_definition: EmbeddingModelDefinition, | ||
): | ||
self.client = client | ||
self.embedding_definition = embedding_definition | ||
|
||
async def embed(self, request: EmbedRequest) -> EmbedResponse: | ||
return await self.client.embed( | ||
model=self.embedding_definition.name, | ||
input=request.documents, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# Copyright 2025 Google LLC | ||
# SPDX-License-Identifier: Apache-2.0 | ||
import asyncio | ||
|
||
from genkit.plugins.ollama import Ollama, ollama_name | ||
from genkit.plugins.ollama.models import ( | ||
EmbeddingModelDefinition, | ||
OllamaPluginParams, | ||
) | ||
from genkit.veneer import Genkit | ||
|
||
# model can be pulled with `ollama pull *LLM_VERSION*` | ||
EMBEDDER_VERSION = 'mxbai-embed-large' | ||
|
||
plugin_params = OllamaPluginParams( | ||
embedders=[ | ||
EmbeddingModelDefinition( | ||
name=EMBEDDER_VERSION, | ||
dimensions=512, | ||
) | ||
], | ||
use_async_api=True, | ||
) | ||
|
||
ai = Genkit( | ||
plugins=[ | ||
Ollama( | ||
plugin_params=plugin_params, | ||
) | ||
], | ||
) | ||
|
||
|
||
async def sample_embed(documents: list[str]): | ||
"""Generate a request to greet a user. | ||
Args: | ||
hi_input: Input data containing user information. | ||
Returns: | ||
A GenerateRequest object with the greeting message. | ||
""" | ||
return await ai.embed( | ||
model=ollama_name(EMBEDDER_VERSION), | ||
documents=documents, | ||
) | ||
|
||
|
||
async def main() -> None: | ||
response = await sample_embed( | ||
documents=[ | ||
'test document 1', | ||
'test document 2', | ||
] | ||
) | ||
print(response) | ||
|
||
|
||
if __name__ == '__main__': | ||
asyncio.run(main()) |