Skip to content

Commit a77d866

Browse files
authored
feat(client): Add connection capturing API to hyper-util (#112)
- rework features to allow enabling only tokio/sync for the client - a `capture_connection` API
1 parent f87fe0d commit a77d866

File tree

5 files changed

+235
-5
lines changed

5 files changed

+235
-5
lines changed

Cargo.toml

+3-3
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ pin-project-lite = "0.2.4"
2626
futures-channel = { version = "0.3", optional = true }
2727
socket2 = { version = "0.5", optional = true, features = ["all"] }
2828
tracing = { version = "0.1", default-features = false, features = ["std"], optional = true }
29-
tokio = { version = "1", optional = true, features = ["net", "rt", "time"] }
29+
tokio = { version = "1", optional = true, default-features = false }
3030
tower-service ={ version = "0.3", optional = true }
3131
tower = { version = "0.4.1", optional = true, default-features = false, features = ["make", "util"] }
3232

@@ -57,7 +57,7 @@ full = [
5757
]
5858

5959
client = ["hyper/client", "dep:tracing", "dep:futures-channel", "dep:tower", "dep:tower-service"]
60-
client-legacy = ["client", "dep:socket2"]
60+
client-legacy = ["client", "dep:socket2", "tokio/sync"]
6161

6262
server = ["hyper/server"]
6363
server-auto = ["server", "http1", "http2"]
@@ -67,7 +67,7 @@ service = ["dep:tower", "dep:tower-service"]
6767
http1 = ["hyper/http1"]
6868
http2 = ["hyper/http2"]
6969

70-
tokio = ["dep:tokio"]
70+
tokio = ["dep:tokio", "tokio/net", "tokio/rt", "tokio/time"]
7171

7272
# internal features used in CI
7373
__internal_happy_eyeballs_tests = []

src/client/legacy/client.rs

+5
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use hyper::rt::Timer;
1818
use hyper::{body::Body, Method, Request, Response, Uri, Version};
1919
use tracing::{debug, trace, warn};
2020

21+
use super::connect::capture::CaptureConnectionExtension;
2122
#[cfg(feature = "tokio")]
2223
use super::connect::HttpConnector;
2324
use super::connect::{Alpn, Connect, Connected, Connection};
@@ -265,6 +266,10 @@ where
265266
) -> Result<Response<hyper::body::Incoming>, Error> {
266267
let mut pooled = self.connection_for(pool_key).await?;
267268

269+
req.extensions_mut()
270+
.get_mut::<CaptureConnectionExtension>()
271+
.map(|conn| conn.set(&pooled.conn_info));
272+
268273
if pooled.is_http1() {
269274
if req.version() == Version::HTTP_2 {
270275
warn!("Connection is HTTP/1, but request requires HTTP/2");

src/client/legacy/connect/capture.rs

+191
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
use std::{ops::Deref, sync::Arc};
2+
3+
use http::Request;
4+
use tokio::sync::watch;
5+
6+
use super::Connected;
7+
8+
/// [`CaptureConnection`] allows callers to capture [`Connected`] information
9+
///
10+
/// To capture a connection for a request, use [`capture_connection`].
11+
#[derive(Debug, Clone)]
12+
pub struct CaptureConnection {
13+
rx: watch::Receiver<Option<Connected>>,
14+
}
15+
16+
/// Capture the connection for a given request
17+
///
18+
/// When making a request with Hyper, the underlying connection must implement the [`Connection`] trait.
19+
/// [`capture_connection`] allows a caller to capture the returned [`Connected`] structure as soon
20+
/// as the connection is established.
21+
///
22+
/// *Note*: If establishing a connection fails, [`CaptureConnection::connection_metadata`] will always return none.
23+
///
24+
/// # Examples
25+
///
26+
/// **Synchronous access**:
27+
/// The [`CaptureConnection::connection_metadata`] method allows callers to check if a connection has been
28+
/// established. This is ideal for situations where you are certain the connection has already
29+
/// been established (e.g. after the response future has already completed).
30+
/// ```rust
31+
/// use hyper_util::client::legacy::connect::capture_connection;
32+
/// let mut request = http::Request::builder()
33+
/// .uri("http://foo.com")
34+
/// .body(())
35+
/// .unwrap();
36+
///
37+
/// let captured_connection = capture_connection(&mut request);
38+
/// // some time later after the request has been sent...
39+
/// let connection_info = captured_connection.connection_metadata();
40+
/// println!("we are connected! {:?}", connection_info.as_ref());
41+
/// ```
42+
///
43+
/// **Asynchronous access**:
44+
/// The [`CaptureConnection::wait_for_connection_metadata`] method returns a future resolves as soon as the
45+
/// connection is available.
46+
///
47+
/// ```rust
48+
/// # #[cfg(feature = "tokio")]
49+
/// # async fn example() {
50+
/// use hyper_util::client::legacy::connect::capture_connection;
51+
/// use hyper_util::client::legacy::Client;
52+
/// use hyper_util::rt::TokioExecutor;
53+
/// use bytes::Bytes;
54+
/// use http_body_util::Empty;
55+
/// let mut request = http::Request::builder()
56+
/// .uri("http://foo.com")
57+
/// .body(Empty::<Bytes>::new())
58+
/// .unwrap();
59+
///
60+
/// let mut captured = capture_connection(&mut request);
61+
/// tokio::task::spawn(async move {
62+
/// let connection_info = captured.wait_for_connection_metadata().await;
63+
/// println!("we are connected! {:?}", connection_info.as_ref());
64+
/// });
65+
///
66+
/// let client = Client::builder(TokioExecutor::new()).build_http();
67+
/// client.request(request).await.expect("request failed");
68+
/// # }
69+
/// ```
70+
pub fn capture_connection<B>(request: &mut Request<B>) -> CaptureConnection {
71+
let (tx, rx) = CaptureConnection::new();
72+
request.extensions_mut().insert(tx);
73+
rx
74+
}
75+
76+
/// TxSide for [`CaptureConnection`]
77+
///
78+
/// This is inserted into `Extensions` to allow Hyper to back channel connection info
79+
#[derive(Clone)]
80+
pub(crate) struct CaptureConnectionExtension {
81+
tx: Arc<watch::Sender<Option<Connected>>>,
82+
}
83+
84+
impl CaptureConnectionExtension {
85+
pub(crate) fn set(&self, connected: &Connected) {
86+
self.tx.send_replace(Some(connected.clone()));
87+
}
88+
}
89+
90+
impl CaptureConnection {
91+
/// Internal API to create the tx and rx half of [`CaptureConnection`]
92+
pub(crate) fn new() -> (CaptureConnectionExtension, Self) {
93+
let (tx, rx) = watch::channel(None);
94+
(
95+
CaptureConnectionExtension { tx: Arc::new(tx) },
96+
CaptureConnection { rx },
97+
)
98+
}
99+
100+
/// Retrieve the connection metadata, if available
101+
pub fn connection_metadata(&self) -> impl Deref<Target = Option<Connected>> + '_ {
102+
self.rx.borrow()
103+
}
104+
105+
/// Wait for the connection to be established
106+
///
107+
/// If a connection was established, this will always return `Some(...)`. If the request never
108+
/// successfully connected (e.g. DNS resolution failure), this method will never return.
109+
pub async fn wait_for_connection_metadata(
110+
&mut self,
111+
) -> impl Deref<Target = Option<Connected>> + '_ {
112+
if self.rx.borrow().is_some() {
113+
return self.rx.borrow();
114+
}
115+
let _ = self.rx.changed().await;
116+
self.rx.borrow()
117+
}
118+
}
119+
120+
#[cfg(all(test, not(miri)))]
121+
mod test {
122+
use super::*;
123+
124+
#[test]
125+
fn test_sync_capture_connection() {
126+
let (tx, rx) = CaptureConnection::new();
127+
assert!(
128+
rx.connection_metadata().is_none(),
129+
"connection has not been set"
130+
);
131+
tx.set(&Connected::new().proxy(true));
132+
assert_eq!(
133+
rx.connection_metadata()
134+
.as_ref()
135+
.expect("connected should be set")
136+
.is_proxied(),
137+
true
138+
);
139+
140+
// ensure it can be called multiple times
141+
assert_eq!(
142+
rx.connection_metadata()
143+
.as_ref()
144+
.expect("connected should be set")
145+
.is_proxied(),
146+
true
147+
);
148+
}
149+
150+
#[tokio::test]
151+
async fn async_capture_connection() {
152+
let (tx, mut rx) = CaptureConnection::new();
153+
assert!(
154+
rx.connection_metadata().is_none(),
155+
"connection has not been set"
156+
);
157+
let test_task = tokio::spawn(async move {
158+
assert_eq!(
159+
rx.wait_for_connection_metadata()
160+
.await
161+
.as_ref()
162+
.expect("connection should be set")
163+
.is_proxied(),
164+
true
165+
);
166+
// can be awaited multiple times
167+
assert!(
168+
rx.wait_for_connection_metadata().await.is_some(),
169+
"should be awaitable multiple times"
170+
);
171+
172+
assert_eq!(rx.connection_metadata().is_some(), true);
173+
});
174+
// can't be finished, we haven't set the connection yet
175+
assert_eq!(test_task.is_finished(), false);
176+
tx.set(&Connected::new().proxy(true));
177+
178+
assert!(test_task.await.is_ok());
179+
}
180+
181+
#[tokio::test]
182+
async fn capture_connection_sender_side_dropped() {
183+
let (tx, mut rx) = CaptureConnection::new();
184+
assert!(
185+
rx.connection_metadata().is_none(),
186+
"connection has not been set"
187+
);
188+
drop(tx);
189+
assert!(rx.wait_for_connection_metadata().await.is_none());
190+
}
191+
}

src/client/legacy/connect/mod.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ pub mod dns;
7474
#[cfg(feature = "tokio")]
7575
mod http;
7676

77+
pub(crate) mod capture;
78+
pub use capture::{capture_connection, CaptureConnection};
79+
7780
pub use self::sealed::Connect;
7881

7982
/// Describes a type returned by a connector.
@@ -169,7 +172,6 @@ impl Connected {
169172

170173
// Don't public expose that `Connected` is `Clone`, unsure if we want to
171174
// keep that contract...
172-
#[cfg(feature = "http2")]
173175
pub(super) fn clone(&self) -> Connected {
174176
Connected {
175177
alpn: self.alpn,

tests/legacy_client.rs

+33-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use http_body_util::{Empty, Full, StreamBody};
1818
use hyper::body::Bytes;
1919
use hyper::body::Frame;
2020
use hyper::Request;
21-
use hyper_util::client::legacy::connect::HttpConnector;
21+
use hyper_util::client::legacy::connect::{capture_connection, HttpConnector};
2222
use hyper_util::client::legacy::Client;
2323
use hyper_util::rt::{TokioExecutor, TokioIo};
2424

@@ -876,3 +876,35 @@ fn alpn_h2() {
876876
);
877877
drop(client);
878878
}
879+
880+
#[cfg(not(miri))]
881+
#[test]
882+
fn capture_connection_on_client() {
883+
let _ = pretty_env_logger::try_init();
884+
885+
let rt = runtime();
886+
let connector = DebugConnector::new();
887+
888+
let client = Client::builder(TokioExecutor::new()).build(connector);
889+
890+
let server = TcpListener::bind("127.0.0.1:0").unwrap();
891+
let addr = server.local_addr().unwrap();
892+
thread::spawn(move || {
893+
let mut sock = server.accept().unwrap().0;
894+
//drop(server);
895+
sock.set_read_timeout(Some(Duration::from_secs(5))).unwrap();
896+
sock.set_write_timeout(Some(Duration::from_secs(5)))
897+
.unwrap();
898+
let mut buf = [0; 4096];
899+
sock.read(&mut buf).expect("read 1");
900+
sock.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
901+
.expect("write 1");
902+
});
903+
let mut req = Request::builder()
904+
.uri(&*format!("http://{}/a", addr))
905+
.body(Empty::<Bytes>::new())
906+
.unwrap();
907+
let captured_conn = capture_connection(&mut req);
908+
rt.block_on(client.request(req)).expect("200 OK");
909+
assert!(captured_conn.connection_metadata().is_some());
910+
}

0 commit comments

Comments
 (0)