Skip to content

Commit

Permalink
Handle request level TlsConfig
Browse files Browse the repository at this point in the history
  • Loading branch information
algesten committed Feb 11, 2025
1 parent ac8065c commit 2fd4c49
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 31 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Unreleased

* Support request level TlsConfig (#996)

# 3.0.5

* Fix incorrect reading of valid utf8 (#992)
Expand Down
7 changes: 6 additions & 1 deletion src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,12 @@ impl Agent {
}

pub(crate) fn new_request_level_config(&self) -> RequestLevelConfig {
RequestLevelConfig(self.config.as_ref().clone())
let mut config = self.config.as_ref().clone();

// Set flag indicating this is request level.
config.request_level = true;

RequestLevelConfig(config)
}

/// Make a GET request using this agent.
Expand Down
4 changes: 4 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ pub struct Config {
// Techically not config, but here to pass as argument from
// RequestBuilder::force_send_body() to run()
pub(crate) force_send_body: bool,

// If this config instance is request level.
pub(crate) request_level: bool,
}

impl Config {
Expand Down Expand Up @@ -850,6 +853,7 @@ impl Default for Config {
max_idle_age: Duration::from_secs(15),
middleware: MiddlewareChain::default(),
force_send_body: false,
request_level: false,
}
}
}
Expand Down
24 changes: 22 additions & 2 deletions src/tls/cert.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::fmt;
use std::hash::{Hash, Hasher};

use crate::Error;

Expand All @@ -8,7 +9,7 @@ use crate::Error;
///
/// The internal representation is DER form. The provided helpers for PEM
/// translates to DER.
#[derive(Clone)]
#[derive(Clone, Hash)]
pub struct Certificate<'a> {
der: CertDer<'a>,
}
Expand All @@ -20,6 +21,13 @@ enum CertDer<'a> {
Rustls(rustls_pki_types::CertificateDer<'static>),
}

impl Hash for CertDer<'_> {
fn hash<H: Hasher>(&self, state: &mut H) {
core::mem::discriminant(self).hash(state);
self.as_ref().hash(state)
}
}

