From 10242dc8a21a6c5c7ce7e51efa21ea171875fb83 Mon Sep 17 00:00:00 2001 From: Mikey O'Brien Date: Mon, 24 Feb 2025 19:33:54 -0600 Subject: [PATCH] Add thinking support for claude-3-7-sonnet --- .../providers/anthropic_manifold_pipeline.py | 174 ++++++++++++++---- 1 file changed, 135 insertions(+), 39 deletions(-) diff --git a/examples/pipelines/providers/anthropic_manifold_pipeline.py b/examples/pipelines/providers/anthropic_manifold_pipeline.py index 3540c618..29dd2d70 100644 --- a/examples/pipelines/providers/anthropic_manifold_pipeline.py +++ b/examples/pipelines/providers/anthropic_manifold_pipeline.py @@ -6,7 +6,7 @@ license: MIT description: A pipeline for generating text and processing images using the Anthropic API. requirements: requests, sseclient-py -environment_variables: ANTHROPIC_API_KEY +environment_variables: ANTHROPIC_API_KEY, ANTHROPIC_THINKING_BUDGET_TOKENS, ANTHROPIC_ENABLE_THINKING """ import os @@ -18,6 +18,17 @@ from utils.pipelines.main import pop_system_message +REASONING_EFFORT_BUDGET_TOKEN_MAP = { + "none": None, + "low": 1024, + "medium": 4096, + "high": 16384, + "max": 32768, +} + +# Maximum combined token limit for Claude 3.7 +MAX_COMBINED_TOKENS = 64000 + class Pipeline: class Valves(BaseModel): @@ -29,16 +40,20 @@ def __init__(self): self.name = "anthropic/" self.valves = self.Valves( - **{"ANTHROPIC_API_KEY": os.getenv("ANTHROPIC_API_KEY", "your-api-key-here")} + **{ + "ANTHROPIC_API_KEY": os.getenv( + "ANTHROPIC_API_KEY", "your-api-key-here" + ), + } ) - self.url = 'https://api.anthropic.com/v1/messages' + self.url = "https://api.anthropic.com/v1/messages" self.update_headers() def update_headers(self): self.headers = { - 'anthropic-version': '2023-06-01', - 'content-type': 'application/json', - 'x-api-key': self.valves.ANTHROPIC_API_KEY + "anthropic-version": "2023-06-01", + "content-type": "application/json", + "x-api-key": self.valves.ANTHROPIC_API_KEY, } def get_anthropic_models(self): @@ -88,7 +103,7 @@ def pipe( ) -> Union[str, Generator, Iterator]: try: # Remove unnecessary keys - for key in ['user', 'chat_id', 'title']: + for key in ["user", "chat_id", "title"]: body.pop(key, None) system_message, messages = pop_system_message(messages) @@ -102,28 +117,40 @@ def pipe( if isinstance(message.get("content"), list): for item in message["content"]: if item["type"] == "text": - processed_content.append({"type": "text", "text": item["text"]}) + processed_content.append( + {"type": "text", "text": item["text"]} + ) elif item["type"] == "image_url": if image_count >= 5: - raise ValueError("Maximum of 5 images per API call exceeded") + raise ValueError( + "Maximum of 5 images per API call exceeded" + ) processed_image = self.process_image(item["image_url"]) processed_content.append(processed_image) if processed_image["source"]["type"] == "base64": - image_size = len(processed_image["source"]["data"]) * 3 / 4 + image_size = ( + len(processed_image["source"]["data"]) * 3 / 4 + ) else: image_size = 0 total_image_size += image_size if total_image_size > 100 * 1024 * 1024: - raise ValueError("Total size of images exceeds 100 MB limit") + raise ValueError( + "Total size of images exceeds 100 MB limit" + ) image_count += 1 else: - processed_content = [{"type": "text", "text": message.get("content", "")}] + processed_content = [ + {"type": "text", "text": message.get("content", "")} + ] - processed_messages.append({"role": message["role"], "content": processed_content}) + processed_messages.append( + {"role": message["role"], "content": processed_content} + ) # Prepare the payload payload = { @@ -139,6 +166,42 @@ def pipe( } if body.get("stream", False): + supports_thinking = "claude-3-7" in model_id + reasoning_effort = body.get("reasoning_effort", "none") + budget_tokens = REASONING_EFFORT_BUDGET_TOKEN_MAP.get(reasoning_effort) + + # Allow users to input an integer value representing budget tokens + if ( + not budget_tokens + and reasoning_effort not in REASONING_EFFORT_BUDGET_TOKEN_MAP.keys() + ): + try: + budget_tokens = int(reasoning_effort) + except ValueError as e: + print("Failed to convert reasoning effort to int", e) + budget_tokens = None + + if supports_thinking and budget_tokens: + # Check if the combined tokens (budget_tokens + max_tokens) exceeds the limit + max_tokens = payload.get("max_tokens", 4096) + combined_tokens = budget_tokens + max_tokens + + if combined_tokens > MAX_COMBINED_TOKENS: + error_message = f"Error: Combined tokens (budget_tokens {budget_tokens} + max_tokens {max_tokens} = {combined_tokens}) exceeds the maximum limit of {MAX_COMBINED_TOKENS}" + print(error_message) + return error_message + + payload["max_tokens"] = combined_tokens + payload["thinking"] = { + "type": "enabled", + "budget_tokens": budget_tokens, + } + # Thinking requires temperature 1.0 and does not support top_p, top_k + payload["temperature"] = 1.0 + if "top_k" in payload: + del payload["top_k"] + if "top_p" in payload: + del payload["top_p"] return self.stream_response(payload) else: return self.get_completion(payload) @@ -146,31 +209,64 @@ def pipe( return f"Error: {e}" def stream_response(self, payload: dict) -> Generator: - response = requests.post(self.url, headers=self.headers, json=payload, stream=True) - - if response.status_code == 200: - client = sseclient.SSEClient(response) - for event in client.events(): - try: - data = json.loads(event.data) - if data["type"] == "content_block_start": - yield data["content_block"]["text"] - elif data["type"] == "content_block_delta": - yield data["delta"]["text"] - elif data["type"] == "message_stop": - break - except json.JSONDecodeError: - print(f"Failed to parse JSON: {event.data}") - except KeyError as e: - print(f"Unexpected data structure: {e}") - print(f"Full data: {data}") - else: - raise Exception(f"Error: {response.status_code} - {response.text}") + """Used for title and tag generation""" + try: + response = requests.post( + self.url, headers=self.headers, json=payload, stream=True + ) + print(f"{response} for {payload}") + + if response.status_code == 200: + client = sseclient.SSEClient(response) + for event in client.events(): + try: + data = json.loads(event.data) + if data["type"] == "content_block_start": + if data["content_block"]["type"] == "thinking": + yield "" + else: + yield data["content_block"]["text"] + elif data["type"] == "content_block_delta": + if data["delta"]["type"] == "thinking_delta": + yield data["delta"]["thinking"] + elif data["delta"]["type"] == "signature_delta": + yield "\n \n\n" + else: + yield data["delta"]["text"] + elif data["type"] == "message_stop": + break + except json.JSONDecodeError: + print(f"Failed to parse JSON: {event.data}") + yield f"Error: Failed to parse JSON response" + except KeyError as e: + print(f"Unexpected data structure: {e} for payload {payload}") + print(f"Full data: {data}") + yield f"Error: Unexpected data structure: {e}" + else: + error_message = f"Error: {response.status_code} - {response.text}" + print(error_message) + yield error_message + except Exception as e: + error_message = f"Error: {str(e)}" + print(error_message) + yield error_message def get_completion(self, payload: dict) -> str: - response = requests.post(self.url, headers=self.headers, json=payload) - if response.status_code == 200: - res = response.json() - return res["content"][0]["text"] if "content" in res and res["content"] else "" - else: - raise Exception(f"Error: {response.status_code} - {response.text}") + try: + response = requests.post(self.url, headers=self.headers, json=payload) + print(response, payload) + if response.status_code == 200: + res = response.json() + for content in res["content"]: + if not content.get("text"): + continue + return content["text"] + return "" + else: + error_message = f"Error: {response.status_code} - {response.text}" + print(error_message) + return error_message + except Exception as e: + error_message = f"Error: {str(e)}" + print(error_message) + return error_message