1
1
import asyncio
2
- from typing import Any , Dict
2
+ import logging
3
+ from typing import Any , Callable , Dict
3
4
4
5
import httpx
5
- from fastapi import APIRouter
6
+ from fastapi import APIRouter , WebSocket
6
7
from fastapi .responses import JSONResponse , StreamingResponse
7
8
from helpers .errors import ModelLoadFailed , ModelNotReady
8
- from httpx import URL , ConnectError , RemoteProtocolError
9
+ from httpx_ws import aconnect_ws
9
10
from starlette .requests import ClientDisconnect , Request
10
11
from starlette .responses import Response
11
12
from tenacity import RetryCallState , Retrying , retry_if_exception_type , wait_fixed
13
+ from wsproto .events import BytesMessage , TextMessage
12
14
13
15
INFERENCE_SERVER_START_WAIT_SECS = 60
16
+ BASE_RETRY_EXCEPTIONS = (
17
+ retry_if_exception_type (httpx .ConnectError )
18
+ | retry_if_exception_type (httpx .RemoteProtocolError )
19
+ | retry_if_exception_type (httpx .ReadError )
20
+ | retry_if_exception_type (httpx .ReadTimeout )
21
+ | retry_if_exception_type (httpx .ConnectTimeout )
22
+ | retry_if_exception_type (ModelNotReady )
23
+ )
14
24
15
25
control_app = APIRouter ()
16
26
@@ -20,14 +30,14 @@ def index():
20
30
return {}
21
31
22
32
23
- async def proxy (request : Request ):
33
+ async def proxy_http (request : Request ):
24
34
inference_server_process_controller = (
25
35
request .app .state .inference_server_process_controller
26
36
)
27
37
client : httpx .AsyncClient = request .app .state .proxy_client
28
38
29
39
path = _reroute_if_health_check (request .url .path )
30
- url = URL (path = path , query = request .url .query .encode ("utf-8" ))
40
+ url = httpx . URL (path = path , query = request .url .query .encode ("utf-8" ))
31
41
32
42
# 2 min connect timeouts, no timeout for requests.
33
43
# We don't want requests to fail due to timeout on the proxy
@@ -47,19 +57,7 @@ async def proxy(request: Request):
47
57
)
48
58
49
59
# Wait a bit for inference server to start
50
- for attempt in Retrying (
51
- retry = (
52
- retry_if_exception_type (ConnectError )
53
- | retry_if_exception_type (ModelNotReady )
54
- | retry_if_exception_type (RemoteProtocolError )
55
- | retry_if_exception_type (httpx .ReadError )
56
- | retry_if_exception_type (httpx .ReadTimeout )
57
- | retry_if_exception_type (httpx .ConnectTimeout )
58
- ),
59
- stop = _custom_stop_strategy ,
60
- wait = wait_fixed (1 ),
61
- reraise = False ,
62
- ):
60
+ for attempt in inference_retries ():
63
61
with attempt :
64
62
try :
65
63
if inference_server_process_controller .is_inference_server_intentionally_stopped ():
@@ -68,7 +66,7 @@ async def proxy(request: Request):
68
66
69
67
if await _is_model_not_ready (resp ):
70
68
raise ModelNotReady ("Model has started running, but not ready yet." )
71
- except (RemoteProtocolError , ConnectError ) as exp :
69
+ except (httpx . RemoteProtocolError , httpx . ConnectError ) as exp :
72
70
# This check is a bit expensive so we don't do it before every request, we
73
71
# do it only if request fails with connection error. If the inference server
74
72
# process is running then we continue waiting for it to start (by retrying),
@@ -94,7 +92,59 @@ async def proxy(request: Request):
94
92
return response
95
93
96
94
97
- control_app .add_route ("/v1/{path:path}" , proxy , ["GET" , "POST" ])
95
+ def inference_retries (
96
+ retry_condition : Callable [[RetryCallState ], bool ] = BASE_RETRY_EXCEPTIONS ,
97
+ ):
98
+ for attempt in Retrying (
99
+ retry = retry_condition ,
100
+ stop = _custom_stop_strategy ,
101
+ wait = wait_fixed (1 ),
102
+ reraise = False ,
103
+ ):
104
+ yield attempt
105
+
106
+
107
+ async def _safe_close_ws (ws : WebSocket , logger : logging .Logger ):
108
+ try :
109
+ await ws .close ()
110
+ except RuntimeError as close_error :
111
+ logger .debug (f"Duplicate close of websocket: `{ close_error } `." )
112
+
113
+
114
+ async def proxy_ws (client_ws : WebSocket ):
115
+ await client_ws .accept ()
116
+ proxy_client : httpx .AsyncClient = client_ws .app .state .proxy_client
117
+ logger = client_ws .app .state .logger
118
+
119
+ for attempt in inference_retries ():
120
+ with attempt :
121
+ async with aconnect_ws ("/v1/websocket" , proxy_client ) as server_ws : # type: ignore
122
+ # Unfortunate, but FastAPI and httpx-ws have slightly different abstractions
123
+ # for sending data, so it's not easy to create a unified wrapper.
124
+ async def forward_to_server ():
125
+ while True :
126
+ message = await client_ws .receive ()
127
+ if "text" in message :
128
+ await server_ws .send_text (message ["text" ])
129
+ elif "bytes" in message :
130
+ await server_ws .send_bytes (message ["bytes" ])
131
+
132
+ async def forward_to_client ():
133
+ while True :
134
+ message = await server_ws .receive ()
135
+ if isinstance (message , TextMessage ):
136
+ await client_ws .send_text (message .data )
137
+ elif isinstance (message , BytesMessage ):
138
+ await client_ws .send_bytes (message .data )
139
+
140
+ try :
141
+ await asyncio .gather (forward_to_client (), forward_to_server ())
142
+ finally :
143
+ await _safe_close_ws (client_ws , logger )
144
+
145
+
146
+ control_app .add_websocket_route ("/v1/websocket" , proxy_ws )
147
+ control_app .add_route ("/v1/{path:path}" , proxy_http , ["GET" , "POST" ])
98
148
99
149
100
150
@control_app .post ("/control/patch" )
0 commit comments