Skip to content

Commit

Permalink
feat(s2n-quic-dc): add WithMap router
Browse files Browse the repository at this point in the history
  • Loading branch information
camshaft committed Mar 11, 2025
1 parent c8de638 commit 9cbb70f
Show file tree
Hide file tree
Showing 6 changed files with 343 additions and 71 deletions.
223 changes: 194 additions & 29 deletions dc/s2n-quic-dc/src/socket/recv/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,23 @@
use crate::{
credentials::Credentials,
packet::{self, stream},
path::secret,
socket::recv::descriptor,
};
use s2n_codec::DecoderBufferMut;
use s2n_quic_core::inet::{ExplicitCongestionNotification, SocketAddress};

/// Routes incoming packet segments to the appropriate destination
pub trait Router {
/// Wraps `self` in a router that intercepts secret control messages and forwards
/// them to the provided [`secret::Map`].
fn with_map(self, map: secret::Map) -> WithMap<Self>
where
Self: Sized,
{
WithMap { inner: self, map }
}

fn is_open(&self) -> bool;

#[inline(always)]
Expand All @@ -27,37 +37,48 @@ pub trait Router {
// We don't check `remaining` since we currently assume one packet per segment.
// If we ever support multiple packets per segment, we'll need to split the segment up even
// further and correctly dispatch to the right place.
Ok((packet, _remaining)) => match packet {
packet::Packet::Control(packet) => {
let tag = packet.tag();
let stream_id = packet.stream_id().copied();
let credentials = *packet.credentials();
self.handle_control_packet(remote_address, ecn, packet);
self.dispatch_control_packet(tag, stream_id, credentials, segment);
}
packet::Packet::Stream(packet) => {
let tag = packet.tag();
let stream_id = *packet.stream_id();
let credentials = *packet.credentials();
self.handle_stream_packet(remote_address, ecn, packet);
self.dispatch_stream_packet(tag, stream_id, credentials, segment);
}
packet::Packet::Datagram(packet) => {
let tag = packet.tag();
let credentials = *packet.credentials();
self.handle_datagram_packet(remote_address, ecn, packet);
self.dispatch_datagram_packet(tag, credentials, segment);
Ok((packet, remaining)) => {
if cfg!(test) {
assert!(remaining.is_empty());
}
packet::Packet::StaleKey(packet) => {
self.handle_stale_key_packet(packet, remote_address);
match packet {
packet::Packet::Control(packet) => {
let tag = packet.tag();
let stream_id = packet.stream_id().copied();
let credentials = *packet.credentials();
tracing::trace!(?tag, ?stream_id, ?credentials, "parsed_control_packet");
self.handle_control_packet(remote_address, ecn, packet);
self.dispatch_control_packet(tag, stream_id, credentials, segment);
}
packet::Packet::Stream(packet) => {
let tag = packet.tag();
let stream_id = *packet.stream_id();
let credentials = *packet.credentials();
tracing::trace!(?tag, ?stream_id, ?credentials, "parsed_stream_packet");
self.handle_stream_packet(remote_address, ecn, packet);
self.dispatch_stream_packet(tag, stream_id, credentials, segment);
}
packet::Packet::Datagram(packet) => {
let tag = packet.tag();
let credentials = *packet.credentials();
tracing::trace!(?tag, ?credentials, "parsed_datagram_packet");
self.handle_datagram_packet(remote_address, ecn, packet);
self.dispatch_datagram_packet(tag, credentials, segment);
}
packet::Packet::StaleKey(packet) => {
tracing::trace!(?packet, "parsed_stale_key_packet");
self.handle_stale_key_packet(packet, remote_address);
}
packet::Packet::ReplayDetected(packet) => {
tracing::trace!(?packet, "parsed_replay_detected_packet");
self.handle_replay_detected_packet(packet, remote_address);
}
packet::Packet::UnknownPathSecret(packet) => {
tracing::trace!(?packet, "parsed_unknown_path_secret_packet");
self.handle_unknown_path_secret_packet(packet, remote_address);
}
}
packet::Packet::ReplayDetected(packet) => {
self.handle_replay_detected_packet(packet, remote_address);
}
packet::Packet::UnknownPathSecret(packet) => {
self.handle_unknown_path_secret_packet(packet, remote_address);
}
},
}
Err(error) => {
self.on_decode_error(error, remote_address, segment);
}
Expand Down Expand Up @@ -196,3 +217,147 @@ pub trait Router {
);
}
}

#[derive(Clone)]
pub struct WithMap<Inner> {
inner: Inner,
map: crate::path::secret::Map,
}

impl<Inner: Router> Router for WithMap<Inner> {
#[inline]
fn is_open(&self) -> bool {
self.inner.is_open()
}

#[inline]
fn tag_len(&self) -> usize {
self.inner.tag_len()
}

#[inline]
fn handle_control_packet(
&mut self,
remote_address: SocketAddress,
ecn: ExplicitCongestionNotification,
packet: packet::control::decoder::Packet,
) {
self.inner
.handle_control_packet(remote_address, ecn, packet);
}

#[inline]
fn dispatch_control_packet(
&mut self,
tag: packet::control::Tag,
id: Option<stream::Id>,
credentials: Credentials,
segment: descriptor::Filled,
) {
self.inner
.dispatch_control_packet(tag, id, credentials, segment);
}

#[inline]
fn handle_stream_packet(
&mut self,
remote_address: SocketAddress,
ecn: ExplicitCongestionNotification,
packet: packet::stream::decoder::Packet,
) {
self.inner.handle_stream_packet(remote_address, ecn, packet);
}

#[inline]
fn dispatch_stream_packet(
&mut self,
tag: stream::Tag,
id: stream::Id,
credentials: Credentials,
segment: descriptor::Filled,
) {
self.inner
.dispatch_stream_packet(tag, id, credentials, segment);
}

#[inline]
fn handle_datagram_packet(
&mut self,
remote_address: SocketAddress,
ecn: ExplicitCongestionNotification,
packet: packet::datagram::decoder::Packet,
) {
self.inner
.handle_datagram_packet(remote_address, ecn, packet);
}

#[inline]
fn dispatch_datagram_packet(
&mut self,
tag: packet::datagram::Tag,
credentials: Credentials,
segment: descriptor::Filled,
) {
self.inner
.dispatch_datagram_packet(tag, credentials, segment);
}

#[inline]
fn handle_stale_key_packet(
&mut self,
packet: packet::secret_control::stale_key::Packet,
remote_address: SocketAddress,
) {
// TODO check if the packet was authentic before forwarding the packet on to inner
self.map.handle_control_packet(
&packet::secret_control::Packet::StaleKey(packet),
&remote_address.into(),
);
self.inner.handle_stale_key_packet(packet, remote_address);
}

#[inline]
fn handle_replay_detected_packet(
&mut self,
packet: packet::secret_control::replay_detected::Packet,
remote_address: SocketAddress,
) {
// TODO check if the packet was authentic before forwarding the packet on to inner
self.map.handle_control_packet(
&packet::secret_control::Packet::ReplayDetected(packet),
&remote_address.into(),
);
self.inner
.handle_replay_detected_packet(packet, remote_address);
}

#[inline]
fn handle_unknown_path_secret_packet(
&mut self,
packet: packet::secret_control::unknown_path_secret::Packet,
remote_address: SocketAddress,
) {
// TODO check if the packet was authentic before forwarding the packet on to inner
self.map.handle_control_packet(
&packet::secret_control::Packet::UnknownPathSecret(packet),
&remote_address.into(),
);
self.inner
.handle_unknown_path_secret_packet(packet, remote_address);
}

#[inline]
fn on_unhandled_packet(&mut self, remote_address: SocketAddress, packet: packet::Packet) {
self.inner.on_unhandled_packet(remote_address, packet);
}

#[inline]
fn on_decode_error(
&mut self,
error: s2n_codec::DecoderError,
remote_address: SocketAddress,
segment: descriptor::Filled,
) {
self.inner.on_decode_error(error, remote_address, segment);
}
}
65 changes: 58 additions & 7 deletions dc/s2n-quic-dc/src/socket/recv/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@

use crate::{
socket::recv::{pool, router::Router},
stream::socket::fd::udp,
stream::socket::{fd::udp, Socket},
};
use std::net::UdpSocket;
use std::{io, os::fd::AsRawFd, task::Poll};

/// Receives packets from a blocking [`UdpSocket`] and dispatches into the provided [`Router`]
pub fn blocking<R: Router>(socket: UdpSocket, mut pool: pool::Pool, mut router: R) {
loop {
let mut unfilled = pool.alloc_or_grow();
loop {
/// Receives packets from a blocking [`std::net::UdpSocket`] and dispatches into the provided [`Router`]
pub fn blocking<S: AsRawFd, R: Router>(socket: S, mut alloc: pool::Pool, mut router: R) {
while router.is_open() {
let mut unfilled = alloc.alloc_or_grow();
while router.is_open() {
let res = unfilled.recv_with(|addr, cmsg, buffer| {
udp::recv(&socket, addr, cmsg, &mut [buffer], Default::default())
});
Expand All @@ -32,3 +32,54 @@ pub fn blocking<R: Router>(socket: UdpSocket, mut pool: pool::Pool, mut router:
}
}
}

/// Receives packets from a non-blocking [`std::net::UdpSocket`] and dispatches into the provided [`Router`]
pub async fn non_blocking<S: Socket, R: Router>(socket: S, mut alloc: pool::Pool, mut router: R) {
let mut pending = None;
core::future::poll_fn(move |cx| {
while router.is_open() {
let unfilled = pending.take().unwrap_or_else(|| alloc.alloc_or_grow());

let res = unfilled.recv_with(|addr, cmsg, buffer| {
match socket.poll_recv(cx, addr, cmsg, &mut [buffer]) {
Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
Poll::Ready(Ok(len)) => Ok(len),
Poll::Ready(Err(err)) => Err(err),
}
});

match res {
Ok(segments) => {
for segment in segments {
router.on_segment(segment);
}

// poll the socket again
continue;
}
Err((desc, err)) => {
// put the unfilled segment back in the pool
pending = Some(desc);

let kind = err.kind();

// if we got blocked then yield the future
if kind == io::ErrorKind::WouldBlock {
return Poll::Pending;
}

// if tokio is shutting down, it starts returning an `Other` error
if kind == io::ErrorKind::Other {
tracing::info!("worker shutting down due to: {err}");
break;
}

tracing::error!("socket recv error (kind={:?}): {err}", err.kind());
}
}
}

Poll::Ready(())
})
.await;
}
Loading

0 comments on commit 9cbb70f

Please sign in to comment.