Skip to content

Commit

Permalink
feat: Generic model creation and mistral support started
Browse files Browse the repository at this point in the history
  • Loading branch information
RezaRahemtola committed Aug 16, 2024
1 parent 86d2042 commit cecbf7e
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 21 deletions.
20 changes: 12 additions & 8 deletions libertai_agents/agents.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Callable

import aiohttp
from aiohttp import ClientSession

Expand All @@ -9,9 +11,9 @@
class ChatAgent:
model: Model
system_prompt: str
tools: list
tools: list[Callable]

def __init__(self, model: Model, system_prompt: str, tools: list | None = None):
def __init__(self, model: Model, system_prompt: str, tools: list[Callable] | None = None):
if tools is None:
tools = []

Expand All @@ -36,33 +38,35 @@ async def generate_answer(self, messages: list[Message]) -> str:
if len(tool_calls) == 0:
return response
messages.append(self.__create_tool_calls_message(tool_calls))
tool_messages = self.execute_tool_calls(tool_calls)
tool_messages = self.__execute_tool_calls(tool_calls)
return await self.generate_answer(messages + tool_messages)

async def __call_model(self, session: ClientSession, prompt: str):
params = LlamaCppParams(prompt=prompt)

async with session.post(self.model.vm_url, json=params.model_dump()) as response:
# TODO: handle errors and retries
if response.status == 200:
response_data = await response.json()
return response_data["content"]

def execute_tool_calls(self, tool_calls: list[ToolCallFunction]) -> list[Message]:
def __execute_tool_calls(self, tool_calls: list[ToolCallFunction]) -> list[Message]:
# TODO: support async function calls
messages = []
for call in tool_calls:
function_to_call = find(lambda x: x.__name__ == call.name, self.tools)
function_name = call.name
function_to_call = find(lambda x: x.__name__ == function_name, self.tools)
if function_to_call is None:
# TODO: handle error
continue
function_response = function_to_call(*call.arguments.values())
messages.append(Message(role=MessageRoleEnum.tool, name=call.name, content=str(function_response)))
messages.append(Message(role=MessageRoleEnum.tool, name=function_name, content=str(function_response)))
return messages

