Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Bedrock Agents #89

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ async def health():
return {"status": "OK"}



@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc):
return PlainTextResponse(str(exc), status_code=400)
Expand Down
132 changes: 69 additions & 63 deletions src/api/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import time
from abc import ABC
from typing import AsyncIterable, Iterable, Literal
from api.models.model_manager import ModelManager

import boto3
import numpy as np
Expand Down Expand Up @@ -75,83 +76,88 @@ def get_inference_region_prefix():

ENCODER = tiktoken.get_encoding("cl100k_base")

# Initialize the model list.
#bedrock_model_list = list_bedrock_models()

def list_bedrock_models() -> dict:
"""Automatically getting a list of supported models.

Returns a model list combines:
- ON_DEMAND models.
- Cross-Region Inference Profiles (if enabled via Env)
"""
model_list = {}
try:
profile_list = []
if ENABLE_CROSS_REGION_INFERENCE:
# List system defined inference profile IDs
response = bedrock_client.list_inference_profiles(
maxResults=1000,
typeEquals='SYSTEM_DEFINED'
)
profile_list = [p['inferenceProfileId'] for p in response['inferenceProfileSummaries']]

# List foundation models, only cares about text outputs here.
response = bedrock_client.list_foundation_models(
byOutputModality='TEXT'
)

for model in response['modelSummaries']:
model_id = model.get('modelId', 'N/A')
stream_supported = model.get('responseStreamingSupported', True)
status = model['modelLifecycle'].get('status', 'ACTIVE')

# currently, use this to filter out rerank models and legacy models
if not stream_supported or status != "ACTIVE":
continue

inference_types = model.get('inferenceTypesSupported', [])
input_modalities = model['inputModalities']
# Add on-demand model list
if 'ON_DEMAND' in inference_types:
model_list[model_id] = {
'modalities': input_modalities
}

# Add cross-region inference model list.
profile_id = cr_inference_prefix + '.' + model_id
if profile_id in profile_list:
model_list[profile_id] = {
'modalities': input_modalities
}

except Exception as e:
logger.error(f"Unable to list models: {str(e)}")
class BedrockModel(BaseChatModel):

if not model_list:
# In case stack not updated.
model_list[DEFAULT_MODEL] = {
'modalities': ["TEXT", "IMAGE"]
}
#bedrock_model_list = None
model_manager = None
def __init__(self):
super().__init__()
self.model_manager = ModelManager()

return model_list
def list_bedrock_models(self) -> dict:
"""Automatically getting a list of supported models.

Returns a model list combines:
- ON_DEMAND models.
- Cross-Region Inference Profiles (if enabled via Env)
"""
#model_list = {}
try:
profile_list = []
if ENABLE_CROSS_REGION_INFERENCE:
# List system defined inference profile IDs
response = bedrock_client.list_inference_profiles(
maxResults=1000,
typeEquals='SYSTEM_DEFINED'
)
profile_list = [p['inferenceProfileId'] for p in response['inferenceProfileSummaries']]

# Initialize the model list.
bedrock_model_list = list_bedrock_models()
# List foundation models, only cares about text outputs here.
response = bedrock_client.list_foundation_models(
byOutputModality='TEXT'
)

for model in response['modelSummaries']:
model_id = model.get('modelId', 'N/A')
stream_supported = model.get('responseStreamingSupported', True)
status = model['modelLifecycle'].get('status', 'ACTIVE')

# currently, use this to filter out rerank models and legacy models
if not stream_supported or status != "ACTIVE":
continue

inference_types = model.get('inferenceTypesSupported', [])
input_modalities = model['inputModalities']
# Add on-demand model list
if 'ON_DEMAND' in inference_types:
model[model_id] = {
'modalities': input_modalities
}
self.model_manager.add_model(model)
# model_list[model_id] = {
# 'modalities': input_modalities
# }

# Add cross-region inference model list.
profile_id = cr_inference_prefix + '.' + model_id
if profile_id in profile_list:
model[profile_id] = {
'modalities': input_modalities
}
self.model_manager.add_model(model)

class BedrockModel(BaseChatModel):
except Exception as e:
logger.error(e)
raise HTTPException(status_code=500, detail=str(e))

def list_models(self) -> list[str]:
"""Always refresh the latest model list"""
global bedrock_model_list
bedrock_model_list = list_bedrock_models()
return list(bedrock_model_list.keys())
#global bedrock_model_list
self.list_bedrock_models()
return list(self.model_manager.get_all_models().keys())

def validate(self, chat_request: ChatRequest):
"""Perform basic validation on requests"""

error = ""

###### TODO - failing here as kb and agents are not in the bedrock_model_list
# check if model is supported
if chat_request.model not in bedrock_model_list.keys():
if chat_request.model not in self.model_manager.get_all_models().keys():
error = f"Unsupported model {chat_request.model}, please use models API to get a list of supported models"

if error:
Expand Down Expand Up @@ -659,7 +665,7 @@ def _parse_content_parts(

@staticmethod
def is_supported_modality(model_id: str, modality: str = "IMAGE") -> bool:
model = bedrock_model_list.get(model_id)
model = ModelManager().models.get(model_id)
modalities = model.get('modalities', [])
if modality in modalities:
return True
Expand Down Expand Up @@ -851,4 +857,4 @@ def get_embeddings_model(model_id: str) -> BedrockEmbeddingsModel:
raise HTTPException(
status_code=400,
detail="Unsupported embedding model id " + model_id,
)
)
Loading
Loading