@@ -674,21 +674,24 @@ async fn test_detect_certificate_rotation() {
674
674
let certs_dir = tempfile:: tempdir ( ) . unwrap ( ) ;
675
675
let cert_file = certs_dir. path ( ) . join ( "policy-server.pem" ) ;
676
676
let key_file = certs_dir. path ( ) . join ( "policy-server-key.pem" ) ;
677
- let client_ca = certs_dir. path ( ) . join ( "client_cert.pem" ) ;
677
+ let first_client_ca = certs_dir. path ( ) . join ( "client_cert1.pem" ) ;
678
+ let second_client_ca = certs_dir. path ( ) . join ( "client_cert2.pem" ) ;
678
679
679
680
let hostname1 = "cert1.example.com" ;
680
681
let tls_data1 = create_cert ( hostname1) ;
681
- let tls_data_client = create_cert ( hostname1) ;
682
+ let first_tls_data_client = create_cert ( hostname1) ;
683
+ let second_tls_data_client = create_cert ( hostname1) ;
682
684
683
685
std:: fs:: write ( & cert_file, tls_data1. cert ) . unwrap ( ) ;
684
686
std:: fs:: write ( & key_file, tls_data1. key ) . unwrap ( ) ;
685
- std:: fs:: write ( & client_ca, tls_data_client. cert . clone ( ) ) . unwrap ( ) ;
687
+ std:: fs:: write ( & first_client_ca, first_tls_data_client. cert . clone ( ) ) . unwrap ( ) ;
688
+ std:: fs:: write ( & second_client_ca, second_tls_data_client. cert . clone ( ) ) . unwrap ( ) ;
686
689
687
690
let mut config = default_test_config ( ) ;
688
691
config. tls_config = Some ( policy_server:: config:: TlsConfig {
689
692
cert_file : cert_file. to_str ( ) . unwrap ( ) . to_owned ( ) ,
690
693
key_file : key_file. to_str ( ) . unwrap ( ) . to_owned ( ) ,
691
- client_ca_file : Some ( client_ca . to_str ( ) . unwrap ( ) . to_owned ( ) ) ,
694
+ client_ca_file : vec ! [ first_client_ca . clone ( ) , second_client_ca . clone ( ) ] ,
692
695
} ) ;
693
696
config. policies = HashMap :: new ( ) ;
694
697
@@ -745,6 +748,22 @@ async fn test_detect_certificate_rotation() {
745
748
check_tls_san_name ( & host, & port, hostname2)
746
749
. await
747
750
. expect ( "certificate hasn't been reloaded" ) ;
751
+
752
+ let first_tls_data_client2 = create_cert ( hostname2) ;
753
+
754
+ // write only the cert file
755
+ std:: fs:: write ( & first_client_ca, first_tls_data_client2. cert ) . unwrap ( ) ;
756
+
757
+ // give inotify some time to ensure it detected the cert change
758
+ tokio:: time:: sleep ( std:: time:: Duration :: from_secs ( 4 ) ) . await ;
759
+
760
+ let second_tls_data_client2 = create_cert ( hostname2) ;
761
+
762
+ // write only the cert file
763
+ std:: fs:: write ( & second_client_ca, second_tls_data_client2. cert ) . unwrap ( ) ;
764
+
765
+ // give inotify some time to ensure it detected the cert change
766
+ tokio:: time:: sleep ( std:: time:: Duration :: from_secs ( 4 ) ) . await ;
748
767
}
749
768
750
769
#[ tokio:: test]
@@ -954,11 +973,11 @@ fn generate_tls_certs() -> (String, String, String) {
954
973
#[ case:: with_server_tls_config( Some ( certificate_reload_helpers:: create_cert( "127.0.0.1" ) ) , None ) ]
955
974
#[ case:: mtls_config(
956
975
Some ( certificate_reload_helpers:: create_cert( "127.0.0.1" ) ) ,
957
- Some ( certificate_reload_helpers:: create_cert( "127.0.0.1" ) )
976
+ Some ( vec! [ certificate_reload_helpers:: create_cert( "127.0.0.1" ) ] )
958
977
) ]
959
978
async fn test_tls (
960
979
#[ case] server_tls_data : Option < certificate_reload_helpers:: TlsData > ,
961
- #[ case] client_tls_data : Option < certificate_reload_helpers:: TlsData > ,
980
+ #[ case] client_tls_data : Option < Vec < certificate_reload_helpers:: TlsData > > ,
962
981
) {
963
982
use certificate_reload_helpers:: * ;
964
983
@@ -967,35 +986,49 @@ async fn test_tls(
967
986
let certs_dir = tempfile:: tempdir ( ) . unwrap ( ) ;
968
987
let cert_file = certs_dir. path ( ) . join ( "policy-server.pem" ) ;
969
988
let key_file = certs_dir. path ( ) . join ( "policy-server-key.pem" ) ;
970
- let client_ca = certs_dir. path ( ) . join ( "client_cert.pem" ) ;
971
989
972
- let server_cert = if let Some ( ref tls_data) = server_tls_data {
990
+ if let Some ( ref tls_data) = server_tls_data {
973
991
std:: fs:: write ( & cert_file, tls_data. cert . clone ( ) ) . unwrap ( ) ;
974
992
std:: fs:: write ( & key_file, tls_data. key . clone ( ) ) . unwrap ( ) ;
975
- tls_data. cert . clone ( )
976
- } else {
977
- String :: new ( )
978
- } ;
993
+ }
979
994
980
- let ( client_cert, client_key) = if let Some ( ref tls_data) = client_tls_data {
981
- std:: fs:: write ( & client_ca, tls_data. cert . clone ( ) ) . unwrap ( ) ;
982
- ( tls_data. cert . clone ( ) , tls_data. key . clone ( ) )
983
- } else {
984
- ( String :: new ( ) , String :: new ( ) )
985
- } ;
995
+ // Client CA pem file, cert data and key data
996
+ let clients_cas_info: Vec < ( PathBuf , String , String ) > =
997
+ if let Some ( ref tls_data) = client_tls_data {
998
+ tls_data
999
+ . iter ( )
1000
+ . enumerate ( )
1001
+ . into_iter ( )
1002
+ . map ( |( i, tls_data) | {
1003
+ let client_ca = certs_dir
1004
+ . path ( )
1005
+ . join ( format ! ( "client_cert_{}.pem" , i) )
1006
+ . to_owned ( ) ;
1007
+ std:: fs:: write ( & client_ca, tls_data. cert . clone ( ) )
1008
+ . expect ( "failed to write client CA file" ) ;
1009
+ ( client_ca, tls_data. cert . clone ( ) , tls_data. key . clone ( ) )
1010
+ } )
1011
+ . collect ( )
1012
+ } else {
1013
+ vec ! [ ]
1014
+ } ;
986
1015
987
1016
let mut config = default_test_config ( ) ;
988
1017
config. tls_config = match ( server_tls_data. as_ref ( ) , client_tls_data. as_ref ( ) ) {
989
1018
( None , None ) => None ,
990
1019
( Some ( _) , Some ( _) ) => Some ( policy_server:: config:: TlsConfig {
991
1020
cert_file : cert_file. to_str ( ) . unwrap ( ) . to_owned ( ) ,
992
1021
key_file : key_file. to_str ( ) . unwrap ( ) . to_owned ( ) ,
993
- client_ca_file : Some ( client_ca. to_str ( ) . unwrap ( ) . to_owned ( ) ) ,
1022
+ client_ca_file : clients_cas_info
1023
+ . clone ( )
1024
+ . into_iter ( )
1025
+ . map ( |it| it. 0 )
1026
+ . collect ( ) ,
994
1027
} ) ,
995
1028
( Some ( _) , None ) => Some ( policy_server:: config:: TlsConfig {
996
1029
cert_file : cert_file. to_str ( ) . unwrap ( ) . to_owned ( ) ,
997
1030
key_file : key_file. to_str ( ) . unwrap ( ) . to_owned ( ) ,
998
- client_ca_file : None ,
1031
+ client_ca_file : vec ! [ ] ,
999
1032
} ) ,
1000
1033
_ => {
1001
1034
panic ! ( "Invalid test case" )
@@ -1028,38 +1061,71 @@ async fn test_tls(
1028
1061
. expect ( "policy server is not ready" ) ;
1029
1062
assert_eq ! ( status_code, reqwest:: StatusCode :: OK ) ;
1030
1063
1031
- // Validate TLS communication
1032
- let mut builder = reqwest:: Client :: builder ( ) ;
1064
+ // Test sending request to policy server using each of the client CA certificates
1065
+ let client_to_test = match client_tls_data {
1066
+ Some ( _) => clients_cas_info
1067
+ . iter ( )
1068
+ . map ( |( _, client_cert, client_key) | {
1069
+ build_request_client (
1070
+ server_tls_data. as_ref ( ) ,
1071
+ Some ( client_cert. to_owned ( ) ) ,
1072
+ Some ( client_key. to_owned ( ) ) ,
1073
+ )
1074
+ } )
1075
+ . collect ( ) ,
1076
+ _ => vec ! [ build_request_client( server_tls_data. as_ref( ) , None , None ) ] ,
1077
+ } ;
1033
1078
1034
- if server_tls_data . is_some ( ) {
1035
- let certificate = reqwest :: Certificate :: from_pem ( server_cert . as_bytes ( ) )
1036
- . expect ( "Invalid policy server certificate" ) ;
1037
- builder = builder . add_root_certificate ( certificate ) ;
1079
+ for client in client_to_test {
1080
+ let response =
1081
+ send_validate_request ( & client , format ! ( "{host}:{port}" ) , server_tls_data . as_ref ( ) ) ;
1082
+ assert_eq ! ( response . await . status ( ) , reqwest :: StatusCode :: OK ) ;
1038
1083
}
1084
+ }
1039
1085
1040
- if client_tls_data. is_some ( ) {
1041
- let identity =
1042
- reqwest:: Identity :: from_pem ( format ! ( "{}\n {}" , client_cert, client_key) . as_bytes ( ) )
1043
- . expect ( "successfull pem parsing" ) ;
1044
- builder = builder. identity ( identity) ;
1045
- } ;
1046
- let client = builder. build ( ) . unwrap ( ) ;
1047
-
1086
+ async fn send_validate_request (
1087
+ client : & reqwest:: Client ,
1088
+ address : String ,
1089
+ server_tls_data : Option < & certificate_reload_helpers:: TlsData > ,
1090
+ ) -> reqwest:: Response {
1048
1091
let prefix = if server_tls_data. is_some ( ) {
1049
1092
"https"
1050
1093
} else {
1051
1094
"http"
1052
1095
} ;
1053
- let url =
1054
- reqwest :: Url :: parse ( & format ! ( "{prefix}://{host}:{port}/validate/pod-privileged" ) ) . unwrap ( ) ;
1055
- let response = client
1056
- . post ( url)
1096
+ let url = reqwest :: Url :: parse ( & format ! ( "{prefix}://{address}/validate/pod-privileged" ) )
1097
+ . expect ( "failed to format url" ) ;
1098
+ client
1099
+ . post ( url. clone ( ) )
1057
1100
. header ( header:: CONTENT_TYPE , "application/json" )
1058
1101
. body ( include_str ! ( "data/pod_without_privileged_containers.json" ) )
1059
1102
. send ( )
1060
- . await ;
1061
- assert_eq ! (
1062
- response. expect( "successfull request" ) . status( ) ,
1063
- reqwest:: StatusCode :: OK
1064
- ) ;
1103
+ . await
1104
+ . expect ( "successfull request" )
1105
+ }
1106
+
1107
+ fn build_request_client (
1108
+ server_tls_data : Option < & certificate_reload_helpers:: TlsData > ,
1109
+ client_cert : Option < String > ,
1110
+ client_key : Option < String > ,
1111
+ ) -> reqwest:: Client {
1112
+ // Validate TLS communication
1113
+ let mut builder = reqwest:: Client :: builder ( ) ;
1114
+
1115
+ if let Some ( server_tls_data) = server_tls_data {
1116
+ let certificate = reqwest:: Certificate :: from_pem ( server_tls_data. cert . clone ( ) . as_bytes ( ) )
1117
+ . expect ( "Invalid policy server certificate" ) ;
1118
+ builder = builder. add_root_certificate ( certificate) ;
1119
+ }
1120
+
1121
+ match ( client_cert, client_key) {
1122
+ ( Some ( client_cert) , Some ( client_key) ) => {
1123
+ let identity =
1124
+ reqwest:: Identity :: from_pem ( format ! ( "{}\n {}" , client_cert, client_key) . as_bytes ( ) )
1125
+ . expect ( "successfull pem parsing" ) ;
1126
+ builder = builder. identity ( identity)
1127
+ }
1128
+ _ => { }
1129
+ }
1130
+ builder. build ( ) . expect ( "failed to build client" )
1065
1131
}
0 commit comments