Skip to content

Commit

Permalink
feat(py/ai): allow setting config schema and info in the model metada…
Browse files Browse the repository at this point in the history
…ta (#2220)
  • Loading branch information
pavelgj authored Mar 3, 2025
1 parent 458db7c commit 54c35b2
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 2 deletions.
2 changes: 1 addition & 1 deletion py/packages/genkit/src/genkit/ai/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

"""Testing utils/helpers for genkit.ai"""

from genkit.core.action import Action, ActionRunContext
from genkit.core.action import ActionRunContext
from genkit.core.codec import dump_json
from genkit.core.typing import (
GenerateRequest,
Expand Down
24 changes: 23 additions & 1 deletion py/packages/genkit/src/genkit/veneer/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
from genkit.ai.formats.types import FormatDef
from genkit.ai.model import ModelFn
from genkit.core.action import Action, ActionKind
from genkit.core.codec import dump_dict
from genkit.core.registry import Registry
from genkit.core.schema import to_json_schema
from genkit.core.typing import ModelInfo
from pydantic import BaseModel


class GenkitRegistry:
Expand Down Expand Up @@ -138,19 +142,37 @@ def define_model(
self,
name: str,
fn: ModelFn,
config_schema: BaseModel | dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
info: ModelInfo | None = None,
) -> Action:
"""Define a custom model action.
Args:
name: Name of the model.
fn: Function implementing the model behavior.
config_schema: Optional schema for model configuration.
metadata: Optional metadata for the model.
info: Optional ModelInfo for the model.
"""
model_meta = metadata if metadata else {}
if info:
model_meta['model'] = dump_dict(info)
if 'model' not in model_meta:
model_meta['model'] = {}
if (
'label' not in model_meta['model']
or not model_meta['model']['label']
):
model_meta['model']['label'] = name

if config_schema:
model_meta['model']['customOptions'] = to_json_schema(config_schema)

return self.registry.register_action(
name=name,
kind=ActionKind.MODEL,
fn=fn,
metadata=metadata,
metadata=model_meta,
)

def define_embedder(
Expand Down
86 changes: 86 additions & 0 deletions py/packages/genkit/src/genkit/veneer/veneer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
GenerateResponseChunk,
Message,
Metadata,
ModelInfo,
OutputConfig,
Role,
Supports,
TextPart,
ToolDefinition,
ToolRequest1,
Expand Down Expand Up @@ -628,3 +630,87 @@ def collect_chunks(chunk):
content_type='application/banana',
),
)


def test_define_model_default_metadata(setup_test: SetupFixture):
ai, _, pm, *_ = setup_test

def foo_model_fn():
return GenerateResponse(
message=Message(role=Role.MODEL, content=[TextPart(text='banana!')])
)

action = ai.define_model(
name='foo',
fn=foo_model_fn,
)

assert action.metadata['model'] == {
'label': 'foo',
}


def test_define_model_with_schema(setup_test: SetupFixture):
ai, _, pm, *_ = setup_test

class Config(BaseModel):
field_a: str = Field(description='a field')
field_b: str = Field(description='b field')

def foo_model_fn():
return GenerateResponse(
message=Message(role=Role.MODEL, content=[TextPart(text='banana!')])
)

action = ai.define_model(
name='foo',
fn=foo_model_fn,
config_schema=Config,
)
assert action.metadata['model'] == {
'customOptions': {
'properties': {
'field_a': {
'description': 'a field',
'title': 'Field A',
'type': 'string',
},
'field_b': {
'description': 'b field',
'title': 'Field B',
'type': 'string',
},
},
'required': [
'field_a',
'field_b',
],
'title': 'Config',
'type': 'object',
},
'label': 'foo',
}


def test_define_model_with_info(setup_test: SetupFixture):
ai, _, pm, *_ = setup_test

def foo_model_fn():
return GenerateResponse(
message=Message(role=Role.MODEL, content=[TextPart(text='banana!')])
)

action = ai.define_model(
name='foo',
fn=foo_model_fn,
info=ModelInfo(
label='Foo Bar', supports=Supports(multiturn=True, tools=True)
),
)
assert action.metadata['model'] == {
'label': 'Foo Bar',
'supports': {
'multiturn': True,
'tools': True,
},
}

0 comments on commit 54c35b2

Please sign in to comment.