Skip to content

Commit 042bbff

Browse files
committed
feat(ValidateRequestHeaderLayer): add assert() function
1 parent 92d1954 commit 042bbff

File tree

2 files changed

+167
-0
lines changed

2 files changed

+167
-0
lines changed

tower-http/CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1111

1212
- **request_id:** Derive `Default` for `MakeRequestUuid` ([#335])
1313
- **fs:** Derive `Default` for `ServeFileSystemResponseBody` ([#336])
14+
- **validate-request:** Add `ValidateRequestHeaderLayer::assert()` function to reject requests when a header does not have an expected value ([#360])
1415

1516
## Changed
1617

tower-http/src/validate_request.rs

+166
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
//!
33
//! # Example
44
//!
5+
//! Validation of the `Accept` header can be made by using [`ValidateRequestHeaderLayer::accept()`]:
6+
//!
57
//! ```
68
//! use tower_http::validate_request::ValidateRequestHeaderLayer;
79
//! use hyper::{Request, Response, Body, Error};
@@ -50,6 +52,70 @@
5052
//! # }
5153
//! ```
5254
//!
55+
//! Validation of a custom header can be made by using [`ValidateRequestHeaderLayer::assert()`]:
56+
//!
57+
//! ```
58+
//! use tower_http::validate_request::ValidateRequestHeaderLayer;
59+
//! use hyper::{Request, Response, Body, Error};
60+
//! use http::StatusCode;
61+
//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn};
62+
//!
63+
//! async fn handle(request: Request<Body>) -> Result<Response<Body>, Error> {
64+
//! Ok(Response::new(Body::empty()))
65+
//! }
66+
//!
67+
//! # #[tokio::main]
68+
//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
69+
//! let mut service = ServiceBuilder::new()
70+
//! // Require a `X-Custom-Header` header to have the value `random-value-1234567890` or reject with a `403 Forbidden` response
71+
//! .layer(ValidateRequestHeaderLayer::assert("x-custom-header", "random-value-1234567890", StatusCode::FORBIDDEN))
72+
//! .service_fn(handle);
73+
//!
74+
//! // Requests with the correct value are allowed through
75+
//! let request = Request::builder()
76+
//! .header("x-custom-header", "random-value-1234567890")
77+
//! .body(Body::empty())
78+
//! .unwrap();
79+
//!
80+
//! let response = service
81+
//! .ready()
82+
//! .await?
83+
//! .call(request)
84+
//! .await?;
85+
//!
86+
//! assert_eq!(StatusCode::OK, response.status());
87+
//!
88+
//! // Requests with an invalid value get a `403 Forbidden` response
89+
//! let request = Request::builder()
90+
//! .header("x-custom-header", "wrong-value")
91+
//! .body(Body::empty())
92+
//! .unwrap();
93+
//!
94+
//! let response = service
95+
//! .ready()
96+
//! .await?
97+
//! .call(request)
98+
//! .await?;
99+
//!
100+
//! assert_eq!(StatusCode::FORBIDDEN, response.status());
101+
//! #
102+
//! # // Requests without the expected header also get a `403 Forbidden` response
103+
//! # let request = Request::builder()
104+
//! # .body(Body::empty())
105+
//! # .unwrap();
106+
//! #
107+
//! # let response = service
108+
//! # .ready()
109+
//! # .await?
110+
//! # .call(request)
111+
//! # .await?;
112+
//! #
113+
//! # assert_eq!(StatusCode::FORBIDDEN, response.status());
114+
//! #
115+
//! # Ok(())
116+
//! # }
117+
//! ```
118+
//!
53119
//! Custom validation can be made by implementing [`ValidateRequest`]:
54120
//!
55121
//! ```
@@ -112,6 +178,8 @@
112178
//! # Ok(())
113179
//! # }
114180
//! ```
181+
//!
182+
//! [`Accept`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept
115183
116184
use http::{header, Request, Response, StatusCode};
117185
use http_body::Body;
@@ -165,6 +233,34 @@ impl<ResBody> ValidateRequestHeaderLayer<AcceptHeader<ResBody>> {
165233
}
166234
}
167235

236+
impl<ResBody> ValidateRequestHeaderLayer<AssertHeaderOrReject<ResBody>> {
237+
/// Validate requests have a required header with a specific value.
238+
///
239+
/// # Example
240+
///
241+
/// ```
242+
/// use http::StatusCode;
243+
/// use hyper::Body;
244+
/// use tower_http::validate_request::{AssertHeaderOrReject, ValidateRequestHeaderLayer};
245+
///
246+
/// let layer = ValidateRequestHeaderLayer::<AssertHeaderOrReject<Body>>::assert("x-custom-header", "random-value-1234567890", StatusCode::FORBIDDEN);
247+
/// ```
248+
pub fn assert(
249+
expected_header_name: &str,
250+
expected_header_value: &str,
251+
response_status_code: StatusCode,
252+
) -> Self
253+
where
254+
ResBody: Body + Default,
255+
{
256+
Self::custom(AssertHeaderOrReject::new(
257+
expected_header_name,
258+
expected_header_value,
259+
response_status_code,
260+
))
261+
}
262+
}
263+
168264
impl<T> ValidateRequestHeaderLayer<T> {
169265
/// Validate requests using a custom method.
170266
pub fn custom(validate: T) -> ValidateRequestHeaderLayer<T> {
@@ -409,6 +505,76 @@ where
409505
}
410506
}
411507

508+
/// Type that rejects requests if a header is not present or does not have an expected value.
509+
pub struct AssertHeaderOrReject<ResBody> {
510+
expected_header_name: String,
511+
expected_header_value: String,
512+
response_status_code: StatusCode,
513+
_ty: PhantomData<fn() -> ResBody>,
514+
}
515+
516+
impl<ResBody> AssertHeaderOrReject<ResBody> {
517+
/// Create a new `AssertHeaderOrReject` struct.
518+
fn new(
519+
expected_header_name: &str,
520+
expected_header_value: &str,
521+
response_status_code: StatusCode,
522+
) -> Self
523+
where
524+
ResBody: Body + Default,
525+
{
526+
Self {
527+
expected_header_name: expected_header_name.to_string(),
528+
expected_header_value: expected_header_value.to_string(),
529+
response_status_code,
530+
_ty: PhantomData,
531+
}
532+
}
533+
}
534+
535+
impl<ResBody> Clone for AssertHeaderOrReject<ResBody> {
536+
fn clone(&self) -> Self {
537+
Self {
538+
expected_header_name: self.expected_header_name.clone(),
539+
expected_header_value: self.expected_header_value.clone(),
540+
response_status_code: self.response_status_code.clone(),
541+
_ty: PhantomData,
542+
}
543+
}
544+
}
545+
546+
impl<ResBody> fmt::Debug for AssertHeaderOrReject<ResBody> {
547+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
548+
f.debug_struct("AssertHeaderOrReject")
549+
.field("expected_header_name", &self.expected_header_name)
550+
.field("expected_header_value", &self.expected_header_value)
551+
.field("response_status_code", &self.response_status_code)
552+
.finish()
553+
}
554+
}
555+
556+
impl<B, ResBody> ValidateRequest<B> for AssertHeaderOrReject<ResBody>
557+
where
558+
ResBody: Body + Default,
559+
{
560+
type ResponseBody = ResBody;
561+
562+
fn validate(&mut self, req: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>> {
563+
let request_header_value = req
564+
.headers()
565+
.get(&self.expected_header_name)
566+
.and_then(|v| v.to_str().ok());
567+
568+
if request_header_value != Some(&self.expected_header_value) {
569+
let mut res = Response::new(ResBody::default());
570+
*res.status_mut() = self.response_status_code;
571+
return Err(res);
572+
}
573+
574+
Ok(())
575+
}
576+
}
577+
412578
#[cfg(test)]
413579
mod tests {
414580
#[allow(unused_imports)]

0 commit comments

Comments
 (0)