Skip to content

Commit c5c451c

Browse files
Add Command R and Command R+ models (#2548)
Co-authored-by: Yifan Mai <yifan@cs.stanford.edu>
1 parent 13abf8f commit c5c451c

File tree

5 files changed

+156
-4
lines changed

5 files changed

+156
-4
lines changed

setup.cfg

+4
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,9 @@ anthropic =
129129
anthropic~=0.17
130130
websocket-client~=1.3.2 # For legacy stanford-online-all-v4-s3
131131

132+
cohere =
133+
cohere~=5.3
134+
132135
mistral =
133136
mistralai~=0.0.11
134137

@@ -154,6 +157,7 @@ models =
154157
crfm-helm[allenai]
155158
crfm-helm[amazon]
156159
crfm-helm[anthropic]
160+
crfm-helm[cohere]
157161
crfm-helm[google]
158162
crfm-helm[mistral]
159163
crfm-helm[openai]

src/helm/clients/cohere_client.py

+98-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import json
22
import requests
3-
from typing import List
3+
from typing import List, Optional, Sequence, TypedDict
44

55
from helm.common.cache import CacheConfig
6+
from helm.common.optional_dependencies import handle_module_not_found_error
67
from helm.common.request import (
78
wrap_request_time,
89
EMBEDDING_UNAVAILABLE_REQUEST_RESULT,
@@ -11,8 +12,13 @@
1112
GeneratedOutput,
1213
Token,
1314
)
14-
from .client import CachingClient, truncate_sequence
15-
from .cohere_utils import get_cohere_url, DEFAULT_COHERE_API_VERSION
15+
from helm.clients.client import CachingClient, truncate_sequence
16+
from helm.clients.cohere_utils import get_cohere_url, DEFAULT_COHERE_API_VERSION
17+
18+
try:
19+
import cohere
20+
except ModuleNotFoundError as e:
21+
handle_module_not_found_error(e, ["cohere"])
1622

1723

1824
class CohereClient(CachingClient):
@@ -152,3 +158,92 @@ def do_it():
152158
completions=completions,
153159
embedding=[],
154160
)
161+
162+
163+
class CohereRawChatRequest(TypedDict):
164+
message: str
165+
model: Optional[str]
166+
preamble: Optional[str]
167+
chat_history: Optional[Sequence[cohere.ChatMessage]]
168+
temperature: Optional[float]
169+
max_tokens: Optional[int]
170+
k: Optional[int]
171+
p: Optional[float]
172+
seed: Optional[float]
173+
stop_sequences: Optional[Sequence[str]]
174+
frequency_penalty: Optional[float]
175+
presence_penalty: Optional[float]
176+
177+
178+
def convert_to_raw_chat_request(request: Request) -> CohereRawChatRequest:
179+
# TODO: Support chat
180+
model = request.model.replace("cohere/", "")
181+
return {
182+
"message": request.prompt,
183+
"model": model,
184+
"preamble": None,
185+
"chat_history": None,
186+
"temperature": request.temperature,
187+
"max_tokens": request.max_tokens,
188+
"k": request.top_k_per_token,
189+
"p": request.top_p,
190+
"stop_sequences": request.stop_sequences,
191+
"seed": float(request.random) if request.random is not None else None,
192+
"frequency_penalty": request.frequency_penalty,
193+
"presence_penalty": request.presence_penalty,
194+
}
195+
196+
197+
class CohereChatClient(CachingClient):
198+
"""
199+
Leverages the chat endpoint: https://docs.cohere.com/reference/chat
200+
201+
Cohere models will only support chat soon: https://docs.cohere.com/docs/migrating-from-cogenerate-to-cochat
202+
"""
203+
204+
def __init__(self, api_key: str, cache_config: CacheConfig):
205+
super().__init__(cache_config=cache_config)
206+
self.client = cohere.Client(api_key=api_key)
207+
208+
def make_request(self, request: Request) -> RequestResult:
209+
if request.embedding:
210+
return EMBEDDING_UNAVAILABLE_REQUEST_RESULT
211+
# TODO: Support multiple completions
212+
assert request.num_completions == 1, "CohereChatClient only supports num_completions=1"
213+
# TODO: Support messages
214+
assert not request.messages, "CohereChatClient currently does not support the messages API"
215+
216+
raw_request: CohereRawChatRequest = convert_to_raw_chat_request(request)
217+
218+
try:
219+
220+
def do_it():
221+
"""
222+
Send the request to the Cohere Chat API. Responses will be structured like this:
223+
cohere.Chat {
224+
message: What's up?
225+
text: Hey there! How's it going? I'm doing well, thank you for asking 😊.
226+
...
227+
}
228+
"""
229+
raw_response = self.client.chat(**raw_request).dict()
230+
assert "text" in raw_response, f"Response does not contain text: {raw_response}"
231+
return raw_response
232+
233+
response, cached = self.cache.get(raw_request, wrap_request_time(do_it))
234+
except (requests.exceptions.RequestException, AssertionError) as e:
235+
error: str = f"CohereClient error: {e}"
236+
return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
237+
238+
completions: List[GeneratedOutput] = []
239+
completion: GeneratedOutput = GeneratedOutput(text=response["text"], logprob=0.0, tokens=[])
240+
completions.append(completion)
241+
242+
return RequestResult(
243+
success=True,
244+
cached=cached,
245+
request_time=response["request_time"],
246+
request_datetime=response["request_datetime"],
247+
completions=completions,
248+
embedding=[],
249+
)

