Skip to content

Commit 59564ff

Browse files
feat(s2n-quic-dc): add ClientConfirm subscriber (#2274)
* feat(s2n-quic-dc): add ClientConfirm subscriber * PR feedback * handle server case * clippy
1 parent fcd9a1b commit 59564ff

File tree

3 files changed

+208
-26
lines changed

3 files changed

+208
-26
lines changed

quic/s2n-quic/src/provider/dc.rs

+4
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,14 @@
33

44
//! Provides dc support
55
6+
mod confirm;
7+
68
use s2n_quic_core::dc::Disabled;
79

810
// these imports are only accessible if the unstable feature is enabled
911
#[allow(unused_imports)]
12+
pub use confirm::ConfirmComplete;
13+
#[allow(unused_imports)]
1014
pub use s2n_quic_core::dc::{ApplicationParams, ConnectionInfo, Endpoint, Path};
1115

1216
pub trait Provider {
+144
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
use crate::Connection;
5+
use core::task::{Context, Poll, Waker};
6+
use s2n_quic_core::{
7+
connection,
8+
connection::Error,
9+
ensure,
10+
event::{
11+
api as events,
12+
api::{ConnectionInfo, ConnectionMeta, DcState, EndpointType, Subscriber},
13+
},
14+
};
15+
use std::io;
16+
17+
/// `event::Subscriber` used for ensuring an s2n-quic client or server negotiating dc
18+
/// waits for the dc handshake to complete
19+
pub struct ConfirmComplete;
20+
impl ConfirmComplete {
21+
/// Blocks the task until the provided connection has either completed the dc handshake or closed
22+
/// with an error
23+
pub async fn wait_ready(conn: &mut Connection) -> io::Result<()> {
24+
core::future::poll_fn(|cx| {
25+
conn.query_event_context_mut(|context: &mut ConfirmContext| context.poll_ready(cx))
26+
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?
27+
})
28+
.await
29+
}
30+
}
31+
32+
#[derive(Default)]
33+
pub struct ConfirmContext {
34+
waker: Option<Waker>,
35+
state: State,
36+
}
37+
38+
impl ConfirmContext {
39+
/// Updates the state on the context
40+
fn update(&mut self, state: State) {
41+
self.state = state;
42+
43+
// notify the application that the state was updated
44+
self.wake();
45+
}
46+
47+
/// Polls the context for handshake completion
48+
fn poll_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
49+
match self.state {
50+
// if we're ready or have errored then let the application know
51+
State::Ready => Poll::Ready(Ok(())),
52+
State::Failed(error) => Poll::Ready(Err(error.into())),
53+
State::Waiting(_) => {
54+
// store the waker so we can notify the application of state updates
55+
self.waker = Some(cx.waker().clone());
56+
Poll::Pending
57+
}
58+
}
59+
}
60+
61+
/// notify the application of a state update
62+
fn wake(&mut self) {
63+
if let Some(waker) = self.waker.take() {
64+
waker.wake();
65+
}
66+
}
67+
}
68+
69+
impl Drop for ConfirmContext {
70+
// make sure the application is notified that we're closing the connection
71+
fn drop(&mut self) {
72+
if matches!(self.state, State::Waiting(_)) {
73+
self.state = State::Failed(connection::Error::unspecified());
74+
}
75+
self.wake();
76+
}
77+
}
78+
79+
enum State {
80+
Waiting(Option<DcState>),
81+
Ready,
82+
Failed(connection::Error),
83+
}
84+
85+
impl Default for State {
86+
fn default() -> Self {
87+
State::Waiting(None)
88+
}
89+
}
90+
91+
impl Subscriber for ConfirmComplete {
92+
type ConnectionContext = ConfirmContext;
93+
94+
#[inline]
95+
fn create_connection_context(
96+
&mut self,
97+
_: &ConnectionMeta,
98+
_info: &ConnectionInfo,
99+
) -> Self::ConnectionContext {
100+
ConfirmContext::default()
101+
}
102+
103+
#[inline]
104+
fn on_connection_closed(
105+
&mut self,
106+
context: &mut Self::ConnectionContext,
107+
meta: &ConnectionMeta,
108+
event: &events::ConnectionClosed,
109+
) {
110+
ensure!(matches!(context.state, State::Waiting(_)));
111+
112+
match (&meta.endpoint_type, event.error, &context.state) {
113+
(
114+
EndpointType::Server { .. },
115+
Error::Closed { .. },
116+
State::Waiting(Some(DcState::PathSecretsReady { .. })),
117+
) => {
118+
// The client may close the connection immediately after the dc handshake completes,
119+
// before it sends acknowledgement of the server's DC_STATELESS_RESET_TOKENS.
120+
// Since the server has already moved into the PathSecretsReady state, this can be considered
121+
// as a successful completion of the dc handshake.
122+
context.update(State::Ready)
123+
}
124+
_ => context.update(State::Failed(event.error)),
125+
}
126+
}
127+
128+
#[inline]
129+
fn on_dc_state_changed(
130+
&mut self,
131+
context: &mut Self::ConnectionContext,
132+
_meta: &ConnectionMeta,
133+
event: &events::DcStateChanged,
134+
) {
135+
ensure!(matches!(context.state, State::Waiting(_)));
136+
137+
if let DcState::Complete { .. } = event.state {
138+
// notify the application that the dc handshake has completed
139+
context.update(State::Ready);
140+
} else {
141+
context.update(State::Waiting(Some(event.state.clone())));
142+
}
143+
}
144+
}

quic/s2n-quic/src/tests/dc.rs

+60-26
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
// SPDX-License-Identifier: Apache-2.0
33

44
use super::*;
5-
use crate::{client, client::ClientProviders, server, server::ServerProviders};
5+
use crate::{client, client::ClientProviders, provider::dc, server, server::ServerProviders};
66
use s2n_quic_core::{
77
dc::testing::MockDcEndpoint,
88
event::{api::DcState, Timestamp},
99
stateless_reset::token::testing::{TEST_TOKEN_1, TEST_TOKEN_2},
1010
};
11+
use std::io::ErrorKind;
1112

1213
// Client Server
1314
//
@@ -42,7 +43,7 @@ fn dc_handshake_self_test() {
4243
let server = Server::builder().with_tls(SERVER_CERTS).unwrap();
4344
let client = Client::builder().with_tls(certificates::CERT_PEM).unwrap();
4445

45-
self_test(server, client);
46+
self_test(server, client, None);
4647
}
4748

4849
// Client Server
@@ -81,17 +82,28 @@ fn dc_mtls_handshake_self_test() {
8182
let client_tls = build_client_mtls_provider(certificates::MTLS_CA_CERT).unwrap();
8283
let client = Client::builder().with_tls(client_tls).unwrap();
8384

84-
self_test(server, client);
85+
self_test(server, client, None);
86+
}
87+
88+
#[test]
89+
fn dc_mtls_handshake_auth_failure_self_test() {
90+
let server_tls = build_server_mtls_provider(certificates::UNTRUSTED_CERT_PEM).unwrap();
91+
let server = Server::builder().with_tls(server_tls).unwrap();
92+
93+
let client_tls = build_client_mtls_provider(certificates::MTLS_CA_CERT).unwrap();
94+
let client = Client::builder().with_tls(client_tls).unwrap();
95+
96+
self_test(server, client, Some(ErrorKind::ConnectionReset));
8597
}
8698

8799
fn self_test<S: ServerProviders, C: ClientProviders>(
88100
server: server::Builder<S>,
89101
client: client::Builder<C>,
102+
expected_error: Option<ErrorKind>,
90103
) {
91104
let model = Model::default();
92105
let rtt = Duration::from_millis(100);
93106
model.set_delay(rtt / 2);
94-
const LEN: usize = 1000;
95107

96108
let server_subscriber = DcStateChanged::new();
97109
let server_events = server_subscriber.clone();
@@ -103,60 +115,68 @@ fn self_test<S: ServerProviders, C: ClientProviders>(
103115
test(model, |handle| {
104116
let mut server = server
105117
.with_io(handle.builder().build()?)?
106-
.with_event((tracing_events(), server_subscriber))?
118+
.with_event((dc::ConfirmComplete, (tracing_events(), server_subscriber)))?
107119
.with_random(Random::with_seed(456))?
108120
.with_dc(MockDcEndpoint::new(&server_tokens))?
109121
.start()?;
110122

111123
let addr = server.local_addr()?;
112124
spawn(async move {
113-
let mut conn = server.accept().await.unwrap();
114-
let mut stream = conn.open_bidirectional_stream().await.unwrap();
115-
stream.send(vec![42; LEN].into()).await.unwrap();
116-
stream.flush().await.unwrap();
125+
let conn = server.accept().await;
126+
if expected_error.is_some() {
127+
assert!(conn.is_none());
128+
} else {
129+
assert!(dc::ConfirmComplete::wait_ready(&mut conn.unwrap())
130+
.await
131+
.is_ok());
132+
}
117133
});
118134

119135
let client = client
120136
.with_io(handle.builder().build().unwrap())?
121-
.with_event((tracing_events(), client_subscriber))?
137+
.with_event((dc::ConfirmComplete, (tracing_events(), client_subscriber)))?
122138
.with_random(Random::with_seed(456))?
123139
.with_dc(MockDcEndpoint::new(&client_tokens))?
124140
.start()?;
125141

142+
let client_events = client_events.clone();
143+
126144
primary::spawn(async move {
127145
let connect = Connect::new(addr).with_server_name("localhost");
128146
let mut conn = client.connect(connect).await.unwrap();
129-
let mut stream = conn.accept_bidirectional_stream().await.unwrap().unwrap();
147+
let result = dc::ConfirmComplete::wait_ready(&mut conn).await;
130148

131-
let mut recv_len = 0;
132-
while let Some(chunk) = stream.receive().await.unwrap() {
133-
recv_len += chunk.len();
149+
if let Some(error) = expected_error {
150+
assert_eq!(error, result.err().unwrap().kind());
151+
} else {
152+
assert!(result.is_ok());
153+
let client_events = client_events.events().lock().unwrap().clone();
154+
assert_dc_complete(&client_events);
155+
// wait briefly so the ack for the `DC_STATELESS_RESET_TOKENS` frame from the server is sent
156+
// before the client closes the connection. This is only necessary to confirm the `dc::State`
157+
// on the server moves to `DcState::Complete`
158+
delay(Duration::from_millis(100)).await;
134159
}
135-
assert_eq!(LEN, recv_len);
136160
});
137161

138162
Ok(addr)
139163
})
140164
.unwrap();
141165

166+
if expected_error.is_some() {
167+
return;
168+
}
169+
142170
let server_events = server_events.events().lock().unwrap().clone();
143171
let client_events = client_events.events().lock().unwrap().clone();
144172

173+
assert_dc_complete(&server_events);
174+
assert_dc_complete(&client_events);
175+
145176
// 3 state transitions (VersionNegotiated -> PathSecretsReady -> Complete)
146177
assert_eq!(3, server_events.len());
147178
assert_eq!(3, client_events.len());
148179

149-
for events in [server_events.clone(), client_events.clone()] {
150-
if let DcState::VersionNegotiated { version, .. } = events[0].state {
151-
assert_eq!(version, s2n_quic_core::dc::SUPPORTED_VERSIONS[0]);
152-
} else {
153-
panic!("VersionNegotiated should be the first dc state");
154-
}
155-
156-
assert!(matches!(events[1].state, DcState::PathSecretsReady { .. }));
157-
assert!(matches!(events[2].state, DcState::Complete { .. }));
158-
}
159-
160180
// Server path secrets are ready in 1.5 RTTs measured from the start of the test, since it takes
161181
// .5 RTT for the Initial from the client to reach the server
162182
assert_eq!(
@@ -175,6 +195,20 @@ fn self_test<S: ServerProviders, C: ClientProviders>(
175195
assert_eq!(rtt * 2, client_events[2].timestamp.duration_since_start());
176196
}
177197

198+
fn assert_dc_complete(events: &[DcStateChangedEvent]) {
199+
// 3 state transitions (VersionNegotiated -> PathSecretsReady -> Complete)
200+
assert_eq!(3, events.len());
201+
202+
if let DcState::VersionNegotiated { version, .. } = events[0].state {
203+
assert_eq!(version, s2n_quic_core::dc::SUPPORTED_VERSIONS[0]);
204+
} else {
205+
panic!("VersionNegotiated should be the first dc state");
206+
}
207+
208+
assert!(matches!(events[1].state, DcState::PathSecretsReady { .. }));
209+
assert!(matches!(events[2].state, DcState::Complete { .. }));
210+
}
211+
178212
#[derive(Clone)]
179213
struct DcStateChangedEvent {
180214
timestamp: Timestamp,

0 commit comments

Comments
 (0)