Skip to content

Commit

Permalink
Option to change how websocket data frames are returned to the client
Browse files Browse the repository at this point in the history
The websocket specification defines 2 ways how application-layer data transmitted over the socket can be represented: `Text` or `Binary`.
Currently the default data frame representation returned by the RPC server is `Binary` but this is counterintuitive since JSONRPC already is in a text representation.

This PR let the client, per connection, decide in which representation data should be returned by the server. This is achived by adding support for a `frame` query parameter on the `/ws` endpoint.
Examples:
- ws://127.0.0.1:8000/ws		--> data frames are returned as `Binary` (default)
- ws://127.0.0.1:8000/ws?frame=text	--> data frames are returned as `Text`
- ws://127.0.0.1:8000/ws?frame=binary	--> data frames are returned as `Binary`
- ws://127.0.0.1:8000/ws?frame=foo   	--> data frames are returned as `Binary` (Fallback for unsupported values)

To accomplish backwards compatibility `Binary` is still the default returned representation but in a major release this should be changed to `Text`
  • Loading branch information
Eligioo committed Jan 23, 2025
1 parent c8cbf5a commit 9ae342d
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 19 deletions.
25 changes: 25 additions & 0 deletions core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,31 @@ pub enum Error {
InvalidSubscriptionId(Value),
}

/// Indicate if a websocket frame response should be in Binary or Text
#[derive(Copy, Clone)]
pub enum FrameType {
/// Binary frame type
Binary,
/// Text frame type
Text,
}

impl From<&String> for FrameType {
fn from(value: &String) -> Self {
match value.as_str() {
"text" => FrameType::Text,
"binary" => FrameType::Binary,
_ => FrameType::Binary,
}
}
}

impl Default for FrameType {
fn default() -> Self {
FrameType::Binary
}
}

Check warning on line 61 in core/src/lib.rs

View workflow job for this annotation

GitHub Actions / Clippy Report

this `impl` can be derived

warning: this `impl` can be derived --> core/src/lib.rs:57:1 | 57 | / impl Default for FrameType { 58 | | fn default() -> Self { 59 | | FrameType::Binary 60 | | } 61 | | } | |_^ | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#derivable_impls = note: `#[warn(clippy::derivable_impls)]` on by default = help: remove the manual implementation... help: ...and instead derive it... | 40 + #[derive(Default)] 41 | pub enum FrameType { | help: ...and mark the default variant | 42 ~ #[default] 43 ~ Binary, |

Check warning on line 61 in core/src/lib.rs

View workflow job for this annotation

GitHub Actions / Clippy Report

this `impl` can be derived

warning: this `impl` can be derived --> core/src/lib.rs:57:1 | 57 | / impl Default for FrameType { 58 | | fn default() -> Self { 59 | | FrameType::Binary 60 | | } 61 | | } | |_^ | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#derivable_impls = note: `#[warn(clippy::derivable_impls)]` on by default = help: remove the manual implementation... help: ...and instead derive it... | 40 + #[derive(Default)] 41 | pub enum FrameType { | help: ...and mark the default variant | 42 ~ #[default] 43 ~ Binary, |

Check warning on line 61 in core/src/lib.rs

View workflow job for this annotation

GitHub Actions / Clippy Report

this `impl` can be derived

warning: this `impl` can be derived --> core/src/lib.rs:57:1 | 57 | / impl Default for FrameType { 58 | | fn default() -> Self { 59 | | FrameType::Binary 60 | | } 61 | | } | |_^ | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#derivable_impls = note: `#[warn(clippy::derivable_impls)]` on by default = help: remove the manual implementation... help: ...and instead derive it... | 40 + #[derive(Default)] 41 | pub enum FrameType { | help: ...and mark the default variant | 42 ~ #[default] 43 ~ Binary, |

