Skip to content

Commit b3a10a1

Browse files
fix(s2n-quic): ConfirmComplete with handshake de-duplication (#2307)
Previously, we only stored a single Waker, which meant that when the underlying s2n-quic connection/handle was returned to applications multiple times it ended up only waking one of the application tasks. This modifies the ConfirmComplete logic such that we now store as many Wakers as needed (using Tokio's watch channel) and wake all of them on changes.
1 parent b4ba62d commit b3a10a1

File tree

3 files changed

+98
-90
lines changed

3 files changed

+98
-90
lines changed

quic/s2n-quic/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ s2n-quic-rustls = { version = "=0.44.1", path = "../s2n-quic-rustls", optional =
7777
s2n-quic-tls = { version = "=0.44.1", path = "../s2n-quic-tls", optional = true }
7878
s2n-quic-tls-default = { version = "=0.44.1", path = "../s2n-quic-tls-default", optional = true }
7979
s2n-quic-transport = { version = "=0.44.1", path = "../s2n-quic-transport" }
80-
tokio = { version = "1", default-features = false }
80+
tokio = { version = "1", default-features = false, features = ["sync"] }
8181
zerocopy = { version = "0.7", optional = true, features = ["derive"] }
8282
zeroize = { version = "1", optional = true, default-features = false }
8383

quic/s2n-quic/src/provider/dc/confirm.rs

+43-47
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
// SPDX-License-Identifier: Apache-2.0
33

44
use crate::Connection;
5-
use core::task::{Context, Poll, Waker};
65
use s2n_quic_core::{
76
connection,
87
connection::Error,
@@ -13,6 +12,7 @@ use s2n_quic_core::{
1312
},
1413
};
1514
use std::io;
15+
use tokio::sync::watch;
1616

1717
/// `event::Subscriber` used for ensuring an s2n-quic client or server negotiating dc
1818
/// waits for the dc handshake to complete
@@ -21,58 +21,54 @@ impl ConfirmComplete {
2121
/// Blocks the task until the provided connection has either completed the dc handshake or closed
2222
/// with an error
2323
pub async fn wait_ready(conn: &mut Connection) -> io::Result<()> {
24-
core::future::poll_fn(|cx| {
25-
conn.query_event_context_mut(|context: &mut ConfirmContext| context.poll_ready(cx))
26-
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?
27-
})
28-
.await
24+
let mut receiver = conn
25+
.query_event_context_mut(|context: &mut ConfirmContext| context.sender.subscribe())
26+
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
27+
28+
loop {
29+
match &*receiver.borrow_and_update() {
30+
// if we're ready or have errored then let the application know
31+
State::Ready => return Ok(()),
32+
State::Failed(error) => return Err((*error).into()),
33+
State::Waiting(_) => {}
34+
}
35+
36+
if receiver.changed().await.is_err() {
37+
return Err(io::Error::new(
38+
io::ErrorKind::Other,
39+
"never reached terminal state",
40+
));
41+
}
42+
}
2943
}
3044
}
3145

32-
#[derive(Default)]
3346
pub struct ConfirmContext {
34-
waker: Option<Waker>,
35-
state: State,
47+
sender: watch::Sender<State>,
48+
}
49+
50+
impl Default for ConfirmContext {
51+
fn default() -> Self {
52+
let (sender, _receiver) = watch::channel(State::default());
53+
Self { sender }
54+
}
3655
}
3756

3857
impl ConfirmContext {
3958
/// Updates the state on the context
4059
fn update(&mut self, state: State) {
41-
self.state = state;
42-
43-
// notify the application that the state was updated
44-
self.wake();
45-
}
46-
47-
/// Polls the context for handshake completion
48-
fn poll_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
49-
match self.state {
50-
// if we're ready or have errored then let the application know
51-
State::Ready => Poll::Ready(Ok(())),
52-
State::Failed(error) => Poll::Ready(Err(error.into())),
53-
State::Waiting(_) => {
54-
// store the waker so we can notify the application of state updates
55-
self.waker = Some(cx.waker().clone());
56-
Poll::Pending
57-
}
58-
}
59-
}
60-
61-
/// notify the application of a state update
62-
fn wake(&mut self) {
63-
if let Some(waker) = self.waker.take() {
64-
waker.wake();
65-
}
60+
self.sender.send_replace(state);
6661
}
6762
}
6863

6964
impl Drop for ConfirmContext {
7065
// make sure the application is notified that we're closing the connection
7166
fn drop(&mut self) {
72-
if matches!(self.state, State::Waiting(_)) {
73-
self.state = State::Failed(connection::Error::unspecified());
74-
}
75-
self.wake();
67+
self.sender.send_modify(|state| {
68+
if matches!(state, State::Waiting(_)) {
69+
*state = State::Failed(connection::Error::unspecified());
70+
}
71+
});
7672
}
7773
}
7874

@@ -107,14 +103,14 @@ impl Subscriber for ConfirmComplete {
107103
meta: &ConnectionMeta,
108104
event: &events::ConnectionClosed,
109105
) {
110-
ensure!(matches!(context.state, State::Waiting(_)));
111-
112-
match (&meta.endpoint_type, event.error, &context.state) {
113-
(
114-
EndpointType::Server { .. },
115-
Error::Closed { .. },
116-
State::Waiting(Some(DcState::PathSecretsReady { .. })),
117-
) => {
106+
ensure!(matches!(*context.sender.borrow(), State::Waiting(_)));
107+
let is_ready = matches!(
108+
*context.sender.borrow(),
109+
State::Waiting(Some(DcState::PathSecretsReady { .. }))
110+
);
111+
112+
match (&meta.endpoint_type, event.error, is_ready) {
113+
(EndpointType::Server { .. }, Error::Closed { .. }, true) => {
118114
// The client may close the connection immediately after the dc handshake completes,
119115
// before it sends acknowledgement of the server's DC_STATELESS_RESET_TOKENS.
120116
// Since the server has already moved into the PathSecretsReady state, this can be considered
@@ -132,7 +128,7 @@ impl Subscriber for ConfirmComplete {
132128
_meta: &ConnectionMeta,
133129
event: &events::DcStateChanged,
134130
) {
135-
ensure!(matches!(context.state, State::Waiting(_)));
131+
ensure!(matches!(*context.sender.borrow(), State::Waiting(_)));
136132

137133
match event.state {
138134
DcState::NoVersionNegotiated { .. } => context.update(State::Failed(

quic/s2n-quic/src/tests/dc.rs

+54-42
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ fn dc_handshake_self_test() -> Result<()> {
6969
.with_tls(certificates::CERT_PEM)?
7070
.with_dc(MockDcEndpoint::new(&CLIENT_TOKENS))?;
7171

72-
self_test(server, client, None, None)?;
72+
self_test(server, client, true, None, None)?;
7373

7474
Ok(())
7575
}
@@ -114,7 +114,7 @@ fn dc_mtls_handshake_self_test() -> Result<()> {
114114
.with_tls(client_tls)?
115115
.with_dc(MockDcEndpoint::new(&SERVER_TOKENS))?;
116116

117-
self_test(server, client, None, None)?;
117+
self_test(server, client, true, None, None)?;
118118

119119
Ok(())
120120
}
@@ -143,7 +143,7 @@ fn dc_mtls_handshake_auth_failure_self_test() -> Result<()> {
143143
}
144144
.into();
145145

146-
self_test(server, client, Some(expected_client_error), None)?;
146+
self_test(server, client, true, Some(expected_client_error), None)?;
147147

148148
Ok(())
149149
}
@@ -181,6 +181,7 @@ fn dc_mtls_handshake_server_not_supported_self_test() -> Result<()> {
181181
self_test(
182182
server,
183183
client,
184+
true,
184185
Some(connection::Error::invalid_configuration(
185186
"peer does not support specified dc versions",
186187
)),
@@ -228,6 +229,7 @@ fn dc_mtls_handshake_client_not_supported_self_test() -> Result<()> {
228229
self_test(
229230
server,
230231
client,
232+
false,
231233
Some(expected_client_error),
232234
Some(connection::Error::invalid_configuration(
233235
"peer does not support specified dc versions",
@@ -266,7 +268,7 @@ fn dc_possible_secret_control_packet(
266268
.with_dc(dc_endpoint)?
267269
.with_packet_interceptor(RandomShort::default())?;
268270

269-
let (client_events, _server_events) = self_test(server, client, None, None)?;
271+
let (client_events, _server_events) = self_test(server, client, true, None, None)?;
270272

271273
assert_eq!(
272274
1,
@@ -297,6 +299,7 @@ fn dc_possible_secret_control_packet(
297299
fn self_test<S: ServerProviders, C: ClientProviders>(
298300
server: server::Builder<S>,
299301
client: client::Builder<C>,
302+
client_has_dc: bool,
300303
expected_client_error: Option<connection::Error>,
301304
expected_server_error: Option<connection::Error>,
302305
) -> Result<(DcRecorder, DcRecorder)> {
@@ -318,18 +321,21 @@ fn self_test<S: ServerProviders, C: ClientProviders>(
318321

319322
let addr = server.local_addr()?;
320323

324+
let expected_count = 1 + client_has_dc as usize;
321325
spawn(async move {
322-
if let Some(mut conn) = server.accept().await {
323-
let result = dc::ConfirmComplete::wait_ready(&mut conn).await;
324-
325-
if let Some(error) = expected_server_error {
326-
assert_eq!(error, convert_io_result(result).unwrap());
327-
328-
if expected_client_error.is_some() {
329-
conn.close(SERVER_CLOSE_ERROR_CODE.into());
326+
for _ in 0..expected_count {
327+
if let Some(mut conn) = server.accept().await {
328+
let result = dc::ConfirmComplete::wait_ready(&mut conn).await;
329+
330+
if let Some(error) = expected_server_error {
331+
assert_eq!(error, convert_io_result(result).unwrap());
332+
333+
if expected_client_error.is_some() {
334+
conn.close(SERVER_CLOSE_ERROR_CODE.into());
335+
}
336+
} else {
337+
assert!(result.is_ok());
330338
}
331-
} else {
332-
assert!(result.is_ok());
333339
}
334340
}
335341
});
@@ -340,35 +346,41 @@ fn self_test<S: ServerProviders, C: ClientProviders>(
340346
.with_random(Random::with_seed(456))?
341347
.start()?;
342348

343-
let client_events = client_events.clone();
344-
345-
primary::spawn(async move {
346-
let connect = Connect::new(addr).with_server_name("localhost");
347-
let mut conn = client.connect(connect).await.unwrap();
348-
let result = dc::ConfirmComplete::wait_ready(&mut conn).await;
349-
350-
if let Some(error) = expected_client_error {
351-
assert_eq!(error, convert_io_result(result).unwrap());
352-
353-
if expected_server_error.is_some() {
354-
conn.close(CLIENT_CLOSE_ERROR_CODE.into());
355-
// wait for the server to assert the expected error before dropping
356-
delay(Duration::from_millis(100)).await;
349+
for _ in 0..expected_count {
350+
primary::spawn({
351+
let client = client.clone();
352+
let client_events = client_events.clone();
353+
async move {
354+
let connect = Connect::new(addr)
355+
.with_server_name("localhost")
356+
.with_deduplicate(client_has_dc);
357+
let mut conn = client.connect(connect).await.unwrap();
358+
let result = dc::ConfirmComplete::wait_ready(&mut conn).await;
359+
360+
if let Some(error) = expected_client_error {
361+
assert_eq!(error, convert_io_result(result).unwrap());
362+
363+
if expected_server_error.is_some() {
364+
conn.close(CLIENT_CLOSE_ERROR_CODE.into());
365+
// wait for the server to assert the expected error before dropping
366+
delay(Duration::from_millis(100)).await;
367+
}
368+
} else {
369+
assert!(result.is_ok());
370+
let client_events = client_events
371+
.dc_state_changed_events()
372+
.lock()
373+
.unwrap()
374+
.clone();
375+
assert_dc_complete(&client_events);
376+
// wait briefly so the ack for the `DC_STATELESS_RESET_TOKENS` frame from the server is sent
377+
// before the client closes the connection. This is only necessary to confirm the `dc::State`
378+
// on the server moves to `DcState::Complete`
379+
delay(Duration::from_millis(100)).await;
380+
}
357381
}
358-
} else {
359-
assert!(result.is_ok());
360-
let client_events = client_events
361-
.dc_state_changed_events()
362-
.lock()
363-
.unwrap()
364-
.clone();
365-
assert_dc_complete(&client_events);
366-
// wait briefly so the ack for the `DC_STATELESS_RESET_TOKENS` frame from the server is sent
367-
// before the client closes the connection. This is only necessary to confirm the `dc::State`
368-
// on the server moves to `DcState::Complete`
369-
delay(Duration::from_millis(100)).await;
370-
}
371-
});
382+
});
383+
}
372384

373385
Ok(addr)
374386
})

0 commit comments

Comments
 (0)