Skip to content

Commit

Permalink
feat(s2n-quic-dc): add SendOnly UDP socket
Browse files Browse the repository at this point in the history
  • Loading branch information
camshaft committed Mar 12, 2025
1 parent c8de638 commit d3ae894
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 39 deletions.
2 changes: 2 additions & 0 deletions dc/s2n-quic-dc/src/stream/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use super::TransportFeatures;
pub mod application;
pub mod fd;
mod handle;
mod send_only;
#[cfg(feature = "tokio")]
mod tokio;
mod tracing;
Expand All @@ -14,6 +15,7 @@ pub use self::tracing::Tracing;
pub use crate::socket::*;
pub use application::Application;
pub use handle::{Ext, Flags, Socket};
pub use send_only::SendOnly;

pub type ArcApplication = std::sync::Arc<dyn Application>;

Expand Down
14 changes: 12 additions & 2 deletions dc/s2n-quic-dc/src/stream/socket/application/builder.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

use crate::stream::socket::ArcApplication;
use std::io;
use crate::stream::socket::{application, fd, ArcApplication, SendOnly, Tracing};
use std::{io, sync::Arc};

pub trait Builder: 'static + Send + Sync {
fn build(self: Box<Self>) -> io::Result<ArcApplication>;
}

impl<S: fd::udp::Socket> Builder for SendOnly<Arc<S>> {
#[inline]
fn build(self: Box<Self>) -> io::Result<ArcApplication> {
let v = Tracing(*self);
let v = application::Single(v);
let v = Arc::new(v);
Ok(v)
}
}

#[cfg(feature = "tokio")]
mod tokio_impl {
use super::*;
Expand Down
36 changes: 35 additions & 1 deletion dc/s2n-quic-dc/src/stream/socket/fd/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,42 @@ use crate::msg::{
use s2n_quic_core::inet::ExplicitCongestionNotification;
use std::{
io::{self, IoSlice, IoSliceMut},
net::SocketAddr,
os::fd::AsRawFd,
};

pub trait Socket: 'static + AsRawFd + Send + Sync {
fn local_addr(&self) -> io::Result<SocketAddr>;
}

impl Socket for std::net::UdpSocket {
#[inline]
fn local_addr(&self) -> io::Result<SocketAddr> {
(*self).local_addr()
}
}

impl Socket for tokio::net::UdpSocket {
#[inline]
fn local_addr(&self) -> io::Result<SocketAddr> {
(*self).local_addr()
}
}

impl<T: Socket> Socket for std::sync::Arc<T> {
#[inline]
fn local_addr(&self) -> io::Result<SocketAddr> {
(**self).local_addr()
}
}

impl<T: Socket> Socket for Box<T> {
#[inline]
fn local_addr(&self) -> io::Result<SocketAddr> {
(**self).local_addr()
}
}

pub use super::peek;

#[inline]
Expand Down Expand Up @@ -66,12 +99,13 @@ pub fn send<T>(
addr: &Addr,
ecn: ExplicitCongestionNotification,
buffer: &[IoSlice],
flags: Flags,
) -> io::Result<usize>
where
T: AsRawFd,
{
send_msghdr(addr, ecn, buffer, |msghdr| {
libc_call(|| unsafe { libc::sendmsg(fd.as_raw_fd(), msghdr, 0) as _ })
libc_call(|| unsafe { libc::sendmsg(fd.as_raw_fd(), msghdr, flags) as _ })
})
}

Expand Down
102 changes: 102 additions & 0 deletions dc/s2n-quic-dc/src/stream/socket/send_only.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

use super::{fd::udp, Protocol, Socket, TransportFeatures};
use crate::msg::{addr::Addr, cmsg};
use core::task::{Context, Poll};
use s2n_quic_core::{ensure, inet::ExplicitCongestionNotification};
use std::{
io::{self, IoSlice, IoSliceMut},
net::SocketAddr,
};

#[derive(Clone, Debug)]
pub struct SendOnly<T: udp::Socket>(pub T);

impl<T> Socket for SendOnly<T>
where
T: udp::Socket,
{
#[inline]
fn local_addr(&self) -> io::Result<SocketAddr> {
self.0.local_addr()
}

#[inline]
fn protocol(&self) -> Protocol {
Protocol::Udp
}

#[inline]
fn features(&self) -> TransportFeatures {
TransportFeatures::UDP
}

#[inline]
fn poll_peek_len(&self, _cx: &mut Context) -> Poll<io::Result<usize>> {
unimplemented!()
}

#[inline]
fn poll_recv(
&self,
_cx: &mut Context,
_addr: &mut Addr,
_cmsg: &mut cmsg::Receiver,
_buffer: &mut [IoSliceMut],
) -> Poll<io::Result<usize>> {
unimplemented!()
}

#[inline]
fn try_send(
&self,
addr: &Addr,
ecn: ExplicitCongestionNotification,
buffer: &[IoSlice],
) -> io::Result<usize> {
// no point in sending empty packets
ensure!(!buffer.is_empty(), Ok(0));

debug_assert!(
buffer.iter().any(|s| !s.is_empty()),
"trying to send from an empty buffer"
);

debug_assert!(
addr.get().port() != 0,
"cannot send packet to unspecified port"
);

loop {
match udp::send(&self.0, addr, ecn, buffer, libc::MSG_DONTWAIT) {
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {
// try the operation again if we were interrupted
continue;
}
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
// we got a WouldBlock so pretend we sent it - we have no way of registering interest
return Ok(buffer.iter().map(|s| s.len()).sum());
}
res => return res,
}
}
}

#[inline]
fn poll_send(
&self,
_cx: &mut Context,
addr: &Addr,
ecn: ExplicitCongestionNotification,
buffer: &[IoSlice],
) -> Poll<io::Result<usize>> {
self.try_send(addr, ecn, buffer).into()
}

#[inline]
fn send_finish(&self) -> io::Result<()> {
// UDP sockets don't need a shut down
Ok(())
}
}
39 changes: 3 additions & 36 deletions dc/s2n-quic-dc/src/stream/socket/tokio/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,45 +11,12 @@ use s2n_quic_core::{ensure, inet::ExplicitCongestionNotification, ready};
use std::{
io::{self, IoSlice, IoSliceMut},
net::SocketAddr,
os::fd::AsRawFd,
};
use tokio::io::unix::{AsyncFd, TryIoError};

