Skip to content

Commit

Permalink
feat(pj/ai): added text/output methods to response and chunk wrappers (
Browse files Browse the repository at this point in the history
  • Loading branch information
pavelgj authored Mar 2, 2025
1 parent 81cbfcd commit a57cf85
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 22 deletions.
56 changes: 47 additions & 9 deletions py/packages/genkit/src/genkit/ai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ def my_model(request: GenerateRequest) -> GenerateResponse:
"""

from collections.abc import Callable
from functools import cached_property
from typing import Any

from genkit.core.extract import extract_json
from genkit.core.typing import (
GenerateRequest,
GenerateResponse,
Expand All @@ -30,7 +33,15 @@ def my_model(request: GenerateRequest) -> GenerateResponse:


class GenerateResponseWrapper(GenerateResponse):
"""A helper wrapper class for GenerateResponse that offer a few utility methods"""

def __init__(self, response: GenerateResponse, request: GenerateRequest):
"""Initializes a GenerateResponseWrapper instance.
Args:
response: The original GenerateResponse object.
request: The GenerateRequest object associated with the response.
"""
super().__init__(
message=response.message,
finish_reason=response.finish_reason,
Expand All @@ -52,33 +63,60 @@ def assert_valid_schema(self):
# TODO: implement
pass

def text(self):
@cached_property
def text(self) -> str:
"""Returns all text parts of the response joined into a single string"""
return ''.join([
p.root.text if p.root.text is not None else ''
for p in self.message.content
])

def output(self):
# TODO: implement
pass
@cached_property
def output(self) -> Any:
"""Parses out JSON data from the text parts of the response."""
return extract_json(self.text)


class GenerateResponseChunkWrapper(GenerateResponseChunk):
"""A helper wrapper class for GenerateResponseChunk that offer a few utility methods"""

previous_chunks: list[GenerateResponseChunk]

def __init__(
self,
chunk: GenerateResponseChunk,
index: int,
previous_chunks: list[GenerateResponseChunk],
index: str,
):
super().__init__(
role=chunk.role,
index=chunk.index,
index=index,
content=chunk.content,
custom=chunk.custom,
aggregated=chunk.aggregated,
previous_chunks=previous_chunks,
)

def text(self):
return ''.join([
@cached_property
def text(self) -> str:
"""Returns all text parts of the current chunk joined into a single string."""
return ''.join(
p.root.text if p.root.text is not None else '' for p in self.content
])
)

@cached_property
def accumulated_text(self) -> str:
"""Returns all text parts from previous chunks plus the latest chunk."""
if not self.previous_chunks:
return ''
atext = ''
for chunk in self.previous_chunks:
for p in chunk.content:
if p.root.text:
atext += p.root.text
return atext + self.text

@cached_property
def output(self) -> Any:
"""Parses out JSON data from the accumulated text parts of the response."""
return extract_json(self.accumulated_text)
24 changes: 12 additions & 12 deletions py/packages/genkit/src/genkit/veneer/veneer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ async def test_generate_uses_default_model(setup_test) -> None:

response = await ai.generate(prompt='hi', config={'temperature': 11})

assert response.text() == '[ECHO] user: "hi" {"temperature": 11}'
assert response.text == '[ECHO] user: "hi" {"temperature": 11}'


@pytest.mark.asyncio
Expand All @@ -52,7 +52,7 @@ async def test_generate_with_explicit_model(setup_test) -> None:
model='echoModel', prompt='hi', config={'temperature': 11}
)

assert response.text() == '[ECHO] user: "hi" {"temperature": 11}'
assert response.text == '[ECHO] user: "hi" {"temperature": 11}'


@pytest.mark.asyncio
Expand All @@ -61,7 +61,7 @@ async def test_generate_with_str_prompt(setup_test) -> None:

response = await ai.generate(prompt='hi', config={'temperature': 11})

assert response.text() == '[ECHO] user: "hi" {"temperature": 11}'
assert response.text == '[ECHO] user: "hi" {"temperature": 11}'


@pytest.mark.asyncio
Expand All @@ -72,7 +72,7 @@ async def test_generate_with_part_prompt(setup_test) -> None:
prompt=TextPart(text='hi'), config={'temperature': 11}
)

assert response.text() == '[ECHO] user: "hi" {"temperature": 11}'
assert response.text == '[ECHO] user: "hi" {"temperature": 11}'


@pytest.mark.asyncio
Expand All @@ -84,7 +84,7 @@ async def test_generate_with_part_list_prompt(setup_test) -> None:
config={'temperature': 11},
)

assert response.text() == '[ECHO] user: "hello","world" {"temperature": 11}'
assert response.text == '[ECHO] user: "hello","world" {"temperature": 11}'


@pytest.mark.asyncio
Expand All @@ -96,7 +96,7 @@ async def test_generate_with_str_system(setup_test) -> None:
)

assert (
response.text()
response.text
== '[ECHO] system: "talk like pirate" user: "hi" {"temperature": 11}'
)

Expand All @@ -112,7 +112,7 @@ async def test_generate_with_part_system(setup_test) -> None:
)

assert (
response.text()
response.text
== '[ECHO] system: "talk like pirate" user: "hi" {"temperature": 11}'
)

Expand All @@ -128,7 +128,7 @@ async def test_generate_with_part_list_system(setup_test) -> None:
)

assert (
response.text()
response.text
== '[ECHO] system: "talk","like pirate" user: "hi" {"temperature": 11}'
)

Expand All @@ -147,7 +147,7 @@ async def test_generate_with_messages(setup_test) -> None:
config={'temperature': 11},
)

assert response.text() == '[ECHO] user: "hi" {"temperature": 11}'
assert response.text == '[ECHO] user: "hi" {"temperature": 11}'


@pytest.mark.asyncio
Expand All @@ -170,7 +170,7 @@ async def test_generate_with_system_prompt_messages(setup_test) -> None:
)

assert (
response.text()
response.text
== '[ECHO] system: "talk like pirate" user: "hi" model: "bye" user: "hi again"'
)

Expand All @@ -186,7 +186,7 @@ async def test_generate_with_tools(setup_test) -> None:
tools=['testTool'],
)

assert response.text() == '[ECHO] user: "hi" tool_choice=required'
assert response.text == '[ECHO] user: "hi" tool_choice=required'
assert echo.last_request.tools == [
ToolDefinition(
name='testTool',
Expand Down Expand Up @@ -251,7 +251,7 @@ def test_tool(input: ToolInput):
tools=['testTool'],
)

assert response.text() == 'tool called'
assert response.text == 'tool called'
assert response.request.messages[0] == Message(
role=Role.USER, content=[TextPart(text='hi')]
)
Expand Down
2 changes: 1 addition & 1 deletion py/packages/genkit/tests/genkit/ai/generate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ async def test_simple_text_generate_request(setup_test) -> None:
),
)

assert response.text() == 'bye'
assert response.text == 'bye'


##########################################################################
Expand Down
92 changes: 92 additions & 0 deletions py/packages/genkit/tests/genkit/ai/model_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#!/usr/bin/env python3
#
# Copyright 2025 Google LLC
# SPDX-License-Identifier: Apache-2.0

"""Tests for the action module."""

from genkit.ai.model import (
GenerateResponseChunkWrapper,
GenerateResponseWrapper,
)
from genkit.core.typing import (
GenerateRequest,
GenerateResponse,
GenerateResponseChunk,
Message,
TextPart,
)


def test_response_wrapper_text() -> None:
wrapper = GenerateResponseWrapper(
response=GenerateResponse(
message=Message(
role='model',
content=[TextPart(text='hello'), TextPart(text=' world')],
)
),
request=GenerateRequest(
messages=[], # doesn't matter for now
),
)

assert wrapper.text == 'hello world'


def test_response_wrapper_output() -> None:
wrapper = GenerateResponseWrapper(
response=GenerateResponse(
message=Message(
role='model',
content=[TextPart(text='{"foo":'), TextPart(text='"bar')],
)
),
request=GenerateRequest(
messages=[], # doesn't matter for now
),
)

assert wrapper.output == {'foo': 'bar'}


def test_chunk_wrapper_text() -> None:
wrapper = GenerateResponseChunkWrapper(
chunk=GenerateResponseChunk(
content=[TextPart(text='hello'), TextPart(text=' world')]
),
index=0,
previous_chunks=[],
)

assert wrapper.text == 'hello world'


def test_chunk_wrapper_accumulated_text() -> None:
wrapper = GenerateResponseChunkWrapper(
GenerateResponseChunk(content=[TextPart(text=' PS: aliens')]),
index=0,
previous_chunks=[
GenerateResponseChunk(
content=[TextPart(text='hello'), TextPart(text=' ')]
),
GenerateResponseChunk(content=[TextPart(text='world!')]),
],
)

assert wrapper.accumulated_text == 'hello world! PS: aliens'


def test_chunk_wrapper_output() -> None:
wrapper = GenerateResponseChunkWrapper(
GenerateResponseChunk(content=[TextPart(text=', "baz":[1,2,')]),
index=0,
previous_chunks=[
GenerateResponseChunk(
content=[TextPart(text='{"foo":'), TextPart(text='"ba')]
),
GenerateResponseChunk(content=[TextPart(text='r"')]),
],
)

assert wrapper.output == {'foo': 'bar', 'baz': [1, 2]}

0 comments on commit a57cf85

Please sign in to comment.