diff --git a/src/api/app.py b/src/api/app.py index 7fa2f01..56f815d 100644 --- a/src/api/app.py +++ b/src/api/app.py @@ -42,6 +42,7 @@ async def health(): return {"status": "OK"} + @app.exception_handler(RequestValidationError) async def validation_exception_handler(request, exc): return PlainTextResponse(str(exc), status_code=400) diff --git a/src/api/models/bedrock.py b/src/api/models/bedrock.py index 16bfd33..e9fdf1b 100644 --- a/src/api/models/bedrock.py +++ b/src/api/models/bedrock.py @@ -5,6 +5,7 @@ import time from abc import ABC from typing import AsyncIterable, Iterable, Literal +from api.models.model_manager import ModelManager import boto3 import numpy as np @@ -75,83 +76,88 @@ def get_inference_region_prefix(): ENCODER = tiktoken.get_encoding("cl100k_base") +# Initialize the model list. +#bedrock_model_list = list_bedrock_models() -def list_bedrock_models() -> dict: - """Automatically getting a list of supported models. - - Returns a model list combines: - - ON_DEMAND models. - - Cross-Region Inference Profiles (if enabled via Env) - """ - model_list = {} - try: - profile_list = [] - if ENABLE_CROSS_REGION_INFERENCE: - # List system defined inference profile IDs - response = bedrock_client.list_inference_profiles( - maxResults=1000, - typeEquals='SYSTEM_DEFINED' - ) - profile_list = [p['inferenceProfileId'] for p in response['inferenceProfileSummaries']] - - # List foundation models, only cares about text outputs here. - response = bedrock_client.list_foundation_models( - byOutputModality='TEXT' - ) - - for model in response['modelSummaries']: - model_id = model.get('modelId', 'N/A') - stream_supported = model.get('responseStreamingSupported', True) - status = model['modelLifecycle'].get('status', 'ACTIVE') - - # currently, use this to filter out rerank models and legacy models - if not stream_supported or status != "ACTIVE": - continue - - inference_types = model.get('inferenceTypesSupported', []) - input_modalities = model['inputModalities'] - # Add on-demand model list - if 'ON_DEMAND' in inference_types: - model_list[model_id] = { - 'modalities': input_modalities - } - - # Add cross-region inference model list. - profile_id = cr_inference_prefix + '.' + model_id - if profile_id in profile_list: - model_list[profile_id] = { - 'modalities': input_modalities - } - except Exception as e: - logger.error(f"Unable to list models: {str(e)}") +class BedrockModel(BaseChatModel): - if not model_list: - # In case stack not updated. - model_list[DEFAULT_MODEL] = { - 'modalities': ["TEXT", "IMAGE"] - } + #bedrock_model_list = None + model_manager = None + def __init__(self): + super().__init__() + self.model_manager = ModelManager() - return model_list + def list_bedrock_models(self) -> dict: + """Automatically getting a list of supported models. + Returns a model list combines: + - ON_DEMAND models. + - Cross-Region Inference Profiles (if enabled via Env) + """ + #model_list = {} + try: + profile_list = [] + if ENABLE_CROSS_REGION_INFERENCE: + # List system defined inference profile IDs + response = bedrock_client.list_inference_profiles( + maxResults=1000, + typeEquals='SYSTEM_DEFINED' + ) + profile_list = [p['inferenceProfileId'] for p in response['inferenceProfileSummaries']] -# Initialize the model list. -bedrock_model_list = list_bedrock_models() + # List foundation models, only cares about text outputs here. + response = bedrock_client.list_foundation_models( + byOutputModality='TEXT' + ) + for model in response['modelSummaries']: + model_id = model.get('modelId', 'N/A') + stream_supported = model.get('responseStreamingSupported', True) + status = model['modelLifecycle'].get('status', 'ACTIVE') + + # currently, use this to filter out rerank models and legacy models + if not stream_supported or status != "ACTIVE": + continue + + inference_types = model.get('inferenceTypesSupported', []) + input_modalities = model['inputModalities'] + # Add on-demand model list + if 'ON_DEMAND' in inference_types: + model[model_id] = { + 'modalities': input_modalities + } + self.model_manager.add_model(model) + # model_list[model_id] = { + # 'modalities': input_modalities + # } + + # Add cross-region inference model list. + profile_id = cr_inference_prefix + '.' + model_id + if profile_id in profile_list: + model[profile_id] = { + 'modalities': input_modalities + } + self.model_manager.add_model(model) -class BedrockModel(BaseChatModel): + except Exception as e: + logger.error(e) + raise HTTPException(status_code=500, detail=str(e)) def list_models(self) -> list[str]: """Always refresh the latest model list""" - global bedrock_model_list - bedrock_model_list = list_bedrock_models() - return list(bedrock_model_list.keys()) + #global bedrock_model_list + self.list_bedrock_models() + return list(self.model_manager.get_all_models().keys()) def validate(self, chat_request: ChatRequest): """Perform basic validation on requests""" + error = "" + + ###### TODO - failing here as kb and agents are not in the bedrock_model_list # check if model is supported - if chat_request.model not in bedrock_model_list.keys(): + if chat_request.model not in self.model_manager.get_all_models().keys(): error = f"Unsupported model {chat_request.model}, please use models API to get a list of supported models" if error: @@ -659,7 +665,7 @@ def _parse_content_parts( @staticmethod def is_supported_modality(model_id: str, modality: str = "IMAGE") -> bool: - model = bedrock_model_list.get(model_id) + model = ModelManager().models.get(model_id) modalities = model.get('modalities', []) if modality in modalities: return True @@ -851,4 +857,4 @@ def get_embeddings_model(model_id: str) -> BedrockEmbeddingsModel: raise HTTPException( status_code=400, detail="Unsupported embedding model id " + model_id, - ) + ) \ No newline at end of file diff --git a/src/api/models/bedrock_agents.py b/src/api/models/bedrock_agents.py new file mode 100644 index 0000000..6f4206f --- /dev/null +++ b/src/api/models/bedrock_agents.py @@ -0,0 +1,391 @@ +import base64 +import json +import logging +import re +import time +from abc import ABC +from typing import AsyncIterable + +import boto3 +from botocore.config import Config +import numpy as np +import requests +import tiktoken +from fastapi import HTTPException +from api.models.model_manager import ModelManager + +from api.models.bedrock import ( + BedrockModel, + bedrock_client, + bedrock_runtime) + +from api.schema import ( + ChatResponse, + ChatRequest, + ChatResponseMessage, + ChatStreamResponse, + ChoiceDelta +) + +from api.setting import (DEBUG, AWS_REGION, DEFAULT_KB_MODEL, KB_PREFIX, AGENT_PREFIX) + +logger = logging.getLogger(__name__) +config = Config(connect_timeout=1, read_timeout=120, retries={"max_attempts": 1}) + +bedrock_agent = boto3.client( + service_name="bedrock-agent", + region_name=AWS_REGION, + config=config, + ) + +bedrock_agent_runtime = boto3.client( + service_name="bedrock-agent-runtime", + region_name=AWS_REGION, + config=config, +) + + +class BedrockAgents(BedrockModel): + + #bedrock_model_list = None + def __init__(self): + super().__init__() + model_manager = ModelManager() + + def list_models(self) -> list[str]: + """Always refresh the latest model list""" + super().list_models() + self.get_kbs() + self.get_agents() + return list(self.model_manager.get_all_models().keys()) + + # get list of active knowledge bases + def get_kbs(self): + + # List knowledge bases + response = bedrock_agent.list_knowledge_bases(maxResults=100) + + # Print knowledge base information + for kb in response['knowledgeBaseSummaries']: + name = f"{KB_PREFIX}{kb['name']}" + val = { + "system": True, # Supports system prompts for context setting + "multimodal": True, # Capable of processing both text and images + "tool_call": True, + "stream_tool_call": True, + "kb_id": kb['knowledgeBaseId'], + "model_id": DEFAULT_KB_MODEL + } + #self.model_manager.get_all_models()[name] = val + model = {} + model[name]=val + self.model_manager.add_model(model) + + def get_latest_agent_alias(self, client, agent_id): + + # List all aliases for the agent + response = client.list_agent_aliases( + agentId=agent_id, + maxResults=100 # Adjust based on your needs + ) + + if not response.get('agentAliasSummaries'): + return None + + # Sort aliases by creation time to get the latest one + aliases = response['agentAliasSummaries'] + latest_alias = None + latest_creation_time = None + + for alias in aliases: + # Only consider aliases that are in PREPARED state + if alias['agentAliasStatus'] == 'PREPARED': + creation_time = alias.get('creationDateTime') + if latest_creation_time is None or creation_time > latest_creation_time: + latest_creation_time = creation_time + latest_alias = alias + + if latest_alias: + return latest_alias['agentAliasId'] + + return None + + def get_agents(self): + bedrock_ag = boto3.client( + service_name="bedrock-agent", + region_name=AWS_REGION, + config=config, + ) + # List Agents + response = bedrock_agent.list_agents(maxResults=100) + + # Prepare agent for display + for agent in response['agentSummaries']: + + if (agent['agentStatus'] != 'PREPARED'): + continue + + name = f"{AGENT_PREFIX}{agent['agentName']}" + agentId = agent['agentId'] + + aliasId = self.get_latest_agent_alias(bedrock_ag, agentId) + if (aliasId is None): + continue + + val = { + "system": False, # Supports system prompts for context setting. These are already set in Bedrock Agent configuration + "multimodal": True, # Capable of processing both text and images + "tool_call": False, # Tool Use not required for Agents + "stream_tool_call": False, + "agent_id": agentId, + "alias_id": aliasId + } + #self.model_manager.get_all_models()[name] = val + model = {} + model[name]=val + self.model_manager.add_model(model) + + + def _invoke_bedrock(self, chat_request: ChatRequest, stream=False): + """Common logic for invoke bedrock models""" + + # convert OpenAI chat request to Bedrock SDK request + args = self._parse_request(chat_request) + if DEBUG: + logger.info("Bedrock request: " + json.dumps(str(args))) + + try: + + if stream: + response = bedrock_runtime.converse_stream(**args) + else: + response = bedrock_runtime.converse(**args) + + + except bedrock_client.exceptions.ValidationException as e: + logger.error("Validation Error: " + str(e)) + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.error(e) + raise HTTPException(status_code=500, detail=str(e)) + return response + + def chat(self, chat_request: ChatRequest) -> ChatResponse: + """Default implementation for Chat API.""" + #chat: {chat_request}") + + message_id = self.generate_message_id() + response = self._invoke_bedrock(chat_request) + + output_message = response["output"]["message"] + input_tokens = response["usage"]["inputTokens"] + output_tokens = response["usage"]["outputTokens"] + finish_reason = response["stopReason"] + + chat_response = self._create_response( + model=chat_request.model, + message_id=message_id, + content=output_message["content"], + finish_reason=finish_reason, + input_tokens=input_tokens, + output_tokens=output_tokens, + ) + if DEBUG: + logger.info("Proxy response :" + chat_response.model_dump_json()) + return chat_response + + def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]: + + """Default implementation for Chat Stream API""" + + response = '' + message_id = self.generate_message_id() + + if (chat_request.model.startswith(KB_PREFIX)): + response = self._invoke_kb(chat_request, stream=True) + elif (chat_request.model.startswith(AGENT_PREFIX)): + response = self._invoke_agent(chat_request, stream=True) + + _event_stream = response["completion"] + + chunk_count = 1 + message = ChatResponseMessage( + role="assistant", + content="", + ) + stream_response = ChatStreamResponse( + id=message_id, + model=chat_request.model, + choices=[ + ChoiceDelta( + index=0, + delta=message, + logprobs=None, + finish_reason=None, + ) + ], + usage=None, + ) + yield self.stream_response_to_bytes(stream_response) + + for event in _event_stream: + chunk_count += 1 + if "chunk" in event: + _data = event["chunk"]["bytes"].decode("utf8") + message = ChatResponseMessage(content=_data) + + stream_response = ChatStreamResponse( + id=message_id, + model=chat_request.model, + choices=[ + ChoiceDelta( + index=0, + delta=message, + logprobs=None, + finish_reason=None, + ) + ], + usage=None, + ) + yield self.stream_response_to_bytes(stream_response) + + #message = self._make_fully_cited_answer(_data, event, False, 0) + + # return an [DONE] message at the end. + yield self.stream_response_to_bytes() + return None + else: + response = self._invoke_bedrock(chat_request, stream=True) + + stream = response.get("stream") + for chunk in stream: + stream_response = self._create_response_stream( + model_id=chat_request.model, message_id=message_id, chunk=chunk + ) + if not stream_response: + continue + if DEBUG: + logger.info("Proxy response :" + stream_response.model_dump_json()) + if stream_response.choices: + yield self.stream_response_to_bytes(stream_response) + elif ( + chat_request.stream_options + and chat_request.stream_options.include_usage + ): + # An empty choices for Usage as per OpenAI doc below: + # if you set stream_options: {"include_usage": true}. + # an additional chunk will be streamed before the data: [DONE] message. + # The usage field on this chunk shows the token usage statistics for the entire request, + # and the choices field will always be an empty array. + # All other chunks will also include a usage field, but with a null value. + yield self.stream_response_to_bytes(stream_response) + + # return an [DONE] message at the end. + yield self.stream_response_to_bytes() + + + + # This function invokes knowledgebase + def _invoke_kb(self, chat_request: ChatRequest, stream=False): + """Common logic for invoke kb with default model""" + if DEBUG: + logger.info("BedrockAgents._invoke_kb: Raw request: " + chat_request.model_dump_json()) + + # convert OpenAI chat request to Bedrock SDK request + args = self._parse_request(chat_request) + + + if DEBUG: + logger.info("Bedrock request: " + json.dumps(str(args))) + + model = self.model_manager.get_all_models()[chat_request.model] + args['modelId'] = model['model_id'] + + + ################ + + try: + query = args['messages'][0]['content'][0]['text'] + messages = args['messages'] + query = messages[len(messages)-1]['content'][0]['text'] + + # Step 1 - Retrieve Context + retrieval_request_body = { + "retrievalQuery": { + "text": query + }, + "retrievalConfiguration": { + "vectorSearchConfiguration": { + "numberOfResults": 2 + } + } + } + + # Make the retrieve request + response = bedrock_agent_runtime.retrieve(knowledgeBaseId=model['kb_id'], **retrieval_request_body) + + # Extract and return the results + context = '' + if "retrievalResults" in response: + for result in response["retrievalResults"]: + result = result["content"]["text"] + context = f"{context}\n{result}" + + + # Step 2 - Append context in the prompt + args['messages'][0]['content'][0]['text'] = f"Context: {context} \n\n {query}" + + # Step 3 - Make the converse request + if stream: + response = bedrock_runtime.converse_stream(**args) + else: + response = bedrock_runtime.converse(**args) + + except Exception as e: + logger.error(e) + raise HTTPException(status_code=500, detail=str(e)) + + ############### + return response + + # This function invokes knowledgebase + def _invoke_agent(self, chat_request: ChatRequest, stream=False): + """Common logic for invoke agent """ + if DEBUG: + logger.info("BedrockAgents._invoke_agent: Raw request: " + chat_request.model_dump_json()) + + # convert OpenAI chat request to Bedrock SDK request + args = self._parse_request(chat_request) + + + if DEBUG: + logger.info("Bedrock request: " + json.dumps(str(args))) + + model = self.model_manager.get_all_models()[chat_request.model] + + ################ + + try: + query = args['messages'][0]['content'][0]['text'] + messages = args['messages'] + query = messages[len(messages)-1]['content'][0]['text'] + + + # Step 1 - Retrieve Context + request_params = { + 'agentId': model['agent_id'], + 'agentAliasId': model['alias_id'], + 'sessionId': 'unique-session-id', # Generate a unique session ID + 'inputText': query + } + + # Make the retrieve request + # Invoke the agent + response = bedrock_agent_runtime.invoke_agent(**request_params) + return response + + except Exception as e: + logger.error(e) + raise HTTPException(status_code=500, detail=str(e)) + + \ No newline at end of file diff --git a/src/api/models/model_manager.py b/src/api/models/model_manager.py new file mode 100644 index 0000000..2974634 --- /dev/null +++ b/src/api/models/model_manager.py @@ -0,0 +1,35 @@ +# This is a singleton class to maintain list of models +class ModelManager: + _instance = None + _models = None + + def __new__(cls, *args, **kwargs): + # Ensure that only one instance of ModelManager is created + if cls._instance is None: + cls._instance = super(ModelManager, cls).__new__(cls, *args, **kwargs) + cls._instance._models = {} # Initialize the list of models + + return cls._instance + + def get_all_models(self): + return self._models + + def add_model(self, model): + """Add a model to the list.""" + if (self._models is None): + self._models = {} + self._models.update(model) + + + def get_model_by_name(self, model_name: str): + """Get the list of models.""" + return self._models + + def clear_models(self): + """Clear the list of models.""" + self._models.clear() + self._models = {} + + def __repr__(self): + return f"ModelManager(models={self._models})" + diff --git a/src/api/routers/chat.py b/src/api/routers/chat.py index 1e48a48..c12181b 100644 --- a/src/api/routers/chat.py +++ b/src/api/routers/chat.py @@ -2,12 +2,15 @@ from fastapi import APIRouter, Depends, Body from fastapi.responses import StreamingResponse - +import logging from api.auth import api_key_auth -from api.models.bedrock import BedrockModel +#from api.models.bedrock import BedrockModel +from api.models.bedrock_agents import BedrockAgents from api.schema import ChatRequest, ChatResponse, ChatStreamResponse from api.setting import DEFAULT_MODEL +logger = logging.getLogger(__name__) + router = APIRouter( prefix="/chat", dependencies=[Depends(api_key_auth)], @@ -32,14 +35,16 @@ async def chat_completions( ), ] ): + # this method gets called by front-end + if chat_request.model.lower().startswith("gpt-"): chat_request.model = DEFAULT_MODEL # Exception will be raised if model not supported. - model = BedrockModel() + model = BedrockAgents() model.validate(chat_request) if chat_request.stream: - return StreamingResponse( - content=model.chat_stream(chat_request), media_type="text/event-stream" - ) + response = StreamingResponse(content=model.chat_stream(chat_request), media_type="text/event-stream") + return response + return model.chat(chat_request) diff --git a/src/api/routers/embeddings.py b/src/api/routers/embeddings.py index e5cde31..01f83d0 100644 --- a/src/api/routers/embeddings.py +++ b/src/api/routers/embeddings.py @@ -3,7 +3,7 @@ from fastapi import APIRouter, Depends, Body from api.auth import api_key_auth -from api.models.bedrock import get_embeddings_model +#from api.models.bedrock import get_embeddings_model from api.schema import EmbeddingsRequest, EmbeddingsResponse from api.setting import DEFAULT_EMBEDDING_MODEL diff --git a/src/api/routers/model.py b/src/api/routers/model.py index ce5e8a1..36c4d02 100644 --- a/src/api/routers/model.py +++ b/src/api/routers/model.py @@ -1,10 +1,13 @@ from typing import Annotated - +import logging from fastapi import APIRouter, Depends, HTTPException, Path from api.auth import api_key_auth -from api.models.bedrock import BedrockModel +#from api.models.bedrock import BedrockModel +from api.models.bedrock_agents import BedrockAgents from api.schema import Models, Model +logger = logging.getLogger(__name__) + router = APIRouter( prefix="/models", @@ -12,16 +15,18 @@ # responses={404: {"description": "Not found"}}, ) -chat_model = BedrockModel() - +#chat_model = BedrockModel() +chat_model = BedrockAgents() async def validate_model_id(model_id: str): + logger.info(f"validate_model_id: {model_id}") if model_id not in chat_model.list_models(): raise HTTPException(status_code=500, detail="Unsupported Model Id") @router.get("", response_model=Models) async def list_models(): + model_list = [ Model(id=model_id) for model_id in chat_model.list_models() ] @@ -38,5 +43,6 @@ async def get_model( Path(description="Model ID", example="anthropic.claude-3-sonnet-20240229-v1:0"), ] ): + logger.info(f"get_model: {model_id}") await validate_model_id(model_id) return Model(id=model_id) diff --git a/src/api/setting.py b/src/api/setting.py index 9026202..940cd0b 100644 --- a/src/api/setting.py +++ b/src/api/setting.py @@ -20,3 +20,13 @@ "DEFAULT_EMBEDDING_MODEL", "cohere.embed-multilingual-v3" ) ENABLE_CROSS_REGION_INFERENCE = os.environ.get("ENABLE_CROSS_REGION_INFERENCE", "true").lower() != "false" + +KB_PREFIX = 'kb-' +AGENT_PREFIX = 'ag-' + +DEFAULT_KB_MODEL = os.environ.get( + "DEFAULT_KB_MODEL", "anthropic.claude-3-haiku-20240307-v1:0" +) + + +DEFAULT_KB_MODEL_ARN = f'arn:aws:bedrock:{AWS_REGION}::foundation-model/{DEFAULT_KB_MODEL}' \ No newline at end of file