diff --git a/py/packages/genkit/src/genkit/ai/model.py b/py/packages/genkit/src/genkit/ai/model.py index 382d0f5a2..01ded4d6a 100644 --- a/py/packages/genkit/src/genkit/ai/model.py +++ b/py/packages/genkit/src/genkit/ai/model.py @@ -16,7 +16,10 @@ def my_model(request: GenerateRequest) -> GenerateResponse: """ from collections.abc import Callable +from functools import cached_property +from typing import Any +from genkit.core.extract import extract_json from genkit.core.typing import ( GenerateRequest, GenerateResponse, @@ -30,7 +33,15 @@ def my_model(request: GenerateRequest) -> GenerateResponse: class GenerateResponseWrapper(GenerateResponse): + """A helper wrapper class for GenerateResponse that offer a few utility methods""" + def __init__(self, response: GenerateResponse, request: GenerateRequest): + """Initializes a GenerateResponseWrapper instance. + + Args: + response: The original GenerateResponse object. + request: The GenerateRequest object associated with the response. + """ super().__init__( message=response.message, finish_reason=response.finish_reason, @@ -52,33 +63,60 @@ def assert_valid_schema(self): # TODO: implement pass - def text(self): + @cached_property + def text(self) -> str: + """Returns all text parts of the response joined into a single string""" return ''.join([ p.root.text if p.root.text is not None else '' for p in self.message.content ]) - def output(self): - # TODO: implement - pass + @cached_property + def output(self) -> Any: + """Parses out JSON data from the text parts of the response.""" + return extract_json(self.text) class GenerateResponseChunkWrapper(GenerateResponseChunk): + """A helper wrapper class for GenerateResponseChunk that offer a few utility methods""" + + previous_chunks: list[GenerateResponseChunk] + def __init__( self, chunk: GenerateResponseChunk, - index: int, previous_chunks: list[GenerateResponseChunk], + index: str, ): super().__init__( role=chunk.role, - index=chunk.index, + index=index, content=chunk.content, custom=chunk.custom, aggregated=chunk.aggregated, + previous_chunks=previous_chunks, ) - def text(self): - return ''.join([ + @cached_property + def text(self) -> str: + """Returns all text parts of the current chunk joined into a single string.""" + return ''.join( p.root.text if p.root.text is not None else '' for p in self.content - ]) + ) + + @cached_property + def accumulated_text(self) -> str: + """Returns all text parts from previous chunks plus the latest chunk.""" + if not self.previous_chunks: + return '' + atext = '' + for chunk in self.previous_chunks: + for p in chunk.content: + if p.root.text: + atext += p.root.text + return atext + self.text + + @cached_property + def output(self) -> Any: + """Parses out JSON data from the accumulated text parts of the response.""" + return extract_json(self.accumulated_text) diff --git a/py/packages/genkit/src/genkit/veneer/veneer_test.py b/py/packages/genkit/src/genkit/veneer/veneer_test.py index f7174cb78..a222f5bd6 100644 --- a/py/packages/genkit/src/genkit/veneer/veneer_test.py +++ b/py/packages/genkit/src/genkit/veneer/veneer_test.py @@ -41,7 +41,7 @@ async def test_generate_uses_default_model(setup_test) -> None: response = await ai.generate(prompt='hi', config={'temperature': 11}) - assert response.text() == '[ECHO] user: "hi" {"temperature": 11}' + assert response.text == '[ECHO] user: "hi" {"temperature": 11}' @pytest.mark.asyncio @@ -52,7 +52,7 @@ async def test_generate_with_explicit_model(setup_test) -> None: model='echoModel', prompt='hi', config={'temperature': 11} ) - assert response.text() == '[ECHO] user: "hi" {"temperature": 11}' + assert response.text == '[ECHO] user: "hi" {"temperature": 11}' @pytest.mark.asyncio @@ -61,7 +61,7 @@ async def test_generate_with_str_prompt(setup_test) -> None: response = await ai.generate(prompt='hi', config={'temperature': 11}) - assert response.text() == '[ECHO] user: "hi" {"temperature": 11}' + assert response.text == '[ECHO] user: "hi" {"temperature": 11}' @pytest.mark.asyncio @@ -72,7 +72,7 @@ async def test_generate_with_part_prompt(setup_test) -> None: prompt=TextPart(text='hi'), config={'temperature': 11} ) - assert response.text() == '[ECHO] user: "hi" {"temperature": 11}' + assert response.text == '[ECHO] user: "hi" {"temperature": 11}' @pytest.mark.asyncio @@ -84,7 +84,7 @@ async def test_generate_with_part_list_prompt(setup_test) -> None: config={'temperature': 11}, ) - assert response.text() == '[ECHO] user: "hello","world" {"temperature": 11}' + assert response.text == '[ECHO] user: "hello","world" {"temperature": 11}' @pytest.mark.asyncio @@ -96,7 +96,7 @@ async def test_generate_with_str_system(setup_test) -> None: ) assert ( - response.text() + response.text == '[ECHO] system: "talk like pirate" user: "hi" {"temperature": 11}' ) @@ -112,7 +112,7 @@ async def test_generate_with_part_system(setup_test) -> None: ) assert ( - response.text() + response.text == '[ECHO] system: "talk like pirate" user: "hi" {"temperature": 11}' ) @@ -128,7 +128,7 @@ async def test_generate_with_part_list_system(setup_test) -> None: ) assert ( - response.text() + response.text == '[ECHO] system: "talk","like pirate" user: "hi" {"temperature": 11}' ) @@ -147,7 +147,7 @@ async def test_generate_with_messages(setup_test) -> None: config={'temperature': 11}, ) - assert response.text() == '[ECHO] user: "hi" {"temperature": 11}' + assert response.text == '[ECHO] user: "hi" {"temperature": 11}' @pytest.mark.asyncio @@ -170,7 +170,7 @@ async def test_generate_with_system_prompt_messages(setup_test) -> None: ) assert ( - response.text() + response.text == '[ECHO] system: "talk like pirate" user: "hi" model: "bye" user: "hi again"' ) @@ -186,7 +186,7 @@ async def test_generate_with_tools(setup_test) -> None: tools=['testTool'], ) - assert response.text() == '[ECHO] user: "hi" tool_choice=required' + assert response.text == '[ECHO] user: "hi" tool_choice=required' assert echo.last_request.tools == [ ToolDefinition( name='testTool', @@ -251,7 +251,7 @@ def test_tool(input: ToolInput): tools=['testTool'], ) - assert response.text() == 'tool called' + assert response.text == 'tool called' assert response.request.messages[0] == Message( role=Role.USER, content=[TextPart(text='hi')] ) diff --git a/py/packages/genkit/tests/genkit/ai/generate_test.py b/py/packages/genkit/tests/genkit/ai/generate_test.py index 2dc97f9b8..7c19b25be 100644 --- a/py/packages/genkit/tests/genkit/ai/generate_test.py +++ b/py/packages/genkit/tests/genkit/ai/generate_test.py @@ -63,7 +63,7 @@ async def test_simple_text_generate_request(setup_test) -> None: ), ) - assert response.text() == 'bye' + assert response.text == 'bye' ########################################################################## diff --git a/py/packages/genkit/tests/genkit/ai/model_test.py b/py/packages/genkit/tests/genkit/ai/model_test.py new file mode 100644 index 000000000..c68def38a --- /dev/null +++ b/py/packages/genkit/tests/genkit/ai/model_test.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python3 +# +# Copyright 2025 Google LLC +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for the action module.""" + +from genkit.ai.model import ( + GenerateResponseChunkWrapper, + GenerateResponseWrapper, +) +from genkit.core.typing import ( + GenerateRequest, + GenerateResponse, + GenerateResponseChunk, + Message, + TextPart, +) + + +def test_response_wrapper_text() -> None: + wrapper = GenerateResponseWrapper( + response=GenerateResponse( + message=Message( + role='model', + content=[TextPart(text='hello'), TextPart(text=' world')], + ) + ), + request=GenerateRequest( + messages=[], # doesn't matter for now + ), + ) + + assert wrapper.text == 'hello world' + + +def test_response_wrapper_output() -> None: + wrapper = GenerateResponseWrapper( + response=GenerateResponse( + message=Message( + role='model', + content=[TextPart(text='{"foo":'), TextPart(text='"bar')], + ) + ), + request=GenerateRequest( + messages=[], # doesn't matter for now + ), + ) + + assert wrapper.output == {'foo': 'bar'} + + +def test_chunk_wrapper_text() -> None: + wrapper = GenerateResponseChunkWrapper( + chunk=GenerateResponseChunk( + content=[TextPart(text='hello'), TextPart(text=' world')] + ), + index=0, + previous_chunks=[], + ) + + assert wrapper.text == 'hello world' + + +def test_chunk_wrapper_accumulated_text() -> None: + wrapper = GenerateResponseChunkWrapper( + GenerateResponseChunk(content=[TextPart(text=' PS: aliens')]), + index=0, + previous_chunks=[ + GenerateResponseChunk( + content=[TextPart(text='hello'), TextPart(text=' ')] + ), + GenerateResponseChunk(content=[TextPart(text='world!')]), + ], + ) + + assert wrapper.accumulated_text == 'hello world! PS: aliens' + + +def test_chunk_wrapper_output() -> None: + wrapper = GenerateResponseChunkWrapper( + GenerateResponseChunk(content=[TextPart(text=', "baz":[1,2,')]), + index=0, + previous_chunks=[ + GenerateResponseChunk( + content=[TextPart(text='{"foo":'), TextPart(text='"ba')] + ), + GenerateResponseChunk(content=[TextPart(text='r"')]), + ], + ) + + assert wrapper.output == {'foo': 'bar', 'baz': [1, 2]}