Skip to content

Commit 86fdf00

Browse files
cratelyndavidpdrsn
andauthored
util: add a channel body (#140)
* Add channel body * review: use `sync::oneshot` for error channel this applies a review suggestion here: https://github.com/hyperium/http-body/pull/100/files#r1399781061 this commit refactors the channel-backed body in #100, changing the `mpsc::Receiver<E>` used to transmit an error into a `oneshot::Receiver<E>`. this should improve memory usage, and make the channel a smaller structure. in order to achieve this, some minor adjustments are made: * use pin projection, projecting pinnedness to the oneshot receiver, polling it via `core::future::Future::poll(..)` to yield a body frame. * add `Debug` bounds were needed. as an alternative, see tokio-rs/tokio#7059, which proposed a `poll_recv(..)` inherent method for a oneshot channel receiver. * review: use `&mut self` method receivers this applies a review suggestion here: https://github.com/hyperium/http-body/pull/100/files#r1399780355 this commit refactors the channel-backed body in #100, changing the signature of `send_*` methods on the sender to require a mutable reference. * review: fix `<Channel<D, E> as Body>::poll_frame()` see: #140 (comment) this commit adds test coverage exposing the bug, and tightens the pattern used to match frames yielded by the data channel. now, when the channel is closed, a `None` will flow onwards and poll the error channel. `None` will be returned when the error channel is closed, which also indicates that the associated `Sender` has been dropped. --------- Co-authored-by: David Pedersen <david.pdrsn@gmail.com>
1 parent 7339aec commit 86fdf00

File tree

3 files changed

+248
-0
lines changed

3 files changed

+248
-0
lines changed

http-body-util/Cargo.toml

+8
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,21 @@ keywords = ["http"]
2626
categories = ["web-programming"]
2727
rust-version = "1.61"
2828

29+
[features]
30+
default = []
31+
channel = ["dep:tokio"]
32+
full = ["channel"]
33+
2934
[dependencies]
3035
bytes = "1"
3136
futures-core = { version = "0.3", default-features = false }
3237
http = "1"
3338
http-body = { version = "1", path = "../http-body" }
3439
pin-project-lite = "0.2"
3540

41+
# optional dependencies
42+
tokio = { version = "1", features = ["sync"], optional = true }
43+
3644
[dev-dependencies]
3745
futures-util = { version = "0.3", default-features = false }
3846
tokio = { version = "1", features = ["macros", "rt", "sync", "rt-multi-thread"] }

http-body-util/src/channel.rs

+234
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
//! A body backed by a channel.
2+
3+
use std::{
4+
fmt::Display,
5+
pin::Pin,
6+
task::{Context, Poll},
7+
};
8+
9+
use bytes::Buf;
10+
use http::HeaderMap;
11+
use http_body::{Body, Frame};
12+
use pin_project_lite::pin_project;
13+
use tokio::sync::{mpsc, oneshot};
14+
15+
pin_project! {
16+
/// A body backed by a channel.
17+
pub struct Channel<D, E = std::convert::Infallible> {
18+
rx_frame: mpsc::Receiver<Frame<D>>,
19+
#[pin]
20+
rx_error: oneshot::Receiver<E>,
21+
}
22+
}
23+
24+
impl<D, E> Channel<D, E> {
25+
/// Create a new channel body.
26+
///
27+
/// The channel will buffer up to the provided number of messages. Once the buffer is full,
28+
/// attempts to send new messages will wait until a message is received from the channel. The
29+
/// provided buffer capacity must be at least 1.
30+
pub fn new(buffer: usize) -> (Sender<D, E>, Self) {
31+
let (tx_frame, rx_frame) = mpsc::channel(buffer);
32+
let (tx_error, rx_error) = oneshot::channel();
33+
(Sender { tx_frame, tx_error }, Self { rx_frame, rx_error })
34+
}
35+
}
36+
37+
impl<D, E> Body for Channel<D, E>
38+
where
39+
D: Buf,
40+
{
41+
type Data = D;
42+
type Error = E;
43+
44+
fn poll_frame(
45+
self: Pin<&mut Self>,
46+
cx: &mut Context<'_>,
47+
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
48+
let this = self.project();
49+
50+
match this.rx_frame.poll_recv(cx) {
51+
Poll::Ready(frame @ Some(_)) => return Poll::Ready(frame.map(Ok)),
52+
Poll::Ready(None) | Poll::Pending => {}
53+
}
54+
55+
use core::future::Future;
56+
match this.rx_error.poll(cx) {
57+
Poll::Ready(Ok(error)) => return Poll::Ready(Some(Err(error))),
58+
Poll::Ready(Err(_)) => return Poll::Ready(None),
59+
Poll::Pending => {}
60+
}
61+
62+
Poll::Pending
63+
}
64+
}
65+
66+
impl<D, E: std::fmt::Debug> std::fmt::Debug for Channel<D, E> {
67+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68+
f.debug_struct("Channel")
69+
.field("rx_frame", &self.rx_frame)
70+
.field("rx_error", &self.rx_error)
71+
.finish()
72+
}
73+
}
74+
75+
/// A sender half created through [`Channel::new`].
76+
pub struct Sender<D, E = std::convert::Infallible> {
77+
tx_frame: mpsc::Sender<Frame<D>>,
78+
tx_error: oneshot::Sender<E>,
79+
}
80+
81+
impl<D, E> Sender<D, E> {
82+
/// Send a frame on the channel.
83+
pub async fn send(&mut self, frame: Frame<D>) -> Result<(), SendError> {
84+
self.tx_frame.send(frame).await.map_err(|_| SendError)
85+
}
86+
87+
/// Send data on data channel.
88+
pub async fn send_data(&mut self, buf: D) -> Result<(), SendError> {
89+
self.send(Frame::data(buf)).await
90+
}
91+
92+
/// Send trailers on trailers channel.
93+
pub async fn send_trailers(&mut self, trailers: HeaderMap) -> Result<(), SendError> {
94+
self.send(Frame::trailers(trailers)).await
95+
}
96+
97+
/// Aborts the body in an abnormal fashion.
98+
pub fn abort(self, error: E) {
99+
self.tx_error.send(error).ok();
100+
}
101+
}
102+
103+
impl<D, E: std::fmt::Debug> std::fmt::Debug for Sender<D, E> {
104+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105+
f.debug_struct("Sender")
106+
.field("tx_frame", &self.tx_frame)
107+
.field("tx_error", &self.tx_error)
108+
.finish()
109+
}
110+
}
111+
112+
/// The error returned if [`Sender`] fails to send because the receiver is closed.
113+
#[derive(Debug)]
114+
#[non_exhaustive]
115+
pub struct SendError;
116+
117+
impl Display for SendError {
118+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119+
write!(f, "failed to send frame")
120+
}
121+
}
122+
123+
impl std::error::Error for SendError {}
124+
125+
#[cfg(test)]
126+
mod tests {
127+
use bytes::Bytes;
128+
use http::{HeaderName, HeaderValue};
129+
130+
use crate::BodyExt;
131+
132+
use super::*;
133+
134+
#[tokio::test]
135+
async fn empty() {
136+
let (tx, body) = Channel::<Bytes>::new(1024);
137+
drop(tx);
138+
139+
let collected = body.collect().await.unwrap();
140+
assert!(collected.trailers().is_none());
141+
assert!(collected.to_bytes().is_empty());
142+
}
143+
144+
#[tokio::test]
145+
async fn can_send_data() {
146+
let (mut tx, body) = Channel::<Bytes>::new(1024);
147+
148+
tokio::spawn(async move {
149+
tx.send_data(Bytes::from("Hel")).await.unwrap();
150+
tx.send_data(Bytes::from("lo!")).await.unwrap();
151+
});
152+
153+
let collected = body.collect().await.unwrap();
154+
assert!(collected.trailers().is_none());
155+
assert_eq!(collected.to_bytes(), "Hello!");
156+
}
157+
158+
#[tokio::test]
159+
async fn can_send_trailers() {
160+
let (mut tx, body) = Channel::<Bytes>::new(1024);
161+
162+
tokio::spawn(async move {
163+
let mut trailers = HeaderMap::new();
164+
trailers.insert(
165+
HeaderName::from_static("foo"),
166+
HeaderValue::from_static("bar"),
167+
);
168+
tx.send_trailers(trailers).await.unwrap();
169+
});
170+
171+
let collected = body.collect().await.unwrap();
172+
assert_eq!(collected.trailers().unwrap()["foo"], "bar");
173+
assert!(collected.to_bytes().is_empty());
174+
}
175+
176+
#[tokio::test]
177+
async fn can_send_both_data_and_trailers() {
178+
let (mut tx, body) = Channel::<Bytes>::new(1024);
179+
180+
tokio::spawn(async move {
181+
tx.send_data(Bytes::from("Hel")).await.unwrap();
182+
tx.send_data(Bytes::from("lo!")).await.unwrap();
183+
let mut trailers = HeaderMap::new();
184+
trailers.insert(
185+
HeaderName::from_static("foo"),
186+
HeaderValue::from_static("bar"),
187+
);
188+
tx.send_trailers(trailers).await.unwrap();
189+
});
190+
191+
let collected = body.collect().await.unwrap();
192+
assert_eq!(collected.trailers().unwrap()["foo"], "bar");
193+
assert_eq!(collected.to_bytes(), "Hello!");
194+
}
195+
196+
/// A stand-in for an error type, for unit tests.
197+
type Error = &'static str;
198+
/// An example error message.
199+
const MSG: Error = "oh no";
200+
201+
#[tokio::test]
202+
async fn aborts_before_trailers() {
203+
let (mut tx, body) = Channel::<Bytes, Error>::new(1024);
204+
205+
tokio::spawn(async move {
206+
tx.send_data(Bytes::from("Hel")).await.unwrap();
207+
tx.send_data(Bytes::from("lo!")).await.unwrap();
208+
tx.abort(MSG);
209+
});
210+
211+
let err = body.collect().await.unwrap_err();
212+
assert_eq!(err, MSG);
213+
}
214+
215+
#[tokio::test]
216+
async fn aborts_after_trailers() {
217+
let (mut tx, body) = Channel::<Bytes, Error>::new(1024);
218+
219+
tokio::spawn(async move {
220+
tx.send_data(Bytes::from("Hel")).await.unwrap();
221+
tx.send_data(Bytes::from("lo!")).await.unwrap();
222+
let mut trailers = HeaderMap::new();
223+
trailers.insert(
224+
HeaderName::from_static("foo"),
225+
HeaderValue::from_static("bar"),
226+
);
227+
tx.send_trailers(trailers).await.unwrap();
228+
tx.abort(MSG);
229+
});
230+
231+
let err = body.collect().await.unwrap_err();
232+
assert_eq!(err, MSG);
233+
}
234+
}

http-body-util/src/lib.rs

+6
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ mod full;
1515
mod limited;
1616
mod stream;
1717

18+
#[cfg(feature = "channel")]
19+
pub mod channel;
20+
1821
mod util;
1922

2023
use self::combinators::{BoxBody, MapErr, MapFrame, UnsyncBoxBody};
@@ -26,6 +29,9 @@ pub use self::full::Full;
2629
pub use self::limited::{LengthLimitError, Limited};
2730
pub use self::stream::{BodyDataStream, BodyStream, StreamBody};
2831

32+
#[cfg(feature = "channel")]
33+
pub use self::channel::Channel;
34+
2935
/// An extension trait for [`http_body::Body`] adding various combinators and adapters
3036
pub trait BodyExt: http_body::Body {
3137
/// Returns a future that resolves to the next [`Frame`], if any.

0 commit comments

Comments
 (0)