Skip to content

Commit

Permalink
feat(py): Implementation of embedders api in Ollama Plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
kirgrim committed Mar 6, 2025
1 parent 54c35b2 commit 13fa811
Show file tree
Hide file tree
Showing 13 changed files with 557 additions and 170 deletions.
23 changes: 23 additions & 0 deletions py/plugins/ollama/src/genkit/plugins/ollama/embedders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright 2025 Google LLC
# SPDX-License-Identifier: Apache-2.0

from genkit.ai.embedding import EmbedRequest, EmbedResponse
from genkit.plugins.ollama.models import EmbeddingModelDefinition

import ollama as ollama_api


class OllamaEmbedder:
def __init__(
self,
client: ollama_api.AsyncClient,
embedding_definition: EmbeddingModelDefinition,
):
self.client = client
self.embedding_definition = embedding_definition

async def embed(self, request: EmbedRequest) -> EmbedResponse:
return await self.client.embed(
model=self.embedding_definition.name,
input=request.documents,
)
59 changes: 0 additions & 59 deletions py/plugins/ollama/src/genkit/plugins/ollama/mixins.py

This file was deleted.

149 changes: 73 additions & 76 deletions py/plugins/ollama/src/genkit/plugins/ollama/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from genkit.core.typing import (
GenerateRequest,
GenerateResponse,
GenerateResponseChunk,
GenerationCommonConfig,
Message,
Role,
TextPart,
Expand All @@ -14,7 +16,6 @@
DEFAULT_OLLAMA_SERVER_URL,
OllamaAPITypes,
)
from genkit.plugins.ollama.mixins import BaseOllamaModelMixin
from pydantic import BaseModel, Field, HttpUrl

import ollama as ollama_api
Expand All @@ -37,81 +38,9 @@ class OllamaPluginParams(BaseModel):
embedders: list[EmbeddingModelDefinition] = Field(default_factory=list)
server_address: HttpUrl = Field(default=HttpUrl(DEFAULT_OLLAMA_SERVER_URL))
request_headers: dict[str, str] | None = None
use_async_api: bool = Field(default=True)


class OllamaModel(BaseOllamaModelMixin):
def __init__(
self, client: ollama_api.Client, model_definition: ModelDefinition
):
self.client = client
self.model_definition = model_definition

def generate(
self, request: GenerateRequest, ctx: ActionRunContext | None
) -> GenerateResponse:
txt_response = 'Failed to get response from Ollama API'

if self.model_definition.api_type == OllamaAPITypes.CHAT:
api_response = self._chat_with_ollama(request=request, ctx=ctx)
if api_response:
txt_response = api_response.message.content
else:
api_response = self._generate_ollama_response(
request=request, ctx=ctx
)
if api_response:
txt_response = api_response.response

return GenerateResponse(
message=Message(
role=Role.MODEL,
content=[TextPart(text=txt_response)],
)
)

def _chat_with_ollama(
self, request: GenerateRequest, ctx: ActionRunContext | None = None
) -> ollama_api.ChatResponse | None:
messages = self.build_chat_messages(request)
streaming_request = self.is_streaming_request(ctx=ctx)

chat_response = self.client.chat(
model=self.model_definition.name,
messages=messages,
options=self.build_request_options(config=request.config),
stream=streaming_request,
)

if streaming_request:
for chunk in chat_response:
ctx.send_chunk(chunk=chunk)
else:
return chat_response

def _generate_ollama_response(
self, request: GenerateRequest, ctx: ActionRunContext | None = None
) -> ollama_api.GenerateResponse | None:
prompt = self.build_prompt(request)
streaming_request = self.is_streaming_request(ctx=ctx)

request_kwargs = {
'model': self.model_definition.name,
'prompt': prompt,
'options': self.build_request_options(config=request.config),
'stream': streaming_request,
}

generate_response = self.client.generate(**request_kwargs)

if streaming_request:
for chunk in generate_response:
ctx.send_chunk(chunk=chunk)
else:
return generate_response


