Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: pass headers through as metadata #106

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions src/api/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,13 @@ def validate(self, chat_request: ChatRequest):
detail=error,
)

def _invoke_bedrock(self, chat_request: ChatRequest, stream=False):
def _invoke_bedrock(self, chat_request: ChatRequest, headers: dict, stream=False):
"""Common logic for invoke bedrock models"""
if DEBUG:
logger.info("Raw request: " + chat_request.model_dump_json())

# convert OpenAI chat request to Bedrock SDK request
args = self._parse_request(chat_request)
args = self._parse_request(chat_request, headers)
if DEBUG:
logger.info("Bedrock request: " + json.dumps(str(args)))

Expand All @@ -183,11 +183,11 @@ def _invoke_bedrock(self, chat_request: ChatRequest, stream=False):
raise HTTPException(status_code=500, detail=str(e))
return response

def chat(self, chat_request: ChatRequest) -> ChatResponse:
def chat(self, chat_request: ChatRequest, headers: dict) -> ChatResponse:
"""Default implementation for Chat API."""

message_id = self.generate_message_id()
response = self._invoke_bedrock(chat_request)
response = self._invoke_bedrock(chat_request, headers)

output_message = response["output"]["message"]
input_tokens = response["usage"]["inputTokens"]
Expand All @@ -206,9 +206,9 @@ def chat(self, chat_request: ChatRequest) -> ChatResponse:
logger.info("Proxy response :" + chat_response.model_dump_json())
return chat_response

def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
def chat_stream(self, chat_request: ChatRequest, headers: dict) -> AsyncIterable[bytes]:
"""Default implementation for Chat Stream API"""
response = self._invoke_bedrock(chat_request, stream=True)
response = self._invoke_bedrock(chat_request, headers=headers, stream=True)
message_id = self.generate_message_id()

stream = response.get("stream")
Expand Down Expand Up @@ -390,7 +390,7 @@ def _reframe_multi_payloard(self, messages: list) -> list:

return reformatted_messages

def _parse_request(self, chat_request: ChatRequest) -> dict:
def _parse_request(self, chat_request: ChatRequest, headers: dict) -> dict:
"""Create default converse request body.

Also perform validations to tool call etc.
Expand Down Expand Up @@ -420,6 +420,13 @@ def _parse_request(self, chat_request: ChatRequest) -> dict:
"system": system_prompts,
"inferenceConfig": inference_config,
}

# Pass headers through as metadata
args["requestMetadata"] = {
"X-Header-Value-0": header.get("X-Header-Value-0", "unknown"),
"X-Header-Value-1": header.get("X-Header-Value-1", "unknown"),
}

# add tool config
if chat_request.tools:
args["toolConfig"] = {
Expand Down
12 changes: 10 additions & 2 deletions src/api/routers/chat.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Annotated
import logging

from fastapi import APIRouter, Depends, Body
from fastapi.responses import StreamingResponse
Expand All @@ -8,6 +9,8 @@
from api.schema import ChatRequest, ChatResponse, ChatStreamResponse
from api.setting import DEFAULT_MODEL

logger = logging.getLogger(__name__)

router = APIRouter(
prefix="/chat",
dependencies=[Depends(api_key_auth)],
Expand All @@ -17,6 +20,7 @@

@router.post("/completions", response_model=ChatResponse | ChatStreamResponse, response_model_exclude_unset=True)
async def chat_completions(
request: Request,
chat_request: Annotated[
ChatRequest,
Body(
Expand All @@ -32,6 +36,10 @@ async def chat_completions(
),
]
):
# Log headers for security and analytics
headers = request.headers
logger.info(f"Headers: {headers}")

if chat_request.model.lower().startswith("gpt-"):
chat_request.model = DEFAULT_MODEL

Expand All @@ -40,6 +48,6 @@ async def chat_completions(
model.validate(chat_request)
if chat_request.stream:
return StreamingResponse(
content=model.chat_stream(chat_request), media_type="text/event-stream"
content=model.chat_stream(chat_request, headers), media_type="text/event-stream"
)
return model.chat(chat_request)
return model.chat(chat_request, headers)