From 9160e3f19579dc35811f72a0a875a3f91e8a8c18 Mon Sep 17 00:00:00 2001 From: Reza Rahemtola Date: Wed, 19 Feb 2025 19:17:41 +0900 Subject: [PATCH] feat(agent): Keep same aiohttp session across requests --- libertai_agents/libertai_agents/agents.py | 52 ++++++++++++------- .../libertai_agents/models/base.py | 1 - .../libertai_agents/models/models.py | 5 -- 3 files changed, 34 insertions(+), 24 deletions(-) diff --git a/libertai_agents/libertai_agents/agents.py b/libertai_agents/libertai_agents/agents.py index 0893bb2..fa5e3a0 100644 --- a/libertai_agents/libertai_agents/agents.py +++ b/libertai_agents/libertai_agents/agents.py @@ -2,10 +2,10 @@ import inspect import json import time +import weakref from http import HTTPStatus from typing import Any, AsyncIterable, Awaitable -import aiohttp from aiohttp import ClientSession from fastapi import APIRouter, FastAPI from starlette.responses import StreamingResponse @@ -36,6 +36,8 @@ class ChatAgent: llamacpp_params: CustomizableLlamaCppParams app: FastAPI | None + __session: ClientSession | None + def __init__( self, model: Model, @@ -62,7 +64,11 @@ def __init__( self.system_prompt = system_prompt self.tools = tools self.llamacpp_params = llamacpp_params - self.call_session: ClientSession | None = None + self.__session = None + + weakref.finalize( + self, self.__sync_cleanup + ) # Ensures cleanup when object is deleted if expose_api: # Define API routes @@ -78,8 +84,27 @@ def __init__( self.app = FastAPI(title="LibertAI ChatAgent") self.app.include_router(router) + @property + def session(self) -> ClientSession: + if self.__session is None: + self.__session = ClientSession() + return self.__session + def __repr__(self): - return f"ChatAgent(model={self.model.model_id})" + return f"{self.__class__.__name__}(model={self.model.model_id})" + + async def __cleanup(self): + if self.__session is not None and not self.__session.closed: + await self.__session.close() + + def __sync_cleanup(self): + """Schedules the async cleanup coroutine properly.""" + try: + loop = asyncio.get_running_loop() + loop.create_task(self.__cleanup()) + except RuntimeError: + # No running loop, run cleanup synchronously + asyncio.run(self.__cleanup()) def get_model_information(self) -> ModelInformation: """ @@ -94,7 +119,6 @@ async def generate_answer( messages: list[Message], only_final_answer: bool = True, system_prompt: str | None = None, - session: ClientSession | None = None, ) -> AsyncIterable[Message]: """ Generate an answer based on a conversation @@ -111,11 +135,8 @@ async def generate_answer( prompt = self.model.generate_prompt( messages, self.tools, system_prompt=system_prompt or self.system_prompt ) - if session is None: - async with aiohttp.ClientSession() as local_session: - response = await self.__call_model(local_session, prompt) - else: - response = await self.__call_model(session, prompt) + + response = await self.__call_model(prompt) if response is None: # TODO: handle error correctly @@ -161,8 +182,6 @@ async def __api_generate_answer( Generate an answer based on an existing conversation. The response messages can be streamed or sent in a single block. """ - if self.call_session is None: - self.call_session = ClientSession() if stream: return StreamingResponse( @@ -174,7 +193,7 @@ async def __api_generate_answer( response_messages: list[Message] = [] async for message in self.generate_answer( - messages, only_final_answer=only_final_answer, session=self.call_session + messages, only_final_answer=only_final_answer ): response_messages.append(message) return response_messages @@ -189,19 +208,16 @@ async def __dump_api_generate_streamed_answer( :param only_final_answer: Param to pass to generate_answer :return: Iterable of each messages from generate_answer dumped to JSON """ - if self.call_session is None: - self.call_session = ClientSession() async for message in self.generate_answer( - messages, only_final_answer=only_final_answer, session=self.call_session + messages, only_final_answer=only_final_answer ): yield json.dumps(message.model_dump(), indent=4) - async def __call_model(self, session: ClientSession, prompt: str) -> str | None: + async def __call_model(self, prompt: str) -> str | None: """ Call the model with a given prompt - :param session: aiohttp session to use to make the call :param prompt: Prompt to give to the model :return: String response (if no error) """ @@ -211,7 +227,7 @@ async def __call_model(self, session: ClientSession, prompt: str) -> str | None: max_retries = 150 # Looping until we get a satisfying response for _ in range(max_retries): - async with session.post( + async with self.session.post( self.model.vm_url, json=params.model_dump() ) as response: if response.status == HTTPStatus.OK: diff --git a/libertai_agents/libertai_agents/models/base.py b/libertai_agents/libertai_agents/models/base.py index 5362d71..ea69cfa 100644 --- a/libertai_agents/libertai_agents/models/base.py +++ b/libertai_agents/libertai_agents/models/base.py @@ -12,7 +12,6 @@ ModelId = Literal[ "NousResearch/Hermes-3-Llama-3.1-8B", "mistralai/Mistral-Nemo-Instruct-2407", - "deepseek-ai/DeepSeek-V3", ] diff --git a/libertai_agents/libertai_agents/models/models.py b/libertai_agents/libertai_agents/models/models.py index 1c0bf48..8c89978 100644 --- a/libertai_agents/libertai_agents/models/models.py +++ b/libertai_agents/libertai_agents/models/models.py @@ -25,11 +25,6 @@ class FullModelConfiguration(ModelConfiguration): context_length=16384, constructor=HermesModel, ), - "deepseek-ai/DeepSeek-V3": FullModelConfiguration( - vm_url="https://curated.aleph.cloud/vm/9aa80dc7f00c515a5f56b70e65fdab4c367e35f341c3b4220419adb6ca86a33f/completion", - context_length=16384, - constructor=HermesModel, - ), "mistralai/Mistral-Nemo-Instruct-2407": FullModelConfiguration( vm_url="https://curated.aleph.cloud/vm/2c4ad0bf343fb12924936cbc801732d95ce90f84cd895aa8bee82c0a062815c2/completion", context_length=8192,