/// A JSON-RPC request or response can either be a single request or response, or a list of the former. This `enum`
/// matches either for serialization and deserialization.
///
Expand Down
2 changes: 1 addition & 1 deletion derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ impl<'a> RpcMethod<'a> {
move |params: #args_struct_ident| async move {
let stream = self.#method_ident(#(#method_args),*).await?;

let subscription = ::nimiq_jsonrpc_server::connect_stream(stream, tx, stream_id, #method_name.to_owned());
let subscription = ::nimiq_jsonrpc_server::connect_stream(stream, tx, stream_id, #method_name.to_owned(), frame_type);

Ok::<_, ::nimiq_jsonrpc_core::RpcError>(subscription)
}
Expand Down
1 change: 1 addition & 0 deletions derive/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ fn impl_service(im: &mut ItemImpl, args: &ServiceMeta) -> TokenStream {
request: ::nimiq_jsonrpc_core::Request,
tx: Option<&::tokio::sync::mpsc::Sender<::nimiq_jsonrpc_server::Message>>,
stream_id: u64,
frame_type: Option<::nimiq_jsonrpc_core::FrameType>,
) -> Option<::nimiq_jsonrpc_core::Response> {
match request.method.as_str() {
#(#match_arms)*
Expand Down
69 changes: 51 additions & 18 deletions server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#![warn(rustdoc::missing_doc_code_examples)]

use std::{
collections::HashSet,
collections::{HashMap, HashSet},
error,
fmt::{self, Debug},
future::Future,
Expand All @@ -19,7 +19,7 @@ use std::{
use async_trait::async_trait;
use axum::{
body::{Body, Bytes},
extract::{DefaultBodyLimit, State, WebSocketUpgrade},
extract::{DefaultBodyLimit, Query, State, WebSocketUpgrade},
http::{header::CONTENT_TYPE, response::Builder, HeaderValue, Method, StatusCode},
middleware::Next,
response::{IntoResponse as _, Response as HttpResponse},
Expand Down Expand Up @@ -47,7 +47,8 @@ use tokio::{
};

use nimiq_jsonrpc_core::{
Request, Response, RpcError, Sensitive, SingleOrBatch, SubscriptionId, SubscriptionMessage,
FrameType, Request, Response, RpcError, Sensitive, SingleOrBatch, SubscriptionId,
SubscriptionMessage,
};

pub use axum::extract::ws::Message;
Expand Down Expand Up @@ -291,7 +292,7 @@ impl<D: Dispatcher> Server<D> {
let http_router = Router::new().route(
"/",
post(|body: Bytes| async move {
let data = Self::handle_raw_request(inner, &Message::binary(body), None)
let data = Self::handle_raw_request(inner, &Message::binary(body), None, None)
.await
.unwrap_or(Message::Binary(Bytes::new()));

Expand All @@ -306,7 +307,11 @@ impl<D: Dispatcher> Server<D> {
let inner = Arc::clone(&self.inner);
let ws_router = Router::new().route(
"/ws",
any(|ws: WebSocketUpgrade| async move { Self::upgrade_to_ws(inner, ws) }),
any(
|Query(params): Query<HashMap<String, String>>, ws: WebSocketUpgrade| async move {
Self::upgrade_to_ws(inner, ws, params)
},
),
);

let app = Router::new()
Expand Down Expand Up @@ -338,10 +343,18 @@ impl<D: Dispatcher> Server<D> {
///
/// # TODO:
///
/// - This sends stuff as binary websocket frames. It should really use text frames.
/// - Make the queue size configurable
///
fn upgrade_to_ws(inner: Arc<Inner<D>>, ws: WebSocketUpgrade) -> HttpResponse<Body> {
fn upgrade_to_ws(
inner: Arc<Inner<D>>,
ws: WebSocketUpgrade,
query_params: HashMap<String, String>,
) -> HttpResponse<Body> {
let frame_type: Option<FrameType> = query_params
.get("frame")
.map(|frame_type| Some(frame_type.into()))
.unwrap_or_default();

ws.on_upgrade(move |websocket| {
let (mut tx, mut rx) = websocket.split();

Expand Down Expand Up @@ -377,6 +390,7 @@ impl<D: Dispatcher> Server<D> {
Arc::clone(&inner),
&message,
Some(&multiplex_tx),
frame_type,
)
.await
{
Expand All @@ -403,14 +417,16 @@ impl<D: Dispatcher> Server<D> {
/// - `request`: The raw request data.
/// - `tx`: If the request was received over websocket, this the message queue over which the called function can
/// send notifications to the client (used for subscriptions).
/// - `frame_type`: If the request was received over websocket, indicate whether notifications are send back as Text or Binary frames.
///
async fn handle_raw_request(
inner: Arc<Inner<D>>,
request: &Message,
tx: Option<&mpsc::Sender<Message>>,
frame_type: Option<FrameType>,
) -> Option<Message> {
match serde_json::from_slice(request.clone().into_data().as_ref()) {
Ok(request) => Self::handle_request(inner, request, tx).await,
Ok(request) => Self::handle_request(inner, request, tx, frame_type).await,
Err(_e) => {
log::error!("Received invalid JSON from client");
Some(SingleOrBatch::Single(Response::new_error(
Expand Down Expand Up @@ -441,21 +457,27 @@ impl<D: Dispatcher> Server<D> {
/// - `request`: The request that was received.
/// - `tx`: If the request was received over websocket, this the message queue over which the called function can
/// send notifications to the client (used for subscriptions).
/// - `frame_type`: If the request was received over websocket, indicate whether notifications are send back as Text or Binary frames.
///
async fn handle_request(
inner: Arc<Inner<D>>,
request: SingleOrBatch<Request>,
tx: Option<&mpsc::Sender<Message>>,
frame_type: Option<FrameType>,
) -> Option<SingleOrBatch<Response>> {
match request {
SingleOrBatch::Single(request) => Self::handle_single_request(inner, request, tx)
.await
.map(SingleOrBatch::Single),
SingleOrBatch::Single(request) => {
Self::handle_single_request(inner, request, tx, frame_type)
.await
.map(SingleOrBatch::Single)
}

SingleOrBatch::Batch(requests) => {
let futures = requests
.into_iter()
.map(|request| Self::handle_single_request(Arc::clone(&inner), request, tx))
.map(|request| {
Self::handle_single_request(Arc::clone(&inner), request, tx, frame_type)
})
.collect::<FuturesUnordered<_>>();

let responses = futures
Expand All @@ -477,14 +499,15 @@ impl<D: Dispatcher> Server<D> {
inner: Arc<Inner<D>>,
request: Request,
tx: Option<&mpsc::Sender<Message>>,
frame_type: Option<FrameType>,
) -> Option<Response> {
let mut dispatcher = inner.dispatcher.write().await;
// This ID is only used for streams
let id = inner.next_id.fetch_add(1, Ordering::SeqCst);

log::debug!("request: {:#?}", request);

let response = dispatcher.dispatch(request, tx, id).await;
let response = dispatcher.dispatch(request, tx, id, frame_type).await;

log::debug!("response: {:#?}", response);

Expand All @@ -502,6 +525,7 @@ pub trait Dispatcher: Send + Sync + 'static {
request: Request,
tx: Option<&mpsc::Sender<Message>>,
id: u64,
frame_type: Option<FrameType>,
) -> Option<Response>;

/// Returns whether a method should be dispatched with this dispatcher.
Expand Down Expand Up @@ -542,13 +566,14 @@ impl Dispatcher for ModularDispatcher {
request: Request,
tx: Option<&mpsc::Sender<Message>>,
id: u64,
frame_type: Option<FrameType>,
) -> Option<Response> {
for dispatcher in &mut self.dispatchers {
let m = dispatcher.match_method(&request.method);
log::debug!("Matching '{}' against dispatcher -> {}", request.method, m);
log::debug!("Methods: {:?}", dispatcher.method_names());
if m {
return dispatcher.dispatch(request, tx, id).await;
return dispatcher.dispatch(request, tx, id, frame_type).await;
}
}

Expand Down Expand Up @@ -611,10 +636,11 @@ where
request: Request,
tx: Option<&mpsc::Sender<Message>>,
id: u64,
frame_type: Option<FrameType>,
) -> Option<Response> {
if self.is_allowed(&request.method) {
log::debug!("Dispatching method: {}", request.method);
self.inner.dispatch(request, tx, id).await
self.inner.dispatch(request, tx, id, frame_type).await
} else {
log::debug!("Method not allowed: {}", request.method);
// If the method is not white-listed, pretend it doesn't exist.
Expand Down Expand Up @@ -761,6 +787,7 @@ async fn forward_notification<T>(
tx: &mut mpsc::Sender<Message>,
id: &SubscriptionId,
method: &str,
frame_type: Option<FrameType>,
) -> Result<(), Error>
where
T: Serialize + Debug + Send + Sync,
Expand All @@ -774,8 +801,12 @@ where

log::debug!("Sending notification: {:?}", notification);

tx.send(Message::binary(serde_json::to_vec(&notification)?))
.await?;
let message = match frame_type {
Some(FrameType::Text) => Message::text(serde_json::to_string(&notification)?),
Some(FrameType::Binary) | None => Message::binary(serde_json::to_vec(&notification)?),
};

tx.send(message).await?;

Ok(())
}
Expand All @@ -798,6 +829,7 @@ pub fn connect_stream<T, S>(
tx: &mpsc::Sender<Message>,
stream_id: u64,
method: String,
frame_type: Option<FrameType>,
) -> SubscriptionId
where
T: Serialize + Debug + Send + Sync,
Expand All @@ -812,7 +844,8 @@ where
pin_mut!(stream);

while let Some(item) = stream.next().await {
if let Err(e) = forward_notification(item, &mut tx, &id, &method).await {
if let Err(e) = forward_notification(item, &mut tx, &id, &method, frame_type).await
{
// Break the loop when the channel is closed
if let Error::Mpsc(_) = e {
break;
Expand Down

0 comments on commit 9ae342d

Please sign in to comment.