From 3275cfb638fe6cac6b0bbb1f60ee59eb499f6c2a Mon Sep 17 00:00:00 2001 From: Rustin Date: Sun, 7 Jan 2024 00:22:26 +0800 Subject: [PATCH] io: make `copy` cooperative (#6265) --- tokio/src/io/util/copy.rs | 75 +++++++++++++++++++++++++++- tokio/tests/io_copy.rs | 15 ++++++ tokio/tests/io_copy_bidirectional.rs | 25 ++++++++++ 3 files changed, 113 insertions(+), 2 deletions(-) diff --git a/tokio/src/io/util/copy.rs b/tokio/src/io/util/copy.rs index 8bd0bff7f2b..56310c86f59 100644 --- a/tokio/src/io/util/copy.rs +++ b/tokio/src/io/util/copy.rs @@ -82,6 +82,19 @@ impl CopyBuffer { R: AsyncRead + ?Sized, W: AsyncWrite + ?Sized, { + ready!(crate::trace::trace_leaf(cx)); + #[cfg(any( + feature = "fs", + feature = "io-std", + feature = "net", + feature = "process", + feature = "rt", + feature = "signal", + feature = "sync", + feature = "time", + ))] + // Keep track of task budget + let coop = ready!(crate::runtime::coop::poll_proceed(cx)); loop { // If our buffer is empty, then we need to read some data to // continue. @@ -90,13 +103,49 @@ impl CopyBuffer { self.cap = 0; match self.poll_fill_buf(cx, reader.as_mut()) { - Poll::Ready(Ok(())) => (), - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Ready(Ok(())) => { + #[cfg(any( + feature = "fs", + feature = "io-std", + feature = "net", + feature = "process", + feature = "rt", + feature = "signal", + feature = "sync", + feature = "time", + ))] + coop.made_progress(); + } + Poll::Ready(Err(err)) => { + #[cfg(any( + feature = "fs", + feature = "io-std", + feature = "net", + feature = "process", + feature = "rt", + feature = "signal", + feature = "sync", + feature = "time", + ))] + coop.made_progress(); + return Poll::Ready(Err(err)); + } Poll::Pending => { // Try flushing when the reader has no progress to avoid deadlock // when the reader depends on buffered writer. if self.need_flush { ready!(writer.as_mut().poll_flush(cx))?; + #[cfg(any( + feature = "fs", + feature = "io-std", + feature = "net", + feature = "process", + feature = "rt", + feature = "signal", + feature = "sync", + feature = "time", + ))] + coop.made_progress(); self.need_flush = false; } @@ -108,6 +157,17 @@ impl CopyBuffer { // If our buffer has some data, let's write it out! while self.pos < self.cap { let i = ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?; + #[cfg(any( + feature = "fs", + feature = "io-std", + feature = "net", + feature = "process", + feature = "rt", + feature = "signal", + feature = "sync", + feature = "time", + ))] + coop.made_progress(); if i == 0 { return Poll::Ready(Err(io::Error::new( io::ErrorKind::WriteZero, @@ -132,6 +192,17 @@ impl CopyBuffer { // data and finish the transfer. if self.pos == self.cap && self.read_done { ready!(writer.as_mut().poll_flush(cx))?; + #[cfg(any( + feature = "fs", + feature = "io-std", + feature = "net", + feature = "process", + feature = "rt", + feature = "signal", + feature = "sync", + feature = "time", + ))] + coop.made_progress(); return Poll::Ready(Ok(self.amt)); } } diff --git a/tokio/tests/io_copy.rs b/tokio/tests/io_copy.rs index 005e1701191..82d92a9688b 100644 --- a/tokio/tests/io_copy.rs +++ b/tokio/tests/io_copy.rs @@ -85,3 +85,18 @@ async fn proxy() { assert_eq!(n, 1024); } + +#[tokio::test] +async fn copy_is_cooperative() { + tokio::select! { + biased; + _ = async { + loop { + let mut reader: &[u8] = b"hello"; + let mut writer: Vec = vec![]; + let _ = io::copy(&mut reader, &mut writer).await; + } + } => {}, + _ = tokio::task::yield_now() => {} + } +} diff --git a/tokio/tests/io_copy_bidirectional.rs b/tokio/tests/io_copy_bidirectional.rs index 10eba3166ac..3cdce32d0ce 100644 --- a/tokio/tests/io_copy_bidirectional.rs +++ b/tokio/tests/io_copy_bidirectional.rs @@ -138,3 +138,28 @@ async fn immediate_exit_on_read_error() { assert!(copy_bidirectional(&mut a, &mut b).await.is_err()); } + +#[tokio::test] +async fn copy_bidirectional_is_cooperative() { + tokio::select! { + biased; + _ = async { + loop { + let payload = b"here, take this"; + + let mut a = tokio_test::io::Builder::new() + .read(payload) + .write(payload) + .build(); + + let mut b = tokio_test::io::Builder::new() + .read(payload) + .write(payload) + .build(); + + let _ = copy_bidirectional(&mut a, &mut b).await; + } + } => {}, + _ = tokio::task::yield_now() => {} + } +}