Skip to content

Commit d53a491

Browse files
committed
review: fix <Channel<D, E> as Body>::poll_frame()
see: #140 (comment) this commit adds test coverage exposing the bug, and tightens the pattern used to match frames yielded by the data channel. now, when the channel is closed, a `None` will flow onwards and poll the error channel. `None` will be returned when the error channel is closed, which also indicates that the associated `Sender` has been dropped.
1 parent c6a4ffd commit d53a491

File tree

1 file changed

+85
-4
lines changed

1 file changed

+85
-4
lines changed

http-body-util/src/channel.rs

+85-4
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,14 @@ where
4848
let this = self.project();
4949

5050
match this.rx_frame.poll_recv(cx) {
51-
Poll::Ready(frame) => return Poll::Ready(frame.map(Ok)),
52-
Poll::Pending => {}
51+
Poll::Ready(frame @ Some(_)) => return Poll::Ready(frame.map(Ok)),
52+
Poll::Ready(None) | Poll::Pending => {}
5353
}
5454

5555
use core::future::Future;
5656
match this.rx_error.poll(cx) {
57-
Poll::Ready(err) => return Poll::Ready(err.ok().map(Err)),
57+
Poll::Ready(Ok(error)) => return Poll::Ready(Some(Err(error))),
58+
Poll::Ready(Err(_)) => return Poll::Ready(None),
5859
Poll::Pending => {}
5960
}
6061

@@ -131,13 +132,54 @@ mod tests {
131132
use super::*;
132133

133134
#[tokio::test]
134-
async fn works() {
135+
async fn empty() {
136+
let (tx, body) = Channel::<Bytes>::new(1024);
137+
drop(tx);
138+
139+
let collected = body.collect().await.unwrap();
140+
assert!(collected.trailers().is_none());
141+
assert!(collected.to_bytes().is_empty());
142+
}
143+
144+
#[tokio::test]
145+
async fn can_send_data() {
135146
let (mut tx, body) = Channel::<Bytes>::new(1024);
136147

137148
tokio::spawn(async move {
138149
tx.send_data(Bytes::from("Hel")).await.unwrap();
139150
tx.send_data(Bytes::from("lo!")).await.unwrap();
151+
});
152+
153+
let collected = body.collect().await.unwrap();
154+
assert!(collected.trailers().is_none());
155+
assert_eq!(collected.to_bytes(), "Hello!");
156+
}
157+
158+
#[tokio::test]
159+
async fn can_send_trailers() {
160+
let (mut tx, body) = Channel::<Bytes>::new(1024);
161+
162+
tokio::spawn(async move {
163+
let mut trailers = HeaderMap::new();
164+
trailers.insert(
165+
HeaderName::from_static("foo"),
166+
HeaderValue::from_static("bar"),
167+
);
168+
tx.send_trailers(trailers).await.unwrap();
169+
});
170+
171+
let collected = body.collect().await.unwrap();
172+
assert_eq!(collected.trailers().unwrap()["foo"], "bar");
173+
assert!(collected.to_bytes().is_empty());
174+
}
175+
176+
#[tokio::test]
177+
async fn can_send_both_data_and_trailers() {
178+
let (mut tx, body) = Channel::<Bytes>::new(1024);
140179

180+
tokio::spawn(async move {
181+
tx.send_data(Bytes::from("Hel")).await.unwrap();
182+
tx.send_data(Bytes::from("lo!")).await.unwrap();
141183
let mut trailers = HeaderMap::new();
142184
trailers.insert(
143185
HeaderName::from_static("foo"),
@@ -150,4 +192,43 @@ mod tests {
150192
assert_eq!(collected.trailers().unwrap()["foo"], "bar");
151193
assert_eq!(collected.to_bytes(), "Hello!");
152194
}
195+
196+
/// A stand-in for an error type, for unit tests.
197+
type Error = &'static str;
198+
/// An example error message.
199+
const MSG: Error = "oh no";
200+
201+
#[tokio::test]
202+
async fn aborts_before_trailers() {
203+
let (mut tx, body) = Channel::<Bytes, Error>::new(1024);
204+
205+
tokio::spawn(async move {
206+
tx.send_data(Bytes::from("Hel")).await.unwrap();
207+
tx.send_data(Bytes::from("lo!")).await.unwrap();
208+
tx.abort(MSG);
209+
});
210+
211+
let err = body.collect().await.unwrap_err();
212+
assert_eq!(err, MSG);
213+
}
214+
215+
#[tokio::test]
216+
async fn aborts_after_trailers() {
217+
let (mut tx, body) = Channel::<Bytes, Error>::new(1024);
218+
219+
tokio::spawn(async move {
220+
tx.send_data(Bytes::from("Hel")).await.unwrap();
221+
tx.send_data(Bytes::from("lo!")).await.unwrap();
222+
let mut trailers = HeaderMap::new();
223+
trailers.insert(
224+
HeaderName::from_static("foo"),
225+
HeaderValue::from_static("bar"),
226+
);
227+
tx.send_trailers(trailers).await.unwrap();
228+
tx.abort(MSG);
229+
});
230+
231+
let err = body.collect().await.unwrap_err();
232+
assert_eq!(err, MSG);
233+
}
153234
}

0 commit comments

Comments
 (0)