Skip to content

Commit 1524515

Browse files
committedDec 10, 2023
Add concurrency_limit middleware
The tower::limit::ConcurrencyLimit middleware does not work properly when used with services that return response with streaming body. The middleware considers concurrency only from call to service to response future resolution. But response with streaming body can't be considered finished until body has been consumed. Add concurrency_limit module with ConcurrencyLimit middleware implementation that holds on the semaphore permit until response and its body has been consumed. It uses original middleware from `tower` crate and `tower-http::metrics::InFlightRequests` middleware as inspiration.
1 parent 6f964b1 commit 1524515

File tree

3 files changed

+317
-0
lines changed

3 files changed

+317
-0
lines changed
 

‎tower-http/Cargo.toml

+2
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ full = [
6262
"auth",
6363
"catch-panic",
6464
"compression-full",
65+
"concurrency-limit",
6566
"cors",
6667
"decompression-full",
6768
"follow-redirect",
@@ -86,6 +87,7 @@ full = [
8687
add-extension = []
8788
auth = ["base64", "validate-request"]
8889
catch-panic = ["tracing", "futures-util/std"]
90+
concurrency-limit = ["tokio", "tokio-util"]
8991
cors = []
9092
follow-redirect = ["futures-util", "iri-string", "tower/util"]
9193
fs = ["futures-util", "tokio/fs", "tokio-util/io", "tokio/io-util", "dep:http-range-header", "mime_guess", "mime", "percent-encoding", "httpdate", "set-status", "futures-util/alloc", "tracing"]

‎tower-http/src/concurrency_limit.rs

+312
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,312 @@
1+
//! Limit the max number of concurrently processed requests.
2+
//!
3+
//! The service sets a maximum limit to the number of concurrently processed requests. The
4+
//! processing of a request starts when it is received by the service (`tower::Service::call` is
5+
//! called) and is considered complete when the response body is consumed, dropped, or an error
6+
//! happens.
7+
//!
8+
//! Internally, it uses semaphore to track and limit number of in-flight requests
9+
//!
10+
//! # Relation to `ConcurrencyLimit` from `tower` crate
11+
//!
12+
//! The `tower::limit::concurrency::ConcurrencyLimit` service uses a different definition of
13+
//! 'request processing'. It starts when request is received by `tower::Service::call`, and ends
14+
//! immediatelly after response is produced.
15+
//!
16+
//! In some cases it may not work properly with [`http::Response`], as it does not account for
17+
//! process of consuming response body.
18+
//!
19+
//! When stream is used as response body, the process of consumig it (ie streaming to called) may
20+
//! take longer and use more resources than just producing the response itself. And often it the
21+
//! number of streams we are processing concurrently we want to limit.
22+
//!
23+
//! The service version from [`tower-http`](crate) takes response body consumption into
24+
//! consideration and *will* limit number of concurrent streams correctly.
25+
//!
26+
//! ```
27+
//! use std::convert::Infallible;
28+
//! use bytes::Bytes;
29+
//! use http::{Request, Response};
30+
//! use http_body_util::Full;
31+
//! use tower::{Service, ServiceExt, ServiceBuilder};
32+
//! use tower_http::concurrency_limit::ConcurrencyLimitLayer;
33+
//!
34+
//! async fn handle(req: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> {
35+
//! // ...
36+
//! # Ok(Response::new(Full::default()))
37+
//! }
38+
//!
39+
//! # #[tokio::main]
40+
//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
41+
//! let mut service = ServiceBuilder::new()
42+
//! // limit number of concurrent requests to 3
43+
//! .layer(ConcurrencyLimitLayer::new(3))
44+
//! .service_fn(handle);
45+
//!
46+
//! // Call the service.
47+
//! let response = service
48+
//! .ready()
49+
//! .await?
50+
//! .call(Request::new(Full::default()))
51+
//! .await?;
52+
//! # Ok(())
53+
//! # }
54+
//! ```
55+
//!
56+
57+
use http::{Request, Response};
58+
use pin_project_lite::pin_project;
59+
use std::{
60+
future::Future,
61+
pin::Pin,
62+
sync::Arc,
63+
task::{ready, Context, Poll},
64+
};
65+
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
66+
use tokio_util::sync::PollSemaphore;
67+
68+
/// Limit max number of concurrent requests (per service)
69+
///
70+
/// The layer enforces a same concurrency limit for each inner service separately. In other words,
71+
/// [`ConcurrencyLimit`] middleware constructed from this layer for each service will have its own
72+
/// semaphore and will track and limit requests separately.
73+
///
74+
/// To track and limit multiple services together, see [`SharedConcurrencyLimitLayer`]
75+
///
76+
/// See the [module docs](crate::concurrency_limit) for more details.
77+
#[derive(Clone, Debug)]
78+
pub struct ConcurrencyLimitLayer {
79+
max: usize,
80+
}
81+
82+
impl ConcurrencyLimitLayer {
83+
/// Create new [`ConcurrencyLimitLayer`] with semaphore size
84+
pub fn new(max: usize) -> Self {
85+
Self { max }
86+
}
87+
}
88+
89+
impl<S> tower_layer::Layer<S> for ConcurrencyLimitLayer {
90+
type Service = ConcurrencyLimit<S>;
91+
92+
fn layer(&self, service: S) -> Self::Service {
93+
ConcurrencyLimit::new(service, Arc::new(Semaphore::new(self.max)))
94+
}
95+
}
96+
97+
/// Limit max number of concurrent requests (shared)
98+
///
99+
/// The layer enforces a single concurrency limit for multiple inner services at once. In other
100+
/// words, [`ConcurrencyLimit`] middleware constructed from this layer for each service will have
101+
/// one shared semaphore and will track and limit requests together..
102+
///
103+
/// To track and limit each service separately, see [`ConcurrencyLimitLayer`].
104+
///
105+
/// See the [module docs](crate::concurrency_limit) for more details.
106+
#[derive(Clone, Debug)]
107+
pub struct SharedConcurrencyLimitLayer {
108+
semaphore: Arc<Semaphore>,
109+
}
110+
111+
impl SharedConcurrencyLimitLayer {
112+
/// Create new [`ConcurrencyLimitLayer`] with shared semaphore
113+
pub fn new(max: usize) -> Self {
114+
Self {
115+
semaphore: Arc::new(Semaphore::new(max)),
116+
}
117+
}
118+
119+
/// Create [`ConcurrencyLimitLayer`] from semaphore
120+
pub fn from_semaphore(semaphore: Arc<Semaphore>) -> Self {
121+
Self { semaphore }
122+
}
123+
}
124+
125+
impl<S> tower_layer::Layer<S> for SharedConcurrencyLimitLayer {
126+
type Service = ConcurrencyLimit<S>;
127+
128+
fn layer(&self, service: S) -> Self::Service {
129+
ConcurrencyLimit::new(service, self.semaphore.clone())
130+
}
131+
}
132+
133+
/// Middleware that limits max number fo concurrent in-flight requests.
134+
///
135+
/// See the [module docs](crate::concurrency_limit) for more details.
136+
#[derive(Debug)]
137+
pub struct ConcurrencyLimit<S> {
138+
inner: S,
139+
semaphore: PollSemaphore,
140+
permit: Option<OwnedSemaphorePermit>,
141+
}
142+
143+
impl<S> ConcurrencyLimit<S> {
144+
/// Create new [`ConcurrencyLimit`] with associated semaphore
145+
pub fn new(inner: S, semaphore: Arc<Semaphore>) -> Self {
146+
Self {
147+
inner,
148+
semaphore: PollSemaphore::new(semaphore),
149+
permit: None,
150+
}
151+
}
152+
153+
define_inner_service_accessors!();
154+
}
155+
156+
// Since we hold an `OwnedSemaphorePermit`, we can't derive `Clone`. Instead, when cloning the
157+
// service, create a new service with the same semaphore, but with the permit in the un-acquired
158+
// state.
159+
impl<T: Clone> Clone for ConcurrencyLimit<T> {
160+
fn clone(&self) -> Self {
161+
Self {
162+
inner: self.inner.clone(),
163+
semaphore: self.semaphore.clone(),
164+
permit: None,
165+
}
166+
}
167+
}
168+
169+
impl<S, R, Body> tower_service::Service<Request<R>> for ConcurrencyLimit<S>
170+
where
171+
S: tower_service::Service<Request<R>, Response = Response<Body>>,
172+
{
173+
type Response = Response<ResponseBody<Body>>;
174+
type Error = S::Error;
175+
type Future = ResponseFuture<S::Future>;
176+
177+
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
178+
if self.permit.is_none() {
179+
self.permit = ready!(self.semaphore.poll_acquire(cx));
180+
}
181+
182+
self.inner.poll_ready(cx)
183+
}
184+
185+
fn call(&mut self, request: Request<R>) -> Self::Future {
186+
let permit = self
187+
.permit
188+
.take()
189+
.expect("max requests in-flight; poll_ready must be called first");
190+
191+
let future = self.inner.call(request);
192+
ResponseFuture {
193+
inner: future,
194+
permit: Some(permit),
195+
}
196+
}
197+
}
198+
199+
pin_project! {
200+
201+
/// Response future for [`ConcurrencyLimit`]
202+
pub struct ResponseFuture<F> {
203+
#[pin]
204+
inner: F,
205+
206+
// The permit is stored inside option, so that we can take it out from the future on its
207+
// completion and pass it to the ResponseBody. The permit has to be droped only after
208+
// ResponseBody is consumed.
209+
permit: Option<OwnedSemaphorePermit>,
210+
}
211+
}
212+
213+
impl<F, B, E> Future for ResponseFuture<F>
214+
where
215+
F: Future<Output = Result<Response<B>, E>>,
216+
{
217+
type Output = Result<Response<ResponseBody<B>>, E>;
218+
219+
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
220+
let this = self.project();
221+
let response = ready!(this.inner.poll(cx))?;
222+
223+
let permit = this.permit.take().unwrap();
224+
let response = response.map(move |body| ResponseBody {
225+
inner: body,
226+
permit,
227+
});
228+
229+
Poll::Ready(Ok(response))
230+
}
231+
}
232+
233+
pin_project! {
234+
235+
/// Response body for [`ConcurrencyLimit`]
236+
///
237+
/// It enforces limit on number of `struct` instances in concurrent existence.
238+
pub struct ResponseBody<B> {
239+
#[pin]
240+
inner: B,
241+
permit: OwnedSemaphorePermit,
242+
}
243+
}
244+
245+
impl<B> http_body::Body for ResponseBody<B>
246+
where
247+
B: http_body::Body,
248+
{
249+
type Data = B::Data;
250+
type Error = B::Error;
251+
252+
#[inline]
253+
fn poll_frame(
254+
self: Pin<&mut Self>,
255+
cx: &mut Context<'_>,
256+
) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
257+
self.project().inner.poll_frame(cx)
258+
}
259+
260+
#[inline]
261+
fn is_end_stream(&self) -> bool {
262+
self.inner.is_end_stream()
263+
}
264+
265+
#[inline]
266+
fn size_hint(&self) -> http_body::SizeHint {
267+
self.inner.size_hint()
268+
}
269+
}
270+
271+
#[cfg(test)]
272+
mod tests {
273+
use super::*;
274+
use crate::test_helpers::Body;
275+
use http::Request;
276+
use tower::{BoxError, ServiceBuilder};
277+
use tower_service::Service;
278+
279+
#[tokio::test]
280+
async fn basic() {
281+
let semaphore = Arc::new(Semaphore::new(1));
282+
assert_eq!(1, semaphore.available_permits());
283+
284+
let mut service = ServiceBuilder::new()
285+
.layer(SharedConcurrencyLimitLayer::from_semaphore(
286+
semaphore.clone(),
287+
))
288+
.service_fn(echo);
289+
290+
// driving service to ready pre-acquire semaphore permit, decrease available count
291+
std::future::poll_fn(|cx| service.poll_ready(cx))
292+
.await
293+
.unwrap();
294+
assert_eq!(0, semaphore.available_permits());
295+
296+
// creating response future decreases number of permits
297+
let response_future = service.call(Request::new(Body::empty()));
298+
299+
// awaiting response future moves permit to response, no change in available count
300+
let response = response_future.await.unwrap();
301+
assert_eq!(0, semaphore.available_permits());
302+
303+
// consuming response frees permit and increase available count
304+
let body = response.into_body();
305+
crate::test_helpers::to_bytes(body).await.unwrap();
306+
assert_eq!(1, semaphore.available_permits());
307+
}
308+
309+
async fn echo(req: Request<Body>) -> Result<Response<Body>, BoxError> {
310+
Ok(Response::new(req.into_body()))
311+
}
312+
}

‎tower-http/src/lib.rs

+3
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,9 @@ mod compression_utils;
294294
))]
295295
pub use compression_utils::CompressionLevel;
296296

297+
#[cfg(feature = "concurrency-limit")]
298+
pub mod concurrency_limit;
299+
297300
#[cfg(feature = "map-response-body")]
298301
pub mod map_response_body;
299302

0 commit comments

Comments
 (0)