From 25e2706eef03b253f8219572df11e2fb920c1586 Mon Sep 17 00:00:00 2001 From: Cameron Bytheway Date: Mon, 10 Mar 2025 11:57:33 -0600 Subject: [PATCH] feat(s2n-quic-dc): implement queue allocator/dispatcher (#2517) --- dc/s2n-quic-dc/src/socket/recv/router.rs | 2 + dc/s2n-quic-dc/src/stream/recv.rs | 1 + dc/s2n-quic-dc/src/stream/recv/dispatch.rs | 209 ++++++++++++ .../corpus.tar.gz | 3 + .../src/stream/recv/dispatch/descriptor.rs | 195 +++++++++++ .../src/stream/recv/dispatch/free_list.rs | 148 +++++++++ .../src/stream/recv/dispatch/handle.rs | 124 +++++++ .../src/stream/recv/dispatch/pool.rs | 211 ++++++++++++ .../src/stream/recv/dispatch/queue.rs | 184 +++++++++++ .../src/stream/recv/dispatch/sender.rs | 68 ++++ .../src/stream/recv/dispatch/tests.rs | 306 ++++++++++++++++++ dc/s2n-quic-dc/src/sync/mpsc.rs | 32 +- dc/s2n-quic-dc/src/sync/ring_deque.rs | 31 +- dc/s2n-quic-dc/src/testing.rs | 15 + 14 files changed, 1515 insertions(+), 14 deletions(-) create mode 100644 dc/s2n-quic-dc/src/stream/recv/dispatch.rs create mode 100644 dc/s2n-quic-dc/src/stream/recv/dispatch/__fuzz__/stream__recv__dispatch__tests__model/corpus.tar.gz create mode 100644 dc/s2n-quic-dc/src/stream/recv/dispatch/descriptor.rs create mode 100644 dc/s2n-quic-dc/src/stream/recv/dispatch/free_list.rs create mode 100644 dc/s2n-quic-dc/src/stream/recv/dispatch/handle.rs create mode 100644 dc/s2n-quic-dc/src/stream/recv/dispatch/pool.rs create mode 100644 dc/s2n-quic-dc/src/stream/recv/dispatch/queue.rs create mode 100644 dc/s2n-quic-dc/src/stream/recv/dispatch/sender.rs create mode 100644 dc/s2n-quic-dc/src/stream/recv/dispatch/tests.rs diff --git a/dc/s2n-quic-dc/src/socket/recv/router.rs b/dc/s2n-quic-dc/src/socket/recv/router.rs index 5b8dd9a77..68d8d1fb3 100644 --- a/dc/s2n-quic-dc/src/socket/recv/router.rs +++ b/dc/s2n-quic-dc/src/socket/recv/router.rs @@ -11,6 +11,8 @@ use s2n_quic_core::inet::{ExplicitCongestionNotification, SocketAddress}; /// Routes incoming packet segments to the appropriate destination pub trait Router { + fn is_open(&self) -> bool; + #[inline(always)] fn tag_len(&self) -> usize { 16 diff --git a/dc/s2n-quic-dc/src/stream/recv.rs b/dc/s2n-quic-dc/src/stream/recv.rs index 2426e18d1..0ade98c00 100644 --- a/dc/s2n-quic-dc/src/stream/recv.rs +++ b/dc/s2n-quic-dc/src/stream/recv.rs @@ -4,6 +4,7 @@ mod ack; pub mod application; pub(crate) mod buffer; +pub mod dispatch; mod error; mod packet; mod probes; diff --git a/dc/s2n-quic-dc/src/stream/recv/dispatch.rs b/dc/s2n-quic-dc/src/stream/recv/dispatch.rs new file mode 100644 index 000000000..d2612189e --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/recv/dispatch.rs @@ -0,0 +1,209 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::{credentials, packet, socket::recv::descriptor as desc, sync::ring_deque}; +use s2n_quic_core::{inet::SocketAddress, varint::VarInt}; +use tracing::debug; + +mod descriptor; +mod free_list; +mod handle; +mod pool; +mod queue; +mod sender; + +#[cfg(test)] +mod tests; + +/// Allocate this many channels at a time +/// +/// With `debug_assertions`, we allocate smaller pages to try and cover more +/// branches in the allocator logic around growth. +const PAGE_SIZE: usize = if cfg!(debug_assertions) { 8 } else { 256 }; + +pub type Error = queue::Error; +pub type Control = handle::Control; +pub type Stream = handle::Stream; + +/// A queue allocator for registering a receiver to process packets +/// for a given ID. +#[derive(Clone)] +pub struct Allocator { + pool: pool::Pool, +} + +impl Allocator { + pub fn new( + stream_capacity: impl Into, + control_capacity: impl Into, + ) -> Self { + Self { + pool: pool::Pool::new( + VarInt::ZERO, + stream_capacity.into(), + control_capacity.into(), + ), + } + } + + /// Creates an allocator with a non-zero queue id + /// + /// This is used for patterns where the `queue_id=0` is special and used to + /// indicate newly initialized flows waiting to be assigned. For example, + /// a client sends a packet with `queue_id=0` to a server and waits for the + /// server to respond with an actual `queue_id` for future packets from the client. + pub fn new_non_zero( + stream_capacity: impl Into, + control_capacity: impl Into, + ) -> Self { + Self { + pool: pool::Pool::new( + VarInt::from_u8(1), + stream_capacity.into(), + control_capacity.into(), + ), + } + } + + #[inline] + pub fn dispatcher(&self) -> Dispatch { + Dispatch { + senders: self.pool.senders(), + is_open: true, + } + } + + #[inline] + pub fn alloc(&mut self) -> Option<(Control, Stream)> { + self.pool.alloc() + } + + #[inline] + pub fn alloc_or_grow(&mut self) -> (Control, Stream) { + self.pool.alloc_or_grow() + } +} + +/// A dispatcher which routes packets to the specified queue, if +/// there is a registered receiver. +#[derive(Clone)] +pub struct Dispatch { + senders: sender::Senders, + is_open: bool, +} + +impl Dispatch { + #[inline] + pub fn send_control( + &mut self, + queue_id: VarInt, + segment: desc::Filled, + ) -> Result, Error> { + let mut res = Err(Error::Unallocated); + self.senders.lookup(queue_id, |sender| { + res = sender.send_control(segment); + }); + + if matches!(res, Err(Error::Closed)) { + self.is_open = false; + } + + res + } + + #[inline] + pub fn send_stream( + &mut self, + queue_id: VarInt, + segment: desc::Filled, + ) -> Result, Error> { + let mut res = Err(Error::Unallocated); + self.senders.lookup(queue_id, |sender| { + res = sender.send_stream(segment); + }); + + if matches!(res, Err(Error::Closed)) { + self.is_open = false; + } + + res + } +} + +impl crate::socket::recv::router::Router for Dispatch { + #[inline(always)] + fn is_open(&self) -> bool { + self.is_open + } + + #[inline(always)] + fn tag_len(&self) -> usize { + 16 + } + + /// implement this so we don't get warnings about not handling it + #[inline(always)] + fn handle_control_packet( + &mut self, + _remote_address: SocketAddress, + _ecn: s2n_quic_core::inet::ExplicitCongestionNotification, + _packet: packet::control::decoder::Packet, + ) { + } + + #[inline] + fn dispatch_control_packet( + &mut self, + _tag: packet::control::Tag, + id: Option, + credentials: credentials::Credentials, + segment: desc::Filled, + ) { + let Some(id) = id else { + return; + }; + + match self.send_control(id.queue_id, segment) { + Ok(None) => {} + Ok(Some(_prev)) => { + // TODO increment metrics + debug!(queue_id = %id.queue_id, "control queue overflow"); + } + Err(_) => { + // TODO increment metrics + debug!(stream_id = ?id, ?credentials, "unroutable control packet"); + } + } + } + + /// implement this so we don't get warnings about not handling it + #[inline(always)] + fn handle_stream_packet( + &mut self, + _remote_address: SocketAddress, + _ecn: s2n_quic_core::inet::ExplicitCongestionNotification, + _packet: packet::stream::decoder::Packet, + ) { + } + + #[inline] + fn dispatch_stream_packet( + &mut self, + _tag: packet::stream::Tag, + id: packet::stream::Id, + credentials: credentials::Credentials, + segment: desc::Filled, + ) { + match self.send_stream(id.queue_id, segment) { + Ok(None) => {} + Ok(Some(_prev)) => { + // TODO increment metrics + debug!(queue_id = %id.queue_id, "stream queue overflow"); + } + Err(_) => { + // TODO increment metrics + debug!(stream_id = ?id, ?credentials, "unroutable stream packet"); + } + } + } +} diff --git a/dc/s2n-quic-dc/src/stream/recv/dispatch/__fuzz__/stream__recv__dispatch__tests__model/corpus.tar.gz b/dc/s2n-quic-dc/src/stream/recv/dispatch/__fuzz__/stream__recv__dispatch__tests__model/corpus.tar.gz new file mode 100644 index 000000000..f1457090c --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/recv/dispatch/__fuzz__/stream__recv__dispatch__tests__model/corpus.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:64dc6c1ff0f7c89de06fcc817360f949a95b2d524afe99789e2429c9e082603f +size 2488320 diff --git a/dc/s2n-quic-dc/src/stream/recv/dispatch/descriptor.rs b/dc/s2n-quic-dc/src/stream/recv/dispatch/descriptor.rs new file mode 100644 index 000000000..b6e974dbe --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/recv/dispatch/descriptor.rs @@ -0,0 +1,195 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::{free_list::FreeList, queue::Queue}; +use crate::sync::ring_deque; +use s2n_quic_core::{ensure, varint::VarInt}; +use std::{ + marker::PhantomData, + ptr::NonNull, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, +}; +use tracing::trace; + +/// A pointer to a single descriptor in a group +/// +/// Fundamentally, this is similar to something like `Arc`. However, +/// unlike [`Arc`] which frees back to the global allocator, a Descriptor deallocates into +/// the backing [`FreeList`]. +pub(super) struct Descriptor { + ptr: NonNull>, + phantom: PhantomData>, +} + +impl Descriptor { + #[inline] + pub(super) fn new(ptr: NonNull>) -> Self { + Self { + ptr, + phantom: PhantomData, + } + } + + /// # Safety + /// + /// The caller needs to guarantee the [`Descriptor`] is still allocated. Additionally, + /// the [`Self::drop_sender`] method should be used when the cloned descriptor is + /// no longer needed. + #[inline] + pub unsafe fn clone_for_sender(&self) -> Descriptor { + self.inner().senders.fetch_add(1, Ordering::Relaxed); + Descriptor::new(self.ptr) + } + + /// # Safety + /// + /// This should only be called once the caller can guarantee the descriptor is no longer + /// used. + #[inline] + pub unsafe fn drop_in_place(&self) { + core::ptr::drop_in_place(self.ptr.as_ptr()); + } + + /// # Safety + /// + /// The caller needs to guarantee the [`Descriptor`] is still allocated. + #[inline] + pub unsafe fn queue_id(&self) -> VarInt { + self.inner().id + } + + /// # Safety + /// + /// The caller needs to guarantee the [`Descriptor`] is still allocated. + #[inline] + pub unsafe fn stream_queue(&self) -> &Queue { + &self.inner().stream + } + + /// # Safety + /// + /// The caller needs to guarantee the [`Descriptor`] is still allocated. + #[inline] + pub unsafe fn control_queue(&self) -> &Queue { + &self.inner().control + } + + #[inline] + fn inner(&self) -> &DescriptorInner { + unsafe { self.ptr.as_ref() } + } + + /// # Safety + /// + /// * The [`Descriptor`] needs to be marked as free of receivers + #[inline] + pub unsafe fn into_receiver_pair(self) -> (Self, Self) { + let inner = self.inner(); + + // open the queues back up for receiving + inner.control.open_receiver(); + inner.stream.open_receiver(); + + let other = Self { + ptr: self.ptr, + phantom: PhantomData, + }; + + (self, other) + } + + /// # Safety + /// + /// This method can be used to drop the Descriptor, but shouldn't be called after the last sender Descriptor + /// is released. That implies only calling it once on a given Descriptor handle obtained from [`Self::clone_for_sender`]. + #[inline] + pub unsafe fn drop_sender(&self) { + let inner = self.inner(); + let desc_ref = inner.senders.fetch_sub(1, Ordering::Release); + debug_assert_ne!(desc_ref, 0, "reference count underflow"); + + // based on the implementation in: + // https://github.com/rust-lang/rust/blob/28b83ee59698ae069f5355b8e03f976406f410f5/library/alloc/src/sync.rs#L2551 + if desc_ref != 1 { + trace!(id = ?inner.id, "drop_sender"); + return; + } + + core::sync::atomic::fence(Ordering::Acquire); + + // close both of the queues so the receivers are notified + inner.control.close(); + inner.stream.close(); + trace!(id = ?inner.id, "close_queue"); + } + + /// # Safety + /// + /// This method can be used to drop the Descriptor, but shouldn't be called after the last receiver Descriptor + /// is released. That implies only calling it once on a given Descriptor handle obtained from [`Self::into_receiver_pair`]. + #[inline] + pub unsafe fn drop_stream_receiver(&self) { + let inner = self.inner(); + trace!(id = ?inner.id, "drop_stream_receiver"); + inner.stream.close_receiver(); + // check if the control is still open + ensure!(!inner.control.has_receiver()); + let storage = inner.free_list.free(Descriptor { + ptr: self.ptr, + phantom: PhantomData, + }); + drop(storage); + } + + /// # Safety + /// + /// This method can be used to drop the Descriptor, but shouldn't be called after the last receiver Descriptor + /// is released. That implies only calling it once on a given Descriptor handle obtained from [`Self::into_receiver_pair`]. + #[inline] + pub unsafe fn drop_control_receiver(&self) { + let inner = self.inner(); + trace!(id = ?inner.id, "drop_control_receiver"); + inner.control.close_receiver(); + // check if the stream is still open + ensure!(!inner.stream.has_receiver()); + let storage = inner.free_list.free(Descriptor { + ptr: self.ptr, + phantom: PhantomData, + }); + drop(storage); + } +} + +unsafe impl Send for Descriptor {} +unsafe impl Sync for Descriptor {} + +pub(super) struct DescriptorInner { + id: VarInt, + stream: Queue, + control: Queue, + /// A reference back to the free list + free_list: Arc>, + senders: AtomicUsize, +} + +impl DescriptorInner { + pub(super) fn new( + id: VarInt, + stream: ring_deque::Capacity, + control: ring_deque::Capacity, + free_list: Arc>, + ) -> Self { + let stream = Queue::new(stream); + let control = Queue::new(control); + Self { + id, + stream, + control, + senders: AtomicUsize::new(0), + free_list, + } + } +} diff --git a/dc/s2n-quic-dc/src/stream/recv/dispatch/free_list.rs b/dc/s2n-quic-dc/src/stream/recv/dispatch/free_list.rs new file mode 100644 index 000000000..b825f4bff --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/recv/dispatch/free_list.rs @@ -0,0 +1,148 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::{ + descriptor::Descriptor, + handle::{Control, Stream}, + pool::Region, +}; +use std::sync::{Arc, Mutex}; + +/// Callback which releases a descriptor back into the free list +pub(super) trait FreeList: 'static + Send + Sync { + /// Frees a descriptor back into the free list + /// + /// Once the free list has been closed and all descriptors returned, the `free` function + /// should return an object that can be dropped to release all of the memory associated + /// with the descriptor pool. This works around any issues around the "Stacked Borrows" + /// model by deferring freeing memory borrowed by `self`. + fn free(&self, descriptor: Descriptor) -> Option>; +} + +/// A free list of unfilled descriptors +/// +/// Note that this uses a [`Vec`] instead of [`std::collections::VecDeque`], which acts more +/// like a stack than a queue. This is to prefer more-recently used descriptors which should +/// hopefully reduce the number of cache misses. +pub(super) struct FreeVec(Mutex>); + +impl FreeVec { + #[inline] + pub fn new(initial_cap: usize) -> (Arc, Arc>) { + let descriptors = Vec::with_capacity(initial_cap); + let regions = Vec::with_capacity(1); + let inner = FreeInner { + descriptors, + regions, + total: 0, + open: true, + }; + let free = Arc::new(Self(Mutex::new(inner))); + let memory = Arc::new(Memory(free.clone())); + (free, memory) + } + + #[inline] + pub fn alloc(&self) -> Option<(Control, Stream)> { + self.0.lock().unwrap().descriptors.pop().map(|v| unsafe { + // SAFETY: the descriptor is only owned by the free list + let (control, stream) = v.into_receiver_pair(); + (Control::new(control), Stream::new(stream)) + }) + } + + #[inline] + pub fn record_region(&self, region: Region, mut descriptors: Vec>) { + let mut inner = self.0.lock().unwrap(); + inner.regions.push(region); + inner.total += descriptors.len(); + inner.descriptors.append(&mut descriptors); + // Even though the `descriptors` is now empty (`len=0`), it still owns + // capacity and will need to be freed. Drop the lock before interacting + // with the global allocator. + drop(inner); + drop(descriptors); + } + + #[inline] + fn try_free(&self) -> Option> { + let mut inner = self.0.lock().unwrap(); + inner.open = false; + inner.try_free() + } +} + +/// A memory reference to the free list +/// +/// Once dropped, the pool and all associated descriptors will be +/// freed after the last handle is dropped. +pub(super) struct Memory(Arc>); + +impl Drop for Memory { + #[inline] + fn drop(&mut self) { + drop(self.0.try_free()); + } +} + +impl FreeList for FreeVec { + #[inline] + fn free(&self, descriptor: Descriptor) -> Option> { + let mut inner = self.0.lock().unwrap(); + inner.descriptors.push(descriptor); + if inner.open { + return None; + } + inner + .try_free() + .map(|to_free| Box::new(to_free) as Box) + } +} + +struct FreeInner { + descriptors: Vec>, + regions: Vec>, + total: usize, + open: bool, +} + +impl FreeInner { + #[inline(never)] // this is rarely called + fn try_free(&mut self) -> Option { + if self.descriptors.len() < self.total { + tracing::trace!("waiting for more descriptors to be freed"); + return None; + } + + tracing::trace!("all descriptors freed back to pool"); + + // move all of the allocations out of itself, since this is self-referential + Some(core::mem::replace( + self, + FreeInner { + descriptors: Vec::new(), + regions: Vec::new(), + total: 0, + open: false, + }, + )) + } +} + +impl Drop for FreeInner { + #[inline] + fn drop(&mut self) { + if self.descriptors.is_empty() { + return; + } + + tracing::trace!("dropping {} descriptors", self.descriptors.len()); + + for descriptor in self.descriptors.drain(..) { + unsafe { + // SAFETY: the free list is closed and there are no outstanding descriptors + descriptor.drop_in_place(); + } + } + } +} diff --git a/dc/s2n-quic-dc/src/stream/recv/dispatch/handle.rs b/dc/s2n-quic-dc/src/stream/recv/dispatch/handle.rs new file mode 100644 index 000000000..34554a8f9 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/recv/dispatch/handle.rs @@ -0,0 +1,124 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::{descriptor::Descriptor, queue::Error}; +use crate::sync::ring_deque; +use core::{ + fmt, + task::{Context, Poll}, +}; +use s2n_quic_core::varint::VarInt; +use std::collections::VecDeque; + +macro_rules! impl_recv { + ($name:ident, $field:ident, $drop:ident) => { + pub struct $name { + descriptor: Descriptor, + } + + impl $name { + #[inline] + pub(super) fn new(descriptor: Descriptor) -> Self { + Self { descriptor } + } + + /// Returns the associated `queue_id` for the channel + /// + /// This can be sent to a peer, which can be used to route packets back to the channel. + #[inline] + pub fn queue_id(&self) -> VarInt { + unsafe { self.descriptor.queue_id() } + } + + #[inline] + pub fn push(&self, item: T) -> Option { + unsafe { self.descriptor.$field().force_push(item) } + } + + #[inline] + pub fn try_recv(&self) -> Result, ring_deque::Closed> { + unsafe { self.descriptor.$field().pop() } + } + + #[inline] + pub async fn recv(&self) -> Result { + core::future::poll_fn(|cx| self.poll_recv(cx)).await + } + + #[inline] + pub fn poll_recv(&self, cx: &mut Context) -> Poll> { + unsafe { self.descriptor.$field().poll_pop(cx) } + } + + #[inline] + pub fn poll_swap( + &self, + cx: &mut Context, + out: &mut VecDeque, + ) -> Poll> { + unsafe { self.descriptor.$field().poll_swap(cx, out) } + } + } + + impl fmt::Debug for $name { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct(stringify!($name)) + .field("queue_id", &self.queue_id()) + .finish() + } + } + + impl Drop for $name { + #[inline] + fn drop(&mut self) { + unsafe { + self.descriptor.$drop(); + } + } + } + }; +} + +impl_recv!(Control, control_queue, drop_control_receiver); +impl_recv!(Stream, stream_queue, drop_stream_receiver); + +pub struct Sender { + descriptor: Descriptor, +} + +impl Clone for Sender { + #[inline] + fn clone(&self) -> Self { + unsafe { + Self { + descriptor: self.descriptor.clone_for_sender(), + } + } + } +} + +impl Sender { + #[inline] + pub(super) fn new(descriptor: Descriptor) -> Self { + Self { descriptor } + } + + #[inline] + pub fn send_stream(&self, item: T) -> Result, Error> { + unsafe { self.descriptor.stream_queue().push(item) } + } + + #[inline] + pub fn send_control(&self, item: T) -> Result, Error> { + unsafe { self.descriptor.control_queue().push(item) } + } +} + +impl Drop for Sender { + #[inline] + fn drop(&mut self) { + unsafe { + self.descriptor.drop_sender(); + } + } +} diff --git a/dc/s2n-quic-dc/src/stream/recv/dispatch/pool.rs b/dc/s2n-quic-dc/src/stream/recv/dispatch/pool.rs new file mode 100644 index 000000000..fc800c609 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/recv/dispatch/pool.rs @@ -0,0 +1,211 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::{ + descriptor::{Descriptor, DescriptorInner}, + free_list::{self, FreeVec}, + handle::{Control, Sender, Stream}, + sender::{SenderPages, Senders}, +}; +use crate::sync::ring_deque::Capacity; +use s2n_quic_core::varint::VarInt; +use std::{ + alloc::Layout, + marker::PhantomData, + ptr::NonNull, + sync::{Arc, RwLock}, +}; + +pub struct Pool { + senders: Arc>>, + free: Arc>, + /// Holds the backing memory allocated as long as there's at least one reference + memory_handle: Arc>, + stream_capacity: Capacity, + control_capacity: Capacity, + epoch: VarInt, +} + +impl Clone for Pool { + #[inline] + fn clone(&self) -> Self { + Self { + free: self.free.clone(), + memory_handle: self.memory_handle.clone(), + senders: self.senders.clone(), + stream_capacity: self.stream_capacity, + control_capacity: self.control_capacity, + epoch: self.epoch, + } + } +} + +impl Pool { + #[inline] + pub fn new(epoch: VarInt, stream_capacity: Capacity, control_capacity: Capacity) -> Self { + let (free, memory_handle) = FreeVec::new(PAGE_SIZE); + let senders = Arc::new(RwLock::new(SenderPages::new(epoch))); + let mut pool = Pool { + free, + memory_handle, + senders, + stream_capacity, + control_capacity, + epoch, + }; + pool.grow(); + pool + } + + #[inline] + pub fn senders(&self) -> Senders { + Senders { + senders: self.senders.clone(), + // make sure the memory lives as long as this sender is alive + memory_handle: self.memory_handle.clone(), + local: Default::default(), + } + } + + #[inline] + pub fn alloc(&self) -> Option<(Control, Stream)> { + self.free.alloc() + } + + #[inline] + pub fn alloc_or_grow(&mut self) -> (Control, Stream) { + loop { + if let Some(descriptor) = self.alloc() { + return descriptor; + } + self.grow(); + } + } + + #[inline(never)] // this should happen rarely + fn grow(&mut self) { + let (region, layout) = Region::alloc(PAGE_SIZE); + + let ptr = region.ptr; + + let mut pending_desc = vec![]; + let mut pending_senders = vec![]; + + for idx in 0..PAGE_SIZE { + let offset = layout.size() * idx; + + unsafe { + let descriptor = ptr.as_ptr().add(offset).cast::>(); + + // Give the descriptor a non-`Strong` reference to the free list, since this will be the + // last reference to get dropped. + let free_list = self.free.clone(); + + // initialize the descriptor with the channels + descriptor.write(DescriptorInner::new( + self.epoch + idx, + self.stream_capacity, + self.control_capacity, + free_list, + )); + + let descriptor = NonNull::new_unchecked(descriptor); + let descriptor = Descriptor::new(descriptor); + let sender = Sender::new(descriptor.clone_for_sender()); + + // push the descriptor into the free list + pending_desc.push(descriptor); + + // push the senders into the sender page + pending_senders.push(sender); + } + } + + let pending_senders: Arc<[_]> = pending_senders.into(); + + let mut senders = self.senders.write().unwrap(); + + // check if another pool instance already updated the senders list + if senders.epoch != self.epoch { + // update our local copy + self.epoch = senders.epoch; + + // free what we just allocated, since we raced with the other pool instance + for desc in pending_desc { + unsafe { + desc.drop_in_place(); + } + } + + // return back to the alloc method, which may have a free descriptor now + return; + } + + // update the epoch with the latest value + let target_epoch = self.epoch + PAGE_SIZE; + senders.epoch = target_epoch; + self.epoch = target_epoch; + + // update the sender list with the newly allocated channels + senders.pages.push(pending_senders); + // we don't need to synchronize with the senders any more so drop the local + drop(senders); + + // push all of the descriptors into the free list + self.free.record_region(region, pending_desc); + } +} + +pub(super) struct Region { + ptr: NonNull, + layout: Layout, + phantom: PhantomData, +} + +unsafe impl Send for Region {} +unsafe impl Sync for Region {} + +impl Region { + #[inline] + fn alloc(page_size: usize) -> (Self, Layout) { + debug_assert!(page_size > 0, "need at least 1 entry in page"); + + // first create the descriptor layout + let descriptor = Layout::new::>().pad_to_align(); + + let descriptors = { + // TODO use `descriptor.repeat(page_size)` once stable + // https://doc.rust-lang.org/stable/core/alloc/struct.Layout.html#method.repeat + Layout::from_size_align( + descriptor.size().checked_mul(page_size).unwrap(), + descriptor.align(), + ) + .unwrap() + }; + + let ptr = unsafe { + // SAFETY: the layout is non-zero size + debug_assert_ne!(descriptors.size(), 0); + // ensure that the allocation is zeroed out so we don't have to worry about MaybeUninit + std::alloc::alloc_zeroed(descriptors) + }; + let ptr = NonNull::new(ptr).unwrap_or_else(|| std::alloc::handle_alloc_error(descriptors)); + + let region = Self { + ptr, + layout: descriptors, + phantom: PhantomData, + }; + + (region, descriptor) + } +} + +impl Drop for Region { + #[inline] + fn drop(&mut self) { + unsafe { + std::alloc::dealloc(self.ptr.as_ptr(), self.layout); + } + } +} diff --git a/dc/s2n-quic-dc/src/stream/recv/dispatch/queue.rs b/dc/s2n-quic-dc/src/stream/recv/dispatch/queue.rs new file mode 100644 index 000000000..69cbdbdeb --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/recv/dispatch/queue.rs @@ -0,0 +1,184 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::sync::ring_deque::{Capacity, Closed, RecvWaker}; +use core::task::{Context, Poll}; +use s2n_quic_core::ensure; +use std::{collections::VecDeque, sync::Mutex, task::Waker}; +use tracing::trace; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum Error { + /// The queue ID is not associated with a stream + Unallocated, + /// The queue has been closed and won't reopen + Closed, +} + +impl From for Error { + #[inline] + fn from(_: Closed) -> Self { + Self::Closed + } +} + +struct Inner { + queue: VecDeque, + capacity: usize, + is_open: bool, + has_receiver: bool, + waker: Option, +} + +pub struct Queue { + inner: Mutex>, +} + +impl Queue { + #[inline] + pub fn new(capacity: Capacity) -> Self { + Self { + inner: Mutex::new(Inner { + queue: VecDeque::with_capacity(capacity.initial), + capacity: capacity.max, + is_open: true, + has_receiver: false, + waker: None, + }), + } + } + + #[inline] + pub fn push(&self, value: T) -> Result, Error> { + let mut inner = self.lock()?; + // check if the queue is permanently closed + ensure!(inner.is_open, Err(Error::Closed)); + // check if the queue is temporarily closed + ensure!(inner.has_receiver, Err(Error::Unallocated)); + + let prev = if inner.capacity == inner.queue.len() { + inner.queue.pop_front() + } else { + None + }; + + trace!(has_overflow = prev.is_some(), "push"); + + inner.queue.push_back(value); + let waker = inner.waker.take(); + drop(inner); + if let Some(waker) = waker { + waker.wake(); + } + + Ok(prev) + } + + /// Bypasses closed checks and pushes items into the queue + #[inline] + pub fn force_push(&self, value: T) -> Option { + let Ok(mut inner) = self.lock() else { + return Some(value); + }; + + let prev = if inner.capacity == inner.queue.len() { + inner.queue.pop_front() + } else { + None + }; + + trace!(has_overflow = prev.is_some(), "push"); + + inner.queue.push_back(value); + let waker = inner.waker.take(); + drop(inner); + if let Some(waker) = waker { + waker.wake(); + } + + prev + } + + #[inline] + pub fn pop(&self) -> Result, Closed> { + let mut inner = self.lock()?; + trace!(has_items = !inner.queue.is_empty(), "pop"); + if let Some(item) = inner.queue.pop_front() { + Ok(Some(item)) + } else { + ensure!(inner.is_open, Err(Closed)); + Ok(None) + } + } + + #[inline] + pub fn poll_pop(&self, cx: &mut Context) -> Poll> { + let mut inner = self.lock()?; + trace!(has_items = !inner.queue.is_empty(), "poll_pop"); + if let Some(item) = inner.queue.pop_front() { + Ok(item).into() + } else { + ensure!(inner.is_open, Err(Closed).into()); + inner.waker.update(cx); + Poll::Pending + } + } + + #[inline] + pub fn poll_swap(&self, cx: &mut Context, items: &mut VecDeque) -> Poll> { + let mut inner = self.lock()?; + trace!(items = 0, "poll_swap"); + if inner.queue.is_empty() { + ensure!(inner.is_open, Err(Closed).into()); + inner.waker.update(cx); + return Poll::Pending; + } + core::mem::swap(items, &mut inner.queue); + Ok(()).into() + } + + #[inline] + pub fn has_receiver(&self) -> bool { + self.lock().map(|inner| inner.has_receiver).unwrap_or(false) + } + + #[inline] + pub fn open_receiver(&self) { + let Ok(mut inner) = self.lock() else { + return; + }; + trace!("opening receiver"); + inner.has_receiver = true; + } + + #[inline] + pub fn close_receiver(&self) { + let Ok(mut inner) = self.lock() else { + return; + }; + trace!("closing receiver"); + inner.has_receiver = false; + inner.waker = None; + inner.queue.clear(); + } + + #[inline] + pub fn close(&self) { + let Ok(mut inner) = self.lock() else { + return; + }; + trace!("close queue"); + inner.is_open = false; + // Leave the remaining items in the queue in case the receiver wants them. + + // Notify the receiver that the queue is now closed + if let Some(waker) = inner.waker.take() { + waker.wake(); + } + } + + #[inline] + fn lock(&self) -> Result>, Closed> { + self.inner.lock().map_err(|_| Closed) + } +} diff --git a/dc/s2n-quic-dc/src/stream/recv/dispatch/sender.rs b/dc/s2n-quic-dc/src/stream/recv/dispatch/sender.rs new file mode 100644 index 000000000..62484d992 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/recv/dispatch/sender.rs @@ -0,0 +1,68 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::{free_list, handle::Sender}; +use s2n_quic_core::varint::VarInt; +use std::sync::{Arc, RwLock}; + +pub struct Senders { + pub(super) senders: Arc>>, + pub(super) local: Vec]>>, + pub(super) memory_handle: Arc>, +} + +impl Clone for Senders { + fn clone(&self) -> Self { + Self { + senders: self.senders.clone(), + memory_handle: self.memory_handle.clone(), + local: self.local.clone(), + } + } +} + +impl Senders { + #[inline] + pub fn lookup)>(&mut self, queue_id: VarInt, f: F) { + let queue_id = queue_id.as_u64() as usize; + let page = queue_id / PAGE_SIZE; + let offset = queue_id % PAGE_SIZE; + + if self.local.len() <= page { + let Ok(senders) = self.senders.read() else { + return; + }; + + // the senders haven't been updated + if self.local.len() == senders.pages.len() { + return; + } + + self.local + .extend_from_slice(&senders.pages[self.local.len()..]); + } + + let Some(page) = self.local.get(page) else { + return; + }; + let Some(sender) = page.get(offset) else { + return; + }; + f(sender) + } +} + +pub(super) struct SenderPages { + pub(super) pages: Vec]>>, + pub(super) epoch: VarInt, +} + +impl SenderPages { + #[inline] + pub(super) fn new(epoch: VarInt) -> Self { + Self { + pages: Vec::with_capacity(8), + epoch, + } + } +} diff --git a/dc/s2n-quic-dc/src/stream/recv/dispatch/tests.rs b/dc/s2n-quic-dc/src/stream/recv/dispatch/tests.rs new file mode 100644 index 000000000..797d89a1a --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/recv/dispatch/tests.rs @@ -0,0 +1,306 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use crate::{ + socket::recv, + testing::{ext::*, sim}, +}; +use bolero::{check, TypeGenerator}; +use s2n_quic_core::varint::VarInt; +use std::{collections::BTreeMap, panic::AssertUnwindSafe}; + +#[derive(Clone, Debug, TypeGenerator)] +enum Op { + Alloc, + FreeControl { idx: u16 }, + FreeStream { idx: u16 }, + SendControl { idx: u16 }, + SendStream { idx: u16, inject: bool }, + DropAllocator, + DropDispatcher, +} + +struct Model { + oracle: Oracle, + alloc: Option, + dispatch: Option, +} + +impl Default for Model { + fn default() -> Self { + Self::new(Default::default()) + } +} + +impl Model { + fn new(packets: Packets) -> Self { + let stream_cap = 32; + let control_cap = 8; + let alloc = Allocator::new(stream_cap, control_cap); + let dispatch = alloc.dispatcher(); + let oracle = Oracle::new(packets); + + Self { + oracle, + alloc: Some(alloc), + dispatch: Some(dispatch), + } + } + + fn apply(&mut self, op: &Op) { + match op { + Op::Alloc => { + self.alloc(); + } + Op::FreeControl { idx } => { + self.free_control((*idx).into()); + } + Op::FreeStream { idx } => { + self.free_stream((*idx).into()); + } + Op::SendControl { idx } => { + self.send_control((*idx).into()); + } + Op::SendStream { idx, inject } => { + self.send_stream((*idx).into(), *inject); + } + Op::DropAllocator => { + self.alloc = None; + } + Op::DropDispatcher => { + self.dispatch = None; + } + } + } + + fn alloc(&mut self) { + let Some(alloc) = self.alloc.as_mut() else { + return; + }; + let (control, stream) = alloc.alloc_or_grow(); + self.oracle.on_alloc(control, stream); + } + + fn free_control(&mut self, idx: VarInt) { + let _ = self.oracle.control.remove(&idx); + } + + fn free_stream(&mut self, idx: VarInt) { + let _ = self.oracle.stream.remove(&idx); + } + + fn send_control(&mut self, queue_id: VarInt) { + let Some(dispatch) = self.dispatch.as_mut() else { + return; + }; + + let (packet_id, packet) = self.oracle.packets.create(); + let res = dispatch.send_control(queue_id, packet); + self.oracle.on_control_dispatch(queue_id, packet_id, res); + } + + fn send_stream(&mut self, queue_id: VarInt, inject: bool) { + if inject { + return self.oracle.send_stream_inject(queue_id); + } + + let Some(dispatch) = self.dispatch.as_mut() else { + return; + }; + + let (packet_id, packet) = self.oracle.packets.create(); + let res = dispatch.send_stream(queue_id, packet); + self.oracle.on_stream_dispatch(queue_id, packet_id, res); + } +} + +struct Oracle { + stream: BTreeMap, + control: BTreeMap, + packets: Packets, +} + +impl Oracle { + fn new(packets: Packets) -> Self { + Self { + packets, + stream: Default::default(), + control: Default::default(), + } + } + + fn on_alloc(&mut self, control: Control, stream: Stream) { + let queue_id = control.queue_id(); + assert_eq!(queue_id, stream.queue_id(), "queue IDs should match"); + + assert!( + control.try_recv().unwrap().is_none(), + "queue should be empty" + ); + assert!( + stream.try_recv().unwrap().is_none(), + "queue should be empty" + ); + + assert!( + self.control.insert(queue_id, control).is_none(), + "queue ID should be unique" + ); + assert!( + self.stream.insert(queue_id, stream).is_none(), + "queue ID should be unique" + ); + } + + fn on_control_dispatch( + &mut self, + idx: VarInt, + packet_id: u64, + result: Result, Error>, + ) { + let Some(channel) = self.control.get(&idx) else { + assert!(result.is_err()); + return; + }; + assert!(result.is_ok()); + let actual = channel.try_recv().unwrap().unwrap(); + assert_eq!( + actual.payload(), + packet_id.to_be_bytes(), + "queue should contain expected packet id" + ); + assert!( + channel.try_recv().unwrap().is_none(), + "queue should be empty now" + ); + } + + fn on_stream_dispatch( + &mut self, + idx: VarInt, + packet_id: u64, + result: Result, Error>, + ) { + let Some(channel) = self.stream.get(&idx) else { + assert!(result.is_err()); + return; + }; + assert!(result.is_ok()); + let actual = channel.try_recv().unwrap().unwrap(); + assert_eq!( + actual.payload(), + packet_id.to_be_bytes(), + "queue should contain expected packet id" + ); + assert!( + channel.try_recv().unwrap().is_none(), + "queue should be empty now" + ); + } + + fn send_stream_inject(&mut self, idx: VarInt) { + let Some(channel) = self + .stream + .get(&idx) + .or_else(|| self.stream.first_key_value().map(|(_k, v)| v)) + else { + return; + }; + let (packet_id, packet) = self.packets.create(); + assert!(channel.push(packet).is_none(), "queue should accept packet"); + let actual = channel.try_recv().unwrap().unwrap(); + assert_eq!( + actual.payload(), + packet_id.to_be_bytes(), + "queue should contain expected packet id" + ); + if matches!(channel.try_recv(), Ok(Some(_))) { + panic!("queue should be empty or errored"); + } + } +} + +#[derive(Clone)] +struct Packets { + packets: recv::pool::Pool, + packet_id: u64, +} + +impl Default for Packets { + fn default() -> Self { + Self { + packets: recv::pool::Pool::new(8, 8), + packet_id: Default::default(), + } + } +} + +impl Packets { + fn create(&mut self) -> (u64, recv::descriptor::Filled) { + let packet_id = self.packet_id; + self.packet_id += 1; + let unfilled = self.packets.alloc_or_grow(); + let packet = unfilled + .recv_with(|_addr, _cmsg, mut payload| { + let v = packet_id.to_be_bytes(); + payload[..v.len()].copy_from_slice(&v); + >::Ok(v.len()) + }) + .unwrap() + .next() + .unwrap(); + (packet_id, packet) + } +} + +#[test] +fn model_test() { + crate::testing::init_tracing(); + + // create a Packet allocator once to avoid setup/teardown costs + let packets = AssertUnwindSafe(Packets::default()); + + check!() + .with_type::>() + .with_test_time(core::time::Duration::from_secs(30)) + .for_each(move |ops| { + let mut model = Model::new(packets.clone()); + for op in ops { + model.apply(op); + } + }); +} + +/// ensure that freeing an allocator notifies all of the open receivers +#[test] +fn alloc_drop_notify() { + sim(|| { + let stream_cap = 1; + let control_cap = 1; + let mut alloc = Allocator::new(stream_cap, control_cap); + + for _ in 0..2 { + let (stream, control) = alloc.alloc_or_grow(); + + async move { + stream.recv().await.unwrap_err(); + } + .primary() + .spawn(); + + async move { + control.recv().await.unwrap_err(); + } + .primary() + .spawn(); + } + + async move { + core::time::Duration::from_millis(100).sleep().await; + + drop(alloc); + } + .spawn(); + }); +} diff --git a/dc/s2n-quic-dc/src/sync/mpsc.rs b/dc/s2n-quic-dc/src/sync/mpsc.rs index 1dafdd046..7f24a428a 100644 --- a/dc/s2n-quic-dc/src/sync/mpsc.rs +++ b/dc/s2n-quic-dc/src/sync/mpsc.rs @@ -11,10 +11,11 @@ use std::{ task::Waker, }; -pub use ring_deque::{Closed, Priority}; +pub use ring_deque::{Capacity, Closed, Priority}; -pub fn new(cap: usize) -> (Sender, Receiver) { - assert!(cap >= 1, "capacity must be at least 2"); +pub fn new(cap: impl Into) -> (Sender, Receiver) { + let cap = cap.into(); + assert!(cap.max >= 1, "capacity must be at least 2"); let channel = Arc::new(Channel { queue: RingDeque::new(cap), @@ -34,6 +35,16 @@ struct Channel { } impl Channel { + #[inline] + fn clone_for_sender(self: &Arc) -> Arc { + let count = self.sender_count.fetch_add(1, Ordering::Relaxed); + + // Make sure the count never overflows, even if lots of sender clones are leaked. + assert!(count < usize::MAX / 2, "too many senders"); + + self.clone() + } + /// Closes the channel and notifies all blocked operations. /// /// Returns `Err` if this call has closed the channel and it was not closed already. @@ -85,14 +96,10 @@ impl fmt::Debug for Sender { } impl Clone for Sender { + #[inline] fn clone(&self) -> Sender { - let count = self.channel.sender_count.fetch_add(1, Ordering::Relaxed); - - // Make sure the count never overflows, even if lots of sender clones are leaked. - assert!(count < usize::MAX / 2, "too many senders"); - Sender { - channel: self.channel.clone(), + channel: self.channel.clone_for_sender(), } } } @@ -114,6 +121,13 @@ impl Drop for Receiver { } impl Receiver { + #[inline] + pub fn sender(&self) -> Sender { + Sender { + channel: self.channel.clone_for_sender(), + } + } + /// Attempts to receive a message from the front of the channel. /// /// If the channel is empty, or empty and closed, this method returns an error. diff --git a/dc/s2n-quic-dc/src/sync/ring_deque.rs b/dc/s2n-quic-dc/src/sync/ring_deque.rs index 3724f428c..1ef39c321 100644 --- a/dc/s2n-quic-dc/src/sync/ring_deque.rs +++ b/dc/s2n-quic-dc/src/sync/ring_deque.rs @@ -14,6 +14,24 @@ mod tests; #[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] pub struct Closed; +#[derive(Clone, Copy, Debug)] +pub struct Capacity { + /// Set the upper bound of items in the queue + pub max: usize, + /// Initial allocated capacity + pub initial: usize, +} + +impl From for Capacity { + #[inline] + fn from(capacity: usize) -> Self { + Self { + max: capacity, + initial: capacity, + } + } +} + #[derive(Clone, Copy, Debug, Default)] pub enum Priority { #[default] @@ -36,7 +54,7 @@ impl Clone for RingDeque { impl RingDeque { #[inline] - pub fn new(capacity: usize) -> Self { + pub fn new>(capacity: C) -> Self { let waker = W::default(); Self::with_waker(capacity, waker) } @@ -44,11 +62,13 @@ impl RingDeque { impl RingDeque { #[inline] - pub fn with_waker(capacity: usize, recv_waker: W) -> Self { - let queue = VecDeque::with_capacity(capacity); + pub fn with_waker>(capacity: C, recv_waker: W) -> Self { + let capacity = capacity.into(); + let queue = VecDeque::with_capacity(capacity.initial); let inner = Inner { open: true, queue, + capacity: capacity.max, recv_waker, }; let inner = Arc::new(Mutex::new(inner)); @@ -59,7 +79,7 @@ impl RingDeque { pub fn push_back(&self, value: T) -> Result, Closed> { let mut inner = self.lock()?; - let prev = if inner.queue.capacity() == inner.queue.len() { + let prev = if inner.capacity == inner.queue.len() { inner.queue.pop_front() } else { None @@ -79,7 +99,7 @@ impl RingDeque { pub fn push_front(&self, value: T) -> Result, Closed> { let mut inner = self.lock()?; - let prev = if inner.queue.capacity() == inner.queue.len() { + let prev = if inner.capacity == inner.queue.len() { inner.queue.pop_back() } else { None @@ -227,6 +247,7 @@ impl RingDeque { struct Inner { open: bool, queue: VecDeque, + capacity: usize, recv_waker: W, } diff --git a/dc/s2n-quic-dc/src/testing.rs b/dc/s2n-quic-dc/src/testing.rs index b90e7f323..4dc715859 100644 --- a/dc/s2n-quic-dc/src/testing.rs +++ b/dc/s2n-quic-dc/src/testing.rs @@ -29,6 +29,7 @@ pub fn init_tracing() { let format = tracing_subscriber::fmt::format() //.with_level(false) // don't include levels in formatted output //.with_ansi(false) + .with_timer(Uptime::default()) .compact(); // Use a less verbose output format. let default_level = if cfg!(debug_assertions) { @@ -51,6 +52,20 @@ pub fn init_tracing() { }); } +#[derive(Default)] +struct Uptime(tracing_subscriber::fmt::time::SystemTime); + +// Generate the timestamp from the testing IO provider rather than wall clock. +impl tracing_subscriber::fmt::time::FormatTime for Uptime { + fn format_time(&self, w: &mut tracing_subscriber::fmt::format::Writer<'_>) -> std::fmt::Result { + if bach::is_active() { + write!(w, "{}", bach::time::Instant::now()) + } else { + self.0.format_time(w) + } + } +} + /// Runs a function in a deterministic, discrete event simulation environment pub fn sim(f: impl FnOnce()) { init_tracing();