From 54c35b2404ea1fcbebd69aadadf60ca414469209 Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Sun, 2 Mar 2025 20:16:54 -0500 Subject: [PATCH] feat(py/ai): allow setting config schema and info in the model metadata (#2220) --- .../genkit/src/genkit/ai/testing_utils.py | 2 +- .../genkit/src/genkit/veneer/registry.py | 24 +++++- .../genkit/src/genkit/veneer/veneer_test.py | 86 +++++++++++++++++++ 3 files changed, 110 insertions(+), 2 deletions(-) diff --git a/py/packages/genkit/src/genkit/ai/testing_utils.py b/py/packages/genkit/src/genkit/ai/testing_utils.py index f0b1af2b4..957a2b91b 100644 --- a/py/packages/genkit/src/genkit/ai/testing_utils.py +++ b/py/packages/genkit/src/genkit/ai/testing_utils.py @@ -5,7 +5,7 @@ """Testing utils/helpers for genkit.ai""" -from genkit.core.action import Action, ActionRunContext +from genkit.core.action import ActionRunContext from genkit.core.codec import dump_json from genkit.core.typing import ( GenerateRequest, diff --git a/py/packages/genkit/src/genkit/veneer/registry.py b/py/packages/genkit/src/genkit/veneer/registry.py index 503fb1a25..cc5ef80de 100644 --- a/py/packages/genkit/src/genkit/veneer/registry.py +++ b/py/packages/genkit/src/genkit/veneer/registry.py @@ -11,7 +11,11 @@ from genkit.ai.formats.types import FormatDef from genkit.ai.model import ModelFn from genkit.core.action import Action, ActionKind +from genkit.core.codec import dump_dict from genkit.core.registry import Registry +from genkit.core.schema import to_json_schema +from genkit.core.typing import ModelInfo +from pydantic import BaseModel class GenkitRegistry: @@ -138,19 +142,37 @@ def define_model( self, name: str, fn: ModelFn, + config_schema: BaseModel | dict[str, Any] | None = None, metadata: dict[str, Any] | None = None, + info: ModelInfo | None = None, ) -> Action: """Define a custom model action. Args: name: Name of the model. fn: Function implementing the model behavior. + config_schema: Optional schema for model configuration. metadata: Optional metadata for the model. + info: Optional ModelInfo for the model. """ + model_meta = metadata if metadata else {} + if info: + model_meta['model'] = dump_dict(info) + if 'model' not in model_meta: + model_meta['model'] = {} + if ( + 'label' not in model_meta['model'] + or not model_meta['model']['label'] + ): + model_meta['model']['label'] = name + + if config_schema: + model_meta['model']['customOptions'] = to_json_schema(config_schema) + return self.registry.register_action( name=name, kind=ActionKind.MODEL, fn=fn, - metadata=metadata, + metadata=model_meta, ) def define_embedder( diff --git a/py/packages/genkit/src/genkit/veneer/veneer_test.py b/py/packages/genkit/src/genkit/veneer/veneer_test.py index f96a3763a..5c98c5af3 100644 --- a/py/packages/genkit/src/genkit/veneer/veneer_test.py +++ b/py/packages/genkit/src/genkit/veneer/veneer_test.py @@ -23,8 +23,10 @@ GenerateResponseChunk, Message, Metadata, + ModelInfo, OutputConfig, Role, + Supports, TextPart, ToolDefinition, ToolRequest1, @@ -628,3 +630,87 @@ def collect_chunks(chunk): content_type='application/banana', ), ) + + +def test_define_model_default_metadata(setup_test: SetupFixture): + ai, _, pm, *_ = setup_test + + def foo_model_fn(): + return GenerateResponse( + message=Message(role=Role.MODEL, content=[TextPart(text='banana!')]) + ) + + action = ai.define_model( + name='foo', + fn=foo_model_fn, + ) + + assert action.metadata['model'] == { + 'label': 'foo', + } + + +def test_define_model_with_schema(setup_test: SetupFixture): + ai, _, pm, *_ = setup_test + + class Config(BaseModel): + field_a: str = Field(description='a field') + field_b: str = Field(description='b field') + + def foo_model_fn(): + return GenerateResponse( + message=Message(role=Role.MODEL, content=[TextPart(text='banana!')]) + ) + + action = ai.define_model( + name='foo', + fn=foo_model_fn, + config_schema=Config, + ) + assert action.metadata['model'] == { + 'customOptions': { + 'properties': { + 'field_a': { + 'description': 'a field', + 'title': 'Field A', + 'type': 'string', + }, + 'field_b': { + 'description': 'b field', + 'title': 'Field B', + 'type': 'string', + }, + }, + 'required': [ + 'field_a', + 'field_b', + ], + 'title': 'Config', + 'type': 'object', + }, + 'label': 'foo', + } + + +def test_define_model_with_info(setup_test: SetupFixture): + ai, _, pm, *_ = setup_test + + def foo_model_fn(): + return GenerateResponse( + message=Message(role=Role.MODEL, content=[TextPart(text='banana!')]) + ) + + action = ai.define_model( + name='foo', + fn=foo_model_fn, + info=ModelInfo( + label='Foo Bar', supports=Supports(multiturn=True, tools=True) + ), + ) + assert action.metadata['model'] == { + 'label': 'Foo Bar', + 'supports': { + 'multiturn': True, + 'tools': True, + }, + }