Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow async predicate for cors AllowOrigin #478

Merged
merged 13 commits into from
Mar 15, 2024
105 changes: 95 additions & 10 deletions tower-http/src/cors/allow_origin.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
use std::{array, fmt, sync::Arc};

use http::{
header::{self, HeaderName, HeaderValue},
request::Parts as RequestParts,
};
use pin_project_lite::pin_project;
use std::{
array, fmt,
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};

use super::{Any, WILDCARD};

Expand Down Expand Up @@ -73,6 +79,21 @@ impl AllowOrigin {
Self(OriginInner::Predicate(Arc::new(f)))
}

/// Set the allowed origins from an async predicate
///
/// See [`CorsLayer::allow_origin`] for more details.
///
/// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin
pub fn async_predicate<F, Fut>(f: F) -> Self
where
F: Fn(&HeaderValue, &RequestParts) -> Fut + Send + Sync + 'static,
Fut: Future<Output = bool> + Send + Sync + 'static,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like users would virtually always have to clone the origin HeaderValue, so what do you think about changing this to

Suggested change
F: Fn(&HeaderValue, &RequestParts) -> Fut + Send + Sync + 'static,
F: Fn(HeaderValue, &RequestParts) -> Fut + Send + Sync + 'static,

?

Copy link
Contributor Author

@PoOnesNerfect PoOnesNerfect Mar 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's an interesting point; I've thought about it for a little, and, for now, I'm still leaning towards keeping it as is, for more consistent type signatures.

For most, or all, async predicate use-cases, you would be awaiting on some external state, which requires you to follow the pattern of:

let client = get_api_client();

AllowOrigin::async_predicate(move |origin, parts| {
    let origin = origin.clone();
    let client = client.clone();
    async move {
        client.call_api(origin).await?;
        ...
    }
})

I'm not sure if removing the line let origin = origin.clone(); is a big lift in dx, when the pattern itself is already pretty verbose.
I would rather have the type signatures be consistent with the other allow origin functions.

Also, having origin as owned value may make it feel like there are two ways of doing things.

For example, this would work fine:

AllowOrigin::async_predicate(|origin, parts| async move { some_simple_logic(origin) })

So, users may try to do something like this, which will not compile:

AllowOrigin::async_predicate(|origin, parts| async move { some_simple_logic(origin, parts.uri().path()) })

This may give users a hard time with lifetime errors, if they haven't worked much with async closures.

So, I would just like there to be a single way of doing things, although it might be a bit more annoying, which is:

// this compiles
AllowOrigin::async_predicate(|origin, parts| {
    let origin = origin.clone();
    async move { some_simple_logic(origin) }
})

// this compiles
AllowOrigin::async_predicate(|origin, parts| {
    let origin = origin.clone();
    let path = parts.uri().path().to_owned();
    async move { some_simple_logic(origin, path) }
})

However, I just might be overthinking it too much, when it's just not a big deal, and will just be a better dx.

I don't think my opinion on this is strong, so I will just go with your decision.
If you think this change makes sense, I will make the change and re-request review.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an interesting point about captured variables also needing clone. Could you try changing Fn to FnOnce + Clone? I'm pretty sure that then the extra cloning of captured variables can be internalized into the library as one clone of the closure.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that works! however, if you want to use parts, you still have to clone it before passing it to the future.
Should we also just provide parts as RequestParts instead of &RequestParts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it becomes too costly at that point, and it'd be better to just let the user decide what to pass into the future.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah for the origin I think it's okay to clone it always, it's extremely likely the user will need it and not that expensive to clone. The RequestParts are much more expensive to clone and virtually never needed in full, or even at all.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense.
I've made the updates, and fixed the docs. Please take a look!

{
Self(OriginInner::AsyncPredicate(Arc::new(move |v, p| {
Box::pin(f(v, p))
})))
}

/// Allow any origin, by mirroring the request origin
///
/// This is equivalent to
Expand All @@ -90,18 +111,70 @@ impl AllowOrigin {
matches!(&self.0, OriginInner::Const(v) if v == WILDCARD)
}

