diff --git a/libertai_agents/agents.py b/libertai_agents/agents.py index 6445d0f..bd383c9 100644 --- a/libertai_agents/agents.py +++ b/libertai_agents/agents.py @@ -3,7 +3,8 @@ import aiohttp from aiohttp import ClientSession -from libertai_agents.interfaces import Message, MessageRoleEnum, LlamaCppParams, MessageToolCall, ToolCallFunction +from libertai_agents.interfaces import Message, MessageRoleEnum, LlamaCppParams, MessageToolCall, ToolCallFunction, \ + ToolCallMessage, ToolResponseMessage from libertai_agents.models import Model from libertai_agents.utils import find @@ -30,15 +31,15 @@ async def generate_answer(self, messages: list[Message]) -> str: raise ValueError("Last message is not from the user or tool") prompt = self.model.generate_prompt(messages, self.system_prompt, self.tools) - print(prompt) async with aiohttp.ClientSession() as session: response = await self.__call_model(session, prompt) tool_calls = self.model.extract_tool_calls_from_response(response) if len(tool_calls) == 0: return response - messages.append(self.__create_tool_calls_message(tool_calls)) - tool_messages = self.__execute_tool_calls(tool_calls) + tool_calls_message = self.__create_tool_calls_message(tool_calls) + messages.append(tool_calls_message) + tool_messages = self.__execute_tool_calls(tool_calls_message.tool_calls) return await self.generate_answer(messages + tool_messages) async def __call_model(self, session: ClientSession, prompt: str): @@ -50,24 +51,26 @@ async def __call_model(self, session: ClientSession, prompt: str): response_data = await response.json() return response_data["content"] - def __execute_tool_calls(self, tool_calls: list[ToolCallFunction]) -> list[Message]: + def __execute_tool_calls(self, tool_calls: list[MessageToolCall]) -> list[Message]: # TODO: support async function calls - messages = [] + messages: list[Message] = [] for call in tool_calls: - function_name = call.name + function_name = call.function.name function_to_call = find(lambda x: x.__name__ == function_name, self.tools) if function_to_call is None: # TODO: handle error continue - function_response = function_to_call(*call.arguments.values()) - messages.append(Message(role=MessageRoleEnum.tool, name=function_name, content=str(function_response))) + function_response = function_to_call(*call.function.arguments.values()) + messages.append( + ToolResponseMessage(role=MessageRoleEnum.tool, name=function_name, tool_call_id=call.id, + content=str(function_response))) return messages - def __create_tool_calls_message(self, tool_calls: list[ToolCallFunction]) -> Message: - return Message(role=MessageRoleEnum.assistant, - tool_calls=[MessageToolCall(type="function", - id=self.model.generate_tool_call_id(), - function=ToolCallFunction(name=call.name, - arguments=call.arguments)) for - call in - tool_calls]) + def __create_tool_calls_message(self, tool_calls: list[ToolCallFunction]) -> ToolCallMessage: + return ToolCallMessage(role=MessageRoleEnum.assistant, + tool_calls=[MessageToolCall(type="function", + id=self.model.generate_tool_call_id(), + function=ToolCallFunction(name=call.name, + arguments=call.arguments)) for + call in + tool_calls]) diff --git a/libertai_agents/interfaces.py b/libertai_agents/interfaces.py index 0f6e641..1b81f2a 100644 --- a/libertai_agents/interfaces.py +++ b/libertai_agents/interfaces.py @@ -24,10 +24,16 @@ class MessageToolCall(BaseModel): class Message(BaseModel): role: MessageRoleEnum - name: Optional[str] = None content: Optional[str] = None + + +class ToolCallMessage(Message): + tool_calls: list[MessageToolCall] + + +class ToolResponseMessage(Message): + name: Optional[str] = None tool_call_id: Optional[str] = None - tool_calls: Optional[list[MessageToolCall]] = None class LlamaCppParams(BaseModel): diff --git a/libertai_agents/models/__init__.py b/libertai_agents/models/__init__.py index 7d5547c..72de728 100644 --- a/libertai_agents/models/__init__.py +++ b/libertai_agents/models/__init__.py @@ -1,2 +1,2 @@ -from .base import Model -from .models import get_model +from .base import Model as Model +from .models import get_model as get_model diff --git a/libertai_agents/models/base.py b/libertai_agents/models/base.py index 7527ea0..d5f9f5d 100644 --- a/libertai_agents/models/base.py +++ b/libertai_agents/models/base.py @@ -2,27 +2,29 @@ from transformers import PreTrainedTokenizerFast, AutoTokenizer -from libertai_agents.interfaces import Message, MessageRoleEnum, ToolCallFunction +from libertai_agents.interfaces import Message, ToolCallFunction, MessageRoleEnum class Model(ABC): tokenizer: PreTrainedTokenizerFast vm_url: str + system_message: bool - def __init__(self, model_id: str, vm_url: str): + def __init__(self, model_id: str, vm_url: str, system_message: bool = True): self.tokenizer = AutoTokenizer.from_pretrained(model_id) self.vm_url = vm_url + self.system_message = system_message def generate_prompt(self, messages: list[Message], system_prompt: str, tools: list) -> str: - messages.insert(0, Message(role=MessageRoleEnum.system, content=system_prompt)) + if self.system_message: + messages.insert(0, Message(role=MessageRoleEnum.system, content=system_prompt)) raw_messages = list(map(lambda x: x.model_dump(), messages)) return self.tokenizer.apply_chat_template(conversation=raw_messages, tools=tools, tokenize=False, add_generation_prompt=True) - @abstractmethod def generate_tool_call_id(self) -> str | None: - pass + return None @staticmethod @abstractmethod diff --git a/libertai_agents/models/hermes.py b/libertai_agents/models/hermes.py index f6aa1c0..1500321 100644 --- a/libertai_agents/models/hermes.py +++ b/libertai_agents/models/hermes.py @@ -1,16 +1,11 @@ import json import re -from transformers import PreTrainedTokenizerFast - from libertai_agents.interfaces import ToolCallFunction from libertai_agents.models.base import Model class HermesModel(Model): - tokenizer: PreTrainedTokenizerFast - vm_url: str - def __init__(self, model_id: str, vm_url: str): super().__init__(model_id, vm_url) @@ -18,6 +13,3 @@ def __init__(self, model_id: str, vm_url: str): def extract_tool_calls_from_response(response: str) -> list[ToolCallFunction]: tool_calls = re.findall("\s*(.*)\s*", response) return [ToolCallFunction(**json.loads(call)) for call in tool_calls] - - def generate_tool_call_id(self) -> str | None: - return None diff --git a/libertai_agents/models/mistral.py b/libertai_agents/models/mistral.py index bc73017..cddc814 100644 --- a/libertai_agents/models/mistral.py +++ b/libertai_agents/models/mistral.py @@ -1,23 +1,22 @@ import json -import re - -from transformers import PreTrainedTokenizerFast +import random +import string from libertai_agents.interfaces import ToolCallFunction from libertai_agents.models.base import Model class MistralModel(Model): - tokenizer: PreTrainedTokenizerFast - vm_url: str - def __init__(self, model_id: str, vm_url: str): - super().__init__(model_id, vm_url) + super().__init__(model_id, vm_url, system_message=False) @staticmethod def extract_tool_calls_from_response(response: str) -> list[ToolCallFunction]: - tool_calls = re.findall("\s*(.*)\s*", response) - return [ToolCallFunction(**json.loads(call)) for call in tool_calls] + try: + tool_calls = json.loads(response) + return [ToolCallFunction(**call) for call in tool_calls] + except Exception: + return [] - def generate_tool_call_id(self) -> str | None: - return None + def generate_tool_call_id(self) -> str: + return ''.join(random.choice(string.ascii_letters + string.digits) for _ in range(9)) diff --git a/main.py b/main.py index 8946fe1..262df25 100644 --- a/main.py +++ b/main.py @@ -11,7 +11,7 @@ async def start(): system_prompt="You are a helpful assistant", tools=[get_current_temperature]) response = await agent.generate_answer( - [Message(role=MessageRoleEnum.user, content="What's the temperature in Paris and in Lyon ?")]) + [Message(role=MessageRoleEnum.user, content="What's the temperature in Paris and in Lyon in Celsius ?")]) print(response)