Skip to content

Commit

Permalink
Merge pull request #443 from mikeyobrien/main
Browse files Browse the repository at this point in the history
Add thinking support for claude-3-7-sonnet
  • Loading branch information
tjbck authored Mar 6, 2025
2 parents ff41479 + 10242dc commit f89ab37
Showing 1 changed file with 135 additions and 39 deletions.
174 changes: 135 additions & 39 deletions examples/pipelines/providers/anthropic_manifold_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
license: MIT
description: A pipeline for generating text and processing images using the Anthropic API.
requirements: requests, sseclient-py
environment_variables: ANTHROPIC_API_KEY
environment_variables: ANTHROPIC_API_KEY, ANTHROPIC_THINKING_BUDGET_TOKENS, ANTHROPIC_ENABLE_THINKING
"""

import os
Expand All @@ -18,6 +18,17 @@

from utils.pipelines.main import pop_system_message

REASONING_EFFORT_BUDGET_TOKEN_MAP = {
"none": None,
"low": 1024,
"medium": 4096,
"high": 16384,
"max": 32768,
}

# Maximum combined token limit for Claude 3.7
MAX_COMBINED_TOKENS = 64000


class Pipeline:
class Valves(BaseModel):
Expand All @@ -29,16 +40,20 @@ def __init__(self):
self.name = "anthropic/"

self.valves = self.Valves(
**{"ANTHROPIC_API_KEY": os.getenv("ANTHROPIC_API_KEY", "your-api-key-here")}
**{
"ANTHROPIC_API_KEY": os.getenv(
"ANTHROPIC_API_KEY", "your-api-key-here"
),
}
)
self.url = 'https://api.anthropic.com/v1/messages'
self.url = "https://api.anthropic.com/v1/messages"
self.update_headers()

def update_headers(self):
self.headers = {
'anthropic-version': '2023-06-01',
'content-type': 'application/json',
'x-api-key': self.valves.ANTHROPIC_API_KEY
"anthropic-version": "2023-06-01",
"content-type": "application/json",
"x-api-key": self.valves.ANTHROPIC_API_KEY,
}

def get_anthropic_models(self):
Expand Down Expand Up @@ -88,7 +103,7 @@ def pipe(
) -> Union[str, Generator, Iterator]:
try:
# Remove unnecessary keys
for key in ['user', 'chat_id', 'title']:
for key in ["user", "chat_id", "title"]:
body.pop(key, None)

system_message, messages = pop_system_message(messages)
Expand All @@ -102,28 +117,40 @@ def pipe(
if isinstance(message.get("content"), list):
for item in message["content"]:
if item["type"] == "text":
processed_content.append({"type": "text", "text": item["text"]})
processed_content.append(
{"type": "text", "text": item["text"]}
)
elif item["type"] == "image_url":
if image_count >= 5:
raise ValueError("Maximum of 5 images per API call exceeded")
raise ValueError(
"Maximum of 5 images per API call exceeded"
)

processed_image = self.process_image(item["image_url"])
processed_content.append(processed_image)

if processed_image["source"]["type"] == "base64":
image_size = len(processed_image["source"]["data"]) * 3 / 4
image_size = (
len(processed_image["source"]["data"]) * 3 / 4
)
else:
image_size = 0

total_image_size += image_size
if total_image_size > 100 * 1024 * 1024:
raise ValueError("Total size of images exceeds 100 MB limit")
raise ValueError(
"Total size of images exceeds 100 MB limit"
)

image_count += 1
else:
processed_content = [{"type": "text", "text": message.get("content", "")}]
processed_content = [
{"type": "text", "text": message.get("content", "")}
]

processed_messages.append({"role": message["role"], "content": processed_content})
processed_messages.append(
{"role": message["role"], "content": processed_content}
)

# Prepare the payload
payload = {
Expand All @@ -139,38 +166,107 @@ def pipe(
}

if body.get("stream", False):
supports_thinking = "claude-3-7" in model_id
reasoning_effort = body.get("reasoning_effort", "none")
budget_tokens = REASONING_EFFORT_BUDGET_TOKEN_MAP.get(reasoning_effort)

# Allow users to input an integer value representing budget tokens
if (
not budget_tokens
and reasoning_effort not in REASONING_EFFORT_BUDGET_TOKEN_MAP.keys()
):
try:
budget_tokens = int(reasoning_effort)
except ValueError as e:
print("Failed to convert reasoning effort to int", e)
budget_tokens = None

if supports_thinking and budget_tokens:
# Check if the combined tokens (budget_tokens + max_tokens) exceeds the limit
max_tokens = payload.get("max_tokens", 4096)
combined_tokens = budget_tokens + max_tokens

if combined_tokens > MAX_COMBINED_TOKENS:
error_message = f"Error: Combined tokens (budget_tokens {budget_tokens} + max_tokens {max_tokens} = {combined_tokens}) exceeds the maximum limit of {MAX_COMBINED_TOKENS}"
print(error_message)
return error_message

payload["max_tokens"] = combined_tokens
payload["thinking"] = {
"type": "enabled",
"budget_tokens": budget_tokens,
}
# Thinking requires temperature 1.0 and does not support top_p, top_k
payload["temperature"] = 1.0
if "top_k" in payload:
del payload["top_k"]
if "top_p" in payload:
del payload["top_p"]
return self.stream_response(payload)
else:
return self.get_completion(payload)
except Exception as e:
return f"Error: {e}"

def stream_response(self, payload: dict) -> Generator:
response = requests.post(self.url, headers=self.headers, json=payload, stream=True)

if response.status_code == 200:
client = sseclient.SSEClient(response)
for event in client.events():
try:
data = json.loads(event.data)
if data["type"] == "content_block_start":
yield data["content_block"]["text"]
elif data["type"] == "content_block_delta":
yield data["delta"]["text"]
elif data["type"] == "message_stop":
break
except json.JSONDecodeError:
print(f"Failed to parse JSON: {event.data}")
except KeyError as e:
print(f"Unexpected data structure: {e}")
print(f"Full data: {data}")
else:
raise Exception(f"Error: {response.status_code} - {response.text}")
"""Used for title and tag generation"""
try:
response = requests.post(
self.url, headers=self.headers, json=payload, stream=True
)
print(f"{response} for {payload}")

if response.status_code == 200:
client = sseclient.SSEClient(response)
for event in client.events():
try:
data = json.loads(event.data)
if data["type"] == "content_block_start":
if data["content_block"]["type"] == "thinking":
yield "<think>"
else:
yield data["content_block"]["text"]
elif data["type"] == "content_block_delta":
if data["delta"]["type"] == "thinking_delta":
yield data["delta"]["thinking"]
elif data["delta"]["type"] == "signature_delta":
yield "\n </think> \n\n"
else:
yield data["delta"]["text"]
elif data["type"] == "message_stop":
break
except json.JSONDecodeError:
print(f"Failed to parse JSON: {event.data}")
yield f"Error: Failed to parse JSON response"
except KeyError as e:
print(f"Unexpected data structure: {e} for payload {payload}")
print(f"Full data: {data}")
yield f"Error: Unexpected data structure: {e}"
else:
error_message = f"Error: {response.status_code} - {response.text}"
print(error_message)
yield error_message
except Exception as e:
error_message = f"Error: {str(e)}"
print(error_message)
yield error_message

def get_completion(self, payload: dict) -> str:
response = requests.post(self.url, headers=self.headers, json=payload)
if response.status_code == 200:
res = response.json()
return res["content"][0]["text"] if "content" in res and res["content"] else ""
else:
raise Exception(f"Error: {response.status_code} - {response.text}")
try:
response = requests.post(self.url, headers=self.headers, json=payload)
print(response, payload)
if response.status_code == 200:
res = response.json()
for content in res["content"]:
if not content.get("text"):
continue
return content["text"]
return ""
else:
error_message = f"Error: {response.status_code} - {response.text}"
print(error_message)
return error_message
except Exception as e:
error_message = f"Error: {str(e)}"
print(error_message)
return error_message

0 comments on commit f89ab37

Please sign in to comment.