-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
7f04619
commit 068199c
Showing
7 changed files
with
540 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import aiohttp | ||
from aiohttp import ClientSession | ||
|
||
from libertai_agents.interfaces import Message, MessageRoleEnum, LlamaCppParams | ||
from libertai_agents.models import Model | ||
|
||
|
||
class ChatAgent: | ||
model: Model | ||
system_prompt: str | ||
tools: list | ||
|
||
def __init__(self, model: Model, system_prompt: str, tools: list | None = None): | ||
if tools is None: | ||
tools = [] | ||
self.model = model | ||
self.system_prompt = system_prompt | ||
self.tools = tools | ||
|
||
async def generate_answer(self, messages: list[Message]) -> str: | ||
if len(messages) == 0: | ||
raise ValueError("No previous message to respond to") | ||
if messages[-1].role != MessageRoleEnum.user: | ||
raise ValueError("Last message is not from the user") | ||
|
||
prompt = self.model.generate_prompt(messages, self.system_prompt, self.tools) | ||
print(prompt) | ||
async with aiohttp.ClientSession() as session: | ||
return await self.__call_model(session, prompt) | ||
|
||
async def __call_model(self, session: ClientSession, prompt: str): | ||
params = LlamaCppParams(prompt=prompt) | ||
|
||
async with session.post(self.model.vm_url, json=params.model_dump()) as response: | ||
if response.status == 200: | ||
response_data = await response.json() | ||
return response_data["content"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from enum import Enum | ||
|
||
from pydantic import BaseModel | ||
|
||
|
||
class MessageRoleEnum(str, Enum): | ||
system = 'system' | ||
user = 'user' | ||
assistant = 'assistant' | ||
|
||
|
||
class Message(BaseModel): | ||
role: MessageRoleEnum | ||
content: str | ||
|
||
|
||
class LlamaCppParams(BaseModel): | ||
prompt: str |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,23 @@ | ||
from transformers import AutoTokenizer, PreTrainedTokenizerFast | ||
|
||
from libertai_agents.interfaces import Message, MessageRoleEnum | ||
|
||
|
||
class Model: | ||
tokenizer: PreTrainedTokenizerFast | ||
vm_url: str | ||
|
||
def __init__(self, model_id: str): | ||
def __init__(self, model_id: str, vm_url: str): | ||
self.tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
self.vm_url = vm_url | ||
|
||
def generate_prompt(self, messages: list[Message], system_prompt: str, tools: list) -> str: | ||
messages.insert(0, Message(role=MessageRoleEnum.system, content=system_prompt)) | ||
raw_messages = list(map(lambda x: x.model_dump(), messages)) | ||
|
||
return self.tokenizer.apply_chat_template(conversation=raw_messages, tools=tools, tokenize=False, | ||
add_generation_prompt=True) | ||
|
||
|
||
Hermes2Pro = Model(model_id="NousResearch/Hermes-2-Pro-Llama-3-8B") | ||
Hermes2Pro = Model(model_id="NousResearch/Hermes-2-Pro-Llama-3-8B", | ||
vm_url='https://curated.aleph.cloud/vm/84df52ac4466d121ef3bb409bb14f315de7be4ce600e8948d71df6485aa5bcc3/completion') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
def get_current_temperature(location: str, unit: str) -> float: | ||
""" | ||
Get the current temperature at a location. | ||
Args: | ||
location: The location to get the temperature for, in the format "City, Country" | ||
unit: The unit to return the temperature in. (choices: ["celsius", "fahrenheit"]) | ||
Returns: | ||
The current temperature at the specified location in the specified units, as a float. | ||
""" | ||
return 22. # A real function should probably actually get the temperature! |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import asyncio | ||
|
||
from libertai_agents.agents import ChatAgent | ||
from libertai_agents.interfaces import MessageRoleEnum, Message | ||
from libertai_agents.models import Hermes2Pro | ||
from libertai_agents.tools import get_current_temperature | ||
|
||
|
||
async def start(): | ||
agent = ChatAgent(model=Hermes2Pro, system_prompt="You are a helpful assistant", tools=[get_current_temperature]) | ||
response = await agent.generate_answer( | ||
[Message(role=MessageRoleEnum.user, content="What's the temperature in Paris ?")]) | ||
print(response) | ||
|
||
|
||
asyncio.run(start()) |
Oops, something went wrong.