diff --git a/libertai_agents/agents.py b/libertai_agents/agents.py index 1343b44..6445d0f 100644 --- a/libertai_agents/agents.py +++ b/libertai_agents/agents.py @@ -1,3 +1,5 @@ +from typing import Callable + import aiohttp from aiohttp import ClientSession @@ -9,9 +11,9 @@ class ChatAgent: model: Model system_prompt: str - tools: list + tools: list[Callable] - def __init__(self, model: Model, system_prompt: str, tools: list | None = None): + def __init__(self, model: Model, system_prompt: str, tools: list[Callable] | None = None): if tools is None: tools = [] @@ -36,33 +38,35 @@ async def generate_answer(self, messages: list[Message]) -> str: if len(tool_calls) == 0: return response messages.append(self.__create_tool_calls_message(tool_calls)) - tool_messages = self.execute_tool_calls(tool_calls) + tool_messages = self.__execute_tool_calls(tool_calls) return await self.generate_answer(messages + tool_messages) async def __call_model(self, session: ClientSession, prompt: str): params = LlamaCppParams(prompt=prompt) async with session.post(self.model.vm_url, json=params.model_dump()) as response: + # TODO: handle errors and retries if response.status == 200: response_data = await response.json() return response_data["content"] - def execute_tool_calls(self, tool_calls: list[ToolCallFunction]) -> list[Message]: + def __execute_tool_calls(self, tool_calls: list[ToolCallFunction]) -> list[Message]: # TODO: support async function calls messages = [] for call in tool_calls: - function_to_call = find(lambda x: x.__name__ == call.name, self.tools) + function_name = call.name + function_to_call = find(lambda x: x.__name__ == function_name, self.tools) if function_to_call is None: # TODO: handle error continue function_response = function_to_call(*call.arguments.values()) - messages.append(Message(role=MessageRoleEnum.tool, name=call.name, content=str(function_response))) + messages.append(Message(role=MessageRoleEnum.tool, name=function_name, content=str(function_response))) return messages - @staticmethod - def __create_tool_calls_message(tool_calls: list[ToolCallFunction]) -> Message: + def __create_tool_calls_message(self, tool_calls: list[ToolCallFunction]) -> Message: return Message(role=MessageRoleEnum.assistant, tool_calls=[MessageToolCall(type="function", + id=self.model.generate_tool_call_id(), function=ToolCallFunction(name=call.name, arguments=call.arguments)) for call in diff --git a/libertai_agents/interfaces.py b/libertai_agents/interfaces.py index a39a0ae..0f6e641 100644 --- a/libertai_agents/interfaces.py +++ b/libertai_agents/interfaces.py @@ -18,6 +18,7 @@ class ToolCallFunction(BaseModel): class MessageToolCall(BaseModel): type: str + id: Optional[str] = None function: ToolCallFunction @@ -25,6 +26,7 @@ class Message(BaseModel): role: MessageRoleEnum name: Optional[str] = None content: Optional[str] = None + tool_call_id: Optional[str] = None tool_calls: Optional[list[MessageToolCall]] = None diff --git a/libertai_agents/models/__init__.py b/libertai_agents/models/__init__.py new file mode 100644 index 0000000..7d5547c --- /dev/null +++ b/libertai_agents/models/__init__.py @@ -0,0 +1,2 @@ +from .base import Model +from .models import get_model diff --git a/libertai_agents/models.py b/libertai_agents/models/base.py similarity index 64% rename from libertai_agents/models.py rename to libertai_agents/models/base.py index 703cde5..7527ea0 100644 --- a/libertai_agents/models.py +++ b/libertai_agents/models/base.py @@ -1,12 +1,11 @@ -import json -import re +from abc import ABC, abstractmethod -from transformers import AutoTokenizer, PreTrainedTokenizerFast +from transformers import PreTrainedTokenizerFast, AutoTokenizer from libertai_agents.interfaces import Message, MessageRoleEnum, ToolCallFunction -class Model: +class Model(ABC): tokenizer: PreTrainedTokenizerFast vm_url: str @@ -21,11 +20,11 @@ def generate_prompt(self, messages: list[Message], system_prompt: str, tools: li return self.tokenizer.apply_chat_template(conversation=raw_messages, tools=tools, tokenize=False, add_generation_prompt=True) + @abstractmethod + def generate_tool_call_id(self) -> str | None: + pass + @staticmethod + @abstractmethod def extract_tool_calls_from_response(response: str) -> list[ToolCallFunction]: - tool_calls = re.findall("^\s*(.*)\s*$", response) - return [ToolCallFunction(**json.loads(call)) for call in tool_calls] - - -Hermes2Pro = Model(model_id="NousResearch/Hermes-2-Pro-Llama-3-8B", - vm_url='https://curated.aleph.cloud/vm/84df52ac4466d121ef3bb409bb14f315de7be4ce600e8948d71df6485aa5bcc3/completion') + pass diff --git a/libertai_agents/models/hermes.py b/libertai_agents/models/hermes.py new file mode 100644 index 0000000..f6aa1c0 --- /dev/null +++ b/libertai_agents/models/hermes.py @@ -0,0 +1,23 @@ +import json +import re + +from transformers import PreTrainedTokenizerFast + +from libertai_agents.interfaces import ToolCallFunction +from libertai_agents.models.base import Model + + +class HermesModel(Model): + tokenizer: PreTrainedTokenizerFast + vm_url: str + + def __init__(self, model_id: str, vm_url: str): + super().__init__(model_id, vm_url) + + @staticmethod + def extract_tool_calls_from_response(response: str) -> list[ToolCallFunction]: + tool_calls = re.findall("\s*(.*)\s*", response) + return [ToolCallFunction(**json.loads(call)) for call in tool_calls] + + def generate_tool_call_id(self) -> str | None: + return None diff --git a/libertai_agents/models/mistral.py b/libertai_agents/models/mistral.py new file mode 100644 index 0000000..bc73017 --- /dev/null +++ b/libertai_agents/models/mistral.py @@ -0,0 +1,23 @@ +import json +import re + +from transformers import PreTrainedTokenizerFast + +from libertai_agents.interfaces import ToolCallFunction +from libertai_agents.models.base import Model + + +class MistralModel(Model): + tokenizer: PreTrainedTokenizerFast + vm_url: str + + def __init__(self, model_id: str, vm_url: str): + super().__init__(model_id, vm_url) + + @staticmethod + def extract_tool_calls_from_response(response: str) -> list[ToolCallFunction]: + tool_calls = re.findall("\s*(.*)\s*", response) + return [ToolCallFunction(**json.loads(call)) for call in tool_calls] + + def generate_tool_call_id(self) -> str | None: + return None diff --git a/libertai_agents/models/models.py b/libertai_agents/models/models.py new file mode 100644 index 0000000..66763a8 --- /dev/null +++ b/libertai_agents/models/models.py @@ -0,0 +1,43 @@ +import typing + +from huggingface_hub import login +from pydantic import BaseModel + +from libertai_agents.models.base import Model +from libertai_agents.models.hermes import HermesModel +from libertai_agents.models.mistral import MistralModel + + +class ModelConfiguration(BaseModel): + vm_url: str + constructor: typing.Type[Model] + + +ModelId = typing.Literal[ + "NousResearch/Hermes-2-Pro-Llama-3-8B", + "NousResearch/Hermes-3-Llama-3.1-8B", + "mistralai/Mistral-Nemo-Instruct-2407" +] +MODEL_IDS: list[ModelId] = list(typing.get_args(ModelId)) + +MODELS_CONFIG: dict[ModelId, ModelConfiguration] = { + "NousResearch/Hermes-2-Pro-Llama-3-8B": ModelConfiguration( + vm_url="https://curated.aleph.cloud/vm/84df52ac4466d121ef3bb409bb14f315de7be4ce600e8948d71df6485aa5bcc3/completion", + constructor=HermesModel), + "NousResearch/Hermes-3-Llama-3.1-8B": ModelConfiguration(vm_url="http://localhost:8080/completion", + constructor=HermesModel), + "mistralai/Mistral-Nemo-Instruct-2407": ModelConfiguration(vm_url="http://localhost:8080/completion", + constructor=MistralModel) +} + + +def get_model(model_id: ModelId, hf_token: str | None = None) -> Model: + model_configuration = MODELS_CONFIG.get(model_id) + + if model_configuration is None: + raise ValueError(f'model_id must be one of {MODEL_IDS}') + + if hf_token is not None: + login(hf_token) + + return model_configuration.constructor(model_id=model_id, **model_configuration.model_dump(exclude={'constructor'})) diff --git a/main.py b/main.py index fef78bb..8946fe1 100644 --- a/main.py +++ b/main.py @@ -2,14 +2,16 @@ from libertai_agents.agents import ChatAgent from libertai_agents.interfaces import Message, MessageRoleEnum -from libertai_agents.models import Hermes2Pro +from libertai_agents.models import get_model from libertai_agents.tools import get_current_temperature async def start(): - agent = ChatAgent(model=Hermes2Pro, system_prompt="You are a helpful assistant", tools=[get_current_temperature]) + agent = ChatAgent(model=get_model("mistralai/Mistral-Nemo-Instruct-2407"), + system_prompt="You are a helpful assistant", + tools=[get_current_temperature]) response = await agent.generate_answer( - [Message(role=MessageRoleEnum.user, content="What's the temperature in Paris ?")]) + [Message(role=MessageRoleEnum.user, content="What's the temperature in Paris and in Lyon ?")]) print(response)