class AsyncOllamaModel(BaseOllamaModelMixin):
class OllamaModel:
def __init__(
self, client: ollama_api.AsyncClient, model_definition: ModelDefinition
):
Expand All @@ -138,6 +67,9 @@ async def generate(
else:
LOG.error(f'Unresolved API type: {self.model_definition.api_type}')

if self.is_streaming_request(ctx=ctx):
txt_response = 'Response sent to Streaming API'

return GenerateResponse(
message=Message(
role=Role.MODEL,
Expand All @@ -159,8 +91,21 @@ async def _chat_with_ollama(
)

if streaming_request:
idx = 0
async for chunk in chat_response:
ctx.send_chunk(chunk=chunk)
idx += 1
role = (
Role.MODEL
if chunk.message.role == 'assistant'
else Role.TOOL
)
ctx.send_chunk(
chunk=GenerateResponseChunk(
role=role,
index=idx,
content=[TextPart(text=chunk.message.content)],
)
)
else:
return chat_response

Expand All @@ -180,7 +125,59 @@ async def _generate_ollama_response(
generate_response = await self.client.generate(**request_kwargs)

if streaming_request:
idx = 0
async for chunk in generate_response:
ctx.send_chunk(chunk=chunk)
idx += 1
ctx.send_chunk(
chunk=GenerateResponseChunk(
role=Role.MODEL,
index=idx,
content=[TextPart(text=chunk.response)],
)
)
else:
return generate_response

@staticmethod
def build_request_options(
config: GenerationCommonConfig,
) -> ollama_api.Options:
if config:
return ollama_api.Options(
top_k=config.top_k,
top_p=config.top_p,
stop=config.stop_sequences,
temperature=config.temperature,
num_predict=config.max_output_tokens,
)

@staticmethod
def build_prompt(request: GenerateRequest) -> str:
prompt = ''
for message in request.messages:
for text_part in message.content:
if isinstance(text_part.root, TextPart):
prompt += text_part.root.text
else:
LOG.error('Non-text messages are not supported')
return prompt

@staticmethod
def build_chat_messages(request: GenerateRequest) -> list[dict[str, str]]:
messages = []
for message in request.messages:
item = {
'role': message.role,
'content': '',
}
for text_part in message.content:
if isinstance(text_part.root, TextPart):
item['content'] += text_part.root.text
else:
LOG.error(f'Unsupported part of message: {text_part}')
messages.append(item)
return messages

@staticmethod
def is_streaming_request(ctx: ActionRunContext | None) -> bool:
return ctx and ctx.is_streaming
56 changes: 25 additions & 31 deletions py/plugins/ollama/src/genkit/plugins/ollama/plugin_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,9 @@
"""

import logging
from functools import cached_property
from typing import Type

from genkit.core.action import ActionKind
from genkit.core.registry import Registry
from genkit.plugins.ollama.embedders import OllamaEmbedder
from genkit.plugins.ollama.models import (
AsyncOllamaModel,
OllamaAPITypes,
OllamaModel,
OllamaPluginParams,
Expand All @@ -34,35 +30,17 @@ class Ollama(Plugin):

def __init__(self, plugin_params: OllamaPluginParams):
self.plugin_params = plugin_params
self._sync_client = ollama_api.Client(
self.client = ollama_api.AsyncClient(
host=self.plugin_params.server_address.unicode_string()
)
self._async_client = ollama_api.AsyncClient(
host=self.plugin_params.server_address.unicode_string()
)

@cached_property
def client(self) -> ollama_api.AsyncClient | ollama_api.Client:
client_cls = (
ollama_api.AsyncClient
if self.plugin_params.use_async_api
else ollama_api.Client
)
return client_cls(
host=self.plugin_params.server_address.unicode_string(),
)

@cached_property
def ollama_model_class(self) -> Type[AsyncOllamaModel | OllamaModel]:
return (
AsyncOllamaModel
if self.plugin_params.use_async_api
else OllamaModel
)

def initialize(self, ai: GenkitRegistry) -> None:
self._initialize_models(ai=ai)
self._initialize_embedders(ai=ai)

def _initialize_models(self, ai: GenkitRegistry):
for model_definition in self.plugin_params.models:
model = self.ollama_model_class(
model = OllamaModel(
client=self.client,
model_definition=model_definition,
)
Expand All @@ -75,5 +53,21 @@ def initialize(self, ai: GenkitRegistry) -> None:
'system_role': True,
},
)
# TODO: introduce embedders here
# for embedder in self.plugin_params.embedders:

def _initialize_embedders(self, ai: GenkitRegistry):
for embedding_definition in self.plugin_params.embedders:
embedder = OllamaEmbedder(
client=self.client,
embedding_definition=embedding_definition,
)
ai.define_embedder(
name=ollama_name(embedding_definition.name),
fn=embedder.embed,
metadata={
'label': f'Ollama Embedding - {embedding_definition.name}',
'dimensions': embedding_definition.dimensions,
'supports': {
'input': ['text'],
},
},
)
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ In case of questions, please refer to `./py/plugins/ollama/README.md`
## Execute "Hello World" Sample

```bash
genkit start -- uv run hello.py
genkit start -- uv run ./ollama-hello/src/hello.py
```
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
api_type=OllamaAPITypes.CHAT,
)
],
use_async_api=True,
)

ai = Genkit(
Expand All @@ -34,6 +33,10 @@
)


def on_chunk(chunk):
print('received chunk: ', chunk)


@ai.flow()
async def say_hi(hi_input: str):
"""Generate a request to greet a user.
Expand All @@ -52,7 +55,9 @@ async def say_hi(hi_input: str):
TextPart(text='hi ' + hi_input),
],
)
]
],
# uncomment me to handle streaming response
# on_chunk=on_chunk,
)


Expand Down
Loading

0 comments on commit 13fa811

Please sign in to comment.