diff --git a/core/src/lib.rs b/core/src/lib.rs index 202faa1..06c0c66 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -35,6 +35,26 @@ pub enum Error { InvalidSubscriptionId(Value), } +/// Indicate if a websocket frame response should be in Binary or Text +#[derive(Copy, Clone, Default)] +pub enum FrameType { + /// Binary frame type + #[default] + Binary, + /// Text frame type + Text, +} + +impl From<&String> for FrameType { + fn from(value: &String) -> Self { + match value.as_str() { + "text" => FrameType::Text, + "binary" => FrameType::Binary, + _ => FrameType::Binary, + } + } +} + /// A JSON-RPC request or response can either be a single request or response, or a list of the former. This `enum` /// matches either for serialization and deserialization. /// diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 2e5d647..3516d0e 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -152,7 +152,7 @@ impl<'a> RpcMethod<'a> { let notifier = ::std::sync::Arc::new(::nimiq_jsonrpc_server::Notify::new()); let listener = notifier.clone(); - let subscription = ::nimiq_jsonrpc_server::connect_stream(stream, tx, stream_id, #method_name.to_owned(), listener); + let subscription = ::nimiq_jsonrpc_server::connect_stream(stream, tx, stream_id, #method_name.to_owned(), listener, frame_type); Ok::<_, ::nimiq_jsonrpc_core::RpcError>((subscription, Some(notifier))) } diff --git a/derive/src/service.rs b/derive/src/service.rs index 3dfb8e1..bd7542d 100644 --- a/derive/src/service.rs +++ b/derive/src/service.rs @@ -92,6 +92,7 @@ fn impl_service(im: &mut ItemImpl, args: &ServiceMeta) -> TokenStream { request: ::nimiq_jsonrpc_core::Request, tx: Option<&::tokio::sync::mpsc::Sender<::nimiq_jsonrpc_server::Message>>, stream_id: u64, + frame_type: Option<::nimiq_jsonrpc_core::FrameType>, ) -> Option<::nimiq_jsonrpc_server::ResponseAndSubScriptionNotifier> { match request.method.as_str() { #(#match_arms)* diff --git a/server/src/lib.rs b/server/src/lib.rs index 60e1ffe..c185b3c 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -19,7 +19,7 @@ use std::{ use async_trait::async_trait; use axum::{ body::{Body, Bytes}, - extract::{DefaultBodyLimit, State, WebSocketUpgrade}, + extract::{DefaultBodyLimit, Query, State, WebSocketUpgrade}, http::{header::CONTENT_TYPE, response::Builder, HeaderValue, Method, StatusCode}, middleware::Next, response::{IntoResponse as _, Response as HttpResponse}, @@ -47,7 +47,8 @@ use tokio::{ }; use nimiq_jsonrpc_core::{ - Request, Response, RpcError, Sensitive, SingleOrBatch, SubscriptionId, SubscriptionMessage, + FrameType, Request, Response, RpcError, Sensitive, SingleOrBatch, SubscriptionId, + SubscriptionMessage, }; pub use axum::extract::ws::Message; @@ -297,7 +298,7 @@ impl Server { let http_router = Router::new().route( "/", post(|body: Bytes| async move { - let data = Self::handle_raw_request(inner, &Message::binary(body), None) + let data = Self::handle_raw_request(inner, &Message::binary(body), None, None) .await .unwrap_or(Message::Binary(Bytes::new())); @@ -312,7 +313,11 @@ impl Server { let inner = Arc::clone(&self.inner); let ws_router = Router::new().route( "/ws", - any(|ws: WebSocketUpgrade| async move { Self::upgrade_to_ws(inner, ws) }), + any( + |Query(params): Query>, ws: WebSocketUpgrade| async move { + Self::upgrade_to_ws(inner, ws, params) + }, + ), ); let app = Router::new() @@ -344,10 +349,18 @@ impl Server { /// /// # TODO: /// - /// - This sends stuff as binary websocket frames. It should really use text frames. /// - Make the queue size configurable /// - fn upgrade_to_ws(inner: Arc>, ws: WebSocketUpgrade) -> HttpResponse { + fn upgrade_to_ws( + inner: Arc>, + ws: WebSocketUpgrade, + query_params: HashMap, + ) -> HttpResponse { + let frame_type: Option = query_params + .get("frame") + .map(|frame_type| Some(frame_type.into())) + .unwrap_or_default(); + ws.on_upgrade(move |websocket| { let (mut tx, mut rx) = websocket.split(); @@ -383,6 +396,7 @@ impl Server { Arc::clone(&inner), &message, Some(&multiplex_tx), + frame_type, ) .await { @@ -409,14 +423,16 @@ impl Server { /// - `request`: The raw request data. /// - `tx`: If the request was received over websocket, this the message queue over which the called function can /// send notifications to the client (used for subscriptions). + /// - `frame_type`: If the request was received over websocket, indicate whether notifications are send back as Text or Binary frames. /// async fn handle_raw_request( inner: Arc>, request: &Message, tx: Option<&mpsc::Sender>, + frame_type: Option, ) -> Option { match serde_json::from_slice(request.clone().into_data().as_ref()) { - Ok(request) => Self::handle_request(inner, request, tx).await, + Ok(request) => Self::handle_request(inner, request, tx, frame_type).await, Err(_e) => { log::error!("Received invalid JSON from client"); Some(SingleOrBatch::Single(Response::new_error( @@ -447,21 +463,27 @@ impl Server { /// - `request`: The request that was received. /// - `tx`: If the request was received over websocket, this the message queue over which the called function can /// send notifications to the client (used for subscriptions). + /// - `frame_type`: If the request was received over websocket, indicate whether notifications are send back as Text or Binary frames. /// async fn handle_request( inner: Arc>, request: SingleOrBatch, tx: Option<&mpsc::Sender>, + frame_type: Option, ) -> Option> { match request { - SingleOrBatch::Single(request) => Self::handle_single_request(inner, request, tx) - .await - .map(|(response, _)| SingleOrBatch::Single(response)), + SingleOrBatch::Single(request) => { + Self::handle_single_request(inner, request, tx, frame_type) + .await + .map(|(response, _)| SingleOrBatch::Single(response)) + } SingleOrBatch::Batch(requests) => { let futures = requests .into_iter() - .map(|request| Self::handle_single_request(Arc::clone(&inner), request, tx)) + .map(|request| { + Self::handle_single_request(Arc::clone(&inner), request, tx, frame_type) + }) .collect::>(); let responses = futures @@ -479,6 +501,7 @@ impl Server { inner: Arc>, request: Request, tx: Option<&mpsc::Sender>, + frame_type: Option, ) -> Option { if request.method == "unsubscribe" { return Self::handle_unsubscribe_stream(request, inner).await; @@ -490,7 +513,7 @@ impl Server { log::debug!("request: {:#?}", request); - let response = dispatcher.dispatch(request, tx, id).await; + let response = dispatcher.dispatch(request, tx, id, frame_type).await; log::debug!("response: {:#?}", response); @@ -565,6 +588,7 @@ pub trait Dispatcher: Send + Sync + 'static { request: Request, tx: Option<&mpsc::Sender>, id: u64, + frame_type: Option, ) -> Option; /// Returns whether a method should be dispatched with this dispatcher. @@ -605,13 +629,14 @@ impl Dispatcher for ModularDispatcher { request: Request, tx: Option<&mpsc::Sender>, id: u64, + frame_type: Option, ) -> Option { for dispatcher in &mut self.dispatchers { let m = dispatcher.match_method(&request.method); log::debug!("Matching '{}' against dispatcher -> {}", request.method, m); log::debug!("Methods: {:?}", dispatcher.method_names()); if m { - return dispatcher.dispatch(request, tx, id).await; + return dispatcher.dispatch(request, tx, id, frame_type).await; } } @@ -674,10 +699,11 @@ where request: Request, tx: Option<&mpsc::Sender>, id: u64, + frame_type: Option, ) -> Option { if self.is_allowed(&request.method) { log::debug!("Dispatching method: {}", request.method); - self.inner.dispatch(request, tx, id).await + self.inner.dispatch(request, tx, id, frame_type).await } else { log::debug!("Method not allowed: {}", request.method); // If the method is not white-listed, pretend it doesn't exist. @@ -833,6 +859,7 @@ async fn forward_notification( tx: &mut mpsc::Sender, id: &SubscriptionId, method: &str, + frame_type: Option, ) -> Result<(), Error> where T: Serialize + Debug + Send + Sync, @@ -846,8 +873,12 @@ where log::debug!("Sending notification: {:?}", notification); - tx.send(Message::binary(serde_json::to_vec(¬ification)?)) - .await?; + let message = match frame_type { + Some(FrameType::Text) => Message::text(serde_json::to_string(¬ification)?), + Some(FrameType::Binary) | None => Message::binary(serde_json::to_vec(¬ification)?), + }; + + tx.send(message).await?; Ok(()) } @@ -871,6 +902,7 @@ pub fn connect_stream( stream_id: u64, method: String, notify_handler: Arc, + frame_type: Option, ) -> SubscriptionId where T: Serialize + Debug + Send + Sync, @@ -892,7 +924,7 @@ where item = stream.next() => { match item { Some(notification) => { - if let Err(error) = forward_notification(notification, &mut tx, &id, &method).await { + if let Err(error) = forward_notification(notification, &mut tx, &id, &method, frame_type).await { // Break the loop when the channel is closed if let Error::Mpsc(_) = error { break;