Skip to content

Commit

Permalink
v0.0.58
Browse files Browse the repository at this point in the history
  • Loading branch information
Josh-XT committed Jul 21, 2024
1 parent 3fd6f0f commit d98f456
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 3 deletions.
100 changes: 98 additions & 2 deletions agixtsdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,20 @@
from datetime import datetime
from pydub import AudioSegment
from pydantic import BaseModel
from typing import Dict, List, Any, Optional, Callable
from typing import (
Dict,
List,
Any,
Optional,
Callable,
Type,
get_args,
get_origin,
Union,
)
from enum import Enum
from pydantic import BaseModel
import json


class ChatCompletions(BaseModel):
Expand Down Expand Up @@ -62,9 +75,9 @@ def __init__(
"Authorization": f"{api_key}",
"Content-Type": "application/json",
}

if self.base_uri[-1] == "/":
self.base_uri = self.base_uri[:-1]
self.failures = 0

def handle_error(self, error) -> str:
print(f"Error: {error}")
Expand Down Expand Up @@ -1660,3 +1673,86 @@ def plan_task(
return response.json()["response"]
except Exception as e:
return self.handle_error(e)

def convert_to_model(
self,
input_string: str,
model: Type[BaseModel],
agent_name: str = "gpt4free",
max_failures: int = 3,
response_type: str = None,
):
input_string = str(input_string)
fields = model.__annotations__
field_descriptions = []
for field, field_type in fields.items():
description = f"{field}: {field_type}"
if get_origin(field_type) == Union:
field_type = get_args(field_type)[0]
if isinstance(field_type, type) and issubclass(field_type, Enum):
enum_values = ", ".join([f"{e.name} = {e.value}" for e in field_type])
description += f" (Enum values: {enum_values})"
field_descriptions.append(description)
schema = "\n".join(field_descriptions)
response = self.prompt_agent(
agent_name=agent_name,
prompt_name="Convert to Pydantic Model",
prompt_args={
"schema": schema,
"user_input": input_string,
},
)
if "```json" in response:
response = response.split("```json")[1].split("```")[0].strip()
elif "```" in response:
response = response.split("```")[1].strip()
try:
response = json.loads(response)
if response_type == "json":
return response
else:
return model(**response)
except Exception as e:
self.failures += 1
if self.failures > max_failures:
print(
f"Error: {e} . Failed to convert the response to the model after 3 attempts. Response: {response}"
)
return (
response
if response
else "Failed to convert the response to the model."
)
else:
self.failures = 1
print(
f"Error: {e} . Failed to convert the response to the model, trying again. {self.failures}/3 failures. Response: {response}"
)
return self.convert_to_model(
input_string=input_string,
model=model,
agent_name=agent_name,
max_failures=max_failures,
failures=self.failures,
)

def convert_list_of_dicts(
self,
data: List[dict],
model: Type[BaseModel],
agent_name: str = "gpt4free",
):
converted_data = self.convert_to_model(
input_string=json.dumps(data[0], indent=4),
model=model,
agent_name=agent_name,
)
mapped_list = []
for info in data:
new_data = {}
for key, value in converted_data.items():
item = [k for k, v in data[0].items() if v == value]
if item:
new_data[key] = info[item[0]]
mapped_list.append(new_data)
return mapped_list
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

setup(
name="agixtsdk",
version="0.0.57",
version="0.0.58",
description="The AGiXT SDK for Python.",
long_description=long_description,
long_description_content_type="text/markdown",
Expand Down

0 comments on commit d98f456

Please sign in to comment.