Skip to content

Commit

Permalink
feat: Mistral model support
Browse files Browse the repository at this point in the history
  • Loading branch information
RezaRahemtola committed Aug 17, 2024
1 parent cecbf7e commit 21e78e8
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 46 deletions.
37 changes: 20 additions & 17 deletions libertai_agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import aiohttp
from aiohttp import ClientSession

from libertai_agents.interfaces import Message, MessageRoleEnum, LlamaCppParams, MessageToolCall, ToolCallFunction
from libertai_agents.interfaces import Message, MessageRoleEnum, LlamaCppParams, MessageToolCall, ToolCallFunction, \
ToolCallMessage, ToolResponseMessage
from libertai_agents.models import Model
from libertai_agents.utils import find

Expand All @@ -30,15 +31,15 @@ async def generate_answer(self, messages: list[Message]) -> str:
raise ValueError("Last message is not from the user or tool")

prompt = self.model.generate_prompt(messages, self.system_prompt, self.tools)
print(prompt)
async with aiohttp.ClientSession() as session:
response = await self.__call_model(session, prompt)

tool_calls = self.model.extract_tool_calls_from_response(response)
if len(tool_calls) == 0:
return response
messages.append(self.__create_tool_calls_message(tool_calls))
tool_messages = self.__execute_tool_calls(tool_calls)
tool_calls_message = self.__create_tool_calls_message(tool_calls)
messages.append(tool_calls_message)
tool_messages = self.__execute_tool_calls(tool_calls_message.tool_calls)
return await self.generate_answer(messages + tool_messages)

async def __call_model(self, session: ClientSession, prompt: str):
Expand All @@ -50,24 +51,26 @@ async def __call_model(self, session: ClientSession, prompt: str):
response_data = await response.json()
return response_data["content"]

def __execute_tool_calls(self, tool_calls: list[ToolCallFunction]) -> list[Message]:
def __execute_tool_calls(self, tool_calls: list[MessageToolCall]) -> list[Message]:
# TODO: support async function calls
messages = []
messages: list[Message] = []
for call in tool_calls:
function_name = call.name
function_name = call.function.name
function_to_call = find(lambda x: x.__name__ == function_name, self.tools)
if function_to_call is None:
# TODO: handle error
continue
function_response = function_to_call(*call.arguments.values())
messages.append(Message(role=MessageRoleEnum.tool, name=function_name, content=str(function_response)))
function_response = function_to_call(*call.function.arguments.values())
messages.append(
ToolResponseMessage(role=MessageRoleEnum.tool, name=function_name, tool_call_id=call.id,
content=str(function_response)))
return messages

def __create_tool_calls_message(self, tool_calls: list[ToolCallFunction]) -> Message:
return Message(role=MessageRoleEnum.assistant,
tool_calls=[MessageToolCall(type="function",
id=self.model.generate_tool_call_id(),
function=ToolCallFunction(name=call.name,
arguments=call.arguments)) for
call in
tool_calls])
def __create_tool_calls_message(self, tool_calls: list[ToolCallFunction]) -> ToolCallMessage:
return ToolCallMessage(role=MessageRoleEnum.assistant,
tool_calls=[MessageToolCall(type="function",
id=self.model.generate_tool_call_id(),
function=ToolCallFunction(name=call.name,
arguments=call.arguments)) for
call in
tool_calls])
10 changes: 8 additions & 2 deletions libertai_agents/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,16 @@ class MessageToolCall(BaseModel):

class Message(BaseModel):
role: MessageRoleEnum
name: Optional[str] = None
content: Optional[str] = None


class ToolCallMessage(Message):
tool_calls: list[MessageToolCall]


class ToolResponseMessage(Message):
name: Optional[str] = None
tool_call_id: Optional[str] = None
tool_calls: Optional[list[MessageToolCall]] = None


class LlamaCppParams(BaseModel):
Expand Down
4 changes: 2 additions & 2 deletions libertai_agents/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .base import Model
from .models import get_model
from .base import Model as Model
from .models import get_model as get_model
12 changes: 7 additions & 5 deletions libertai_agents/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,29 @@

from transformers import PreTrainedTokenizerFast, AutoTokenizer

from libertai_agents.interfaces import Message, MessageRoleEnum, ToolCallFunction
from libertai_agents.interfaces import Message, ToolCallFunction, MessageRoleEnum


class Model(ABC):
tokenizer: PreTrainedTokenizerFast
vm_url: str
system_message: bool

def __init__(self, model_id: str, vm_url: str):
def __init__(self, model_id: str, vm_url: str, system_message: bool = True):
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.vm_url = vm_url
self.system_message = system_message

def generate_prompt(self, messages: list[Message], system_prompt: str, tools: list) -> str:
messages.insert(0, Message(role=MessageRoleEnum.system, content=system_prompt))
if self.system_message:
messages.insert(0, Message(role=MessageRoleEnum.system, content=system_prompt))
raw_messages = list(map(lambda x: x.model_dump(), messages))

return self.tokenizer.apply_chat_template(conversation=raw_messages, tools=tools, tokenize=False,
add_generation_prompt=True)

@abstractmethod
def generate_tool_call_id(self) -> str | None:
pass
return None

@staticmethod
@abstractmethod
Expand Down
8 changes: 0 additions & 8 deletions libertai_agents/models/hermes.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,15 @@
import json
import re

from transformers import PreTrainedTokenizerFast

from libertai_agents.interfaces import ToolCallFunction
from libertai_agents.models.base import Model


class HermesModel(Model):
tokenizer: PreTrainedTokenizerFast
vm_url: str

def __init__(self, model_id: str, vm_url: str):
super().__init__(model_id, vm_url)

@staticmethod
def extract_tool_calls_from_response(response: str) -> list[ToolCallFunction]:
tool_calls = re.findall("<tool_call>\s*(.*)\s*</tool_call>", response)
return [ToolCallFunction(**json.loads(call)) for call in tool_calls]

def generate_tool_call_id(self) -> str | None:
return None
21 changes: 10 additions & 11 deletions libertai_agents/models/mistral.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
import json
import re

from transformers import PreTrainedTokenizerFast
import random
import string

from libertai_agents.interfaces import ToolCallFunction
from libertai_agents.models.base import Model


class MistralModel(Model):
tokenizer: PreTrainedTokenizerFast
vm_url: str

def __init__(self, model_id: str, vm_url: str):
super().__init__(model_id, vm_url)
super().__init__(model_id, vm_url, system_message=False)

@staticmethod
def extract_tool_calls_from_response(response: str) -> list[ToolCallFunction]:
tool_calls = re.findall("<tool_call>\s*(.*)\s*</tool_call>", response)
return [ToolCallFunction(**json.loads(call)) for call in tool_calls]
try:
tool_calls = json.loads(response)
return [ToolCallFunction(**call) for call in tool_calls]
except Exception:
return []

def generate_tool_call_id(self) -> str | None:
return None
def generate_tool_call_id(self) -> str:
return ''.join(random.choice(string.ascii_letters + string.digits) for _ in range(9))
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ async def start():
system_prompt="You are a helpful assistant",
tools=[get_current_temperature])
response = await agent.generate_answer(
[Message(role=MessageRoleEnum.user, content="What's the temperature in Paris and in Lyon ?")])
[Message(role=MessageRoleEnum.user, content="What's the temperature in Paris and in Lyon in Celsius ?")])
print(response)


Expand Down

0 comments on commit 21e78e8

Please sign in to comment.