Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

normalize_path: Add Append mode #547

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions tower-http/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,17 @@ pub trait ServiceBuilderExt<L>: crate::sealed::Sealed<L> + Sized {
fn trim_trailing_slash(
self,
) -> ServiceBuilder<Stack<crate::normalize_path::NormalizePathLayer, L>>;

/// Normalize paths based on the specified `mode`.
///
/// See [`tower_http::normalize_path`] for more details.
///
/// [`tower_http::normalize_path`]: crate::normalize_path
#[cfg(feature = "normalize-path")]
fn normalize_path(
self,
mode: crate::normalize_path::NormalizeMode,
) -> ServiceBuilder<Stack<crate::normalize_path::NormalizePathLayer, L>>;
}

impl<L> crate::sealed::Sealed<L> for ServiceBuilder<L> {}
Expand Down Expand Up @@ -594,4 +605,11 @@ impl<L> ServiceBuilderExt<L> for ServiceBuilder<L> {
) -> ServiceBuilder<Stack<crate::normalize_path::NormalizePathLayer, L>> {
self.layer(crate::normalize_path::NormalizePathLayer::trim_trailing_slash())
}
#[cfg(feature = "normalize-path")]
fn normalize_path(
self,
mode: crate::normalize_path::NormalizeMode,
) -> ServiceBuilder<Stack<crate::normalize_path::NormalizePathLayer, L>> {
self.layer(crate::normalize_path::NormalizePathLayer::new(mode))
}
}
187 changes: 164 additions & 23 deletions tower-http/src/normalize_path.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
//! Middleware that normalizes paths.
//!
//! Any trailing slashes from request paths will be removed. For example, a request with `/foo/`
//! will be changed to `/foo` before reaching the inner service.
//! Normalizes the request paths based on the provided `NormalizeMode`
//!
//! # Example
//!
//! ```
//! use tower_http::normalize_path::NormalizePathLayer;
//! use tower_http::normalize_path::{NormalizePathLayer, NormalizeMode};
//! use http::{Request, Response, StatusCode};
//! use http_body_util::Full;
//! use bytes::Bytes;
Expand All @@ -22,7 +21,7 @@
//!
//! let mut service = ServiceBuilder::new()
//! // trim trailing slashes from paths
//! .layer(NormalizePathLayer::trim_trailing_slash())
//! .layer(NormalizePathLayer::new(NormalizeMode::Trim))
//! .service_fn(handle);
//!
//! // call the service
Expand All @@ -45,27 +44,47 @@ use std::{
use tower_layer::Layer;
use tower_service::Service;

/// Different modes of normalizing paths
#[derive(Debug, Copy, Clone)]
pub enum NormalizeMode {
/// Normalizes paths by trimming the trailing slashes, e.g. /foo/ -> /foo
Trim,
/// Normalizes paths by appending trailing slash, e.g. /foo -> /foo/
Append,
}

/// Layer that applies [`NormalizePath`] which normalizes paths.
///
/// See the [module docs](self) for more details.
#[derive(Debug, Copy, Clone)]
pub struct NormalizePathLayer {}
pub struct NormalizePathLayer {
mode: NormalizeMode,
}

impl NormalizePathLayer {
/// Create a new [`NormalizePathLayer`].
///
/// Any trailing slashes from request paths will be removed. For example, a request with `/foo/`
/// will be changed to `/foo` before reaching the inner service.
pub fn trim_trailing_slash() -> Self {
NormalizePathLayer {}
NormalizePathLayer {
mode: NormalizeMode::Trim,
}
}

/// Create a new [`NormalizePathLayer`].
///
/// Creates a new `NormalizePathLayer` with the specified mode.
pub fn new(mode: NormalizeMode) -> Self {
NormalizePathLayer { mode }
}
}

impl<S> Layer<S> for NormalizePathLayer {
type Service = NormalizePath<S>;

fn layer(&self, inner: S) -> Self::Service {
NormalizePath::trim_trailing_slash(inner)
NormalizePath::new(inner, self.mode)
}
}

Expand All @@ -74,16 +93,16 @@ impl<S> Layer<S> for NormalizePathLayer {
/// See the [module docs](self) for more details.
#[derive(Debug, Copy, Clone)]
pub struct NormalizePath<S> {
mode: NormalizeMode,
inner: S,
}

impl<S> NormalizePath<S> {
/// Create a new [`NormalizePath`].
///
/// Any trailing slashes from request paths will be removed. For example, a request with `/foo/`
/// will be changed to `/foo` before reaching the inner service.
pub fn trim_trailing_slash(inner: S) -> Self {
Self { inner }
/// Normalize path based on the specified `mode`
pub fn new(inner: S, mode: NormalizeMode) -> Self {
Self { mode, inner }
}

define_inner_service_accessors!();
Expand All @@ -103,12 +122,15 @@ where
}

fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
normalize_trailing_slash(req.uri_mut());
match self.mode {
NormalizeMode::Trim => trim_trailing_slash(req.uri_mut()),
NormalizeMode::Append => append_trailing_slash(req.uri_mut()),
}
self.inner.call(req)
}
}

fn normalize_trailing_slash(uri: &mut Uri) {
fn trim_trailing_slash(uri: &mut Uri) {
if !uri.path().ends_with('/') && !uri.path().starts_with("//") {
return;
}
Expand Down Expand Up @@ -137,14 +159,48 @@ fn normalize_trailing_slash(uri: &mut Uri) {
}
}

fn append_trailing_slash(uri: &mut Uri) {
if uri.path().ends_with("/") && !uri.path().ends_with("//") {
return;
}

let trimmed = uri.path().trim_matches('/');
let new_path = if trimmed.is_empty() {
"/".to_string()
} else {
format!("/{}/", trimmed)
};

let mut parts = uri.clone().into_parts();

let new_path_and_query = if let Some(path_and_query) = &parts.path_and_query {
let new_path_and_query = if let Some(query) = path_and_query.query() {
Cow::Owned(format!("{}?{}", new_path, query))
} else {
new_path.into()
}
.parse()
.unwrap();

Some(new_path_and_query)
} else {
Some(new_path.parse().unwrap())
};

parts.path_and_query = new_path_and_query;
if let Ok(new_uri) = Uri::from_parts(parts) {
*uri = new_uri;
}
}

#[cfg(test)]
mod tests {
use super::*;
use std::convert::Infallible;
use tower::{ServiceBuilder, ServiceExt};

#[tokio::test]
async fn works() {
async fn trim_works() {
async fn handle(request: Request<()>) -> Result<Response<String>, Infallible> {
Ok(Response::new(request.uri().to_string()))
}
Expand All @@ -168,63 +224,148 @@ mod tests {
#[test]
fn is_noop_if_no_trailing_slash() {
let mut uri = "/foo".parse::<Uri>().unwrap();
normalize_trailing_slash(&mut uri);
trim_trailing_slash(&mut uri);
assert_eq!(uri, "/foo");
}

#[test]
fn maintains_query() {
let mut uri = "/foo/?a=a".parse::<Uri>().unwrap();
normalize_trailing_slash(&mut uri);
trim_trailing_slash(&mut uri);
assert_eq!(uri, "/foo?a=a");
}

#[test]
fn removes_multiple_trailing_slashes() {
let mut uri = "/foo////".parse::<Uri>().unwrap();
normalize_trailing_slash(&mut uri);
trim_trailing_slash(&mut uri);
assert_eq!(uri, "/foo");
}

#[test]
fn removes_multiple_trailing_slashes_even_with_query() {
let mut uri = "/foo////?a=a".parse::<Uri>().unwrap();
normalize_trailing_slash(&mut uri);
trim_trailing_slash(&mut uri);
assert_eq!(uri, "/foo?a=a");
}

#[test]
fn is_noop_on_index() {
let mut uri = "/".parse::<Uri>().unwrap();
normalize_trailing_slash(&mut uri);
trim_trailing_slash(&mut uri);
assert_eq!(uri, "/");
}

#[test]
fn removes_multiple_trailing_slashes_on_index() {
let mut uri = "////".parse::<Uri>().unwrap();
normalize_trailing_slash(&mut uri);
trim_trailing_slash(&mut uri);
assert_eq!(uri, "/");
}

#[test]
fn removes_multiple_trailing_slashes_on_index_even_with_query() {
let mut uri = "////?a=a".parse::<Uri>().unwrap();
normalize_trailing_slash(&mut uri);
trim_trailing_slash(&mut uri);
assert_eq!(uri, "/?a=a");
}

#[test]
fn removes_multiple_preceding_slashes_even_with_query() {
let mut uri = "///foo//?a=a".parse::<Uri>().unwrap();
normalize_trailing_slash(&mut uri);
trim_trailing_slash(&mut uri);
assert_eq!(uri, "/foo?a=a");
}

#[test]
fn removes_multiple_preceding_slashes() {
let mut uri = "///foo".parse::<Uri>().unwrap();
normalize_trailing_slash(&mut uri);
trim_trailing_slash(&mut uri);
assert_eq!(uri, "/foo");
}

#[tokio::test]
async fn append_works() {
async fn handle(request: Request<()>) -> Result<Response<String>, Infallible> {
Ok(Response::new(request.uri().to_string()))
}

let mut svc = ServiceBuilder::new()
.layer(NormalizePathLayer::new(NormalizeMode::Append))
.service_fn(handle);

let body = svc
.ready()
.await
.unwrap()
.call(Request::builder().uri("/foo").body(()).unwrap())
.await
.unwrap()
.into_body();

assert_eq!(body, "/foo/");
}

#[test]
fn is_noop_if_trailing_slash() {
let mut uri = "/foo/".parse::<Uri>().unwrap();
append_trailing_slash(&mut uri);
assert_eq!(uri, "/foo/");
}

#[test]
fn append_maintains_query() {
let mut uri = "/foo?a=a".parse::<Uri>().unwrap();
append_trailing_slash(&mut uri);
assert_eq!(uri, "/foo/?a=a");
}

#[test]
fn append_only_keeps_one_slash() {
let mut uri = "/foo////".parse::<Uri>().unwrap();
append_trailing_slash(&mut uri);
assert_eq!(uri, "/foo/");
}

#[test]
fn append_only_keeps_one_slash_even_with_query() {
let mut uri = "/foo////?a=a".parse::<Uri>().unwrap();
append_trailing_slash(&mut uri);
assert_eq!(uri, "/foo/?a=a");
}

#[test]
fn append_is_noop_on_index() {
let mut uri = "/".parse::<Uri>().unwrap();
append_trailing_slash(&mut uri);
assert_eq!(uri, "/");
}

#[test]
fn append_removes_multiple_trailing_slashes_on_index() {
let mut uri = "////".parse::<Uri>().unwrap();
append_trailing_slash(&mut uri);
assert_eq!(uri, "/");
}

#[test]
fn append_removes_multiple_trailing_slashes_on_index_even_with_query() {
let mut uri = "////?a=a".parse::<Uri>().unwrap();
append_trailing_slash(&mut uri);
assert_eq!(uri, "/?a=a");
}

#[test]
fn append_removes_multiple_preceding_slashes_even_with_query() {
let mut uri = "///foo//?a=a".parse::<Uri>().unwrap();
append_trailing_slash(&mut uri);
assert_eq!(uri, "/foo/?a=a");
}

#[test]
fn append_removes_multiple_preceding_slashes() {
let mut uri = "///foo".parse::<Uri>().unwrap();
append_trailing_slash(&mut uri);
assert_eq!(uri, "/foo/");
}
}
Loading