From 9790c367cf8a54cba84d04de074c4a215f5f8245 Mon Sep 17 00:00:00 2001 From: William Fu-Hinthorn <13333726+hinthornw@users.noreply.github.com> Date: Wed, 22 Jan 2025 16:10:13 -0800 Subject: [PATCH] Update for coerced existing types --- pyproject.toml | 2 +- trustcall/_base.py | 58 ++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 54 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8eacaa5..40c9ab9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ dependencies = [ "jsonpatch<2.0,>=1.33", ] name = "trustcall" -version = "0.0.27" +version = "0.0.28" description = "Tenacious & trustworthy tool calling built on LangGraph." readme = "README.md" diff --git a/trustcall/_base.py b/trustcall/_base.py index a0ae6dc..b4858dc 100644 --- a/trustcall/_base.py +++ b/trustcall/_base.py @@ -80,7 +80,7 @@ class SchemaInstance(NamedTuple): """ record_id: str - schema_name: str + schema_name: str | Literal["__any__"] record: Dict[str, Any] @@ -653,7 +653,7 @@ def _setup(self, state: ExtractionState): existing = state.existing if not existing: raise ValueError("No existing schemas provided.") - self._validate_existing(existing) + existing = self._validate_existing(existing) schema_strings = [] if isinstance(existing, dict): for k, v in existing.items(): @@ -720,7 +720,18 @@ def _teardown( (e for e in existing if e[0] == json_doc_id), ) if not tool_name: - raise ValueError("Could not find tool name") + raise ValueError( + f"Could not find tool name for json_doc_id {json_doc_id}" + ) + except StopIteration: + logger.error( + f"Could not find existing schema in dict for {json_doc_id}" + ) + if rt: + rt.error = ( + f"Could not find existing schema for {json_doc_id}" + ) + continue except (ValueError, IndexError, TypeError): logger.error( f"Could not find existing schema in list for {json_doc_id}" @@ -782,8 +793,10 @@ def _validate_existing(self, existing: ExistingType): " with keys matching one of the provided tool names:" f" {self._provided_tools}" ) - elif isinstance(existing, list): + return existing + if isinstance(existing, list): # For list types, validate each item's schema_name + coerced = [] for i, item in enumerate(existing): if isinstance(item, SchemaInstance): if item.schema_name not in self.tools: @@ -793,6 +806,7 @@ def _validate_existing(self, existing: ExistingType): f"name. Provided: {item}, Expected: SchemaInstance" f" with schema_name in {self._provided_tools}" ) + coerced.append(coerced) elif isinstance(item, tuple) and len(item) == 3: if item[1] not in self.tools: raise ValueError( @@ -801,12 +815,46 @@ def _validate_existing(self, existing: ExistingType): f" Expected: Tuple(str, str, dict) with second" f" element in {self._provided_tools}" ) + coerced.append(SchemaInstance(item[0], item[1], item[2])) + elif isinstance(item, tuple) and len(item) == 2: + # Assume record_ID, item + if hasattr(item[1], "__name__"): + schema_name = item[1].__name__ + else: + schema_name = item[1].__repr_name__() + + if schema_name not in self.tools: + raise ValueError( + f"Schema name '{schema_name}' at index {i} does" + f" not match any tool name. Provided: {item}," + f" Expected: Tuple(str, str, dict) with second" + f" element in {self._provided_tools}" + ) + val = ( + item[1].model_dump(mode="json") + if isinstance(item[1], BaseModel) + else item[1] + ) + coerced.append(SchemaInstance(item[0], schema_name, val)) + elif isinstance(item, BaseModel): + if hasattr(item, "__name__"): + schema_name = item.__name__ + else: + schema_name = item.__repr_name__() + coerced.append( + SchemaInstance( + str(uuid.uuid4()), + schema_name, + item.model_dump(mode="json"), + ) + ) else: raise ValueError( f"Invalid item at index {i} in existing list." f" Provided: {item}, Expected: SchemaInstance" - f" or Tuple(str, str, dict)" + f" or Tuple(str, str, dict) or BaseModel" ) + return coerced else: raise ValueError( f"Invalid type for existing. Provided: {type(existing)},"