Skip to content

Commit c9410c7

Browse files
authored
Merge pull request #1075 from jvanz/mtls
feat: enable mTLS.
2 parents 8c988fa + 7990214 commit c9410c7

File tree

7 files changed

+191
-35
lines changed

7 files changed

+191
-35
lines changed

Cargo.lock

+2-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ rustls = { version = "0.23", default-features = false, features = [
4242
"tls12",
4343
] }
4444
rustls-pki-types = { version = "1", features = ["alloc"] }
45+
rustls-pemfile = "2.2.0"
4546
rayon = "1.10"
4647
regex = "1.10"
4748
serde_json = "1.0"

cli-docs.md

+3
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ This document contains the help content for the `policy-server` command-line pro
2525
* `--always-accept-admission-reviews-on-namespace <NAMESPACE>` — Always accept AdmissionReviews that target the given namespace
2626
* `--cert-file <CERT_FILE>` — Path to an X.509 certificate file for HTTPS
2727

28+
Default value: ``
29+
* `--client-ca-file <CLIENT_CA_FILE>` — Path to an CA certificate file that issued the client certificate. Required to enable mTLS
30+
2831
Default value: ``
2932
* `--daemon` — If set, runs policy-server in detached mode as a daemon
3033
* `--daemon-pid-file <DAEMON-PID-FILE>` — Path to the PID file, used only when running in daemon mode

src/cli.rs

