Skip to content

Commit 3eddce1

Browse files
authored
Remove postprocess from lock (#700)
* Add test. * Move postprocess out of the predict lock. * Bump version. * Bump pyproject.
1 parent 8d33e9f commit 3eddce1

File tree

4 files changed

+171
-6
lines changed

4 files changed

+171
-6
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "truss"
3-
version = "0.7.12"
3+
version = "0.7.13rc1"
44
description = "A seamless bridge from model development to model delivery"
55
license = "MIT"
66
readme = "README.md"

truss/templates/server/common/truss_server.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ async def predict(
142142

143143
# In the case that the model returns a Generator object, return a
144144
# StreamingResponse instead.
145-
if isinstance(response, AsyncGenerator):
145+
if isinstance(response, (AsyncGenerator, Generator)):
146146
# media_type in StreamingResponse sets the Content-Type header
147147
return StreamingResponse(response, media_type="application/octet-stream")
148148

truss/templates/server/model_wrapper.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ async def postprocess(
225225
if inspect.isasyncgenfunction(
226226
self._model.postprocess
227227
) or inspect.isgeneratorfunction(self._model.postprocess):
228-
return self._model.postprocess(response, headers)
228+
return self._model.postprocess(response)
229229

230230
if inspect.iscoroutinefunction(self._model.postprocess):
231231
return await _intercept_exceptions_async(self._model.postprocess)(response)
@@ -264,10 +264,16 @@ async def __call__(
264264
async with deferred_semaphore(self._predict_semaphore) as semaphore_manager:
265265
response = await self.predict(payload, headers)
266266

267-
processed_response = await self.postprocess(response)
268-
269267
# Streaming cases
270268
if inspect.isgenerator(response) or inspect.isasyncgen(response):
269+
if hasattr(self._model, "postprocess"):
270+
logging.warning(
271+
"Predict returned a streaming response, while a postprocess is defined."
272+
"Note that in this case, the postprocess will run within the predict lock."
273+
)
274+
275+
response = await self.postprocess(response)
276+
271277
async_generator = _force_async_generator(response)
272278

273279
if headers and headers.get("accept") == "application/json":
@@ -309,7 +315,8 @@ async def _response_generator():
309315

310316
return _response_generator()
311317

312-
return processed_response
318+
processed_response = await self.postprocess(response)
319+
return processed_response
313320

314321

315322
class ResponseChunk:

truss/tests/test_model_inference.py

+158
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,164 @@ def predict(self, request):
397397
assert "Internal Server Error" in response.json()["error"]
398398

399399

400+
@pytest.mark.integration
401+
def test_postprocess_with_streaming_predict():
402+
"""
403+
Test a Truss that has streaming response from both predict and postprocess.
404+
In this case, the postprocess step continues to happen within the predict lock,
405+
so we don't bother testing the lock scenario, just the behavior that the postprocess
406+
function is applied.
407+
"""
408+
model = """
409+
import time
410+
411+
class Model:
412+
def postprocess(self, response):
413+
for item in response:
414+
time.sleep(1)
415+
yield item + " modified"
416+
417+
def predict(self, request):
418+
for i in range(2):
419+
yield str(i)
420+
"""
421+
422+
config = "model_name: error-truss"
423+
with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir:
424+
truss_dir = Path(tmp_work_dir, "truss")
425+
426+
_create_truss(truss_dir, config, textwrap.dedent(model))
427+
428+
tr = TrussHandle(truss_dir)
429+
_ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
430+
truss_server_addr = "http://localhost:8090"
431+
full_url = f"{truss_server_addr}/v1/models/model:predict"
432+
response = requests.post(full_url, json={}, stream=True)
433+
# Note that the postprocess function is applied to the
434+
# streamed response.
435+
assert response.content == b"0 modified1 modified"
436+
437+
438+
@pytest.mark.integration
439+
def test_streaming_postprocess():
440+
"""
441+
Tests a Truss where predict returns non-streaming, but postprocess is streamd, and
442+
ensures that the postprocess step does not happen within the predict lock. To do this,
443+
we sleep for two seconds during the postprocess streaming process, and fire off two
444+
requests with a total timeout of 3 seconds, ensuring that if they were serialized
445+
the test would fail.
446+
"""
447+
model = """
448+
import time
449+
450+
class Model:
451+
def postprocess(self, response):
452+
for item in response:
453+
time.sleep(1)
454+
yield item + " modified"
455+
456+
def predict(self, request):
457+
return ["0", "1"]
458+
"""
459+
460+
config = "model_name: error-truss"
461+
with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir:
462+
truss_dir = Path(tmp_work_dir, "truss")
463+
464+
_create_truss(truss_dir, config, textwrap.dedent(model))
465+
466+
tr = TrussHandle(truss_dir)
467+
_ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
468+
truss_server_addr = "http://localhost:8090"
469+
full_url = f"{truss_server_addr}/v1/models/model:predict"
470+
471+
def make_request(delay: int):
472+
# For streamed responses, requests does not start receiving content from server until
473+
# `iter_content` is called, so we must call this in order to get an actual timeout.
474+
time.sleep(delay)
475+
response = requests.post(full_url, json={}, stream=True)
476+
477+
assert response.status_code == 200
478+
assert response.content == b"0 modified1 modified"
479+
480+
with ThreadPoolExecutor() as e:
481+
# We use concurrent.futures.wait instead of the timeout property
482+
# on requests, since requests timeout property has a complex interaction
483+
# with streaming.
484+
first_request = e.submit(make_request, 0)
485+
second_request = e.submit(make_request, 0.2)
486+
futures = [first_request, second_request]
487+
done, _ = concurrent.futures.wait(futures, timeout=3)
488+
# Ensure that both requests complete within the 3 second timeout,
489+
# as the predict lock is not held through the postprocess step
490+
assert first_request in done
491+
assert second_request in done
492+
493+
for future in done:
494+
# Ensure that both futures completed without error
495+
future.result()
496+
497+
498+
@pytest.mark.integration
499+
def test_postprocess():
500+
"""
501+
Tests a Truss that has a postprocess step defined, and ensures that the
502+
postprocess does not happen within the predict lock. To do this, we sleep
503+
for two seconds during the postprocess, and fire off two requests with a total
504+
timeout of 3 seconds, ensureing that if they were serialized the test would fail.
505+
"""
506+
507+
model = """
508+
import time
509+
510+
class Model:
511+
def postprocess(self, response):
512+
updated_items = []
513+
for item in response:
514+
time.sleep(1)
515+
updated_items.append(item + " modified")
516+
return updated_items
517+
518+
def predict(self, request):
519+
return ["0", "1"]
520+
521+
"""
522+
523+
config = "model_name: error-truss"
524+
with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir:
525+
truss_dir = Path(tmp_work_dir, "truss")
526+
527+
_create_truss(truss_dir, config, textwrap.dedent(model))
528+
529+
tr = TrussHandle(truss_dir)
530+
_ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
531+
truss_server_addr = "http://localhost:8090"
532+
full_url = f"{truss_server_addr}/v1/models/model:predict"
533+
534+
def make_request(delay: int):
535+
time.sleep(delay)
536+
response = requests.post(full_url, json={})
537+
assert response.status_code == 200
538+
assert response.json() == ["0 modified", "1 modified"]
539+
540+
with ThreadPoolExecutor() as e:
541+
# We use concurrent.futures.wait instead of the timeout property
542+
# on requests, since requests timeout property has a complex interaction
543+
# with streaming.
544+
first_request = e.submit(make_request, 0)
545+
second_request = e.submit(make_request, 0.2)
546+
futures = [first_request, second_request]
547+
done, _ = concurrent.futures.wait(futures, timeout=3)
548+
# Ensure that both requests complete within the 3 second timeout,
549+
# as the predict lock is not held through the postprocess step
550+
assert first_request in done
551+
assert second_request in done
552+
553+
for future in done:
554+
# Ensure that both futures completed without error
555+
future.result()
556+
557+
400558
@pytest.mark.integration
401559
def test_truss_with_errors():
402560
model = """

0 commit comments

Comments
 (0)