Skip to content

Commit 1cd1d0a

Browse files
authored
Adds control server passthrough for websockets (#1441)
1 parent 1b02e8a commit 1cd1d0a

File tree

6 files changed

+172
-24
lines changed

6 files changed

+172
-24
lines changed

poetry.lock

+35-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "truss"
3-
version = "0.9.67rc001"
3+
version = "0.9.67rc004"
44
description = "A seamless bridge from model development to model delivery"
55
license = "MIT"
66
readme = "README.md"
@@ -91,6 +91,7 @@ click = { version = "^8.0.3", optional = false }
9191
fastapi = { version =">=0.109.1", optional = false }
9292
google-cloud-storage = { version = "2.10.0", optional = false }
9393
httpx = { version = ">=0.24.1", optional = false }
94+
httpx-ws = { version = "^0.7.1", optional = false }
9495
inquirerpy = { version = "^0.3.4", optional = false }
9596
libcst = { version = "<1.2.0", optional = false }
9697
loguru = { version = ">=0.7.2", optional = false }

truss/templates/control/control/application.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ async def start_background_inference_startup():
9494
@app.on_event("shutdown")
9595
def on_shutdown():
9696
# FastApi handles the term signal to start the shutdown flow. Here we
97-
# make sure that the inference server is stopeed when control server
97+
# make sure that the inference server is stopped when control server
9898
# shuts down. Inference server has logic to wait until all requests are
9999
# finished before exiting. By waiting on that, we inherit the same
100100
# behavior for control server.

truss/templates/control/control/endpoints.py

+70-20
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,26 @@
11
import asyncio
2-
from typing import Any, Dict
2+
import logging
3+
from typing import Any, Callable, Dict
34

45
import httpx
5-
from fastapi import APIRouter
6+
from fastapi import APIRouter, WebSocket
67
from fastapi.responses import JSONResponse, StreamingResponse
78
from helpers.errors import ModelLoadFailed, ModelNotReady
8-
from httpx import URL, ConnectError, RemoteProtocolError
9+
from httpx_ws import aconnect_ws
910
from starlette.requests import ClientDisconnect, Request
1011
from starlette.responses import Response
1112
from tenacity import RetryCallState, Retrying, retry_if_exception_type, wait_fixed
13+
from wsproto.events import BytesMessage, TextMessage
1214

1315
INFERENCE_SERVER_START_WAIT_SECS = 60
16+
BASE_RETRY_EXCEPTIONS = (
17+
retry_if_exception_type(httpx.ConnectError)
18+
| retry_if_exception_type(httpx.RemoteProtocolError)
19+
| retry_if_exception_type(httpx.ReadError)
20+
| retry_if_exception_type(httpx.ReadTimeout)
21+
| retry_if_exception_type(httpx.ConnectTimeout)
22+
| retry_if_exception_type(ModelNotReady)
23+
)
1424

1525
control_app = APIRouter()
1626

@@ -20,14 +30,14 @@ def index():
2030
return {}
2131

2232

23-
async def proxy(request: Request):
33+
async def proxy_http(request: Request):
2434
inference_server_process_controller = (
2535
request.app.state.inference_server_process_controller
2636
)
2737
client: httpx.AsyncClient = request.app.state.proxy_client
2838

2939
path = _reroute_if_health_check(request.url.path)
30-
url = URL(path=path, query=request.url.query.encode("utf-8"))
40+
url = httpx.URL(path=path, query=request.url.query.encode("utf-8"))
3141

3242
# 2 min connect timeouts, no timeout for requests.
3343
# We don't want requests to fail due to timeout on the proxy
@@ -47,19 +57,7 @@ async def proxy(request: Request):
4757
)
4858

4959
# Wait a bit for inference server to start
50-
for attempt in Retrying(
51-
retry=(
52-
retry_if_exception_type(ConnectError)
53-
| retry_if_exception_type(ModelNotReady)
54-
| retry_if_exception_type(RemoteProtocolError)
55-
| retry_if_exception_type(httpx.ReadError)
56-
| retry_if_exception_type(httpx.ReadTimeout)
57-
| retry_if_exception_type(httpx.ConnectTimeout)
58-
),
59-
stop=_custom_stop_strategy,
60-
wait=wait_fixed(1),
61-
reraise=False,
62-
):
60+
for attempt in inference_retries():
6361
with attempt:
6462
try:
6563
if inference_server_process_controller.is_inference_server_intentionally_stopped():
@@ -68,7 +66,7 @@ async def proxy(request: Request):
6866

6967
if await _is_model_not_ready(resp):
7068
raise ModelNotReady("Model has started running, but not ready yet.")
71-
except (RemoteProtocolError, ConnectError) as exp:
69+
except (httpx.RemoteProtocolError, httpx.ConnectError) as exp:
7270
# This check is a bit expensive so we don't do it before every request, we
7371
# do it only if request fails with connection error. If the inference server
7472
# process is running then we continue waiting for it to start (by retrying),
@@ -94,7 +92,59 @@ async def proxy(request: Request):
9492
return response
9593

9694

