Skip to content

Commit 87c83d9

Browse files
authored
Introduce truss server passthrough for OpenAI methods (#1364)
1 parent aee0bd8 commit 87c83d9

File tree

10 files changed

+419
-98
lines changed

10 files changed

+419
-98
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "truss"
3-
version = "0.9.60rc004"
3+
version = "0.9.60rc005"
44
description = "A seamless bridge from model development to model delivery"
55
license = "MIT"
66
readme = "README.md"

truss/templates/server/model_wrapper.py

+174-82
Large diffs are not rendered by default.

truss/templates/server/truss_server.py

+84-15
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import sys
88
from http import HTTPStatus
99
from pathlib import Path
10-
from typing import Dict, Optional, Union
10+
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional, Union
1111

1212
import pydantic
1313
import uvicorn
@@ -17,7 +17,7 @@
1717
from fastapi import Depends, FastAPI, HTTPException, Request
1818
from fastapi.responses import ORJSONResponse, StreamingResponse
1919
from fastapi.routing import APIRoute as FastAPIRoute
20-
from model_wrapper import InputType, ModelWrapper, OutputType
20+
from model_wrapper import MODEL_BASENAME, MethodName, ModelWrapper
2121
from opentelemetry import propagate as otel_propagate
2222
from opentelemetry import trace
2323
from opentelemetry.sdk import trace as sdk_trace
@@ -38,6 +38,9 @@
3838
TIMEOUT_GRACEFUL_SHUTDOWN = 120
3939
INFERENCE_SERVER_FAILED_FILE = Path("~/inference_server_crashed.txt").expanduser()
4040

41+
if TYPE_CHECKING:
42+
from model_wrapper import InputType, MethodDescriptor, OutputType
43+
4144

4245
async def parse_body(request: Request) -> bytes:
4346
"""
@@ -63,7 +66,7 @@ def __init__(self, model: ModelWrapper, tracer: sdk_trace.Tracer) -> None:
6366
self._model = model
6467
self._tracer = tracer
6568

66-
def _safe_lookup_model(self, model_name: str) -> ModelWrapper:
69+
def _safe_lookup_model(self, model_name: str = MODEL_BASENAME) -> ModelWrapper:
6770
if model_name != self._model.name:
6871
raise errors.ModelMissingError(model_name)
6972
return self._model
@@ -116,7 +119,7 @@ async def _parse_body(
116119
body_raw: bytes,
117120
truss_schema: Optional[TrussSchema],
118121
span: trace.Span,
119-
) -> InputType:
122+
) -> "InputType":
120123
if self.is_binary(request):
121124
with tracing.section_as_event(span, "binary-deserialize"):
122125
inputs = serialization.truss_msgpack_deserialize(body_raw)
@@ -148,36 +151,38 @@ async def _parse_body(
148151

149152
return inputs
150153

151-
async def predict(
152-
self, model_name: str, request: Request, body_raw: bytes = Depends(parse_body)
154+
async def _execute_request(
155+
self,
156+
model: ModelWrapper,
157+
method: Callable[["InputType", Request], Awaitable["OutputType"]],
158+
method_name: MethodName,
159+
request: Request,
160+
body_raw: bytes,
153161
) -> Response:
154162
"""
155-
This method calls the user-provided predict method
163+
Executes a predictive endpoint
156164
"""
157165
if await request.is_disconnected():
158-
msg = "Client disconnected. Skipping `predict`."
166+
msg = f"Client disconnected. Skipping `{method_name}`."
159167
logging.info(msg)
160168
raise ClientDisconnect(msg)
161169

162-
model: ModelWrapper = self._safe_lookup_model(model_name)
163-
164170
self.check_healthy(model)
165171
trace_ctx = otel_propagate.extract(request.headers) or None
166172
# This is the top-level span in the truss-server, so we set the context here.
167173
# Nested spans "inherit" context automatically.
168174
with self._tracer.start_as_current_span(
169-
"predict-endpoint", context=trace_ctx
175+
f"{method_name}-endpoint", context=trace_ctx
170176
) as span:
171-
inputs: Optional[InputType]
177+
inputs: Optional["InputType"]
172178
if model.model_descriptor.skip_input_parsing:
173179
inputs = None
174180
else:
175181
inputs = await self._parse_body(
176182
request, body_raw, model.model_descriptor.truss_schema, span
177183
)
178-
# Calls ModelWrapper which runs: preprocess, predict, postprocess.
179184
with tracing.section_as_event(span, "model-call"):
180-
result: OutputType = await model(inputs, request)
185+
result: "OutputType" = await method(inputs, request)
181186

182187
# In the case that the model returns a Generator object, return a
183188
# StreamingResponse instead.
@@ -190,8 +195,59 @@ async def predict(
190195
return result
191196
return self._serialize_result(result, self.is_binary(request), span)
192197

198+
async def chat_completions(
199+
self, request: Request, body_raw: bytes = Depends(parse_body)
200+
) -> Response:
201+
model = self._safe_lookup_model()
202+
self._raise_if_not_supported(
203+
MethodName.CHAT_COMPLETIONS, model.model_descriptor.chat_completions
204+
)
205+
206+
return await self._execute_request(
207+
model=model,
208+
method=model.chat_completions,
209+
method_name=MethodName.CHAT_COMPLETIONS,
210+
request=request,
211+
body_raw=body_raw,
212+
)
213+
214+
def _raise_if_not_supported(
215+
self, method_name: MethodName, descriptor: Optional["MethodDescriptor"]
216+
):
217+
if not descriptor:
218+
raise HTTPException(status_code=404, detail=f"{method_name} not supported.")
219+
220+
async def completions(
221+
self, request: Request, body_raw: bytes = Depends(parse_body)
222+
) -> Response:
223+
model = self._safe_lookup_model()
224+
self._raise_if_not_supported(
225+
MethodName.COMPLETIONS, model.model_descriptor.completions
226+
)
227+
228+
return await self._execute_request(
229+
model=model,
230+
method=model.completions,
231+
method_name=MethodName.COMPLETIONS,
232+
request=request,
233+
body_raw=body_raw,
234+
)
235+
236+
async def predict(
237+
self, model_name: str, request: Request, body_raw: bytes = Depends(parse_body)
238+
) -> Response:
239+
model = self._safe_lookup_model(model_name)
240+
241+
return await self._execute_request(
242+
model=model,
243+
method=model, # We overwrote __call__ on ModelWrapper
244+
method_name=MethodName.PREDICT,
245+
request=request,
246+
body_raw=body_raw,
247+
)
248+
193249
def _serialize_result(
194-
self, result: OutputType, is_binary: bool, span: trace.Span
250+
self, result: "OutputType", is_binary: bool, span: trace.Span
195251
) -> Response:
196252
response_headers = {}
197253
if is_binary:
@@ -338,6 +394,19 @@ def create_application(self):
338394
methods=["POST"],
339395
tags=["V1"],
340396
),
397+
# OpenAI Spec
398+
FastAPIRoute(
399+
r"/v1/chat/completions",
400+
self._endpoints.chat_completions,
401+
methods=["POST"],
402+
tags=["V1"],
403+
),
404+
FastAPIRoute(
405+
r"/v1/completions",
406+
self._endpoints.completions,
407+
methods=["POST"],
408+
tags=["V1"],
409+
),
341410
# Endpoint aliases for Sagemaker hosting
342411
FastAPIRoute(r"/ping", self._endpoints.invocations_ready),
343412
FastAPIRoute(

truss/tests/conftest.py

+5
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,11 @@ def trt_llm_truss_container_fs(tmp_path, test_data_path):
553553
return _build_truss_fs(test_data_path / "test_trt_llm_truss", tmp_path)
554554

555555

556+
@pytest.fixture
557+
def open_ai_container_fs(tmp_path, test_data_path):
558+
return _build_truss_fs(test_data_path / "test_openai", tmp_path)
559+
560+
556561
@pytest.fixture
557562
def truss_control_container_fs(tmp_path, test_data_path):
558563
test_truss_dir = test_data_path / "test_truss"

truss/tests/templates/server/test_model_wrapper.py

+24
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,30 @@ async def mock_predict(return_value, request: Request):
190190
assert resp == expected_predict_response
191191

192192

193+
@pytest.mark.anyio
194+
async def test_open_ai_completion_endpoints(open_ai_container_fs, helpers):
195+
app_path = open_ai_container_fs / "app"
196+
with _clear_model_load_modules(), helpers.sys_paths(app_path), _change_directory(
197+
app_path
198+
):
199+
model_wrapper_module = importlib.import_module("model_wrapper")
200+
model_wrapper_class = getattr(model_wrapper_module, "ModelWrapper")
201+
config = yaml.safe_load((app_path / "config.yaml").read_text())
202+
203+
model_wrapper = model_wrapper_class(config, sdk_trace.NoOpTracer())
204+
model_wrapper.load()
205+
206+
mock_req = MagicMock(spec=Request)
207+
predict_resp = await model_wrapper.predict({}, mock_req)
208+
assert predict_resp == "predict"
209+
210+
completions_resp = await model_wrapper.completions({}, mock_req)
211+
assert completions_resp == "completions"
212+
213+
chat_completions_resp = await model_wrapper.chat_completions({}, mock_req)
214+
assert chat_completions_resp == "chat_completions"
215+
216+
193217
@contextmanager
194218
def _change_directory(new_directory: Path):
195219
original_directory = os.getcwd()

truss/tests/test_data/test_openai/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
environment_variables: {}
2+
external_package_dirs: []
3+
model_metadata: {}
4+
model_name: Test OpenAI Compatibility
5+
python_version: py39
6+
resources:
7+
accelerator: null
8+
cpu: '1'
9+
memory: 2Gi
10+
use_gpu: false
11+
secrets: {}
12+
system_packages: []

truss/tests/test_data/test_openai/model/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from typing import Dict
2+
3+
4+
class Model:
5+
def __init__(self, **kwargs):
6+
pass
7+
8+
def chat_completions(self, input: Dict) -> str:
9+
return "chat_completions"
10+
11+
def completions(self, input: Dict) -> str:
12+
return "completions"
13+
14+
def predict(self, input: Dict) -> str:
15+
return "predict"

truss/tests/test_model_inference.py

+104
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333

3434
DEFAULT_LOG_ERROR = "Internal Server Error"
3535
PREDICT_URL = "http://localhost:8090/v1/models/model:predict"
36+
COMPLETIONS_URL = "http://localhost:8090/v1/completions"
37+
CHAT_COMPLETIONS_URL = "http://localhost:8090/v1/chat/completions"
3638

3739

3840
@pytest.fixture
@@ -1714,3 +1716,105 @@ def make_request(consume_chunks, timeout, task_id):
17141716

17151717
result = make_request(True, timeout=0.55, task_id=4)
17161718
print(f"Final chunks: {result}")
1719+
1720+
1721+
@pytest.mark.integration
1722+
def test_custom_openai_endpoints():
1723+
"""
1724+
Test a Truss that exposes an OpenAI compatible endpoint.
1725+
"""
1726+
model = """
1727+
from typing import Dict
1728+
1729+
class Model:
1730+
def __init__(self):
1731+
pass
1732+
1733+
def load(self):
1734+
self._predict_count = 0
1735+
self._completions_count = 0
1736+
1737+
async def predict(self, inputs: Dict) -> int:
1738+
self._predict_count += inputs["increment"]
1739+
return self._predict_count
1740+
1741+
async def completions(self, inputs: Dict) -> int:
1742+
self._completions_count += inputs["increment"]
1743+
return self._completions_count
1744+
"""
1745+
with ensure_kill_all(), _temp_truss(model) as tr:
1746+
tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
1747+
1748+
response = requests.post(PREDICT_URL, json={"increment": 1})
1749+
assert response.status_code == 200
1750+
assert response.json() == 1
1751+
1752+
response = requests.post(COMPLETIONS_URL, json={"increment": 2})
1753+
assert response.status_code == 200
1754+
assert response.json() == 2
1755+
1756+
response = requests.post(CHAT_COMPLETIONS_URL, json={"increment": 3})
1757+
assert response.status_code == 404
1758+
1759+
1760+
@pytest.mark.integration
1761+
def test_postprocess_async_generator_streaming():
1762+
"""
1763+
Test a Truss that exposes an OpenAI compatible endpoint.
1764+
"""
1765+
model = """
1766+
from typing import Dict, List, Generator
1767+
1768+
class Model:
1769+
def __init__(self):
1770+
pass
1771+
1772+
def load(self):
1773+
pass
1774+
1775+
async def predict(self, inputs: Dict) -> List[str]:
1776+
nums: List[int] = inputs["nums"]
1777+
return nums
1778+
1779+
async def postprocess(self, nums: List[str]) -> Generator[str, None, None]:
1780+
for num in nums:
1781+
yield num
1782+
"""
1783+
with ensure_kill_all(), _temp_truss(model) as tr:
1784+
tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
1785+
1786+
response = requests.post(PREDICT_URL, json={"nums": ["1", "2"]}, stream=True)
1787+
assert response.headers.get("transfer-encoding") == "chunked"
1788+
assert [
1789+
byte_string.decode() for byte_string in list(response.iter_content())
1790+
] == ["1", "2"]
1791+
1792+
1793+
@pytest.mark.integration
1794+
def test_preprocess_async_generator():
1795+
"""
1796+
Test a Truss that exposes an OpenAI compatible endpoint.
1797+
"""
1798+
model = """
1799+
from typing import Dict, List, AsyncGenerator
1800+
1801+
class Model:
1802+
def __init__(self):
1803+
pass
1804+
1805+
def load(self):
1806+
pass
1807+
1808+
async def preprocess(self, inputs: Dict) -> AsyncGenerator[str, None]:
1809+
for num in inputs["nums"]:
1810+
yield num
1811+
1812+
async def predict(self, nums: AsyncGenerator[str, None]) -> List[str]:
1813+
return [num async for num in nums]
1814+
"""
1815+
with ensure_kill_all(), _temp_truss(model) as tr:
1816+
tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
1817+
1818+
response = requests.post(PREDICT_URL, json={"nums": ["1", "2"]})
1819+
assert response.status_code == 200
1820+
assert response.json() == ["1", "2"]

0 commit comments

Comments
 (0)