pub(super) fn to_header(
pub(super) fn to_future(
&self,
origin: Option<&HeaderValue>,
parts: &RequestParts,
) -> Option<(HeaderName, HeaderValue)> {
let allow_origin = match &self.0 {
OriginInner::Const(v) => v.clone(),
OriginInner::List(l) => origin.filter(|o| l.contains(o))?.clone(),
OriginInner::Predicate(c) => origin.filter(|origin| c(origin, parts))?.clone(),
};
) -> AllowOriginFuture {
let name = header::ACCESS_CONTROL_ALLOW_ORIGIN;

Some((header::ACCESS_CONTROL_ALLOW_ORIGIN, allow_origin))
match &self.0 {
OriginInner::Const(v) => AllowOriginFuture::ok(Some((name, v.clone()))),
OriginInner::List(l) => {
AllowOriginFuture::ok(origin.filter(|o| l.contains(o)).map(|o| (name, o.clone())))
}
OriginInner::Predicate(c) => AllowOriginFuture::ok(
origin
.filter(|origin| c(origin, parts))
.map(|o| (name, o.clone())),
),
OriginInner::AsyncPredicate(f) => {
if let Some(origin) = origin.cloned() {
let fut = f(&origin, parts);
AllowOriginFuture::fut(async move { fut.await.then_some((name, origin)) })
} else {
AllowOriginFuture::ok(None)
}
}
}
}
}

pin_project! {
#[project = AllowOriginFutureProj]
pub(super) enum AllowOriginFuture {
Ok{
res: Option<(HeaderName, HeaderValue)>
},
Future{
#[pin]
future: Pin<Box<dyn Future<Output = Option<(HeaderName, HeaderValue)>> + Send + 'static>>
},
}
}

impl AllowOriginFuture {
fn ok(res: Option<(HeaderName, HeaderValue)>) -> Self {
Self::Ok { res }
}

fn fut<F: Future<Output = Option<(HeaderName, HeaderValue)>> + Send + 'static>(
future: F,
) -> Self {
Self::Future {
future: Box::pin(future),
}
}
}

impl Future for AllowOriginFuture {
type Output = Option<(HeaderName, HeaderValue)>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.project() {
AllowOriginFutureProj::Ok { res } => Poll::Ready(res.take()),
AllowOriginFutureProj::Future { future } => future.poll(cx),
}
}
}

Expand All @@ -111,6 +184,7 @@ impl fmt::Debug for AllowOrigin {
OriginInner::Const(inner) => f.debug_tuple("Const").field(inner).finish(),
OriginInner::List(inner) => f.debug_tuple("List").field(inner).finish(),
OriginInner::Predicate(_) => f.debug_tuple("Predicate").finish(),
OriginInner::AsyncPredicate(_) => f.debug_tuple("AsyncPredicate").finish(),
}
}
}
Expand Down Expand Up @@ -147,6 +221,17 @@ enum OriginInner {
Predicate(
Arc<dyn for<'a> Fn(&'a HeaderValue, &'a RequestParts) -> bool + Send + Sync + 'static>,
),
AsyncPredicate(
Arc<
dyn for<'a> Fn(
&'a HeaderValue,
&'a RequestParts,
) -> Pin<Box<dyn Future<Output = bool> + Send + 'static>>
+ Send
+ Sync
+ 'static,
>,
),
}

