Skip to content

Commit dcd6e40

Browse files
committed
feat: multiple client CA.
Updates the policy server to allow loading multiple CA to validate the certificate used by client in a mTLS scenario. Signed-off-by: José Guilherme Vanz <jguilhermevanz@suse.com>
1 parent f8b201b commit dcd6e40

File tree

4 files changed

+153
-72
lines changed

4 files changed

+153
-72
lines changed

src/cli.rs

+1
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ pub(crate) fn build_cli() -> Command {
9595

9696
Arg::new("client-ca-file")
9797
.long("client-ca-file")
98+
.value_delimiter(',')
9899
.value_name("CLIENT_CA_FILE")
99100
.env("KUBEWARDEN_CLIENT_CA_FILE")
100101
.help("Path to an CA certificate file that issued the client certificate. Required to enable mTLS"),

src/config.rs

+13-7
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ pub struct Config {
5454
pub struct TlsConfig {
5555
pub cert_file: String,
5656
pub key_file: String,
57-
pub client_ca_file: Option<String>,
57+
pub client_ca_file: Vec<PathBuf>,
5858
}
5959

6060
impl Config {
@@ -192,14 +192,20 @@ fn readiness_probe_bind_address(matches: &clap::ArgMatches) -> Result<SocketAddr
192192
fn build_tls_config(matches: &clap::ArgMatches) -> Result<Option<TlsConfig>> {
193193
let cert_file = matches.get_one::<String>("cert-file").cloned();
194194
let key_file = matches.get_one::<String>("key-file").cloned();
195-
let client_ca_file = matches.get_one::<String>("client-ca-file").cloned();
195+
let client_ca_file = matches.get_many::<String>("client-ca-file");
196196

197197
match (cert_file, key_file, &client_ca_file) {
198-
(Some(cert_file), Some(key_file), _) => Ok(Some(TlsConfig {
199-
cert_file,
200-
key_file,
201-
client_ca_file,
202-
})),
198+
(Some(cert_file), Some(key_file), _) => {
199+
let client_ca_file = client_ca_file
200+
.unwrap_or_default()
201+
.map(PathBuf::from)
202+
.collect();
203+
Ok(Some(TlsConfig {
204+
cert_file,
205+
key_file,
206+
client_ca_file,
207+
}))
208+
}
203209
// No TLS configuration provided
204210
(None, None, None) => Ok(None),
205211
// Client CA certificate provided without server certificate and key

src/lib.rs

+30-22
Original file line numberDiff line numberDiff line change
@@ -332,11 +332,19 @@ async fn build_tls_server_config(tls_config: &TlsConfig) -> Result<rustls::Serve
332332
let key = PrivateKeyDer::try_from(key_vec.pop().unwrap())
333333
.map_err(|e| anyhow!("Cannot parse server key: {e}"))?;
334334

335-
if let Some(client_ca_file) = tls_config.client_ca_file.clone() {
335+
if tls_config.client_ca_file.is_empty() {
336+
return Ok(ServerConfig::builder()
337+
.with_no_client_auth()
338+
.with_single_cert(cert, key)?);
339+
}
340+
341+
let mut store = RootCertStore::empty();
342+
343+
//mTLS enabled
344+
for client_ca_file in tls_config.client_ca_file.clone() {
336345
// we have the client CA. Therefore, we should enable mTLS.
337346
let client_ca_reader = &mut BufReader::new(File::open(client_ca_file)?);
338347

339-
let mut store = RootCertStore::empty();
340348
let client_ca_certs: Vec<_> = rustls_pemfile::certs(client_ca_reader)
341349
.filter_map(|it| {
342350
if let Err(ref e) = it {
@@ -351,15 +359,10 @@ async fn build_tls_server_config(tls_config: &TlsConfig) -> Result<rustls::Serve
351359
client_ca_certs_ignored = cert_ignored,
352360
"Loaded client CA certificates"
353361
);
354-
let client_verifier = WebPkiClientVerifier::builder(Arc::new(store)).build()?;
355-
356-
return Ok(ServerConfig::builder()
357-
.with_client_cert_verifier(client_verifier)
358-
.with_single_cert(cert, key)?);
359362
}
360-
363+
let client_verifier = WebPkiClientVerifier::builder(Arc::new(store)).build()?;
361364
Ok(ServerConfig::builder()
362-
.with_no_client_auth()
365+
.with_client_cert_verifier(client_verifier)
363366
.with_single_cert(cert, key)?)
364367
}
365368

@@ -385,11 +388,12 @@ async fn create_tls_config_and_watch_certificate_changes(
385388
) -> Result<RustlsConfig> {
386389
use ::tracing::error;
387390

388-
let config = build_tls_server_config(&tls_config).await?;
389-
390-
let rust_config = RustlsConfig::from_config(Arc::new(config));
391+
// Build initial TLS configuration
392+
let initial_config = build_tls_server_config(&tls_config).await?;
393+
let rust_config = RustlsConfig::from_config(Arc::new(initial_config));
391394
let reloadable_rust_config = rust_config.clone();
392395

396+
// Init inotify to watch for changes in the certificate files
393397
let inotify =
394398
inotify::Inotify::init().map_err(|e| anyhow!("Cannot initialize inotify: {e}"))?;
395399
let cert_watch = inotify
@@ -404,15 +408,18 @@ async fn create_tls_config_and_watch_certificate_changes(
404408
.add(tls_config.key_file.clone(), inotify::WatchMask::CLOSE_WRITE)
405409
.map_err(|e| anyhow!("Cannot watch key file: {e}"))?;
406410

407-
let mut client_cert_watch = None;
408-
if let Some(ref client_ca_file) = tls_config.client_ca_file {
409-
client_cert_watch = Some(
411+
let client_cert_watches = tls_config
412+
.client_ca_file
413+
.clone()
414+
.into_iter()
415+
.map(|path| {
410416
inotify
411417
.watches()
412-
.add(client_ca_file, inotify::WatchMask::CLOSE_WRITE)
413-
.map_err(|e| anyhow!("Cannot watch client certificate file: {e}"))?,
414-
);
415-
}
418+
.add(path, inotify::WatchMask::CLOSE_WRITE)
419+
.map_err(|e| anyhow!("Cannot watch client certificate file: {e}"))
420+
.unwrap()
421+
})
422+
.collect::<Vec<_>>();
416423

417424
let buffer = [0; 1024];
418425
let stream = inotify
@@ -442,7 +449,8 @@ async fn create_tls_config_and_watch_certificate_changes(
442449
info!("TLS key file has been modified");
443450
key_changed = true;
444451
}
445-
if let Some(ref client_cert_watch) = client_cert_watch {
452+
453+
for client_cert_watch in client_cert_watches.iter() {
446454
if event.wd == *client_cert_watch {
447455
info!("TLS client certificate file has been modified");
448456
client_cert_changed = true;
@@ -454,11 +462,12 @@ async fn create_tls_config_and_watch_certificate_changes(
454462
if (key_changed && cert_changed)
455463
|| (client_cert_changed && (key_changed == cert_changed))
456464
{
457-
info!("reloading TLS certificates");
465+
info!("Reloading TLS certificates");
458466

459467
cert_changed = false;
460468
key_changed = false;
461469
client_cert_changed = false;
470+
462471
let server_config = build_tls_server_config(&tls_config).await;
463472
if let Err(e) = server_config {
464473
error!("Failed to reload TLS certificate: {}", e);
@@ -468,7 +477,6 @@ async fn create_tls_config_and_watch_certificate_changes(
468477
}
469478
}
470479
});
471-
472480
Ok(rust_config)
473481
}
474482

tests/integration_test.rs

+109-43
Original file line numberDiff line numberDiff line change
@@ -674,21 +674,24 @@ async fn test_detect_certificate_rotation() {
674674
let certs_dir = tempfile::tempdir().unwrap();
675675
let cert_file = certs_dir.path().join("policy-server.pem");
676676
let key_file = certs_dir.path().join("policy-server-key.pem");
677-
let client_ca = certs_dir.path().join("client_cert.pem");
677+
let first_client_ca = certs_dir.path().join("client_cert1.pem");
678+
let second_client_ca = certs_dir.path().join("client_cert2.pem");
678679

679680
let hostname1 = "cert1.example.com";
680681
let tls_data1 = create_cert(hostname1);
681-
let tls_data_client = create_cert(hostname1);
682+
let first_tls_data_client = create_cert(hostname1);
683+
let second_tls_data_client = create_cert(hostname1);
682684

683685
std::fs::write(&cert_file, tls_data1.cert).unwrap();
684686
std::fs::write(&key_file, tls_data1.key).unwrap();
685-
std::fs::write(&client_ca, tls_data_client.cert.clone()).unwrap();
687+
std::fs::write(&first_client_ca, first_tls_data_client.cert.clone()).unwrap();
688+
std::fs::write(&second_client_ca, second_tls_data_client.cert.clone()).unwrap();
686689

687690
let mut config = default_test_config();
688691
config.tls_config = Some(policy_server::config::TlsConfig {
689692
cert_file: cert_file.to_str().unwrap().to_owned(),
690693
key_file: key_file.to_str().unwrap().to_owned(),
691-
client_ca_file: Some(client_ca.to_str().unwrap().to_owned()),
694+
client_ca_file: vec![first_client_ca.clone(), second_client_ca.clone()],
692695
});
693696
config.policies = HashMap::new();
694697

@@ -745,6 +748,22 @@ async fn test_detect_certificate_rotation() {
745748
check_tls_san_name(&host, &port, hostname2)
746749
.await
747750
.expect("certificate hasn't been reloaded");
751+
752+
let first_tls_data_client2 = create_cert(hostname2);
753+
754+
// write only the cert file
755+
std::fs::write(&first_client_ca, first_tls_data_client2.cert).unwrap();
756+
757+
// give inotify some time to ensure it detected the cert change
758+
tokio::time::sleep(std::time::Duration::from_secs(4)).await;
759+
760+
let second_tls_data_client2 = create_cert(hostname2);
761+
762+
// write only the cert file
763+
std::fs::write(&second_client_ca, second_tls_data_client2.cert).unwrap();
764+
765+
// give inotify some time to ensure it detected the cert change
766+
tokio::time::sleep(std::time::Duration::from_secs(4)).await;
748767
}
749768

750769
#[tokio::test]
@@ -954,11 +973,11 @@ fn generate_tls_certs() -> (String, String, String) {
954973
#[case::with_server_tls_config(Some(certificate_reload_helpers::create_cert("127.0.0.1")), None)]
955974
#[case::mtls_config(
956975
Some(certificate_reload_helpers::create_cert("127.0.0.1")),
957-
Some(certificate_reload_helpers::create_cert("127.0.0.1"))
976+
Some(vec![certificate_reload_helpers::create_cert("127.0.0.1")])
958977
)]
959978
async fn test_tls(
960979
#[case] server_tls_data: Option<certificate_reload_helpers::TlsData>,
961-
#[case] client_tls_data: Option<certificate_reload_helpers::TlsData>,
980+
#[case] client_tls_data: Option<Vec<certificate_reload_helpers::TlsData>>,
962981
) {
963982
use certificate_reload_helpers::*;
964983

@@ -967,35 +986,49 @@ async fn test_tls(
967986
let certs_dir = tempfile::tempdir().unwrap();
968987
let cert_file = certs_dir.path().join("policy-server.pem");
969988
let key_file = certs_dir.path().join("policy-server-key.pem");
970-
let client_ca = certs_dir.path().join("client_cert.pem");
971989

972-
let server_cert = if let Some(ref tls_data) = server_tls_data {
990+
if let Some(ref tls_data) = server_tls_data {
973991
std::fs::write(&cert_file, tls_data.cert.clone()).unwrap();
974992
std::fs::write(&key_file, tls_data.key.clone()).unwrap();
975-
tls_data.cert.clone()
976-
} else {
977-
String::new()
978-
};
993+
}
979994

980-
let (client_cert, client_key) = if let Some(ref tls_data) = client_tls_data {
981-
std::fs::write(&client_ca, tls_data.cert.clone()).unwrap();
982-
(tls_data.cert.clone(), tls_data.key.clone())
983-
} else {
984-
(String::new(), String::new())
985-
};
995+
// Client CA pem file, cert data and key data
996+
let clients_cas_info: Vec<(PathBuf, String, String)> =
997+
if let Some(ref tls_data) = client_tls_data {
998+
tls_data
999+
.iter()
1000+
.enumerate()
1001+
.into_iter()
1002+
.map(|(i, tls_data)| {
1003+
let client_ca = certs_dir
1004+
.path()
1005+
.join(format!("client_cert_{}.pem", i))
1006+
.to_owned();
1007+
std::fs::write(&client_ca, tls_data.cert.clone())
1008+
.expect("failed to write client CA file");
1009+
(client_ca, tls_data.cert.clone(), tls_data.key.clone())
1010+
})
1011+
.collect()
1012+
} else {
1013+
vec![]
1014+
};
9861015

9871016
let mut config = default_test_config();
9881017
config.tls_config = match (server_tls_data.as_ref(), client_tls_data.as_ref()) {
9891018
(None, None) => None,
9901019
(Some(_), Some(_)) => Some(policy_server::config::TlsConfig {
9911020
cert_file: cert_file.to_str().unwrap().to_owned(),
9921021
key_file: key_file.to_str().unwrap().to_owned(),
993-
client_ca_file: Some(client_ca.to_str().unwrap().to_owned()),
1022+
client_ca_file: clients_cas_info
1023+
.clone()
1024+
.into_iter()
1025+
.map(|it| it.0)
1026+
.collect(),
9941027
}),
9951028
(Some(_), None) => Some(policy_server::config::TlsConfig {
9961029
cert_file: cert_file.to_str().unwrap().to_owned(),
9971030
key_file: key_file.to_str().unwrap().to_owned(),
998-
client_ca_file: None,
1031+
client_ca_file: vec![],
9991032
}),
10001033
_ => {
10011034
panic!("Invalid test case")
@@ -1028,38 +1061,71 @@ async fn test_tls(
10281061
.expect("policy server is not ready");
10291062
assert_eq!(status_code, reqwest::StatusCode::OK);
10301063

1031-
// Validate TLS communication
1032-
let mut builder = reqwest::Client::builder();
1064+
// Test sending request to policy server using each of the client CA certificates
1065+
let client_to_test = match client_tls_data {
1066+
Some(_) => clients_cas_info
1067+
.iter()
1068+
.map(|(_, client_cert, client_key)| {
1069+
build_request_client(
1070+
server_tls_data.as_ref(),
1071+
Some(client_cert.to_owned()),
1072+
Some(client_key.to_owned()),
1073+
)
1074+
})
1075+
.collect(),
1076+
_ => vec![build_request_client(server_tls_data.as_ref(), None, None)],
1077+
};
10331078

1034-
if server_tls_data.is_some() {
1035-
let certificate = reqwest::Certificate::from_pem(server_cert.as_bytes())
1036-
.expect("Invalid policy server certificate");
1037-
builder = builder.add_root_certificate(certificate);
1079+
for client in client_to_test {
1080+
let response =
1081+
send_validate_request(&client, format!("{host}:{port}"), server_tls_data.as_ref());
1082+
assert_eq!(response.await.status(), reqwest::StatusCode::OK);
10381083
}
1084+
}
10391085

1040-
if client_tls_data.is_some() {
1041-
let identity =
1042-
reqwest::Identity::from_pem(format!("{}\n{}", client_cert, client_key).as_bytes())
1043-
.expect("successfull pem parsing");
1044-
builder = builder.identity(identity);
1045-
};
1046-
let client = builder.build().unwrap();
1047-
1086+
async fn send_validate_request(
1087+
client: &reqwest::Client,
1088+
address: String,
1089+
server_tls_data: Option<&certificate_reload_helpers::TlsData>,
1090+
) -> reqwest::Response {
10481091
let prefix = if server_tls_data.is_some() {
10491092
"https"
10501093
} else {
10511094
"http"
10521095
};
1053-
let url =
1054-
reqwest::Url::parse(&format!("{prefix}://{host}:{port}/validate/pod-privileged")).unwrap();
1055-
let response = client
1056-
.post(url)
1096+
let url = reqwest::Url::parse(&format!("{prefix}://{address}/validate/pod-privileged"))
1097+
.expect("failed to format url");
1098+
client
1099+
.post(url.clone())
10571100
.header(header::CONTENT_TYPE, "application/json")
10581101
.body(include_str!("data/pod_without_privileged_containers.json"))
10591102
.send()
1060-
.await;
1061-
assert_eq!(
1062-
response.expect("successfull request").status(),
1063-
reqwest::StatusCode::OK
1064-
);
1103+
.await
1104+
.expect("successfull request")
1105+
}
1106+
1107+
fn build_request_client(
1108+
server_tls_data: Option<&certificate_reload_helpers::TlsData>,
1109+
client_cert: Option<String>,
1110+
client_key: Option<String>,
1111+
) -> reqwest::Client {
1112+
// Validate TLS communication
1113+
let mut builder = reqwest::Client::builder();
1114+
1115+
if let Some(server_tls_data) = server_tls_data {
1116+
let certificate = reqwest::Certificate::from_pem(server_tls_data.cert.clone().as_bytes())
1117+
.expect("Invalid policy server certificate");
1118+
builder = builder.add_root_certificate(certificate);
1119+
}
1120+
1121+
match (client_cert, client_key) {
1122+
(Some(client_cert), Some(client_key)) => {
1123+
let identity =
1124+
reqwest::Identity::from_pem(format!("{}\n{}", client_cert, client_key).as_bytes())
1125+
.expect("successfull pem parsing");
1126+
builder = builder.identity(identity)
1127+
}
1128+
_ => {}
1129+
}
1130+
builder.build().expect("failed to build client")
10651131
}

0 commit comments

Comments
 (0)