Skip to content

Commit 1cbe627

Browse files
squidarthbolasim
andauthored
Return 500 error codes on user error. (#545)
* Return 500 error codes on user error. * Fix issue w/ missing values. * Remove stacktrace. * Fix pr. * Fix tests. * Fix tests. * Fix all tests. * Fix test. * Fix broked TGI model_id swap * Bump rc version * Fix create dir that already exists * bump version to test * Drop old make from in-memory model code * Drop trainable config from cli comand * Fix training failures * Fix tests * Try again to make test pass * Add flask depedency for integration tests * Skip flaky test * Bump pyproject. * Update for PR feedback. --------- Co-authored-by: Bola Malek <bola@baseten.co>
1 parent 4eb33ea commit 1cbe627

File tree

5 files changed

+185
-41
lines changed

5 files changed

+185
-41
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "truss"
3-
version = "0.6.5rc7"
3+
version = "0.6.5rc11"
44
description = "A seamless bridge from model development to model delivery"
55
license = "MIT"
66
readme = "README.md"

truss/templates/server/common/errors.py

+4
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,7 @@ async def not_implemented_error_handler(_, exc):
9494
return JSONResponse(
9595
status_code=HTTPStatus.NOT_IMPLEMENTED, content={"error": str(exc)}
9696
)
97+
98+
99+
async def http_exception_handler(_, exc):
100+
return JSONResponse(status_code=exc.status_code, content={"error": exc.detail})

truss/templates/server/common/truss_server.py

+1
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ def create_application(self):
234234
errors.ModelNotFound: errors.model_not_found_handler,
235235
errors.ModelNotReady: errors.model_not_ready_handler,
236236
NotImplementedError: errors.not_implemented_error_handler,
237+
HTTPException: errors.http_exception_handler,
237238
Exception: errors.generic_exception_handler,
238239
},
239240
)

truss/templates/server/model_wrapper.py

+40-21
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import os
66
import sys
77
import time
8-
import traceback
98
from collections.abc import Generator
109
from contextlib import asynccontextmanager
1110
from enum import Enum
@@ -16,6 +15,7 @@
1615
from anyio import Semaphore, to_thread
1716
from common.patches import apply_patches
1817
from common.retry import retry
18+
from fastapi import HTTPException
1919
from shared.secrets_resolver import SecretsResolver
2020

