diff --git a/libertai_agents/agents.py b/libertai_agents/agents.py index 9153b8c..1343b44 100644 --- a/libertai_agents/agents.py +++ b/libertai_agents/agents.py @@ -1,8 +1,9 @@ import aiohttp from aiohttp import ClientSession -from libertai_agents.interfaces import Message, MessageRoleEnum, LlamaCppParams +from libertai_agents.interfaces import Message, MessageRoleEnum, LlamaCppParams, MessageToolCall, ToolCallFunction from libertai_agents.models import Model +from libertai_agents.utils import find class ChatAgent: @@ -13,6 +14,9 @@ class ChatAgent: def __init__(self, model: Model, system_prompt: str, tools: list | None = None): if tools is None: tools = [] + + if len(set(map(lambda x: x.__name__, tools))) != len(tools): + raise ValueError("Tool functions must have different names") self.model = model self.system_prompt = system_prompt self.tools = tools @@ -20,13 +24,20 @@ def __init__(self, model: Model, system_prompt: str, tools: list | None = None): async def generate_answer(self, messages: list[Message]) -> str: if len(messages) == 0: raise ValueError("No previous message to respond to") - if messages[-1].role != MessageRoleEnum.user: - raise ValueError("Last message is not from the user") + if messages[-1].role not in [MessageRoleEnum.user, MessageRoleEnum.tool]: + raise ValueError("Last message is not from the user or tool") prompt = self.model.generate_prompt(messages, self.system_prompt, self.tools) print(prompt) async with aiohttp.ClientSession() as session: - return await self.__call_model(session, prompt) + response = await self.__call_model(session, prompt) + + tool_calls = self.model.extract_tool_calls_from_response(response) + if len(tool_calls) == 0: + return response + messages.append(self.__create_tool_calls_message(tool_calls)) + tool_messages = self.execute_tool_calls(tool_calls) + return await self.generate_answer(messages + tool_messages) async def __call_model(self, session: ClientSession, prompt: str): params = LlamaCppParams(prompt=prompt) @@ -35,3 +46,24 @@ async def __call_model(self, session: ClientSession, prompt: str): if response.status == 200: response_data = await response.json() return response_data["content"] + + def execute_tool_calls(self, tool_calls: list[ToolCallFunction]) -> list[Message]: + # TODO: support async function calls + messages = [] + for call in tool_calls: + function_to_call = find(lambda x: x.__name__ == call.name, self.tools) + if function_to_call is None: + # TODO: handle error + continue + function_response = function_to_call(*call.arguments.values()) + messages.append(Message(role=MessageRoleEnum.tool, name=call.name, content=str(function_response))) + return messages + + @staticmethod + def __create_tool_calls_message(tool_calls: list[ToolCallFunction]) -> Message: + return Message(role=MessageRoleEnum.assistant, + tool_calls=[MessageToolCall(type="function", + function=ToolCallFunction(name=call.name, + arguments=call.arguments)) for + call in + tool_calls]) diff --git a/libertai_agents/interfaces.py b/libertai_agents/interfaces.py index 67de31e..a39a0ae 100644 --- a/libertai_agents/interfaces.py +++ b/libertai_agents/interfaces.py @@ -1,4 +1,5 @@ from enum import Enum +from typing import Optional from pydantic import BaseModel @@ -7,11 +8,24 @@ class MessageRoleEnum(str, Enum): system = 'system' user = 'user' assistant = 'assistant' + tool = 'tool' + + +class ToolCallFunction(BaseModel): + name: str + arguments: dict + + +class MessageToolCall(BaseModel): + type: str + function: ToolCallFunction class Message(BaseModel): role: MessageRoleEnum - content: str + name: Optional[str] = None + content: Optional[str] = None + tool_calls: Optional[list[MessageToolCall]] = None class LlamaCppParams(BaseModel): diff --git a/libertai_agents/models.py b/libertai_agents/models.py index 9cdf322..703cde5 100644 --- a/libertai_agents/models.py +++ b/libertai_agents/models.py @@ -1,6 +1,9 @@ +import json +import re + from transformers import AutoTokenizer, PreTrainedTokenizerFast -from libertai_agents.interfaces import Message, MessageRoleEnum +from libertai_agents.interfaces import Message, MessageRoleEnum, ToolCallFunction class Model: @@ -18,6 +21,11 @@ def generate_prompt(self, messages: list[Message], system_prompt: str, tools: li return self.tokenizer.apply_chat_template(conversation=raw_messages, tools=tools, tokenize=False, add_generation_prompt=True) + @staticmethod + def extract_tool_calls_from_response(response: str) -> list[ToolCallFunction]: + tool_calls = re.findall("^\s*(.*)\s*$", response) + return [ToolCallFunction(**json.loads(call)) for call in tool_calls] + Hermes2Pro = Model(model_id="NousResearch/Hermes-2-Pro-Llama-3-8B", vm_url='https://curated.aleph.cloud/vm/84df52ac4466d121ef3bb409bb14f315de7be4ce600e8948d71df6485aa5bcc3/completion') diff --git a/libertai_agents/utils.py b/libertai_agents/utils.py new file mode 100644 index 0000000..4661912 --- /dev/null +++ b/libertai_agents/utils.py @@ -0,0 +1,11 @@ +from typing import TypeVar, Callable + +T = TypeVar("T") + + +def find(f: Callable[[T], bool], seq: list[T]) -> T | None: + """Return first item in sequence where f(item) == True.""" + for item in seq: + if f(item): + return item + return None diff --git a/main.py b/main.py index 9cb5168..fef78bb 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,7 @@ import asyncio from libertai_agents.agents import ChatAgent -from libertai_agents.interfaces import MessageRoleEnum, Message +from libertai_agents.interfaces import Message, MessageRoleEnum from libertai_agents.models import Hermes2Pro from libertai_agents.tools import get_current_temperature