src/helm/config/model_deployments.yaml

+19
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,25 @@ model_deployments:
325325
window_service_spec:
326326
class_name: "helm.benchmark.window_services.cohere_window_service.CohereWindowService"
327327

328+
- name: cohere/command-r
329+
model_name: cohere/command-r
330+
tokenizer_name: cohere/c4ai-command-r-v01
331+
max_sequence_length: 128000
332+
max_request_length: 128000
333+
client_spec:
334+
class_name: "helm.clients.cohere_client.CohereChatClient"
335+
336+
- name: cohere/command-r-plus
337+
model_name: cohere/command-r-plus
338+
tokenizer_name: cohere/c4ai-command-r-plus
339+
# "We have a known issue where prompts between 112K - 128K in length
340+
# result in bad generations."
341+
# Source: https://docs.cohere.com/docs/command-r-plus
342+
max_sequence_length: 110000
343+
max_request_length: 110000
344+
client_spec:
345+
class_name: "helm.clients.cohere_client.CohereChatClient"
346+
328347
# Craiyon
329348

330349
- name: craiyon/dalle-mini

src/helm/config/model_metadata.yaml

+19-1
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,25 @@ models:
468468
creator_organization_name: Cohere
469469
access: limited
470470
release_date: 2023-09-29
471-
tags: [TEXT_MODEL_TAG, FULL_FUNCTIONALITY_TEXT_MODEL_TAG, INSTRUCTION_FOLLOWING_MODEL_TAG]
471+
tags: [TEXT_MODEL_TAG, PARTIAL_FUNCTIONALITY_TEXT_MODEL_TAG, INSTRUCTION_FOLLOWING_MODEL_TAG]
472+
473+
- name: cohere/command-r
474+
display_name: Cohere Command R
475+
description: Command R is a multilingual 35B parameter model with a context length of 128K that has been trained with conversational tool use capabilities.
476+
creator_organization_name: Cohere
477+
access: open
478+
num_parameters: 35000000000
479+
release_date: 2024-03-11
480+
tags: [TEXT_MODEL_TAG, PARTIAL_FUNCTIONALITY_TEXT_MODEL_TAG, INSTRUCTION_FOLLOWING_MODEL_TAG]
481+
482+
- name: cohere/command-r-plus
483+
display_name: Cohere Command R Plus
484+
description: Command R+ is a multilingual 104B parameter model with a context length of 128K that has been trained with conversational tool use capabilities.
485+
creator_organization_name: Cohere
486+
access: open
487+
num_parameters: 104000000000
488+
release_date: 2024-04-04
489+
tags: [TEXT_MODEL_TAG, PARTIAL_FUNCTIONALITY_TEXT_MODEL_TAG, INSTRUCTION_FOLLOWING_MODEL_TAG]
472490

473491
# Craiyon
474492
- name: craiyon/dalle-mini

src/helm/config/tokenizer_configs.yaml

+16
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,22 @@ tokenizer_configs:
8383
end_of_text_token: ""
8484
prefix_token: ":"
8585

86+
- name: cohere/c4ai-command-r-v01
87+
tokenizer_spec:
88+
class_name: "helm.tokenizers.huggingface_tokenizer.HuggingFaceTokenizer"
89+
args:
90+
pretrained_model_name_or_path: CohereForAI/c4ai-command-r-v01
91+
end_of_text_token: "<|END_OF_TURN_TOKEN|>"
92+
prefix_token: "<BOS_TOKEN>"
93+
94+
- name: cohere/c4ai-command-r-plus
95+
tokenizer_spec:
96+
class_name: "helm.tokenizers.huggingface_tokenizer.HuggingFaceTokenizer"
97+
args:
98+
pretrained_model_name_or_path: CohereForAI/c4ai-command-r-plus
99+
end_of_text_token: "<|END_OF_TURN_TOKEN|>"
100+
prefix_token: "<BOS_TOKEN>"
101+
86102
# Databricks
87103
- name: databricks/dbrx-instruct
88104
tokenizer_spec:

0 commit comments

Comments
 (0)