97-
control_app.add_route("/v1/{path:path}", proxy, ["GET", "POST"])
95+
def inference_retries(
96+
retry_condition: Callable[[RetryCallState], bool] = BASE_RETRY_EXCEPTIONS,
97+
):
98+
for attempt in Retrying(
99+
retry=retry_condition,
100+
stop=_custom_stop_strategy,
101+
wait=wait_fixed(1),
102+
reraise=False,
103+
):
104+
yield attempt
105+
106+
107+
async def _safe_close_ws(ws: WebSocket, logger: logging.Logger):
108+
try:
109+
await ws.close()
110+
except RuntimeError as close_error:
111+
logger.debug(f"Duplicate close of websocket: `{close_error}`.")
112+
113+
114+
async def proxy_ws(client_ws: WebSocket):
115+
await client_ws.accept()
116+
proxy_client: httpx.AsyncClient = client_ws.app.state.proxy_client
117+
logger = client_ws.app.state.logger
118+
119+
for attempt in inference_retries():
120+
with attempt:
121+
async with aconnect_ws("/v1/websocket", proxy_client) as server_ws: # type: ignore
122+
# Unfortunate, but FastAPI and httpx-ws have slightly different abstractions
123+
# for sending data, so it's not easy to create a unified wrapper.
124+
async def forward_to_server():
125+
while True:
126+
message = await client_ws.receive()
127+
if "text" in message:
128+
await server_ws.send_text(message["text"])
129+
elif "bytes" in message:
130+
await server_ws.send_bytes(message["bytes"])
131+
132+
async def forward_to_client():
133+
while True:
134+
message = await server_ws.receive()
135+
if isinstance(message, TextMessage):
136+
await client_ws.send_text(message.data)
137+
elif isinstance(message, BytesMessage):
138+
await client_ws.send_bytes(message.data)
139+
140+
try:
141+
await asyncio.gather(forward_to_client(), forward_to_server())
142+
finally:
143+
await _safe_close_ws(client_ws, logger)
144+
145+
146+
control_app.add_websocket_route("/v1/websocket", proxy_ws)
147+
control_app.add_route("/v1/{path:path}", proxy_http, ["GET", "POST"])
98148

99149

100150
@control_app.post("/control/patch")

truss/templates/control/requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ uvicorn==0.24.0
55
uvloop==0.19.0
66
tenacity==8.1.0
77
httpx==0.27.0
8+
httpx-ws>=0.7.0
89
python-json-logger==2.0.2
910
loguru==0.7.2
1011
websockets<=14.0

truss/tests/templates/control/control/test_server_integration.py

+63
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import psutil
1616
import pytest
1717
import requests
18+
import websockets
1819

1920
PATCH_PING_MAX_DELAY_SECS = 3
2021

@@ -91,6 +92,68 @@ def inner():
9192
assert resp.content == "01234".encode("utf-8")
9293

9394

95+
@pytest.mark.asyncio
96+
@pytest.mark.integration
97+
async def test_truss_control_server_text_websocket(
98+
control_server: ControlServerDetails,
99+
):
100+
ws_model_code = """
101+
import fastapi
102+
103+
class Model:
104+
async def websocket(self, websocket: fastapi.WebSocket):
105+
try:
106+
while True:
107+
text = await websocket.receive_text()
108+
await websocket.send_text(text + " pong")
109+
except fastapi.WebSocketDisconnect:
110+
pass
111+
"""
112+
113+
ctrl_url = f"ws://localhost:{control_server.control_server_port}"
114+
_patch(ws_model_code, control_server)
115+
116+
async with websockets.connect(f"{ctrl_url}/v1/websocket") as websocket:
117+
await websocket.send("hello")
118+
response = await websocket.recv()
119+
assert response == "hello pong"
120+
121+
await websocket.send("world")
122+
response = await websocket.recv()
123+
assert response == "world pong"
124+
125+
126+
@pytest.mark.asyncio
127+
@pytest.mark.integration
128+
async def test_truss_control_server_binary_websocket(
129+
control_server: ControlServerDetails,
130+
):
131+
ws_model_code = """
132+
import fastapi
133+
134+
class Model:
135+
async def websocket(self, websocket: fastapi.WebSocket):
136+
try:
137+
while True:
138+
text = await websocket.receive_bytes()
139+
await websocket.send_bytes(text + b" pong")
140+
except fastapi.WebSocketDisconnect:
141+
pass
142+
"""
143+
144+
ctrl_url = f"ws://localhost:{control_server.control_server_port}"
145+
_patch(ws_model_code, control_server)
146+
147+
async with websockets.connect(f"{ctrl_url}/v1/websocket") as websocket:
148+
await websocket.send(b"hello")
149+
response = await websocket.recv()
150+
assert response == b"hello pong"
151+
152+
await websocket.send(b"world")
153+
response = await websocket.recv()
154+
assert response == b"world pong"
155+
156+
94157
@pytest.mark.integration
95158
def test_truss_control_server_health_check(control_server: ControlServerDetails):
96159
ctrl_url = f"http://localhost:{control_server.control_server_port}"

0 commit comments

Comments
 (0)