@@ -397,6 +397,164 @@ def predict(self, request):
397
397
assert "Internal Server Error" in response .json ()["error" ]
398
398
399
399
400
+ @pytest .mark .integration
401
+ def test_postprocess_with_streaming_predict ():
402
+ """
403
+ Test a Truss that has streaming response from both predict and postprocess.
404
+ In this case, the postprocess step continues to happen within the predict lock,
405
+ so we don't bother testing the lock scenario, just the behavior that the postprocess
406
+ function is applied.
407
+ """
408
+ model = """
409
+ import time
410
+
411
+ class Model:
412
+ def postprocess(self, response):
413
+ for item in response:
414
+ time.sleep(1)
415
+ yield item + " modified"
416
+
417
+ def predict(self, request):
418
+ for i in range(2):
419
+ yield str(i)
420
+ """
421
+
422
+ config = "model_name: error-truss"
423
+ with ensure_kill_all (), tempfile .TemporaryDirectory (dir = "." ) as tmp_work_dir :
424
+ truss_dir = Path (tmp_work_dir , "truss" )
425
+
426
+ _create_truss (truss_dir , config , textwrap .dedent (model ))
427
+
428
+ tr = TrussHandle (truss_dir )
429
+ _ = tr .docker_run (local_port = 8090 , detach = True , wait_for_server_ready = True )
430
+ truss_server_addr = "http://localhost:8090"
431
+ full_url = f"{ truss_server_addr } /v1/models/model:predict"
432
+ response = requests .post (full_url , json = {}, stream = True )
433
+ # Note that the postprocess function is applied to the
434
+ # streamed response.
435
+ assert response .content == b"0 modified1 modified"
436
+
437
+
438
+ @pytest .mark .integration
439
+ def test_streaming_postprocess ():
440
+ """
441
+ Tests a Truss where predict returns non-streaming, but postprocess is streamd, and
442
+ ensures that the postprocess step does not happen within the predict lock. To do this,
443
+ we sleep for two seconds during the postprocess streaming process, and fire off two
444
+ requests with a total timeout of 3 seconds, ensuring that if they were serialized
445
+ the test would fail.
446
+ """
447
+ model = """
448
+ import time
449
+
450
+ class Model:
451
+ def postprocess(self, response):
452
+ for item in response:
453
+ time.sleep(1)
454
+ yield item + " modified"
455
+
456
+ def predict(self, request):
457
+ return ["0", "1"]
458
+ """
459
+
460
+ config = "model_name: error-truss"
461
+ with ensure_kill_all (), tempfile .TemporaryDirectory (dir = "." ) as tmp_work_dir :
462
+ truss_dir = Path (tmp_work_dir , "truss" )
463
+
464
+ _create_truss (truss_dir , config , textwrap .dedent (model ))
465
+
466
+ tr = TrussHandle (truss_dir )
467
+ _ = tr .docker_run (local_port = 8090 , detach = True , wait_for_server_ready = True )
468
+ truss_server_addr = "http://localhost:8090"
469
+ full_url = f"{ truss_server_addr } /v1/models/model:predict"
470
+
471
+ def make_request (delay : int ):
472
+ # For streamed responses, requests does not start receiving content from server until
473
+ # `iter_content` is called, so we must call this in order to get an actual timeout.
474
+ time .sleep (delay )
475
+ response = requests .post (full_url , json = {}, stream = True )
476
+
477
+ assert response .status_code == 200
478
+ assert response .content == b"0 modified1 modified"
479
+
480
+ with ThreadPoolExecutor () as e :
481
+ # We use concurrent.futures.wait instead of the timeout property
482
+ # on requests, since requests timeout property has a complex interaction
483
+ # with streaming.
484
+ first_request = e .submit (make_request , 0 )
485
+ second_request = e .submit (make_request , 0.2 )
486
+ futures = [first_request , second_request ]
487
+ done , _ = concurrent .futures .wait (futures , timeout = 3 )
488
+ # Ensure that both requests complete within the 3 second timeout,
489
+ # as the predict lock is not held through the postprocess step
490
+ assert first_request in done
491
+ assert second_request in done
492
+
493
+ for future in done :
494
+ # Ensure that both futures completed without error
495
+ future .result ()
496
+
497
+
498
+ @pytest .mark .integration
499
+ def test_postprocess ():
500
+ """
501
+ Tests a Truss that has a postprocess step defined, and ensures that the
502
+ postprocess does not happen within the predict lock. To do this, we sleep
503
+ for two seconds during the postprocess, and fire off two requests with a total
504
+ timeout of 3 seconds, ensureing that if they were serialized the test would fail.
505
+ """
506
+
507
+ model = """
508
+ import time
509
+
510
+ class Model:
511
+ def postprocess(self, response):
512
+ updated_items = []
513
+ for item in response:
514
+ time.sleep(1)
515
+ updated_items.append(item + " modified")
516
+ return updated_items
517
+
518
+ def predict(self, request):
519
+ return ["0", "1"]
520
+
521
+ """
522
+
523
+ config = "model_name: error-truss"
524
+ with ensure_kill_all (), tempfile .TemporaryDirectory (dir = "." ) as tmp_work_dir :
525
+ truss_dir = Path (tmp_work_dir , "truss" )
526
+
527
+ _create_truss (truss_dir , config , textwrap .dedent (model ))
528
+
529
+ tr = TrussHandle (truss_dir )
530
+ _ = tr .docker_run (local_port = 8090 , detach = True , wait_for_server_ready = True )
531
+ truss_server_addr = "http://localhost:8090"
532
+ full_url = f"{ truss_server_addr } /v1/models/model:predict"
533
+
534
+ def make_request (delay : int ):
535
+ time .sleep (delay )
536
+ response = requests .post (full_url , json = {})
537
+ assert response .status_code == 200
538
+ assert response .json () == ["0 modified" , "1 modified" ]
539
+
540
+ with ThreadPoolExecutor () as e :
541
+ # We use concurrent.futures.wait instead of the timeout property
542
+ # on requests, since requests timeout property has a complex interaction
543
+ # with streaming.
544
+ first_request = e .submit (make_request , 0 )
545
+ second_request = e .submit (make_request , 0.2 )
546
+ futures = [first_request , second_request ]
547
+ done , _ = concurrent .futures .wait (futures , timeout = 3 )
548
+ # Ensure that both requests complete within the 3 second timeout,
549
+ # as the predict lock is not held through the postprocess step
550
+ assert first_request in done
551
+ assert second_request in done
552
+
553
+ for future in done :
554
+ # Ensure that both futures completed without error
555
+ future .result ()
556
+
557
+
400
558
@pytest .mark .integration
401
559
def test_truss_with_errors ():
402
560
model = """
0 commit comments