diff --git a/libertai_agents/agents.py b/libertai_agents/agents.py
index 6445d0f..bd383c9 100644
--- a/libertai_agents/agents.py
+++ b/libertai_agents/agents.py
@@ -3,7 +3,8 @@
import aiohttp
from aiohttp import ClientSession
-from libertai_agents.interfaces import Message, MessageRoleEnum, LlamaCppParams, MessageToolCall, ToolCallFunction
+from libertai_agents.interfaces import Message, MessageRoleEnum, LlamaCppParams, MessageToolCall, ToolCallFunction, \
+ ToolCallMessage, ToolResponseMessage
from libertai_agents.models import Model
from libertai_agents.utils import find
@@ -30,15 +31,15 @@ async def generate_answer(self, messages: list[Message]) -> str:
raise ValueError("Last message is not from the user or tool")
prompt = self.model.generate_prompt(messages, self.system_prompt, self.tools)
- print(prompt)
async with aiohttp.ClientSession() as session:
response = await self.__call_model(session, prompt)
tool_calls = self.model.extract_tool_calls_from_response(response)
if len(tool_calls) == 0:
return response
- messages.append(self.__create_tool_calls_message(tool_calls))
- tool_messages = self.__execute_tool_calls(tool_calls)
+ tool_calls_message = self.__create_tool_calls_message(tool_calls)
+ messages.append(tool_calls_message)
+ tool_messages = self.__execute_tool_calls(tool_calls_message.tool_calls)
return await self.generate_answer(messages + tool_messages)
async def __call_model(self, session: ClientSession, prompt: str):
@@ -50,24 +51,26 @@ async def __call_model(self, session: ClientSession, prompt: str):
response_data = await response.json()
return response_data["content"]
- def __execute_tool_calls(self, tool_calls: list[ToolCallFunction]) -> list[Message]:
+ def __execute_tool_calls(self, tool_calls: list[MessageToolCall]) -> list[Message]:
# TODO: support async function calls
- messages = []
+ messages: list[Message] = []
for call in tool_calls:
- function_name = call.name
+ function_name = call.function.name
function_to_call = find(lambda x: x.__name__ == function_name, self.tools)
if function_to_call is None:
# TODO: handle error
continue
- function_response = function_to_call(*call.arguments.values())
- messages.append(Message(role=MessageRoleEnum.tool, name=function_name, content=str(function_response)))
+ function_response = function_to_call(*call.function.arguments.values())
+ messages.append(
+ ToolResponseMessage(role=MessageRoleEnum.tool, name=function_name, tool_call_id=call.id,
+ content=str(function_response)))
return messages
- def __create_tool_calls_message(self, tool_calls: list[ToolCallFunction]) -> Message:
- return Message(role=MessageRoleEnum.assistant,
- tool_calls=[MessageToolCall(type="function",
- id=self.model.generate_tool_call_id(),
- function=ToolCallFunction(name=call.name,
- arguments=call.arguments)) for
- call in
- tool_calls])
+ def __create_tool_calls_message(self, tool_calls: list[ToolCallFunction]) -> ToolCallMessage:
+ return ToolCallMessage(role=MessageRoleEnum.assistant,
+ tool_calls=[MessageToolCall(type="function",
+ id=self.model.generate_tool_call_id(),
+ function=ToolCallFunction(name=call.name,
+ arguments=call.arguments)) for
+ call in
+ tool_calls])
diff --git a/libertai_agents/interfaces.py b/libertai_agents/interfaces.py
index 0f6e641..1b81f2a 100644
--- a/libertai_agents/interfaces.py
+++ b/libertai_agents/interfaces.py
@@ -24,10 +24,16 @@ class MessageToolCall(BaseModel):
class Message(BaseModel):
role: MessageRoleEnum
- name: Optional[str] = None
content: Optional[str] = None
+
+
+class ToolCallMessage(Message):
+ tool_calls: list[MessageToolCall]
+
+
+class ToolResponseMessage(Message):
+ name: Optional[str] = None
tool_call_id: Optional[str] = None
- tool_calls: Optional[list[MessageToolCall]] = None
class LlamaCppParams(BaseModel):
diff --git a/libertai_agents/models/__init__.py b/libertai_agents/models/__init__.py
index 7d5547c..72de728 100644
--- a/libertai_agents/models/__init__.py
+++ b/libertai_agents/models/__init__.py
@@ -1,2 +1,2 @@
-from .base import Model
-from .models import get_model
+from .base import Model as Model
+from .models import get_model as get_model
diff --git a/libertai_agents/models/base.py b/libertai_agents/models/base.py
index 7527ea0..d5f9f5d 100644
--- a/libertai_agents/models/base.py
+++ b/libertai_agents/models/base.py
@@ -2,27 +2,29 @@
from transformers import PreTrainedTokenizerFast, AutoTokenizer
-from libertai_agents.interfaces import Message, MessageRoleEnum, ToolCallFunction
+from libertai_agents.interfaces import Message, ToolCallFunction, MessageRoleEnum
class Model(ABC):
tokenizer: PreTrainedTokenizerFast
vm_url: str
+ system_message: bool
- def __init__(self, model_id: str, vm_url: str):
+ def __init__(self, model_id: str, vm_url: str, system_message: bool = True):
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.vm_url = vm_url
+ self.system_message = system_message
def generate_prompt(self, messages: list[Message], system_prompt: str, tools: list) -> str:
- messages.insert(0, Message(role=MessageRoleEnum.system, content=system_prompt))
+ if self.system_message:
+ messages.insert(0, Message(role=MessageRoleEnum.system, content=system_prompt))
raw_messages = list(map(lambda x: x.model_dump(), messages))
return self.tokenizer.apply_chat_template(conversation=raw_messages, tools=tools, tokenize=False,
add_generation_prompt=True)
- @abstractmethod
def generate_tool_call_id(self) -> str | None:
- pass
+ return None
@staticmethod
@abstractmethod
diff --git a/libertai_agents/models/hermes.py b/libertai_agents/models/hermes.py
index f6aa1c0..1500321 100644
--- a/libertai_agents/models/hermes.py
+++ b/libertai_agents/models/hermes.py
@@ -1,16 +1,11 @@
import json
import re
-from transformers import PreTrainedTokenizerFast
-
from libertai_agents.interfaces import ToolCallFunction
from libertai_agents.models.base import Model
class HermesModel(Model):
- tokenizer: PreTrainedTokenizerFast
- vm_url: str
-
def __init__(self, model_id: str, vm_url: str):
super().__init__(model_id, vm_url)
@@ -18,6 +13,3 @@ def __init__(self, model_id: str, vm_url: str):
def extract_tool_calls_from_response(response: str) -> list[ToolCallFunction]:
tool_calls = re.findall("\s*(.*)\s*", response)
return [ToolCallFunction(**json.loads(call)) for call in tool_calls]
-
- def generate_tool_call_id(self) -> str | None:
- return None
diff --git a/libertai_agents/models/mistral.py b/libertai_agents/models/mistral.py
index bc73017..cddc814 100644
--- a/libertai_agents/models/mistral.py
+++ b/libertai_agents/models/mistral.py
@@ -1,23 +1,22 @@
import json
-import re
-
-from transformers import PreTrainedTokenizerFast
+import random
+import string
from libertai_agents.interfaces import ToolCallFunction
from libertai_agents.models.base import Model
class MistralModel(Model):
- tokenizer: PreTrainedTokenizerFast
- vm_url: str
-
def __init__(self, model_id: str, vm_url: str):
- super().__init__(model_id, vm_url)
+ super().__init__(model_id, vm_url, system_message=False)
@staticmethod
def extract_tool_calls_from_response(response: str) -> list[ToolCallFunction]:
- tool_calls = re.findall("\s*(.*)\s*", response)
- return [ToolCallFunction(**json.loads(call)) for call in tool_calls]
+ try:
+ tool_calls = json.loads(response)
+ return [ToolCallFunction(**call) for call in tool_calls]
+ except Exception:
+ return []
- def generate_tool_call_id(self) -> str | None:
- return None
+ def generate_tool_call_id(self) -> str:
+ return ''.join(random.choice(string.ascii_letters + string.digits) for _ in range(9))
diff --git a/main.py b/main.py
index 8946fe1..262df25 100644
--- a/main.py
+++ b/main.py
@@ -11,7 +11,7 @@ async def start():
system_prompt="You are a helpful assistant",
tools=[get_current_temperature])
response = await agent.generate_answer(
- [Message(role=MessageRoleEnum.user, content="What's the temperature in Paris and in Lyon ?")])
+ [Message(role=MessageRoleEnum.user, content="What's the temperature in Paris and in Lyon in Celsius ?")])
print(response)