Skip to content

Commit

Permalink
wip: Basic model call
Browse files Browse the repository at this point in the history
  • Loading branch information
RezaRahemtola committed Aug 15, 2024
1 parent 7f04619 commit 068199c
Show file tree
Hide file tree
Showing 7 changed files with 540 additions and 3 deletions.
37 changes: 37 additions & 0 deletions libertai_agents/agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import aiohttp
from aiohttp import ClientSession

from libertai_agents.interfaces import Message, MessageRoleEnum, LlamaCppParams
from libertai_agents.models import Model


class ChatAgent:
model: Model
system_prompt: str
tools: list

def __init__(self, model: Model, system_prompt: str, tools: list | None = None):
if tools is None:
tools = []
self.model = model
self.system_prompt = system_prompt
self.tools = tools

async def generate_answer(self, messages: list[Message]) -> str:
if len(messages) == 0:
raise ValueError("No previous message to respond to")
if messages[-1].role != MessageRoleEnum.user:
raise ValueError("Last message is not from the user")

prompt = self.model.generate_prompt(messages, self.system_prompt, self.tools)
print(prompt)
async with aiohttp.ClientSession() as session:
return await self.__call_model(session, prompt)

async def __call_model(self, session: ClientSession, prompt: str):
params = LlamaCppParams(prompt=prompt)

async with session.post(self.model.vm_url, json=params.model_dump()) as response:
if response.status == 200:
response_data = await response.json()
return response_data["content"]
18 changes: 18 additions & 0 deletions libertai_agents/interfaces.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from enum import Enum

from pydantic import BaseModel


class MessageRoleEnum(str, Enum):
system = 'system'
user = 'user'
assistant = 'assistant'


class Message(BaseModel):
role: MessageRoleEnum
content: str


class LlamaCppParams(BaseModel):
prompt: str
16 changes: 14 additions & 2 deletions libertai_agents/models.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
from transformers import AutoTokenizer, PreTrainedTokenizerFast

from libertai_agents.interfaces import Message, MessageRoleEnum


class Model:
tokenizer: PreTrainedTokenizerFast
vm_url: str

def __init__(self, model_id: str):
def __init__(self, model_id: str, vm_url: str):
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.vm_url = vm_url

def generate_prompt(self, messages: list[Message], system_prompt: str, tools: list) -> str:
messages.insert(0, Message(role=MessageRoleEnum.system, content=system_prompt))
raw_messages = list(map(lambda x: x.model_dump(), messages))

return self.tokenizer.apply_chat_template(conversation=raw_messages, tools=tools, tokenize=False,
add_generation_prompt=True)


Hermes2Pro = Model(model_id="NousResearch/Hermes-2-Pro-Llama-3-8B")
Hermes2Pro = Model(model_id="NousResearch/Hermes-2-Pro-Llama-3-8B",
vm_url='https://curated.aleph.cloud/vm/84df52ac4466d121ef3bb409bb14f315de7be4ce600e8948d71df6485aa5bcc3/completion')
11 changes: 11 additions & 0 deletions libertai_agents/tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
def get_current_temperature(location: str, unit: str) -> float:
"""
Get the current temperature at a location.
Args:
location: The location to get the temperature for, in the format "City, Country"
unit: The unit to return the temperature in. (choices: ["celsius", "fahrenheit"])
Returns:
The current temperature at the specified location in the specified units, as a float.
"""
return 22. # A real function should probably actually get the temperature!
16 changes: 16 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import asyncio

from libertai_agents.agents import ChatAgent
from libertai_agents.interfaces import MessageRoleEnum, Message
from libertai_agents.models import Hermes2Pro
from libertai_agents.tools import get_current_temperature


async def start():
agent = ChatAgent(model=Hermes2Pro, system_prompt="You are a helpful assistant", tools=[get_current_temperature])
response = await agent.generate_answer(
[Message(role=MessageRoleEnum.user, content="What's the temperature in Paris ?")])
print(response)


asyncio.run(start())
Loading

0 comments on commit 068199c

Please sign in to comment.