diff --git a/docs/griptape-framework/drivers/prompt-drivers.md b/docs/griptape-framework/drivers/prompt-drivers.md index 1919b3bd5d..3f3a020c9a 100644 --- a/docs/griptape-framework/drivers/prompt-drivers.md +++ b/docs/griptape-framework/drivers/prompt-drivers.md @@ -230,3 +230,11 @@ The [GrokPromptDriver](../../reference/griptape/drivers/prompt/grok_prompt_drive ```python --8<-- "docs/griptape-framework/drivers/src/prompt_drivers_grok.py" ``` + +### Perplexity Sonar + +The [PerplexitySonarPromptDriver](../../reference/griptape/drivers/prompt/perplexity_sonar_prompt_driver.md) uses [Perplexity Sonar's chat completion](https://docs.perplexity.ai/api-reference/chat-completions) endpoint. + +```python +--8<-- "docs/griptape-framework/drivers/src/prompt_drivers_perplexity.py" +``` diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_perplexity.py b/docs/griptape-framework/drivers/src/prompt_drivers_perplexity.py new file mode 100644 index 0000000000..e14a8067ff --- /dev/null +++ b/docs/griptape-framework/drivers/src/prompt_drivers_perplexity.py @@ -0,0 +1,12 @@ +import os + +from griptape.drivers.prompt.perplexity_sonar import PerplexitySonarPromptDriver +from griptape.rules import Rule +from griptape.structures import Agent + +agent = Agent( + prompt_driver=PerplexitySonarPromptDriver(model="sonar-pro", api_key=os.environ["PERPLEXITY_SONAR_API_KEY"]), + rules=[Rule("Be precise and concise")], +) + +agent.run("How many stars are there in our galaxy?") diff --git a/griptape/drivers/__init__.py b/griptape/drivers/__init__.py index 5c702dd910..5dc9fddbdd 100644 --- a/griptape/drivers/__init__.py +++ b/griptape/drivers/__init__.py @@ -14,6 +14,7 @@ from .prompt.dummy import DummyPromptDriver from .prompt.ollama import OllamaPromptDriver from .prompt.grok import GrokPromptDriver +from .prompt.perplexity_sonar import PerplexitySonarPromptDriver from .memory.conversation import BaseConversationMemoryDriver from .memory.conversation.local import LocalConversationMemoryDriver @@ -141,6 +142,7 @@ "DummyPromptDriver", "OllamaPromptDriver", "GrokPromptDriver", + "PerplexitySonarPromptDriver", "BaseConversationMemoryDriver", "LocalConversationMemoryDriver", "AmazonDynamoDbConversationMemoryDriver", diff --git a/griptape/drivers/prompt/perplexity_sonar/__init__.py b/griptape/drivers/prompt/perplexity_sonar/__init__.py new file mode 100644 index 0000000000..a9612b1611 --- /dev/null +++ b/griptape/drivers/prompt/perplexity_sonar/__init__.py @@ -0,0 +1,5 @@ +from griptape.drivers.prompt.perplexity_sonar_prompt_driver import PerplexitySonarPromptDriver + +__all__ = [ + "PerplexitySonarPromptDriver", +] diff --git a/griptape/drivers/prompt/perplexity_sonar_prompt_driver.py b/griptape/drivers/prompt/perplexity_sonar_prompt_driver.py new file mode 100644 index 0000000000..f6ed76275e --- /dev/null +++ b/griptape/drivers/prompt/perplexity_sonar_prompt_driver.py @@ -0,0 +1,10 @@ +from __future__ import annotations + +from attrs import define, field + +from griptape.drivers.prompt.openai import OpenAiChatPromptDriver + + +@define +class PerplexitySonarPromptDriver(OpenAiChatPromptDriver): + base_url: str = field(default="https://api.perplexity.ai", kw_only=True, metadata={"serializable": True}) diff --git a/tests/unit/drivers/prompt/test_perplexity_sonar_prompt_driver.py b/tests/unit/drivers/prompt/test_perplexity_sonar_prompt_driver.py new file mode 100644 index 0000000000..5436e6cdcb --- /dev/null +++ b/tests/unit/drivers/prompt/test_perplexity_sonar_prompt_driver.py @@ -0,0 +1,188 @@ +from unittest.mock import ANY + +import pytest + +from griptape.artifacts import ActionArtifact, AudioArtifact, TextArtifact +from griptape.common import ActionCallDeltaMessageContent, AudioDeltaMessageContent, TextDeltaMessageContent +from griptape.drivers.prompt.perplexity_sonar import PerplexitySonarPromptDriver +from tests.unit.drivers.prompt.test_openai_chat_prompt_driver import TestOpenAiChatPromptDriverFixtureMixin + + +class TestPerplexitySonarPromptDriver(TestOpenAiChatPromptDriverFixtureMixin): + def test_init(self): + assert PerplexitySonarPromptDriver(api_key="foo", model="gpt-4") + + @pytest.mark.parametrize("use_native_tools", [True, False]) + @pytest.mark.parametrize("structured_output_strategy", ["native", "tool"]) + @pytest.mark.parametrize("modalities", [["text"], ["text", "audio"], ["audio"]]) + def test_try_run( + self, + mock_chat_completion_create, + prompt_stack, + messages, + use_native_tools, + structured_output_strategy, + modalities, + ): + # Given + driver = PerplexitySonarPromptDriver( + model="sonar-pro", + use_native_tools=use_native_tools, + structured_output_strategy=structured_output_strategy, + modalities=modalities, + extra_params={"foo": "bar"}, + ) + + # When + message = driver.try_run(prompt_stack) + + # Then + mock_chat_completion_create.assert_called_once_with( + model=driver.model, + temperature=driver.temperature, + user=driver.user, + messages=messages, + modalities=modalities, + **{ + "audio": driver.audio, + } + if "audio" in driver.modalities + else {}, + seed=driver.seed, + **{ + "parallel_tool_calls": driver.parallel_tool_calls, + } + if prompt_stack.tools and driver.use_native_tools + else {}, + **{ + "tools": self.OPENAI_TOOLS, + "tool_choice": "required" if structured_output_strategy == "tool" else driver.tool_choice, + } + if use_native_tools + else {}, + **{ + "response_format": { + "type": "json_schema", + "json_schema": { + "name": "Output", + "schema": self.OPENAI_STRUCTURED_OUTPUT_SCHEMA, + "strict": True, + }, + } + } + if structured_output_strategy == "native" + else {}, + foo="bar", + ) + assert isinstance(message.value[0], TextArtifact) + assert message.value[0].value == "model-output" + assert isinstance(message.value[1], AudioArtifact) + assert message.value[1].value == b"assistant-audio-data" + assert message.value[1].format == "wav" + assert message.value[1].meta == { + "audio_id": "audio-id", + "transcript": "assistant-audio-transcription", + "expires_at": ANY, + } + assert isinstance(message.value[2], ActionArtifact) + assert message.value[2].value.tag == "mock-id" + assert message.value[2].value.name == "MockTool" + assert message.value[2].value.path == "test" + assert message.value[2].value.input == {"foo": "bar"} + + @pytest.mark.parametrize("use_native_tools", [True, False]) + @pytest.mark.parametrize("structured_output_strategy", ["native", "tool"]) + @pytest.mark.parametrize("modalities", [["text"], ["text", "audio"], ["audio"]]) + def test_try_stream_run( + self, + mock_chat_completion_stream_create, + prompt_stack, + messages, + use_native_tools, + structured_output_strategy, + modalities, + ): + # Given + driver = PerplexitySonarPromptDriver( + model="sonar-pro", + stream=True, + modalities=modalities, + use_native_tools=use_native_tools, + structured_output_strategy=structured_output_strategy, + extra_params={"foo": "bar"}, + ) + + # When + stream = driver.try_stream(prompt_stack) + event = next(stream) + + # Then + mock_chat_completion_stream_create.assert_called_once_with( + model=driver.model, + temperature=driver.temperature, + user=driver.user, + **{ + "audio": driver.audio, + } + if "audio" in driver.modalities + else {}, + stream=True, + messages=messages, + modalities=modalities, + **{ + "tools": self.OPENAI_TOOLS, + "tool_choice": "required" if structured_output_strategy == "tool" else driver.tool_choice, + } + if use_native_tools + else {}, + seed=driver.seed, + **{ + "parallel_tool_calls": driver.parallel_tool_calls, + } + if prompt_stack.tools and driver.use_native_tools + else {}, + **{ + "response_format": { + "type": "json_schema", + "json_schema": { + "name": "Output", + "schema": self.OPENAI_STRUCTURED_OUTPUT_SCHEMA, + "strict": True, + }, + } + } + if structured_output_strategy == "native" + else {}, + foo="bar", + stream_options={"include_usage": True}, + ) + + assert isinstance(event.content, TextDeltaMessageContent) + assert event.content.text == "model-output" + + event = next(stream) + assert isinstance(event.content, ActionCallDeltaMessageContent) + assert event.content.tag == "mock-id" + assert event.content.name == "MockTool" + assert event.content.path == "test" + + event = next(stream) + assert isinstance(event.content, ActionCallDeltaMessageContent) + assert event.content.partial_input == '{"foo": "bar"}' + + event = next(stream) + assert event.usage.input_tokens == 5 + assert event.usage.output_tokens == 10 + + event = next(stream) + assert isinstance(event.content, AudioDeltaMessageContent) + assert event.content.id == "audio-id" + + event = next(stream) + assert isinstance(event.content, AudioDeltaMessageContent) + assert event.content.data == "YXNzaXN0YW50LWF1ZGlvLWRhdGE=" + + event = next(stream) + assert isinstance(event.content, AudioDeltaMessageContent) + assert event.content.expires_at == ANY + assert event.content.transcript == "assistant-audio-transcription"