2121
MODEL_BASENAME = "model"
@@ -177,23 +177,11 @@ async def preprocess(
177177
return payload
178178

179179
if inspect.iscoroutinefunction(self._model.preprocess):
180-
return await self._model.preprocess(payload)
180+
return await _intercept_exceptions_async(self._model.preprocess)(payload)
181181
else:
182-
return await to_thread.run_sync(self._model.preprocess, payload)
183-
184-
def _predict_sync_with_error_handling(self, payload):
185-
try:
186-
return self._model.predict(payload)
187-
except Exception:
188-
logging.exception("Exception while running predict")
189-
return {"error": {"traceback": traceback.format_exc()}}
190-
191-
async def _predict_async_with_error_handling(self, payload):
192-
try:
193-
return await self._model.predict(payload)
194-
except Exception:
195-
logging.exception("Exception while running predict")
196-
return {"error": {"traceback": traceback.format_exc()}}
182+
return await to_thread.run_sync(
183+
_intercept_exceptions_sync(self._model.preprocess), payload
184+
)
197185

198186
async def predict(
199187
self,
@@ -214,9 +202,11 @@ async def predict(
214202
return self._model.predict(payload)
215203

216204
if inspect.iscoroutinefunction(self._model.predict):
217-
return await self._predict_async_with_error_handling(payload)
205+
return await _intercept_exceptions_async(self._model.predict)(payload)
218206

219-
return await to_thread.run_sync(self._predict_sync_with_error_handling, payload)
207+
return await to_thread.run_sync(
208+
_intercept_exceptions_sync(self._model.predict), payload
209+
)
220210

221211
async def postprocess(
222212
self,
@@ -238,9 +228,11 @@ async def postprocess(
238228
return self._model.postprocess(response, headers)
239229

240230
if inspect.iscoroutinefunction(self._model.postprocess):
241-
return await self._model.postprocess(response)
231+
return await _intercept_exceptions_async(self._model.postprocess)(response)
242232

243-
return await to_thread.run_sync(self._model.postprocess, response)
233+
return await to_thread.run_sync(
234+
_intercept_exceptions_sync(self._model.postprocess), response
235+
)
244236

245237
async def write_response_to_queue(
246238
self, queue: asyncio.Queue, generator: AsyncGenerator
@@ -368,3 +360,30 @@ def _signature_accepts_kwargs(signature: inspect.Signature) -> bool:
368360

369361
def _elapsed_ms(since_micro_seconds: float) -> int:
370362
return int((time.perf_counter() - since_micro_seconds) * 1000)
363+
364+
365+
def _handle_exception():
366+
# Note that logger.exception logs the stacktrace, such that the user can
367+
# debug this error from the logs.
368+
logging.exception("Internal Server Error")
369+
raise HTTPException(status_code=500, detail={"message": "Internal Server Error"})
370+
371+
372+
def _intercept_exceptions_sync(func):
373+
def inner(*args, **kwargs):
374+
try:
375+
return func(*args, **kwargs)
376+
except Exception:
377+
_handle_exception()
378+
379+
return inner
380+
381+
382+
def _intercept_exceptions_async(func):
383+
async def inner(*args, **kwargs):
384+
try:
385+
return await func(*args, **kwargs)
386+
except Exception:
387+
_handle_exception()
388+
389+
return inner

truss/tests/test_model_inference.py

+139-19
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import concurrent
22
import inspect
3+
import json
34
import logging
45
import tempfile
56
import textwrap
@@ -32,6 +33,19 @@ def _create_truss(truss_dir: Path, config_contents: str, model_contents: str):
3233
file.write(model_contents)
3334

3435

36+
def _log_contains_error(line: dict, error: str):
37+
return (
38+
line["levelname"] == "ERROR"
39+
and line["message"] == "Exception while running predict"
40+
and error in line["exc_info"]
41+
)
42+
43+
44+
def assert_logs_contain_error(logs: str, error: str):
45+
loglines = logs.splitlines()
46+
assert any(_log_contains_error(json.loads(line), error) for line in loglines)
47+
48+
3549
class PropagatingThread(Thread):
3650
"""
3751
PropagatingThread allows us to run threads and keep track of exceptions
@@ -317,20 +331,11 @@ def predict(self, request):
317331
return self._secrets["secret"]
318332

319333
config = """model_name: secrets-truss
320-
cpu: "3"
321-
memory: 14Gi
322-
use_gpu: true
323-
accelerator: A10G
324334
secrets:
325335
secret: null
326336
"""
327337

328-
config_with_no_secret = """model_name: secrets-truss
329-
cpu: "3"
330-
memory: 14Gi
331-
use_gpu: true
332-
accelerator: A10G
333-
"""
338+
config_with_no_secret = "model_name: secrets-truss"
334339

335340
with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir:
336341
truss_dir = Path(tmp_work_dir, "truss")
@@ -344,9 +349,8 @@ def predict(self, request):
344349
full_url = f"{truss_server_addr}/v1/models/model:predict"
345350

346351
response = requests.post(full_url, json={})
347-
assert response.json() == "secret_value"
348352

349-
_create_truss(truss_dir, config, textwrap.dedent(inspect.getsource(Model)))
353+
assert response.json() == "secret_value"
350354

351355
with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir:
352356
# Case where the secret is not specified in the config
@@ -357,33 +361,149 @@ def predict(self, request):
357361
)
358362
tr = TrussHandle(truss_dir)
359363
LocalConfigHandler.set_secret("secret", "secret_value")
360-
_ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
364+
container = tr.docker_run(
365+
local_port=8090, detach=True, wait_for_server_ready=True
366+
)
361367
truss_server_addr = "http://localhost:8090"
362368
full_url = f"{truss_server_addr}/v1/models/model:predict"
363369

364370
response = requests.post(full_url, json={})
365371

366372
assert "error" in response.json()
367-
assert "not specified in the config" in response.json()["error"]["traceback"]
373+
assert_logs_contain_error(container.logs(), "not specified in the config")
374+
assert "Error while running predict" in response.json()["error"]["message"]
368375

369376
with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir:
370-
# Case where the secret is not specified in the config
377+
# Case where the secret is not mounted
371378
truss_dir = Path(tmp_work_dir, "truss")
372379

373380
_create_truss(truss_dir, config, textwrap.dedent(inspect.getsource(Model)))
374381
tr = TrussHandle(truss_dir)
375382
LocalConfigHandler.remove_secret("secret")
376-
_ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
383+
container = tr.docker_run(
384+
local_port=8090, detach=True, wait_for_server_ready=True
385+
)
386+
truss_server_addr = "http://localhost:8090"
387+
full_url = f"{truss_server_addr}/v1/models/model:predict"
388+
389+
response = requests.post(full_url, json={})
390+
assert response.status_code == 500
391+
assert_logs_contain_error(
392+
container.logs(), "'secret' not found. Please check available secrets."
393+
)
394+
assert "Error while running predict" in response.json()["error"]["message"]
395+
396+
397+
@pytest.mark.integration
398+
def test_truss_with_errors():
399+
model = """
400+
class Model:
401+
def predict(self, request):
402+
raise ValueError("error")
403+
"""
404+
405+
config = "model_name: error-truss"
406+
407+
with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir:
408+
truss_dir = Path(tmp_work_dir, "truss")
409+
410+
_create_truss(truss_dir, config, textwrap.dedent(model))
411+
412+
tr = TrussHandle(truss_dir)
413+
container = tr.docker_run(
414+
local_port=8090, detach=True, wait_for_server_ready=True
415+
)
416+
truss_server_addr = "http://localhost:8090"
417+
full_url = f"{truss_server_addr}/v1/models/model:predict"
418+
419+
response = requests.post(full_url, json={})
420+
assert response.status_code == 500
421+
assert "error" in response.json()
422+
423+
assert_logs_contain_error(container.logs(), "ValueError: error")
424+
425+
assert "Error while running predict" in response.json()["error"]["message"]
426+
427+
model_preprocess_error = """
428+
class Model:
429+
def preprocess(self, request):
430+
raise ValueError("error")
431+
432+
def predict(self, request):
433+
return {"a": "b"}
434+
"""
435+
436+
with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir:
437+
truss_dir = Path(tmp_work_dir, "truss")
438+
439+
_create_truss(truss_dir, config, textwrap.dedent(model_preprocess_error))
440+
441+
tr = TrussHandle(truss_dir)
442+
container = tr.docker_run(
443+
local_port=8090, detach=True, wait_for_server_ready=True
444+
)
377445
truss_server_addr = "http://localhost:8090"
378446
full_url = f"{truss_server_addr}/v1/models/model:predict"
379447

380448
response = requests.post(full_url, json={})
449+
assert response.status_code == 500
450+
assert "error" in response.json()
451+
452+
assert_logs_contain_error(container.logs(), "ValueError: error")
453+
assert "Error while running predict" in response.json()["error"]["message"]
454+
455+
model_postprocess_error = """
456+
class Model:
457+
def predict(self, request):
458+
return {"a": "b"}
459+
460+
def postprocess(self, response):
461+
raise ValueError("error")
462+
"""
463+
464+
with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir:
465+
truss_dir = Path(tmp_work_dir, "truss")
466+
467+
_create_truss(truss_dir, config, textwrap.dedent(model_postprocess_error))
468+
469+
tr = TrussHandle(truss_dir)
470+
container = tr.docker_run(
471+
local_port=8090, detach=True, wait_for_server_ready=True
472+
)
473+
truss_server_addr = "http://localhost:8090"
474+
full_url = f"{truss_server_addr}/v1/models/model:predict"
381475

476+
response = requests.post(full_url, json={})
477+
assert response.status_code == 500
382478
assert "error" in response.json()
383-
assert (
384-
"not found. Please check available secrets."
385-
in response.json()["error"]["traceback"]
479+
assert_logs_contain_error(container.logs(), "ValueError: error")
480+
assert "Error while running predict" in response.json()["error"]["message"]
481+
482+
model_async = """
483+
class Model:
484+
async def predict(self, request):
485+
raise ValueError("error")
486+
"""
487+
488+
with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir:
489+
truss_dir = Path(tmp_work_dir, "truss")
490+
491+
_create_truss(truss_dir, config, textwrap.dedent(model_async))
492+
493+
tr = TrussHandle(truss_dir)
494+
container = tr.docker_run(
495+
local_port=8090, detach=True, wait_for_server_ready=True
386496
)
497+
truss_server_addr = "http://localhost:8090"
498+
full_url = f"{truss_server_addr}/v1/models/model:predict"
499+
500+
response = requests.post(full_url, json={})
501+
assert response.status_code == 500
502+
assert "error" in response.json()
503+
504+
assert_logs_contain_error(container.logs(), "ValueError: error")
505+
506+
assert "Error while running predict" in response.json()["error"]["message"]
387507

388508

389509
@pytest.mark.integration

0 commit comments

Comments
 (0)