Skip to content

Commit

Permalink
feat(py): Add Embedders and test on VertexAI (#2040)
Browse files Browse the repository at this point in the history
  • Loading branch information
Irillit authored Feb 24, 2025
1 parent 6a8e94d commit 443561e
Show file tree
Hide file tree
Showing 7 changed files with 162 additions and 22 deletions.
17 changes: 17 additions & 0 deletions py/packages/genkit/src/genkit/ai/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright 2025 Google LLC
# SPDX-License-Identifier: Apache-2.0

from collections.abc import Callable

from pydantic import BaseModel


class EmbedRequest(BaseModel):
documents: list[str]


class EmbedResponse(BaseModel):
embeddings: list[list[float]]


EmbedderFn = Callable[[EmbedRequest], EmbedResponse]
19 changes: 19 additions & 0 deletions py/packages/genkit/src/genkit/veneer/veneer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from http.server import HTTPServer
from typing import Any

from genkit.ai.embedding import EmbedRequest, EmbedResponse
from genkit.ai.model import ModelFn
from genkit.core.action import ActionKind
from genkit.core.environment import is_dev_environment
Expand Down Expand Up @@ -128,6 +129,24 @@ async def generate(
)
).response

async def embed(
self, model: str | None = None, documents: list[str] | None = None
) -> EmbedResponse:
"""Calculates embeddings for the given texts.
Args:
model: Optional embedder model name to use.
documents: Texts to embed.
Returns:
The generated response with embeddings.
"""
embed_action = self.registry.lookup_action(ActionKind.EMBEDDER, model)

return (
await embed_action.arun(EmbedRequest(documents=documents))
).response

def flow(self, name: str | None = None) -> Callable[[Callable], Callable]:
"""Decorator to register a function as a flow.
Expand Down
10 changes: 9 additions & 1 deletion py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
enabling the use of Vertex AI models and services within the Genkit framework.
"""

from genkit.plugins.vertex_ai.embedding import EmbeddingModels
from genkit.plugins.vertex_ai.gemini import GeminiVersion
from genkit.plugins.vertex_ai.plugin_api import VertexAI, vertexai_name


Expand All @@ -18,4 +20,10 @@ def package_name() -> str:
return 'genkit.plugins.vertex_ai'


__all__ = ['package_name', 'VertexAI', 'vertexai_name']
__all__ = [
'package_name',
'VertexAI',
'vertexai_name',
'EmbeddingModels',
'GeminiVersion',
]
56 changes: 56 additions & 0 deletions py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2025 Google LLC
# SPDX-License-Identifier: Apache-2.0

from enum import StrEnum
from typing import Any

from genkit.ai.embedding import EmbedRequest, EmbedResponse
from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel


class EmbeddingModels(StrEnum):
GECKO_003_ENG = 'textembedding-gecko@003'
TEXT_EMBEDDING_004_ENG = 'text-embedding-004'
TEXT_EMBEDDING_005_ENG = 'text-embedding-005'
GECKO_MULTILINGUAL = 'textembedding-gecko-multilingual@001'
TEXT_EMBEDDING_002_MULTILINGUAL = 'text-multilingual-embedding-002'


class TaskType(StrEnum):
SEMANTIC_SIMILARITY = 'SEMANTIC_SIMILARITY'
CLASSIFICATION = 'CLASSIFICATION'
CLUSTERING = 'CLUSTERING'
RETRIEVAL_DOCUMENT = 'RETRIEVAL_DOCUMENT'
RETRIEVAL_QUERY = 'RETRIEVAL_QUERY'
QUESTION_ANSWERING = 'QUESTION_ANSWERING'
FACT_VERIFICATION = 'FACT_VERIFICATION'
CODE_RETRIEVAL_QUERY = 'CODE_RETRIEVAL_QUERY'


class Embedder:
TASK = TaskType.RETRIEVAL_QUERY

# By default, the model generates embeddings with 768 dimensions.
# Models such as `text-embedding-004`, `text-embedding-005`,
# and `text-multilingual-embedding-002`allow the output dimensionality
# to be adjusted between 1 and 768.
DIMENSIONALITY = 768

def __init__(self, version):
self._version = version

@property
def embedding_model(self) -> TextEmbeddingModel:
return TextEmbeddingModel.from_pretrained(self._version)

def handle_request(self, request: EmbedRequest) -> EmbedResponse:
inputs = [
TextEmbeddingInput(text, self.TASK) for text in request.documents
]
vertexai_embeddings = self.embedding_model.get_embeddings(inputs)
embeddings = [embedding.values for embedding in vertexai_embeddings]
return EmbedResponse(embeddings=embeddings)

@property
def model_metadata(self) -> dict[str, dict[str, Any]]:
return {}
14 changes: 12 additions & 2 deletions py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"""

from enum import StrEnum
from typing import Any

from genkit.core.typing import (
GenerateRequest,
Expand Down Expand Up @@ -89,7 +90,7 @@ def __init__(self, version: str):
version: The version of the Gemini model to use, should be
one of the values from GeminiVersion.
"""
self.version = version
self._version = version

