diff --git a/libertai_agents/agents.py b/libertai_agents/agents.py
index 1343b44..6445d0f 100644
--- a/libertai_agents/agents.py
+++ b/libertai_agents/agents.py
@@ -1,3 +1,5 @@
+from typing import Callable
+
import aiohttp
from aiohttp import ClientSession
@@ -9,9 +11,9 @@
class ChatAgent:
model: Model
system_prompt: str
- tools: list
+ tools: list[Callable]
- def __init__(self, model: Model, system_prompt: str, tools: list | None = None):
+ def __init__(self, model: Model, system_prompt: str, tools: list[Callable] | None = None):
if tools is None:
tools = []
@@ -36,33 +38,35 @@ async def generate_answer(self, messages: list[Message]) -> str:
if len(tool_calls) == 0:
return response
messages.append(self.__create_tool_calls_message(tool_calls))
- tool_messages = self.execute_tool_calls(tool_calls)
+ tool_messages = self.__execute_tool_calls(tool_calls)
return await self.generate_answer(messages + tool_messages)
async def __call_model(self, session: ClientSession, prompt: str):
params = LlamaCppParams(prompt=prompt)
async with session.post(self.model.vm_url, json=params.model_dump()) as response:
+ # TODO: handle errors and retries
if response.status == 200:
response_data = await response.json()
return response_data["content"]
- def execute_tool_calls(self, tool_calls: list[ToolCallFunction]) -> list[Message]:
+ def __execute_tool_calls(self, tool_calls: list[ToolCallFunction]) -> list[Message]:
# TODO: support async function calls
messages = []
for call in tool_calls:
- function_to_call = find(lambda x: x.__name__ == call.name, self.tools)
+ function_name = call.name
+ function_to_call = find(lambda x: x.__name__ == function_name, self.tools)
if function_to_call is None:
# TODO: handle error
continue
function_response = function_to_call(*call.arguments.values())
- messages.append(Message(role=MessageRoleEnum.tool, name=call.name, content=str(function_response)))
+ messages.append(Message(role=MessageRoleEnum.tool, name=function_name, content=str(function_response)))
return messages
- @staticmethod
- def __create_tool_calls_message(tool_calls: list[ToolCallFunction]) -> Message:
+ def __create_tool_calls_message(self, tool_calls: list[ToolCallFunction]) -> Message:
return Message(role=MessageRoleEnum.assistant,
tool_calls=[MessageToolCall(type="function",
+ id=self.model.generate_tool_call_id(),
function=ToolCallFunction(name=call.name,
arguments=call.arguments)) for
call in
diff --git a/libertai_agents/interfaces.py b/libertai_agents/interfaces.py
index a39a0ae..0f6e641 100644
--- a/libertai_agents/interfaces.py
+++ b/libertai_agents/interfaces.py
@@ -18,6 +18,7 @@ class ToolCallFunction(BaseModel):
class MessageToolCall(BaseModel):
type: str
+ id: Optional[str] = None
function: ToolCallFunction
@@ -25,6 +26,7 @@ class Message(BaseModel):
role: MessageRoleEnum
name: Optional[str] = None
content: Optional[str] = None
+ tool_call_id: Optional[str] = None
tool_calls: Optional[list[MessageToolCall]] = None
diff --git a/libertai_agents/models/__init__.py b/libertai_agents/models/__init__.py
new file mode 100644
index 0000000..7d5547c
--- /dev/null
+++ b/libertai_agents/models/__init__.py
@@ -0,0 +1,2 @@
+from .base import Model
+from .models import get_model
diff --git a/libertai_agents/models.py b/libertai_agents/models/base.py
similarity index 64%
rename from libertai_agents/models.py
rename to libertai_agents/models/base.py
index 703cde5..7527ea0 100644
--- a/libertai_agents/models.py
+++ b/libertai_agents/models/base.py
@@ -1,12 +1,11 @@
-import json
-import re
+from abc import ABC, abstractmethod
-from transformers import AutoTokenizer, PreTrainedTokenizerFast
+from transformers import PreTrainedTokenizerFast, AutoTokenizer
from libertai_agents.interfaces import Message, MessageRoleEnum, ToolCallFunction
-class Model:
+class Model(ABC):
tokenizer: PreTrainedTokenizerFast
vm_url: str
@@ -21,11 +20,11 @@ def generate_prompt(self, messages: list[Message], system_prompt: str, tools: li
return self.tokenizer.apply_chat_template(conversation=raw_messages, tools=tools, tokenize=False,
add_generation_prompt=True)
+ @abstractmethod
+ def generate_tool_call_id(self) -> str | None:
+ pass
+
@staticmethod
+ @abstractmethod
def extract_tool_calls_from_response(response: str) -> list[ToolCallFunction]:
- tool_calls = re.findall("^\s*(.*)\s*$", response)
- return [ToolCallFunction(**json.loads(call)) for call in tool_calls]
-
-
-Hermes2Pro = Model(model_id="NousResearch/Hermes-2-Pro-Llama-3-8B",
- vm_url='https://curated.aleph.cloud/vm/84df52ac4466d121ef3bb409bb14f315de7be4ce600e8948d71df6485aa5bcc3/completion')
+ pass
diff --git a/libertai_agents/models/hermes.py b/libertai_agents/models/hermes.py
new file mode 100644
index 0000000..f6aa1c0
--- /dev/null
+++ b/libertai_agents/models/hermes.py
@@ -0,0 +1,23 @@
+import json
+import re
+
+from transformers import PreTrainedTokenizerFast
+
+from libertai_agents.interfaces import ToolCallFunction
+from libertai_agents.models.base import Model
+
+
+class HermesModel(Model):
+ tokenizer: PreTrainedTokenizerFast
+ vm_url: str
+
+ def __init__(self, model_id: str, vm_url: str):
+ super().__init__(model_id, vm_url)
+
+ @staticmethod
+ def extract_tool_calls_from_response(response: str) -> list[ToolCallFunction]:
+ tool_calls = re.findall("\s*(.*)\s*", response)
+ return [ToolCallFunction(**json.loads(call)) for call in tool_calls]
+
+ def generate_tool_call_id(self) -> str | None:
+ return None
diff --git a/libertai_agents/models/mistral.py b/libertai_agents/models/mistral.py
new file mode 100644
index 0000000..bc73017
--- /dev/null
+++ b/libertai_agents/models/mistral.py
@@ -0,0 +1,23 @@
+import json
+import re
+
+from transformers import PreTrainedTokenizerFast
+
+from libertai_agents.interfaces import ToolCallFunction
+from libertai_agents.models.base import Model
+
+
+class MistralModel(Model):
+ tokenizer: PreTrainedTokenizerFast
+ vm_url: str
+
+ def __init__(self, model_id: str, vm_url: str):
+ super().__init__(model_id, vm_url)
+
+ @staticmethod
+ def extract_tool_calls_from_response(response: str) -> list[ToolCallFunction]:
+ tool_calls = re.findall("\s*(.*)\s*", response)
+ return [ToolCallFunction(**json.loads(call)) for call in tool_calls]
+
+ def generate_tool_call_id(self) -> str | None:
+ return None
diff --git a/libertai_agents/models/models.py b/libertai_agents/models/models.py
new file mode 100644
index 0000000..66763a8
--- /dev/null
+++ b/libertai_agents/models/models.py
@@ -0,0 +1,43 @@
+import typing
+
+from huggingface_hub import login
+from pydantic import BaseModel
+
+from libertai_agents.models.base import Model
+from libertai_agents.models.hermes import HermesModel
+from libertai_agents.models.mistral import MistralModel
+
+
+class ModelConfiguration(BaseModel):
+ vm_url: str
+ constructor: typing.Type[Model]
+
+
+ModelId = typing.Literal[
+ "NousResearch/Hermes-2-Pro-Llama-3-8B",
+ "NousResearch/Hermes-3-Llama-3.1-8B",
+ "mistralai/Mistral-Nemo-Instruct-2407"
+]
+MODEL_IDS: list[ModelId] = list(typing.get_args(ModelId))
+
+MODELS_CONFIG: dict[ModelId, ModelConfiguration] = {
+ "NousResearch/Hermes-2-Pro-Llama-3-8B": ModelConfiguration(
+ vm_url="https://curated.aleph.cloud/vm/84df52ac4466d121ef3bb409bb14f315de7be4ce600e8948d71df6485aa5bcc3/completion",
+ constructor=HermesModel),
+ "NousResearch/Hermes-3-Llama-3.1-8B": ModelConfiguration(vm_url="http://localhost:8080/completion",
+ constructor=HermesModel),
+ "mistralai/Mistral-Nemo-Instruct-2407": ModelConfiguration(vm_url="http://localhost:8080/completion",
+ constructor=MistralModel)
+}
+
+
+def get_model(model_id: ModelId, hf_token: str | None = None) -> Model:
+ model_configuration = MODELS_CONFIG.get(model_id)
+
+ if model_configuration is None:
+ raise ValueError(f'model_id must be one of {MODEL_IDS}')
+
+ if hf_token is not None:
+ login(hf_token)
+
+ return model_configuration.constructor(model_id=model_id, **model_configuration.model_dump(exclude={'constructor'}))
diff --git a/main.py b/main.py
index fef78bb..8946fe1 100644
--- a/main.py
+++ b/main.py
@@ -2,14 +2,16 @@
from libertai_agents.agents import ChatAgent
from libertai_agents.interfaces import Message, MessageRoleEnum
-from libertai_agents.models import Hermes2Pro
+from libertai_agents.models import get_model
from libertai_agents.tools import get_current_temperature
async def start():
- agent = ChatAgent(model=Hermes2Pro, system_prompt="You are a helpful assistant", tools=[get_current_temperature])
+ agent = ChatAgent(model=get_model("mistralai/Mistral-Nemo-Instruct-2407"),
+ system_prompt="You are a helpful assistant",
+ tools=[get_current_temperature])
response = await agent.generate_answer(
- [Message(role=MessageRoleEnum.user, content="What's the temperature in Paris ?")])
+ [Message(role=MessageRoleEnum.user, content="What's the temperature in Paris and in Lyon ?")])
print(response)