Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle request level TlsConfig #996

Merged
merged 2 commits into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Unreleased

* Stop passing internal state in Config (#996)
* Support request level TlsConfig (#996)

# 3.0.5

* Fix incorrect reading of valid utf8 (#992)
Expand Down
5 changes: 0 additions & 5 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,6 @@ pub struct Config {

// Chain built for middleware.
pub(crate) middleware: MiddlewareChain,

// Techically not config, but here to pass as argument from
// RequestBuilder::force_send_body() to run()
pub(crate) force_send_body: bool,
}

impl Config {
Expand Down Expand Up @@ -849,7 +845,6 @@ impl Default for Config {
max_idle_connections_per_host: 3,
max_idle_age: Duration::from_secs(15),
middleware: MiddlewareChain::default(),
force_send_body: false,
}
}
}
Expand Down
10 changes: 6 additions & 4 deletions src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -402,10 +402,9 @@ impl RequestBuilder<WithoutBody> {
/// # Ok::<_, ureq::Error>(())
/// ```
pub fn force_send_body(mut self) -> RequestBuilder<WithBody> {
// This is how we communicate to run() that we want to disable
// the method-body-compliance check.
let config = self.request_level_config();
config.force_send_body = true;
if let Some(exts) = self.extensions_mut() {
exts.insert(ForceSendBody);
}

RequestBuilder {
agent: self.agent,
Expand All @@ -417,6 +416,9 @@ impl RequestBuilder<WithoutBody> {
}
}

#[derive(Debug, Clone)]
pub(crate) struct ForceSendBody;

impl RequestBuilder<WithBody> {
pub(crate) fn new<T>(agent: Agent, method: Method, uri: T) -> Self
where
Expand Down
19 changes: 13 additions & 6 deletions src/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use crate::body::ResponseInfo;
use crate::config::{Config, RequestLevelConfig, DEFAULT_USER_AGENT};
use crate::http;
use crate::pool::Connection;
use crate::request::ForceSendBody;
use crate::response::{RedirectHistory, ResponseUri};
use crate::timings::{CallTimings, CurrentTime};
use crate::transport::time::{Duration, Instant};
Expand All @@ -33,12 +34,13 @@ pub(crate) fn run(
let mut redirect_count = 0;

// Configuration on the request level overrides the agent level.
let config = request
let (config, request_level) = request
.extensions_mut()
.remove::<RequestLevelConfig>()
.map(|rl| rl.0)
.map(Arc::new)
.unwrap_or_else(|| agent.config.clone());
.map(|rl| (Arc::new(rl.0), true))
.unwrap_or_else(|| (agent.config.clone(), false));

let force_send_body = request.extensions_mut().remove::<ForceSendBody>().is_some();

let mut redirect_history: Option<Vec<Uri>> =
config.save_redirect_history().then_some(Vec::new());
Expand All @@ -49,7 +51,7 @@ pub(crate) fn run(

let mut flow = Flow::new(request)?;

if config.force_send_body {
if force_send_body {
flow.send_body_despite_method();
}

Expand All @@ -66,6 +68,7 @@ pub(crate) fn run(
match flow_run(
agent,
&config,
request_level,
flow,
&mut body,
redirect_count,
Expand Down Expand Up @@ -109,9 +112,11 @@ pub(crate) fn run(
Ok(response)
}

#[allow(clippy::too_many_arguments)]
fn flow_run(
agent: &Agent,
config: &Config,
request_level: bool,
mut flow: Flow<Prepare>,
body: &mut SendBody,
redirect_count: u32,
Expand All @@ -127,7 +132,7 @@ fn flow_run(

add_headers(&mut flow, agent, config, body, &uri)?;

let mut connection = connect(agent, config, &uri, timings)?;
let mut connection = connect(agent, config, request_level, &uri, timings)?;

let mut flow = flow.proceed();

Expand Down Expand Up @@ -336,6 +341,7 @@ fn add_headers(
fn connect(
agent: &Agent,
config: &Config,
request_level: bool,
wanted_uri: &Uri,
timings: &mut CallTimings,
) -> Result<Connection, Error> {
Expand Down Expand Up @@ -363,6 +369,7 @@ fn connect(
addrs,
resolver: &*agent.resolver,
config,
request_level,
now: timings.now(),
timeout: timings.next_timeout(Timeout::Connect),
proxied,
Expand Down
24 changes: 22 additions & 2 deletions src/tls/cert.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::fmt;
use std::hash::{Hash, Hasher};

use crate::Error;

Expand All @@ -8,7 +9,7 @@ use crate::Error;
///
/// The internal representation is DER form. The provided helpers for PEM
/// translates to DER.
#[derive(Clone)]
#[derive(Clone, Hash)]
pub struct Certificate<'a> {
der: CertDer<'a>,
}
Expand All @@ -20,6 +21,13 @@ enum CertDer<'a> {
Rustls(rustls_pki_types::CertificateDer<'static>),
}

impl Hash for CertDer<'_> {
fn hash<H: Hasher>(&self, state: &mut H) {
core::mem::discriminant(self).hash(state);
self.as_ref().hash(state)
}
}

impl<'a> AsRef<[u8]> for CertDer<'a> {
fn as_ref(&self) -> &[u8] {
match self {
Expand Down Expand Up @@ -78,6 +86,7 @@ impl<'a> Certificate<'a> {
/// translates to DER.
///
/// Deliberately not `Clone` to avoid accidental copies in memory.
#[derive(Hash)]
pub struct PrivateKey<'a> {
kind: KeyKind,
der: PrivateKeyDer<'a>,
Expand All @@ -89,6 +98,17 @@ enum PrivateKeyDer<'a> {
Rustls(rustls_pki_types::PrivateKeyDer<'static>),
}

impl Hash for PrivateKeyDer<'_> {
fn hash<H: Hasher>(&self, state: &mut H) {
core::mem::discriminant(self).hash(state);
match self {
PrivateKeyDer::Borrowed(v) => v.hash(state),
PrivateKeyDer::Owned(v) => v.hash(state),
PrivateKeyDer::Rustls(v) => v.secret_der().as_ref().hash(state),
}
}
}

impl<'a> AsRef<[u8]> for PrivateKey<'a> {
fn as_ref(&self) -> &[u8] {
match &self.der {
Expand All @@ -103,7 +123,7 @@ impl<'a> AsRef<[u8]> for PrivateKey<'a> {
///
/// * For **rustls** any kind is valid.
/// * For **native-tls** the only valid option is [`Pkcs8`](KeyKind::Pkcs8).
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum KeyKind {
/// An RSA private key
Expand Down
29 changes: 26 additions & 3 deletions src/tls/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
//! TLS for handling `https`.

use std::collections::hash_map::DefaultHasher;
use std::fmt;
use std::hash::{Hash, Hasher};
use std::sync::Arc;

mod cert;
Expand All @@ -17,7 +19,7 @@ pub(crate) mod native_tls;
/// Defaults to [`Rustls`][Self::Rustls] because this has the highest chance
/// to compile and "just work" straight out of the box without installing additional
/// development dependencies.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum TlsProvider {
/// [Rustls](https://crates.io/crates/rustls) with the
Expand Down Expand Up @@ -81,6 +83,12 @@ impl TlsConfig {
config: TlsConfig::default(),
}
}

pub(crate) fn hash_value(&self) -> u64 {
let mut hasher = DefaultHasher::new();
self.hash(&mut hasher);
hasher.finish()
}
}

impl TlsConfig {
Expand Down Expand Up @@ -259,7 +267,7 @@ impl TlsConfigBuilder {
}

/// A client certificate.
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Hash)]
pub struct ClientCert(Arc<(Vec<Certificate<'static>>, PrivateKey<'static>)>);

impl ClientCert {
Expand All @@ -280,7 +288,7 @@ impl ClientCert {
}

/// Configuration setting for root certs.
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Hash)]
#[non_exhaustive]
pub enum RootCerts {
/// Use these specific certificates as root certs.
Expand Down Expand Up @@ -348,6 +356,21 @@ impl fmt::Debug for TlsConfig {
}
}

impl Hash for TlsConfig {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.provider.hash(state);
self.client_cert.hash(state);
self.root_certs.hash(state);
self.use_sni.hash(state);
self.disable_verification.hash(state);

#[cfg(feature = "_rustls")]
if let Some(arc) = &self.rustls_crypto_provider {
(Arc::as_ptr(arc) as usize).hash(state);
}
}
}

#[cfg(test)]
mod test {
use super::*;
Expand Down
76 changes: 59 additions & 17 deletions src/tls/native_tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@ use super::TlsConfig;
/// Requires feature flag **native-tls**.
#[derive(Default)]
pub struct NativeTlsConnector {
connector: OnceLock<Arc<TlsConnector>>,
connector: OnceLock<CachedNativeTlsConnector>,
}

struct CachedNativeTlsConnector {
config_hash: u64,
native_tls_connector: Arc<TlsConnector>,
}

impl<In: Transport> Connector<In> for NativeTlsConnector {
Expand Down Expand Up @@ -46,20 +51,7 @@ impl<In: Transport> Connector<In> for NativeTlsConnector {

trace!("Try wrap TLS");

let tls_config = &details.config.tls_config();

// Initialize the connector on first run.
let connector_ref = match self.connector.get() {
Some(v) => v,
None => {
// This is unlikely to be racy, but if it is, doesn't matter much.
let c = build_connector(tls_config)?;
// Maybe someone else set it first. Weird, but ok.
let _ = self.connector.set(c);
self.connector.get().unwrap()
}
};
let connector = connector_ref.clone(); // cheap clone due to Arc
let connector = self.get_cached_native_tls_connector(details)?;

let domain = details
.uri
Expand All @@ -84,7 +76,52 @@ impl<In: Transport> Connector<In> for NativeTlsConnector {
}
}

fn build_connector(tls_config: &TlsConfig) -> Result<Arc<TlsConnector>, Error> {
impl NativeTlsConnector {
fn get_cached_native_tls_connector(
&self,
details: &ConnectionDetails,
) -> Result<Arc<TlsConnector>, Error> {
let tls_config = details.config.tls_config();

let connector = if details.request_level {
// If the TlsConfig is request level, it is not allowed to
// initialize the self.config OnceLock, but it should
// reuse the cached value if it is the same TlsConfig
// by comparing the config_hash value.

let is_cached = self
.connector
.get()
.map(|c| c.config_hash == tls_config.hash_value())
.unwrap_or(false);

if is_cached {
// unwrap is ok because if is_cached is true we must have had a value.
self.connector.get().unwrap().native_tls_connector.clone()
} else {
build_connector(tls_config)?.native_tls_connector
}
} else {
// Initialize the connector on first run.
let connector_ref = match self.connector.get() {
Some(v) => v,
None => {
// This is unlikely to be racy, but if it is, doesn't matter much.
let c = build_connector(tls_config)?;
// Maybe someone else set it first. Weird, but ok.
let _ = self.connector.set(c);
self.connector.get().unwrap()
}
};

connector_ref.native_tls_connector.clone() // cheap clone due to Arc
};

Ok(connector)
}
}

fn build_connector(tls_config: &TlsConfig) -> Result<CachedNativeTlsConnector, Error> {
let mut builder = TlsConnector::builder();

if tls_config.disable_verification {
Expand Down Expand Up @@ -136,7 +173,12 @@ fn build_connector(tls_config: &TlsConfig) -> Result<Arc<TlsConnector>, Error> {

let conn = builder.build()?;

Ok(Arc::new(conn))
let cached = CachedNativeTlsConnector {
config_hash: tls_config.hash_value(),
native_tls_connector: Arc::new(conn),
};

Ok(cached)
}

fn add_valid_der<'a, C>(certs: C, builder: &mut TlsConnectorBuilder)
Expand Down
Loading