Skip to content

Commit

Permalink
net: debug_assert on creating a tokio socket from a blocking one (#7166)
Browse files Browse the repository at this point in the history
See #5595 and #7172.

This adds a debug assertion that checks that a supplied underlying std socket is set to nonblocking mode when constructing a tokio socket object from such an object.

This only works on unix.
  • Loading branch information
Noah-Kennedy authored Mar 5, 2025
1 parent 0284d1b commit 042433c
Show file tree
Hide file tree
Showing 13 changed files with 101 additions and 3 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ unexpected_cfgs = { level = "warn", check-cfg = [
'cfg(fuzzing)',
'cfg(loom)',
'cfg(mio_unsupported_force_poll_poll)',
'cfg(tokio_allow_from_blocking_fd)',
'cfg(tokio_internal_mt_counters)',
'cfg(tokio_no_parking_lot)',
'cfg(tokio_no_tuning_tests)',
Expand Down
7 changes: 7 additions & 0 deletions tokio/src/net/tcp/listener.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::io::{Interest, PollEvented};
use crate::net::tcp::TcpStream;
use crate::util::check_socket_for_blocking;

cfg_not_wasi! {
use crate::net::{to_socket_addrs, ToSocketAddrs};
Expand Down Expand Up @@ -209,6 +210,10 @@ impl TcpListener {
/// will block the thread, which will cause unexpected behavior.
/// Non-blocking mode can be set using [`set_nonblocking`].
///
/// Passing a listener in blocking mode is always erroneous,
/// and the behavior in that case may change in the future.
/// For example, it could panic.
///
/// [`set_nonblocking`]: std::net::TcpListener::set_nonblocking
///
/// # Examples
Expand Down Expand Up @@ -236,6 +241,8 @@ impl TcpListener {
/// explicitly with [`Runtime::enter`](crate::runtime::Runtime::enter) function.
#[track_caller]
pub fn from_std(listener: net::TcpListener) -> io::Result<TcpListener> {
check_socket_for_blocking(&listener)?;

let io = mio::net::TcpListener::from_std(listener);
let io = PollEvented::new(io)?;
Ok(TcpListener { io })
Expand Down
7 changes: 7 additions & 0 deletions tokio/src/net/tcp/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ cfg_not_wasi! {
use crate::io::{AsyncRead, AsyncWrite, Interest, PollEvented, ReadBuf, Ready};
use crate::net::tcp::split::{split, ReadHalf, WriteHalf};
use crate::net::tcp::split_owned::{split_owned, OwnedReadHalf, OwnedWriteHalf};
use crate::util::check_socket_for_blocking;

use std::fmt;
use std::io;
Expand Down Expand Up @@ -173,6 +174,10 @@ impl TcpStream {
/// will block the thread, which will cause unexpected behavior.
/// Non-blocking mode can be set using [`set_nonblocking`].
///
/// Passing a listener in blocking mode is always erroneous,
/// and the behavior in that case may change in the future.
/// For example, it could panic.
///
/// [`set_nonblocking`]: std::net::TcpStream::set_nonblocking
///
/// # Examples
Expand Down Expand Up @@ -200,6 +205,8 @@ impl TcpStream {
/// explicitly with [`Runtime::enter`](crate::runtime::Runtime::enter) function.
#[track_caller]
pub fn from_std(stream: std::net::TcpStream) -> io::Result<TcpStream> {
check_socket_for_blocking(&stream)?;

let io = mio::net::TcpStream::from_std(stream);
let io = PollEvented::new(io)?;
Ok(TcpStream { io })
Expand Down
7 changes: 7 additions & 0 deletions tokio/src/net/udp.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::io::{Interest, PollEvented, ReadBuf, Ready};
use crate::net::{to_socket_addrs, ToSocketAddrs};
use crate::util::check_socket_for_blocking;

use std::fmt;
use std::io;
Expand Down Expand Up @@ -192,6 +193,10 @@ impl UdpSocket {
/// will block the thread, which will cause unexpected behavior.
/// Non-blocking mode can be set using [`set_nonblocking`].
///
/// Passing a listener in blocking mode is always erroneous,
/// and the behavior in that case may change in the future.
/// For example, it could panic.
///
/// [`set_nonblocking`]: std::net::UdpSocket::set_nonblocking
///
/// # Panics
Expand Down Expand Up @@ -220,6 +225,8 @@ impl UdpSocket {
/// ```
#[track_caller]
pub fn from_std(socket: net::UdpSocket) -> io::Result<UdpSocket> {
check_socket_for_blocking(&socket)?;

let io = mio::net::UdpSocket::from_std(socket);
UdpSocket::new(io)
}
Expand Down
7 changes: 7 additions & 0 deletions tokio/src/net/unix/datagram/socket.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::io::{Interest, PollEvented, ReadBuf, Ready};
use crate::net::unix::SocketAddr;
use crate::util::check_socket_for_blocking;

use std::fmt;
use std::io;
Expand Down Expand Up @@ -449,6 +450,10 @@ impl UnixDatagram {
/// will block the thread, which will cause unexpected behavior.
/// Non-blocking mode can be set using [`set_nonblocking`].
///
/// Passing a listener in blocking mode is always erroneous,
/// and the behavior in that case may change in the future.
/// For example, it could panic.
///
/// [`set_nonblocking`]: std::os::unix::net::UnixDatagram::set_nonblocking
///
/// # Panics
Expand Down Expand Up @@ -484,6 +489,8 @@ impl UnixDatagram {
/// ```
#[track_caller]
pub fn from_std(datagram: net::UnixDatagram) -> io::Result<UnixDatagram> {
check_socket_for_blocking(&datagram)?;

let socket = mio::net::UnixDatagram::from_std(datagram);
let io = PollEvented::new(socket)?;
Ok(UnixDatagram { io })
Expand Down
7 changes: 7 additions & 0 deletions tokio/src/net/unix/listener.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::io::{Interest, PollEvented};
use crate::net::unix::{SocketAddr, UnixStream};
use crate::util::check_socket_for_blocking;

use std::fmt;
use std::io;
Expand Down Expand Up @@ -106,6 +107,10 @@ impl UnixListener {
/// will block the thread, which will cause unexpected behavior.
/// Non-blocking mode can be set using [`set_nonblocking`].
///
/// Passing a listener in blocking mode is always erroneous,
/// and the behavior in that case may change in the future.
/// For example, it could panic.
///
/// [`set_nonblocking`]: std::os::unix::net::UnixListener::set_nonblocking
///
/// # Examples
Expand Down Expand Up @@ -133,6 +138,8 @@ impl UnixListener {
/// explicitly with [`Runtime::enter`](crate::runtime::Runtime::enter) function.
#[track_caller]
pub fn from_std(listener: net::UnixListener) -> io::Result<UnixListener> {
check_socket_for_blocking(&listener)?;

let listener = mio::net::UnixListener::from_std(listener);
let io = PollEvented::new(listener)?;
Ok(UnixListener { io })
Expand Down
7 changes: 7 additions & 0 deletions tokio/src/net/unix/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::net::unix::split::{split, ReadHalf, WriteHalf};
use crate::net::unix::split_owned::{split_owned, OwnedReadHalf, OwnedWriteHalf};
use crate::net::unix::ucred::{self, UCred};
use crate::net::unix::SocketAddr;
use crate::util::check_socket_for_blocking;

use std::fmt;
use std::future::poll_fn;
Expand Down Expand Up @@ -791,6 +792,10 @@ impl UnixStream {
/// will block the thread, which will cause unexpected behavior.
/// Non-blocking mode can be set using [`set_nonblocking`].
///
/// Passing a listener in blocking mode is always erroneous,
/// and the behavior in that case may change in the future.
/// For example, it could panic.
///
/// [`set_nonblocking`]: std::os::unix::net::UnixStream::set_nonblocking
///
/// # Examples
Expand Down Expand Up @@ -818,6 +823,8 @@ impl UnixStream {
/// explicitly with [`Runtime::enter`](crate::runtime::Runtime::enter) function.
#[track_caller]
pub fn from_std(stream: net::UnixStream) -> io::Result<UnixStream> {
check_socket_for_blocking(&stream)?;

let stream = mio::net::UnixStream::from_std(stream);
let io = PollEvented::new(stream)?;

Expand Down
29 changes: 29 additions & 0 deletions tokio/src/util/blocking_check.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#[cfg(unix)]
use std::os::fd::AsFd;

#[cfg(unix)]
#[allow(unused_variables)]
#[track_caller]
pub(crate) fn check_socket_for_blocking<S: AsFd>(s: &S) -> crate::io::Result<()> {
#[cfg(not(tokio_allow_from_blocking_fd))]
{
let sock = socket2::SockRef::from(s);

debug_assert!(
sock.nonblocking()?,
"Registering a blocking socket with the tokio runtime is unsupported. \
If you wish to do anyways, please add `--cfg tokio_allow_from_blocking_fd` to your \
RUSTFLAGS. See github.com/tokio-rs/tokio/issues/7172 for details."
);
}

Ok(())
}

#[cfg(not(unix))]
#[allow(unused_variables)]
pub(crate) fn check_socket_for_blocking<S>(s: &S) -> crate::io::Result<()> {
// we cannot retrieve the nonblocking status on windows
// and i dont know how to support wasi yet
Ok(())
}
6 changes: 6 additions & 0 deletions tokio/src/util/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ cfg_io_driver! {
#[cfg(feature = "rt")]
pub(crate) mod atomic_cell;

#[cfg(feature = "net")]
mod blocking_check;
#[cfg(feature = "net")]
#[allow(unused_imports)]
pub(crate) use blocking_check::check_socket_for_blocking;

pub(crate) mod metric_atomics;

#[cfg(any(feature = "rt", feature = "signal", feature = "process"))]
Expand Down
7 changes: 5 additions & 2 deletions tokio/tests/io_driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,10 @@ fn panics_when_io_disabled() {
let rt = runtime::Builder::new_current_thread().build().unwrap();

rt.block_on(async {
let _ =
tokio::net::TcpListener::from_std(std::net::TcpListener::bind("127.0.0.1:0").unwrap());
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();

listener.set_nonblocking(true).unwrap();

let _ = tokio::net::TcpListener::from_std(listener);
});
}
6 changes: 6 additions & 0 deletions tokio/tests/io_driver_drop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ fn tcp_doesnt_block() {
let listener = {
let _enter = rt.enter();
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();

listener.set_nonblocking(true).unwrap();

TcpListener::from_std(listener).unwrap()
};

Expand All @@ -33,6 +36,9 @@ fn drop_wakes() {
let listener = {
let _enter = rt.enter();
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();

listener.set_nonblocking(true).unwrap();

TcpListener::from_std(listener).unwrap()
};

Expand Down
6 changes: 5 additions & 1 deletion tokio/tests/no_rt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,9 @@ async fn timeout_value() {
)]
#[cfg_attr(miri, ignore)] // No `socket` in miri.
fn io_panics_when_no_tokio_context() {
let _ = tokio::net::TcpListener::from_std(std::net::TcpListener::bind("127.0.0.1:0").unwrap());
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();

listener.set_nonblocking(true).unwrap();

let _ = tokio::net::TcpListener::from_std(listener);
}
7 changes: 7 additions & 0 deletions tokio/tests/tcp_peek.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,18 @@ use std::{io::Write, net};
#[tokio::test]
async fn peek() {
let listener = net::TcpListener::bind("127.0.0.1:0").unwrap();

let addr = listener.local_addr().unwrap();
let t = thread::spawn(move || assert_ok!(listener.accept()).0);

let left = net::TcpStream::connect(addr).unwrap();

left.set_nonblocking(true).unwrap();

let mut right = t.join().unwrap();

right.set_nonblocking(true).unwrap();

let _ = right.write(&[1, 2, 3, 4]).unwrap();

let mut left: TcpStream = left.try_into().unwrap();
Expand Down

0 comments on commit 042433c

Please sign in to comment.