diff --git a/src/api/models/bedrock.py b/src/api/models/bedrock.py index 04c86b3..07c3cf7 100644 --- a/src/api/models/bedrock.py +++ b/src/api/models/bedrock.py @@ -7,6 +7,8 @@ from typing import AsyncIterable, Iterable, Literal import boto3 +import asyncio +import functools import numpy as np import requests import tiktoken @@ -55,6 +57,11 @@ ENCODER = tiktoken.get_encoding("cl100k_base") +async def run_in_executor(func, **args): + """Run a function in an executor.""" + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, functools.partial(func, **args)) + class BedrockModel(BaseChatModel): # https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#conversation-inference-supported-models-features @@ -189,13 +196,13 @@ def validate(self, chat_request: ChatRequest): status_code=400, detail=error, ) - - def _invoke_bedrock(self, chat_request: ChatRequest, stream=False): - """Common logic for invoke bedrock models""" + + def _invoke_bedrock_common(self, chat_request: ChatRequest, stream: bool = False): + """Common logic for invoking bedrock models""" if DEBUG: logger.info("Raw request: " + chat_request.model_dump_json()) - # convert OpenAI chat request to Bedrock SDK request + # Convert OpenAI chat request to Bedrock SDK request args = self._parse_request(chat_request) if DEBUG: logger.info("Bedrock request: " + json.dumps(args)) @@ -204,20 +211,28 @@ def _invoke_bedrock(self, chat_request: ChatRequest, stream=False): if stream: response = bedrock_runtime.converse_stream(**args) else: - response = bedrock_runtime.converse(**args) + response = run_in_executor(bedrock_runtime.converse, **args) + return response except bedrock_runtime.exceptions.ValidationException as e: logger.error("Validation Error: " + str(e)) raise HTTPException(status_code=400, detail=str(e)) except Exception as e: logger.error(e) raise HTTPException(status_code=500, detail=str(e)) - return response + + async def _invoke_bedrock_async(self, chat_request: ChatRequest): + """Invoke bedrock models in async mode using common logic""" + return await self._invoke_bedrock_common(chat_request) + + def _invoke_bedrock_stream(self, chat_request: ChatRequest): + """Invoke bedrock models in streaming mode using common logic""" + return self._invoke_bedrock_common(chat_request, stream=True) - def chat(self, chat_request: ChatRequest) -> ChatResponse: + async def chat(self, chat_request: ChatRequest) -> ChatResponse: """Default implementation for Chat API.""" message_id = self.generate_message_id() - response = self._invoke_bedrock(chat_request) + response = await self._invoke_bedrock_async(chat_request) output_message = response["output"]["message"] input_tokens = response["usage"]["inputTokens"] @@ -238,7 +253,7 @@ def chat(self, chat_request: ChatRequest) -> ChatResponse: def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]: """Default implementation for Chat Stream API""" - response = self._invoke_bedrock(chat_request, stream=True) + response = self._invoke_bedrock_stream(chat_request, stream=True) message_id = self.generate_message_id() stream = response.get("stream") @@ -677,13 +692,13 @@ class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC): accept = "application/json" content_type = "application/json" - def _invoke_model(self, args: dict, model_id: str): + async def _invoke_model(self, args: dict, model_id: str): body = json.dumps(args) if DEBUG: logger.info("Invoke Bedrock Model: " + model_id) logger.info("Bedrock request body: " + body) try: - return bedrock_runtime.invoke_model( + return await run_in_executor(bedrock_runtime.invoke_model, body=body, modelId=model_id, accept=self.accept, @@ -757,7 +772,7 @@ def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict: } return args - def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse: + async def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse: response = self._invoke_model( args=self._parse_args(embeddings_request), model_id=embeddings_request.model ) @@ -798,7 +813,7 @@ def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict: ) return args - def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse: + async def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse: response = self._invoke_model( args=self._parse_args(embeddings_request), model_id=embeddings_request.model ) diff --git a/src/api/routers/chat.py b/src/api/routers/chat.py index 1e48a48..542ee36 100644 --- a/src/api/routers/chat.py +++ b/src/api/routers/chat.py @@ -42,4 +42,4 @@ async def chat_completions( return StreamingResponse( content=model.chat_stream(chat_request), media_type="text/event-stream" ) - return model.chat(chat_request) + return await model.chat(chat_request) diff --git a/src/api/routers/embeddings.py b/src/api/routers/embeddings.py index e5cde31..3fe95d2 100644 --- a/src/api/routers/embeddings.py +++ b/src/api/routers/embeddings.py @@ -33,4 +33,4 @@ async def embeddings( embeddings_request.model = DEFAULT_EMBEDDING_MODEL # Exception will be raised if model not supported. model = get_embeddings_model(embeddings_request.model) - return model.embed(embeddings_request) + return await model.embed(embeddings_request)