Skip to content

Commit a93b30d

Browse files
authored
Merge pull request #1094 from fabriziosestito/refactor-cert-hot-reload
fix: remove TLS hot-reload race condition
2 parents c1f0b62 + 097ef01 commit a93b30d

File tree

4 files changed

+312
-261
lines changed

4 files changed

+312
-261
lines changed

src/certs.rs

+258
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
use ::tracing::{info, warn};
2+
use anyhow::{anyhow, Result};
3+
use rustls::{server::WebPkiClientVerifier, RootCertStore, ServerConfig};
4+
use rustls_pemfile::Item;
5+
use rustls_pki_types::{CertificateDer, PrivateKeyDer};
6+
use std::{io::BufReader, path::Path, sync::Arc};
7+
8+
// This is required by certificate hot reload when using inotify, which is available only on linux
9+
#[cfg(target_os = "linux")]
10+
use tokio_stream::StreamExt;
11+
12+
use crate::config::TlsConfig;
13+
14+
/// There's no watching of the certificate files on non-linux platforms
15+
/// since we rely on inotify to watch for changes
16+
#[cfg(not(target_os = "linux"))]
17+
async fn create_tls_config_and_watch_certificate_changes(
18+
tls_config: TlsConfig,
19+
) -> Result<RustlsConfig> {
20+
let cfg = RustlsConfig::from_pem_file(tls_config.cert_file, tls_config.key_file).await?;
21+
Ok(cfg)
22+
}
23+
24+
/// Return the RustlsConfig and watch for changes in the certificate files
25+
/// using inotify.
26+
/// When both the certificate and its key are changed, the RustlsConfig is reloaded,
27+
/// causing the https server to use the new certificate.
28+
///
29+
/// Relying on inotify is only available on linux
30+
#[cfg(target_os = "linux")]
31+
pub(crate) async fn create_tls_config_and_watch_certificate_changes(
32+
tls_config: TlsConfig,
33+
) -> Result<axum_server::tls_rustls::RustlsConfig> {
34+
use ::tracing::error;
35+
use axum_server::tls_rustls::RustlsConfig;
36+
use inotify::WatchDescriptor;
37+
38+
// Build initial TLS configuration
39+
let (mut cert, mut key) =
40+
load_server_cert_and_key(&tls_config.cert_file, &tls_config.key_file).await?;
41+
let mut client_verifier = if tls_config.client_ca_file.is_empty() {
42+
None
43+
} else {
44+
Some(load_client_ca_certs(tls_config.client_ca_file.clone()).await?)
45+
};
46+
let initial_config =
47+
build_tls_server_config(cert.clone(), key.clone_key(), client_verifier.clone())?;
48+
49+
let rust_config = RustlsConfig::from_config(Arc::new(initial_config));
50+
let reloadable_rust_config = rust_config.clone();
51+
52+
// Init inotify to watch for changes in the certificate files
53+
let inotify =
54+
inotify::Inotify::init().map_err(|e| anyhow!("Cannot initialize inotify: {e}"))?;
55+
let cert_watch = inotify
56+
.watches()
57+
.add(
58+
tls_config.cert_file.clone(),
59+
inotify::WatchMask::CLOSE_WRITE,
60+
)
61+
.map_err(|e| anyhow!("Cannot watch certificate file: {e}"))?;
62+
let key_watch = inotify
63+
.watches()
64+
.add(tls_config.key_file.clone(), inotify::WatchMask::CLOSE_WRITE)
65+
.map_err(|e| anyhow!("Cannot watch key file: {e}"))?;
66+
67+
let client_ca_watches: Result<Vec<WatchDescriptor>, anyhow::Error> = tls_config
68+
.client_ca_file
69+
.clone()
70+
.into_iter()
71+
.map(|path| {
72+
inotify
73+
.watches()
74+
.add(path, inotify::WatchMask::CLOSE_WRITE)
75+
.map_err(|e| anyhow!("Cannot watch client CA file: {e}"))
76+
})
77+
.collect();
78+
79+
let client_ca_watches = client_ca_watches?;
80+
81+
let buffer = [0; 1024];
82+
let stream = inotify
83+
.into_event_stream(buffer)
84+
.map_err(|e| anyhow!("Cannot create inotify event stream: {e}"))?;
85+
86+
tokio::spawn(async move {
87+
tokio::pin!(stream);
88+
let mut cert_changed = false;
89+
let mut key_changed = false;
90+
let mut client_ca_changed = false;
91+
92+
while let Some(event) = stream.next().await {
93+
let event = match event {
94+
Ok(event) => event,
95+
Err(e) => {
96+
warn!("Cannot read inotify event: {e}");
97+
continue;
98+
}
99+
};
100+
101+
if event.wd == cert_watch {
102+
info!("TLS certificate file has been modified");
103+
cert_changed = true;
104+
}
105+
if event.wd == key_watch {
106+
info!("TLS key file has been modified");
107+
key_changed = true;
108+
}
109+
110+
for client_ca_watch in client_ca_watches.iter() {
111+
if event.wd == *client_ca_watch {
112+
info!("TLS client CA file has been modified");
113+
client_ca_changed = true;
114+
}
115+
}
116+
117+
// Reload the client CA certificates if they have changed, keeping the current server certificates unchanged
118+
if client_ca_changed {
119+
info!("Reloading client CA certificates");
120+
121+
client_ca_changed = false;
122+
123+
match load_client_ca_certs(tls_config.client_ca_file.clone()).await {
124+
Ok(cv) => {
125+
client_verifier = Some(cv);
126+
}
127+
Err(e) => {
128+
error!("Failed to reload TLS certificates: {e}");
129+
continue;
130+
}
131+
}
132+
}
133+
134+
// Reload the server certificates if they have changed keeping the current client CA certificates unchanged
135+
if key_changed && cert_changed {
136+
info!("Reloading Server TLS certificates");
137+
138+
cert_changed = false;
139+
key_changed = false;
140+
141+
match load_server_cert_and_key(&tls_config.cert_file, &tls_config.key_file).await {
142+
Ok(ck) => {
143+
(cert, key) = ck;
144+
}
145+
Err(e) => {
146+
error!("Failed to reload TLS certificates: {e}");
147+
continue;
148+
}
149+
}
150+
}
151+
152+
match build_tls_server_config(cert.clone(), key.clone_key(), client_verifier.clone()) {
153+
Ok(server_config) => {
154+
reloadable_rust_config.reload_from_config(Arc::new(server_config));
155+
}
156+
Err(e) => {
157+
error!("Failed to reload TLS certificate: {e}");
158+
}
159+
}
160+
}
161+
});
162+
163+
Ok(rust_config)
164+
}
165+
166+
// Build the TLS server
167+
fn build_tls_server_config(
168+
cert: Vec<CertificateDer<'static>>,
169+
key: PrivateKeyDer<'static>,
170+
client_verifier: Option<Arc<dyn rustls::server::danger::ClientCertVerifier>>,
171+
) -> Result<rustls::ServerConfig> {
172+
if let Some(client_verifier) = client_verifier {
173+
return Ok(ServerConfig::builder()
174+
.with_client_cert_verifier(client_verifier)
175+
.with_single_cert(cert, key)?);
176+
}
177+
178+
Ok(ServerConfig::builder()
179+
.with_no_client_auth()
180+
.with_single_cert(cert, key)?)
181+
}
182+
183+
// Load the server certificate and key
184+
async fn load_server_cert_and_key(
185+
cert_file: &Path,
186+
key_file: &Path,
187+
) -> Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>)> {
188+
let cert_contents = tokio::fs::read(cert_file).await?;
189+
let key_contents = tokio::fs::read(key_file).await?;
190+
191+
let cert_reader = &mut BufReader::new(&cert_contents[..]);
192+
let key_reader = &mut BufReader::new(&key_contents[..]);
193+
194+
let cert: Vec<CertificateDer> = rustls_pemfile::certs(cert_reader)
195+
.filter_map(|it| {
196+
if let Err(ref e) = it {
197+
warn!("Cannot parse certificate: {e}");
198+
return None;
199+
}
200+
it.ok()
201+
})
202+
.collect();
203+
if cert.len() > 1 {
204+
return Err(anyhow!("Multiple certificates provided in cert file"));
205+
}
206+
207+
let mut key_vec: Vec<Vec<u8>> = rustls_pemfile::read_all(key_reader)
208+
.filter_map(|i| match i.ok()? {
209+
Item::Sec1Key(key) => Some(key.secret_sec1_der().to_vec()),
210+
Item::Pkcs1Key(key) => Some(key.secret_pkcs1_der().to_vec()),
211+
Item::Pkcs8Key(key) => Some(key.secret_pkcs8_der().to_vec()),
212+
_ => {
213+
info!("Ignoring non-key item in key file");
214+
None
215+
}
216+
})
217+
.collect();
218+
if key_vec.is_empty() {
219+
return Err(anyhow!("No key provided in key file"));
220+
}
221+
if key_vec.len() > 1 {
222+
return Err(anyhow!("Multiple keys provided in key file"));
223+
}
224+
let key = PrivateKeyDer::try_from(key_vec.pop().unwrap())
225+
.map_err(|e| anyhow!("Cannot parse server key: {e}"))?;
226+
227+
Ok((cert, key))
228+
}
229+
230+
// Load the client CA certificates and build the client verifier
231+
async fn load_client_ca_certs(
232+
client_cas: Vec<std::path::PathBuf>,
233+
) -> Result<Arc<dyn rustls::server::danger::ClientCertVerifier>> {
234+
let mut store = RootCertStore::empty();
235+
for client_ca_file in client_cas {
236+
let client_ca_contents = tokio::fs::read(&client_ca_file).await?;
237+
let client_ca_reader = &mut BufReader::new(&client_ca_contents[..]);
238+
239+
let client_ca_certs: Vec<_> = rustls_pemfile::certs(client_ca_reader)
240+
.filter_map(|it| {
241+
if let Err(ref e) = it {
242+
warn!("Cannot parse client CA certificate: {e}");
243+
}
244+
it.ok()
245+
})
246+
.collect();
247+
let (cert_added, cert_ignored) = store.add_parsable_certificates(client_ca_certs);
248+
info!(
249+
client_ca_certs_added = cert_added,
250+
client_ca_certs_ignored = cert_ignored,
251+
"Loaded client CA certificates"
252+
);
253+
}
254+
255+
WebPkiClientVerifier::builder(Arc::new(store))
256+
.build()
257+
.map_err(|e| anyhow!("Cannot build client verifier: {e}"))
258+
}

