Skip to content

Commit d42dfae

Browse files
authored
Gracefully handle client disconnects (#625)
* More gracefully handle client disconnect. * Handle prod servers, bump pyproject. * Adjust error message * add log statement to see if it shows up. * Bump version. * remove log. * Copy the control server logic for inference server. * Bump project version. * Remove log line.
1 parent 32a3157 commit d42dfae

File tree

3 files changed

+17
-8
lines changed

3 files changed

+17
-8
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "truss"
3-
version = "0.6.5rc2"
3+
version = "0.6.5rc5"
44
description = "A seamless bridge from model development to model delivery"
55
license = "MIT"
66
readme = "README.md"

truss/templates/control/control/endpoints.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from fastapi.responses import JSONResponse, StreamingResponse
77
from helpers.errors import ModelLoadFailed, ModelNotReady
88
from httpx import URL, ConnectError, RemoteProtocolError
9-
from starlette.requests import Request
9+
from starlette.requests import ClientDisconnect, Request
1010
from starlette.responses import Response
1111
from tenacity import Retrying, retry_if_exception_type, stop_after_attempt, wait_fixed
1212

@@ -28,15 +28,20 @@ async def proxy(request: Request):
2828
client: httpx.AsyncClient = request.app.state.proxy_client
2929
url = URL(path=request.url.path, query=request.url.query.encode("utf-8"))
3030

31-
# 5 mins request and 2 min connect timeouts
32-
# Large values; we don't want requests to fail due to timeout on the proxy
33-
timeout = httpx.Timeout(5 * 60.0, connect=2 * 60.0)
31+
# 2 min connect timeouts, no timeout for requests.
32+
# We don't want requests to fail due to timeout on the proxy
33+
timeout = httpx.Timeout(None, connect=2 * 60.0)
34+
try:
35+
request_body = await request.body()
36+
except ClientDisconnect:
37+
# If the client disconnects, we don't need to proxy the request
38+
return Response(status_code=499)
3439

3540
inf_serv_req = client.build_request(
3641
request.method,
3742
url,
3843
headers=request.headers.raw,
39-
content=await request.body(),
44+
content=request_body,
4045
timeout=timeout,
4146
)
4247

truss/templates/server/common/truss_server.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import shared.util as utils
1616
import uvicorn
1717
from common.termination_handler_middleware import TerminationHandlerMiddleware
18-
from fastapi import Depends, FastAPI, Request
18+
from fastapi import Depends, FastAPI, HTTPException, Request
1919
from fastapi.responses import ORJSONResponse, StreamingResponse
2020
from fastapi.routing import APIRoute as FastAPIRoute
2121
from model_wrapper import ModelWrapper
@@ -26,6 +26,7 @@
2626
truss_msgpack_serialize,
2727
)
2828
from starlette.middleware.base import BaseHTTPMiddleware
29+
from starlette.requests import ClientDisconnect
2930
from starlette.responses import Response
3031

3132
# [IMPORTANT] A lot of things depend on this currently.
@@ -41,7 +42,10 @@ async def parse_body(request: Request) -> bytes:
4142
"""
4243
Used by FastAPI to read body in an asynchronous manner
4344
"""
44-
return await request.body()
45+
try:
46+
return await request.body()
47+
except ClientDisconnect as exc:
48+
raise HTTPException(status_code=499, detail="Client disconnected") from exc
4549

4650

4751
FORMAT = "%(asctime)s.%(msecs)03d %(name)s %(levelname)s [%(funcName)s():%(lineno)s] %(message)s"

0 commit comments

Comments
 (0)