@@ -567,9 +567,12 @@ async fn test_policy_with_wrong_url() {
567
567
// helper functions for certificate rotation test, which is a feature supported only on Linux
568
568
#[ cfg( target_os = "linux" ) ]
569
569
mod certificate_reload_helpers {
570
+ use std:: net:: TcpStream ;
571
+
572
+ use anyhow:: anyhow;
570
573
use openssl:: ssl:: { SslConnector , SslMethod , SslVerifyMode } ;
571
574
use rcgen:: { generate_simple_self_signed, CertifiedKey } ;
572
- use std :: net :: TcpStream ;
575
+ use reqwest :: StatusCode ;
573
576
574
577
pub struct TlsData {
575
578
pub key : String ,
@@ -614,48 +617,34 @@ mod certificate_reload_helpers {
614
617
. unwrap ( )
615
618
}
616
619
617
- pub async fn check_tls_san_name ( domain_ip : & str , domain_port : & str , hostname : & str ) -> bool {
618
- let sleep_interval = std:: time:: Duration :: from_secs ( 1 ) ;
619
- let max_retries = 10 ;
620
- let mut failed_retries = 0 ;
620
+ pub async fn check_tls_san_name (
621
+ domain_ip : & str ,
622
+ domain_port : & str ,
623
+ hostname : & str ,
624
+ ) -> anyhow:: Result < ( ) > {
621
625
let hostname = hostname. to_string ( ) ;
622
- loop {
623
- let san_names = get_tls_san_names ( domain_ip, domain_port) . await ;
624
- if san_names. contains ( & hostname) {
625
- return true ;
626
- }
627
- failed_retries += 1 ;
628
- if failed_retries >= max_retries {
629
- return false ;
630
- }
631
- tokio:: time:: sleep ( sleep_interval) . await ;
626
+ let san_names = get_tls_san_names ( domain_ip, domain_port) . await ;
627
+ if san_names. contains ( & hostname) {
628
+ Ok ( ( ) )
629
+ } else {
630
+ Err ( anyhow ! (
631
+ "SAN names do not contain the expected hostname ({}): {:?}" ,
632
+ hostname,
633
+ san_names
634
+ ) )
632
635
}
633
636
}
634
637
635
- pub async fn wait_for_policy_server_to_be_ready ( address : & str ) {
636
- let sleep_interval = std:: time:: Duration :: from_secs ( 1 ) ;
637
- let max_retries = 5 ;
638
- let mut failed_retries = 0 ;
639
-
638
+ pub async fn policy_server_is_ready ( address : & str ) -> anyhow:: Result < StatusCode > {
640
639
// wait for the server to start
641
640
let client = reqwest:: Client :: builder ( )
642
641
. danger_accept_invalid_certs ( true )
643
642
. build ( )
644
643
. unwrap ( ) ;
645
644
646
- loop {
647
- let url = reqwest:: Url :: parse ( & format ! ( "https://{address}/readiness" ) ) . unwrap ( ) ;
648
- match client. get ( url) . send ( ) . await {
649
- Ok ( _) => break ,
650
- Err ( e) => {
651
- failed_retries += 1 ;
652
- if failed_retries >= max_retries {
653
- panic ! ( "failed to start the server: {:?}" , e) ;
654
- }
655
- tokio:: time:: sleep ( sleep_interval) . await ;
656
- }
657
- }
658
- }
645
+ let url = reqwest:: Url :: parse ( & format ! ( "https://{address}/readiness" ) ) . unwrap ( ) ;
646
+ let response = client. get ( url) . send ( ) . await ?;
647
+ Ok ( response. status ( ) )
659
648
}
660
649
}
661
650
@@ -699,9 +688,22 @@ async fn test_detect_certificate_rotation() {
699
688
. unwrap ( ) ;
700
689
api_server. run ( ) . await . unwrap ( ) ;
701
690
} ) ;
702
- wait_for_policy_server_to_be_ready ( format ! ( "{domain_ip}:{domain_port}" ) . as_str ( ) ) . await ;
703
691
704
- assert ! ( check_tls_san_name( & domain_ip, & domain_port, hostname1) . await ) ;
692
+ let exponential_backoff = ExponentialBuilder :: default ( )
693
+ . with_min_delay ( Duration :: from_secs ( 10 ) )
694
+ . with_max_delay ( Duration :: from_secs ( 30 ) )
695
+ . with_max_times ( 5 ) ;
696
+
697
+ let status_code =
698
+ ( || async { policy_server_is_ready ( format ! ( "{domain_ip}:{domain_port}" ) . as_str ( ) ) . await } )
699
+ . retry ( exponential_backoff)
700
+ . await
701
+ . unwrap ( ) ;
702
+ assert_eq ! ( status_code, reqwest:: StatusCode :: OK ) ;
703
+
704
+ check_tls_san_name ( & domain_ip, & domain_port, hostname1)
705
+ . await
706
+ . expect ( "certificate served doesn't use the expected SAN name" ) ;
705
707
706
708
// Generate a new certificate and key, and switch to them
707
709
@@ -715,16 +717,19 @@ async fn test_detect_certificate_rotation() {
715
717
tokio:: time:: sleep ( std:: time:: Duration :: from_secs ( 4 ) ) . await ;
716
718
717
719
// the old certificate should still be in use, since we didn't change also the key
718
- assert ! ( check_tls_san_name( & domain_ip, & domain_port, hostname1) . await ) ;
720
+ check_tls_san_name ( & domain_ip, & domain_port, hostname1)
721
+ . await
722
+ . expect ( "certificate should not have been changed" ) ;
719
723
720
724
// write only the key file
721
725
std:: fs:: write ( & key_file, tls_data2. key ) . unwrap ( ) ;
722
726
723
- // give inotify some time to ensure it detected the cert change
727
+ // give inotify some time to ensure it detected the cert change,
728
+ // also give axum some time to complete the certificate reload
724
729
tokio:: time:: sleep ( std:: time:: Duration :: from_secs ( 4 ) ) . await ;
725
-
726
- // the new certificate should be in use
727
- assert ! ( check_tls_san_name ( & domain_ip , & domain_port , hostname2 ) . await ) ;
730
+ check_tls_san_name ( & domain_ip , & domain_port , hostname2 )
731
+ . await
732
+ . expect ( "certificate hasn't been reloaded" ) ;
728
733
}
729
734
730
735
#[ tokio:: test]
0 commit comments