Skip to content

Commit

Permalink
feat: Basic function call
Browse files Browse the repository at this point in the history
  • Loading branch information
RezaRahemtola committed Aug 16, 2024
1 parent 068199c commit 86d2042
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 7 deletions.
40 changes: 36 additions & 4 deletions libertai_agents/agents.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import aiohttp
from aiohttp import ClientSession

from libertai_agents.interfaces import Message, MessageRoleEnum, LlamaCppParams
from libertai_agents.interfaces import Message, MessageRoleEnum, LlamaCppParams, MessageToolCall, ToolCallFunction
from libertai_agents.models import Model
from libertai_agents.utils import find


class ChatAgent:
Expand All @@ -13,20 +14,30 @@ class ChatAgent:
def __init__(self, model: Model, system_prompt: str, tools: list | None = None):
if tools is None:
tools = []

if len(set(map(lambda x: x.__name__, tools))) != len(tools):
raise ValueError("Tool functions must have different names")
self.model = model
self.system_prompt = system_prompt
self.tools = tools

async def generate_answer(self, messages: list[Message]) -> str:
if len(messages) == 0:
raise ValueError("No previous message to respond to")
if messages[-1].role != MessageRoleEnum.user:
raise ValueError("Last message is not from the user")
if messages[-1].role not in [MessageRoleEnum.user, MessageRoleEnum.tool]:
raise ValueError("Last message is not from the user or tool")

prompt = self.model.generate_prompt(messages, self.system_prompt, self.tools)
print(prompt)
async with aiohttp.ClientSession() as session:
return await self.__call_model(session, prompt)
response = await self.__call_model(session, prompt)

tool_calls = self.model.extract_tool_calls_from_response(response)
if len(tool_calls) == 0:
return response
messages.append(self.__create_tool_calls_message(tool_calls))
tool_messages = self.execute_tool_calls(tool_calls)
return await self.generate_answer(messages + tool_messages)

async def __call_model(self, session: ClientSession, prompt: str):
params = LlamaCppParams(prompt=prompt)
Expand All @@ -35,3 +46,24 @@ async def __call_model(self, session: ClientSession, prompt: str):
if response.status == 200:
response_data = await response.json()
return response_data["content"]

def execute_tool_calls(self, tool_calls: list[ToolCallFunction]) -> list[Message]:
# TODO: support async function calls
messages = []
for call in tool_calls:
function_to_call = find(lambda x: x.__name__ == call.name, self.tools)
if function_to_call is None:
# TODO: handle error
continue
function_response = function_to_call(*call.arguments.values())
messages.append(Message(role=MessageRoleEnum.tool, name=call.name, content=str(function_response)))
return messages

@staticmethod
def __create_tool_calls_message(tool_calls: list[ToolCallFunction]) -> Message:
return Message(role=MessageRoleEnum.assistant,
tool_calls=[MessageToolCall(type="function",
function=ToolCallFunction(name=call.name,
arguments=call.arguments)) for
call in
tool_calls])
16 changes: 15 additions & 1 deletion libertai_agents/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from enum import Enum
from typing import Optional

from pydantic import BaseModel

Expand All @@ -7,11 +8,24 @@ class MessageRoleEnum(str, Enum):
system = 'system'
user = 'user'
assistant = 'assistant'
tool = 'tool'


class ToolCallFunction(BaseModel):
name: str
arguments: dict


class MessageToolCall(BaseModel):
type: str
function: ToolCallFunction


class Message(BaseModel):
role: MessageRoleEnum
content: str
name: Optional[str] = None
content: Optional[str] = None
tool_calls: Optional[list[MessageToolCall]] = None


class LlamaCppParams(BaseModel):
Expand Down
10 changes: 9 additions & 1 deletion libertai_agents/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import json
import re

from transformers import AutoTokenizer, PreTrainedTokenizerFast

from libertai_agents.interfaces import Message, MessageRoleEnum
from libertai_agents.interfaces import Message, MessageRoleEnum, ToolCallFunction


class Model:
Expand All @@ -18,6 +21,11 @@ def generate_prompt(self, messages: list[Message], system_prompt: str, tools: li
return self.tokenizer.apply_chat_template(conversation=raw_messages, tools=tools, tokenize=False,
add_generation_prompt=True)

@staticmethod
def extract_tool_calls_from_response(response: str) -> list[ToolCallFunction]:
tool_calls = re.findall("^<tool_call>\s*(.*)\s*</tool_call>$", response)
return [ToolCallFunction(**json.loads(call)) for call in tool_calls]


Hermes2Pro = Model(model_id="NousResearch/Hermes-2-Pro-Llama-3-8B",
vm_url='https://curated.aleph.cloud/vm/84df52ac4466d121ef3bb409bb14f315de7be4ce600e8948d71df6485aa5bcc3/completion')
11 changes: 11 additions & 0 deletions libertai_agents/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from typing import TypeVar, Callable

T = TypeVar("T")


def find(f: Callable[[T], bool], seq: list[T]) -> T | None:
"""Return first item in sequence where f(item) == True."""
for item in seq:
if f(item):
return item
return None
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio

from libertai_agents.agents import ChatAgent
from libertai_agents.interfaces import MessageRoleEnum, Message
from libertai_agents.interfaces import Message, MessageRoleEnum
from libertai_agents.models import Hermes2Pro
from libertai_agents.tools import get_current_temperature

Expand Down

0 comments on commit 86d2042

Please sign in to comment.