@property
def gemini_model(self) -> GenerativeModel:
Expand All @@ -98,7 +99,7 @@ def gemini_model(self) -> GenerativeModel:
Returns:
A configured GenerativeModel instance for the specified version.
"""
return GenerativeModel(self.version)
return GenerativeModel(self._version)

def handle_request(self, request: GenerateRequest) -> GenerateResponse:
"""Handle a generation request using the Gemini model.
Expand All @@ -125,3 +126,12 @@ def handle_request(self, request: GenerateRequest) -> GenerateResponse:
content=[TextPart(text=response.text)],
)
)

@property
def model_metadata(self) -> dict[str, dict[str, Any]]:
supports = SUPPORTED_MODELS[self._version].supports.model_dump()
return {
'model': {
'supports': supports,
}
}
34 changes: 19 additions & 15 deletions py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/plugin_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from genkit.core.plugin_abc import Plugin
from genkit.core.registry import Registry
from genkit.plugins.vertex_ai import constants as const
from genkit.plugins.vertex_ai.embedding import Embedder, EmbeddingModels
from genkit.plugins.vertex_ai.gemini import Gemini, GeminiVersion

LOG = logging.getLogger(__name__)
Expand All @@ -37,9 +38,6 @@ class VertexAI(Plugin):
registration of model actions.
"""

# This is 'gemini-1.5-pro' - the latest stable model
VERTEX_AI_GENERATIVE_MODEL_NAME: str = GeminiVersion.GEMINI_1_5_FLASH.value

def __init__(
self, project_id: str | None = None, location: str | None = None
):
Expand All @@ -56,8 +54,6 @@ def __init__(
project_id if project_id else os.getenv(const.GCLOUD_PROJECT)
)
location = location if location else const.DEFAULT_REGION

self._gemini = Gemini(self.VERTEX_AI_GENERATIVE_MODEL_NAME)
vertexai.init(project=project_id, location=location)

def initialize(self, registry: Registry) -> None:
Expand All @@ -69,13 +65,21 @@ def initialize(self, registry: Registry) -> None:
Args:
registry: The registry to register actions with.
"""
registry.register_action(
kind=ActionKind.MODEL,
name=vertexai_name(self.VERTEX_AI_GENERATIVE_MODEL_NAME),
fn=self._gemini.handle_request,
metadata={
'model': {
'supports': {'multiturn': True},
}
},
)

for model_version in GeminiVersion:
gemini = Gemini(model_version)
registry.register_action(
kind=ActionKind.MODEL,
name=vertexai_name(model_version),
fn=gemini.handle_request,
metadata=gemini.model_metadata,
)

for embed_model in EmbeddingModels:
embedder = Embedder(embed_model)
registry.register_action(
kind=ActionKind.EMBEDDER,
name=vertexai_name(embed_model),
fn=embedder.handle_request,
metadata=embedder.model_metadata,
)
34 changes: 30 additions & 4 deletions py/samples/hello/src/hello.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,18 @@

from genkit.core.action import ActionRunContext
from genkit.core.typing import GenerateRequest, Message, Role, TextPart
from genkit.plugins.vertex_ai import VertexAI, vertexai_name
from genkit.plugins.vertex_ai import (
EmbeddingModels,
GeminiVersion,
VertexAI,
vertexai_name,
)
from genkit.veneer.veneer import Genkit
from pydantic import BaseModel, Field

ai = Genkit(
plugins=[VertexAI()],
model=vertexai_name(VertexAI.VERTEX_AI_GENERATIVE_MODEL_NAME),
model=vertexai_name(GeminiVersion.GEMINI_1_5_FLASH),
)


Expand Down Expand Up @@ -71,6 +76,22 @@ async def say_hi(name: str):
)


@ai.flow()
async def embed_docs(docs: list[str]):
"""Generate an embedding for the words in a list.
Args:
docs: list of texts (string)
Returns:
The generated embedding.
"""
return await ai.embed(
model=vertexai_name(EmbeddingModels.TEXT_EMBEDDING_004_ENG),
documents=docs,
)


@ai.flow()
def sum_two_numbers2(my_input: MyInput) -> Any:
"""Add two numbers together.
Expand All @@ -85,15 +106,17 @@ def sum_two_numbers2(my_input: MyInput) -> Any:


@ai.flow()
def streamingSyncFlow(input: str, ctx: ActionRunContext):
def streaming_sync_flow(inp: str, ctx: ActionRunContext):
"""Example of sync flow."""
ctx.send_chunk(1)
ctx.send_chunk({'chunk': 'blah'})
ctx.send_chunk(3)
return 'streamingSyncFlow 4'


@ai.flow()
async def streamingAsyncFlow(input: str, ctx: ActionRunContext):
async def streaming_async_flow(inp: str, ctx: ActionRunContext):
"""Example of async flow."""
ctx.send_chunk(1)
ctx.send_chunk({'chunk': 'blah'})
ctx.send_chunk(3)
Expand All @@ -108,6 +131,9 @@ async def main() -> None:
"""
print(await say_hi('John Doe'))
print(sum_two_numbers2(MyInput(a=1, b=3)))
print(
await embed_docs(['banana muffins? ', 'banana bread? banana muffins?'])
)


if __name__ == '__main__':
Expand Down

0 comments on commit 443561e

Please sign in to comment.