diff --git a/src/client.rs b/src/client.rs index 96bdbb10..5e900602 100644 --- a/src/client.rs +++ b/src/client.rs @@ -4,35 +4,32 @@ //! //! Support for connecting to JSONRPC servers over HTTP, sending requests, //! and parsing responses -//! use std::borrow::Cow; use std::collections::HashMap; use std::fmt; +use std::hash::{Hash, Hasher}; use std::sync::atomic; -use serde; -use serde_json; use serde_json::value::RawValue; +use serde_json::Value; -use super::{Request, Response}; use crate::error::Error; -use crate::util::HashableValue; +use crate::{Request, Response}; /// An interface for a transport over which to use the JSONRPC protocol. pub trait Transport: Send + Sync + 'static { - /// Send an RPC request over the transport. + /// Sends an RPC request over the transport. fn send_request(&self, _: Request) -> Result; - /// Send a batch of RPC requests over the transport. + /// Sends a batch of RPC requests over the transport. fn send_batch(&self, _: &[Request]) -> Result, Error>; - /// Format the target of this transport. - /// I.e. the URL/socket/... + /// Formats the target of this transport. I.e. the URL/socket/... fn fmt_target(&self, f: &mut fmt::Formatter) -> fmt::Result; } /// A JSON-RPC client. /// -/// Create a new Client using one of the transport-specific constructors e.g., +/// Creates a new Client using one of the transport-specific constructors e.g., /// [`Client::simple_http`] for a bare-minimum HTTP transport. pub struct Client { pub(crate) transport: Box, @@ -111,7 +108,7 @@ impl Client { Ok(results) } - /// Make a request and deserialize the response. + /// Makes a request and deserializes the response. /// /// To construct the arguments, one can use one of the shorthand methods /// [`crate::arg`] or [`crate::try_arg`]. @@ -149,9 +146,65 @@ impl From for Client { } } +/// Newtype around `Value` which allows hashing for use as hashmap keys, +/// this is needed for batch requests. +/// +/// The reason `Value` does not support `Hash` or `Eq` by itself +/// is that it supports `f64` values; but for batch requests we +/// will only be hashing the "id" field of the request/response +/// pair, which should never need decimal precision and therefore +/// never use `f64`. +#[derive(Clone, PartialEq, Debug)] +struct HashableValue<'a>(pub Cow<'a, Value>); + +impl<'a> Eq for HashableValue<'a> {} + +impl<'a> Hash for HashableValue<'a> { + fn hash(&self, state: &mut H) { + match *self.0.as_ref() { + Value::Null => "null".hash(state), + Value::Bool(false) => "false".hash(state), + Value::Bool(true) => "true".hash(state), + Value::Number(ref n) => { + "number".hash(state); + if let Some(n) = n.as_i64() { + n.hash(state); + } else if let Some(n) = n.as_u64() { + n.hash(state); + } else { + n.to_string().hash(state); + } + } + Value::String(ref s) => { + "string".hash(state); + s.hash(state); + } + Value::Array(ref v) => { + "array".hash(state); + v.len().hash(state); + for obj in v { + HashableValue(Cow::Borrowed(obj)).hash(state); + } + } + Value::Object(ref m) => { + "object".hash(state); + m.len().hash(state); + for (key, val) in m { + key.hash(state); + HashableValue(Cow::Borrowed(val)).hash(state); + } + } + } + } +} + #[cfg(test)] mod tests { use super::*; + + use std::borrow::Cow; + use std::collections::HashSet; + use std::str::FromStr; use std::sync; struct DummyTransport; @@ -177,4 +230,38 @@ mod tests { assert_eq!(client.nonce.load(sync::atomic::Ordering::Relaxed), 3); assert!(req1.id != req2.id); } + + #[test] + fn hash_value() { + let val = HashableValue(Cow::Owned(Value::from_str("null").unwrap())); + let t = HashableValue(Cow::Owned(Value::from_str("true").unwrap())); + let f = HashableValue(Cow::Owned(Value::from_str("false").unwrap())); + let ns = + HashableValue(Cow::Owned(Value::from_str("[0, -0, 123.4567, -100000000]").unwrap())); + let m = + HashableValue(Cow::Owned(Value::from_str("{ \"field\": 0, \"field\": -0 }").unwrap())); + + let mut coll = HashSet::new(); + + assert!(!coll.contains(&val)); + coll.insert(val.clone()); + assert!(coll.contains(&val)); + + assert!(!coll.contains(&t)); + assert!(!coll.contains(&f)); + coll.insert(t.clone()); + assert!(coll.contains(&t)); + assert!(!coll.contains(&f)); + coll.insert(f.clone()); + assert!(coll.contains(&t)); + assert!(coll.contains(&f)); + + assert!(!coll.contains(&ns)); + coll.insert(ns.clone()); + assert!(coll.contains(&ns)); + + assert!(!coll.contains(&m)); + coll.insert(m.clone()); + assert!(coll.contains(&m)); + } } diff --git a/src/error.rs b/src/error.rs index 97b88868..6798f1c1 100644 --- a/src/error.rs +++ b/src/error.rs @@ -2,13 +2,11 @@ //! # Error handling //! -//! Some useful methods for creating Error objects -//! +//! Some useful methods for creating Error objects. use std::{error, fmt}; use serde::{Deserialize, Serialize}; -use serde_json; use crate::Response; @@ -50,18 +48,18 @@ impl From for Error { impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + use Error::*; + match *self { - Error::Transport(ref e) => write!(f, "transport error: {}", e), - Error::Json(ref e) => write!(f, "JSON decode error: {}", e), - Error::Rpc(ref r) => write!(f, "RPC error response: {:?}", r), - Error::BatchDuplicateResponseId(ref v) => { - write!(f, "duplicate RPC batch response ID: {}", v) - } - Error::WrongBatchResponseId(ref v) => write!(f, "wrong RPC batch response ID: {}", v), - Error::NonceMismatch => write!(f, "Nonce of response did not match nonce of request"), - Error::VersionMismatch => write!(f, "`jsonrpc` field set to non-\"2.0\""), - Error::EmptyBatch => write!(f, "batches can't be empty"), - Error::WrongBatchResponseSize => write!(f, "too many responses returned in batch"), + Transport(ref e) => write!(f, "transport error: {}", e), + Json(ref e) => write!(f, "JSON decode error: {}", e), + Rpc(ref r) => write!(f, "RPC error response: {:?}", r), + BatchDuplicateResponseId(ref v) => write!(f, "duplicate RPC batch response ID: {}", v), + WrongBatchResponseId(ref v) => write!(f, "wrong RPC batch response ID: {}", v), + NonceMismatch => write!(f, "nonce of response did not match nonce of request"), + VersionMismatch => write!(f, "`jsonrpc` field set to non-\"2.0\""), + EmptyBatch => write!(f, "batches can't be empty"), + WrongBatchResponseSize => write!(f, "too many responses returned in batch"), } } } @@ -121,8 +119,8 @@ pub enum StandardError { InternalError, } -#[derive(Clone, Debug, Deserialize, Serialize)] /// A JSONRPC error object +#[derive(Clone, Debug, Deserialize, Serialize)] pub struct RpcError { /// The integer identifier of the error pub code: i32, diff --git a/src/lib.rs b/src/lib.rs index 9615ae12..a53e1259 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,14 +3,11 @@ //! # Rust JSON-RPC Library //! //! Rust support for the JSON-RPC 2.0 protocol. -//! #![cfg_attr(docsrs, feature(doc_auto_cfg))] // Coding conventions #![warn(missing_docs)] -use serde::{Deserialize, Serialize}; - /// Re-export `serde` crate. pub extern crate serde; /// Re-export `serde_json` crate. @@ -22,7 +19,6 @@ pub extern crate base64; pub mod client; pub mod error; -mod util; #[cfg(feature = "simple_http")] pub mod simple_http; @@ -33,12 +29,12 @@ pub mod simple_tcp; #[cfg(all(feature = "simple_uds", not(windows)))] pub mod simple_uds; -// Re-export error type +use serde::{Deserialize, Serialize}; +use serde_json::value::RawValue; + pub use crate::client::{Client, Transport}; pub use crate::error::Error; -use serde_json::value::RawValue; - /// Shorthand method to convert an argument into a boxed [`serde_json::value::RawValue`]. /// /// Since serializers rarely fail, it's probably easier to use [`arg`] instead. @@ -60,27 +56,27 @@ pub fn arg(arg: T) -> Box { } } -#[derive(Debug, Clone, Serialize)] /// A JSONRPC request object. +#[derive(Debug, Clone, Serialize)] pub struct Request<'a> { /// The name of the RPC call. pub method: &'a str, /// Parameters to the RPC call. pub params: &'a [Box], - /// Identifier for this Request, which should appear in the response. + /// Identifier for this request, which should appear in the response. pub id: serde_json::Value, /// jsonrpc field, MUST be "2.0". pub jsonrpc: Option<&'a str>, } -#[derive(Debug, Clone, Deserialize, Serialize)] /// A JSONRPC response object. +#[derive(Debug, Clone, Deserialize, Serialize)] pub struct Response { /// A result if there is one, or [`None`]. pub result: Option>, /// An error if there is one, or [`None`]. pub error: Option, - /// Identifier for this Request, which should match that of the request. + /// Identifier for this response, which should match that of the request. pub id: serde_json::Value, /// jsonrpc field, MUST be "2.0". pub jsonrpc: Option, diff --git a/src/simple_http.rs b/src/simple_http.rs index cbd9d286..f907ea79 100644 --- a/src/simple_http.rs +++ b/src/simple_http.rs @@ -3,7 +3,6 @@ //! This module implements a minimal and non standard conforming HTTP 1.0 //! round-tripper that works with the bitcoind RPC server. This can be used //! if minimal dependencies are a goal and synchronous communication is ok. -//! #[cfg(feature = "proxy")] use socks::Socks5Stream; @@ -18,56 +17,15 @@ use std::{error, fmt, io, net, num}; use crate::client::Transport; use crate::{Request, Response}; -#[cfg(fuzzing)] -/// Global mutex used by the fuzzing harness to inject data into the read -/// end of the TCP stream. -pub static FUZZ_TCP_SOCK: Mutex>>> = Mutex::new(None); - -#[cfg(fuzzing)] -#[derive(Clone, Debug)] -struct TcpStream; - -#[cfg(fuzzing)] -mod impls { - use super::*; - impl Read for TcpStream { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - match *FUZZ_TCP_SOCK.lock().unwrap() { - Some(ref mut cursor) => io::Read::read(cursor, buf), - None => Ok(0), - } - } - } - impl Write for TcpStream { - fn write(&mut self, buf: &[u8]) -> io::Result { - io::sink().write(buf) - } - fn flush(&mut self) -> io::Result<()> { - Ok(()) - } - } - - impl TcpStream { - pub fn connect_timeout(_: &SocketAddr, _: Duration) -> io::Result { - Ok(TcpStream) - } - pub fn set_read_timeout(&self, _: Option) -> io::Result<()> { - Ok(()) - } - pub fn set_write_timeout(&self, _: Option) -> io::Result<()> { - Ok(()) - } - } -} - /// The default TCP port to use for connections. /// Set to 8332, the default RPC port for bitcoind. pub const DEFAULT_PORT: u16 = 8332; /// The Default SOCKS5 Port to use for proxy connection. +/// Set to 9050, the default RPC port for tor. pub const DEFAULT_PROXY_PORT: u16 = 9050; -/// Absolute maximum content length we will allow before cutting off the response +/// Absolute maximum content length allowed before cutting off the response. const FINAL_RESP_ALLOC: u64 = 1024 * 1024 * 1024; /// Simple HTTP transport that implements the necessary subset of HTTP for @@ -122,7 +80,7 @@ impl SimpleHttpTransport { Builder::new() } - /// Replaces the URL of the transport + /// Replaces the URL of the transport. pub fn set_url(&mut self, url: &str) -> Result<(), Error> { let url = check_url(url)?; self.addr = url.0; @@ -130,7 +88,7 @@ impl SimpleHttpTransport { Ok(()) } - /// Replaces only the path part of the URL + /// Replaces only the path part of the URL. pub fn set_url_path(&mut self, path: String) { self.path = path; } @@ -317,171 +275,6 @@ impl SimpleHttpTransport { } } -/// Error that can happen when sending requests. -#[derive(Debug)] -pub enum Error { - /// An invalid URL was passed. - InvalidUrl { - /// The URL passed. - url: String, - /// The reason the URL is invalid. - reason: &'static str, - }, - /// An error occurred on the socket layer. - SocketError(io::Error), - /// The HTTP response was too short to even fit a HTTP 1.1 header - HttpResponseTooShort { - /// The total length of the response - actual: usize, - /// Minimum length we can parse - needed: usize, - }, - /// The HTTP response started with a HTTP/1.1 line which was not ASCII - HttpResponseNonAsciiHello(Vec), - /// The HTTP response did not start with HTTP/1.1 - HttpResponseBadHello { - /// Actual HTTP-whatever string - actual: String, - /// The hello string of the HTTP version we support - expected: String, - }, - /// Could not parse the status value as a number - HttpResponseBadStatus(String, num::ParseIntError), - /// Could not parse the status value as a number - HttpResponseBadContentLength(String, num::ParseIntError), - /// The indicated content-length header exceeded our maximum - HttpResponseContentLengthTooLarge { - /// The length indicated in the content-length header - length: u64, - /// Our hard maximum on number of bytes we'll try to read - max: u64, - }, - /// Unexpected HTTP error code (non-200). - HttpErrorCode(u16), - /// Received EOF before getting as many bytes as were indicated by the - /// content-length header - IncompleteResponse { - /// The content-length header - content_length: u64, - /// The number of bytes we actually read - n_read: u64, - }, - /// JSON parsing error. - Json(serde_json::Error), -} - -impl Error { - /// Utility method to create [`Error::InvalidUrl`] variants. - fn url>(url: U, reason: &'static str) -> Error { - Error::InvalidUrl { - url: url.into(), - reason, - } - } -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { - match *self { - Error::InvalidUrl { - ref url, - ref reason, - } => write!(f, "invalid URL '{}': {}", url, reason), - Error::SocketError(ref e) => write!(f, "Couldn't connect to host: {}", e), - Error::HttpResponseTooShort { - ref actual, - ref needed, - } => { - write!(f, "HTTP response too short: length {}, needed {}.", actual, needed) - } - Error::HttpResponseNonAsciiHello(ref bytes) => { - write!(f, "HTTP response started with non-ASCII {:?}", bytes) - } - Error::HttpResponseBadHello { - ref actual, - ref expected, - } => { - write!(f, "HTTP response started with `{}`; expected `{}`.", actual, expected) - } - Error::HttpResponseBadStatus(ref status, ref err) => { - write!(f, "HTTP response had bad status code `{}`: {}.", status, err) - } - Error::HttpResponseBadContentLength(ref len, ref err) => { - write!(f, "HTTP response had bad content length `{}`: {}.", len, err) - } - Error::HttpResponseContentLengthTooLarge { - length, - max, - } => { - write!(f, "HTTP response content length {} exceeds our max {}.", length, max) - } - Error::HttpErrorCode(c) => write!(f, "unexpected HTTP code: {}", c), - Error::IncompleteResponse { - content_length, - n_read, - } => { - write!( - f, - "Read {} bytes but HTTP response content-length header was {}.", - n_read, content_length - ) - } - Error::Json(ref e) => write!(f, "JSON error: {}", e), - } - } -} - -impl error::Error for Error { - fn source(&self) -> Option<&(dyn error::Error + 'static)> { - use self::Error::*; - - match *self { - InvalidUrl { - .. - } - | HttpResponseTooShort { - .. - } - | HttpResponseNonAsciiHello(..) - | HttpResponseBadHello { - .. - } - | HttpResponseBadStatus(..) - | HttpResponseBadContentLength(..) - | HttpResponseContentLengthTooLarge { - .. - } - | HttpErrorCode(_) - | IncompleteResponse { - .. - } => None, - SocketError(ref e) => Some(e), - Json(ref e) => Some(e), - } - } -} - -impl From for Error { - fn from(e: io::Error) -> Self { - Error::SocketError(e) - } -} - -impl From for Error { - fn from(e: serde_json::Error) -> Self { - Error::Json(e) - } -} - -impl From for crate::Error { - fn from(e: Error) -> crate::Error { - match e { - Error::Json(e) => crate::Error::Json(e), - e => crate::Error::Transport(Box::new(e)), - } - } -} - /// Does some very basic manual URL parsing because the uri/url crates /// all have unicode-normalization as a dependency and that's broken. fn check_url(url: &str) -> Result<(SocketAddr, String), Error> { @@ -598,16 +391,16 @@ impl Builder { self } - #[cfg(feature = "proxy")] /// Adds proxy address to the transport for SOCKS5 proxy. + #[cfg(feature = "proxy")] pub fn proxy_addr>(mut self, proxy_addr: S) -> Result { // We don't expect path in proxy address. self.tp.proxy_addr = check_url(proxy_addr.as_ref())?.0; Ok(self) } - #[cfg(feature = "proxy")] /// Adds optional proxy authentication as ('username', 'password'). + #[cfg(feature = "proxy")] pub fn proxy_auth>(mut self, user: S, pass: S) -> Self { self.tp.proxy_auth = Some((user, pass)).map(|(u, p)| (u.as_ref().to_string(), p.as_ref().to_string())); @@ -640,8 +433,8 @@ impl crate::Client { Ok(crate::Client::with_transport(builder.build())) } - #[cfg(feature = "proxy")] /// Creates a new JSON_RPC client using a HTTP-Socks5 proxy transport. + #[cfg(feature = "proxy")] pub fn http_proxy( url: &str, user: Option, @@ -662,6 +455,213 @@ impl crate::Client { } } +/// Error that can happen when sending requests. +#[derive(Debug)] +pub enum Error { + /// An invalid URL was passed. + InvalidUrl { + /// The URL passed. + url: String, + /// The reason the URL is invalid. + reason: &'static str, + }, + /// An error occurred on the socket layer. + SocketError(io::Error), + /// The HTTP response was too short to even fit a HTTP 1.1 header. + HttpResponseTooShort { + /// The total length of the response. + actual: usize, + /// Minimum length we can parse. + needed: usize, + }, + /// The HTTP response started with a HTTP/1.1 line which was not ASCII. + HttpResponseNonAsciiHello(Vec), + /// The HTTP response did not start with HTTP/1.1 + HttpResponseBadHello { + /// Actual HTTP-whatever string. + actual: String, + /// The hello string of the HTTP version we support. + expected: String, + }, + /// Could not parse the status value as a number. + HttpResponseBadStatus(String, num::ParseIntError), + /// Could not parse the status value as a number. + HttpResponseBadContentLength(String, num::ParseIntError), + /// The indicated content-length header exceeded our maximum. + HttpResponseContentLengthTooLarge { + /// The length indicated in the content-length header. + length: u64, + /// Our hard maximum on number of bytes we'll try to read. + max: u64, + }, + /// Unexpected HTTP error code (non-200). + HttpErrorCode(u16), + /// Received EOF before getting as many bytes as were indicated by the content-length header. + IncompleteResponse { + /// The content-length header. + content_length: u64, + /// The number of bytes we actually read. + n_read: u64, + }, + /// JSON parsing error. + Json(serde_json::Error), +} + +impl Error { + /// Utility method to create [`Error::InvalidUrl`] variants. + fn url>(url: U, reason: &'static str) -> Error { + Error::InvalidUrl { + url: url.into(), + reason, + } + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { + use Error::*; + + match *self { + InvalidUrl { + ref url, + ref reason, + } => write!(f, "invalid URL '{}': {}", url, reason), + SocketError(ref e) => write!(f, "Couldn't connect to host: {}", e), + HttpResponseTooShort { + ref actual, + ref needed, + } => { + write!(f, "HTTP response too short: length {}, needed {}.", actual, needed) + } + HttpResponseNonAsciiHello(ref bytes) => { + write!(f, "HTTP response started with non-ASCII {:?}", bytes) + } + HttpResponseBadHello { + ref actual, + ref expected, + } => { + write!(f, "HTTP response started with `{}`; expected `{}`.", actual, expected) + } + HttpResponseBadStatus(ref status, ref err) => { + write!(f, "HTTP response had bad status code `{}`: {}.", status, err) + } + HttpResponseBadContentLength(ref len, ref err) => { + write!(f, "HTTP response had bad content length `{}`: {}.", len, err) + } + HttpResponseContentLengthTooLarge { + length, + max, + } => { + write!(f, "HTTP response content length {} exceeds our max {}.", length, max) + } + HttpErrorCode(c) => write!(f, "unexpected HTTP code: {}", c), + IncompleteResponse { + content_length, + n_read, + } => { + write!( + f, + "read {} bytes but HTTP response content-length header was {}.", + n_read, content_length + ) + } + Json(ref e) => write!(f, "JSON error: {}", e), + } + } +} + +impl error::Error for Error { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { + use self::Error::*; + + match *self { + InvalidUrl { + .. + } + | HttpResponseTooShort { + .. + } + | HttpResponseNonAsciiHello(..) + | HttpResponseBadHello { + .. + } + | HttpResponseBadStatus(..) + | HttpResponseBadContentLength(..) + | HttpResponseContentLengthTooLarge { + .. + } + | HttpErrorCode(_) + | IncompleteResponse { + .. + } => None, + SocketError(ref e) => Some(e), + Json(ref e) => Some(e), + } + } +} + +impl From for Error { + fn from(e: io::Error) -> Self { + Error::SocketError(e) + } +} + +impl From for Error { + fn from(e: serde_json::Error) -> Self { + Error::Json(e) + } +} + +impl From for crate::Error { + fn from(e: Error) -> crate::Error { + match e { + Error::Json(e) => crate::Error::Json(e), + e => crate::Error::Transport(Box::new(e)), + } + } +} + +/// Global mutex used by the fuzzing harness to inject data into the read end of the TCP stream. +#[cfg(fuzzing)] +pub static FUZZ_TCP_SOCK: Mutex>>> = Mutex::new(None); + +#[cfg(fuzzing)] +#[derive(Clone, Debug)] +struct TcpStream; + +#[cfg(fuzzing)] +mod impls { + use super::*; + impl Read for TcpStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + match *FUZZ_TCP_SOCK.lock().unwrap() { + Some(ref mut cursor) => io::Read::read(cursor, buf), + None => Ok(0), + } + } + } + impl Write for TcpStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + io::sink().write(buf) + } + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } + } + + impl TcpStream { + pub fn connect_timeout(_: &SocketAddr, _: Duration) -> io::Result { + Ok(TcpStream) + } + pub fn set_read_timeout(&self, _: Option) -> io::Result<()> { + Ok(()) + } + pub fn set_write_timeout(&self, _: Option) -> io::Result<()> { + Ok(()) + } + } +} + #[cfg(test)] mod tests { #[cfg(not(feature = "proxy"))] diff --git a/src/simple_tcp.rs b/src/simple_tcp.rs index acd7c9fb..27130e4e 100644 --- a/src/simple_tcp.rs +++ b/src/simple_tcp.rs @@ -1,14 +1,64 @@ // SPDX-License-Identifier: CC0-1.0 -//! This module implements a synchronous transport over a raw TcpListener. Note that -//! it does not handle TCP over Unix Domain Sockets, see `simple_uds` for this. -//! +//! This module implements a synchronous transport over a raw [`std::net::TcpListener`]. +//! Note that it does not handle TCP over Unix Domain Sockets, see `simple_uds` for this. use std::{error, fmt, io, net, time}; use crate::client::Transport; use crate::{Request, Response}; +#[derive(Debug, Clone)] +/// Simple synchronous TCP transport. +pub struct TcpTransport { + /// The internet socket address to connect to. + pub addr: net::SocketAddr, + /// The read and write timeout to use for this connection. + pub timeout: Option, +} + +impl TcpTransport { + /// Creates a new `TcpTransport` without timeouts. + pub fn new(addr: net::SocketAddr) -> TcpTransport { + TcpTransport { + addr, + timeout: None, + } + } + + fn request(&self, req: impl serde::Serialize) -> Result + where + R: for<'a> serde::de::Deserialize<'a>, + { + let mut sock = net::TcpStream::connect(self.addr)?; + sock.set_read_timeout(self.timeout)?; + sock.set_write_timeout(self.timeout)?; + + serde_json::to_writer(&mut sock, &req)?; + + // NOTE: we don't check the id there, so it *must* be synchronous + let resp: R = serde_json::Deserializer::from_reader(&mut sock) + .into_iter() + .next() + .ok_or(Error::Timeout)??; + Ok(resp) + } +} + +impl Transport for TcpTransport { + fn send_request(&self, req: Request) -> Result { + Ok(self.request(req)?) + } + + fn send_batch(&self, reqs: &[Request]) -> Result, crate::Error> { + Ok(self.request(reqs)?) + } + + fn fmt_target(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.addr) + } +} + /// Error that can occur while using the TCP transport. #[derive(Debug)] pub enum Error { @@ -22,10 +72,12 @@ pub enum Error { impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { + use Error::*; + match *self { - Error::SocketError(ref e) => write!(f, "Couldn't connect to host: {}", e), - Error::Timeout => f.write_str("Didn't receive response data in time, timed out."), - Error::Json(ref e) => write!(f, "JSON error: {}", e), + SocketError(ref e) => write!(f, "couldn't connect to host: {}", e), + Timeout => f.write_str("didn't receive response data in time, timed out."), + Json(ref e) => write!(f, "JSON error: {}", e), } } } @@ -63,57 +115,6 @@ impl From for crate::Error { } } -/// Simple synchronous TCP transport. -#[derive(Debug, Clone)] -pub struct TcpTransport { - /// The internet socket address to connect to. - pub addr: net::SocketAddr, - /// The read and write timeout to use for this connection. - pub timeout: Option, -} - -impl TcpTransport { - /// Creates a new TcpTransport without timeouts. - pub fn new(addr: net::SocketAddr) -> TcpTransport { - TcpTransport { - addr, - timeout: None, - } - } - - fn request(&self, req: impl serde::Serialize) -> Result - where - R: for<'a> serde::de::Deserialize<'a>, - { - let mut sock = net::TcpStream::connect(self.addr)?; - sock.set_read_timeout(self.timeout)?; - sock.set_write_timeout(self.timeout)?; - - serde_json::to_writer(&mut sock, &req)?; - - // NOTE: we don't check the id there, so it *must* be synchronous - let resp: R = serde_json::Deserializer::from_reader(&mut sock) - .into_iter() - .next() - .ok_or(Error::Timeout)??; - Ok(resp) - } -} - -impl Transport for TcpTransport { - fn send_request(&self, req: Request) -> Result { - Ok(self.request(req)?) - } - - fn send_batch(&self, reqs: &[Request]) -> Result, crate::Error> { - Ok(self.request(reqs)?) - } - - fn fmt_target(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.addr) - } -} - #[cfg(test)] mod tests { use std::{ diff --git a/src/simple_uds.rs b/src/simple_uds.rs index a2fb91ad..a1afdf58 100644 --- a/src/simple_uds.rs +++ b/src/simple_uds.rs @@ -1,7 +1,6 @@ // SPDX-License-Identifier: CC0-1.0 -//! This module implements a synchronous transport over a raw TcpListener. -//! +//! This module implements a synchronous transport over a raw [`std::net::TcpListener`]. use std::os::unix::net::UnixStream; use std::{error, fmt, io, path, time}; @@ -9,60 +8,6 @@ use std::{error, fmt, io, path, time}; use crate::client::Transport; use crate::{Request, Response}; -/// Error that can occur while using the UDS transport. -#[derive(Debug)] -pub enum Error { - /// An error occurred on the socket layer. - SocketError(io::Error), - /// We didn't receive a complete response till the deadline ran out. - Timeout, - /// JSON parsing error. - Json(serde_json::Error), -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { - match *self { - Error::SocketError(ref e) => write!(f, "Couldn't connect to host: {}", e), - Error::Timeout => f.write_str("Didn't receive response data in time, timed out."), - Error::Json(ref e) => write!(f, "JSON error: {}", e), - } - } -} - -impl error::Error for Error { - fn source(&self) -> Option<&(dyn error::Error + 'static)> { - use self::Error::*; - - match *self { - SocketError(ref e) => Some(e), - Timeout => None, - Json(ref e) => Some(e), - } - } -} - -impl From for Error { - fn from(e: io::Error) -> Self { - Error::SocketError(e) - } -} - -impl From for Error { - fn from(e: serde_json::Error) -> Self { - Error::Json(e) - } -} - -impl From for crate::error::Error { - fn from(e: Error) -> crate::error::Error { - match e { - Error::Json(e) => crate::error::Error::Json(e), - e => crate::error::Error::Transport(Box::new(e)), - } - } -} - /// Simple synchronous UDS transport. #[derive(Debug, Clone)] pub struct UdsTransport { @@ -114,6 +59,62 @@ impl Transport for UdsTransport { } } +/// Error that can occur while using the UDS transport. +#[derive(Debug)] +pub enum Error { + /// An error occurred on the socket layer. + SocketError(io::Error), + /// We didn't receive a complete response till the deadline ran out. + Timeout, + /// JSON parsing error. + Json(serde_json::Error), +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { + use Error::*; + + match *self { + SocketError(ref e) => write!(f, "couldn't connect to host: {}", e), + Timeout => f.write_str("didn't receive response data in time, timed out."), + Json(ref e) => write!(f, "JSON error: {}", e), + } + } +} + +impl error::Error for Error { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { + use self::Error::*; + + match *self { + SocketError(ref e) => Some(e), + Timeout => None, + Json(ref e) => Some(e), + } + } +} + +impl From for Error { + fn from(e: io::Error) -> Self { + Error::SocketError(e) + } +} + +impl From for Error { + fn from(e: serde_json::Error) -> Self { + Error::Json(e) + } +} + +impl From for crate::error::Error { + fn from(e: Error) -> crate::error::Error { + match e { + Error::Json(e) => crate::error::Error::Json(e), + e => crate::error::Error::Transport(Box::new(e)), + } + } +} + #[cfg(test)] mod tests { use std::{ diff --git a/src/util.rs b/src/util.rs deleted file mode 100644 index 7770cce9..00000000 --- a/src/util.rs +++ /dev/null @@ -1,101 +0,0 @@ -// SPDX-License-Identifier: CC0-1.0 - -use std::borrow::Cow; -use std::hash::{Hash, Hasher}; - -use serde_json::Value; - -/// Newtype around `Value` which allows hashing for use as hashmap keys, -/// this is needed for batch requests. -/// -/// The reason `Value` does not support `Hash` or `Eq` by itself -/// is that it supports `f64` values; but for batch requests we -/// will only be hashing the "id" field of the request/response -/// pair, which should never need decimal precision and therefore -/// never use `f64`. -#[derive(Clone, PartialEq, Debug)] -pub struct HashableValue<'a>(pub Cow<'a, Value>); - -impl<'a> Eq for HashableValue<'a> {} - -impl<'a> Hash for HashableValue<'a> { - fn hash(&self, state: &mut H) { - match *self.0.as_ref() { - Value::Null => "null".hash(state), - Value::Bool(false) => "false".hash(state), - Value::Bool(true) => "true".hash(state), - Value::Number(ref n) => { - "number".hash(state); - if let Some(n) = n.as_i64() { - n.hash(state); - } else if let Some(n) = n.as_u64() { - n.hash(state); - } else { - n.to_string().hash(state); - } - } - Value::String(ref s) => { - "string".hash(state); - s.hash(state); - } - Value::Array(ref v) => { - "array".hash(state); - v.len().hash(state); - for obj in v { - HashableValue(Cow::Borrowed(obj)).hash(state); - } - } - Value::Object(ref m) => { - "object".hash(state); - m.len().hash(state); - for (key, val) in m { - key.hash(state); - HashableValue(Cow::Borrowed(val)).hash(state); - } - } - } - } -} - -#[cfg(test)] -mod tests { - use std::borrow::Cow; - use std::collections::HashSet; - use std::str::FromStr; - - use super::*; - - #[test] - fn hash_value() { - let val = HashableValue(Cow::Owned(Value::from_str("null").unwrap())); - let t = HashableValue(Cow::Owned(Value::from_str("true").unwrap())); - let f = HashableValue(Cow::Owned(Value::from_str("false").unwrap())); - let ns = - HashableValue(Cow::Owned(Value::from_str("[0, -0, 123.4567, -100000000]").unwrap())); - let m = - HashableValue(Cow::Owned(Value::from_str("{ \"field\": 0, \"field\": -0 }").unwrap())); - - let mut coll = HashSet::new(); - - assert!(!coll.contains(&val)); - coll.insert(val.clone()); - assert!(coll.contains(&val)); - - assert!(!coll.contains(&t)); - assert!(!coll.contains(&f)); - coll.insert(t.clone()); - assert!(coll.contains(&t)); - assert!(!coll.contains(&f)); - coll.insert(f.clone()); - assert!(coll.contains(&t)); - assert!(coll.contains(&f)); - - assert!(!coll.contains(&ns)); - coll.insert(ns.clone()); - assert!(coll.contains(&ns)); - - assert!(!coll.contains(&m)); - coll.insert(m.clone()); - assert!(coll.contains(&m)); - } -}