diff --git a/py/packages/genkit/pyproject.toml b/py/packages/genkit/pyproject.toml index a37e37ada..2968c0e79 100644 --- a/py/packages/genkit/pyproject.toml +++ b/py/packages/genkit/pyproject.toml @@ -20,6 +20,8 @@ dependencies = [ "pydantic>=2.10.5", "requests>=2.32.3", "dotprompt", + "partial-json-parser>=0.2.1.1.post5", + "json5>=0.10.0", ] description = "Genkit AI Framework" license = { text = "Apache-2.0" } diff --git a/py/packages/genkit/src/genkit/core/extract.py b/py/packages/genkit/src/genkit/core/extract.py new file mode 100644 index 000000000..86451f384 --- /dev/null +++ b/py/packages/genkit/src/genkit/core/extract.py @@ -0,0 +1,238 @@ +# Copyright 2025 Google LLC +# SPDX-License-Identifier: Apache-2.0 + +"""Utility functions for extracting JSON data from text and markdown.""" + +from typing import Any + +import json5 +from partial_json_parser import loads + +CHAR_NON_BREAKING_SPACE = '\u00a0' + + +def parse_partial_json(json_string: str) -> Any: + """Parses a partially complete JSON string and returns the parsed object. + + This function attempts to parse the given JSON string, even if it is not + a complete or valid JSON document. + + Args: + json_string: The string to parse as JSON. + + Returns: + The parsed JSON object. + + Raises: + AssertionError: If the string cannot be parsed as JSON. + """ + # TODO: add handling for malformed JSON cases. + return loads(json_string) + + +def extract_json(text: str, throw_on_bad_json: bool = True) -> Any: + """ + Extracts JSON from a string with lenient parsing. + + This function attempts to extract a valid JSON object or array from a + string, even if the string contains extraneous characters or minor + formatting issues. It uses a combination of basic parsing and + `json5` and `partial-json` libraries to maximize the chance of + successful extraction. + + Args: + text: The string to extract JSON from. + throw_on_bad_json: If True, raises a ValueError if no valid JSON + can be extracted. If False, returns None in such cases. + + Returns: + The extracted JSON object (dict or list), or None if no valid + JSON is found and `throw_on_bad_json` is False. + + Raises: + ValueError: If `throw_on_bad_json` is True and no valid JSON + can be extracted. + + Examples: + >>> extract_json(' { "key" : "value" } ') + {'key': 'value'} + + >>> extract_json('{"key": "value",}') # Trailing comma + {'key': 'value'} + + >>> extract_json('some text {"key": "value"} more text') + {'key': 'value'} + + >>> extract_json('invalid json', throw_on_bad_json=False) + None + """ + opening_char = None + closing_char = None + start_pos = None + nesting_count = 0 + in_string = False + escape_next = False + + for i in range(len(text)): + char = text[i].replace(CHAR_NON_BREAKING_SPACE, ' ') + + if escape_next: + escape_next = False + continue + + if char == '\\': + escape_next = True + continue + + if char == '"': + in_string = not in_string + continue + + if in_string: + continue + + if not opening_char and (char == '{' or char == '['): + # Look for opening character + opening_char = char + closing_char = '}' if char == '{' else ']' + start_pos = i + nesting_count += 1 + elif char == opening_char: + # Increment nesting for matching opening character + nesting_count += 1 + elif char == closing_char: + # Decrement nesting for matching closing character + nesting_count -= 1 + if not nesting_count: + # Reached end of target element + return json5.loads(text[start_pos or 0 : i + 1]) + if start_pos is not None and nesting_count > 0: + # If an incomplete JSON structure is detected + try: + # Parse the incomplete JSON structure using partial-json for lenient parsing + return parse_partial_json(text[start_pos:]) + except: + # If parsing fails, throw an error + if throw_on_bad_json: + raise ValueError( + f'Invalid JSON extracted from model output: {text}' + ) + return None + + if throw_on_bad_json: + raise ValueError(f'Invalid JSON extracted from model output: {text}') + return None + + +class ExtractItemsResult: + """Result of array item extraction.""" + + def __init__(self, items: list, cursor: int): + self.items = items + self.cursor = cursor + + +def extract_items(text: str, cursor: int = 0) -> ExtractItemsResult: + """ + Extracts complete JSON objects from the first array found in the text. + + This function searches for the first JSON array within the input string, + starting from an optional cursor position. It extracts complete JSON + objects from this array and returns them along with an updated cursor + position, indicating how much of the string has been processed. + + Args: + text: The string to extract items from. + cursor: The starting position for searching the array (default: 0). + Useful for processing large strings in chunks. + + Returns: + An `ExtractItemsResult` object containing: + - `items`: A list of extracted JSON objects (dictionaries). + - `cursor`: The updated cursor position, which is the index + immediately after the last processed character. If no array is + found, the cursor will be the length of the text. + + Examples: + >>> text = '[{"a": 1}, {"b": 2}, {"c": 3}]' + >>> result = extract_items(text) + >>> result.items + [{'a': 1}, {'b': 2}, {'c': 3}] + >>> result.cursor + 29 + + >>> text = ' [ {"x": 10}, {"y": 20} ] ' + >>> result = extract_items(text) + >>> result.items + [{'x': 10}, {'y': 20}] + >>> result.cursor + 25 + + >>> text = 'some text [ {"p": 100} , {"q": 200} ] more text' + >>> result = extract_items(text, cursor=10) + >>> result.items + [{'p': 100}, {'q': 200}] + >>> result.cursor + 35 + + >>> text = 'no array here' + >>> result = extract_items(text) + >>> result.items + [] + >>> result.cursor + 13 + """ + items = [] + current_cursor = cursor + + # Find the first array start if we haven't already processed any text + if cursor == 0: + array_start = text.find('[') + if array_start == -1: + return ExtractItemsResult(items=[], cursor=len(text)) + current_cursor = array_start + 1 + + object_start = -1 + brace_count = 0 + in_string = False + escape_next = False + + # Process the text from the cursor position + for i in range(current_cursor, len(text)): + char = text[i] + + if escape_next: + escape_next = False + continue + + if char == '\\': + escape_next = True + continue + + if char == '"': + in_string = not in_string + continue + + if in_string: + continue + + if char == '{': + if brace_count == 0: + object_start = i + brace_count += 1 + elif char == '}': + brace_count -= 1 + if brace_count == 0 and object_start != -1: + try: + obj = json5.loads(text[object_start : i + 1]) + items.append(obj) + current_cursor = i + 1 + object_start = -1 + except: + # If parsing fails, continue + pass + elif char == ']' and brace_count == 0: + # End of array + break + + return ExtractItemsResult(items=items, cursor=current_cursor) diff --git a/py/packages/genkit/tests/genkit/core/extract_test.py b/py/packages/genkit/tests/genkit/core/extract_test.py new file mode 100644 index 000000000..fa34440d2 --- /dev/null +++ b/py/packages/genkit/tests/genkit/core/extract_test.py @@ -0,0 +1,167 @@ +# Copyright 2025 Google LLC +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from genkit.core.extract import extract_items, extract_json, parse_partial_json + +# TODO: consider extracting these tests into shared yaml spec. They are already +# duplicated in js/ai/tests/extract_test.ts + +test_cases_extract_items = [ + ( + 'handles simple array in chunks', + [ + {'chunk': '[', 'want': []}, + {'chunk': '{"a": 1},', 'want': [{'a': 1}]}, + {'chunk': '{"b": 2}', 'want': [{'b': 2}]}, + {'chunk': ']', 'want': []}, + ], + ), + ( + 'handles nested objects', + [ + {'chunk': '[{"outer": {', 'want': []}, + { + 'chunk': '"inner": "value"}},', + 'want': [{'outer': {'inner': 'value'}}], + }, + {'chunk': '{"next": true}]', 'want': [{'next': True}]}, + ], + ), + ( + 'handles escaped characters', + [ + {'chunk': '[{"text": "line1\\n', 'want': []}, + { + 'chunk': 'line2"},', + 'want': [{'text': 'line1\nline2'}], + }, + { + 'chunk': '{"text": "tab\\there"}]', + 'want': [{'text': 'tab\there'}], + }, + ], + ), + ( + 'ignores content before first array', + [ + {'chunk': 'Here is an array:\n```json\n\n[', 'want': []}, + {'chunk': '{"a": 1},', 'want': [{'a': 1}]}, + { + 'chunk': '{"b": 2}]\n```\nDid you like my array?', + 'want': [{'b': 2}], + }, + ], + ), + ( + 'handles whitespace', + [ + {'chunk': '[\n ', 'want': []}, + {'chunk': '{"a": 1},\n ', 'want': [{'a': 1}]}, + {'chunk': '{"b": 2}\n]', 'want': [{'b': 2}]}, + ], + ), +] + + +@pytest.mark.parametrize( + 'name, steps', + test_cases_extract_items, + ids=[tc[0] for tc in test_cases_extract_items], +) +def test_extract_items(name, steps): + text = '' + cursor = 0 + for step in steps: + text += step['chunk'] + result = extract_items(text, cursor) + assert result.items == step['want'] + cursor = result.cursor + + +test_cases_extract_json = [ + ( + 'extracts simple object', + {'text': 'prefix{"a":1}suffix'}, + {'expected': {'a': 1}}, + ), + ( + 'extracts simple array', + {'text': 'prefix[1,2,3]suffix'}, + {'expected': [1, 2, 3]}, + ), + ( + 'handles nested structures', + {'text': 'text{"a":{"b":[1,2]}}more'}, + {'expected': {'a': {'b': [1, 2]}}}, + ), + ( + 'handles strings with braces', + {'text': '{"text": "not {a} json"}'}, + {'expected': {'text': 'not {a} json'}}, + ), + ( + 'returns null for invalid JSON without throw', + {'text': 'not json at all'}, + {'expected': None}, + ), + ( + 'throws for invalid JSON with throw flag', + {'text': 'not json at all', 'throwOnBadJson': True}, + {'throws': True}, + ), +] + + +@pytest.mark.parametrize( + 'name, input_data, expected_data', + test_cases_extract_json, + ids=[tc[0] for tc in test_cases_extract_json], +) +def test_extract_json(name, input_data, expected_data): + if expected_data.get('throws'): + with pytest.raises(Exception): + extract_json(input_data['text'], throw_on_bad_json=True) + else: + result = extract_json( + input_data['text'], + throw_on_bad_json=input_data.get('throwOnBadJson', False), + ) + assert result == expected_data['expected'] + + +test_cases_parse_partial_json = [ + ( + 'parses complete object', + '{"a":1,"b":2}', + {'expected': {'a': 1, 'b': 2}}, + ), + ( + 'parses partial object', + '{"a":1,"b":', + {'expected': {'a': 1}}, + ), + ( + 'parses partial array', + '[1,2,3,', + {'expected': [1, 2, 3]}, + ), + # NOTE: this testcase diverges from the one in js/ai/tests/extract_test.ts + # Specifically, python partial json parser lib doesn't like malformed json. + # JS one handles input: '{"a":{"b":1,"c":]}}', + ( + 'parses nested partial structures', + '{"a":{"b":1,"c":[', + {'expected': {'a': {'b': 1, 'c': []}}}, + ), +] + + +@pytest.mark.parametrize( + 'name, input_str, expected_data', + test_cases_parse_partial_json, + ids=[tc[0] for tc in test_cases_parse_partial_json], +) +def test_parse_partial_json(name, input_str, expected_data): + result = parse_partial_json(input_str) + assert result == expected_data['expected'] diff --git a/py/uv.lock b/py/uv.lock index 63ac3a7a4..2a7f1bd91 100644 --- a/py/uv.lock +++ b/py/uv.lock @@ -675,8 +675,10 @@ version = "0.1.0" source = { editable = "packages/genkit" } dependencies = [ { name = "dotprompt" }, + { name = "json5" }, { name = "opentelemetry-api" }, { name = "opentelemetry-sdk" }, + { name = "partial-json-parser" }, { name = "pydantic" }, { name = "requests" }, ] @@ -684,8 +686,10 @@ dependencies = [ [package.metadata] requires-dist = [ { name = "dotprompt" }, + { name = "json5", specifier = ">=0.10.0" }, { name = "opentelemetry-api", specifier = ">=1.29.0" }, { name = "opentelemetry-sdk", specifier = ">=1.29.0" }, + { name = "partial-json-parser", specifier = ">=0.2.1.1.post5" }, { name = "pydantic", specifier = ">=2.10.5" }, { name = "requests", specifier = ">=2.32.3" }, ] @@ -1991,6 +1995,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c6/ac/dac4a63f978e4dcb3c6d3a78c4d8e0192a113d288502a1216950c41b1027/parso-0.8.4-py2.py3-none-any.whl", hash = "sha256:a418670a20291dacd2dddc80c377c5c3791378ee1e8d12bffc35420643d43f18", size = 103650 }, ] +[[package]] +name = "partial-json-parser" +version = "0.2.1.1.post5" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/27/9c/9c366aed65acb40a97842ce1375a87b27ea37d735fc9717f7729bae3cc00/partial_json_parser-0.2.1.1.post5.tar.gz", hash = "sha256:992710ac67e90b367921d52727698928040f7713ba7ecb33b96371ea7aec82ca", size = 10313 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8c/ee/a9476f01f27c74420601be208c6c2c0dd3486681d515e9d765931b89851c/partial_json_parser-0.2.1.1.post5-py3-none-any.whl", hash = "sha256:627715aaa3cb3fb60a65b0d62223243acaa6c70846520a90326fef3a2f0b61ca", size = 10885 }, +] + [[package]] name = "pathspec" version = "0.12.1"