From 9aa52a8199b29f25be0c73e37222df4a155cedb3 Mon Sep 17 00:00:00 2001 From: Deepesh Dhapola Date: Fri, 20 Dec 2024 12:35:39 +0530 Subject: [PATCH 1/3] added support for agents and kb --- src/api/models/bedrock.py | 25 +- src/api/models/bedrock_agents.py | 1064 ++++++++++++++++++++++++++++++ src/api/routers/chat.py | 19 +- src/api/routers/embeddings.py | 2 +- src/api/routers/model.py | 14 +- 5 files changed, 1111 insertions(+), 13 deletions(-) create mode 100644 src/api/models/bedrock_agents.py diff --git a/src/api/models/bedrock.py b/src/api/models/bedrock.py index f59856a..2bb4b74 100644 --- a/src/api/models/bedrock.py +++ b/src/api/models/bedrock.py @@ -335,13 +335,27 @@ class BedrockModel(BaseChatModel): "tool_call": True, "stream_tool_call": True, }, + + "knowledgebase": { + "system": True, # Supports system prompts for context setting + "multimodal": True, # Capable of processing both text and images + "tool_call": True, + "stream_tool_call": True, + }, } + def supported_models(self): + logger.info("BedrockModel.supported_models") + return list(self._supported_models.keys()) + def list_models(self) -> list[str]: + logger.info("BedrockModel.list_models") return list(self._supported_models.keys()) def validate(self, chat_request: ChatRequest): """Perform basic validation on requests""" + logger.info(f"BedrockModel.validate: {chat_request}") + error = "" # check if model is supported if chat_request.model not in self._supported_models.keys(): @@ -361,7 +375,7 @@ def validate(self, chat_request: ChatRequest): def _invoke_bedrock(self, chat_request: ChatRequest, stream=False): """Common logic for invoke bedrock models""" if DEBUG: - logger.info("Raw request: " + chat_request.model_dump_json()) + logger.info("BedrockModel._invoke_bedrock: Raw request: " + chat_request.model_dump_json()) # convert OpenAI chat request to Bedrock SDK request args = self._parse_request(chat_request) @@ -383,7 +397,7 @@ def _invoke_bedrock(self, chat_request: ChatRequest, stream=False): def chat(self, chat_request: ChatRequest) -> ChatResponse: """Default implementation for Chat API.""" - + logger.info(f"BedrockModel.chat: {chat_request}") message_id = self.generate_message_id() response = self._invoke_bedrock(chat_request) @@ -406,6 +420,9 @@ def chat(self, chat_request: ChatRequest) -> ChatResponse: def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]: """Default implementation for Chat Stream API""" + + logger.info(f"BedrockModel.chat_stream: {chat_request}") + response = self._invoke_bedrock(chat_request, stream=True) message_id = self.generate_message_id() @@ -444,6 +461,7 @@ def _parse_system_prompts(self, chat_request: ChatRequest) -> list[dict[str, str See example: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#message-inference-examples """ + logger.info(f"BedrockModel._parse_system_prompts: {chat_request}") system_prompts = [] for message in chat_request.messages: @@ -467,6 +485,9 @@ def _parse_messages(self, chat_request: ChatRequest) -> list[dict]: See example: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#message-inference-examples """ + + logger.info(f"BedrockModel._parse_messages: {chat_request}") + messages = [] for message in chat_request.messages: if isinstance(message, UserMessage): diff --git a/src/api/models/bedrock_agents.py b/src/api/models/bedrock_agents.py new file mode 100644 index 0000000..50aac06 --- /dev/null +++ b/src/api/models/bedrock_agents.py @@ -0,0 +1,1064 @@ +import base64 +import json +import logging +import re +import time +from abc import ABC +from typing import AsyncIterable, Iterable, Literal + +import boto3 +from botocore.config import Config +import numpy as np +import requests +import tiktoken +from fastapi import HTTPException + +from api.models.base import BaseChatModel, BaseEmbeddingsModel +from api.models.bedrock import BedrockModel +from api.schema import ( + # Chat + ChatResponse, + ChatRequest, + Choice, + ChatResponseMessage, + Usage, + ChatStreamResponse, + ImageContent, + TextContent, + ToolCall, + ChoiceDelta, + UserMessage, + AssistantMessage, + ToolMessage, + Function, + ResponseFunction, + # Embeddings + EmbeddingsRequest, + EmbeddingsResponse, + EmbeddingsUsage, + Embedding, +) +from api.setting import DEBUG, AWS_REGION + +KB_PREFIX = 'kb-' +AGENT_PREFIX = 'ag-' + +DEFAULT_KB_MODEL = 'anthropic.claude-3-haiku-20240307-v1:0' +DEFAULT_KB_MODEL_ARN = f'arn:aws:bedrock:{AWS_REGION}::foundation-model/{DEFAULT_KB_MODEL}' + +logger = logging.getLogger(__name__) + +config = Config(connect_timeout=1, read_timeout=120, retries={"max_attempts": 1}) + +bedrock_client = boto3.client( + service_name="bedrock-runtime", + region_name=AWS_REGION, + config=config, +) + +bedrock_agent = boto3.client( + service_name="bedrock-agent-runtime", + region_name=AWS_REGION, + config=config, +) + +SUPPORTED_BEDROCK_EMBEDDING_MODELS = { + "cohere.embed-multilingual-v3": "Cohere Embed Multilingual", + "cohere.embed-english-v3": "Cohere Embed English", + # Disable Titan embedding. + # "amazon.titan-embed-text-v1": "Titan Embeddings G1 - Text", + # "amazon.titan-embed-image-v1": "Titan Multimodal Embeddings G1" +} + +ENCODER = tiktoken.get_encoding("cl100k_base") + + +class BedrockAgents(BedrockModel): + # https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#conversation-inference-supported-models-features + + + _supported_models = { + + } + + # get list of active knowledgebases + def get_kbs(self): + + bedrock_ag = boto3.client( + service_name="bedrock-agent", + region_name=AWS_REGION, + config=config, + ) + + # List knowledge bases + response = bedrock_ag.list_knowledge_bases(maxResults=100) + + # Print knowledge base information + for kb in response['knowledgeBaseSummaries']: + name = f"{KB_PREFIX}{kb['name']}" + val = { + "system": True, # Supports system prompts for context setting + "multimodal": True, # Capable of processing both text and images + "tool_call": True, + "stream_tool_call": True, + "kb_id": kb['knowledgeBaseId'], + "model_id": DEFAULT_KB_MODEL + } + self._supported_models[name] = val + + def get_latest_agent_alias(self, client, agent_id): + + # List all aliases for the agent + response = client.list_agent_aliases( + agentId=agent_id, + maxResults=100 # Adjust based on your needs + ) + + if not response.get('agentAliasSummaries'): + return None + + # Sort aliases by creation time to get the latest one + aliases = response['agentAliasSummaries'] + latest_alias = None + latest_creation_time = None + + for alias in aliases: + # Only consider aliases that are in PREPARED state + if alias['agentAliasStatus'] == 'PREPARED': + creation_time = alias.get('creationDateTime') + if latest_creation_time is None or creation_time > latest_creation_time: + latest_creation_time = creation_time + latest_alias = alias + + if latest_alias: + return latest_alias['agentAliasId'] + + return None + + + + def get_agents(self): + bedrock_ag = boto3.client( + service_name="bedrock-agent", + region_name=AWS_REGION, + config=config, + ) + # List Agents + response = bedrock_ag.list_agents(maxResults=100) + + # Prepare agent for display + for agent in response['agentSummaries']: + + if (agent['agentStatus'] != 'PREPARED'): + continue + + name = f"{AGENT_PREFIX}{agent['agentName']}" + agentId = agent['agentId'] + + aliasId = self.get_latest_agent_alias(bedrock_ag, agentId) + if (aliasId is None): + continue + + val = { + "system": False, # Supports system prompts for context setting + "multimodal": True, # Capable of processing both text and images + "tool_call": True, + "stream_tool_call": False, + "agent_id": agentId, + "alias_id": aliasId + } + self._supported_models[name] = val + + def get_models(self): + + client = boto3.client( + service_name="bedrock", + region_name=AWS_REGION, + config=config, + ) + response = client.list_foundation_models(byInferenceType='ON_DEMAND') + + # Prepare agent for display + for model in response['modelSummaries']: + if ((model['modelLifecycle']['status'] == 'ACTIVE') and + #('ON_DEMAND' in model['inferenceTypesSupported']) and + ('EMBEDDING' not in model['outputModalities'])) : + name = f"{model['modelId']}" + stream_support = False + + if (('responseStreamingSupported' in model.keys()) and + (model['responseStreamingSupported'] is True)): + stream_support = True + + val = { + "system": True, # Supports system prompts for context setting + "multimodal": len(model['inputModalities'])>1, # Capable of processing both text and images + "tool_call": True, + "stream_tool_call": stream_support, + } + self._supported_models[name] = val + + def supported_models(self): + logger.info("BedrockAgents.supported_models") + return list(self._supported_models.keys()) + + def list_models(self) -> list[str]: + logger.info("BedrockAgents.list_models") + self.get_models() + self.get_kbs() + self.get_agents() + return list(self._supported_models.keys()) + + def validate(self, chat_request: ChatRequest): + """Perform basic validation on requests""" + #logger.info(f"BedrockAgents.validate: {chat_request}") + + error = "" + # check if model is supported + if chat_request.model not in self._supported_models.keys(): + error = f"Unsupported model {chat_request.model}, please use models API to get a list of supported models" + + # check if tool call is supported + elif chat_request.tools and not self._is_tool_call_supported(chat_request.model, stream=chat_request.stream): + tool_call_info = "Tool call with streaming" if chat_request.stream else "Tool call" + error = f"{tool_call_info} is currently not supported by {chat_request.model}" + + if error: + raise HTTPException( + status_code=400, + detail=error, + ) + + def _invoke_bedrock(self, chat_request: ChatRequest, stream=False): + """Common logic for invoke bedrock models""" + if DEBUG: + logger.info("BedrockAgents._invoke_bedrock: Raw request: " + chat_request.model_dump_json()) + + # convert OpenAI chat request to Bedrock SDK request + args = self._parse_request(chat_request) + if DEBUG: + logger.info("Bedrock request: " + json.dumps(str(args))) + + try: + + if stream: + response = bedrock_client.converse_stream(**args) + else: + response = bedrock_client.converse(**args) + + + except bedrock_client.exceptions.ValidationException as e: + logger.error("Validation Error: " + str(e)) + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.error(e) + raise HTTPException(status_code=500, detail=str(e)) + return response + + def chat(self, chat_request: ChatRequest) -> ChatResponse: + """Default implementation for Chat API.""" + #chat: {chat_request}") + message_id = self.generate_message_id() + response = self._invoke_bedrock(chat_request) + + output_message = response["output"]["message"] + input_tokens = response["usage"]["inputTokens"] + output_tokens = response["usage"]["outputTokens"] + finish_reason = response["stopReason"] + + chat_response = self._create_response( + model=chat_request.model, + message_id=message_id, + content=output_message["content"], + finish_reason=finish_reason, + input_tokens=input_tokens, + output_tokens=output_tokens, + ) + if DEBUG: + logger.info("Proxy response :" + chat_response.model_dump_json()) + return chat_response + + def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]: + + """Default implementation for Chat Stream API""" + logger.info(f"BedrockAgents.chat_stream: {chat_request}") + + response = '' + message_id = self.generate_message_id() + + if (chat_request.model.startswith(KB_PREFIX)): + response = self._invoke_kb(chat_request, stream=True) + elif (chat_request.model.startswith(AGENT_PREFIX)): + response = self._invoke_agent(chat_request, stream=True) + + _event_stream = response["completion"] + + chunk_count = 1 + message = ChatResponseMessage( + role="assistant", + content="", + ) + stream_response = ChatStreamResponse( + id=message_id, + model=chat_request.model, + choices=[ + ChoiceDelta( + index=0, + delta=message, + logprobs=None, + finish_reason=None, + ) + ], + usage=None, + ) + yield self.stream_response_to_bytes(stream_response) + + for event in _event_stream: + #print(f'\n\nChunk {chunk_count}: {event}') + chunk_count += 1 + if "chunk" in event: + _data = event["chunk"]["bytes"].decode("utf8") + message = ChatResponseMessage(content=_data) + + stream_response = ChatStreamResponse( + id=message_id, + model=chat_request.model, + choices=[ + ChoiceDelta( + index=0, + delta=message, + logprobs=None, + finish_reason=None, + ) + ], + usage=None, + ) + yield self.stream_response_to_bytes(stream_response) + + #message = self._make_fully_cited_answer(_data, event, False, 0) + + # return an [DONE] message at the end. + yield self.stream_response_to_bytes() + return None + else: + response = self._invoke_bedrock(chat_request, stream=True) + + print(response) + stream = response.get("stream") + # chunk_count = 1 + for chunk in stream: + #print(f'\n\nChunk {chunk_count}: {chunk}') + # chunk_count += 1 + + stream_response = self._create_response_stream( + model_id=chat_request.model, message_id=message_id, chunk=chunk + ) + #print(f'stream_response: {stream_response}') + if not stream_response: + continue + if DEBUG: + logger.info("Proxy response :" + stream_response.model_dump_json()) + if stream_response.choices: + yield self.stream_response_to_bytes(stream_response) + elif ( + chat_request.stream_options + and chat_request.stream_options.include_usage + ): + # An empty choices for Usage as per OpenAI doc below: + # if you set stream_options: {"include_usage": true}. + # an additional chunk will be streamed before the data: [DONE] message. + # The usage field on this chunk shows the token usage statistics for the entire request, + # and the choices field will always be an empty array. + # All other chunks will also include a usage field, but with a null value. + yield self.stream_response_to_bytes(stream_response) + + # return an [DONE] message at the end. + yield self.stream_response_to_bytes() + + def _parse_system_prompts(self, chat_request: ChatRequest) -> list[dict[str, str]]: + """Create system prompts. + Note that not all models support system prompts. + + example output: [{"text" : system_prompt}] + + See example: + https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#message-inference-examples + """ + #logger.info(f"BedrockAgents._parse_system_prompts: {chat_request}") + + system_prompts = [] + for message in chat_request.messages: + if message.role != "system": + # ignore system messages here + continue + assert isinstance(message.content, str) + system_prompts.append({"text": message.content}) + + return system_prompts + + def _parse_messages(self, chat_request: ChatRequest) -> list[dict]: + """ + Converse API only support user and assistant messages. + + example output: [{ + "role": "user", + "content": [{"text": input_text}] + }] + + See example: + https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#message-inference-examples + """ + + #logger.info(f"BedrockAgents._parse_messages: {chat_request}") + + messages = [] + for message in chat_request.messages: + if isinstance(message, UserMessage): + messages.append( + { + "role": message.role, + "content": self._parse_content_parts( + message, chat_request.model + ), + } + ) + elif isinstance(message, AssistantMessage): + if message.content: + # Text message + messages.append( + { + "role": message.role, + "content": self._parse_content_parts( + message, chat_request.model + ), + } + ) + else: + # Tool use message + tool_input = json.loads(message.tool_calls[0].function.arguments) + messages.append( + { + "role": message.role, + "content": [ + { + "toolUse": { + "toolUseId": message.tool_calls[0].id, + "name": message.tool_calls[0].function.name, + "input": tool_input + } + } + ], + } + ) + elif isinstance(message, ToolMessage): + # Bedrock does not support tool role, + # Add toolResult to content + # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolResultBlock.html + messages.append( + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": message.tool_call_id, + "content": [{"text": message.content}], + } + } + ], + } + ) + + else: + # ignore others, such as system messages + continue + return self._reframe_multi_payloard(messages) + + + def _reframe_multi_payloard(self, messages: list) -> list: + """ Receive messages and reformat them to comply with the Claude format + +With OpenAI format requests, it's not a problem to repeatedly receive messages from the same role, but +with Claude format requests, you cannot repeatedly receive messages from the same role. + +This method searches through the OpenAI format messages in order and reformats them to the Claude format. + +``` +openai_format_messages=[ +{"role": "user", "content": "hogehoge"}, +{"role": "user", "content": "fugafuga"}, +] + +bedrock_format_messages=[ +{ + "role": "user", + "content": [ + {"text": "hogehoge"}, + {"text": "fugafuga"} + ] +}, +] +``` + """ + reformatted_messages = [] + current_role = None + current_content = [] + + # Search through the list of messages and combine messages from the same role into one list + for message in messages: + next_role = message['role'] + next_content = message['content'] + + # If the next role is different from the previous message, add the previous role's messages to the list + if next_role != current_role: + if current_content: + reformatted_messages.append({ + "role": current_role, + "content": current_content + }) + # Switch to the new role + current_role = next_role + current_content = [] + + # Add the message content to current_content + if isinstance(next_content, str): + current_content.append({"text": next_content}) + elif isinstance(next_content, list): + current_content.extend(next_content) + + # Add the last role's messages to the list + if current_content: + reformatted_messages.append({ + "role": current_role, + "content": current_content + }) + + return reformatted_messages + + # This function invokes knowledgebase + def _invoke_kb(self, chat_request: ChatRequest, stream=False): + """Common logic for invoke kb with default model""" + if DEBUG: + logger.info("BedrockAgents._invoke_kb: Raw request: " + chat_request.model_dump_json()) + + # convert OpenAI chat request to Bedrock SDK request + args = self._parse_request(chat_request) + + + if DEBUG: + logger.info("Bedrock request: " + json.dumps(str(args))) + + model = self._supported_models[chat_request.model] + logger.info(f"model: {model}") + + args['modelId'] = model['model_id'] + logger.info(f"args: {args}") + + ################ + + try: + query = args['messages'][0]['content'][0]['text'] + messages = args['messages'] + query = messages[len(messages)-1]['content'][0]['text'] + logger.info(f"Query: {query}") + + # Step 1 - Retrieve Context + retrieval_request_body = { + "retrievalQuery": { + "text": query + }, + "retrievalConfiguration": { + "vectorSearchConfiguration": { + "numberOfResults": 2 + } + } + } + + # Make the retrieve request + response = bedrock_agent.retrieve(knowledgeBaseId=model['kb_id'], **retrieval_request_body) + logger.info(f"retrieve response: {response}") + + # Extract and return the results + context = '' + if "retrievalResults" in response: + for result in response["retrievalResults"]: + result = result["content"]["text"] + #logger.info(f"Result: {result}") + context = f"{context}\n{result}" + + + # Step 2 - Append context in the prompt + args['messages'][0]['content'][0]['text'] = f"Context: {context} \n\n {query}" + + #print(args) + + # Step 3 - Make the converse request + if stream: + response = bedrock_client.converse_stream(**args) + else: + response = bedrock_client.converse(**args) + + logger.info(f'kb response: {response}') + + except Exception as e: + print(f"Error retrieving from knowledge base: {str(e)}") + raise + + ############### + return response + + # This function invokes knowledgebase + def _invoke_agent(self, chat_request: ChatRequest, stream=False): + """Common logic for invoke agent """ + if DEBUG: + logger.info("BedrockAgents._invoke_agent: Raw request: " + chat_request.model_dump_json()) + + # convert OpenAI chat request to Bedrock SDK request + args = self._parse_request(chat_request) + + + if DEBUG: + logger.info("Bedrock request: " + json.dumps(str(args))) + + model = self._supported_models[chat_request.model] + #logger.info(f"model: {model}") + logger.info(f"args: {args}") + + ################ + + try: + query = args['messages'][0]['content'][0]['text'] + messages = args['messages'] + query = messages[len(messages)-1]['content'][0]['text'] + query = f"My customer id is 1. {query}" + logger.info(f"Query: {query}") + + # Step 1 - Retrieve Context + request_params = { + 'agentId': model['agent_id'], + 'agentAliasId': model['alias_id'], + 'sessionId': 'unique-session-id', # Generate a unique session ID + 'inputText': query + } + + # Make the retrieve request + # Invoke the agent + response = bedrock_agent.invoke_agent(**request_params) + return response + #logger.info(f'agent response: {response} ----\n\n') + + _event_stream = response["completion"] + + chunk_count = 1 + for event in _event_stream: + #print(f'\n\nChunk {chunk_count}: {event}') + chunk_count += 1 + if "chunk" in event: + _data = event["chunk"]["bytes"].decode("utf8") + _agent_answer = self._make_fully_cited_answer( + _data, event, False, 0) + + + #print(f'_agent_answer: {_agent_answer}') + + # Process the response + #completion = response.get('completion', '') + return response + + except Exception as e: + print(f"Error retrieving from knowledge base: {str(e)}") + raise + + ############### + return response + + def _make_fully_cited_answer( + self, orig_agent_answer, event, enable_trace=False, trace_level="none"): + _citations = event.get("chunk", {}).get("attribution", {}).get("citations", []) + if _citations: + if enable_trace: + print( + f"got {len(event['chunk']['attribution']['citations'])} citations \n" + ) + else: + return orig_agent_answer + + # remove tags to work around a bug + _pattern = r"\n\n\n\d+\n\n\n" + _cleaned_text = re.sub(_pattern, "", orig_agent_answer) + _pattern = "" + _cleaned_text = re.sub(_pattern, "", _cleaned_text) + _pattern = "" + _cleaned_text = re.sub(_pattern, "", _cleaned_text) + + _fully_cited_answer = "" + _curr_citation_idx = 0 + + for _citation in _citations: + if enable_trace and trace_level == "all": + print(f"full citation: {_citation}") + + _start = _citation["generatedResponsePart"]["textResponsePart"]["span"][ + "start" + ] - ( + _curr_citation_idx + 1 + ) # +1 + _end = ( + _citation["generatedResponsePart"]["textResponsePart"]["span"]["end"] + - (_curr_citation_idx + 2) + + 4 + ) # +2 + _refs = _citation.get("retrievedReferences", []) + if len(_refs) > 0: + _ref_url = ( + _refs[0].get("location", {}).get("s3Location", {}).get("uri", "") + ) + else: + _ref_url = "" + _fully_cited_answer = _cleaned_text + break + + _fully_cited_answer += _cleaned_text[_start:_end] + " [" + _ref_url + "] " + + if _curr_citation_idx == 0: + _answer_prefix = _cleaned_text[:_start] + _fully_cited_answer = _answer_prefix + _fully_cited_answer + + _curr_citation_idx += 1 + + if enable_trace and trace_level == "all": + print(f"\n\ncitation {_curr_citation_idx}:") + print( + f"got {len(_citation['retrievedReferences'])} retrieved references for this citation\n" + ) + print(f"citation span... start: {_start}, end: {_end}") + print( + f"citation based on span:====\n{_cleaned_text[_start:_end]}\n====" + ) + print(f"citation url: {_ref_url}\n============") + + if enable_trace and trace_level == "all": + print( + f"\nfullly cited answer:*************\n{_fully_cited_answer}\n*************" + ) + + return _fully_cited_answer + + def _parse_request(self, chat_request: ChatRequest) -> dict: + """Create default converse request body. + + Also perform validations to tool call etc. + + Ref: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html + """ + + messages = self._parse_messages(chat_request) + system_prompts = self._parse_system_prompts(chat_request) + + # Base inference parameters. + inference_config = { + "temperature": chat_request.temperature, + "maxTokens": chat_request.max_tokens, + "topP": chat_request.top_p, + } + + args = { + "modelId": chat_request.model, + "messages": messages, + "system": system_prompts, + "inferenceConfig": inference_config, + } + # add tool config + if chat_request.tools: + args["toolConfig"] = { + "tools": [ + self._convert_tool_spec(t.function) for t in chat_request.tools + ] + } + + if chat_request.tool_choice and not chat_request.model.startswith("meta.llama3-1-"): + if isinstance(chat_request.tool_choice, str): + # auto (default) is mapped to {"auto" : {}} + # required is mapped to {"any" : {}} + if chat_request.tool_choice == "required": + args["toolConfig"]["toolChoice"] = {"any": {}} + else: + args["toolConfig"]["toolChoice"] = {"auto": {}} + else: + # Specific tool to use + assert "function" in chat_request.tool_choice + args["toolConfig"]["toolChoice"] = { + "tool": {"name": chat_request.tool_choice["function"].get("name", "")}} + return args + + def _create_response( + self, + model: str, + message_id: str, + content: list[dict] = None, + finish_reason: str | None = None, + input_tokens: int = 0, + output_tokens: int = 0, + ) -> ChatResponse: + + message = ChatResponseMessage( + role="assistant", + ) + if finish_reason == "tool_use": + # https://docs.aws.amazon.com/bedrock/latest/userguide/tool-use.html#tool-use-examples + tool_calls = [] + for part in content: + if "toolUse" in part: + tool = part["toolUse"] + tool_calls.append( + ToolCall( + id=tool["toolUseId"], + type="function", + function=ResponseFunction( + name=tool["name"], + arguments=json.dumps(tool["input"]), + ), + ) + ) + message.tool_calls = tool_calls + message.content = None + else: + message.content = "" + if content: + message.content = content[0]["text"] + + response = ChatResponse( + id=message_id, + model=model, + choices=[ + Choice( + index=0, + message=message, + finish_reason=self._convert_finish_reason(finish_reason), + logprobs=None, + ) + ], + usage=Usage( + prompt_tokens=input_tokens, + completion_tokens=output_tokens, + total_tokens=input_tokens + output_tokens, + ), + ) + response.system_fingerprint = "fp" + response.object = "chat.completion" + response.created = int(time.time()) + return response + + def _create_response_stream( + self, model_id: str, message_id: str, chunk: dict + ) -> ChatStreamResponse | None: + """Parsing the Bedrock stream response chunk. + + Ref: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#message-inference-examples + """ + if DEBUG: + logger.info("Bedrock response chunk: " + str(chunk)) + + finish_reason = None + message = None + usage = None + #logger.info(f'chunk: {chunk}') + if "messageStart" in chunk: + message = ChatResponseMessage( + role=chunk["messageStart"]["role"], + content="", + ) + if "contentBlockStart" in chunk: + # tool call start + delta = chunk["contentBlockStart"]["start"] + if "toolUse" in delta: + # first index is content + index = chunk["contentBlockStart"]["contentBlockIndex"] - 1 + message = ChatResponseMessage( + tool_calls=[ + ToolCall( + index=index, + type="function", + id=delta["toolUse"]["toolUseId"], + function=ResponseFunction( + name=delta["toolUse"]["name"], + arguments="", + ), + ) + ] + ) + if "contentBlockDelta" in chunk: + delta = chunk["contentBlockDelta"]["delta"] + if "text" in delta: + # stream content + message = ChatResponseMessage( + content=delta["text"], + ) + else: + # tool use + index = chunk["contentBlockDelta"]["contentBlockIndex"] - 1 + message = ChatResponseMessage( + tool_calls=[ + ToolCall( + index=index, + function=ResponseFunction( + arguments=delta["toolUse"]["input"], + ) + ) + ] + ) + if "messageStop" in chunk: + message = ChatResponseMessage() + finish_reason = chunk["messageStop"]["stopReason"] + + if "metadata" in chunk: + # usage information in metadata. + metadata = chunk["metadata"] + if "usage" in metadata: + # token usage + return ChatStreamResponse( + id=message_id, + model=model_id, + choices=[], + usage=Usage( + prompt_tokens=metadata["usage"]["inputTokens"], + completion_tokens=metadata["usage"]["outputTokens"], + total_tokens=metadata["usage"]["totalTokens"], + ), + ) + if message: + return ChatStreamResponse( + id=message_id, + model=model_id, + choices=[ + ChoiceDelta( + index=0, + delta=message, + logprobs=None, + finish_reason=self._convert_finish_reason(finish_reason), + ) + ], + usage=usage, + ) + + return None + + def _parse_image(self, image_url: str) -> tuple[bytes, str]: + """Try to get the raw data from an image url. + + Ref: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ImageSource.html + returns a tuple of (Image Data, Content Type) + """ + pattern = r"^data:(image/[a-z]*);base64,\s*" + content_type = re.search(pattern, image_url) + # if already base64 encoded. + # Only supports 'image/jpeg', 'image/png', 'image/gif' or 'image/webp' + if content_type: + image_data = re.sub(pattern, "", image_url) + return base64.b64decode(image_data), content_type.group(1) + + # Send a request to the image URL + response = requests.get(image_url) + # Check if the request was successful + if response.status_code == 200: + + content_type = response.headers.get("Content-Type") + if not content_type.startswith("image"): + content_type = "image/jpeg" + # Get the image content + image_content = response.content + return image_content, content_type + else: + raise HTTPException( + status_code=500, detail="Unable to access the image url" + ) + + def _parse_content_parts( + self, + message: UserMessage, + model_id: str, + ) -> list[dict]: + if isinstance(message.content, str): + return [ + { + "text": message.content, + } + ] + content_parts = [] + for part in message.content: + if isinstance(part, TextContent): + content_parts.append( + { + "text": part.text, + } + ) + elif isinstance(part, ImageContent): + if not self._is_multimodal_supported(model_id): + raise HTTPException( + status_code=400, + detail=f"Multimodal message is currently not supported by {model_id}", + ) + image_data, content_type = self._parse_image(part.image_url.url) + content_parts.append( + { + "image": { + "format": content_type[6:], # image/ + "source": {"bytes": image_data}, + }, + } + ) + else: + # Ignore.. + continue + return content_parts + + def _is_tool_call_supported(self, model_id: str, stream: bool = False) -> bool: + feature = self._supported_models.get(model_id) + if not feature: + return False + return feature["stream_tool_call"] if stream else feature["tool_call"] + + def _is_multimodal_supported(self, model_id: str) -> bool: + feature = self._supported_models.get(model_id) + if not feature: + return False + return feature["multimodal"] + + def _is_system_prompt_supported(self, model_id: str) -> bool: + feature = self._supported_models.get(model_id) + if not feature: + return False + return feature["system"] + + def _convert_tool_spec(self, func: Function) -> dict: + return { + "toolSpec": { + "name": func.name, + "description": func.description, + "inputSchema": { + "json": func.parameters, + }, + } + } + + def _convert_finish_reason(self, finish_reason: str | None) -> str | None: + """ + Below is a list of finish reason according to OpenAI doc: + + - stop: if the model hit a natural stop point or a provided stop sequence, + - length: if the maximum number of tokens specified in the request was reached, + - content_filter: if content was omitted due to a flag from our content filters, + - tool_calls: if the model called a tool + """ + if finish_reason: + finish_reason_mapping = { + "tool_use": "tool_calls", + "finished": "stop", + "end_turn": "stop", + "max_tokens": "length", + "stop_sequence": "stop", + "complete": "stop", + "content_filtered": "content_filter" + } + return finish_reason_mapping.get(finish_reason.lower(), finish_reason.lower()) + return None + diff --git a/src/api/routers/chat.py b/src/api/routers/chat.py index 1e48a48..19d55cd 100644 --- a/src/api/routers/chat.py +++ b/src/api/routers/chat.py @@ -2,12 +2,15 @@ from fastapi import APIRouter, Depends, Body from fastapi.responses import StreamingResponse - +import logging from api.auth import api_key_auth -from api.models.bedrock import BedrockModel +#from api.models.bedrock import BedrockModel +from api.models.bedrock_agents import BedrockAgents from api.schema import ChatRequest, ChatResponse, ChatStreamResponse from api.setting import DEFAULT_MODEL +logger = logging.getLogger(__name__) + router = APIRouter( prefix="/chat", dependencies=[Depends(api_key_auth)], @@ -32,14 +35,18 @@ async def chat_completions( ), ] ): + # this method gets called by front-end + + logger.info(f"chat_completions: {chat_request}") + if chat_request.model.lower().startswith("gpt-"): chat_request.model = DEFAULT_MODEL # Exception will be raised if model not supported. - model = BedrockModel() + model = BedrockAgents() model.validate(chat_request) if chat_request.stream: - return StreamingResponse( - content=model.chat_stream(chat_request), media_type="text/event-stream" - ) + response = StreamingResponse(content=model.chat_stream(chat_request), media_type="text/event-stream") + logger.info(f"\n\nStreaming response: {response}\n\n") + return response return model.chat(chat_request) diff --git a/src/api/routers/embeddings.py b/src/api/routers/embeddings.py index e5cde31..01f83d0 100644 --- a/src/api/routers/embeddings.py +++ b/src/api/routers/embeddings.py @@ -3,7 +3,7 @@ from fastapi import APIRouter, Depends, Body from api.auth import api_key_auth -from api.models.bedrock import get_embeddings_model +#from api.models.bedrock import get_embeddings_model from api.schema import EmbeddingsRequest, EmbeddingsResponse from api.setting import DEFAULT_EMBEDDING_MODEL diff --git a/src/api/routers/model.py b/src/api/routers/model.py index ce5e8a1..36c4d02 100644 --- a/src/api/routers/model.py +++ b/src/api/routers/model.py @@ -1,10 +1,13 @@ from typing import Annotated - +import logging from fastapi import APIRouter, Depends, HTTPException, Path from api.auth import api_key_auth -from api.models.bedrock import BedrockModel +#from api.models.bedrock import BedrockModel +from api.models.bedrock_agents import BedrockAgents from api.schema import Models, Model +logger = logging.getLogger(__name__) + router = APIRouter( prefix="/models", @@ -12,16 +15,18 @@ # responses={404: {"description": "Not found"}}, ) -chat_model = BedrockModel() - +#chat_model = BedrockModel() +chat_model = BedrockAgents() async def validate_model_id(model_id: str): + logger.info(f"validate_model_id: {model_id}") if model_id not in chat_model.list_models(): raise HTTPException(status_code=500, detail="Unsupported Model Id") @router.get("", response_model=Models) async def list_models(): + model_list = [ Model(id=model_id) for model_id in chat_model.list_models() ] @@ -38,5 +43,6 @@ async def get_model( Path(description="Model ID", example="anthropic.claude-3-sonnet-20240229-v1:0"), ] ): + logger.info(f"get_model: {model_id}") await validate_model_id(model_id) return Model(id=model_id) From 6c88ef6094b53083c14e6f93ecdf246ca1d0b4f7 Mon Sep 17 00:00:00 2001 From: Deepesh Dhapola Date: Sat, 28 Dec 2024 09:22:12 +0530 Subject: [PATCH 2/3] merged from main --- src/api/models/bedrock.py | 466 ++++++++++---------------------------- 1 file changed, 117 insertions(+), 349 deletions(-) diff --git a/src/api/models/bedrock.py b/src/api/models/bedrock.py index 2bb4b74..6dfccbd 100644 --- a/src/api/models/bedrock.py +++ b/src/api/models/bedrock.py @@ -7,10 +7,10 @@ from typing import AsyncIterable, Iterable, Literal import boto3 -from botocore.config import Config import numpy as np import requests import tiktoken +from botocore.config import Config from fastapi import HTTPException from api.models.base import BaseChatModel, BaseEmbeddingsModel @@ -37,19 +37,33 @@ EmbeddingsUsage, Embedding, - ) -from api.setting import DEBUG, AWS_REGION +from api.setting import DEBUG, AWS_REGION, ENABLE_CROSS_REGION_INFERENCE, DEFAULT_MODEL logger = logging.getLogger(__name__) -config = Config(connect_timeout=1, read_timeout=120, retries={"max_attempts": 1}) +config = Config(connect_timeout=60, read_timeout=120, retries={"max_attempts": 1}) bedrock_runtime = boto3.client( service_name="bedrock-runtime", region_name=AWS_REGION, config=config, ) +bedrock_client = boto3.client( + service_name='bedrock', + region_name=AWS_REGION, + config=config, +) + + +def get_inference_region_prefix(): + if AWS_REGION.startswith('ap-'): + return 'apac' + return AWS_REGION[:2] + + +# https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html +cr_inference_prefix = get_inference_region_prefix() SUPPORTED_BEDROCK_EMBEDDING_MODELS = { "cohere.embed-multilingual-v3": "Cohere Embed Multilingual", @@ -62,310 +76,84 @@ ENCODER = tiktoken.get_encoding("cl100k_base") +def list_bedrock_models() -> dict: + """Automatically getting a list of supported models. + + Returns a model list combines: + - ON_DEMAND models. + - Cross-Region Inference Profiles (if enabled via Env) + """ + model_list = {} + try: + profile_list = [] + if ENABLE_CROSS_REGION_INFERENCE: + # List system defined inference profile IDs + response = bedrock_client.list_inference_profiles( + maxResults=1000, + typeEquals='SYSTEM_DEFINED' + ) + profile_list = [p['inferenceProfileId'] for p in response['inferenceProfileSummaries']] + + # List foundation models, only cares about text outputs here. + response = bedrock_client.list_foundation_models( + byOutputModality='TEXT' + ) + + for model in response['modelSummaries']: + model_id = model.get('modelId', 'N/A') + stream_supported = model.get('responseStreamingSupported', True) + status = model['modelLifecycle'].get('status', 'ACTIVE') + + # currently, use this to filter out rerank models and legacy models + if not stream_supported or status != "ACTIVE": + continue + + inference_types = model.get('inferenceTypesSupported', []) + input_modalities = model['inputModalities'] + # Add on-demand model list + if 'ON_DEMAND' in inference_types: + model_list[model_id] = { + 'modalities': input_modalities + } + + # Add cross-region inference model list. + profile_id = cr_inference_prefix + '.' + model_id + if profile_id in profile_list: + model_list[profile_id] = { + 'modalities': input_modalities + } + + except Exception as e: + logger.error(f"Unable to list models: {str(e)}") + + if not model_list: + # In case stack not updated. + model_list[DEFAULT_MODEL] = { + 'modalities': ["TEXT", "IMAGE"] + } + + return model_list + + +# Initialize the model list. +bedrock_model_list = list_bedrock_models() + + class BedrockModel(BaseChatModel): - # https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#conversation-inference-supported-models-features - _supported_models = { - "amazon.titan-text-premier-v1:0": { - "system": True, - "multimodal": False, - "tool_call": False, - "stream_tool_call": False, - }, - "anthropic.claude-instant-v1": { - "system": True, - "multimodal": False, - "tool_call": False, - "stream_tool_call": False, - }, - "anthropic.claude-v2:1": { - "system": True, - "multimodal": False, - "tool_call": False, - "stream_tool_call": False, - }, - "anthropic.claude-v2": { - "system": True, - "multimodal": False, - "tool_call": False, - "stream_tool_call": False, - }, - "anthropic.claude-3-sonnet-20240229-v1:0": { - "system": True, - "multimodal": True, - "tool_call": True, - "stream_tool_call": True, - }, - "anthropic.claude-3-opus-20240229-v1:0": { - "system": True, - "multimodal": True, - "tool_call": True, - "stream_tool_call": True, - }, - "anthropic.claude-3-haiku-20240307-v1:0": { - "system": True, - "multimodal": True, - "tool_call": True, - "stream_tool_call": True, - }, - "anthropic.claude-3-5-sonnet-20240620-v1:0": { - "system": True, - "multimodal": True, - "tool_call": True, - "stream_tool_call": True, - }, - "anthropic.claude-3-5-sonnet-20241022-v2:0": { - "system": True, - "multimodal": True, - "tool_call": True, - "stream_tool_call": True, - }, - "meta.llama2-13b-chat-v1": { - "system": True, - "multimodal": False, - "tool_call": False, - "stream_tool_call": False, - }, - "meta.llama2-70b-chat-v1": { - "system": True, - "multimodal": False, - "tool_call": False, - "stream_tool_call": False, - }, - "meta.llama3-8b-instruct-v1:0": { - "system": True, - "multimodal": False, - "tool_call": False, - "stream_tool_call": False, - }, - "meta.llama3-70b-instruct-v1:0": { - "system": True, - "multimodal": False, - "tool_call": False, - "stream_tool_call": False, - }, - # Llama 3.1 8b cross-region inference profile - "us.meta.llama3-1-8b-instruct-v1:0": { - "system": True, - "multimodal": False, - "tool_call": True, - "stream_tool_call": False, - }, - "meta.llama3-1-8b-instruct-v1:0": { - "system": True, - "multimodal": False, - "tool_call": True, - "stream_tool_call": False, - }, - # Llama 3.1 70b cross-region inference profile - "us.meta.llama3-1-70b-instruct-v1:0": { - "system": True, - "multimodal": False, - "tool_call": True, - "stream_tool_call": False, - }, - "meta.llama3-1-70b-instruct-v1:0": { - "system": True, - "multimodal": False, - "tool_call": True, - "stream_tool_call": False, - }, - "meta.llama3-1-405b-instruct-v1:0": { - "system": True, - "multimodal": False, - "tool_call": True, - "stream_tool_call": False, - }, - # Llama 3.2 1B cross-region inference profile - "us.meta.llama3-2-1b-instruct-v1:0": { - "system": True, - "multimodal": False, - "tool_call": False, - "stream_tool_call": False, - }, - # Llama 3.2 3B cross-region inference profile - "us.meta.llama3-2-3b-instruct-v1:0": { - "system": True, - "multimodal": False, - "tool_call": False, - "stream_tool_call": False, - }, - # Llama 3.2 11B cross-region inference profile - "us.meta.llama3-2-11b-instruct-v1:0": { - "system": True, - "multimodal": True, - "tool_call": True, - "stream_tool_call": False, - }, - # Llama 3.2 90B cross-region inference profile - "us.meta.llama3-2-90b-instruct-v1:0": { - "system": True, - "multimodal": True, - "tool_call": True, - "stream_tool_call": False, - }, - "mistral.mistral-7b-instruct-v0:2": { - "system": False, - "multimodal": False, - "tool_call": False, - "stream_tool_call": False, - }, - "mistral.mixtral-8x7b-instruct-v0:1": { - "system": False, - "multimodal": False, - "tool_call": False, - "stream_tool_call": False, - }, - "mistral.mistral-small-2402-v1:0": { - "system": True, - "multimodal": False, - "tool_call": False, - "stream_tool_call": False, - }, - "mistral.mistral-large-2402-v1:0": { - "system": True, - "multimodal": False, - "tool_call": True, - "stream_tool_call": False, - }, - "mistral.mistral-large-2407-v1:0": { - "system": True, - "multimodal": False, - "tool_call": True, - "stream_tool_call": False, - }, - "cohere.command-r-v1:0": { - "system": True, - "multimodal": False, - "tool_call": True, - "stream_tool_call": False, - }, - "cohere.command-r-plus-v1:0": { - "system": True, - "multimodal": False, - "tool_call": True, - "stream_tool_call": False, - }, - "apac.anthropic.claude-3-sonnet-20240229-v1:0": { - "system": True, - "multimodal": True, - "tool_call": True, - "stream_tool_call": True, - }, - "apac.anthropic.claude-3-haiku-20240307-v1:0": { - "system": True, - "multimodal": True, - "tool_call": True, - "stream_tool_call": True, - }, - "apac.anthropic.claude-3-5-sonnet-20240620-v1:0": { - "system": True, - "multimodal": True, - "tool_call": True, - "stream_tool_call": True, - }, - # claude 3 Haiku cross-region inference profile - "us.anthropic.claude-3-haiku-20240307-v1:0": { - "system": True, - "multimodal": True, - "tool_call": True, - "stream_tool_call": True, - }, - "eu.anthropic.claude-3-haiku-20240307-v1:0": { - "system": True, - "multimodal": True, - "tool_call": True, - "stream_tool_call": True, - }, - # claude 3 Opus cross-region inference profile - "us.anthropic.claude-3-opus-20240229-v1:0": { - "system": True, - "multimodal": True, - "tool_call": True, - "stream_tool_call": True, - }, - # claude 3 Sonnet cross-region inference profile - "us.anthropic.claude-3-sonnet-20240229-v1:0": { - "system": True, - "multimodal": True, - "tool_call": True, - "stream_tool_call": True, - }, - "eu.anthropic.claude-3-sonnet-20240229-v1:0": { - "system": True, - "multimodal": True, - "tool_call": True, - "stream_tool_call": True, - }, - # claude 3.5 Sonnet cross-region inference profile - "us.anthropic.claude-3-5-sonnet-20240620-v1:0": { - "system": True, - "multimodal": True, - "tool_call": True, - "stream_tool_call": True, - }, - "eu.anthropic.claude-3-5-sonnet-20240620-v1:0": { - "system": True, - "multimodal": True, - "tool_call": True, - "stream_tool_call": True, - }, - # claude 3.5 Sonnet v2 cross-region inference profile(Now only us-west-2) - "us.anthropic.claude-3-5-sonnet-20241022-v2:0": { - "system": True, - "multimodal": True, - "tool_call": True, - "stream_tool_call": True, - }, - # Amazon Nova models - AWS's proprietary large language models - "us.amazon.nova-lite-v1:0": { - "system": True, # Supports system prompts for context setting - "multimodal": True, # Capable of processing both text and images - "tool_call": True, - "stream_tool_call": True, - }, - "us.amazon.nova-micro-v1:0": { - "system": True, # Supports system prompts for context setting - "multimodal": False, # Text-only model, no image processing capabilities - "tool_call": True, - "stream_tool_call": True, - }, - "us.amazon.nova-pro-v1:0": { - "system": True, # Supports system prompts for context setting - "multimodal": True, # Capable of processing both text and images - "tool_call": True, - "stream_tool_call": True, - }, - - "knowledgebase": { - "system": True, # Supports system prompts for context setting - "multimodal": True, # Capable of processing both text and images - "tool_call": True, - "stream_tool_call": True, - }, - } - - def supported_models(self): - logger.info("BedrockModel.supported_models") - return list(self._supported_models.keys()) def list_models(self) -> list[str]: - logger.info("BedrockModel.list_models") - return list(self._supported_models.keys()) + """Always refresh the latest model list""" + global bedrock_model_list + bedrock_model_list = list_bedrock_models() + return list(bedrock_model_list.keys()) def validate(self, chat_request: ChatRequest): """Perform basic validation on requests""" - logger.info(f"BedrockModel.validate: {chat_request}") - error = "" # check if model is supported - if chat_request.model not in self._supported_models.keys(): + if chat_request.model not in bedrock_model_list.keys(): error = f"Unsupported model {chat_request.model}, please use models API to get a list of supported models" - # check if tool call is supported - elif chat_request.tools and not self._is_tool_call_supported(chat_request.model, stream=chat_request.stream): - tool_call_info = "Tool call with streaming" if chat_request.stream else "Tool call" - error = f"{tool_call_info} is currently not supported by {chat_request.model}" - if error: raise HTTPException( status_code=400, @@ -375,7 +163,7 @@ def validate(self, chat_request: ChatRequest): def _invoke_bedrock(self, chat_request: ChatRequest, stream=False): """Common logic for invoke bedrock models""" if DEBUG: - logger.info("BedrockModel._invoke_bedrock: Raw request: " + chat_request.model_dump_json()) + logger.info("Raw request: " + chat_request.model_dump_json()) # convert OpenAI chat request to Bedrock SDK request args = self._parse_request(chat_request) @@ -397,7 +185,7 @@ def _invoke_bedrock(self, chat_request: ChatRequest, stream=False): def chat(self, chat_request: ChatRequest) -> ChatResponse: """Default implementation for Chat API.""" - logger.info(f"BedrockModel.chat: {chat_request}") + message_id = self.generate_message_id() response = self._invoke_bedrock(chat_request) @@ -420,9 +208,6 @@ def chat(self, chat_request: ChatRequest) -> ChatResponse: def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]: """Default implementation for Chat Stream API""" - - logger.info(f"BedrockModel.chat_stream: {chat_request}") - response = self._invoke_bedrock(chat_request, stream=True) message_id = self.generate_message_id() @@ -461,7 +246,6 @@ def _parse_system_prompts(self, chat_request: ChatRequest) -> list[dict[str, str See example: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#message-inference-examples """ - logger.info(f"BedrockModel._parse_system_prompts: {chat_request}") system_prompts = [] for message in chat_request.messages: @@ -485,9 +269,6 @@ def _parse_messages(self, chat_request: ChatRequest) -> list[dict]: See example: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#message-inference-examples """ - - logger.info(f"BedrockModel._parse_messages: {chat_request}") - messages = [] for message in chat_request.messages: if isinstance(message, UserMessage): @@ -550,31 +331,29 @@ def _parse_messages(self, chat_request: ChatRequest) -> list[dict]: continue return self._reframe_multi_payloard(messages) - def _reframe_multi_payloard(self, messages: list) -> list: """ Receive messages and reformat them to comply with the Claude format -With OpenAI format requests, it's not a problem to repeatedly receive messages from the same role, but -with Claude format requests, you cannot repeatedly receive messages from the same role. - -This method searches through the OpenAI format messages in order and reformats them to the Claude format. - -``` -openai_format_messages=[ -{"role": "user", "content": "hogehoge"}, -{"role": "user", "content": "fugafuga"}, -] - -bedrock_format_messages=[ -{ - "role": "user", - "content": [ - {"text": "hogehoge"}, - {"text": "fugafuga"} - ] -}, -] -``` + With OpenAI format requests, it's not a problem to repeatedly receive messages from the same role, but + with Claude format requests, you cannot repeatedly receive messages from the same role. + + This method searches through the OpenAI format messages in order and reformats them to the Claude format. + + ``` + openai_format_messages=[ + {"role": "user", "content": "Hello"}, + {"role": "user", "content": "Who are you?"}, + ] + + bedrock_format_messages=[ + { + "role": "user", + "content": [ + {"text": "Hello"}, + {"text": "Who are you?"} + ] + }, + ] """ reformatted_messages = [] current_role = None @@ -611,7 +390,6 @@ def _reframe_multi_payloard(self, messages: list) -> list: return reformatted_messages - def _parse_request(self, chat_request: ChatRequest) -> dict: """Create default converse request body. @@ -860,7 +638,7 @@ def _parse_content_parts( } ) elif isinstance(part, ImageContent): - if not self._is_multimodal_supported(model_id): + if not self.is_supported_modality(model_id, modality="IMAGE"): raise HTTPException( status_code=400, detail=f"Multimodal message is currently not supported by {model_id}", @@ -879,23 +657,13 @@ def _parse_content_parts( continue return content_parts - def _is_tool_call_supported(self, model_id: str, stream: bool = False) -> bool: - feature = self._supported_models.get(model_id) - if not feature: - return False - return feature["stream_tool_call"] if stream else feature["tool_call"] - - def _is_multimodal_supported(self, model_id: str) -> bool: - feature = self._supported_models.get(model_id) - if not feature: - return False - return feature["multimodal"] - - def _is_system_prompt_supported(self, model_id: str) -> bool: - feature = self._supported_models.get(model_id) - if not feature: - return False - return feature["system"] + @staticmethod + def is_supported_modality(model_id: str, modality: str = "IMAGE") -> bool: + model = bedrock_model_list.get(model_id) + modalities = model.get('modalities', []) + if modality in modalities: + return True + return False def _convert_tool_spec(self, func: Function) -> dict: return { @@ -1083,4 +851,4 @@ def get_embeddings_model(model_id: str) -> BedrockEmbeddingsModel: raise HTTPException( status_code=400, detail="Unsupported embedding model id " + model_id, - ) + ) \ No newline at end of file From 9d2c08c07b16be230ef9eb77ffc417114f35b4ac Mon Sep 17 00:00:00 2001 From: Deepesh Dhapola Date: Sun, 29 Dec 2024 13:33:19 +0530 Subject: [PATCH 3/3] BedrockAgents class is now subclass of BedrockModel class. model list is maintained in a singleton class --- src/api/app.py | 1 + src/api/models/bedrock.py | 130 ++--- src/api/models/bedrock_agents.py | 787 +++---------------------------- src/api/models/model_manager.py | 35 ++ src/api/routers/chat.py | 4 +- src/api/setting.py | 10 + 6 files changed, 172 insertions(+), 795 deletions(-) create mode 100644 src/api/models/model_manager.py diff --git a/src/api/app.py b/src/api/app.py index 7fa2f01..56f815d 100644 --- a/src/api/app.py +++ b/src/api/app.py @@ -42,6 +42,7 @@ async def health(): return {"status": "OK"} + @app.exception_handler(RequestValidationError) async def validation_exception_handler(request, exc): return PlainTextResponse(str(exc), status_code=400) diff --git a/src/api/models/bedrock.py b/src/api/models/bedrock.py index 6dfccbd..e9fdf1b 100644 --- a/src/api/models/bedrock.py +++ b/src/api/models/bedrock.py @@ -5,6 +5,7 @@ import time from abc import ABC from typing import AsyncIterable, Iterable, Literal +from api.models.model_manager import ModelManager import boto3 import numpy as np @@ -75,83 +76,88 @@ def get_inference_region_prefix(): ENCODER = tiktoken.get_encoding("cl100k_base") +# Initialize the model list. +#bedrock_model_list = list_bedrock_models() -def list_bedrock_models() -> dict: - """Automatically getting a list of supported models. - - Returns a model list combines: - - ON_DEMAND models. - - Cross-Region Inference Profiles (if enabled via Env) - """ - model_list = {} - try: - profile_list = [] - if ENABLE_CROSS_REGION_INFERENCE: - # List system defined inference profile IDs - response = bedrock_client.list_inference_profiles( - maxResults=1000, - typeEquals='SYSTEM_DEFINED' - ) - profile_list = [p['inferenceProfileId'] for p in response['inferenceProfileSummaries']] - - # List foundation models, only cares about text outputs here. - response = bedrock_client.list_foundation_models( - byOutputModality='TEXT' - ) - - for model in response['modelSummaries']: - model_id = model.get('modelId', 'N/A') - stream_supported = model.get('responseStreamingSupported', True) - status = model['modelLifecycle'].get('status', 'ACTIVE') - - # currently, use this to filter out rerank models and legacy models - if not stream_supported or status != "ACTIVE": - continue - - inference_types = model.get('inferenceTypesSupported', []) - input_modalities = model['inputModalities'] - # Add on-demand model list - if 'ON_DEMAND' in inference_types: - model_list[model_id] = { - 'modalities': input_modalities - } - - # Add cross-region inference model list. - profile_id = cr_inference_prefix + '.' + model_id - if profile_id in profile_list: - model_list[profile_id] = { - 'modalities': input_modalities - } - except Exception as e: - logger.error(f"Unable to list models: {str(e)}") +class BedrockModel(BaseChatModel): - if not model_list: - # In case stack not updated. - model_list[DEFAULT_MODEL] = { - 'modalities': ["TEXT", "IMAGE"] - } + #bedrock_model_list = None + model_manager = None + def __init__(self): + super().__init__() + self.model_manager = ModelManager() - return model_list + def list_bedrock_models(self) -> dict: + """Automatically getting a list of supported models. + Returns a model list combines: + - ON_DEMAND models. + - Cross-Region Inference Profiles (if enabled via Env) + """ + #model_list = {} + try: + profile_list = [] + if ENABLE_CROSS_REGION_INFERENCE: + # List system defined inference profile IDs + response = bedrock_client.list_inference_profiles( + maxResults=1000, + typeEquals='SYSTEM_DEFINED' + ) + profile_list = [p['inferenceProfileId'] for p in response['inferenceProfileSummaries']] -# Initialize the model list. -bedrock_model_list = list_bedrock_models() + # List foundation models, only cares about text outputs here. + response = bedrock_client.list_foundation_models( + byOutputModality='TEXT' + ) + for model in response['modelSummaries']: + model_id = model.get('modelId', 'N/A') + stream_supported = model.get('responseStreamingSupported', True) + status = model['modelLifecycle'].get('status', 'ACTIVE') + + # currently, use this to filter out rerank models and legacy models + if not stream_supported or status != "ACTIVE": + continue + + inference_types = model.get('inferenceTypesSupported', []) + input_modalities = model['inputModalities'] + # Add on-demand model list + if 'ON_DEMAND' in inference_types: + model[model_id] = { + 'modalities': input_modalities + } + self.model_manager.add_model(model) + # model_list[model_id] = { + # 'modalities': input_modalities + # } + + # Add cross-region inference model list. + profile_id = cr_inference_prefix + '.' + model_id + if profile_id in profile_list: + model[profile_id] = { + 'modalities': input_modalities + } + self.model_manager.add_model(model) -class BedrockModel(BaseChatModel): + except Exception as e: + logger.error(e) + raise HTTPException(status_code=500, detail=str(e)) def list_models(self) -> list[str]: """Always refresh the latest model list""" - global bedrock_model_list - bedrock_model_list = list_bedrock_models() - return list(bedrock_model_list.keys()) + #global bedrock_model_list + self.list_bedrock_models() + return list(self.model_manager.get_all_models().keys()) def validate(self, chat_request: ChatRequest): """Perform basic validation on requests""" + error = "" + + ###### TODO - failing here as kb and agents are not in the bedrock_model_list # check if model is supported - if chat_request.model not in bedrock_model_list.keys(): + if chat_request.model not in self.model_manager.get_all_models().keys(): error = f"Unsupported model {chat_request.model}, please use models API to get a list of supported models" if error: @@ -659,7 +665,7 @@ def _parse_content_parts( @staticmethod def is_supported_modality(model_id: str, modality: str = "IMAGE") -> bool: - model = bedrock_model_list.get(model_id) + model = ModelManager().models.get(model_id) modalities = model.get('modalities', []) if modality in modalities: return True diff --git a/src/api/models/bedrock_agents.py b/src/api/models/bedrock_agents.py index 50aac06..6f4206f 100644 --- a/src/api/models/bedrock_agents.py +++ b/src/api/models/bedrock_agents.py @@ -4,7 +4,7 @@ import re import time from abc import ABC -from typing import AsyncIterable, Iterable, Literal +from typing import AsyncIterable import boto3 from botocore.config import Config @@ -12,86 +12,58 @@ import requests import tiktoken from fastapi import HTTPException +from api.models.model_manager import ModelManager + +from api.models.bedrock import ( + BedrockModel, + bedrock_client, + bedrock_runtime) -from api.models.base import BaseChatModel, BaseEmbeddingsModel -from api.models.bedrock import BedrockModel from api.schema import ( - # Chat ChatResponse, ChatRequest, - Choice, ChatResponseMessage, - Usage, ChatStreamResponse, - ImageContent, - TextContent, - ToolCall, - ChoiceDelta, - UserMessage, - AssistantMessage, - ToolMessage, - Function, - ResponseFunction, - # Embeddings - EmbeddingsRequest, - EmbeddingsResponse, - EmbeddingsUsage, - Embedding, + ChoiceDelta ) -from api.setting import DEBUG, AWS_REGION - -KB_PREFIX = 'kb-' -AGENT_PREFIX = 'ag-' - -DEFAULT_KB_MODEL = 'anthropic.claude-3-haiku-20240307-v1:0' -DEFAULT_KB_MODEL_ARN = f'arn:aws:bedrock:{AWS_REGION}::foundation-model/{DEFAULT_KB_MODEL}' + +from api.setting import (DEBUG, AWS_REGION, DEFAULT_KB_MODEL, KB_PREFIX, AGENT_PREFIX) logger = logging.getLogger(__name__) - config = Config(connect_timeout=1, read_timeout=120, retries={"max_attempts": 1}) -bedrock_client = boto3.client( - service_name="bedrock-runtime", - region_name=AWS_REGION, - config=config, -) - bedrock_agent = boto3.client( + service_name="bedrock-agent", + region_name=AWS_REGION, + config=config, + ) + +bedrock_agent_runtime = boto3.client( service_name="bedrock-agent-runtime", region_name=AWS_REGION, config=config, ) -SUPPORTED_BEDROCK_EMBEDDING_MODELS = { - "cohere.embed-multilingual-v3": "Cohere Embed Multilingual", - "cohere.embed-english-v3": "Cohere Embed English", - # Disable Titan embedding. - # "amazon.titan-embed-text-v1": "Titan Embeddings G1 - Text", - # "amazon.titan-embed-image-v1": "Titan Multimodal Embeddings G1" -} - -ENCODER = tiktoken.get_encoding("cl100k_base") - class BedrockAgents(BedrockModel): - # https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#conversation-inference-supported-models-features + #bedrock_model_list = None + def __init__(self): + super().__init__() + model_manager = ModelManager() - _supported_models = { - - } + def list_models(self) -> list[str]: + """Always refresh the latest model list""" + super().list_models() + self.get_kbs() + self.get_agents() + return list(self.model_manager.get_all_models().keys()) - # get list of active knowledgebases + # get list of active knowledge bases def get_kbs(self): - bedrock_ag = boto3.client( - service_name="bedrock-agent", - region_name=AWS_REGION, - config=config, - ) - # List knowledge bases - response = bedrock_ag.list_knowledge_bases(maxResults=100) + response = bedrock_agent.list_knowledge_bases(maxResults=100) # Print knowledge base information for kb in response['knowledgeBaseSummaries']: @@ -104,7 +76,10 @@ def get_kbs(self): "kb_id": kb['knowledgeBaseId'], "model_id": DEFAULT_KB_MODEL } - self._supported_models[name] = val + #self.model_manager.get_all_models()[name] = val + model = {} + model[name]=val + self.model_manager.add_model(model) def get_latest_agent_alias(self, client, agent_id): @@ -135,8 +110,6 @@ def get_latest_agent_alias(self, client, agent_id): return None - - def get_agents(self): bedrock_ag = boto3.client( service_name="bedrock-agent", @@ -144,7 +117,7 @@ def get_agents(self): config=config, ) # List Agents - response = bedrock_ag.list_agents(maxResults=100) + response = bedrock_agent.list_agents(maxResults=100) # Prepare agent for display for agent in response['agentSummaries']: @@ -160,79 +133,21 @@ def get_agents(self): continue val = { - "system": False, # Supports system prompts for context setting + "system": False, # Supports system prompts for context setting. These are already set in Bedrock Agent configuration "multimodal": True, # Capable of processing both text and images - "tool_call": True, + "tool_call": False, # Tool Use not required for Agents "stream_tool_call": False, "agent_id": agentId, "alias_id": aliasId } - self._supported_models[name] = val + #self.model_manager.get_all_models()[name] = val + model = {} + model[name]=val + self.model_manager.add_model(model) - def get_models(self): - - client = boto3.client( - service_name="bedrock", - region_name=AWS_REGION, - config=config, - ) - response = client.list_foundation_models(byInferenceType='ON_DEMAND') - - # Prepare agent for display - for model in response['modelSummaries']: - if ((model['modelLifecycle']['status'] == 'ACTIVE') and - #('ON_DEMAND' in model['inferenceTypesSupported']) and - ('EMBEDDING' not in model['outputModalities'])) : - name = f"{model['modelId']}" - stream_support = False - - if (('responseStreamingSupported' in model.keys()) and - (model['responseStreamingSupported'] is True)): - stream_support = True - - val = { - "system": True, # Supports system prompts for context setting - "multimodal": len(model['inputModalities'])>1, # Capable of processing both text and images - "tool_call": True, - "stream_tool_call": stream_support, - } - self._supported_models[name] = val - - def supported_models(self): - logger.info("BedrockAgents.supported_models") - return list(self._supported_models.keys()) - - def list_models(self) -> list[str]: - logger.info("BedrockAgents.list_models") - self.get_models() - self.get_kbs() - self.get_agents() - return list(self._supported_models.keys()) - - def validate(self, chat_request: ChatRequest): - """Perform basic validation on requests""" - #logger.info(f"BedrockAgents.validate: {chat_request}") - - error = "" - # check if model is supported - if chat_request.model not in self._supported_models.keys(): - error = f"Unsupported model {chat_request.model}, please use models API to get a list of supported models" - - # check if tool call is supported - elif chat_request.tools and not self._is_tool_call_supported(chat_request.model, stream=chat_request.stream): - tool_call_info = "Tool call with streaming" if chat_request.stream else "Tool call" - error = f"{tool_call_info} is currently not supported by {chat_request.model}" - - if error: - raise HTTPException( - status_code=400, - detail=error, - ) def _invoke_bedrock(self, chat_request: ChatRequest, stream=False): """Common logic for invoke bedrock models""" - if DEBUG: - logger.info("BedrockAgents._invoke_bedrock: Raw request: " + chat_request.model_dump_json()) # convert OpenAI chat request to Bedrock SDK request args = self._parse_request(chat_request) @@ -242,9 +157,9 @@ def _invoke_bedrock(self, chat_request: ChatRequest, stream=False): try: if stream: - response = bedrock_client.converse_stream(**args) + response = bedrock_runtime.converse_stream(**args) else: - response = bedrock_client.converse(**args) + response = bedrock_runtime.converse(**args) except bedrock_client.exceptions.ValidationException as e: @@ -258,6 +173,7 @@ def _invoke_bedrock(self, chat_request: ChatRequest, stream=False): def chat(self, chat_request: ChatRequest) -> ChatResponse: """Default implementation for Chat API.""" #chat: {chat_request}") + message_id = self.generate_message_id() response = self._invoke_bedrock(chat_request) @@ -281,7 +197,6 @@ def chat(self, chat_request: ChatRequest) -> ChatResponse: def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]: """Default implementation for Chat Stream API""" - logger.info(f"BedrockAgents.chat_stream: {chat_request}") response = '' message_id = self.generate_message_id() @@ -314,7 +229,6 @@ def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]: yield self.stream_response_to_bytes(stream_response) for event in _event_stream: - #print(f'\n\nChunk {chunk_count}: {event}') chunk_count += 1 if "chunk" in event: _data = event["chunk"]["bytes"].decode("utf8") @@ -343,17 +257,11 @@ def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]: else: response = self._invoke_bedrock(chat_request, stream=True) - print(response) stream = response.get("stream") - # chunk_count = 1 for chunk in stream: - #print(f'\n\nChunk {chunk_count}: {chunk}') - # chunk_count += 1 - stream_response = self._create_response_stream( model_id=chat_request.model, message_id=message_id, chunk=chunk ) - #print(f'stream_response: {stream_response}') if not stream_response: continue if DEBUG: @@ -375,164 +283,7 @@ def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]: # return an [DONE] message at the end. yield self.stream_response_to_bytes() - def _parse_system_prompts(self, chat_request: ChatRequest) -> list[dict[str, str]]: - """Create system prompts. - Note that not all models support system prompts. - - example output: [{"text" : system_prompt}] - - See example: - https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#message-inference-examples - """ - #logger.info(f"BedrockAgents._parse_system_prompts: {chat_request}") - - system_prompts = [] - for message in chat_request.messages: - if message.role != "system": - # ignore system messages here - continue - assert isinstance(message.content, str) - system_prompts.append({"text": message.content}) - - return system_prompts - - def _parse_messages(self, chat_request: ChatRequest) -> list[dict]: - """ - Converse API only support user and assistant messages. - - example output: [{ - "role": "user", - "content": [{"text": input_text}] - }] - - See example: - https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#message-inference-examples - """ - - #logger.info(f"BedrockAgents._parse_messages: {chat_request}") - - messages = [] - for message in chat_request.messages: - if isinstance(message, UserMessage): - messages.append( - { - "role": message.role, - "content": self._parse_content_parts( - message, chat_request.model - ), - } - ) - elif isinstance(message, AssistantMessage): - if message.content: - # Text message - messages.append( - { - "role": message.role, - "content": self._parse_content_parts( - message, chat_request.model - ), - } - ) - else: - # Tool use message - tool_input = json.loads(message.tool_calls[0].function.arguments) - messages.append( - { - "role": message.role, - "content": [ - { - "toolUse": { - "toolUseId": message.tool_calls[0].id, - "name": message.tool_calls[0].function.name, - "input": tool_input - } - } - ], - } - ) - elif isinstance(message, ToolMessage): - # Bedrock does not support tool role, - # Add toolResult to content - # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolResultBlock.html - messages.append( - { - "role": "user", - "content": [ - { - "toolResult": { - "toolUseId": message.tool_call_id, - "content": [{"text": message.content}], - } - } - ], - } - ) - - else: - # ignore others, such as system messages - continue - return self._reframe_multi_payloard(messages) - - - def _reframe_multi_payloard(self, messages: list) -> list: - """ Receive messages and reformat them to comply with the Claude format - -With OpenAI format requests, it's not a problem to repeatedly receive messages from the same role, but -with Claude format requests, you cannot repeatedly receive messages from the same role. - -This method searches through the OpenAI format messages in order and reformats them to the Claude format. - -``` -openai_format_messages=[ -{"role": "user", "content": "hogehoge"}, -{"role": "user", "content": "fugafuga"}, -] - -bedrock_format_messages=[ -{ - "role": "user", - "content": [ - {"text": "hogehoge"}, - {"text": "fugafuga"} - ] -}, -] -``` - """ - reformatted_messages = [] - current_role = None - current_content = [] - - # Search through the list of messages and combine messages from the same role into one list - for message in messages: - next_role = message['role'] - next_content = message['content'] - - # If the next role is different from the previous message, add the previous role's messages to the list - if next_role != current_role: - if current_content: - reformatted_messages.append({ - "role": current_role, - "content": current_content - }) - # Switch to the new role - current_role = next_role - current_content = [] - - # Add the message content to current_content - if isinstance(next_content, str): - current_content.append({"text": next_content}) - elif isinstance(next_content, list): - current_content.extend(next_content) - - # Add the last role's messages to the list - if current_content: - reformatted_messages.append({ - "role": current_role, - "content": current_content - }) - - return reformatted_messages + # This function invokes knowledgebase def _invoke_kb(self, chat_request: ChatRequest, stream=False): @@ -547,11 +298,9 @@ def _invoke_kb(self, chat_request: ChatRequest, stream=False): if DEBUG: logger.info("Bedrock request: " + json.dumps(str(args))) - model = self._supported_models[chat_request.model] - logger.info(f"model: {model}") - + model = self.model_manager.get_all_models()[chat_request.model] args['modelId'] = model['model_id'] - logger.info(f"args: {args}") + ################ @@ -559,7 +308,6 @@ def _invoke_kb(self, chat_request: ChatRequest, stream=False): query = args['messages'][0]['content'][0]['text'] messages = args['messages'] query = messages[len(messages)-1]['content'][0]['text'] - logger.info(f"Query: {query}") # Step 1 - Retrieve Context retrieval_request_body = { @@ -574,34 +322,28 @@ def _invoke_kb(self, chat_request: ChatRequest, stream=False): } # Make the retrieve request - response = bedrock_agent.retrieve(knowledgeBaseId=model['kb_id'], **retrieval_request_body) - logger.info(f"retrieve response: {response}") + response = bedrock_agent_runtime.retrieve(knowledgeBaseId=model['kb_id'], **retrieval_request_body) # Extract and return the results context = '' if "retrievalResults" in response: for result in response["retrievalResults"]: result = result["content"]["text"] - #logger.info(f"Result: {result}") context = f"{context}\n{result}" # Step 2 - Append context in the prompt args['messages'][0]['content'][0]['text'] = f"Context: {context} \n\n {query}" - #print(args) - # Step 3 - Make the converse request if stream: - response = bedrock_client.converse_stream(**args) + response = bedrock_runtime.converse_stream(**args) else: - response = bedrock_client.converse(**args) - - logger.info(f'kb response: {response}') - + response = bedrock_runtime.converse(**args) + except Exception as e: - print(f"Error retrieving from knowledge base: {str(e)}") - raise + logger.error(e) + raise HTTPException(status_code=500, detail=str(e)) ############### return response @@ -619,9 +361,7 @@ def _invoke_agent(self, chat_request: ChatRequest, stream=False): if DEBUG: logger.info("Bedrock request: " + json.dumps(str(args))) - model = self._supported_models[chat_request.model] - #logger.info(f"model: {model}") - logger.info(f"args: {args}") + model = self.model_manager.get_all_models()[chat_request.model] ################ @@ -629,9 +369,8 @@ def _invoke_agent(self, chat_request: ChatRequest, stream=False): query = args['messages'][0]['content'][0]['text'] messages = args['messages'] query = messages[len(messages)-1]['content'][0]['text'] - query = f"My customer id is 1. {query}" - logger.info(f"Query: {query}") + # Step 1 - Retrieve Context request_params = { 'agentId': model['agent_id'], @@ -642,423 +381,11 @@ def _invoke_agent(self, chat_request: ChatRequest, stream=False): # Make the retrieve request # Invoke the agent - response = bedrock_agent.invoke_agent(**request_params) + response = bedrock_agent_runtime.invoke_agent(**request_params) return response - #logger.info(f'agent response: {response} ----\n\n') - - _event_stream = response["completion"] - - chunk_count = 1 - for event in _event_stream: - #print(f'\n\nChunk {chunk_count}: {event}') - chunk_count += 1 - if "chunk" in event: - _data = event["chunk"]["bytes"].decode("utf8") - _agent_answer = self._make_fully_cited_answer( - _data, event, False, 0) - - - #print(f'_agent_answer: {_agent_answer}') - # Process the response - #completion = response.get('completion', '') - return response - except Exception as e: - print(f"Error retrieving from knowledge base: {str(e)}") - raise - - ############### - return response - - def _make_fully_cited_answer( - self, orig_agent_answer, event, enable_trace=False, trace_level="none"): - _citations = event.get("chunk", {}).get("attribution", {}).get("citations", []) - if _citations: - if enable_trace: - print( - f"got {len(event['chunk']['attribution']['citations'])} citations \n" - ) - else: - return orig_agent_answer - - # remove tags to work around a bug - _pattern = r"\n\n\n\d+\n\n\n" - _cleaned_text = re.sub(_pattern, "", orig_agent_answer) - _pattern = "" - _cleaned_text = re.sub(_pattern, "", _cleaned_text) - _pattern = "" - _cleaned_text = re.sub(_pattern, "", _cleaned_text) - - _fully_cited_answer = "" - _curr_citation_idx = 0 - - for _citation in _citations: - if enable_trace and trace_level == "all": - print(f"full citation: {_citation}") - - _start = _citation["generatedResponsePart"]["textResponsePart"]["span"][ - "start" - ] - ( - _curr_citation_idx + 1 - ) # +1 - _end = ( - _citation["generatedResponsePart"]["textResponsePart"]["span"]["end"] - - (_curr_citation_idx + 2) - + 4 - ) # +2 - _refs = _citation.get("retrievedReferences", []) - if len(_refs) > 0: - _ref_url = ( - _refs[0].get("location", {}).get("s3Location", {}).get("uri", "") - ) - else: - _ref_url = "" - _fully_cited_answer = _cleaned_text - break - - _fully_cited_answer += _cleaned_text[_start:_end] + " [" + _ref_url + "] " - - if _curr_citation_idx == 0: - _answer_prefix = _cleaned_text[:_start] - _fully_cited_answer = _answer_prefix + _fully_cited_answer - - _curr_citation_idx += 1 - - if enable_trace and trace_level == "all": - print(f"\n\ncitation {_curr_citation_idx}:") - print( - f"got {len(_citation['retrievedReferences'])} retrieved references for this citation\n" - ) - print(f"citation span... start: {_start}, end: {_end}") - print( - f"citation based on span:====\n{_cleaned_text[_start:_end]}\n====" - ) - print(f"citation url: {_ref_url}\n============") - - if enable_trace and trace_level == "all": - print( - f"\nfullly cited answer:*************\n{_fully_cited_answer}\n*************" - ) - - return _fully_cited_answer - - def _parse_request(self, chat_request: ChatRequest) -> dict: - """Create default converse request body. - - Also perform validations to tool call etc. - - Ref: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html - """ - - messages = self._parse_messages(chat_request) - system_prompts = self._parse_system_prompts(chat_request) - - # Base inference parameters. - inference_config = { - "temperature": chat_request.temperature, - "maxTokens": chat_request.max_tokens, - "topP": chat_request.top_p, - } - - args = { - "modelId": chat_request.model, - "messages": messages, - "system": system_prompts, - "inferenceConfig": inference_config, - } - # add tool config - if chat_request.tools: - args["toolConfig"] = { - "tools": [ - self._convert_tool_spec(t.function) for t in chat_request.tools - ] - } - - if chat_request.tool_choice and not chat_request.model.startswith("meta.llama3-1-"): - if isinstance(chat_request.tool_choice, str): - # auto (default) is mapped to {"auto" : {}} - # required is mapped to {"any" : {}} - if chat_request.tool_choice == "required": - args["toolConfig"]["toolChoice"] = {"any": {}} - else: - args["toolConfig"]["toolChoice"] = {"auto": {}} - else: - # Specific tool to use - assert "function" in chat_request.tool_choice - args["toolConfig"]["toolChoice"] = { - "tool": {"name": chat_request.tool_choice["function"].get("name", "")}} - return args - - def _create_response( - self, - model: str, - message_id: str, - content: list[dict] = None, - finish_reason: str | None = None, - input_tokens: int = 0, - output_tokens: int = 0, - ) -> ChatResponse: - - message = ChatResponseMessage( - role="assistant", - ) - if finish_reason == "tool_use": - # https://docs.aws.amazon.com/bedrock/latest/userguide/tool-use.html#tool-use-examples - tool_calls = [] - for part in content: - if "toolUse" in part: - tool = part["toolUse"] - tool_calls.append( - ToolCall( - id=tool["toolUseId"], - type="function", - function=ResponseFunction( - name=tool["name"], - arguments=json.dumps(tool["input"]), - ), - ) - ) - message.tool_calls = tool_calls - message.content = None - else: - message.content = "" - if content: - message.content = content[0]["text"] - - response = ChatResponse( - id=message_id, - model=model, - choices=[ - Choice( - index=0, - message=message, - finish_reason=self._convert_finish_reason(finish_reason), - logprobs=None, - ) - ], - usage=Usage( - prompt_tokens=input_tokens, - completion_tokens=output_tokens, - total_tokens=input_tokens + output_tokens, - ), - ) - response.system_fingerprint = "fp" - response.object = "chat.completion" - response.created = int(time.time()) - return response - - def _create_response_stream( - self, model_id: str, message_id: str, chunk: dict - ) -> ChatStreamResponse | None: - """Parsing the Bedrock stream response chunk. - - Ref: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#message-inference-examples - """ - if DEBUG: - logger.info("Bedrock response chunk: " + str(chunk)) - - finish_reason = None - message = None - usage = None - #logger.info(f'chunk: {chunk}') - if "messageStart" in chunk: - message = ChatResponseMessage( - role=chunk["messageStart"]["role"], - content="", - ) - if "contentBlockStart" in chunk: - # tool call start - delta = chunk["contentBlockStart"]["start"] - if "toolUse" in delta: - # first index is content - index = chunk["contentBlockStart"]["contentBlockIndex"] - 1 - message = ChatResponseMessage( - tool_calls=[ - ToolCall( - index=index, - type="function", - id=delta["toolUse"]["toolUseId"], - function=ResponseFunction( - name=delta["toolUse"]["name"], - arguments="", - ), - ) - ] - ) - if "contentBlockDelta" in chunk: - delta = chunk["contentBlockDelta"]["delta"] - if "text" in delta: - # stream content - message = ChatResponseMessage( - content=delta["text"], - ) - else: - # tool use - index = chunk["contentBlockDelta"]["contentBlockIndex"] - 1 - message = ChatResponseMessage( - tool_calls=[ - ToolCall( - index=index, - function=ResponseFunction( - arguments=delta["toolUse"]["input"], - ) - ) - ] - ) - if "messageStop" in chunk: - message = ChatResponseMessage() - finish_reason = chunk["messageStop"]["stopReason"] - - if "metadata" in chunk: - # usage information in metadata. - metadata = chunk["metadata"] - if "usage" in metadata: - # token usage - return ChatStreamResponse( - id=message_id, - model=model_id, - choices=[], - usage=Usage( - prompt_tokens=metadata["usage"]["inputTokens"], - completion_tokens=metadata["usage"]["outputTokens"], - total_tokens=metadata["usage"]["totalTokens"], - ), - ) - if message: - return ChatStreamResponse( - id=message_id, - model=model_id, - choices=[ - ChoiceDelta( - index=0, - delta=message, - logprobs=None, - finish_reason=self._convert_finish_reason(finish_reason), - ) - ], - usage=usage, - ) - - return None - - def _parse_image(self, image_url: str) -> tuple[bytes, str]: - """Try to get the raw data from an image url. - - Ref: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ImageSource.html - returns a tuple of (Image Data, Content Type) - """ - pattern = r"^data:(image/[a-z]*);base64,\s*" - content_type = re.search(pattern, image_url) - # if already base64 encoded. - # Only supports 'image/jpeg', 'image/png', 'image/gif' or 'image/webp' - if content_type: - image_data = re.sub(pattern, "", image_url) - return base64.b64decode(image_data), content_type.group(1) - - # Send a request to the image URL - response = requests.get(image_url) - # Check if the request was successful - if response.status_code == 200: - - content_type = response.headers.get("Content-Type") - if not content_type.startswith("image"): - content_type = "image/jpeg" - # Get the image content - image_content = response.content - return image_content, content_type - else: - raise HTTPException( - status_code=500, detail="Unable to access the image url" - ) - - def _parse_content_parts( - self, - message: UserMessage, - model_id: str, - ) -> list[dict]: - if isinstance(message.content, str): - return [ - { - "text": message.content, - } - ] - content_parts = [] - for part in message.content: - if isinstance(part, TextContent): - content_parts.append( - { - "text": part.text, - } - ) - elif isinstance(part, ImageContent): - if not self._is_multimodal_supported(model_id): - raise HTTPException( - status_code=400, - detail=f"Multimodal message is currently not supported by {model_id}", - ) - image_data, content_type = self._parse_image(part.image_url.url) - content_parts.append( - { - "image": { - "format": content_type[6:], # image/ - "source": {"bytes": image_data}, - }, - } - ) - else: - # Ignore.. - continue - return content_parts - - def _is_tool_call_supported(self, model_id: str, stream: bool = False) -> bool: - feature = self._supported_models.get(model_id) - if not feature: - return False - return feature["stream_tool_call"] if stream else feature["tool_call"] - - def _is_multimodal_supported(self, model_id: str) -> bool: - feature = self._supported_models.get(model_id) - if not feature: - return False - return feature["multimodal"] - - def _is_system_prompt_supported(self, model_id: str) -> bool: - feature = self._supported_models.get(model_id) - if not feature: - return False - return feature["system"] - - def _convert_tool_spec(self, func: Function) -> dict: - return { - "toolSpec": { - "name": func.name, - "description": func.description, - "inputSchema": { - "json": func.parameters, - }, - } - } - - def _convert_finish_reason(self, finish_reason: str | None) -> str | None: - """ - Below is a list of finish reason according to OpenAI doc: - - - stop: if the model hit a natural stop point or a provided stop sequence, - - length: if the maximum number of tokens specified in the request was reached, - - content_filter: if content was omitted due to a flag from our content filters, - - tool_calls: if the model called a tool - """ - if finish_reason: - finish_reason_mapping = { - "tool_use": "tool_calls", - "finished": "stop", - "end_turn": "stop", - "max_tokens": "length", - "stop_sequence": "stop", - "complete": "stop", - "content_filtered": "content_filter" - } - return finish_reason_mapping.get(finish_reason.lower(), finish_reason.lower()) - return None + logger.error(e) + raise HTTPException(status_code=500, detail=str(e)) + \ No newline at end of file diff --git a/src/api/models/model_manager.py b/src/api/models/model_manager.py new file mode 100644 index 0000000..2974634 --- /dev/null +++ b/src/api/models/model_manager.py @@ -0,0 +1,35 @@ +# This is a singleton class to maintain list of models +class ModelManager: + _instance = None + _models = None + + def __new__(cls, *args, **kwargs): + # Ensure that only one instance of ModelManager is created + if cls._instance is None: + cls._instance = super(ModelManager, cls).__new__(cls, *args, **kwargs) + cls._instance._models = {} # Initialize the list of models + + return cls._instance + + def get_all_models(self): + return self._models + + def add_model(self, model): + """Add a model to the list.""" + if (self._models is None): + self._models = {} + self._models.update(model) + + + def get_model_by_name(self, model_name: str): + """Get the list of models.""" + return self._models + + def clear_models(self): + """Clear the list of models.""" + self._models.clear() + self._models = {} + + def __repr__(self): + return f"ModelManager(models={self._models})" + diff --git a/src/api/routers/chat.py b/src/api/routers/chat.py index 19d55cd..c12181b 100644 --- a/src/api/routers/chat.py +++ b/src/api/routers/chat.py @@ -37,8 +37,6 @@ async def chat_completions( ): # this method gets called by front-end - logger.info(f"chat_completions: {chat_request}") - if chat_request.model.lower().startswith("gpt-"): chat_request.model = DEFAULT_MODEL @@ -47,6 +45,6 @@ async def chat_completions( model.validate(chat_request) if chat_request.stream: response = StreamingResponse(content=model.chat_stream(chat_request), media_type="text/event-stream") - logger.info(f"\n\nStreaming response: {response}\n\n") return response + return model.chat(chat_request) diff --git a/src/api/setting.py b/src/api/setting.py index 9026202..940cd0b 100644 --- a/src/api/setting.py +++ b/src/api/setting.py @@ -20,3 +20,13 @@ "DEFAULT_EMBEDDING_MODEL", "cohere.embed-multilingual-v3" ) ENABLE_CROSS_REGION_INFERENCE = os.environ.get("ENABLE_CROSS_REGION_INFERENCE", "true").lower() != "false" + +KB_PREFIX = 'kb-' +AGENT_PREFIX = 'ag-' + +DEFAULT_KB_MODEL = os.environ.get( + "DEFAULT_KB_MODEL", "anthropic.claude-3-haiku-20240307-v1:0" +) + + +DEFAULT_KB_MODEL_ARN = f'arn:aws:bedrock:{AWS_REGION}::foundation-model/{DEFAULT_KB_MODEL}' \ No newline at end of file