diff --git a/py/plugins/ollama/src/genkit/plugins/ollama/embedders.py b/py/plugins/ollama/src/genkit/plugins/ollama/embedders.py new file mode 100644 index 000000000..f1fd26f59 --- /dev/null +++ b/py/plugins/ollama/src/genkit/plugins/ollama/embedders.py @@ -0,0 +1,23 @@ +# Copyright 2025 Google LLC +# SPDX-License-Identifier: Apache-2.0 + +from genkit.ai.embedding import EmbedRequest, EmbedResponse +from genkit.plugins.ollama.models import EmbeddingModelDefinition + +import ollama as ollama_api + + +class OllamaEmbedder: + def __init__( + self, + client: ollama_api.AsyncClient, + embedding_definition: EmbeddingModelDefinition, + ): + self.client = client + self.embedding_definition = embedding_definition + + async def embed(self, request: EmbedRequest) -> EmbedResponse: + return await self.client.embed( + model=self.embedding_definition.name, + input=request.documents, + ) diff --git a/py/plugins/ollama/src/genkit/plugins/ollama/mixins.py b/py/plugins/ollama/src/genkit/plugins/ollama/mixins.py deleted file mode 100644 index 200e0a48d..000000000 --- a/py/plugins/ollama/src/genkit/plugins/ollama/mixins.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2025 Google LLC -# SPDX-License-Identifier: Apache-2.0 - -import logging - -from genkit.core.action import ActionRunContext, noop_streaming_callback - -# Common helpers extracted into a base class or module -from genkit.core.typing import GenerateRequest, GenerationCommonConfig, TextPart - -import ollama as ollama_api - -LOG = logging.getLogger(__name__) - - -class BaseOllamaModelMixin: - @staticmethod - def build_request_options( - config: GenerationCommonConfig, - ) -> ollama_api.Options: - if config: - return ollama_api.Options( - top_k=config.top_k, - top_p=config.top_p, - stop=config.stop_sequences, - temperature=config.temperature, - num_predict=config.max_output_tokens, - ) - - @staticmethod - def build_prompt(request: GenerateRequest) -> str: - prompt = '' - for message in request.messages: - for text_part in message.content: - if isinstance(text_part.root, TextPart): - prompt += text_part.root.text - else: - LOG.error('Non-text messages are not supported') - return prompt - - @staticmethod - def build_chat_messages(request: GenerateRequest) -> list[dict[str, str]]: - messages = [] - for message in request.messages: - item = { - 'role': message.role.value, - 'content': '', - } - for text_part in message.content: - if isinstance(text_part.root, TextPart): - item['content'] += text_part.root.text - else: - LOG.error(f'Unsupported part of message: {text_part}') - messages.append(item) - return messages - - @staticmethod - def is_streaming_request(ctx: ActionRunContext | None) -> bool: - return ctx and ctx.is_streaming diff --git a/py/plugins/ollama/src/genkit/plugins/ollama/models.py b/py/plugins/ollama/src/genkit/plugins/ollama/models.py index 7f45f1de5..a87df78b0 100644 --- a/py/plugins/ollama/src/genkit/plugins/ollama/models.py +++ b/py/plugins/ollama/src/genkit/plugins/ollama/models.py @@ -6,6 +6,8 @@ from genkit.core.typing import ( GenerateRequest, GenerateResponse, + GenerateResponseChunk, + GenerationCommonConfig, Message, Role, TextPart, @@ -14,7 +16,6 @@ DEFAULT_OLLAMA_SERVER_URL, OllamaAPITypes, ) -from genkit.plugins.ollama.mixins import BaseOllamaModelMixin from pydantic import BaseModel, Field, HttpUrl import ollama as ollama_api @@ -37,81 +38,9 @@ class OllamaPluginParams(BaseModel): embedders: list[EmbeddingModelDefinition] = Field(default_factory=list) server_address: HttpUrl = Field(default=HttpUrl(DEFAULT_OLLAMA_SERVER_URL)) request_headers: dict[str, str] | None = None - use_async_api: bool = Field(default=True) -class OllamaModel(BaseOllamaModelMixin): - def __init__( - self, client: ollama_api.Client, model_definition: ModelDefinition - ): - self.client = client - self.model_definition = model_definition - - def generate( - self, request: GenerateRequest, ctx: ActionRunContext | None - ) -> GenerateResponse: - txt_response = 'Failed to get response from Ollama API' - - if self.model_definition.api_type == OllamaAPITypes.CHAT: - api_response = self._chat_with_ollama(request=request, ctx=ctx) - if api_response: - txt_response = api_response.message.content - else: - api_response = self._generate_ollama_response( - request=request, ctx=ctx - ) - if api_response: - txt_response = api_response.response - - return GenerateResponse( - message=Message( - role=Role.MODEL, - content=[TextPart(text=txt_response)], - ) - ) - - def _chat_with_ollama( - self, request: GenerateRequest, ctx: ActionRunContext | None = None - ) -> ollama_api.ChatResponse | None: - messages = self.build_chat_messages(request) - streaming_request = self.is_streaming_request(ctx=ctx) - - chat_response = self.client.chat( - model=self.model_definition.name, - messages=messages, - options=self.build_request_options(config=request.config), - stream=streaming_request, - ) - - if streaming_request: - for chunk in chat_response: - ctx.send_chunk(chunk=chunk) - else: - return chat_response - - def _generate_ollama_response( - self, request: GenerateRequest, ctx: ActionRunContext | None = None - ) -> ollama_api.GenerateResponse | None: - prompt = self.build_prompt(request) - streaming_request = self.is_streaming_request(ctx=ctx) - - request_kwargs = { - 'model': self.model_definition.name, - 'prompt': prompt, - 'options': self.build_request_options(config=request.config), - 'stream': streaming_request, - } - - generate_response = self.client.generate(**request_kwargs) - - if streaming_request: - for chunk in generate_response: - ctx.send_chunk(chunk=chunk) - else: - return generate_response - - -class AsyncOllamaModel(BaseOllamaModelMixin): +class OllamaModel: def __init__( self, client: ollama_api.AsyncClient, model_definition: ModelDefinition ): @@ -138,6 +67,9 @@ async def generate( else: LOG.error(f'Unresolved API type: {self.model_definition.api_type}') + if self.is_streaming_request(ctx=ctx): + txt_response = 'Response sent to Streaming API' + return GenerateResponse( message=Message( role=Role.MODEL, @@ -159,8 +91,21 @@ async def _chat_with_ollama( ) if streaming_request: + idx = 0 async for chunk in chat_response: - ctx.send_chunk(chunk=chunk) + idx += 1 + role = ( + Role.MODEL + if chunk.message.role == 'assistant' + else Role.TOOL + ) + ctx.send_chunk( + chunk=GenerateResponseChunk( + role=role, + index=idx, + content=[TextPart(text=chunk.message.content)], + ) + ) else: return chat_response @@ -180,7 +125,59 @@ async def _generate_ollama_response( generate_response = await self.client.generate(**request_kwargs) if streaming_request: + idx = 0 async for chunk in generate_response: - ctx.send_chunk(chunk=chunk) + idx += 1 + ctx.send_chunk( + chunk=GenerateResponseChunk( + role=Role.MODEL, + index=idx, + content=[TextPart(text=chunk.response)], + ) + ) else: return generate_response + + @staticmethod + def build_request_options( + config: GenerationCommonConfig, + ) -> ollama_api.Options: + if config: + return ollama_api.Options( + top_k=config.top_k, + top_p=config.top_p, + stop=config.stop_sequences, + temperature=config.temperature, + num_predict=config.max_output_tokens, + ) + + @staticmethod + def build_prompt(request: GenerateRequest) -> str: + prompt = '' + for message in request.messages: + for text_part in message.content: + if isinstance(text_part.root, TextPart): + prompt += text_part.root.text + else: + LOG.error('Non-text messages are not supported') + return prompt + + @staticmethod + def build_chat_messages(request: GenerateRequest) -> list[dict[str, str]]: + messages = [] + for message in request.messages: + item = { + 'role': message.role, + 'content': '', + } + for text_part in message.content: + if isinstance(text_part.root, TextPart): + item['content'] += text_part.root.text + else: + LOG.error(f'Unsupported part of message: {text_part}') + messages.append(item) + return messages + + @staticmethod + def is_streaming_request(ctx: ActionRunContext | None) -> bool: + return ctx and ctx.is_streaming diff --git a/py/plugins/ollama/src/genkit/plugins/ollama/plugin_api.py b/py/plugins/ollama/src/genkit/plugins/ollama/plugin_api.py index 11f68fe38..ceb0a95fa 100644 --- a/py/plugins/ollama/src/genkit/plugins/ollama/plugin_api.py +++ b/py/plugins/ollama/src/genkit/plugins/ollama/plugin_api.py @@ -6,13 +6,9 @@ """ import logging -from functools import cached_property -from typing import Type -from genkit.core.action import ActionKind -from genkit.core.registry import Registry +from genkit.plugins.ollama.embedders import OllamaEmbedder from genkit.plugins.ollama.models import ( - AsyncOllamaModel, OllamaAPITypes, OllamaModel, OllamaPluginParams, @@ -34,35 +30,17 @@ class Ollama(Plugin): def __init__(self, plugin_params: OllamaPluginParams): self.plugin_params = plugin_params - self._sync_client = ollama_api.Client( + self.client = ollama_api.AsyncClient( host=self.plugin_params.server_address.unicode_string() ) - self._async_client = ollama_api.AsyncClient( - host=self.plugin_params.server_address.unicode_string() - ) - - @cached_property - def client(self) -> ollama_api.AsyncClient | ollama_api.Client: - client_cls = ( - ollama_api.AsyncClient - if self.plugin_params.use_async_api - else ollama_api.Client - ) - return client_cls( - host=self.plugin_params.server_address.unicode_string(), - ) - - @cached_property - def ollama_model_class(self) -> Type[AsyncOllamaModel | OllamaModel]: - return ( - AsyncOllamaModel - if self.plugin_params.use_async_api - else OllamaModel - ) def initialize(self, ai: GenkitRegistry) -> None: + self._initialize_models(ai=ai) + self._initialize_embedders(ai=ai) + + def _initialize_models(self, ai: GenkitRegistry): for model_definition in self.plugin_params.models: - model = self.ollama_model_class( + model = OllamaModel( client=self.client, model_definition=model_definition, ) @@ -75,5 +53,21 @@ def initialize(self, ai: GenkitRegistry) -> None: 'system_role': True, }, ) - # TODO: introduce embedders here - # for embedder in self.plugin_params.embedders: + + def _initialize_embedders(self, ai: GenkitRegistry): + for embedding_definition in self.plugin_params.embedders: + embedder = OllamaEmbedder( + client=self.client, + embedding_definition=embedding_definition, + ) + ai.define_embedder( + name=ollama_name(embedding_definition.name), + fn=embedder.embed, + metadata={ + 'label': f'Ollama Embedding - {embedding_definition.name}', + 'dimensions': embedding_definition.dimensions, + 'supports': { + 'input': ['text'], + }, + }, + ) diff --git a/py/samples/ollama/LICENSE b/py/samples/ollama-hello/LICENSE similarity index 100% rename from py/samples/ollama/LICENSE rename to py/samples/ollama-hello/LICENSE diff --git a/py/samples/ollama/README.md b/py/samples/ollama-hello/README.md similarity index 81% rename from py/samples/ollama/README.md rename to py/samples/ollama-hello/README.md index e6ed1cfc8..deac7f68a 100644 --- a/py/samples/ollama/README.md +++ b/py/samples/ollama-hello/README.md @@ -7,5 +7,5 @@ In case of questions, please refer to `./py/plugins/ollama/README.md` ## Execute "Hello World" Sample ```bash -genkit start -- uv run hello.py +genkit start -- uv run ./ollama-hello/src/hello.py ``` diff --git a/py/samples/ollama/pyproject.toml b/py/samples/ollama-hello/pyproject.toml similarity index 100% rename from py/samples/ollama/pyproject.toml rename to py/samples/ollama-hello/pyproject.toml diff --git a/py/samples/ollama/hello.py b/py/samples/ollama-hello/src/hello.py similarity index 89% rename from py/samples/ollama/hello.py rename to py/samples/ollama-hello/src/hello.py index ec0da7d22..e7ec73714 100644 --- a/py/samples/ollama/hello.py +++ b/py/samples/ollama-hello/src/hello.py @@ -21,7 +21,6 @@ api_type=OllamaAPITypes.CHAT, ) ], - use_async_api=True, ) ai = Genkit( @@ -34,6 +33,10 @@ ) +def on_chunk(chunk): + print('received chunk: ', chunk) + + @ai.flow() async def say_hi(hi_input: str): """Generate a request to greet a user. @@ -52,7 +55,9 @@ async def say_hi(hi_input: str): TextPart(text='hi ' + hi_input), ], ) - ] + ], + # uncomment me to handle streaming response + # on_chunk=on_chunk, ) diff --git a/py/samples/ollama-simple-embed/LICENSE b/py/samples/ollama-simple-embed/LICENSE new file mode 100644 index 000000000..220539673 --- /dev/null +++ b/py/samples/ollama-simple-embed/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/py/samples/ollama-simple-embed/README.md b/py/samples/ollama-simple-embed/README.md new file mode 100644 index 000000000..f2c376649 --- /dev/null +++ b/py/samples/ollama-simple-embed/README.md @@ -0,0 +1,11 @@ +# Run the sample + +## NOTE +Before running the sample make sure to install the model and start ollama serving. +In case of questions, please refer to `./py/plugins/ollama/README.md` + +## Execute "Ollama Embed" Sample + +```bash +genkit start -- uv run ./ollama-simple-embed/src/pokemon_glossary.py +``` diff --git a/py/samples/ollama-simple-embed/pyproject.toml b/py/samples/ollama-simple-embed/pyproject.toml new file mode 100644 index 000000000..ba5fe2833 --- /dev/null +++ b/py/samples/ollama-simple-embed/pyproject.toml @@ -0,0 +1,30 @@ +[project] +authors = [{ name = "Google" }] +classifiers = [ + "Development Status :: 3 - Alpha", + "Environment :: Console", + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Operating System :: OS Independent", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3 :: Only", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development :: Libraries", +] +dependencies = [ + "genkit", + "genkit-firebase-plugin", + "genkit-google-ai-plugin", + "genkit-google-cloud-plugin", + "genkit-ollama-plugin", + "pydantic>=2.10.5", +] +description = "Ollama Simple Embed" +license = { text = "Apache-2.0" } +name = "ollama_simple_embed" +readme = "README.md" +requires-python = ">=3.12" +version = "0.1.0" diff --git a/py/samples/ollama-simple-embed/src/pokemon_glossary.py b/py/samples/ollama-simple-embed/src/pokemon_glossary.py new file mode 100644 index 000000000..137621ae4 --- /dev/null +++ b/py/samples/ollama-simple-embed/src/pokemon_glossary.py @@ -0,0 +1,161 @@ +# Copyright 2025 Google LLC +# SPDX-License-Identifier: Apache-2.0 +import asyncio +from math import sqrt + +from genkit.core.typing import GenerateResponse +from genkit.plugins.ollama import Ollama, ollama_name +from genkit.plugins.ollama.constants import OllamaAPITypes +from genkit.plugins.ollama.models import ( + EmbeddingModelDefinition, + ModelDefinition, + OllamaPluginParams, +) +from genkit.veneer import Genkit +from pydantic import BaseModel + +EMBEDDER_MODEL = 'nomic-embed-text' +EMBEDDER_DIMENSIONS = 768 +GENERATE_MODEL = 'phi3.5:latest' + +plugin_params = OllamaPluginParams( + models=[ + ModelDefinition( + name=GENERATE_MODEL, + api_type=OllamaAPITypes.GENERATE, + ) + ], + embedders=[ + EmbeddingModelDefinition( + name=EMBEDDER_MODEL, + dimensions=512, + ) + ], +) + +ai = Genkit( + plugins=[ + Ollama( + plugin_params=plugin_params, + ) + ], +) + + +class PokemonInfo(BaseModel): + name: str + description: str + embedding: list[float] | None = None + + +pokemon_list = [ + PokemonInfo( + name='Pikachu', + description='An Electric-type Pokemon known for its strong electric attacks.', + embedding=None, + ), + PokemonInfo( + name='Charmander', + description='A Fire-type Pokemon that evolves into the powerful Charizard.', + embedding=None, + ), + PokemonInfo( + name='Bulbasaur', + description='A Grass/Poison-type Pokemon that grows into a powerful Venusaur.', + embedding=None, + ), + PokemonInfo( + name='Squirtle', + description='A Water-type Pokemon known for its water-based attacks and high defense.', + embedding=None, + ), + PokemonInfo( + name='Jigglypuff', + description='A Normal/Fairy-type Pokemon with a hypnotic singing ability.', + embedding=None, + ), +] + + +async def embed_pokemons(): + for pokemon in pokemon_list: + embedding_response = await ai.embed( + model=ollama_name(EMBEDDER_MODEL), + documents=[pokemon.description], + ) + pokemon.embedding = embedding_response.embeddings[0] + + +def find_nearest_pokemons( + input_embedding: list[float], top_n: int = 3 +) -> list[PokemonInfo]: + if any(pokemon.embedding is None for pokemon in pokemon_list): + raise AttributeError('Some Pokemon are not yet embedded') + pokemon_distances = [ + { + **pokemon.model_dump(), + 'distance': cosine_distance(input_embedding, pokemon.embedding), + } + for pokemon in pokemon_list + ] + return sorted( + pokemon_distances, + key=lambda pokemon_distance: pokemon_distance['distance'], + )[:top_n] + + +def cosine_distance(a: list[float], b: list[float]) -> float: + if len(a) != len(b): + raise ValueError('Input vectors must have the same length') + + dot_product = sum(ai * bi for ai, bi in zip(a, b)) + magnitude_a = sqrt(sum(ai * ai for ai in a)) + magnitude_b = sqrt(sum(bi * bi for bi in b)) + + if magnitude_a == 0 or magnitude_b == 0: + raise ValueError('Invalid input: zero vector') + + return 1 - (dot_product / (magnitude_a * magnitude_b)) + + +async def generate_response(question: str) -> GenerateResponse: + input_embedding = await ai.embed( + model=ollama_name(EMBEDDER_MODEL), + documents=[question], + ) + nearest_pokemon = find_nearest_pokemons(input_embedding.embeddings[0]) + pokemons_context = '\n'.join( + f'{pokemon["name"]}: {pokemon["description"]}' + for pokemon in nearest_pokemon + ) + + return await ai.generate( + model=ollama_name(GENERATE_MODEL), + prompt=f'Given the following context on Pokemon:\n${pokemons_context}\n\nQuestion: ${question}\n\nAnswer:', + ) + + +@ai.flow( + name='Pokedex', +) +async def pokemon_flow(question: str): + """Generate a request to greet a user. + + Args: + question: Question for pokemons. + + Returns: + A GenerateRequest object with the greeting message. + """ + await embed_pokemons() + response = await generate_response(question=question) + return response.message.content[0].root.text + + +async def main() -> None: + response = await pokemon_flow('Who is the best water pokemon?') + print(response) + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/py/uv.lock b/py/uv.lock index 2a7f1bd91..64f6ac211 100644 --- a/py/uv.lock +++ b/py/uv.lock @@ -24,6 +24,7 @@ members = [ "imagen", "menu", "ollama-example", + "ollama-simple-embed", "prompt-file", "rag", "vertex-ai-model-garden", @@ -1899,7 +1900,30 @@ wheels = [ [[package]] name = "ollama-example" version = "0.1.0" -source = { virtual = "samples/ollama" } +source = { virtual = "samples/ollama-hello" } +dependencies = [ + { name = "genkit" }, + { name = "genkit-firebase-plugin" }, + { name = "genkit-google-ai-plugin" }, + { name = "genkit-google-cloud-plugin" }, + { name = "genkit-ollama-plugin" }, + { name = "pydantic" }, +] + +[package.metadata] +requires-dist = [ + { name = "genkit", editable = "packages/genkit" }, + { name = "genkit-firebase-plugin", editable = "plugins/firebase" }, + { name = "genkit-google-ai-plugin", editable = "plugins/google-ai" }, + { name = "genkit-google-cloud-plugin", editable = "plugins/google-cloud" }, + { name = "genkit-ollama-plugin", editable = "plugins/ollama" }, + { name = "pydantic", specifier = ">=2.10.5" }, +] + +[[package]] +name = "ollama-simple-embed" +version = "0.1.0" +source = { virtual = "samples/ollama-simple-embed" } dependencies = [ { name = "genkit" }, { name = "genkit-firebase-plugin" },