Skip to content

Commit 7227160

Browse files
authored
s2n-tls-tokio: use s2n_shutdown_send instead of s2n_shutdown (aws#4374)
1 parent c128140 commit 7227160

File tree

3 files changed

+47
-92
lines changed

3 files changed

+47
-92
lines changed

bindings/rust/s2n-tls-tokio/src/lib.rs

+8-4
Original file line numberDiff line numberDiff line change
@@ -363,15 +363,19 @@ where
363363
fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
364364
ready!(self.as_mut().poll_blinding(ctx))?;
365365

366-
// s2n_shutdown must not be called again if it errors
366+
// s2n_shutdown_send must not be called again if it errors
367367
if self.shutdown_error.is_none() {
368368
let result = ready!(self.as_mut().with_io(ctx, |mut context| {
369-
context.conn.as_mut().poll_shutdown().map(|r| r.map(|_| ()))
369+
context
370+
.conn
371+
.as_mut()
372+
.poll_shutdown_send()
373+
.map(|r| r.map(|_| ()))
370374
}));
371375
if let Err(error) = result {
372376
self.shutdown_error = Some(error);
373-
// s2n_shutdown reading might have triggered blinding again
374-
ready!(self.as_mut().poll_blinding(ctx))?;
377+
// s2n_shutdown_send only writes, so will never trigger blinding again.
378+
// So we do not need to poll_blinding again after this error.
375379
}
376380
};
377381

bindings/rust/s2n-tls-tokio/tests/shutdown.rs

+23-88
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,6 @@ use tokio::{
1616

1717
pub mod common;
1818

19-
// An arbitrary but very long timeout.
20-
// No valid single IO operation should take anywhere near 10 minutes.
21-
pub const LONG_TIMEOUT: time::Duration = time::Duration::from_secs(600);
22-
2319
async fn read_until_shutdown<S: AsyncRead + AsyncWrite + Unpin>(
2420
stream: &mut TlsStream<S>,
2521
) -> Result<(), std::io::Error> {
@@ -166,18 +162,6 @@ async fn shutdown_with_blinding() -> Result<(), Box<dyn std::error::Error>> {
166162
let (mut client, mut server) =
167163
common::run_negotiate(&client, client_stream, &server, server_stream).await?;
168164

169-
// Attempt to shutdown the client. This will eventually fail because the
170-
// server has not written the close_notify message yet, but it will at least
171-
// write the close_notify message that the server needs.
172-
//
173-
// Because this test begins paused and relies on auto-advancing, this does
174-
// not actually require waiting LONG_TIMEOUT. See the tokio `pause()` docs:
175-
// https://docs.rs/tokio/latest/tokio/time/fn.pause.html
176-
//
177-
// TODO: replace this with a half-close once the bindings support half-close.
178-
let timeout = time::timeout(LONG_TIMEOUT, client.shutdown()).await;
179-
assert!(timeout.is_err());
180-
181165
// Setup a bad record for the next read
182166
overrides.next_read(Some(Box::new(|_, _, buf| {
183167
// Parsing the header is one of the blinded operations
@@ -202,53 +186,9 @@ async fn shutdown_with_blinding() -> Result<(), Box<dyn std::error::Error>> {
202186
// Server MUST eventually successfully shutdown
203187
assert!(result.is_ok());
204188

205-
// Shutdown MUST have sent the close_notify message needed by the peer
206-
// to also shutdown successfully.
207-
client.shutdown().await?;
208-
209-
Ok(())
210-
}
211-
212-
#[tokio::test(start_paused = true)]
213-
async fn shutdown_with_blinding_bad_close_record() -> Result<(), Box<dyn std::error::Error>> {
214-
let clock = common::TokioTime::default();
215-
let mut server_config = common::server_config()?;
216-
server_config.set_monotonic_clock(clock)?;
217-
218-
let client = TlsConnector::new(common::client_config()?.build()?);
219-
let server = TlsAcceptor::new(server_config.build()?);
220-
221-
let (server_stream, client_stream) = common::get_streams().await?;
222-
let server_stream = common::TestStream::new(server_stream);
223-
let overrides = server_stream.overrides();
224-
let (mut client, mut server) =
225-
common::run_negotiate(&client, client_stream, &server, server_stream).await?;
226-
227-
// Setup a bad record for the next read
228-
overrides.next_read(Some(Box::new(|_, _, buf| {
229-
// Parsing the header is one of the blinded operations
230-
// in s2n_shutdown, so provide a malformed header.
231-
let zeroed_header = [23, 0, 0, 0, 0];
232-
buf.put_slice(&zeroed_header);
233-
Ok(()).into()
234-
})));
235-
236-
let time_start = time::Instant::now();
237-
let result = server.shutdown().await;
238-
let time_elapsed = time_start.elapsed();
239-
240-
// Shutdown MUST NOT complete faster than minimal blinding time.
241-
assert!(time_elapsed > common::MIN_BLINDING_SECS);
242-
243-
// Shutdown MUST eventually complete with the correct error after blinding.
244-
let io_error = result.unwrap_err();
245-
let error: error::Error = io_error.try_into()?;
246-
assert!(error.kind() == error::ErrorType::ProtocolError);
247-
assert!(error.name() == "S2N_ERR_BAD_MESSAGE");
248-
249-
// Shutdown MUST have sent the close_notify message needed by the peer
250-
// to also shutdown successfully.
251-
client.shutdown().await?;
189+
// Shutdown MUST have sent the close_notify message needed for EOF.
190+
let mut received = [0; 1];
191+
assert!(client.read(&mut received).await? == 0);
252192

253193
Ok(())
254194
}
@@ -295,7 +235,7 @@ async fn shutdown_with_poll_blinding() -> Result<(), Box<dyn std::error::Error>>
295235
Ok(())
296236
}
297237

298-
#[tokio::test(start_paused = true)]
238+
#[tokio::test]
299239
async fn shutdown_with_tcp_error() -> Result<(), Box<dyn std::error::Error>> {
300240
let client = TlsConnector::new(common::client_config()?.build()?);
301241
let server = TlsAcceptor::new(common::server_config()?.build()?);
@@ -304,20 +244,9 @@ async fn shutdown_with_tcp_error() -> Result<(), Box<dyn std::error::Error>> {
304244
let server_stream = common::TestStream::new(server_stream);
305245
let overrides = server_stream.overrides();
306246

307-
let (mut client, mut server) =
247+
let (_, mut server) =
308248
common::run_negotiate(&client, client_stream, &server, server_stream).await?;
309249

310-
// Attempt to shutdown the client. This will eventually fail because the
311-
// server has not written the close_notify message yet, but it will at least
312-
// write the close_notify message that the server needs.
313-
//
314-
// Because this test begins paused and relies on auto-advancing, this does
315-
// not actually require waiting LONG_TIMEOUT. See the tokio `pause()` docs:
316-
// https://docs.rs/tokio/latest/tokio/time/fn.pause.html
317-
//
318-
// TODO: replace this with a half-close once the bindings support half-close.
319-
_ = time::timeout(time::Duration::from_secs(600), client.shutdown()).await;
320-
321250
// The underlying stream should return a unique error on shutdown
322251
overrides.next_shutdown(Some(Box::new(|_, _| {
323252
Ready(Err(io::Error::new(io::ErrorKind::Other, common::TEST_STR)))
@@ -343,22 +272,22 @@ async fn shutdown_with_tls_error_and_tcp_error() -> Result<(), Box<dyn std::erro
343272
let (_, mut server) =
344273
common::run_negotiate(&client, client_stream, &server, server_stream).await?;
345274

346-
// Both s2n_shutdown and the underlying stream should error on shutdown
347-
overrides.next_read(Some(Box::new(|_, _, _| {
275+
// Both s2n_shutdown_send and the underlying stream should error on shutdown
276+
overrides.next_write(Some(Box::new(|_, _, _| {
348277
Ready(Err(io::Error::from(io::ErrorKind::Other)))
349278
})));
350279
overrides.next_shutdown(Some(Box::new(|_, _| {
351280
Ready(Err(io::Error::new(io::ErrorKind::Other, common::TEST_STR)))
352281
})));
353282

354-
// Shutdown should complete with the correct error from s2n_shutdown
283+
// Shutdown should complete with the correct error from s2n_shutdown_send
355284
let result = server.shutdown().await;
356285
let io_error = result.unwrap_err();
357286
let error: error::Error = io_error.try_into()?;
358287
// Any non-blocking read error is translated as "IOError"
359288
assert!(error.kind() == error::ErrorType::IOError);
360289

361-
// Even if s2n_shutdown fails, we need to close the underlying stream.
290+
// Even if s2n_shutdown_send fails, we need to close the underlying stream.
362291
// Make sure we called our mock shutdown, consuming it.
363292
assert!(overrides.is_consumed());
364293

@@ -374,14 +303,11 @@ async fn shutdown_with_tls_error_and_tcp_delay() -> Result<(), Box<dyn std::erro
374303
let server_stream = common::TestStream::new(server_stream);
375304
let overrides = server_stream.overrides();
376305

377-
let (_, mut server) =
306+
let (mut client, mut server) =
378307
common::run_negotiate(&client, client_stream, &server, server_stream).await?;
379308

380-
// We want s2n_shutdown to fail on read in order to ensure that it is only
381-
// called once on failure.
382-
// If s2n_shutdown were called again, the second call would hang waiting
383-
// for nonexistent input from the peer.
384-
overrides.next_read(Some(Box::new(|_, _, _| {
309+
// We want s2n_shutdown_send to produce an error on write
310+
overrides.next_write(Some(Box::new(|_, _, _| {
385311
Ready(Err(io::Error::from(io::ErrorKind::Other)))
386312
})));
387313

@@ -391,16 +317,25 @@ async fn shutdown_with_tls_error_and_tcp_delay() -> Result<(), Box<dyn std::erro
391317
Pending
392318
})));
393319

394-
// Shutdown should complete with the correct error from s2n_shutdown
320+
// Shutdown should complete with the correct error from s2n_shutdown_send
395321
let result = server.shutdown().await;
396322
let io_error = result.unwrap_err();
397323
let error: error::Error = io_error.try_into()?;
398324
// Any non-blocking read error is translated as "IOError"
399325
assert!(error.kind() == error::ErrorType::IOError);
400326

401-
// Even if s2n_shutdown fails, we need to close the underlying stream.
327+
// Even if s2n_shutdown_send fails, we need to close the underlying stream.
402328
// Make sure we at least called our mock shutdown, consuming it.
403329
assert!(overrides.is_consumed());
404330

331+
// Since s2n_shutdown_send failed, we should NOT have sent a close_notify.
332+
// Make sure the peer doesn't receive a close_notify.
333+
// If this is not true, then we're incorrectly calling s2n_shutdown_send
334+
// again after an error.
335+
let mut received = [0; 1];
336+
let io_error = client.read(&mut received).await.unwrap_err();
337+
let error: error::Error = io_error.try_into()?;
338+
assert!(error.kind() == error::ErrorType::ConnectionClosed);
339+
405340
Ok(())
406341
}

bindings/rust/s2n-tls/src/connection.rs

+16
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,22 @@ impl Connection {
553553
}
554554
}
555555

556+
/// Attempts a graceful shutdown of the write side of a TLS connection.
557+
///
558+
/// Unlike Self::poll_shutdown, no reponse from the peer is necessary.
559+
/// If using TLS1.3, the connection can continue to be used for reading afterwards.
560+
pub fn poll_shutdown_send(&mut self) -> Poll<Result<&mut Self, Error>> {
561+
if !self.remaining_blinding_delay()?.is_zero() {
562+
return Poll::Pending;
563+
}
564+
let mut blocked = s2n_blocked_status::NOT_BLOCKED;
565+
unsafe {
566+
s2n_shutdown_send(self.connection.as_ptr(), &mut blocked)
567+
.into_poll()
568+
.map_ok(|_| self)
569+
}
570+
}
571+
556572
/// Returns the TLS alert code, if any
557573
pub fn alert(&self) -> Option<u8> {
558574
let alert =

0 commit comments

Comments
 (0)