@@ -716,19 +716,37 @@ async def _process_model_fn(
716
716
result = await self ._execute_async_model_fn (inputs , request , descriptor )
717
717
718
718
if inspect .isgenerator (result ) or inspect .isasyncgen (result ):
719
- if request .headers .get ("accept" ) == "application/json" :
720
- return await _gather_generator (result )
721
- else :
722
- return await self ._stream_with_background_task (
723
- result ,
724
- fn_span ,
725
- detached_ctx ,
726
- # No semaphores needed for non-predict model functions.
727
- release_and_end = lambda : None ,
728
- )
719
+ return await self ._handle_generator_response (
720
+ request , result , fn_span , detached_ctx , release_and_end = lambda : None
721
+ )
729
722
730
723
return result
731
724
725
+ def _should_gather_generator (self , request : starlette .requests .Request ) -> bool :
726
+ # The OpenAI SDK sends an accept header for JSON even in a streaming context,
727
+ # but we need to stream results back for client compatibility. Luckily,
728
+ # we can differentiate by looking at the user agent (e.g. OpenAI/Python 1.61.0)
729
+ user_agent = request .headers .get ("user-agent" , "" )
730
+ if "openai" in user_agent .lower ():
731
+ return False
732
+ # TODO(nikhil): determine if we can safely deprecate this behavior.
733
+ return request .headers .get ("accept" ) == "application/json"
734
+
735
+ async def _handle_generator_response (
736
+ self ,
737
+ request : starlette .requests .Request ,
738
+ generator : Union [Generator [bytes , None , None ], AsyncGenerator [bytes , None ]],
739
+ span : trace .Span ,
740
+ trace_ctx : trace .Context ,
741
+ release_and_end : Callable [[], None ],
742
+ ):
743
+ if self ._should_gather_generator (request ):
744
+ return await _gather_generator (generator )
745
+ else :
746
+ return await self ._stream_with_background_task (
747
+ generator , span , trace_ctx , release_and_end
748
+ )
749
+
732
750
async def completions (
733
751
self , inputs : InputType , request : starlette .requests .Request
734
752
) -> OutputType :
@@ -801,17 +819,13 @@ async def __call__(
801
819
"the predict method."
802
820
)
803
821
804
- if request .headers .get ("accept" ) == "application/json" :
805
- # In the case of a streaming response, consume stream
806
- # if the http accept header is set, and json is requested.
807
- return await _gather_generator (predict_result )
808
- else :
809
- return await self ._stream_with_background_task (
810
- predict_result ,
811
- span_predict ,
812
- detached_ctx ,
813
- release_and_end = get_defer_fn (),
814
- )
822
+ return await self ._handle_generator_response (
823
+ request ,
824
+ predict_result ,
825
+ span_predict ,
826
+ detached_ctx ,
827
+ release_and_end = get_defer_fn (),
828
+ )
815
829
816
830
if isinstance (predict_result , starlette .responses .Response ):
817
831
if self .model_descriptor .postprocess :
0 commit comments