trait UdpSocket: 'static + AsRawFd + Send + Sync {
fn local_addr(&self) -> io::Result<SocketAddr>;
}

impl UdpSocket for std::net::UdpSocket {
#[inline]
fn local_addr(&self) -> io::Result<SocketAddr> {
(*self).local_addr()
}
}

impl UdpSocket for tokio::net::UdpSocket {
#[inline]
fn local_addr(&self) -> io::Result<SocketAddr> {
(*self).local_addr()
}
}

impl<T: UdpSocket> UdpSocket for std::sync::Arc<T> {
#[inline]
fn local_addr(&self) -> io::Result<SocketAddr> {
(**self).local_addr()
}
}

impl<T: UdpSocket> UdpSocket for Box<T> {
#[inline]
fn local_addr(&self) -> io::Result<SocketAddr> {
(**self).local_addr()
}
}

impl<T> Socket for AsyncFd<T>
where
T: UdpSocket,
T: udp::Socket,
{
#[inline]
fn local_addr(&self) -> io::Result<SocketAddr> {
Expand Down Expand Up @@ -152,7 +119,7 @@ where
);

loop {
match udp::send(self.get_ref(), addr, ecn, buffer) {
match udp::send(self.get_ref(), addr, ecn, buffer, Default::default()) {
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {
// try the operation again if we were interrupted
continue;
Expand Down Expand Up @@ -186,7 +153,7 @@ where
loop {
let mut socket = ready!(self.poll_write_ready(cx))?;

let res = socket.try_io(|fd| udp::send(fd, addr, ecn, buffer));
let res = socket.try_io(|fd| udp::send(fd, addr, ecn, buffer, Default::default()));

match res {
Ok(Ok(len)) => return Ok(len).into(),
Expand Down

0 comments on commit d3ae894

Please sign in to comment.