Skip to content

Commit d5db001

Browse files
authored
Fix streaming for OpenAI clients (#1371)
1 parent 911dddf commit d5db001

File tree

3 files changed

+77
-22
lines changed

3 files changed

+77
-22
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "truss"
3-
version = "0.9.60rc005"
3+
version = "0.9.60rc006"
44
description = "A seamless bridge from model development to model delivery"
55
license = "MIT"
66
readme = "README.md"

truss/templates/server/model_wrapper.py

+35-21
Original file line numberDiff line numberDiff line change
@@ -716,19 +716,37 @@ async def _process_model_fn(
716716
result = await self._execute_async_model_fn(inputs, request, descriptor)
717717

718718
if inspect.isgenerator(result) or inspect.isasyncgen(result):
719-
if request.headers.get("accept") == "application/json":
720-
return await _gather_generator(result)
721-
else:
722-
return await self._stream_with_background_task(
723-
result,
724-
fn_span,
725-
detached_ctx,
726-
# No semaphores needed for non-predict model functions.
727-
release_and_end=lambda: None,
728-
)
719+
return await self._handle_generator_response(
720+
request, result, fn_span, detached_ctx, release_and_end=lambda: None
721+
)
729722

730723
return result
731724

725+
def _should_gather_generator(self, request: starlette.requests.Request) -> bool:
726+
# The OpenAI SDK sends an accept header for JSON even in a streaming context,
727+
# but we need to stream results back for client compatibility. Luckily,
728+
# we can differentiate by looking at the user agent (e.g. OpenAI/Python 1.61.0)
729+
user_agent = request.headers.get("user-agent", "")
730+
if "openai" in user_agent.lower():
731+
return False
732+
# TODO(nikhil): determine if we can safely deprecate this behavior.
733+
return request.headers.get("accept") == "application/json"
734+
735+
async def _handle_generator_response(
736+
self,
737+
request: starlette.requests.Request,
738+
generator: Union[Generator[bytes, None, None], AsyncGenerator[bytes, None]],
739+
span: trace.Span,
740+
trace_ctx: trace.Context,
741+
release_and_end: Callable[[], None],
742+
):
743+
if self._should_gather_generator(request):
744+
return await _gather_generator(generator)
745+
else:
746+
return await self._stream_with_background_task(
747+
generator, span, trace_ctx, release_and_end
748+
)
749+
732750
async def completions(
733751
self, inputs: InputType, request: starlette.requests.Request
734752
) -> OutputType:
@@ -801,17 +819,13 @@ async def __call__(
801819
"the predict method."
802820
)
803821

804-
if request.headers.get("accept") == "application/json":
805-
# In the case of a streaming response, consume stream
806-
# if the http accept header is set, and json is requested.
807-
return await _gather_generator(predict_result)
808-
else:
809-
return await self._stream_with_background_task(
810-
predict_result,
811-
span_predict,
812-
detached_ctx,
813-
release_and_end=get_defer_fn(),
814-
)
822+
return await self._handle_generator_response(
823+
request,
824+
predict_result,
825+
span_predict,
826+
detached_ctx,
827+
release_and_end=get_defer_fn(),
828+
)
815829

816830
if isinstance(predict_result, starlette.responses.Response):
817831
if self.model_descriptor.postprocess:

truss/tests/test_model_inference.py

+41
Original file line numberDiff line numberDiff line change
@@ -1818,3 +1818,44 @@ async def predict(self, nums: AsyncGenerator[str, None]) -> List[str]:
18181818
response = requests.post(PREDICT_URL, json={"nums": ["1", "2"]})
18191819
assert response.status_code == 200
18201820
assert response.json() == ["1", "2"]
1821+
1822+
1823+
@pytest.mark.integration
1824+
def test_openai_client_streaming():
1825+
"""
1826+
Test a Truss that exposes an OpenAI compatible endpoint.
1827+
"""
1828+
model = """
1829+
from typing import Dict, AsyncGenerator
1830+
1831+
class Model:
1832+
def __init__(self):
1833+
pass
1834+
1835+
def load(self):
1836+
pass
1837+
1838+
async def chat_completions(self, inputs: Dict) -> AsyncGenerator[str, None]:
1839+
for num in inputs["nums"]:
1840+
yield num
1841+
1842+
async def predict(self, inputs: Dict):
1843+
pass
1844+
"""
1845+
with ensure_kill_all(), _temp_truss(model) as tr:
1846+
tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
1847+
1848+
response = requests.post(
1849+
CHAT_COMPLETIONS_URL,
1850+
json={"nums": ["1", "2"]},
1851+
stream=True,
1852+
# Despite requesting json, we should still stream results back.
1853+
headers={
1854+
"accept": "application/json",
1855+
"user-agent": "OpenAI/Python 1.61.0",
1856+
},
1857+
)
1858+
assert response.headers.get("transfer-encoding") == "chunked"
1859+
assert [
1860+
byte_string.decode() for byte_string in list(response.iter_content())
1861+
] == ["1", "2"]

0 commit comments

Comments
 (0)