Skip to content

Commit 72be335

Browse files
committed
normalize_path: Add Append mode
1 parent d0c522b commit 72be335

File tree

2 files changed

+182
-23
lines changed

2 files changed

+182
-23
lines changed

tower-http/src/builder.rs

+18
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,17 @@ pub trait ServiceBuilderExt<L>: crate::sealed::Sealed<L> + Sized {
364364
fn trim_trailing_slash(
365365
self,
366366
) -> ServiceBuilder<Stack<crate::normalize_path::NormalizePathLayer, L>>;
367+
368+
/// Normalize paths based on the specified `mode`.
369+
///
370+
/// See [`tower_http::normalize_path`] for more details.
371+
///
372+
/// [`tower_http::normalize_path`]: crate::normalize_path
373+
#[cfg(feature = "normalize-path")]
374+
fn normalize_path(
375+
self,
376+
mode: crate::normalize_path::NormalizeMode,
377+
) -> ServiceBuilder<Stack<crate::normalize_path::NormalizePathLayer, L>>;
367378
}
368379

369380
impl<L> crate::sealed::Sealed<L> for ServiceBuilder<L> {}
@@ -594,4 +605,11 @@ impl<L> ServiceBuilderExt<L> for ServiceBuilder<L> {
594605
) -> ServiceBuilder<Stack<crate::normalize_path::NormalizePathLayer, L>> {
595606
self.layer(crate::normalize_path::NormalizePathLayer::trim_trailing_slash())
596607
}
608+
#[cfg(feature = "normalize-path")]
609+
fn normalize_path(
610+
self,
611+
mode: crate::normalize_path::NormalizeMode,
612+
) -> ServiceBuilder<Stack<crate::normalize_path::NormalizePathLayer, L>> {
613+
self.layer(crate::normalize_path::NormalizePathLayer::new(mode))
614+
}
597615
}

tower-http/src/normalize_path.rs

