Skip to content

Commit

Permalink
Update to support (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw authored Jan 21, 2025
1 parent eaaaad1 commit db8edb6
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 9 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ dependencies = [
"jsonpatch<2.0,>=1.33",
]
name = "trustcall"
version = "0.0.26"
version = "0.0.27"
description = "Tenacious & trustworthy tool calling built on LangGraph."
readme = "README.md"

Expand All @@ -31,6 +31,7 @@ dev = [
"anyio>=4.7.0",
"pytest-asyncio-cooperative>=0.37.0",
]
standard = ["langchain>=0.3"]

[tool.setuptools]
packages = ["trustcall"]
Expand Down
11 changes: 5 additions & 6 deletions tests/evals/test_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import langsmith as ls
import pytest
from dydantic import create_model_from_schema
from langchain.chat_models import init_chat_model
from langsmith import aevaluate, expect, traceable
from langsmith.evaluation import EvaluationResults
from langsmith.schemas import Example, Run
Expand Down Expand Up @@ -61,8 +60,7 @@ async def predict_with_model(
),
("user", inputs["input_str"]),
]
llm = init_chat_model(model_name, temperature=0.8)
extractor = create_extractor(llm, tools=[tool_def], tool_choice=tool_def["name"])
extractor = create_extractor(model_name, tools=[tool_def], tool_choice=tool_def["name"])
existing = inputs.get("current_value", {})
extractor_inputs: dict = {"messages": messages}
if existing:
Expand Down Expand Up @@ -226,7 +224,7 @@ def query_docs(query: str) -> str:
return "I am a document."

extractor = create_extractor(
init_chat_model("gpt-4o"), tools=[query_docs], tool_choice="query_docs"
"gpt-4o", tools=[query_docs], tool_choice="query_docs"
)
extractor.invoke({"messages": [("user", "What are the docs about?")]})

Expand All @@ -246,8 +244,9 @@ def validate_query_length(cls, v: str) -> str:
)
return v

llm = init_chat_model("gpt-4o-mini")
extractor = create_extractor(llm, tools=[query_docs], tool_choice="any")
extractor = create_extractor(
"gpt-4o", tools=[query_docs], tool_choice="any"
)
extractor.invoke(
{
"messages": [
Expand Down
13 changes: 12 additions & 1 deletion trustcall/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class ExtractionOutputs(TypedDict):


def create_extractor(
llm: BaseChatModel,
llm: str | BaseChatModel,
*,
tools: Sequence[TOOL_T],
tool_choice: Optional[str] = None,
Expand Down Expand Up @@ -258,6 +258,17 @@ def create_extractor(
... }
... )
""" # noqa
if isinstance(llm, str):
try:
from langchain.chat_models import init_chat_model
except ImportError:
raise ImportError(
"Creating extractors from a string requires langchain>=0.3.0,"
" as well as the provider-specific package"
" (like langchain-openai, langchain-anthropic, etc.)"
" Please install langchain to continue."
)
llm = init_chat_model(llm)
builder = StateGraph(ExtractionState)

def format_exception(error: BaseException, call: ToolCall, schema: Type[BaseModel]):
Expand Down
6 changes: 5 additions & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit db8edb6

Please sign in to comment.