|
1 | 1 | import json
|
2 | 2 | import requests
|
3 |
| -from typing import List |
| 3 | +from typing import List, Optional, Sequence, TypedDict |
4 | 4 |
|
5 | 5 | from helm.common.cache import CacheConfig
|
| 6 | +from helm.common.optional_dependencies import handle_module_not_found_error |
6 | 7 | from helm.common.request import (
|
7 | 8 | wrap_request_time,
|
8 | 9 | EMBEDDING_UNAVAILABLE_REQUEST_RESULT,
|
|
11 | 12 | GeneratedOutput,
|
12 | 13 | Token,
|
13 | 14 | )
|
14 |
| -from .client import CachingClient, truncate_sequence |
15 |
| -from .cohere_utils import get_cohere_url, DEFAULT_COHERE_API_VERSION |
| 15 | +from helm.clients.client import CachingClient, truncate_sequence |
| 16 | +from helm.clients.cohere_utils import get_cohere_url, DEFAULT_COHERE_API_VERSION |
| 17 | + |
| 18 | +try: |
| 19 | + import cohere |
| 20 | +except ModuleNotFoundError as e: |
| 21 | + handle_module_not_found_error(e, ["cohere"]) |
16 | 22 |
|
17 | 23 |
|
18 | 24 | class CohereClient(CachingClient):
|
@@ -152,3 +158,92 @@ def do_it():
|
152 | 158 | completions=completions,
|
153 | 159 | embedding=[],
|
154 | 160 | )
|
| 161 | + |
| 162 | + |
| 163 | +class CohereRawChatRequest(TypedDict): |
| 164 | + message: str |
| 165 | + model: Optional[str] |
| 166 | + preamble: Optional[str] |
| 167 | + chat_history: Optional[Sequence[cohere.ChatMessage]] |
| 168 | + temperature: Optional[float] |
| 169 | + max_tokens: Optional[int] |
| 170 | + k: Optional[int] |
| 171 | + p: Optional[float] |
| 172 | + seed: Optional[float] |
| 173 | + stop_sequences: Optional[Sequence[str]] |
| 174 | + frequency_penalty: Optional[float] |
| 175 | + presence_penalty: Optional[float] |
| 176 | + |
| 177 | + |
| 178 | +def convert_to_raw_chat_request(request: Request) -> CohereRawChatRequest: |
| 179 | + # TODO: Support chat |
| 180 | + model = request.model.replace("cohere/", "") |
| 181 | + return { |
| 182 | + "message": request.prompt, |
| 183 | + "model": model, |
| 184 | + "preamble": None, |
| 185 | + "chat_history": None, |
| 186 | + "temperature": request.temperature, |
| 187 | + "max_tokens": request.max_tokens, |
| 188 | + "k": request.top_k_per_token, |
| 189 | + "p": request.top_p, |
| 190 | + "stop_sequences": request.stop_sequences, |
| 191 | + "seed": float(request.random) if request.random is not None else None, |
| 192 | + "frequency_penalty": request.frequency_penalty, |
| 193 | + "presence_penalty": request.presence_penalty, |
| 194 | + } |
| 195 | + |
| 196 | + |
| 197 | +class CohereChatClient(CachingClient): |
| 198 | + """ |
| 199 | + Leverages the chat endpoint: https://docs.cohere.com/reference/chat |
| 200 | +
|
| 201 | + Cohere models will only support chat soon: https://docs.cohere.com/docs/migrating-from-cogenerate-to-cochat |
| 202 | + """ |
| 203 | + |
| 204 | + def __init__(self, api_key: str, cache_config: CacheConfig): |
| 205 | + super().__init__(cache_config=cache_config) |
| 206 | + self.client = cohere.Client(api_key=api_key) |
| 207 | + |
| 208 | + def make_request(self, request: Request) -> RequestResult: |
| 209 | + if request.embedding: |
| 210 | + return EMBEDDING_UNAVAILABLE_REQUEST_RESULT |
| 211 | + # TODO: Support multiple completions |
| 212 | + assert request.num_completions == 1, "CohereChatClient only supports num_completions=1" |
| 213 | + # TODO: Support messages |
| 214 | + assert not request.messages, "CohereChatClient currently does not support the messages API" |
| 215 | + |
| 216 | + raw_request: CohereRawChatRequest = convert_to_raw_chat_request(request) |
| 217 | + |
| 218 | + try: |
| 219 | + |
| 220 | + def do_it(): |
| 221 | + """ |
| 222 | + Send the request to the Cohere Chat API. Responses will be structured like this: |
| 223 | + cohere.Chat { |
| 224 | + message: What's up? |
| 225 | + text: Hey there! How's it going? I'm doing well, thank you for asking 😊. |
| 226 | + ... |
| 227 | + } |
| 228 | + """ |
| 229 | + raw_response = self.client.chat(**raw_request).dict() |
| 230 | + assert "text" in raw_response, f"Response does not contain text: {raw_response}" |
| 231 | + return raw_response |
| 232 | + |
| 233 | + response, cached = self.cache.get(raw_request, wrap_request_time(do_it)) |
| 234 | + except (requests.exceptions.RequestException, AssertionError) as e: |
| 235 | + error: str = f"CohereClient error: {e}" |
| 236 | + return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[]) |
| 237 | + |
| 238 | + completions: List[GeneratedOutput] = [] |
| 239 | + completion: GeneratedOutput = GeneratedOutput(text=response["text"], logprob=0.0, tokens=[]) |
| 240 | + completions.append(completion) |
| 241 | + |
| 242 | + return RequestResult( |
| 243 | + success=True, |
| 244 | + cached=cached, |
| 245 | + request_time=response["request_time"], |
| 246 | + request_datetime=response["request_datetime"], |
| 247 | + completions=completions, |
| 248 | + embedding=[], |
| 249 | + ) |
0 commit comments