Skip to content

Commit

Permalink
feat(agent): Keep same aiohttp session across requests
Browse files Browse the repository at this point in the history
  • Loading branch information
RezaRahemtola committed Feb 19, 2025
1 parent 88b0700 commit 9160e3f
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 24 deletions.
52 changes: 34 additions & 18 deletions libertai_agents/libertai_agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import inspect
import json
import time
import weakref
from http import HTTPStatus
from typing import Any, AsyncIterable, Awaitable

import aiohttp
from aiohttp import ClientSession
from fastapi import APIRouter, FastAPI
from starlette.responses import StreamingResponse
Expand Down Expand Up @@ -36,6 +36,8 @@ class ChatAgent:
llamacpp_params: CustomizableLlamaCppParams
app: FastAPI | None

__session: ClientSession | None

def __init__(
self,
model: Model,
Expand All @@ -62,7 +64,11 @@ def __init__(
self.system_prompt = system_prompt
self.tools = tools
self.llamacpp_params = llamacpp_params
self.call_session: ClientSession | None = None
self.__session = None

weakref.finalize(
self, self.__sync_cleanup
) # Ensures cleanup when object is deleted

if expose_api:
# Define API routes
Expand All @@ -78,8 +84,27 @@ def __init__(
self.app = FastAPI(title="LibertAI ChatAgent")
self.app.include_router(router)

@property
def session(self) -> ClientSession:
if self.__session is None:
self.__session = ClientSession()
return self.__session

def __repr__(self):
return f"ChatAgent(model={self.model.model_id})"
return f"{self.__class__.__name__}(model={self.model.model_id})"

Check warning on line 94 in libertai_agents/libertai_agents/agents.py

View check run for this annotation

Codecov / codecov/patch

libertai_agents/libertai_agents/agents.py#L94

Added line #L94 was not covered by tests

async def __cleanup(self):
if self.__session is not None and not self.__session.closed:
await self.__session.close()

Check warning on line 98 in libertai_agents/libertai_agents/agents.py

View check run for this annotation

Codecov / codecov/patch

libertai_agents/libertai_agents/agents.py#L97-L98

Added lines #L97 - L98 were not covered by tests

def __sync_cleanup(self):
"""Schedules the async cleanup coroutine properly."""
try:
loop = asyncio.get_running_loop()
loop.create_task(self.__cleanup())
except RuntimeError:

Check warning on line 105 in libertai_agents/libertai_agents/agents.py

View check run for this annotation

Codecov / codecov/patch

libertai_agents/libertai_agents/agents.py#L102-L105

Added lines #L102 - L105 were not covered by tests
# No running loop, run cleanup synchronously
asyncio.run(self.__cleanup())

Check warning on line 107 in libertai_agents/libertai_agents/agents.py

View check run for this annotation

Codecov / codecov/patch

libertai_agents/libertai_agents/agents.py#L107

Added line #L107 was not covered by tests

def get_model_information(self) -> ModelInformation:
"""
Expand All @@ -94,7 +119,6 @@ async def generate_answer(
messages: list[Message],
only_final_answer: bool = True,
system_prompt: str | None = None,
session: ClientSession | None = None,
) -> AsyncIterable[Message]:
"""
Generate an answer based on a conversation
Expand All @@ -111,11 +135,8 @@ async def generate_answer(
prompt = self.model.generate_prompt(
messages, self.tools, system_prompt=system_prompt or self.system_prompt
)
if session is None:
async with aiohttp.ClientSession() as local_session:
response = await self.__call_model(local_session, prompt)
else:
response = await self.__call_model(session, prompt)

response = await self.__call_model(prompt)

if response is None:
# TODO: handle error correctly
Expand Down Expand Up @@ -161,8 +182,6 @@ async def __api_generate_answer(
Generate an answer based on an existing conversation.
The response messages can be streamed or sent in a single block.
"""
if self.call_session is None:
self.call_session = ClientSession()

if stream:
return StreamingResponse(
Expand All @@ -174,7 +193,7 @@ async def __api_generate_answer(

response_messages: list[Message] = []
async for message in self.generate_answer(
messages, only_final_answer=only_final_answer, session=self.call_session
messages, only_final_answer=only_final_answer
):
response_messages.append(message)
return response_messages
Expand All @@ -189,19 +208,16 @@ async def __dump_api_generate_streamed_answer(
:param only_final_answer: Param to pass to generate_answer
:return: Iterable of each messages from generate_answer dumped to JSON
"""
if self.call_session is None:
self.call_session = ClientSession()

async for message in self.generate_answer(
messages, only_final_answer=only_final_answer, session=self.call_session
messages, only_final_answer=only_final_answer
):
yield json.dumps(message.model_dump(), indent=4)

async def __call_model(self, session: ClientSession, prompt: str) -> str | None:
async def __call_model(self, prompt: str) -> str | None:
"""
Call the model with a given prompt
:param session: aiohttp session to use to make the call
:param prompt: Prompt to give to the model
:return: String response (if no error)
"""
Expand All @@ -211,7 +227,7 @@ async def __call_model(self, session: ClientSession, prompt: str) -> str | None:
max_retries = 150
# Looping until we get a satisfying response
for _ in range(max_retries):
async with session.post(
async with self.session.post(
self.model.vm_url, json=params.model_dump()
) as response:
if response.status == HTTPStatus.OK:
Expand Down
1 change: 0 additions & 1 deletion libertai_agents/libertai_agents/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
ModelId = Literal[
"NousResearch/Hermes-3-Llama-3.1-8B",
"mistralai/Mistral-Nemo-Instruct-2407",
"deepseek-ai/DeepSeek-V3",
]


Expand Down
5 changes: 0 additions & 5 deletions libertai_agents/libertai_agents/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,6 @@ class FullModelConfiguration(ModelConfiguration):
context_length=16384,
constructor=HermesModel,
),
"deepseek-ai/DeepSeek-V3": FullModelConfiguration(
vm_url="https://curated.aleph.cloud/vm/9aa80dc7f00c515a5f56b70e65fdab4c367e35f341c3b4220419adb6ca86a33f/completion",
context_length=16384,
constructor=HermesModel,
),
"mistralai/Mistral-Nemo-Instruct-2407": FullModelConfiguration(
vm_url="https://curated.aleph.cloud/vm/2c4ad0bf343fb12924936cbc801732d95ce90f84cd895aa8bee82c0a062815c2/completion",
context_length=8192,
Expand Down

0 comments on commit 9160e3f

Please sign in to comment.