From 4f02a05170c297747882c64840572c2d4f3a8831 Mon Sep 17 00:00:00 2001 From: Stefan Date: Fri, 17 Jan 2025 17:04:41 +0100 Subject: [PATCH] Migrate server away from `warp` and start using `axum` --- server/Cargo.toml | 12 +-- server/src/auth_filter.rs | 53 ---------- server/src/lib.rs | 199 +++++++++++++++++++++----------------- 3 files changed, 115 insertions(+), 149 deletions(-) delete mode 100644 server/src/auth_filter.rs diff --git a/server/Cargo.toml b/server/Cargo.toml index d4779cb..5fa8880 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -11,26 +11,26 @@ categories.workspace = true keywords.workspace = true [dependencies] +axum = { version = "0.8.1", features = ["ws"] } +axum-extra = { version = "0.10.0", features = ["typed-header"] } async-trait = "0.1" blake2 = "0.10" -bytes = "1.4" futures = "0.3" -headers = "0.3" -http = "0.2" +http = "1.2.0" log = "0.4" serde = "1.0" serde_json = "1.0" subtle = "2.5" thiserror = "1.0" -tokio = { version = "1.25", features = ["sync"] } -warp = "0.3" +tokio = { version = "1.43.0", features = ["sync"] } +tower-http = { version = "0.6.2", features = ["auth", "cors"] } nimiq-jsonrpc-core = { workspace = true } [dev-dependencies] anyhow = "1.0" pretty_env_logger = "0.5.0" -tokio = { version = "1.25", features = ["macros", "rt-multi-thread"] } +tokio = { version = "1.43.0", features = ["macros", "rt-multi-thread"] } nimiq-jsonrpc-client = { workspace = true } nimiq-jsonrpc-derive = { workspace = true } diff --git a/server/src/auth_filter.rs b/server/src/auth_filter.rs deleted file mode 100644 index 047dbff..0000000 --- a/server/src/auth_filter.rs +++ /dev/null @@ -1,53 +0,0 @@ -//! HTTP Authentication Filters - -use headers::{authorization::Basic, Authorization, HeaderMap, HeaderMapExt}; -use http::Response; -use warp::{header, http::StatusCode, reject, reject::Rejection, Filter, Reply}; - -/// Custom reject reason when the authorization header is wrong or is not found. -#[derive(Debug)] -pub struct Unauthorized { - pub(crate) realm: String, -} - -impl reject::Reject for Unauthorized {} - -/// Handles a `Unauthorized` rejection to return a 401 Unauthorized status. -pub(crate) async fn handle_auth_rejection( - err: Rejection, -) -> Result { - let response = if err.is_not_found() { - Response::builder().status(StatusCode::NOT_FOUND).body("") - } else if let Some(error) = err.find::() { - Response::builder() - .status(StatusCode::UNAUTHORIZED) - .header( - "WWW-Authenticate", - format!("Basic realm = \"{}\"", error.realm), - ) - .body("") - } else { - Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) - .body("") - }; - Ok(response) -} - -/// Creates a `Filter` to the HTTP Basic Authentication header. -/// If none was sent by the client, this filter will reject any request. -/// -/// The `handle_rejection` recover filter can be used along with this filter to -/// properly return a 401 Unauthorized status whenever the header is not found. -/// -pub(crate) fn basic_auth_filter( - realm: &str, -) -> impl Filter,), Error = Rejection> + '_ { - header::headers_cloned().and_then(move |headers: HeaderMap| async move { - headers.typed_get::>().ok_or_else(|| { - reject::custom(Unauthorized { - realm: realm.to_string(), - }) - }) - }) -} diff --git a/server/src/lib.rs b/server/src/lib.rs index 55d1af3..2e5425a 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -4,13 +4,11 @@ #![warn(missing_docs)] #![warn(rustdoc::missing_doc_code_examples)] -mod auth_filter; - use std::{ collections::HashSet, error, fmt::{self, Debug}, - future::{self, Future}, + future::Future, net::{IpAddr, SocketAddr}, sync::{ atomic::{AtomicU64, Ordering}, @@ -19,33 +17,49 @@ use std::{ }; use async_trait::async_trait; +use axum::{ + body::{Body, Bytes}, + extract::{DefaultBodyLimit, State, WebSocketUpgrade}, + http::{header::CONTENT_TYPE, response::Builder, StatusCode}, + middleware::Next, + response::{IntoResponse as _, Response as HttpResponse}, + routing::{any, post}, + Router, +}; +use axum_extra::{ + headers::{authorization::Basic, Authorization}, + TypedHeader, +}; use blake2::{digest::consts::U32, Blake2b, Digest}; -use bytes::Bytes; use futures::{ pin_mut, sink::SinkExt, stream::{FuturesUnordered, StreamExt}, Stream, }; -use headers::{authorization::Basic, Authorization}; +use http::{HeaderValue, Method}; use serde::{de::Deserialize, ser::Serialize}; use serde_json::Value; use subtle::ConstantTimeEq; use thiserror::Error; -use tokio::sync::{mpsc, RwLock, RwLockReadGuard, RwLockWriteGuard}; -pub use warp::filters::ws::Message; -use warp::{filters::cors::Builder, Filter}; +use tokio::{ + net::TcpListener, + sync::{mpsc, RwLock, RwLockReadGuard, RwLockWriteGuard}, +}; use nimiq_jsonrpc_core::{ Request, Response, RpcError, Sensitive, SingleOrBatch, SubscriptionId, SubscriptionMessage, }; +pub use axum::extract::ws::Message; +use tower_http::cors::{Any, CorsLayer}; + /// A server error. #[derive(Debug, Error)] pub enum Error { - /// Error returned by warp + /// Error returned by axum #[error("HTTP error: {0}")] - Warp(#[from] warp::Error), + Axum(#[from] axum::Error), /// Error from the message queues, that are used internally. #[error("Queue error: {0}")] @@ -100,33 +114,63 @@ fn blake2b(bytes: &[u8]) -> [u8; 32] { *Blake2b::::digest(bytes).as_ref() } +async fn basic_auth_middleware( + State(state): State>>, + basic_auth_header: Option>>, + request: axum::extract::Request, + next: Next, +) -> HttpResponse { + if let Some(auth_config) = &state.config.basic_auth { + if let Some(auth_header) = basic_auth_header { + if auth_config + .verify(auth_header.username(), auth_header.password()) + .is_ok() + { + return next.run(request).await; + } + return StatusCode::UNAUTHORIZED.into_response(); + } else { + return StatusCode::UNAUTHORIZED.into_response(); + } + } + + next.run(request).await +} + #[derive(Clone, Debug)] /// CORS configuration -pub struct Cors(Builder); +pub struct Cors(CorsLayer); impl Cors { /// Create a new instance with `Content-Type` as mandatory header and `POST` as mandatory method. pub fn new() -> Self { Self( - warp::cors() - .allow_header("Content-Type") - .allow_method("POST"), + CorsLayer::new() + .allow_headers([CONTENT_TYPE]) + .allow_methods([Method::POST]), ) } /// Configure CORS to only allow specific origins. + /// Note that multiple calls to this method will override any previous origin-related calls. pub fn with_origins(mut self, origins: Vec<&str>) -> Self { - self.0 = self.0.allow_origins(origins); + self.0 = self.0.allow_origin::>( + origins + .iter() + .map(|o| o.parse::().unwrap()) + .collect(), + ); self } /// Configure CORS to allow every origin. Also known as the `*` wildcard. + /// Note that multiple calls to this method will override any previous origin-related calls. pub fn with_any_origin(mut self) -> Self { - self.0 = self.0.allow_any_origin(); + self.0 = self.0.allow_origin(Any); self } - pub(crate) fn into_wrapper(self) -> Builder { + pub(crate) fn into_layer(self) -> CorsLayer { self.0 } } @@ -235,75 +279,48 @@ impl Server { /// Runs the server forever. pub async fn run(&self) { - // Route to use JSON-RPC over websocket let inner = Arc::clone(&self.inner); - let ws_route = warp::path("ws") - .and(warp::path::end()) - .and(warp::ws()) - .map(move |ws| Self::upgrade_to_ws(Arc::clone(&inner), ws)); + let http_router = Router::new().route( + "/", + post(|body: Bytes| async move { + let data = Self::handle_raw_request(inner, &Message::binary(body), None) + .await + .unwrap_or(Message::Binary(Bytes::new())); + + Builder::new() + .status(StatusCode::OK) + .header(CONTENT_TYPE, "application/json") + .body(Body::from(data.into_data().to_owned())) + .unwrap() // As long as the hard-coded status code and content-type is correct, this won't fail. + }), + ); - // Route for backwards-compatibility to use JSON-RPC over HTTP at / let inner = Arc::clone(&self.inner); - let post_route = warp::path::end() - .and(warp::post()) - .and(warp::body::content_length_limit(1024 * 1024)) - .and(warp::body::bytes()) - .and_then(move |body: Bytes| { - let inner = Arc::clone(&inner); - async move { - let data = Self::handle_raw_request(inner, &Message::binary(body), None) - .await - .unwrap_or(Message::binary([])); - - let response = http::response::Builder::new() - .status(200) - .header("Content-Type", "application/json") - .body(data.as_bytes().to_owned()) - .unwrap(); // As long as the hard-coded status code and content-type is correct, this won't fail. - - Ok::<_, warp::Rejection>(response) - } - }); - - let json_rpc_route = ws_route.or(post_route); - - let root = if self.inner.config.basic_auth.is_some() { - let inner = Arc::clone(&self.inner); - let realm = "JSON-RPC"; - auth_filter::basic_auth_filter(realm) - .and_then(move |auth_header: Authorization| { - let inner = Arc::clone(&inner); - - let basic_auth = inner.config.basic_auth.as_ref().unwrap(); - future::ready( - basic_auth - .verify(auth_header.0.username(), auth_header.0.password()) - .map_err(|CredentialsVerificationError(())| { - warp::reject::custom(auth_filter::Unauthorized { - realm: realm.to_string(), - }) - }), - ) - }) - .untuple_one() - .boxed() - } else { - warp::any().boxed() - }; - - warp::serve( - root.and(json_rpc_route) - .with( - self.inner - .config - .cors - .clone() - .map_or(warp::cors(), |cors| cors.into_wrapper()), - ) - .recover(auth_filter::handle_auth_rejection), - ) - .run(self.inner.config.bind_to) - .await; + let ws_router = Router::new().route( + "/ws", + any(|ws: WebSocketUpgrade| async move { Self::upgrade_to_ws(inner, ws) }), + ); + + let app = Router::new() + .merge(http_router) + .merge(ws_router) + .route_layer(axum::middleware::from_fn_with_state( + Arc::clone(&self.inner), + basic_auth_middleware, + )) + .layer(DefaultBodyLimit::max(1024 * 1024 /* 1MB */)) + .layer( + self.inner + .config + .cors + .clone() + .unwrap_or_default() + .into_layer(), + ) + .with_state(Arc::clone(&self.inner)); + + let listener = TcpListener::bind(self.inner.config.bind_to).await.unwrap(); + axum::serve(listener, app).await.unwrap(); } /// Upgrades a connection to websocket. This creates message queues and tasks to forward messages between them. @@ -316,7 +333,7 @@ impl Server { /// - This sends stuff as binary websocket frames. It should really use text frames. /// - Make the queue size configurable /// - fn upgrade_to_ws(inner: Arc>, ws: warp::ws::Ws) -> impl warp::Reply { + fn upgrade_to_ws(inner: Arc>, ws: WebSocketUpgrade) -> HttpResponse { ws.on_upgrade(move |websocket| { let (mut tx, mut rx) = websocket.split(); @@ -326,7 +343,7 @@ impl Server { let forward_fut = async move { while let Some(data) = multiplex_rx.recv().await { // Close the sink if we get a close message (don't echo the message since this is not permitted) - if data.is_close() { + if matches!(data, Message::Close(_)) { tx.close().await?; } else { tx.send(data).await?; @@ -339,11 +356,13 @@ impl Server { let handle_fut = { async move { while let Some(message) = rx.next().await.transpose()? { - if message.is_ping() || message.is_pong() { + if matches!(message, Message::Ping(_)) + || matches!(message, Message::Pong(_)) + { // Do nothing - these messages are handled automatically - } else if message.is_close() { + } else if matches!(message, Message::Close(_)) { // We received the close message, so we need to send a close message to close the sink - multiplex_tx.send(warp::ws::Message::close()).await?; + multiplex_tx.send(Message::Close(None)).await?; // Then we exit the loop which closes the connection break; } else if let Some(response) = Self::handle_raw_request( @@ -382,7 +401,7 @@ impl Server { request: &Message, tx: Option<&mpsc::Sender>, ) -> Option { - match serde_json::from_slice(request.as_bytes()) { + match serde_json::from_slice(request.clone().into_data().as_ref()) { Ok(request) => Self::handle_request(inner, request, tx).await, Err(_e) => { log::error!("Received invalid JSON from client"); @@ -393,7 +412,7 @@ impl Server { } } .map(|response| { - if request.is_text() { + if matches!(&request, Message::Text(_)) { Message::text( serde_json::to_string(&response) .expect("Failed to serialize JSON RPC response"),