Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

run non-stream requests as async calls to avoid blocking requests #23

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 28 additions & 13 deletions src/api/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from typing import AsyncIterable, Iterable, Literal

import boto3
import asyncio
import functools
import numpy as np
import requests
import tiktoken
Expand Down Expand Up @@ -55,6 +57,11 @@

ENCODER = tiktoken.get_encoding("cl100k_base")

async def run_in_executor(func, **args):
"""Run a function in an executor."""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, functools.partial(func, **args))


class BedrockModel(BaseChatModel):
# https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#conversation-inference-supported-models-features
Expand Down Expand Up @@ -189,13 +196,13 @@ def validate(self, chat_request: ChatRequest):
status_code=400,
detail=error,
)

def _invoke_bedrock(self, chat_request: ChatRequest, stream=False):
"""Common logic for invoke bedrock models"""
def _invoke_bedrock_common(self, chat_request: ChatRequest, stream: bool = False):
"""Common logic for invoking bedrock models"""
if DEBUG:
logger.info("Raw request: " + chat_request.model_dump_json())

# convert OpenAI chat request to Bedrock SDK request
# Convert OpenAI chat request to Bedrock SDK request
args = self._parse_request(chat_request)
if DEBUG:
logger.info("Bedrock request: " + json.dumps(args))
Expand All @@ -204,20 +211,28 @@ def _invoke_bedrock(self, chat_request: ChatRequest, stream=False):
if stream:
response = bedrock_runtime.converse_stream(**args)
else:
response = bedrock_runtime.converse(**args)
response = run_in_executor(bedrock_runtime.converse, **args)
return response
except bedrock_runtime.exceptions.ValidationException as e:
logger.error("Validation Error: " + str(e))
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(e)
raise HTTPException(status_code=500, detail=str(e))
return response

async def _invoke_bedrock_async(self, chat_request: ChatRequest):
"""Invoke bedrock models in async mode using common logic"""
return await self._invoke_bedrock_common(chat_request)

def _invoke_bedrock_stream(self, chat_request: ChatRequest):
"""Invoke bedrock models in streaming mode using common logic"""
return self._invoke_bedrock_common(chat_request, stream=True)

def chat(self, chat_request: ChatRequest) -> ChatResponse:
async def chat(self, chat_request: ChatRequest) -> ChatResponse:
"""Default implementation for Chat API."""

message_id = self.generate_message_id()
response = self._invoke_bedrock(chat_request)
response = await self._invoke_bedrock_async(chat_request)

output_message = response["output"]["message"]
input_tokens = response["usage"]["inputTokens"]
Expand All @@ -238,7 +253,7 @@ def chat(self, chat_request: ChatRequest) -> ChatResponse:

def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
"""Default implementation for Chat Stream API"""
response = self._invoke_bedrock(chat_request, stream=True)
response = self._invoke_bedrock_stream(chat_request, stream=True)
message_id = self.generate_message_id()

stream = response.get("stream")
Expand Down Expand Up @@ -677,13 +692,13 @@ class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC):
accept = "application/json"
content_type = "application/json"

def _invoke_model(self, args: dict, model_id: str):
async def _invoke_model(self, args: dict, model_id: str):
body = json.dumps(args)
if DEBUG:
logger.info("Invoke Bedrock Model: " + model_id)
logger.info("Bedrock request body: " + body)
try:
return bedrock_runtime.invoke_model(
return await run_in_executor(bedrock_runtime.invoke_model,
body=body,
modelId=model_id,
accept=self.accept,
Expand Down Expand Up @@ -757,7 +772,7 @@ def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict:
}
return args

def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
async def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
response = self._invoke_model(
args=self._parse_args(embeddings_request), model_id=embeddings_request.model
)
Expand Down Expand Up @@ -798,7 +813,7 @@ def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict:
)
return args

def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
async def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
response = self._invoke_model(
args=self._parse_args(embeddings_request), model_id=embeddings_request.model
)
Expand Down
2 changes: 1 addition & 1 deletion src/api/routers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,4 @@ async def chat_completions(
return StreamingResponse(
content=model.chat_stream(chat_request), media_type="text/event-stream"
)
return model.chat(chat_request)
return await model.chat(chat_request)
2 changes: 1 addition & 1 deletion src/api/routers/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ async def embeddings(
embeddings_request.model = DEFAULT_EMBEDDING_MODEL
# Exception will be raised if model not supported.
model = get_embeddings_model(embeddings_request.model)
return model.embed(embeddings_request)
return await model.embed(embeddings_request)