+7
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,13 @@ pub(crate) fn build_cli() -> Command {
8888
.env("KUBEWARDEN_KEY_FILE")
8989
.help("Path to an X.509 private key file for HTTPS"),
9090

91+
Arg::new("client-ca-file")
92+
.long("client-ca-file")
93+
.value_name("CLIENT_CA_FILE")
94+
.default_value("")
95+
.env("KUBEWARDEN_CLIENT_CA_FILE")
96+
.help("Path to an CA certificate file that issued the client certificate. Required to enable mTLS"),
97+
9198
Arg::new("policies")
9299
.long("policies")
93100
.value_name("POLICIES_FILE")

src/config.rs

+11-13
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ pub struct Config {
5353
pub struct TlsConfig {
5454
pub cert_file: String,
5555
pub key_file: String,
56+
pub client_ca_cert_file: Option<String>,
5657
}
5758

5859
impl Config {
@@ -127,15 +128,8 @@ impl Config {
127128
.expect("clap should have assigned a default value")
128129
.to_owned();
129130

130-
let (cert_file, key_file) = tls_files(matches)?;
131-
let tls_config = if cert_file.is_empty() {
132-
None
133-
} else {
134-
Some(TlsConfig {
135-
cert_file,
136-
key_file,
137-
})
138-
};
131+
let tls_config = Some(build_tls_config(matches)?);
132+
139133
let enable_pprof = matches
140134
.get_one::<bool>("enable-pprof")
141135
.expect("clap should have assigned a default value")
@@ -182,14 +176,18 @@ fn api_bind_address(matches: &clap::ArgMatches) -> Result<SocketAddr> {
182176
.map_err(|e| anyhow!("error parsing arguments: {}", e))
183177
}
184178

185-
fn tls_files(matches: &clap::ArgMatches) -> Result<(String, String)> {
179+
fn build_tls_config(matches: &clap::ArgMatches) -> Result<TlsConfig> {
186180
let cert_file = matches.get_one::<String>("cert-file").unwrap().to_owned();
187181
let key_file = matches.get_one::<String>("key-file").unwrap().to_owned();
182+
let client_ca_cert_file = matches.get_one::<String>("client-ca-file").cloned();
188183
if cert_file.is_empty() != key_file.is_empty() {
189-
Err(anyhow!("error parsing arguments: either both --cert-file and --key-file must be provided, or neither"))
190-
} else {
191-
Ok((cert_file, key_file))
184+
return Err(anyhow!("error parsing arguments: either both --cert-file and --key-file must be provided, or neither"));
192185
}
186+
Ok(TlsConfig {
187+
cert_file,
188+
key_file,
189+
client_ca_cert_file,
190+
})
193191
}
194192

195193
fn policies(matches: &clap::ArgMatches) -> Result<HashMap<String, PolicyOrPolicyGroup>> {

src/lib.rs

+114-12
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,84 @@ impl PolicyServer {
271271
}
272272
}
273273

274+
// Load the ServerConfig to be used by the Policy Server configuring the server
275+
// certificate and mTLS when necessary
276+
//
277+
// RustlsConfig does not offer a function to load the client CA certificate together with the
278+
// service certificates. Therefore, we need to load everything and build the ServerConfig
279+
async fn build_tls_server_config(tls_config: &TlsConfig) -> Result<rustls::ServerConfig> {
280+
use std::{fs::File, io::BufReader};
281+
282+
use rustls::{server::WebPkiClientVerifier, RootCertStore, ServerConfig};
283+
use rustls_pemfile::Item;
284+
use rustls_pki_types::{CertificateDer, PrivateKeyDer};
285+
286+
let cert_file = &mut BufReader::new(File::open(tls_config.cert_file.clone())?);
287+
let key_file = &mut BufReader::new(File::open(tls_config.key_file.clone())?);
288+
let cert: Vec<CertificateDer> = rustls_pemfile::certs(cert_file)
289+
.filter_map(|it| {
290+
if let Err(ref e) = it {
291+
warn!("Cannot parse certificate: {e}");
292+
return None;
293+
}
294+
it.ok()
295+
})
296+
.collect();
297+
if cert.len() > 1 {
298+
return Err(anyhow!("Multiple certificates provided in cert file"));
299+
}
300+
let mut key_vec: Vec<Vec<u8>> = rustls_pemfile::read_all(key_file)
301+
.filter_map(|i| match i.ok()? {
302+
Item::Sec1Key(key) => Some(key.secret_sec1_der().to_vec()),
303+
Item::Pkcs1Key(key) => Some(key.secret_pkcs1_der().to_vec()),
304+
Item::Pkcs8Key(key) => Some(key.secret_pkcs8_der().to_vec()),
305+
_ => {
306+
info!("Ignoring non-key item in key file");
307+
None
308+
}
309+
})
310+
.collect();
311+
if key_vec.is_empty() {
312+
return Err(anyhow!("No key provided in key file"));
313+
}
314+
if key_vec.len() > 1 {
315+
return Err(anyhow!("Multiple keys provided in key file"));
316+
}
317+
let key = PrivateKeyDer::try_from(key_vec.pop().unwrap())
318+
.map_err(|e| anyhow!("Cannot parse server key: {e}"))?;
319+
320+
let config = if let Some(client_ca_cert_file_path) = tls_config.client_ca_cert_file.clone() {
321+
// we have the client CA. Therefore, we should enable mTLS.
322+
let client_ca_cert_file = &mut BufReader::new(File::open(client_ca_cert_file_path)?);
323+
324+
let mut ca_certs = RootCertStore::empty();
325+
let client_ca_certs: Vec<_> = rustls_pemfile::certs(client_ca_cert_file)
326+
.filter_map(|it| {
327+
if let Err(ref e) = it {
328+
warn!("Cannot parse client CA certificate: {e}");
329+
}
330+
it.ok()
331+
})
332+
.collect();
333+
let (cert_added, cert_ignored) = ca_certs.add_parsable_certificates(client_ca_certs);
334+
info!(
335+
client_ca_certs_added = cert_added,
336+
client_ca_certs_ignored = cert_ignored,
337+
"Loaded client CA certificates"
338+
);
339+
let client_verifier = WebPkiClientVerifier::builder(Arc::new(ca_certs)).build()?;
340+
341+
ServerConfig::builder()
342+
.with_client_cert_verifier(client_verifier)
343+
.with_single_cert(cert, key)?
344+
} else {
345+
ServerConfig::builder()
346+
.with_no_client_auth()
347+
.with_single_cert(cert, key)?
348+
};
349+
Ok(config)
350+
}
351+
274352
/// There's no watching of the certificate files on non-linux platforms
275353
/// since we rely on inotify to watch for changes
276354
#[cfg(not(target_os = "linux"))]
@@ -293,24 +371,36 @@ async fn create_tls_config_and_watch_certificate_changes(
293371
) -> Result<RustlsConfig> {
294372
use ::tracing::error;
295373

296-
let cert_file = tls_config.cert_file.clone();
297-
let key_file = tls_config.key_file.clone();
374+
let cert_file_path = tls_config.cert_file.clone();
375+
let key_file_path = tls_config.key_file.clone();
376+
let client_ca_cert_path = tls_config.client_ca_cert_file.clone();
377+
378+
let config = build_tls_server_config(&tls_config).await?;
298379

299-
let rust_config =
300-
RustlsConfig::from_pem_file(tls_config.cert_file, tls_config.key_file).await?;
380+
let rust_config = RustlsConfig::from_config(Arc::new(config));
301381
let reloadable_rust_config = rust_config.clone();
302382

303383
let inotify =
304384
inotify::Inotify::init().map_err(|e| anyhow!("Cannot initialize inotify: {e}"))?;
305385
let cert_watch = inotify
306386
.watches()
307-
.add(cert_file.clone(), inotify::WatchMask::CLOSE_WRITE)
387+
.add(cert_file_path.clone(), inotify::WatchMask::CLOSE_WRITE)
308388
.map_err(|e| anyhow!("Cannot watch certificate file: {e}"))?;
309389
let key_watch = inotify
310390
.watches()
311-
.add(key_file.clone(), inotify::WatchMask::CLOSE_WRITE)
391+
.add(key_file_path.clone(), inotify::WatchMask::CLOSE_WRITE)
312392
.map_err(|e| anyhow!("Cannot watch key file: {e}"))?;
313393

394+
let mut client_cert_watch = None;
395+
if let Some(ref client_cert_file) = client_ca_cert_path {
396+
client_cert_watch = Some(
397+
inotify
398+
.watches()
399+
.add(client_cert_file.clone(), inotify::WatchMask::CLOSE_WRITE)
400+
.map_err(|e| anyhow!("Cannot watch client certificate file: {e}"))?,
401+
);
402+
}
403+
314404
let buffer = [0; 1024];
315405
let stream = inotify
316406
.into_event_stream(buffer)
@@ -320,6 +410,7 @@ async fn create_tls_config_and_watch_certificate_changes(
320410
tokio::pin!(stream);
321411
let mut cert_changed = false;
322412
let mut key_changed = false;
413+
let mut client_cert_changed = false;
323414

324415
while let Some(event) = stream.next().await {
325416
let event = match event {
@@ -338,18 +429,29 @@ async fn create_tls_config_and_watch_certificate_changes(
338429
info!("TLS key file has been modified");
339430
key_changed = true;
340431
}
432+
if let Some(ref client_cert_watch) = client_cert_watch {
433+
if event.wd == *client_cert_watch {
434+
info!("TLS client certificate file has been modified");
435+
client_cert_changed = true;
436+
}
437+
}
341438

342-
if key_changed && cert_changed {
343-
info!("reloading TLS certificate");
439+
// if both the certificate and the key have been changed or there is no change in the
440+
// server cert and key, but the client cert changed, reload the certificate
441+
if (key_changed && cert_changed)
442+
|| (client_cert_changed && (key_changed == cert_changed))
443+
{
444+
info!("reloading TLS certificates");
344445

345446
cert_changed = false;
346447
key_changed = false;
347-
if let Err(e) = reloadable_rust_config
348-
.reload_from_pem_file(cert_file.clone(), key_file.clone())
349-
.await
350-
{
448+
client_cert_changed = false;
449+
let server_config = build_tls_server_config(&tls_config).await;
450+
if let Err(e) = server_config {
351451
error!("Failed to reload TLS certificate: {}", e);
452+
continue;
352453
}
454+
reloadable_rust_config.reload_from_config(Arc::new(server_config.unwrap()))
353455
}
354456
}
355457
});

tests/integration_test.rs

+53-9
Original file line numberDiff line numberDiff line change
@@ -654,9 +654,18 @@ mod certificate_reload_helpers {
654654
}
655655
}
656656

657-
pub async fn policy_server_is_ready(address: &str) -> anyhow::Result<StatusCode> {
657+
pub async fn policy_server_is_ready(
658+
address: &str,
659+
client_tls_pem_bundle: Option<String>,
660+
) -> anyhow::Result<StatusCode> {
658661
// wait for the server to start
659-
let client = reqwest::Client::builder()
662+
let mut client_builder = reqwest::Client::builder();
663+
664+
if let Some(tls_data) = client_tls_pem_bundle {
665+
let identity = reqwest::Identity::from_pem(tls_data.as_bytes())?;
666+
client_builder = client_builder.identity(identity)
667+
};
668+
let client = client_builder
660669
.danger_accept_invalid_certs(true)
661670
.build()
662671
.unwrap();
@@ -677,17 +686,21 @@ async fn test_detect_certificate_rotation() {
677686
let certs_dir = tempfile::tempdir().unwrap();
678687
let cert_file = certs_dir.path().join("policy-server.pem");
679688
let key_file = certs_dir.path().join("policy-server-key.pem");
689+
let client_ca = certs_dir.path().join("client_cert.pem");
680690

681691
let hostname1 = "cert1.example.com";
682692
let tls_data1 = create_cert(hostname1);
693+
let tls_data_client = create_cert(hostname1);
683694

684695
std::fs::write(&cert_file, tls_data1.cert).unwrap();
685696
std::fs::write(&key_file, tls_data1.key).unwrap();
697+
std::fs::write(&client_ca, tls_data_client.cert.clone()).unwrap();
686698

687699
let mut config = default_test_config();
688700
config.tls_config = Some(policy_server::config::TlsConfig {
689-
cert_file: cert_file.to_str().unwrap().to_string(),
690-
key_file: key_file.to_str().unwrap().to_string(),
701+
cert_file: cert_file.to_str().unwrap().to_owned(),
702+
key_file: key_file.to_str().unwrap().to_owned(),
703+
client_ca_cert_file: Some(client_ca.to_str().unwrap().to_owned()),
691704
});
692705
config.policies = HashMap::new();
693706

@@ -706,11 +719,18 @@ async fn test_detect_certificate_rotation() {
706719
.with_max_delay(Duration::from_secs(30))
707720
.with_max_times(5);
708721

709-
let status_code =
710-
(|| async { policy_server_is_ready(format!("{domain_ip}:{domain_port}").as_str()).await })
711-
.retry(exponential_backoff)
712-
.await
713-
.unwrap();
722+
let client_cert = tls_data_client.cert.clone();
723+
let client_key = tls_data_client.key.clone();
724+
let status_code = (|| async {
725+
policy_server_is_ready(
726+
format!("{domain_ip}:{domain_port}").as_str(),
727+
Some(format!("{client_cert}\n{client_key}")),
728+
)
729+
.await
730+
})
731+
.retry(exponential_backoff)
732+
.await
733+
.unwrap();
714734
assert_eq!(status_code, reqwest::StatusCode::OK);
715735

716736
check_tls_san_name(&domain_ip, &domain_port, hostname1)
@@ -721,6 +741,7 @@ async fn test_detect_certificate_rotation() {
721741

722742
let hostname2 = "cert2.example.com";
723743
let tls_data2 = create_cert(hostname2);
744+
let client_ca2 = create_cert(hostname2);
724745

725746
// write only the cert file
726747
std::fs::write(&cert_file, tls_data2.cert).unwrap();
@@ -742,6 +763,29 @@ async fn test_detect_certificate_rotation() {
742763
check_tls_san_name(&domain_ip, &domain_port, hostname2)
743764
.await
744765
.expect("certificate hasn't been reloaded");
766+
767+
// Let test if the server is reloading client certificate
768+
std::fs::write(&client_ca, client_ca2.cert.clone()).unwrap();
769+
770+
// give inotify some time to ensure it detected the cert change
771+
tokio::time::sleep(std::time::Duration::from_secs(4)).await;
772+
773+
assert!(policy_server_is_ready(
774+
format!("{domain_ip}:{domain_port}").as_str(),
775+
Some(format!("{client_cert}\n{client_key}")),
776+
)
777+
.await
778+
.is_err());
779+
780+
let client_cert = client_ca2.cert.clone();
781+
let client_key = client_ca2.key.clone();
782+
let status_code = policy_server_is_ready(
783+
format!("{domain_ip}:{domain_port}").as_str(),
784+
Some(format!("{client_cert}\n{client_key}")),
785+
)
786+
.await
787+
.unwrap();
788+
assert_eq!(status_code, reqwest::StatusCode::OK);
745789
}
746790

747791
#[tokio::test]

0 commit comments

Comments
 (0)