Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Option to change how websocket data frames are returned to the client #39

Merged
merged 1 commit into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,26 @@ pub enum Error {
InvalidSubscriptionId(Value),
}

/// Indicate if a websocket frame response should be in Binary or Text
#[derive(Copy, Clone, Default)]
pub enum FrameType {
/// Binary frame type
#[default]
Binary,
/// Text frame type
Text,
}

impl From<&String> for FrameType {
fn from(value: &String) -> Self {
match value.as_str() {
"text" => FrameType::Text,
"binary" => FrameType::Binary,
_ => FrameType::Binary,
}
}
}

/// A JSON-RPC request or response can either be a single request or response, or a list of the former. This `enum`
/// matches either for serialization and deserialization.
///
Expand Down
2 changes: 1 addition & 1 deletion derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ impl<'a> RpcMethod<'a> {
let notifier = ::std::sync::Arc::new(::nimiq_jsonrpc_server::Notify::new());
let listener = notifier.clone();

let subscription = ::nimiq_jsonrpc_server::connect_stream(stream, tx, stream_id, #method_name.to_owned(), listener);
let subscription = ::nimiq_jsonrpc_server::connect_stream(stream, tx, stream_id, #method_name.to_owned(), listener, frame_type);

Ok::<_, ::nimiq_jsonrpc_core::RpcError>((subscription, Some(notifier)))
}
Expand Down
1 change: 1 addition & 0 deletions derive/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ fn impl_service(im: &mut ItemImpl, args: &ServiceMeta) -> TokenStream {
request: ::nimiq_jsonrpc_core::Request,
tx: Option<&::tokio::sync::mpsc::Sender<::nimiq_jsonrpc_server::Message>>,
stream_id: u64,
frame_type: Option<::nimiq_jsonrpc_core::FrameType>,
) -> Option<::nimiq_jsonrpc_server::ResponseAndSubScriptionNotifier> {
match request.method.as_str() {
#(#match_arms)*
Expand Down
66 changes: 49 additions & 17 deletions server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use std::{
use async_trait::async_trait;
use axum::{
body::{Body, Bytes},
extract::{DefaultBodyLimit, State, WebSocketUpgrade},
extract::{DefaultBodyLimit, Query, State, WebSocketUpgrade},
http::{header::CONTENT_TYPE, response::Builder, HeaderValue, Method, StatusCode},
middleware::Next,
response::{IntoResponse as _, Response as HttpResponse},
Expand Down Expand Up @@ -47,7 +47,8 @@ use tokio::{
};

use nimiq_jsonrpc_core::{
Request, Response, RpcError, Sensitive, SingleOrBatch, SubscriptionId, SubscriptionMessage,
FrameType, Request, Response, RpcError, Sensitive, SingleOrBatch, SubscriptionId,
SubscriptionMessage,
};

pub use axum::extract::ws::Message;
Expand Down Expand Up @@ -297,7 +298,7 @@ impl<D: Dispatcher> Server<D> {
let http_router = Router::new().route(
"/",
post(|body: Bytes| async move {
let data = Self::handle_raw_request(inner, &Message::binary(body), None)
let data = Self::handle_raw_request(inner, &Message::binary(body), None, None)
.await
.unwrap_or(Message::Binary(Bytes::new()));

Expand All @@ -312,7 +313,11 @@ impl<D: Dispatcher> Server<D> {
let inner = Arc::clone(&self.inner);
let ws_router = Router::new().route(
"/ws",
any(|ws: WebSocketUpgrade| async move { Self::upgrade_to_ws(inner, ws) }),
any(
|Query(params): Query<HashMap<String, String>>, ws: WebSocketUpgrade| async move {
Self::upgrade_to_ws(inner, ws, params)
},
),
);

let app = Router::new()
Expand Down Expand Up @@ -344,10 +349,18 @@ impl<D: Dispatcher> Server<D> {
///
/// # TODO:
///
/// - This sends stuff as binary websocket frames. It should really use text frames.
/// - Make the queue size configurable
///
fn upgrade_to_ws(inner: Arc<Inner<D>>, ws: WebSocketUpgrade) -> HttpResponse<Body> {
fn upgrade_to_ws(
inner: Arc<Inner<D>>,
ws: WebSocketUpgrade,
query_params: HashMap<String, String>,
) -> HttpResponse<Body> {
let frame_type: Option<FrameType> = query_params
.get("frame")
.map(|frame_type| Some(frame_type.into()))
.unwrap_or_default();

ws.on_upgrade(move |websocket| {
let (mut tx, mut rx) = websocket.split();

Expand Down Expand Up @@ -383,6 +396,7 @@ impl<D: Dispatcher> Server<D> {
Arc::clone(&inner),
&message,
Some(&multiplex_tx),
frame_type,
)
.await
{
Expand All @@ -409,14 +423,16 @@ impl<D: Dispatcher> Server<D> {
/// - `request`: The raw request data.
/// - `tx`: If the request was received over websocket, this the message queue over which the called function can
/// send notifications to the client (used for subscriptions).
/// - `frame_type`: If the request was received over websocket, indicate whether notifications are send back as Text or Binary frames.
///
async fn handle_raw_request(
inner: Arc<Inner<D>>,
request: &Message,
tx: Option<&mpsc::Sender<Message>>,
frame_type: Option<FrameType>,
) -> Option<Message> {
match serde_json::from_slice(request.clone().into_data().as_ref()) {
Ok(request) => Self::handle_request(inner, request, tx).await,
Ok(request) => Self::handle_request(inner, request, tx, frame_type).await,
Err(_e) => {
log::error!("Received invalid JSON from client");
Some(SingleOrBatch::Single(Response::new_error(
Expand Down Expand Up @@ -447,21 +463,27 @@ impl<D: Dispatcher> Server<D> {
/// - `request`: The request that was received.
/// - `tx`: If the request was received over websocket, this the message queue over which the called function can
/// send notifications to the client (used for subscriptions).
/// - `frame_type`: If the request was received over websocket, indicate whether notifications are send back as Text or Binary frames.
///
async fn handle_request(
inner: Arc<Inner<D>>,
request: SingleOrBatch<Request>,
tx: Option<&mpsc::Sender<Message>>,
frame_type: Option<FrameType>,
) -> Option<SingleOrBatch<Response>> {
match request {
SingleOrBatch::Single(request) => Self::handle_single_request(inner, request, tx)
.await
.map(|(response, _)| SingleOrBatch::Single(response)),
SingleOrBatch::Single(request) => {
Self::handle_single_request(inner, request, tx, frame_type)
.await
.map(|(response, _)| SingleOrBatch::Single(response))
}

SingleOrBatch::Batch(requests) => {
let futures = requests
.into_iter()
.map(|request| Self::handle_single_request(Arc::clone(&inner), request, tx))
.map(|request| {
Self::handle_single_request(Arc::clone(&inner), request, tx, frame_type)
})
.collect::<FuturesUnordered<_>>();

let responses = futures
Expand All @@ -479,6 +501,7 @@ impl<D: Dispatcher> Server<D> {
inner: Arc<Inner<D>>,
request: Request,
tx: Option<&mpsc::Sender<Message>>,
frame_type: Option<FrameType>,
) -> Option<ResponseAndSubScriptionNotifier> {
if request.method == "unsubscribe" {
return Self::handle_unsubscribe_stream(request, inner).await;
Expand All @@ -490,7 +513,7 @@ impl<D: Dispatcher> Server<D> {

log::debug!("request: {:#?}", request);

let response = dispatcher.dispatch(request, tx, id).await;
let response = dispatcher.dispatch(request, tx, id, frame_type).await;

log::debug!("response: {:#?}", response);

Expand Down Expand Up @@ -565,6 +588,7 @@ pub trait Dispatcher: Send + Sync + 'static {
request: Request,
tx: Option<&mpsc::Sender<Message>>,
id: u64,
frame_type: Option<FrameType>,
) -> Option<ResponseAndSubScriptionNotifier>;

/// Returns whether a method should be dispatched with this dispatcher.
Expand Down Expand Up @@ -605,13 +629,14 @@ impl Dispatcher for ModularDispatcher {
request: Request,
tx: Option<&mpsc::Sender<Message>>,
id: u64,
frame_type: Option<FrameType>,
) -> Option<ResponseAndSubScriptionNotifier> {
for dispatcher in &mut self.dispatchers {
let m = dispatcher.match_method(&request.method);
log::debug!("Matching '{}' against dispatcher -> {}", request.method, m);
log::debug!("Methods: {:?}", dispatcher.method_names());
if m {
return dispatcher.dispatch(request, tx, id).await;
return dispatcher.dispatch(request, tx, id, frame_type).await;
}
}

Expand Down Expand Up @@ -674,10 +699,11 @@ where
request: Request,
tx: Option<&mpsc::Sender<Message>>,
id: u64,
frame_type: Option<FrameType>,
) -> Option<ResponseAndSubScriptionNotifier> {
if self.is_allowed(&request.method) {
log::debug!("Dispatching method: {}", request.method);
self.inner.dispatch(request, tx, id).await
self.inner.dispatch(request, tx, id, frame_type).await
} else {
log::debug!("Method not allowed: {}", request.method);
// If the method is not white-listed, pretend it doesn't exist.
Expand Down Expand Up @@ -833,6 +859,7 @@ async fn forward_notification<T>(
tx: &mut mpsc::Sender<Message>,
id: &SubscriptionId,
method: &str,
frame_type: Option<FrameType>,
) -> Result<(), Error>
where
T: Serialize + Debug + Send + Sync,
Expand All @@ -846,8 +873,12 @@ where

log::debug!("Sending notification: {:?}", notification);

tx.send(Message::binary(serde_json::to_vec(&notification)?))
.await?;
let message = match frame_type {
Some(FrameType::Text) => Message::text(serde_json::to_string(&notification)?),
Some(FrameType::Binary) | None => Message::binary(serde_json::to_vec(&notification)?),
};

tx.send(message).await?;

Ok(())
}
Expand All @@ -871,6 +902,7 @@ pub fn connect_stream<T, S>(
stream_id: u64,
method: String,
notify_handler: Arc<Notify>,
frame_type: Option<FrameType>,
) -> SubscriptionId
where
T: Serialize + Debug + Send + Sync,
Expand All @@ -892,7 +924,7 @@ where
item = stream.next() => {
match item {
Some(notification) => {
if let Err(error) = forward_notification(notification, &mut tx, &id, &method).await {
if let Err(error) = forward_notification(notification, &mut tx, &id, &method, frame_type).await {
// Break the loop when the channel is closed
if let Error::Mpsc(_) = error {
break;
Expand Down
Loading