Skip to content

Commit d5c629b

Browse files
Merge pull request #913 from flavio/fix-certificate-rotation-detect-changes
fix: make cert rotation detection more reliable
2 parents 9cd3c64 + fb9cbf1 commit d5c629b

File tree

3 files changed

+48
-43
lines changed

3 files changed

+48
-43
lines changed

Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ tempfile = "3.12.0"
7272
tower = { version = "0.5", features = ["util"] }
7373
http-body-util = "0.1.1"
7474
testcontainers = { version = "0.22", features = ["watchdog"] }
75-
backon = { version = "1.1.0", features = ["tokio-sleep"] }
75+
backon = { version = "1.2", features = ["tokio-sleep"] }
7676

7777
[target.'cfg(target_os = "linux")'.dev-dependencies]
7878
rcgen = { version = "0.13", features = ["crypto"] }

src/lib.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -302,11 +302,11 @@ async fn create_tls_config_and_watch_certificate_changes(
302302
inotify::Inotify::init().map_err(|e| anyhow!("Cannot initialize inotify: {e}"))?;
303303
let cert_watch = inotify
304304
.watches()
305-
.add(cert_file.clone(), inotify::WatchMask::MODIFY)
305+
.add(cert_file.clone(), inotify::WatchMask::CLOSE_WRITE)
306306
.map_err(|e| anyhow!("Cannot watch certificate file: {e}"))?;
307307
let key_watch = inotify
308308
.watches()
309-
.add(key_file.clone(), inotify::WatchMask::MODIFY)
309+
.add(key_file.clone(), inotify::WatchMask::CLOSE_WRITE)
310310
.map_err(|e| anyhow!("Cannot watch key file: {e}"))?;
311311

312312
let buffer = [0; 1024];

tests/integration_test.rs

+45-40
Original file line numberDiff line numberDiff line change
@@ -567,9 +567,12 @@ async fn test_policy_with_wrong_url() {
567567
// helper functions for certificate rotation test, which is a feature supported only on Linux
568568
#[cfg(target_os = "linux")]
569569
mod certificate_reload_helpers {
570+
use std::net::TcpStream;
571+
572+
use anyhow::anyhow;
570573
use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode};
571574
use rcgen::{generate_simple_self_signed, CertifiedKey};
572-
use std::net::TcpStream;
575+
use reqwest::StatusCode;
573576

574577
pub struct TlsData {
575578
pub key: String,
@@ -614,48 +617,34 @@ mod certificate_reload_helpers {
614617
.unwrap()
615618
}
616619

617-
pub async fn check_tls_san_name(domain_ip: &str, domain_port: &str, hostname: &str) -> bool {
618-
let sleep_interval = std::time::Duration::from_secs(1);
619-
let max_retries = 10;
620-
let mut failed_retries = 0;
620+
pub async fn check_tls_san_name(
621+
domain_ip: &str,
622+
domain_port: &str,
623+
hostname: &str,
624+
) -> anyhow::Result<()> {
621625
let hostname = hostname.to_string();
622-
loop {
623-
let san_names = get_tls_san_names(domain_ip, domain_port).await;
624-
if san_names.contains(&hostname) {
625-
return true;
626-
}
627-
failed_retries += 1;
628-
if failed_retries >= max_retries {
629-
return false;
630-
}
631-
tokio::time::sleep(sleep_interval).await;
626+
let san_names = get_tls_san_names(domain_ip, domain_port).await;
627+
if san_names.contains(&hostname) {
628+
Ok(())
629+
} else {
630+
Err(anyhow!(
631+
"SAN names do not contain the expected hostname ({}): {:?}",
632+
hostname,
633+
san_names
634+
))
632635
}
633636
}
634637

635-
pub async fn wait_for_policy_server_to_be_ready(address: &str) {
636-
let sleep_interval = std::time::Duration::from_secs(1);
637-
let max_retries = 5;
638-
let mut failed_retries = 0;
639-
638+
pub async fn policy_server_is_ready(address: &str) -> anyhow::Result<StatusCode> {
640639
// wait for the server to start
641640
let client = reqwest::Client::builder()
642641
.danger_accept_invalid_certs(true)
643642
.build()
644643
.unwrap();
645644

646-
loop {
647-
let url = reqwest::Url::parse(&format!("https://{address}/readiness")).unwrap();
648-
match client.get(url).send().await {
649-
Ok(_) => break,
650-
Err(e) => {
651-
failed_retries += 1;
652-
if failed_retries >= max_retries {
653-
panic!("failed to start the server: {:?}", e);
654-
}
655-
tokio::time::sleep(sleep_interval).await;
656-
}
657-
}
658-
}
645+
let url = reqwest::Url::parse(&format!("https://{address}/readiness")).unwrap();
646+
let response = client.get(url).send().await?;
647+
Ok(response.status())
659648
}
660649
}
661650

@@ -699,9 +688,22 @@ async fn test_detect_certificate_rotation() {
699688
.unwrap();
700689
api_server.run().await.unwrap();
701690
});
702-
wait_for_policy_server_to_be_ready(format!("{domain_ip}:{domain_port}").as_str()).await;
703691

704-
assert!(check_tls_san_name(&domain_ip, &domain_port, hostname1).await);
692+
let exponential_backoff = ExponentialBuilder::default()
693+
.with_min_delay(Duration::from_secs(10))
694+
.with_max_delay(Duration::from_secs(30))
695+
.with_max_times(5);
696+
697+
let status_code =
698+
(|| async { policy_server_is_ready(format!("{domain_ip}:{domain_port}").as_str()).await })
699+
.retry(exponential_backoff)
700+
.await
701+
.unwrap();
702+
assert_eq!(status_code, reqwest::StatusCode::OK);
703+
704+
check_tls_san_name(&domain_ip, &domain_port, hostname1)
705+
.await
706+
.expect("certificate served doesn't use the expected SAN name");
705707

706708
// Generate a new certificate and key, and switch to them
707709

@@ -715,16 +717,19 @@ async fn test_detect_certificate_rotation() {
715717
tokio::time::sleep(std::time::Duration::from_secs(4)).await;
716718

717719
// the old certificate should still be in use, since we didn't change also the key
718-
assert!(check_tls_san_name(&domain_ip, &domain_port, hostname1).await);
720+
check_tls_san_name(&domain_ip, &domain_port, hostname1)
721+
.await
722+
.expect("certificate should not have been changed");
719723

720724
// write only the key file
721725
std::fs::write(&key_file, tls_data2.key).unwrap();
722726

723-
// give inotify some time to ensure it detected the cert change
727+
// give inotify some time to ensure it detected the cert change,
728+
// also give axum some time to complete the certificate reload
724729
tokio::time::sleep(std::time::Duration::from_secs(4)).await;
725-
726-
// the new certificate should be in use
727-
assert!(check_tls_san_name(&domain_ip, &domain_port, hostname2).await);
730+
check_tls_san_name(&domain_ip, &domain_port, hostname2)
731+
.await
732+
.expect("certificate hasn't been reloaded");
728733
}
729734

730735
#[tokio::test]

0 commit comments

Comments
 (0)