+164-23
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
//! Middleware that normalizes paths.
22
//!
3-
//! Any trailing slashes from request paths will be removed. For example, a request with `/foo/`
4-
//! will be changed to `/foo` before reaching the inner service.
3+
//! Normalizes the request paths based on the provided `NormalizeMode`
54
//!
65
//! # Example
76
//!
87
//! ```
9-
//! use tower_http::normalize_path::NormalizePathLayer;
8+
//! use tower_http::normalize_path::{NormalizePathLayer, NormalizeMode};
109
//! use http::{Request, Response, StatusCode};
1110
//! use http_body_util::Full;
1211
//! use bytes::Bytes;
@@ -22,7 +21,7 @@
2221
//!
2322
//! let mut service = ServiceBuilder::new()
2423
//! // trim trailing slashes from paths
25-
//! .layer(NormalizePathLayer::trim_trailing_slash())
24+
//! .layer(NormalizePathLayer::new(NormalizeMode::Trim))
2625
//! .service_fn(handle);
2726
//!
2827
//! // call the service
@@ -45,27 +44,47 @@ use std::{
4544
use tower_layer::Layer;
4645
use tower_service::Service;
4746

47+
/// Different modes of normalizing paths
48+
#[derive(Debug, Copy, Clone)]
49+
pub enum NormalizeMode {
50+
/// Normalizes paths by trimming the trailing slashes, e.g. /foo/ -> /foo
51+
Trim,
52+
/// Normalizes paths by appending trailing slash, e.g. /foo -> /foo/
53+
Append,
54+
}
55+
4856
/// Layer that applies [`NormalizePath`] which normalizes paths.
4957
///
5058
/// See the [module docs](self) for more details.
5159
#[derive(Debug, Copy, Clone)]
52-
pub struct NormalizePathLayer {}
60+
pub struct NormalizePathLayer {
61+
mode: NormalizeMode,
62+
}
5363

5464
impl NormalizePathLayer {
5565
/// Create a new [`NormalizePathLayer`].
5666
///
5767
/// Any trailing slashes from request paths will be removed. For example, a request with `/foo/`
5868
/// will be changed to `/foo` before reaching the inner service.
5969
pub fn trim_trailing_slash() -> Self {
60-
NormalizePathLayer {}
70+
NormalizePathLayer {
71+
mode: NormalizeMode::Trim,
72+
}
73+
}
74+
75+
/// Create a new [`NormalizePathLayer`].
76+
///
77+
/// Creates a new `NormalizePathLayer` with the specified mode.
78+
pub fn new(mode: NormalizeMode) -> Self {
79+
NormalizePathLayer { mode }
6180
}
6281
}
6382

6483
impl<S> Layer<S> for NormalizePathLayer {
6584
type Service = NormalizePath<S>;
6685

6786
fn layer(&self, inner: S) -> Self::Service {
68-
NormalizePath::trim_trailing_slash(inner)
87+
NormalizePath::new(inner, self.mode)
6988
}
7089
}
7190

@@ -74,16 +93,16 @@ impl<S> Layer<S> for NormalizePathLayer {
7493
/// See the [module docs](self) for more details.
7594
#[derive(Debug, Copy, Clone)]
7695
pub struct NormalizePath<S> {
96+
mode: NormalizeMode,
7797
inner: S,
7898
}
7999

80100
impl<S> NormalizePath<S> {
81101
/// Create a new [`NormalizePath`].
82102
///
83-
/// Any trailing slashes from request paths will be removed. For example, a request with `/foo/`
84-
/// will be changed to `/foo` before reaching the inner service.
85-
pub fn trim_trailing_slash(inner: S) -> Self {
86-
Self { inner }
103+
/// Normalize path based on the specified `mode`
104+
pub fn new(inner: S, mode: NormalizeMode) -> Self {
105+
Self { mode, inner }
87106
}
88107

89108
define_inner_service_accessors!();
@@ -103,12 +122,15 @@ where
103122
}
104123

105124
fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
106-
normalize_trailing_slash(req.uri_mut());
125+
match self.mode {
126+
NormalizeMode::Trim => trim_trailing_slash(req.uri_mut()),
127+
NormalizeMode::Append => append_trailing_slash(req.uri_mut()),
128+
}
107129
self.inner.call(req)
108130
}
109131
}
110132

111-
fn normalize_trailing_slash(uri: &mut Uri) {
133+
fn trim_trailing_slash(uri: &mut Uri) {
112134
if !uri.path().ends_with('/') && !uri.path().starts_with("//") {
113135
return;
114136
}
@@ -137,14 +159,48 @@ fn normalize_trailing_slash(uri: &mut Uri) {
137159
}
138160
}
139161

162+
fn append_trailing_slash(uri: &mut Uri) {
163+
if uri.path().ends_with("/") && !uri.path().ends_with("//") {
164+
return;
165+
}
166+
167+
let trimmed = uri.path().trim_matches('/');
168+
let new_path = if trimmed.is_empty() {
169+
"/".to_string()
170+
} else {
171+
format!("/{}/", trimmed)
172+
};
173+
174+
let mut parts = uri.clone().into_parts();
175+
176+
let new_path_and_query = if let Some(path_and_query) = &parts.path_and_query {
177+
let new_path_and_query = if let Some(query) = path_and_query.query() {
178+
Cow::Owned(format!("{}?{}", new_path, query))
179+
} else {
180+
new_path.into()
181+
}
182+
.parse()
183+
.unwrap();
184+
185+
Some(new_path_and_query)
186+
} else {
187+
Some(new_path.parse().unwrap())
188+
};
189+
190+
parts.path_and_query = new_path_and_query;
191+
if let Ok(new_uri) = Uri::from_parts(parts) {
192+
*uri = new_uri;
193+
}
194+
}
195+
140196
#[cfg(test)]
141197
mod tests {
142198
use super::*;
143199
use std::convert::Infallible;
144200
use tower::{ServiceBuilder, ServiceExt};
145201

146202
#[tokio::test]
147-
async fn works() {
203+
async fn trim_works() {
148204
async fn handle(request: Request<()>) -> Result<Response<String>, Infallible> {
149205
Ok(Response::new(request.uri().to_string()))
150206
}
@@ -168,63 +224,148 @@ mod tests {
168224
#[test]
169225
fn is_noop_if_no_trailing_slash() {
170226
let mut uri = "/foo".parse::<Uri>().unwrap();
171-
normalize_trailing_slash(&mut uri);
227+
trim_trailing_slash(&mut uri);
172228
assert_eq!(uri, "/foo");
173229
}
174230

175231
#[test]
176232
fn maintains_query() {
177233
let mut uri = "/foo/?a=a".parse::<Uri>().unwrap();
178-
normalize_trailing_slash(&mut uri);
234+
trim_trailing_slash(&mut uri);
179235
assert_eq!(uri, "/foo?a=a");
180236
}
181237

182238
#[test]
183239
fn removes_multiple_trailing_slashes() {
184240
let mut uri = "/foo////".parse::<Uri>().unwrap();
185-
normalize_trailing_slash(&mut uri);
241+
trim_trailing_slash(&mut uri);
186242
assert_eq!(uri, "/foo");
187243
}
188244

189245
#[test]
190246
fn removes_multiple_trailing_slashes_even_with_query() {
191247
let mut uri = "/foo////?a=a".parse::<Uri>().unwrap();
192-
normalize_trailing_slash(&mut uri);
248+
trim_trailing_slash(&mut uri);
193249
assert_eq!(uri, "/foo?a=a");
194250
}
195251

196252
#[test]
197253
fn is_noop_on_index() {
198254
let mut uri = "/".parse::<Uri>().unwrap();
199-
normalize_trailing_slash(&mut uri);
255+
trim_trailing_slash(&mut uri);
200256
assert_eq!(uri, "/");
201257
}
202258

203259
#[test]
204260
fn removes_multiple_trailing_slashes_on_index() {
205261
let mut uri = "////".parse::<Uri>().unwrap();
206-
normalize_trailing_slash(&mut uri);
262+
trim_trailing_slash(&mut uri);
207263
assert_eq!(uri, "/");
208264
}
209265

210266
#[test]
211267
fn removes_multiple_trailing_slashes_on_index_even_with_query() {
212268
let mut uri = "////?a=a".parse::<Uri>().unwrap();
213-
normalize_trailing_slash(&mut uri);
269+
trim_trailing_slash(&mut uri);
214270
assert_eq!(uri, "/?a=a");
215271
}
216272

217273
#[test]
218274
fn removes_multiple_preceding_slashes_even_with_query() {
219275
let mut uri = "///foo//?a=a".parse::<Uri>().unwrap();
220-
normalize_trailing_slash(&mut uri);
276+
trim_trailing_slash(&mut uri);
221277
assert_eq!(uri, "/foo?a=a");
222278
}
223279

224280
#[test]
225281
fn removes_multiple_preceding_slashes() {
226282
let mut uri = "///foo".parse::<Uri>().unwrap();
227-
normalize_trailing_slash(&mut uri);
283+
trim_trailing_slash(&mut uri);
228284
assert_eq!(uri, "/foo");
229285
}
286+
287+
#[tokio::test]
288+
async fn append_works() {
289+
async fn handle(request: Request<()>) -> Result<Response<String>, Infallible> {
290+
Ok(Response::new(request.uri().to_string()))
291+
}
292+
293+
let mut svc = ServiceBuilder::new()
294+
.layer(NormalizePathLayer::new(NormalizeMode::Trim))
295+
.service_fn(handle);
296+
297+
let body = svc
298+
.ready()
299+
.await
300+
.unwrap()
301+
.call(Request::builder().uri("/foo").body(()).unwrap())
302+
.await
303+
.unwrap()
304+
.into_body();
305+
306+
assert_eq!(body, "/foo/");
307+
}
308+
309+
#[test]
310+
fn is_noop_if_trailing_slash() {
311+
let mut uri = "/foo/".parse::<Uri>().unwrap();
312+
append_trailing_slash(&mut uri);
313+
assert_eq!(uri, "/foo/");
314+
}
315+
316+
#[test]
317+
fn append_maintains_query() {
318+
let mut uri = "/foo?a=a".parse::<Uri>().unwrap();
319+
append_trailing_slash(&mut uri);
320+
assert_eq!(uri, "/foo/?a=a");
321+
}
322+
323+
#[test]
324+
fn append_only_keeps_one_slash() {
325+
let mut uri = "/foo////".parse::<Uri>().unwrap();
326+
append_trailing_slash(&mut uri);
327+
assert_eq!(uri, "/foo/");
328+
}
329+
330+
#[test]
331+
fn append_only_keeps_one_slash_even_with_query() {
332+
let mut uri = "/foo////?a=a".parse::<Uri>().unwrap();
333+
append_trailing_slash(&mut uri);
334+
assert_eq!(uri, "/foo/?a=a");
335+
}
336+
337+
#[test]
338+
fn append_is_noop_on_index() {
339+
let mut uri = "/".parse::<Uri>().unwrap();
340+
append_trailing_slash(&mut uri);
341+
assert_eq!(uri, "/");
342+
}
343+
344+
#[test]
345+
fn append_removes_multiple_trailing_slashes_on_index() {
346+
let mut uri = "////".parse::<Uri>().unwrap();
347+
append_trailing_slash(&mut uri);
348+
assert_eq!(uri, "/");
349+
}
350+
351+
#[test]
352+
fn append_removes_multiple_trailing_slashes_on_index_even_with_query() {
353+
let mut uri = "////?a=a".parse::<Uri>().unwrap();
354+
append_trailing_slash(&mut uri);
355+
assert_eq!(uri, "/?a=a");
356+
}
357+
358+
#[test]
359+
fn append_removes_multiple_preceding_slashes_even_with_query() {
360+
let mut uri = "///foo//?a=a".parse::<Uri>().unwrap();
361+
append_trailing_slash(&mut uri);
362+
assert_eq!(uri, "/foo/?a=a");
363+
}
364+
365+
#[test]
366+
fn append_removes_multiple_preceding_slashes() {
367+
let mut uri = "///foo".parse::<Uri>().unwrap();
368+
append_trailing_slash(&mut uri);
369+
assert_eq!(uri, "/foo/");
370+
}
230371
}

0 commit comments

Comments
 (0)