impl Default for OriginInner {
Expand Down
72 changes: 68 additions & 4 deletions tower-http/src/cors/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@

#![allow(clippy::enum_variant_names)]

use allow_origin::AllowOriginFuture;
use bytes::{BufMut, BytesMut};
use http::{
header::{self, HeaderName},
Expand Down Expand Up @@ -326,6 +327,42 @@ impl CorsLayer {
/// ));
/// ```
///
/// Additionally, you can use a closure that returns a future:
///
/// Because the future must be static, you must only pass owned values into it.
///
/// ```
/// # #[derive(Clone)]
/// # struct Client;
/// # fn get_api_client() -> Client {
/// # Client
/// # }
/// # impl Client {
/// # async fn fetch_allowed_origins(&self, path: String) -> Vec<HeaderValue> {
/// # vec![]
/// # }
/// # }
/// use tower_http::cors::{CorsLayer, AllowOrigin};
/// use http::{request::Parts as RequestParts, HeaderValue};
///
/// let client = get_api_client();
///
/// let layer = CorsLayer::new().allow_origin(AllowOrigin::async_predicate(
/// move |origin: &HeaderValue, request_parts: &RequestParts| {
/// let client = client.clone();
/// let origin = origin.clone();
/// let path = request_parts.uri.path().to_owned();
///
/// async move {
/// // fetch list of origins that are allowed for this path
/// let origins = client.fetch_allowed_origins(path).await;
///
/// origins.contains(&origin)
/// }
/// },
/// ));
/// ```
///
/// Note that multiple calls to this method will override any previous
/// calls.
///
Expand Down Expand Up @@ -621,11 +658,13 @@ where

// These headers are applied to both preflight and subsequent regular CORS requests:
// https://fetch.spec.whatwg.org/#http-responses
headers.extend(self.layer.allow_origin.to_header(origin, &parts));

headers.extend(self.layer.allow_credentials.to_header(origin, &parts));
headers.extend(self.layer.allow_private_network.to_header(origin, &parts));
headers.extend(self.layer.vary.to_header());

let allow_origin_future = self.layer.allow_origin.to_future(origin, &parts);

// Return results immediately upon preflight request
if parts.method == Method::OPTIONS {
// These headers are applied only to preflight requests
Expand All @@ -634,7 +673,10 @@ where
headers.extend(self.layer.max_age.to_header(origin, &parts));

ResponseFuture {
inner: Kind::PreflightCall { headers },
inner: Kind::PreflightCall {
allow_origin_future,
headers,
},
}
} else {
// This header is applied only to non-preflight requests
Expand All @@ -643,6 +685,8 @@ where
let req = Request::from_parts(parts, body);
ResponseFuture {
inner: Kind::CorsCall {
allow_origin_future,
allow_origin_complete: false,
future: self.inner.call(req),
headers,
},
Expand All @@ -663,11 +707,16 @@ pin_project! {
#[project = KindProj]
enum Kind<F> {
CorsCall {
#[pin]
allow_origin_future: AllowOriginFuture,
allow_origin_complete: bool,
#[pin]
future: F,
headers: HeaderMap,
},
PreflightCall {
#[pin]
allow_origin_future: AllowOriginFuture,
headers: HeaderMap,
},
}
Expand All @@ -682,7 +731,17 @@ where

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.project().inner.project() {
KindProj::CorsCall { future, headers } => {
KindProj::CorsCall {
allow_origin_future,
allow_origin_complete,
future,
headers,
} => {
if !*allow_origin_complete {
headers.extend(ready!(allow_origin_future.poll(cx)));
*allow_origin_complete = true;
}

let mut response: Response<B> = ready!(future.poll(cx))?;

let response_headers = response.headers_mut();
Expand All @@ -697,7 +756,12 @@ where

Poll::Ready(Ok(response))
}
KindProj::PreflightCall { headers } => {
KindProj::PreflightCall {
allow_origin_future,
headers,
} => {
headers.extend(ready!(allow_origin_future.poll(cx)));

let mut response = Response::new(B::default());
mem::swap(response.headers_mut(), headers);

Expand Down
44 changes: 43 additions & 1 deletion tower-http/src/cors/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::test_helpers::Body;
use http::{header, HeaderValue, Request, Response};
use tower::{service_fn, util::ServiceExt, Layer};

use crate::cors::CorsLayer;
use crate::cors::{AllowOrigin, CorsLayer};

#[tokio::test]
#[allow(
Expand All @@ -31,3 +31,45 @@ async fn vary_set_by_inner_service() {
assert_eq!(vary_headers.next(), Some(&PERMISSIVE_CORS_VARY_HEADERS));
assert_eq!(vary_headers.next(), None);
}

#[tokio::test]
async fn test_allow_origin_async_predicate() {
#[derive(Clone)]
struct Client;

impl Client {
async fn fetch_allowed_origins(&self, _path: String) -> Vec<HeaderValue> {
vec![HeaderValue::from_static("http://example.com")]
}
}

let client = Client;

let allow_origin = AllowOrigin::async_predicate(move |origin, parts| {
let origin = origin.clone();
let client = client.clone();
let path = parts.uri.path().to_owned();

async move {
let origins = client.fetch_allowed_origins(path).await;

origins.contains(&origin)
}
});

let valid_origin = HeaderValue::from_static("http://example.com");
let parts = http::Request::new("hello world").into_parts().0;

let header = allow_origin
.to_future(Some(&valid_origin), &parts)
.await
.unwrap();
assert_eq!(header.0, header::ACCESS_CONTROL_ALLOW_ORIGIN);
assert_eq!(header.1, valid_origin);

let invalid_origin = HeaderValue::from_static("http://example.org");
let parts = http::Request::new("hello world").into_parts().0;

let res = allow_origin.to_future(Some(&invalid_origin), &parts).await;
assert!(res.is_none());
}
Loading