7
7
import sys
8
8
from http import HTTPStatus
9
9
from pathlib import Path
10
- from typing import Dict , Optional , Union
10
+ from typing import TYPE_CHECKING , Awaitable , Callable , Dict , Optional , Union
11
11
12
12
import pydantic
13
13
import uvicorn
17
17
from fastapi import Depends , FastAPI , HTTPException , Request
18
18
from fastapi .responses import ORJSONResponse , StreamingResponse
19
19
from fastapi .routing import APIRoute as FastAPIRoute
20
- from model_wrapper import InputType , ModelWrapper , OutputType
20
+ from model_wrapper import MODEL_BASENAME , MethodName , ModelWrapper
21
21
from opentelemetry import propagate as otel_propagate
22
22
from opentelemetry import trace
23
23
from opentelemetry .sdk import trace as sdk_trace
38
38
TIMEOUT_GRACEFUL_SHUTDOWN = 120
39
39
INFERENCE_SERVER_FAILED_FILE = Path ("~/inference_server_crashed.txt" ).expanduser ()
40
40
41
+ if TYPE_CHECKING :
42
+ from model_wrapper import InputType , MethodDescriptor , OutputType
43
+
41
44
42
45
async def parse_body (request : Request ) -> bytes :
43
46
"""
@@ -63,7 +66,7 @@ def __init__(self, model: ModelWrapper, tracer: sdk_trace.Tracer) -> None:
63
66
self ._model = model
64
67
self ._tracer = tracer
65
68
66
- def _safe_lookup_model (self , model_name : str ) -> ModelWrapper :
69
+ def _safe_lookup_model (self , model_name : str = MODEL_BASENAME ) -> ModelWrapper :
67
70
if model_name != self ._model .name :
68
71
raise errors .ModelMissingError (model_name )
69
72
return self ._model
@@ -116,7 +119,7 @@ async def _parse_body(
116
119
body_raw : bytes ,
117
120
truss_schema : Optional [TrussSchema ],
118
121
span : trace .Span ,
119
- ) -> InputType :
122
+ ) -> " InputType" :
120
123
if self .is_binary (request ):
121
124
with tracing .section_as_event (span , "binary-deserialize" ):
122
125
inputs = serialization .truss_msgpack_deserialize (body_raw )
@@ -148,36 +151,38 @@ async def _parse_body(
148
151
149
152
return inputs
150
153
151
- async def predict (
152
- self , model_name : str , request : Request , body_raw : bytes = Depends (parse_body )
154
+ async def _execute_request (
155
+ self ,
156
+ model : ModelWrapper ,
157
+ method : Callable [["InputType" , Request ], Awaitable ["OutputType" ]],
158
+ method_name : MethodName ,
159
+ request : Request ,
160
+ body_raw : bytes ,
153
161
) -> Response :
154
162
"""
155
- This method calls the user-provided predict method
163
+ Executes a predictive endpoint
156
164
"""
157
165
if await request .is_disconnected ():
158
- msg = "Client disconnected. Skipping `predict `."
166
+ msg = f "Client disconnected. Skipping `{ method_name } `."
159
167
logging .info (msg )
160
168
raise ClientDisconnect (msg )
161
169
162
- model : ModelWrapper = self ._safe_lookup_model (model_name )
163
-
164
170
self .check_healthy (model )
165
171
trace_ctx = otel_propagate .extract (request .headers ) or None
166
172
# This is the top-level span in the truss-server, so we set the context here.
167
173
# Nested spans "inherit" context automatically.
168
174
with self ._tracer .start_as_current_span (
169
- "predict -endpoint" , context = trace_ctx
175
+ f" { method_name } -endpoint" , context = trace_ctx
170
176
) as span :
171
- inputs : Optional [InputType ]
177
+ inputs : Optional [" InputType" ]
172
178
if model .model_descriptor .skip_input_parsing :
173
179
inputs = None
174
180
else :
175
181
inputs = await self ._parse_body (
176
182
request , body_raw , model .model_descriptor .truss_schema , span
177
183
)
178
- # Calls ModelWrapper which runs: preprocess, predict, postprocess.
179
184
with tracing .section_as_event (span , "model-call" ):
180
- result : OutputType = await model (inputs , request )
185
+ result : " OutputType" = await method (inputs , request )
181
186
182
187
# In the case that the model returns a Generator object, return a
183
188
# StreamingResponse instead.
@@ -190,8 +195,59 @@ async def predict(
190
195
return result
191
196
return self ._serialize_result (result , self .is_binary (request ), span )
192
197
198
+ async def chat_completions (
199
+ self , request : Request , body_raw : bytes = Depends (parse_body )
200
+ ) -> Response :
201
+ model = self ._safe_lookup_model ()
202
+ self ._raise_if_not_supported (
203
+ MethodName .CHAT_COMPLETIONS , model .model_descriptor .chat_completions
204
+ )
205
+
206
+ return await self ._execute_request (
207
+ model = model ,
208
+ method = model .chat_completions ,
209
+ method_name = MethodName .CHAT_COMPLETIONS ,
210
+ request = request ,
211
+ body_raw = body_raw ,
212
+ )
213
+
214
+ def _raise_if_not_supported (
215
+ self , method_name : MethodName , descriptor : Optional ["MethodDescriptor" ]
216
+ ):
217
+ if not descriptor :
218
+ raise HTTPException (status_code = 404 , detail = f"{ method_name } not supported." )
219
+
220
+ async def completions (
221
+ self , request : Request , body_raw : bytes = Depends (parse_body )
222
+ ) -> Response :
223
+ model = self ._safe_lookup_model ()
224
+ self ._raise_if_not_supported (
225
+ MethodName .COMPLETIONS , model .model_descriptor .completions
226
+ )
227
+
228
+ return await self ._execute_request (
229
+ model = model ,
230
+ method = model .completions ,
231
+ method_name = MethodName .COMPLETIONS ,
232
+ request = request ,
233
+ body_raw = body_raw ,
234
+ )
235
+
236
+ async def predict (
237
+ self , model_name : str , request : Request , body_raw : bytes = Depends (parse_body )
238
+ ) -> Response :
239
+ model = self ._safe_lookup_model (model_name )
240
+
241
+ return await self ._execute_request (
242
+ model = model ,
243
+ method = model , # We overwrote __call__ on ModelWrapper
244
+ method_name = MethodName .PREDICT ,
245
+ request = request ,
246
+ body_raw = body_raw ,
247
+ )
248
+
193
249
def _serialize_result (
194
- self , result : OutputType , is_binary : bool , span : trace .Span
250
+ self , result : " OutputType" , is_binary : bool , span : trace .Span
195
251
) -> Response :
196
252
response_headers = {}
197
253
if is_binary :
@@ -338,6 +394,19 @@ def create_application(self):
338
394
methods = ["POST" ],
339
395
tags = ["V1" ],
340
396
),
397
+ # OpenAI Spec
398
+ FastAPIRoute (
399
+ r"/v1/chat/completions" ,
400
+ self ._endpoints .chat_completions ,
401
+ methods = ["POST" ],
402
+ tags = ["V1" ],
403
+ ),
404
+ FastAPIRoute (
405
+ r"/v1/completions" ,
406
+ self ._endpoints .completions ,
407
+ methods = ["POST" ],
408
+ tags = ["V1" ],
409
+ ),
341
410
# Endpoint aliases for Sagemaker hosting
342
411
FastAPIRoute (r"/ping" , self ._endpoints .invocations_ready ),
343
412
FastAPIRoute (
0 commit comments