1
1
import concurrent
2
2
import inspect
3
+ import json
3
4
import logging
4
5
import tempfile
5
6
import textwrap
@@ -32,6 +33,19 @@ def _create_truss(truss_dir: Path, config_contents: str, model_contents: str):
32
33
file .write (model_contents )
33
34
34
35
36
+ def _log_contains_error (line : dict , error : str ):
37
+ return (
38
+ line ["levelname" ] == "ERROR"
39
+ and line ["message" ] == "Exception while running predict"
40
+ and error in line ["exc_info" ]
41
+ )
42
+
43
+
44
+ def assert_logs_contain_error (logs : str , error : str ):
45
+ loglines = logs .splitlines ()
46
+ assert any (_log_contains_error (json .loads (line ), error ) for line in loglines )
47
+
48
+
35
49
class PropagatingThread (Thread ):
36
50
"""
37
51
PropagatingThread allows us to run threads and keep track of exceptions
@@ -317,20 +331,11 @@ def predict(self, request):
317
331
return self ._secrets ["secret" ]
318
332
319
333
config = """model_name: secrets-truss
320
- cpu: "3"
321
- memory: 14Gi
322
- use_gpu: true
323
- accelerator: A10G
324
334
secrets:
325
335
secret: null
326
336
"""
327
337
328
- config_with_no_secret = """model_name: secrets-truss
329
- cpu: "3"
330
- memory: 14Gi
331
- use_gpu: true
332
- accelerator: A10G
333
- """
338
+ config_with_no_secret = "model_name: secrets-truss"
334
339
335
340
with ensure_kill_all (), tempfile .TemporaryDirectory (dir = "." ) as tmp_work_dir :
336
341
truss_dir = Path (tmp_work_dir , "truss" )
@@ -344,9 +349,8 @@ def predict(self, request):
344
349
full_url = f"{ truss_server_addr } /v1/models/model:predict"
345
350
346
351
response = requests .post (full_url , json = {})
347
- assert response .json () == "secret_value"
348
352
349
- _create_truss ( truss_dir , config , textwrap . dedent ( inspect . getsource ( Model )))
353
+ assert response . json () == "secret_value"
350
354
351
355
with ensure_kill_all (), tempfile .TemporaryDirectory (dir = "." ) as tmp_work_dir :
352
356
# Case where the secret is not specified in the config
@@ -357,33 +361,149 @@ def predict(self, request):
357
361
)
358
362
tr = TrussHandle (truss_dir )
359
363
LocalConfigHandler .set_secret ("secret" , "secret_value" )
360
- _ = tr .docker_run (local_port = 8090 , detach = True , wait_for_server_ready = True )
364
+ container = tr .docker_run (
365
+ local_port = 8090 , detach = True , wait_for_server_ready = True
366
+ )
361
367
truss_server_addr = "http://localhost:8090"
362
368
full_url = f"{ truss_server_addr } /v1/models/model:predict"
363
369
364
370
response = requests .post (full_url , json = {})
365
371
366
372
assert "error" in response .json ()
367
- assert "not specified in the config" in response .json ()["error" ]["traceback" ]
373
+ assert_logs_contain_error (container .logs (), "not specified in the config" )
374
+ assert "Error while running predict" in response .json ()["error" ]["message" ]
368
375
369
376
with ensure_kill_all (), tempfile .TemporaryDirectory (dir = "." ) as tmp_work_dir :
370
- # Case where the secret is not specified in the config
377
+ # Case where the secret is not mounted
371
378
truss_dir = Path (tmp_work_dir , "truss" )
372
379
373
380
_create_truss (truss_dir , config , textwrap .dedent (inspect .getsource (Model )))
374
381
tr = TrussHandle (truss_dir )
375
382
LocalConfigHandler .remove_secret ("secret" )
376
- _ = tr .docker_run (local_port = 8090 , detach = True , wait_for_server_ready = True )
383
+ container = tr .docker_run (
384
+ local_port = 8090 , detach = True , wait_for_server_ready = True
385
+ )
386
+ truss_server_addr = "http://localhost:8090"
387
+ full_url = f"{ truss_server_addr } /v1/models/model:predict"
388
+
389
+ response = requests .post (full_url , json = {})
390
+ assert response .status_code == 500
391
+ assert_logs_contain_error (
392
+ container .logs (), "'secret' not found. Please check available secrets."
393
+ )
394
+ assert "Error while running predict" in response .json ()["error" ]["message" ]
395
+
396
+
397
+ @pytest .mark .integration
398
+ def test_truss_with_errors ():
399
+ model = """
400
+ class Model:
401
+ def predict(self, request):
402
+ raise ValueError("error")
403
+ """
404
+
405
+ config = "model_name: error-truss"
406
+
407
+ with ensure_kill_all (), tempfile .TemporaryDirectory (dir = "." ) as tmp_work_dir :
408
+ truss_dir = Path (tmp_work_dir , "truss" )
409
+
410
+ _create_truss (truss_dir , config , textwrap .dedent (model ))
411
+
412
+ tr = TrussHandle (truss_dir )
413
+ container = tr .docker_run (
414
+ local_port = 8090 , detach = True , wait_for_server_ready = True
415
+ )
416
+ truss_server_addr = "http://localhost:8090"
417
+ full_url = f"{ truss_server_addr } /v1/models/model:predict"
418
+
419
+ response = requests .post (full_url , json = {})
420
+ assert response .status_code == 500
421
+ assert "error" in response .json ()
422
+
423
+ assert_logs_contain_error (container .logs (), "ValueError: error" )
424
+
425
+ assert "Error while running predict" in response .json ()["error" ]["message" ]
426
+
427
+ model_preprocess_error = """
428
+ class Model:
429
+ def preprocess(self, request):
430
+ raise ValueError("error")
431
+
432
+ def predict(self, request):
433
+ return {"a": "b"}
434
+ """
435
+
436
+ with ensure_kill_all (), tempfile .TemporaryDirectory (dir = "." ) as tmp_work_dir :
437
+ truss_dir = Path (tmp_work_dir , "truss" )
438
+
439
+ _create_truss (truss_dir , config , textwrap .dedent (model_preprocess_error ))
440
+
441
+ tr = TrussHandle (truss_dir )
442
+ container = tr .docker_run (
443
+ local_port = 8090 , detach = True , wait_for_server_ready = True
444
+ )
377
445
truss_server_addr = "http://localhost:8090"
378
446
full_url = f"{ truss_server_addr } /v1/models/model:predict"
379
447
380
448
response = requests .post (full_url , json = {})
449
+ assert response .status_code == 500
450
+ assert "error" in response .json ()
451
+
452
+ assert_logs_contain_error (container .logs (), "ValueError: error" )
453
+ assert "Error while running predict" in response .json ()["error" ]["message" ]
454
+
455
+ model_postprocess_error = """
456
+ class Model:
457
+ def predict(self, request):
458
+ return {"a": "b"}
459
+
460
+ def postprocess(self, response):
461
+ raise ValueError("error")
462
+ """
463
+
464
+ with ensure_kill_all (), tempfile .TemporaryDirectory (dir = "." ) as tmp_work_dir :
465
+ truss_dir = Path (tmp_work_dir , "truss" )
466
+
467
+ _create_truss (truss_dir , config , textwrap .dedent (model_postprocess_error ))
468
+
469
+ tr = TrussHandle (truss_dir )
470
+ container = tr .docker_run (
471
+ local_port = 8090 , detach = True , wait_for_server_ready = True
472
+ )
473
+ truss_server_addr = "http://localhost:8090"
474
+ full_url = f"{ truss_server_addr } /v1/models/model:predict"
381
475
476
+ response = requests .post (full_url , json = {})
477
+ assert response .status_code == 500
382
478
assert "error" in response .json ()
383
- assert (
384
- "not found. Please check available secrets."
385
- in response .json ()["error" ]["traceback" ]
479
+ assert_logs_contain_error (container .logs (), "ValueError: error" )
480
+ assert "Error while running predict" in response .json ()["error" ]["message" ]
481
+
482
+ model_async = """
483
+ class Model:
484
+ async def predict(self, request):
485
+ raise ValueError("error")
486
+ """
487
+
488
+ with ensure_kill_all (), tempfile .TemporaryDirectory (dir = "." ) as tmp_work_dir :
489
+ truss_dir = Path (tmp_work_dir , "truss" )
490
+
491
+ _create_truss (truss_dir , config , textwrap .dedent (model_async ))
492
+
493
+ tr = TrussHandle (truss_dir )
494
+ container = tr .docker_run (
495
+ local_port = 8090 , detach = True , wait_for_server_ready = True
386
496
)
497
+ truss_server_addr = "http://localhost:8090"
498
+ full_url = f"{ truss_server_addr } /v1/models/model:predict"
499
+
500
+ response = requests .post (full_url , json = {})
501
+ assert response .status_code == 500
502
+ assert "error" in response .json ()
503
+
504
+ assert_logs_contain_error (container .logs (), "ValueError: error" )
505
+
506
+ assert "Error while running predict" in response .json ()["error" ]["message" ]
387
507
388
508
389
509
@pytest .mark .integration
0 commit comments