Skip to content

Commit cf1f363

Browse files
Adding functionality to run benchmarks for Amazon nova models (#3251)
Co-authored-by: Sai Kiran Reddy Jakka <jakkasa@amazon.com>
1 parent fca12c0 commit cf1f363

9 files changed

+213
-8
lines changed

setup.cfg

+3-3
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,9 @@ allenai =
131131
ai2-olmo~=0.2
132132

133133
amazon =
134-
boto3~=1.28.57
135-
awscli~=1.29.57
136-
botocore~=1.31.57
134+
boto3~=1.34.131
135+
awscli~=1.32.1
136+
botocore~=1.34.1
137137

138138
anthropic =
139139
anthropic~=0.17,<0.39 # TODO(#3212): Limit anthropic to >=0.39 after resolving #3212.

src/helm/benchmark/model_metadata_registry.py

+3
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
# OpenAI Chat format
2323
OPENAI_CHATGPT_MODEL_TAG: str = "OPENAI_CHATGPT_MODEL_TAG"
2424

25+
# For NOVA models
26+
NOVA_MODEL_TAG: str = "NOVA_MODEL_TAG"
27+
2528
# For Anthropic models
2629
ANTHROPIC_CLAUDE_1_MODEL_TAG: str = "ANTHROPIC_CLAUDE_1_MODEL_TAG"
2730
ANTHROPIC_CLAUDE_2_MODEL_TAG: str = "ANTHROPIC_CLAUDE_2_MODEL_TAG"

src/helm/benchmark/run_expander.py

+27
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,33 @@ def expand(self, run_spec: RunSpec) -> List[RunSpec]:
348348
return [run_spec]
349349

350350

351+
class NovaRunExpander(RunExpander):
352+
"""
353+
Custom prompt for Amazon Nova models.
354+
These models need more explicit instructions about following the format.
355+
"""
356+
357+
name = "amazon-nova"
358+
359+
PROMPT = "Do not provide any additional explanation. Follow the format shown in the provided examples strictly."
360+
361+
def __init__(self):
362+
pass
363+
364+
def expand(self, run_spec: RunSpec) -> List[RunSpec]:
365+
return [
366+
replace(
367+
run_spec,
368+
name=run_spec.name,
369+
adapter_spec=replace(
370+
run_spec.adapter_spec,
371+
global_prefix=NovaRunExpander.PROMPT
372+
+ "\n\n"
373+
),
374+
),
375+
]
376+
377+
351378
class FollowFormatInstructionsRunExpander(RunExpander):
352379
"""Adds more explicit instructions about following the format to prompts.
353380

src/helm/benchmark/run_spec_factory.py

+5
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
ANTHROPIC_CLAUDE_1_MODEL_TAG,
1515
ANTHROPIC_CLAUDE_2_MODEL_TAG,
1616
ANTHROPIC_CLAUDE_3_MODEL_TAG,
17+
NOVA_MODEL_TAG,
1718
BUGGY_TEMP_0_TAG,
1819
CHATML_MODEL_TAG,
1920
GOOGLE_GEMINI_PRO_VISION_V1_TAG,
@@ -31,6 +32,7 @@
3132
RUN_EXPANDERS,
3233
AnthropicClaude2RunExpander,
3334
AnthropicClaude3RunExpander,
35+
NovaRunExpander,
3436
ChatMLRunExpander,
3537
GlobalPrefixRunExpander,
3638
IDEFICSInstructRunExpander,
@@ -122,6 +124,9 @@ def alter_run_spec(run_spec: RunSpec) -> RunSpec:
122124
chatml_expander = ChatMLRunExpander()
123125
run_spec = singleton(chatml_expander.expand(run_spec))
124126

127+
if NOVA_MODEL_TAG in model.tags:
128+
run_spec = singleton(NovaRunExpander().expand(run_spec))
129+
125130
# Anthropic Claude 1 and 2 prompts
126131
if ANTHROPIC_CLAUDE_1_MODEL_TAG in model.tags or ANTHROPIC_CLAUDE_2_MODEL_TAG in model.tags:
127132
run_spec = singleton(AnthropicClaude2RunExpander().expand(run_spec))

src/helm/clients/bedrock_client.py

+78-1
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
import json
44
import os
55
from typing import Any, Dict, List, Mapping, Optional
6+
from datetime import datetime
67

78
from helm.common.cache import CacheConfig
89
from helm.clients.client import CachingClient, truncate_and_tokenize_response_text
910
from helm.common.request import Request, RequestResult, GeneratedOutput, wrap_request_time
10-
from helm.clients.bedrock_utils import get_bedrock_client
11+
from helm.common.hierarchical_logger import htrack_block
12+
from helm.clients.bedrock_utils import get_bedrock_client, get_bedrock_client_v1
1113
from helm.tokenizers.tokenizer import Tokenizer
1214

1315

@@ -96,6 +98,81 @@ def do_it() -> Dict[Any, Any]:
9698
)
9799

98100

101+
class BedrockNovaClient(CachingClient):
102+
103+
"""
104+
Amazon Bedrock is a fully managed service that provides s selection of leading foundation models (FMs) from Amazon
105+
and other partner model providers.
106+
"""
107+
108+
def __init__(
109+
self,
110+
cache_config: CacheConfig,
111+
tokenizer: Tokenizer,
112+
tokenizer_name: str,
113+
bedrock_model_id: Optional[str] = None,
114+
assumed_role: Optional[str] = None,
115+
region: Optional[str] = None,
116+
):
117+
super().__init__(cache_config=cache_config)
118+
self.tokenizer = tokenizer
119+
self.tokenizer_name = tokenizer_name
120+
self.bedrock_model_id = bedrock_model_id
121+
self.bedrock_client = get_bedrock_client_v1(
122+
assumed_role=assumed_role or os.environ.get("BEDROCK_ASSUME_ROLE", None),
123+
region=region or os.environ.get("AWS_DEFAULT_REGION", None),
124+
)
125+
126+
def convert_request_to_raw_request(self, request: Request) -> Dict:
127+
model_id = request.model.replace("/", ".")
128+
messages = [
129+
{
130+
"role": "user",
131+
"content": [
132+
{
133+
"text": request.prompt
134+
}
135+
]
136+
}
137+
]
138+
139+
return {
140+
"modelId": model_id,
141+
"inferenceConfig": {
142+
"temperature": request.temperature,
143+
"maxTokens": request.max_tokens,
144+
"topP": request.top_p
145+
},
146+
"messages": messages,
147+
}
148+
149+
def make_request(self, request: Request) -> RequestResult:
150+
raw_request = self.convert_request_to_raw_request(request)
151+
response = self.bedrock_client.converse(**raw_request)
152+
completions = self.convert_raw_response_to_completions(response, request)
153+
dt = datetime.strptime(response["ResponseMetadata"]["HTTPHeaders"]["date"], "%a, %d %b %Y %H:%M:%S GMT")
154+
155+
return RequestResult(
156+
success=True,
157+
cached=False,
158+
request_time=response["metrics"]["latencyMs"],
159+
request_datetime=int(dt.timestamp()),
160+
completions=completions,
161+
embedding=[],
162+
)
163+
164+
def convert_raw_response_to_completions(self, response: Dict, request: Request) -> List[GeneratedOutput]:
165+
completions: List[GeneratedOutput] = []
166+
raw_completion = response["output"]
167+
output_text = raw_completion["message"]["content"][0]["text"]
168+
finish_reason = response["stopReason"]
169+
completion = truncate_and_tokenize_response_text(
170+
output_text.lstrip(), request, self.tokenizer, self.tokenizer_name, finish_reason
171+
)
172+
completions.append(completion)
173+
return completions
174+
175+
99176
# Amazon Bedrock Client for Titan Models
100177
class BedrockTitanClient(BedrockClient):
101178
_COMPLETION_REASON_TO_FINISH_REASON = {

src/helm/clients/bedrock_utils.py

+47-1
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
"""Helper utilities for working with Amazon Bedrock."""
22

33
import os
4-
from typing import Optional
4+
from typing import Optional, Dict
55

66
from helm.common.hierarchical_logger import hlog
77
from helm.common.optional_dependencies import handle_module_not_found_error
88

99
try:
1010
import boto3
11+
from boto3 import Session
1112
from botocore.config import Config
1213
except ModuleNotFoundError as e:
1314
handle_module_not_found_error(e, ["aws"])
@@ -70,3 +71,48 @@ def get_bedrock_client(
7071

7172
hlog(f"Amazon Bedrock client successfully created with endpoint {bedrock_client._endpoint}")
7273
return bedrock_client
74+
75+
76+
def get_bedrock_client_v1(
77+
assumed_role: Optional[str] = None,
78+
service_name: str = "bedrock-runtime",
79+
region: str = "us-east-1",
80+
read_timeout: int = 5000,
81+
connect_timeout: int = 5000,
82+
retries: Dict = {"max_attempts": 10},
83+
):
84+
if region is None:
85+
target_region = os.environ.get("AWS_REGION", os.environ.get("AWS_DEFAULT_REGION"))
86+
else:
87+
target_region = region
88+
89+
boto_config = Config(
90+
read_timeout=read_timeout, connect_timeout=connect_timeout, retries=retries
91+
)
92+
93+
if target_region is None:
94+
raise ValueError(
95+
"region environment variable is not set."
96+
)
97+
98+
if assumed_role:
99+
session = boto3.Session(region_name=target_region)
100+
# Assume role and get credentials
101+
sts = session.client("sts")
102+
creds = sts.assume_role(RoleArn=str(assumed_role),RoleSessionName="crfm-helm")["Credentials"]
103+
session = Session(
104+
aws_access_key_id=creds["AccessKeyId"],
105+
aws_secret_access_key=creds["SecretAccessKey"],
106+
)
107+
return session.client(
108+
service_name=service_name,
109+
region_name=target_region,
110+
config=boto_config,
111+
)
112+
113+
# default to instance role to get the aws credentials or aws configured credentials
114+
return boto3.client(
115+
service_name=service_name,
116+
region_name=target_region,
117+
config=boto_config
118+
)

src/helm/clients/huggingface_client.py

-1
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,6 @@ def serve_request(self, raw_request: HuggingFaceRequest) -> Dict:
106106
encoded_input = tokenizer(raw_request["prompt"], return_tensors="pt", return_token_type_ids=False).to(
107107
0 if self.device is None else self.device
108108
)
109-
110109
stopping_criteria: Optional[StoppingCriteriaList] = None
111110
optional_args = {}
112111
if len(raw_request["stop_sequences"]) > 0:

src/helm/config/model_deployments.yaml

+23-1
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,29 @@ model_deployments:
104104
class_name: "helm.benchmark.window_services.image_generation.clip_window_service.CLIPWindowService"
105105

106106

107-
# Amazon
107+
# Amazon nova models
108+
109+
- name: amazon/nova-pro-v1:0
110+
model_name: amazon/nova-pro-v1:0
111+
tokenizer_name: huggingface/gpt2
112+
max_sequence_length: 300000
113+
client_spec:
114+
class_name: "helm.clients.bedrock_client.BedrockNovaClient"
115+
116+
- name: amazon/nova-lite-v1:0
117+
model_name: amazon/nova-lite-v1:0
118+
tokenizer_name: huggingface/gpt2
119+
max_sequence_length: 300000
120+
client_spec:
121+
class_name: "helm.clients.bedrock_client.BedrockNovaClient"
122+
123+
- name: amazon/nova-micro-v1:0
124+
model_name: amazon/nova-micro-v1:0
125+
tokenizer_name: huggingface/gpt2
126+
max_sequence_length: 128000
127+
client_spec:
128+
class_name: "helm.clients.bedrock_client.BedrockNovaClient"
129+
108130
# Titan on Amazon Bedrock
109131

110132
- name: amazon/titan-text-lite-v1

src/helm/config/model_metadata.yaml

+27-1
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,33 @@ models:
219219
tags: [TEXT_TO_IMAGE_MODEL_TAG]
220220

221221

222-
# Amazon
222+
# Amazon Nova models
223+
# References for Amazon Nova models:
224+
# https://aws.amazon.com/ai/generative-ai/nova/
225+
- name: amazon/nova-pro-v1:0
226+
display_name: Amazon Nova Pro
227+
description: Amazon Nova Pro Model
228+
creator_organization_name: Amazon
229+
access: limited
230+
release_date: 2024-12-03
231+
tags: [NOVA_MODEL_TAG, TEXT_MODEL_TAG, LIMITED_FUNCTIONALITY_TEXT_MODEL_TAG]
232+
233+
- name: amazon/nova-lite-v1:0
234+
display_name: Amazon Nova Lite
235+
description: Amazon Nova Lite Model
236+
creator_organization_name: Amazon
237+
access: limited
238+
release_date: 2024-12-03
239+
tags: [NOVA_MODEL_TAG, TEXT_MODEL_TAG, LIMITED_FUNCTIONALITY_TEXT_MODEL_TAG]
240+
241+
- name: amazon/nova-micro-v1:0
242+
display_name: Amazon Nova Micro
243+
description: Amazon Nova Micro Model
244+
creator_organization_name: Amazon
245+
access: limited
246+
release_date: 2024-12-03
247+
tags: [NOVA_MODEL_TAG, TEXT_MODEL_TAG, LIMITED_FUNCTIONALITY_TEXT_MODEL_TAG]
248+
223249
# Titan Models
224250
# References for Amazon Titan models:
225251
# - https://aws.amazon.com/bedrock/titan/

0 commit comments

Comments
 (0)