src/config.rs

+12-15
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ pub struct Config {
5252
}
5353

5454
pub struct TlsConfig {
55-
pub cert_file: String,
56-
pub key_file: String,
55+
pub cert_file: PathBuf,
56+
pub key_file: PathBuf,
5757
pub client_ca_file: Vec<PathBuf>,
5858
}
5959

@@ -190,22 +190,19 @@ fn readiness_probe_bind_address(matches: &clap::ArgMatches) -> Result<SocketAddr
190190
}
191191

192192
fn build_tls_config(matches: &clap::ArgMatches) -> Result<Option<TlsConfig>> {
193-
let cert_file = matches.get_one::<String>("cert-file").cloned();
194-
let key_file = matches.get_one::<String>("key-file").cloned();
195-
let client_ca_file = matches.get_many::<String>("client-ca-file");
193+
let cert_file = matches.get_one::<PathBuf>("cert-file").cloned();
194+
let key_file = matches.get_one::<PathBuf>("key-file").cloned();
195+
let client_ca_file = matches.get_many::<PathBuf>("client-ca-file");
196196

197197
match (cert_file, key_file, &client_ca_file) {
198-
(Some(cert_file), Some(key_file), _) => {
199-
let client_ca_file = client_ca_file
198+
(Some(cert_file), Some(key_file), _) => Ok(Some(TlsConfig {
199+
cert_file,
200+
key_file,
201+
client_ca_file: client_ca_file
200202
.unwrap_or_default()
201-
.map(PathBuf::from)
202-
.collect();
203-
Ok(Some(TlsConfig {
204-
cert_file,
205-
key_file,
206-
client_ca_file,
207-
}))
208-
}
203+
.map(|p| p.to_owned())
204+
.collect::<Vec<PathBuf>>(),
205+
})),
209206
// No TLS configuration provided
210207
(None, None, None) => Ok(None),
211208
// Client CA certificate provided without server certificate and key

0 commit comments

Comments
 (0)