impl<'a> AsRef<[u8]> for CertDer<'a> {
fn as_ref(&self) -> &[u8] {
match self {
Expand Down Expand Up @@ -78,6 +86,7 @@ impl<'a> Certificate<'a> {
/// translates to DER.
///
/// Deliberately not `Clone` to avoid accidental copies in memory.
#[derive(Hash)]
pub struct PrivateKey<'a> {
kind: KeyKind,
der: PrivateKeyDer<'a>,
Expand All @@ -89,6 +98,17 @@ enum PrivateKeyDer<'a> {
Rustls(rustls_pki_types::PrivateKeyDer<'static>),
}

impl Hash for PrivateKeyDer<'_> {
fn hash<H: Hasher>(&self, state: &mut H) {
core::mem::discriminant(self).hash(state);
match self {
PrivateKeyDer::Borrowed(v) => v.hash(state),
PrivateKeyDer::Owned(v) => v.hash(state),
PrivateKeyDer::Rustls(v) => v.secret_der().as_ref().hash(state),
}
}
}

impl<'a> AsRef<[u8]> for PrivateKey<'a> {
fn as_ref(&self) -> &[u8] {
match &self.der {
Expand All @@ -103,7 +123,7 @@ impl<'a> AsRef<[u8]> for PrivateKey<'a> {
///
/// * For **rustls** any kind is valid.
/// * For **native-tls** the only valid option is [`Pkcs8`](KeyKind::Pkcs8).
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum KeyKind {
/// An RSA private key
Expand Down
28 changes: 25 additions & 3 deletions src/tls/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! TLS for handling `https`.
use std::fmt;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::Arc;

mod cert;
Expand All @@ -17,7 +18,7 @@ pub(crate) mod native_tls;
/// Defaults to [`Rustls`][Self::Rustls] because this has the highest chance
/// to compile and "just work" straight out of the box without installing additional
/// development dependencies.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum TlsProvider {
/// [Rustls](https://crates.io/crates/rustls) with the
Expand Down Expand Up @@ -81,6 +82,12 @@ impl TlsConfig {
config: TlsConfig::default(),
}
}

pub(crate) fn hash_value(&self) -> u64 {
let mut hasher = DefaultHasher::new();
self.hash(&mut hasher);
hasher.finish()
}
}

impl TlsConfig {
Expand Down Expand Up @@ -259,7 +266,7 @@ impl TlsConfigBuilder {
}

/// A client certificate.
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Hash)]
pub struct ClientCert(Arc<(Vec<Certificate<'static>>, PrivateKey<'static>)>);

impl ClientCert {
Expand All @@ -280,7 +287,7 @@ impl ClientCert {
}

/// Configuration setting for root certs.
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Hash)]
#[non_exhaustive]
pub enum RootCerts {
/// Use these specific certificates as root certs.
Expand Down Expand Up @@ -348,6 +355,21 @@ impl fmt::Debug for TlsConfig {
}
}

impl Hash for TlsConfig {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.provider.hash(state);
self.client_cert.hash(state);
self.root_certs.hash(state);
self.use_sni.hash(state);
self.disable_verification.hash(state);

#[cfg(feature = "_rustls")]
if let Some(arc) = &self.rustls_crypto_provider {
Arc::as_ptr(arc).addr().hash(state);
}
}
}

#[cfg(test)]
mod test {
use super::*;
Expand Down
74 changes: 57 additions & 17 deletions src/tls/native_tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::fmt;
use std::io::{Read, Write};
use std::sync::{Arc, OnceLock};

use crate::config::Config;
use crate::tls::{RootCerts, TlsProvider};
use crate::{transport::*, Error};
use der::pem::LineEnding;
Expand All @@ -17,7 +18,12 @@ use super::TlsConfig;
/// Requires feature flag **native-tls**.
#[derive(Default)]
pub struct NativeTlsConnector {
connector: OnceLock<Arc<TlsConnector>>,
connector: OnceLock<CachedNativeTlsConnector>,
}

struct CachedNativeTlsConnector {
config_hash: u64,
native_tls_connector: Arc<TlsConnector>,
}

impl<In: Transport> Connector<In> for NativeTlsConnector {
Expand Down Expand Up @@ -46,20 +52,7 @@ impl<In: Transport> Connector<In> for NativeTlsConnector {

trace!("Try wrap TLS");

let tls_config = &details.config.tls_config();

// Initialize the connector on first run.
let connector_ref = match self.connector.get() {
Some(v) => v,
None => {
// This is unlikely to be racy, but if it is, doesn't matter much.
let c = build_connector(tls_config)?;
// Maybe someone else set it first. Weird, but ok.
let _ = self.connector.set(c);
self.connector.get().unwrap()
}
};
let connector = connector_ref.clone(); // cheap clone due to Arc
let connector = self.get_cached_native_tls_connector(details.config)?;

let domain = details
.uri
Expand All @@ -84,7 +77,49 @@ impl<In: Transport> Connector<In> for NativeTlsConnector {
}
}

fn build_connector(tls_config: &TlsConfig) -> Result<Arc<TlsConnector>, Error> {
impl NativeTlsConnector {
fn get_cached_native_tls_connector(&self, config: &Config) -> Result<Arc<TlsConnector>, Error> {
let tls_config = config.tls_config();

let connector = if config.request_level {
// If the TlsConfig is request level, it is not allowed to
// initialize the self.config OnceLock, but it should
// reuse the cached value if it is the same TlsConfig
// by comparing the config_hash value.

let is_cached = self
.connector
.get()
.map(|c| c.config_hash == tls_config.hash_value())
.unwrap_or(false);

if is_cached {
// unwrap is ok because if is_cached is true we must have had a value.
self.connector.get().unwrap().native_tls_connector.clone()
} else {
build_connector(tls_config)?.native_tls_connector
}
} else {
// Initialize the connector on first run.
let connector_ref = match self.connector.get() {
Some(v) => v,
None => {
// This is unlikely to be racy, but if it is, doesn't matter much.
let c = build_connector(tls_config)?;
// Maybe someone else set it first. Weird, but ok.
let _ = self.connector.set(c);
self.connector.get().unwrap()
}
};

connector_ref.native_tls_connector.clone() // cheap clone due to Arc
};

Ok(connector)
}
}

fn build_connector(tls_config: &TlsConfig) -> Result<CachedNativeTlsConnector, Error> {
let mut builder = TlsConnector::builder();

if tls_config.disable_verification {
Expand Down Expand Up @@ -136,7 +171,12 @@ fn build_connector(tls_config: &TlsConfig) -> Result<Arc<TlsConnector>, Error> {

let conn = builder.build()?;

Ok(Arc::new(conn))
let cached = CachedNativeTlsConnector {
config_hash: tls_config.hash_value(),
native_tls_connector: Arc::new(conn),
};

Ok(cached)
}

fn add_valid_der<'a, C>(certs: C, builder: &mut TlsConnectorBuilder)
Expand Down
53 changes: 45 additions & 8 deletions src/tls/rustls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use rustls::{ClientConfig, ClientConnection, RootCertStore, StreamOwned, ALL_VER
use rustls_pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs1KeyDer, PrivatePkcs8KeyDer};
use rustls_pki_types::{PrivateSec1KeyDer, ServerName};

use crate::config::Config;
use crate::tls::cert::KeyKind;
use crate::tls::{RootCerts, TlsProvider};
use crate::transport::{Buffers, ConnectionDetails, Connector, LazyBuffers};
Expand All @@ -22,7 +23,12 @@ use super::TlsConfig;
/// Requires feature flag **rustls**.
#[derive(Default)]
pub struct RustlsConnector {
config: OnceLock<Arc<ClientConfig>>,
config: OnceLock<CachedRustlConfig>,
}

struct CachedRustlConfig {
config_hash: u64,
rustls_config: Arc<ClientConfig>,
}

impl<In: Transport> Connector<In> for RustlsConnector {
Expand Down Expand Up @@ -51,11 +57,7 @@ impl<In: Transport> Connector<In> for RustlsConnector {

trace!("Try wrap in TLS");

let tls_config = details.config.tls_config();

// Initialize the config on first run.
let config_ref = self.config.get_or_init(|| build_config(tls_config));
let config = config_ref.clone(); // cheap clone due to Arc
let config = self.get_cached_config(details.config);

let name_borrowed: ServerName<'_> = details
.uri
Expand Down Expand Up @@ -89,7 +91,39 @@ impl<In: Transport> Connector<In> for RustlsConnector {
}
}

fn build_config(tls_config: &TlsConfig) -> Arc<ClientConfig> {
impl RustlsConnector {
fn get_cached_config(&self, config: &Config) -> Arc<ClientConfig> {
let tls_config = config.tls_config();

if config.request_level {
// If the TlsConfig is request level, it is not allowed to
// initialize the self.config OnceLock, but it should
// reuse the cached value if it is the same TlsConfig
// by comparing the config_hash value.

let is_cached = self
.config
.get()
.map(|c| c.config_hash == tls_config.hash_value())
.unwrap_or(false);

if is_cached {
// unwrap is ok because if is_cached is true we must have had a value.
self.config.get().unwrap().rustls_config.clone()
} else {
build_config(tls_config).rustls_config
}
} else {
// On agent level, we initialize the config on first run. This is
// the value we want to cache.
let config_ref = self.config.get_or_init(|| build_config(tls_config));

config_ref.rustls_config.clone()
}
}
}

fn build_config(tls_config: &TlsConfig) -> CachedRustlConfig {
// 1. Prefer provider set by TlsConfig.
// 2. Use process wide default set in rustls library.
// 3. Pick ring, if it is enabled (the default behavior).
Expand Down Expand Up @@ -183,7 +217,10 @@ fn build_config(tls_config: &TlsConfig) -> Arc<ClientConfig> {
debug!("Disable SNI");
}

Arc::new(config)
CachedRustlConfig {
config_hash: tls_config.hash_value(),
rustls_config: Arc::new(config),
}
}

pub struct RustlsTransport {
Expand Down

0 comments on commit 2fd4c49

Please sign in to comment.