diff --git a/CHANGELOG.md b/CHANGELOG.md index 8cbed213..72a7bb2c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,8 @@ # Unreleased + * Stop passing internal state in Config (#996) + * Support request level TlsConfig (#996) + # 3.0.5 * Fix incorrect reading of valid utf8 (#992) diff --git a/src/config.rs b/src/config.rs index dd559973..415ff448 100644 --- a/src/config.rs +++ b/src/config.rs @@ -159,10 +159,6 @@ pub struct Config { // Chain built for middleware. pub(crate) middleware: MiddlewareChain, - - // Techically not config, but here to pass as argument from - // RequestBuilder::force_send_body() to run() - pub(crate) force_send_body: bool, } impl Config { @@ -849,7 +845,6 @@ impl Default for Config { max_idle_connections_per_host: 3, max_idle_age: Duration::from_secs(15), middleware: MiddlewareChain::default(), - force_send_body: false, } } } diff --git a/src/request.rs b/src/request.rs index 26c66e9d..62f81c90 100644 --- a/src/request.rs +++ b/src/request.rs @@ -402,10 +402,9 @@ impl RequestBuilder { /// # Ok::<_, ureq::Error>(()) /// ``` pub fn force_send_body(mut self) -> RequestBuilder { - // This is how we communicate to run() that we want to disable - // the method-body-compliance check. - let config = self.request_level_config(); - config.force_send_body = true; + if let Some(exts) = self.extensions_mut() { + exts.insert(ForceSendBody); + } RequestBuilder { agent: self.agent, @@ -417,6 +416,9 @@ impl RequestBuilder { } } +#[derive(Debug, Clone)] +pub(crate) struct ForceSendBody; + impl RequestBuilder { pub(crate) fn new(agent: Agent, method: Method, uri: T) -> Self where diff --git a/src/run.rs b/src/run.rs index e4bf2439..4271c301 100644 --- a/src/run.rs +++ b/src/run.rs @@ -13,6 +13,7 @@ use crate::body::ResponseInfo; use crate::config::{Config, RequestLevelConfig, DEFAULT_USER_AGENT}; use crate::http; use crate::pool::Connection; +use crate::request::ForceSendBody; use crate::response::{RedirectHistory, ResponseUri}; use crate::timings::{CallTimings, CurrentTime}; use crate::transport::time::{Duration, Instant}; @@ -33,12 +34,13 @@ pub(crate) fn run( let mut redirect_count = 0; // Configuration on the request level overrides the agent level. - let config = request + let (config, request_level) = request .extensions_mut() .remove::() - .map(|rl| rl.0) - .map(Arc::new) - .unwrap_or_else(|| agent.config.clone()); + .map(|rl| (Arc::new(rl.0), true)) + .unwrap_or_else(|| (agent.config.clone(), false)); + + let force_send_body = request.extensions_mut().remove::().is_some(); let mut redirect_history: Option> = config.save_redirect_history().then_some(Vec::new()); @@ -49,7 +51,7 @@ pub(crate) fn run( let mut flow = Flow::new(request)?; - if config.force_send_body { + if force_send_body { flow.send_body_despite_method(); } @@ -66,6 +68,7 @@ pub(crate) fn run( match flow_run( agent, &config, + request_level, flow, &mut body, redirect_count, @@ -109,9 +112,11 @@ pub(crate) fn run( Ok(response) } +#[allow(clippy::too_many_arguments)] fn flow_run( agent: &Agent, config: &Config, + request_level: bool, mut flow: Flow, body: &mut SendBody, redirect_count: u32, @@ -127,7 +132,7 @@ fn flow_run( add_headers(&mut flow, agent, config, body, &uri)?; - let mut connection = connect(agent, config, &uri, timings)?; + let mut connection = connect(agent, config, request_level, &uri, timings)?; let mut flow = flow.proceed(); @@ -336,6 +341,7 @@ fn add_headers( fn connect( agent: &Agent, config: &Config, + request_level: bool, wanted_uri: &Uri, timings: &mut CallTimings, ) -> Result { @@ -363,6 +369,7 @@ fn connect( addrs, resolver: &*agent.resolver, config, + request_level, now: timings.now(), timeout: timings.next_timeout(Timeout::Connect), proxied, diff --git a/src/tls/cert.rs b/src/tls/cert.rs index 539df183..0408d15b 100644 --- a/src/tls/cert.rs +++ b/src/tls/cert.rs @@ -1,4 +1,5 @@ use std::fmt; +use std::hash::{Hash, Hasher}; use crate::Error; @@ -8,7 +9,7 @@ use crate::Error; /// /// The internal representation is DER form. The provided helpers for PEM /// translates to DER. -#[derive(Clone)] +#[derive(Clone, Hash)] pub struct Certificate<'a> { der: CertDer<'a>, } @@ -20,6 +21,13 @@ enum CertDer<'a> { Rustls(rustls_pki_types::CertificateDer<'static>), } +impl Hash for CertDer<'_> { + fn hash(&self, state: &mut H) { + core::mem::discriminant(self).hash(state); + self.as_ref().hash(state) + } +} + impl<'a> AsRef<[u8]> for CertDer<'a> { fn as_ref(&self) -> &[u8] { match self { @@ -78,6 +86,7 @@ impl<'a> Certificate<'a> { /// translates to DER. /// /// Deliberately not `Clone` to avoid accidental copies in memory. +#[derive(Hash)] pub struct PrivateKey<'a> { kind: KeyKind, der: PrivateKeyDer<'a>, @@ -89,6 +98,17 @@ enum PrivateKeyDer<'a> { Rustls(rustls_pki_types::PrivateKeyDer<'static>), } +impl Hash for PrivateKeyDer<'_> { + fn hash(&self, state: &mut H) { + core::mem::discriminant(self).hash(state); + match self { + PrivateKeyDer::Borrowed(v) => v.hash(state), + PrivateKeyDer::Owned(v) => v.hash(state), + PrivateKeyDer::Rustls(v) => v.secret_der().as_ref().hash(state), + } + } +} + impl<'a> AsRef<[u8]> for PrivateKey<'a> { fn as_ref(&self) -> &[u8] { match &self.der { @@ -103,7 +123,7 @@ impl<'a> AsRef<[u8]> for PrivateKey<'a> { /// /// * For **rustls** any kind is valid. /// * For **native-tls** the only valid option is [`Pkcs8`](KeyKind::Pkcs8). -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[non_exhaustive] pub enum KeyKind { /// An RSA private key diff --git a/src/tls/mod.rs b/src/tls/mod.rs index 55610e4a..8b41637c 100644 --- a/src/tls/mod.rs +++ b/src/tls/mod.rs @@ -1,6 +1,8 @@ //! TLS for handling `https`. +use std::collections::hash_map::DefaultHasher; use std::fmt; +use std::hash::{Hash, Hasher}; use std::sync::Arc; mod cert; @@ -17,7 +19,7 @@ pub(crate) mod native_tls; /// Defaults to [`Rustls`][Self::Rustls] because this has the highest chance /// to compile and "just work" straight out of the box without installing additional /// development dependencies. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[non_exhaustive] pub enum TlsProvider { /// [Rustls](https://crates.io/crates/rustls) with the @@ -81,6 +83,12 @@ impl TlsConfig { config: TlsConfig::default(), } } + + pub(crate) fn hash_value(&self) -> u64 { + let mut hasher = DefaultHasher::new(); + self.hash(&mut hasher); + hasher.finish() + } } impl TlsConfig { @@ -259,7 +267,7 @@ impl TlsConfigBuilder { } /// A client certificate. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Hash)] pub struct ClientCert(Arc<(Vec>, PrivateKey<'static>)>); impl ClientCert { @@ -280,7 +288,7 @@ impl ClientCert { } /// Configuration setting for root certs. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Hash)] #[non_exhaustive] pub enum RootCerts { /// Use these specific certificates as root certs. @@ -348,6 +356,21 @@ impl fmt::Debug for TlsConfig { } } +impl Hash for TlsConfig { + fn hash(&self, state: &mut H) { + self.provider.hash(state); + self.client_cert.hash(state); + self.root_certs.hash(state); + self.use_sni.hash(state); + self.disable_verification.hash(state); + + #[cfg(feature = "_rustls")] + if let Some(arc) = &self.rustls_crypto_provider { + (Arc::as_ptr(arc) as usize).hash(state); + } + } +} + #[cfg(test)] mod test { use super::*; diff --git a/src/tls/native_tls.rs b/src/tls/native_tls.rs index 61561ea0..ebf1ebc9 100644 --- a/src/tls/native_tls.rs +++ b/src/tls/native_tls.rs @@ -17,7 +17,12 @@ use super::TlsConfig; /// Requires feature flag **native-tls**. #[derive(Default)] pub struct NativeTlsConnector { - connector: OnceLock>, + connector: OnceLock, +} + +struct CachedNativeTlsConnector { + config_hash: u64, + native_tls_connector: Arc, } impl Connector for NativeTlsConnector { @@ -46,20 +51,7 @@ impl Connector for NativeTlsConnector { trace!("Try wrap TLS"); - let tls_config = &details.config.tls_config(); - - // Initialize the connector on first run. - let connector_ref = match self.connector.get() { - Some(v) => v, - None => { - // This is unlikely to be racy, but if it is, doesn't matter much. - let c = build_connector(tls_config)?; - // Maybe someone else set it first. Weird, but ok. - let _ = self.connector.set(c); - self.connector.get().unwrap() - } - }; - let connector = connector_ref.clone(); // cheap clone due to Arc + let connector = self.get_cached_native_tls_connector(details)?; let domain = details .uri @@ -84,7 +76,52 @@ impl Connector for NativeTlsConnector { } } -fn build_connector(tls_config: &TlsConfig) -> Result, Error> { +impl NativeTlsConnector { + fn get_cached_native_tls_connector( + &self, + details: &ConnectionDetails, + ) -> Result, Error> { + let tls_config = details.config.tls_config(); + + let connector = if details.request_level { + // If the TlsConfig is request level, it is not allowed to + // initialize the self.config OnceLock, but it should + // reuse the cached value if it is the same TlsConfig + // by comparing the config_hash value. + + let is_cached = self + .connector + .get() + .map(|c| c.config_hash == tls_config.hash_value()) + .unwrap_or(false); + + if is_cached { + // unwrap is ok because if is_cached is true we must have had a value. + self.connector.get().unwrap().native_tls_connector.clone() + } else { + build_connector(tls_config)?.native_tls_connector + } + } else { + // Initialize the connector on first run. + let connector_ref = match self.connector.get() { + Some(v) => v, + None => { + // This is unlikely to be racy, but if it is, doesn't matter much. + let c = build_connector(tls_config)?; + // Maybe someone else set it first. Weird, but ok. + let _ = self.connector.set(c); + self.connector.get().unwrap() + } + }; + + connector_ref.native_tls_connector.clone() // cheap clone due to Arc + }; + + Ok(connector) + } +} + +fn build_connector(tls_config: &TlsConfig) -> Result { let mut builder = TlsConnector::builder(); if tls_config.disable_verification { @@ -136,7 +173,12 @@ fn build_connector(tls_config: &TlsConfig) -> Result, Error> { let conn = builder.build()?; - Ok(Arc::new(conn)) + let cached = CachedNativeTlsConnector { + config_hash: tls_config.hash_value(), + native_tls_connector: Arc::new(conn), + }; + + Ok(cached) } fn add_valid_der<'a, C>(certs: C, builder: &mut TlsConnectorBuilder) diff --git a/src/tls/rustls.rs b/src/tls/rustls.rs index 5a5f6180..61151307 100644 --- a/src/tls/rustls.rs +++ b/src/tls/rustls.rs @@ -22,7 +22,12 @@ use super::TlsConfig; /// Requires feature flag **rustls**. #[derive(Default)] pub struct RustlsConnector { - config: OnceLock>, + config: OnceLock, +} + +struct CachedRustlConfig { + config_hash: u64, + rustls_config: Arc, } impl Connector for RustlsConnector { @@ -51,11 +56,7 @@ impl Connector for RustlsConnector { trace!("Try wrap in TLS"); - let tls_config = details.config.tls_config(); - - // Initialize the config on first run. - let config_ref = self.config.get_or_init(|| build_config(tls_config)); - let config = config_ref.clone(); // cheap clone due to Arc + let config = self.get_cached_config(details); let name_borrowed: ServerName<'_> = details .uri @@ -89,7 +90,39 @@ impl Connector for RustlsConnector { } } -fn build_config(tls_config: &TlsConfig) -> Arc { +impl RustlsConnector { + fn get_cached_config(&self, details: &ConnectionDetails) -> Arc { + let tls_config = details.config.tls_config(); + + if details.request_level { + // If the TlsConfig is request level, it is not allowed to + // initialize the self.config OnceLock, but it should + // reuse the cached value if it is the same TlsConfig + // by comparing the config_hash value. + + let is_cached = self + .config + .get() + .map(|c| c.config_hash == tls_config.hash_value()) + .unwrap_or(false); + + if is_cached { + // unwrap is ok because if is_cached is true we must have had a value. + self.config.get().unwrap().rustls_config.clone() + } else { + build_config(tls_config).rustls_config + } + } else { + // On agent level, we initialize the config on first run. This is + // the value we want to cache. + let config_ref = self.config.get_or_init(|| build_config(tls_config)); + + config_ref.rustls_config.clone() + } + } +} + +fn build_config(tls_config: &TlsConfig) -> CachedRustlConfig { // 1. Prefer provider set by TlsConfig. // 2. Use process wide default set in rustls library. // 3. Pick ring, if it is enabled (the default behavior). @@ -183,7 +216,10 @@ fn build_config(tls_config: &TlsConfig) -> Arc { debug!("Disable SNI"); } - Arc::new(config) + CachedRustlConfig { + config_hash: tls_config.hash_value(), + rustls_config: Arc::new(config), + } } pub struct RustlsTransport { diff --git a/src/unversioned/transport/mod.rs b/src/unversioned/transport/mod.rs index 96930cce..c0a42f28 100644 --- a/src/unversioned/transport/mod.rs +++ b/src/unversioned/transport/mod.rs @@ -183,9 +183,14 @@ pub struct ConnectionDetails<'a> { /// For CONNECT proxy, this is the address of the proxy server. pub addrs: ResolvedSocketAddrs, - /// The Agent configuration. + /// The configuration. + /// + /// Agent or Request level. pub config: &'a Config, + /// Whether the config is request level. + pub request_level: bool, + /// The resolver configured on [`Agent`](crate::Agent). /// /// Typically the IP address of the host in the uri is already resolved to the `addr`