diff --git a/tokio-util/src/sync/cancellation_token.rs b/tokio-util/src/sync/cancellation_token.rs index 66fbf1a73e7..f4f0ce9f9b8 100644 --- a/tokio-util/src/sync/cancellation_token.rs +++ b/tokio-util/src/sync/cancellation_token.rs @@ -287,6 +287,22 @@ impl CancellationToken { } .await } + + /// Runs a future to completion and returns its result wrapped inside of an `Option` + /// unless the `CancellationToken` is cancelled. In that case the function returns + /// `None` and the future gets dropped. + /// + /// The function takes self by value. + /// + /// # Cancel safety + /// + /// This method is only cancel safe if `fut` is cancel safe. + pub async fn run_until_cancelled_owned(self, fut: F) -> Option + where + F: Future, + { + self.run_until_cancelled(fut).await + } } // ===== impl WaitForCancellationFuture ===== diff --git a/tokio-util/tests/sync_cancellation_token.rs b/tokio-util/tests/sync_cancellation_token.rs index db33114a2e3..9332a8f9d02 100644 --- a/tokio-util/tests/sync_cancellation_token.rs +++ b/tokio-util/tests/sync_cancellation_token.rs @@ -493,3 +493,58 @@ fn run_until_cancelled_test() { ); } } + +#[test] +fn run_until_cancelled_owned_test() { + let (waker, _) = new_count_waker(); + + { + let token = CancellationToken::new(); + let to_cancel = token.clone(); + + let takes_ownership = move |token: CancellationToken| { + token.run_until_cancelled_owned(std::future::pending::<()>()) + }; + + let fut = takes_ownership(token); + pin!(fut); + + assert_eq!( + Poll::Pending, + fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + + to_cancel.cancel(); + + assert_eq!( + Poll::Ready(None), + fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + } + + { + let (tx, rx) = oneshot::channel::<()>(); + + let token = CancellationToken::new(); + let takes_ownership = move |token: CancellationToken, rx: oneshot::Receiver<()>| { + token.run_until_cancelled_owned(async move { + rx.await.unwrap(); + 42 + }) + }; + let fut = takes_ownership(token, rx); + pin!(fut); + + assert_eq!( + Poll::Pending, + fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + + tx.send(()).unwrap(); + + assert_eq!( + Poll::Ready(Some(42)), + fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + } +}