|
2 | 2 | //!
|
3 | 3 | //! # Example
|
4 | 4 | //!
|
| 5 | +//! Validation of the `Accept` header can be made by using [`ValidateRequestHeaderLayer::accept()`]: |
| 6 | +//! |
5 | 7 | //! ```
|
6 | 8 | //! use tower_http::validate_request::ValidateRequestHeaderLayer;
|
7 | 9 | //! use hyper::{Request, Response, Body, Error};
|
|
50 | 52 | //! # }
|
51 | 53 | //! ```
|
52 | 54 | //!
|
| 55 | +//! Validation of a custom header can be made by using [`ValidateRequestHeaderLayer::assert()`]: |
| 56 | +//! |
| 57 | +//! ``` |
| 58 | +//! use tower_http::validate_request::ValidateRequestHeaderLayer; |
| 59 | +//! use hyper::{Request, Response, Body, Error}; |
| 60 | +//! use http::StatusCode; |
| 61 | +//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; |
| 62 | +//! |
| 63 | +//! async fn handle(request: Request<Body>) -> Result<Response<Body>, Error> { |
| 64 | +//! Ok(Response::new(Body::empty())) |
| 65 | +//! } |
| 66 | +//! |
| 67 | +//! # #[tokio::main] |
| 68 | +//! # async fn main() -> Result<(), Box<dyn std::error::Error>> { |
| 69 | +//! let mut service = ServiceBuilder::new() |
| 70 | +//! // Require a `X-Custom-Header` header to have the value `random-value-1234567890` or reject with a `403 Forbidden` response |
| 71 | +//! .layer(ValidateRequestHeaderLayer::assert("x-custom-header", "random-value-1234567890", StatusCode::FORBIDDEN)) |
| 72 | +//! .service_fn(handle); |
| 73 | +//! |
| 74 | +//! // Requests with the correct value are allowed through |
| 75 | +//! let request = Request::builder() |
| 76 | +//! .header("x-custom-header", "random-value-1234567890") |
| 77 | +//! .body(Body::empty()) |
| 78 | +//! .unwrap(); |
| 79 | +//! |
| 80 | +//! let response = service |
| 81 | +//! .ready() |
| 82 | +//! .await? |
| 83 | +//! .call(request) |
| 84 | +//! .await?; |
| 85 | +//! |
| 86 | +//! assert_eq!(StatusCode::OK, response.status()); |
| 87 | +//! |
| 88 | +//! // Requests with an invalid value get a `403 Forbidden` response |
| 89 | +//! let request = Request::builder() |
| 90 | +//! .header("x-custom-header", "wrong-value") |
| 91 | +//! .body(Body::empty()) |
| 92 | +//! .unwrap(); |
| 93 | +//! |
| 94 | +//! let response = service |
| 95 | +//! .ready() |
| 96 | +//! .await? |
| 97 | +//! .call(request) |
| 98 | +//! .await?; |
| 99 | +//! |
| 100 | +//! assert_eq!(StatusCode::FORBIDDEN, response.status()); |
| 101 | +//! # |
| 102 | +//! # // Requests without the expected header also get a `403 Forbidden` response |
| 103 | +//! # let request = Request::builder() |
| 104 | +//! # .body(Body::empty()) |
| 105 | +//! # .unwrap(); |
| 106 | +//! # |
| 107 | +//! # let response = service |
| 108 | +//! # .ready() |
| 109 | +//! # .await? |
| 110 | +//! # .call(request) |
| 111 | +//! # .await?; |
| 112 | +//! # |
| 113 | +//! # assert_eq!(StatusCode::FORBIDDEN, response.status()); |
| 114 | +//! # |
| 115 | +//! # Ok(()) |
| 116 | +//! # } |
| 117 | +//! ``` |
| 118 | +//! |
53 | 119 | //! Custom validation can be made by implementing [`ValidateRequest`]:
|
54 | 120 | //!
|
55 | 121 | //! ```
|
|
112 | 178 | //! # Ok(())
|
113 | 179 | //! # }
|
114 | 180 | //! ```
|
| 181 | +//! |
| 182 | +//! [`Accept`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept |
115 | 183 |
|
116 | 184 | use http::{header, Request, Response, StatusCode};
|
117 | 185 | use http_body::Body;
|
@@ -165,6 +233,34 @@ impl<ResBody> ValidateRequestHeaderLayer<AcceptHeader<ResBody>> {
|
165 | 233 | }
|
166 | 234 | }
|
167 | 235 |
|
| 236 | +impl<ResBody> ValidateRequestHeaderLayer<AssertHeaderOrReject<ResBody>> { |
| 237 | + /// Validate requests have a required header with a specific value. |
| 238 | + /// |
| 239 | + /// # Example |
| 240 | + /// |
| 241 | + /// ``` |
| 242 | + /// use http::StatusCode; |
| 243 | + /// use hyper::Body; |
| 244 | + /// use tower_http::validate_request::{AssertHeaderOrReject, ValidateRequestHeaderLayer}; |
| 245 | + /// |
| 246 | + /// let layer = ValidateRequestHeaderLayer::<AssertHeaderOrReject<Body>>::assert("x-custom-header", "random-value-1234567890", StatusCode::FORBIDDEN); |
| 247 | + /// ``` |
| 248 | + pub fn assert( |
| 249 | + expected_header_name: &str, |
| 250 | + expected_header_value: &str, |
| 251 | + response_status_code: StatusCode, |
| 252 | + ) -> Self |
| 253 | + where |
| 254 | + ResBody: Body + Default, |
| 255 | + { |
| 256 | + Self::custom(AssertHeaderOrReject::new( |
| 257 | + expected_header_name, |
| 258 | + expected_header_value, |
| 259 | + response_status_code, |
| 260 | + )) |
| 261 | + } |
| 262 | +} |
| 263 | + |
168 | 264 | impl<T> ValidateRequestHeaderLayer<T> {
|
169 | 265 | /// Validate requests using a custom method.
|
170 | 266 | pub fn custom(validate: T) -> ValidateRequestHeaderLayer<T> {
|
@@ -409,6 +505,76 @@ where
|
409 | 505 | }
|
410 | 506 | }
|
411 | 507 |
|
| 508 | +/// Type that rejects requests if a header is not present or does not have an expected value. |
| 509 | +pub struct AssertHeaderOrReject<ResBody> { |
| 510 | + expected_header_name: String, |
| 511 | + expected_header_value: String, |
| 512 | + response_status_code: StatusCode, |
| 513 | + _ty: PhantomData<fn() -> ResBody>, |
| 514 | +} |
| 515 | + |
| 516 | +impl<ResBody> AssertHeaderOrReject<ResBody> { |
| 517 | + /// Create a new `AssertHeaderOrReject` struct. |
| 518 | + fn new( |
| 519 | + expected_header_name: &str, |
| 520 | + expected_header_value: &str, |
| 521 | + response_status_code: StatusCode, |
| 522 | + ) -> Self |
| 523 | + where |
| 524 | + ResBody: Body + Default, |
| 525 | + { |
| 526 | + Self { |
| 527 | + expected_header_name: expected_header_name.to_string(), |
| 528 | + expected_header_value: expected_header_value.to_string(), |
| 529 | + response_status_code, |
| 530 | + _ty: PhantomData, |
| 531 | + } |
| 532 | + } |
| 533 | +} |
| 534 | + |
| 535 | +impl<ResBody> Clone for AssertHeaderOrReject<ResBody> { |
| 536 | + fn clone(&self) -> Self { |
| 537 | + Self { |
| 538 | + expected_header_name: self.expected_header_name.clone(), |
| 539 | + expected_header_value: self.expected_header_value.clone(), |
| 540 | + response_status_code: self.response_status_code.clone(), |
| 541 | + _ty: PhantomData, |
| 542 | + } |
| 543 | + } |
| 544 | +} |
| 545 | + |
| 546 | +impl<ResBody> fmt::Debug for AssertHeaderOrReject<ResBody> { |
| 547 | + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 548 | + f.debug_struct("AssertHeaderOrReject") |
| 549 | + .field("expected_header_name", &self.expected_header_name) |
| 550 | + .field("expected_header_value", &self.expected_header_value) |
| 551 | + .field("response_status_code", &self.response_status_code) |
| 552 | + .finish() |
| 553 | + } |
| 554 | +} |
| 555 | + |
| 556 | +impl<B, ResBody> ValidateRequest<B> for AssertHeaderOrReject<ResBody> |
| 557 | +where |
| 558 | + ResBody: Body + Default, |
| 559 | +{ |
| 560 | + type ResponseBody = ResBody; |
| 561 | + |
| 562 | + fn validate(&mut self, req: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>> { |
| 563 | + let request_header_value = req |
| 564 | + .headers() |
| 565 | + .get(&self.expected_header_name) |
| 566 | + .and_then(|v| v.to_str().ok()); |
| 567 | + |
| 568 | + if request_header_value != Some(&self.expected_header_value) { |
| 569 | + let mut res = Response::new(ResBody::default()); |
| 570 | + *res.status_mut() = self.response_status_code; |
| 571 | + return Err(res); |
| 572 | + } |
| 573 | + |
| 574 | + Ok(()) |
| 575 | + } |
| 576 | +} |
| 577 | + |
412 | 578 | #[cfg(test)]
|
413 | 579 | mod tests {
|
414 | 580 | #[allow(unused_imports)]
|
|
0 commit comments