@staticmethod
def __create_tool_calls_message(tool_calls: list[ToolCallFunction]) -> Message:
def __create_tool_calls_message(self, tool_calls: list[ToolCallFunction]) -> Message:
return Message(role=MessageRoleEnum.assistant,
tool_calls=[MessageToolCall(type="function",
id=self.model.generate_tool_call_id(),
function=ToolCallFunction(name=call.name,
arguments=call.arguments)) for
call in
Expand Down
2 changes: 2 additions & 0 deletions libertai_agents/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ class ToolCallFunction(BaseModel):

class MessageToolCall(BaseModel):
type: str
id: Optional[str] = None
function: ToolCallFunction


class Message(BaseModel):
role: MessageRoleEnum
name: Optional[str] = None
content: Optional[str] = None
tool_call_id: Optional[str] = None
tool_calls: Optional[list[MessageToolCall]] = None


Expand Down
2 changes: 2 additions & 0 deletions libertai_agents/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .base import Model

Check failure on line 1 in libertai_agents/models/__init__.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

libertai_agents/models/__init__.py:1:19: F401 `.base.Model` imported but unused; consider removing, adding to `__all__`, or using a redundant alias
from .models import get_model

Check failure on line 2 in libertai_agents/models/__init__.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

libertai_agents/models/__init__.py:2:21: F401 `.models.get_model` imported but unused; consider removing, adding to `__all__`, or using a redundant alias
19 changes: 9 additions & 10 deletions libertai_agents/models.py → libertai_agents/models/base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import json
import re
from abc import ABC, abstractmethod

from transformers import AutoTokenizer, PreTrainedTokenizerFast
from transformers import PreTrainedTokenizerFast, AutoTokenizer

from libertai_agents.interfaces import Message, MessageRoleEnum, ToolCallFunction


class Model:
class Model(ABC):
tokenizer: PreTrainedTokenizerFast
vm_url: str

Expand All @@ -21,11 +20,11 @@ def generate_prompt(self, messages: list[Message], system_prompt: str, tools: li
return self.tokenizer.apply_chat_template(conversation=raw_messages, tools=tools, tokenize=False,
add_generation_prompt=True)

@abstractmethod
def generate_tool_call_id(self) -> str | None:
pass

@staticmethod
@abstractmethod
def extract_tool_calls_from_response(response: str) -> list[ToolCallFunction]:
tool_calls = re.findall("^<tool_call>\s*(.*)\s*</tool_call>$", response)
return [ToolCallFunction(**json.loads(call)) for call in tool_calls]


Hermes2Pro = Model(model_id="NousResearch/Hermes-2-Pro-Llama-3-8B",
vm_url='https://curated.aleph.cloud/vm/84df52ac4466d121ef3bb409bb14f315de7be4ce600e8948d71df6485aa5bcc3/completion')
pass
23 changes: 23 additions & 0 deletions libertai_agents/models/hermes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import json
import re

from transformers import PreTrainedTokenizerFast

from libertai_agents.interfaces import ToolCallFunction
from libertai_agents.models.base import Model


class HermesModel(Model):
tokenizer: PreTrainedTokenizerFast
vm_url: str

def __init__(self, model_id: str, vm_url: str):
super().__init__(model_id, vm_url)

@staticmethod
def extract_tool_calls_from_response(response: str) -> list[ToolCallFunction]:
tool_calls = re.findall("<tool_call>\s*(.*)\s*</tool_call>", response)
return [ToolCallFunction(**json.loads(call)) for call in tool_calls]

def generate_tool_call_id(self) -> str | None:
return None
23 changes: 23 additions & 0 deletions libertai_agents/models/mistral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import json
import re

from transformers import PreTrainedTokenizerFast

from libertai_agents.interfaces import ToolCallFunction
from libertai_agents.models.base import Model


class MistralModel(Model):
tokenizer: PreTrainedTokenizerFast
vm_url: str

def __init__(self, model_id: str, vm_url: str):
super().__init__(model_id, vm_url)

@staticmethod
def extract_tool_calls_from_response(response: str) -> list[ToolCallFunction]:
tool_calls = re.findall("<tool_call>\s*(.*)\s*</tool_call>", response)
return [ToolCallFunction(**json.loads(call)) for call in tool_calls]

def generate_tool_call_id(self) -> str | None:
return None
43 changes: 43 additions & 0 deletions libertai_agents/models/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import typing

from huggingface_hub import login

Check failure on line 3 in libertai_agents/models/models.py

View workflow job for this annotation

GitHub Actions / mypy

[mypy] reported by reviewdog 🐶 Skipping analyzing "huggingface_hub": module is installed, but missing library stubs or py.typed marker [import-untyped] Raw Output: /home/runner/work/libertai-agents/libertai-agents/libertai_agents/models/models.py:3:1: error: Skipping analyzing "huggingface_hub": module is installed, but missing library stubs or py.typed marker [import-untyped]
from pydantic import BaseModel

from libertai_agents.models.base import Model
from libertai_agents.models.hermes import HermesModel
from libertai_agents.models.mistral import MistralModel


class ModelConfiguration(BaseModel):
vm_url: str
constructor: typing.Type[Model]


ModelId = typing.Literal[
"NousResearch/Hermes-2-Pro-Llama-3-8B",
"NousResearch/Hermes-3-Llama-3.1-8B",
"mistralai/Mistral-Nemo-Instruct-2407"
]
MODEL_IDS: list[ModelId] = list(typing.get_args(ModelId))

MODELS_CONFIG: dict[ModelId, ModelConfiguration] = {
"NousResearch/Hermes-2-Pro-Llama-3-8B": ModelConfiguration(
vm_url="https://curated.aleph.cloud/vm/84df52ac4466d121ef3bb409bb14f315de7be4ce600e8948d71df6485aa5bcc3/completion",
constructor=HermesModel),
"NousResearch/Hermes-3-Llama-3.1-8B": ModelConfiguration(vm_url="http://localhost:8080/completion",
constructor=HermesModel),
"mistralai/Mistral-Nemo-Instruct-2407": ModelConfiguration(vm_url="http://localhost:8080/completion",
constructor=MistralModel)
}


def get_model(model_id: ModelId, hf_token: str | None = None) -> Model:
model_configuration = MODELS_CONFIG.get(model_id)

if model_configuration is None:
raise ValueError(f'model_id must be one of {MODEL_IDS}')

if hf_token is not None:
login(hf_token)

return model_configuration.constructor(model_id=model_id, **model_configuration.model_dump(exclude={'constructor'}))
8 changes: 5 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@

from libertai_agents.agents import ChatAgent
from libertai_agents.interfaces import Message, MessageRoleEnum
from libertai_agents.models import Hermes2Pro
from libertai_agents.models import get_model
from libertai_agents.tools import get_current_temperature


async def start():
agent = ChatAgent(model=Hermes2Pro, system_prompt="You are a helpful assistant", tools=[get_current_temperature])
agent = ChatAgent(model=get_model("mistralai/Mistral-Nemo-Instruct-2407"),
system_prompt="You are a helpful assistant",
tools=[get_current_temperature])
response = await agent.generate_answer(
[Message(role=MessageRoleEnum.user, content="What's the temperature in Paris ?")])
[Message(role=MessageRoleEnum.user, content="What's the temperature in Paris and in Lyon ?")])
print(response)


Expand Down

0 comments on commit cecbf7e

Please sign in to comment.