Skip to content

Commit 77add8f

Browse files
chore: fix naming and cleanup
Signed-off-by: Fabrizio Sestito <fabrizio.sestito@suse.com>
1 parent f5dadfe commit 77add8f

File tree

3 files changed

+40
-42
lines changed

3 files changed

+40
-42
lines changed

src/config.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ pub struct Config {
5454
pub struct TlsConfig {
5555
pub cert_file: String,
5656
pub key_file: String,
57-
pub client_ca_cert_file: Option<String>,
57+
pub client_ca_file: Option<String>,
5858
}
5959

6060
impl Config {
@@ -192,13 +192,13 @@ fn readiness_probe_bind_address(matches: &clap::ArgMatches) -> Result<SocketAddr
192192
fn build_tls_config(matches: &clap::ArgMatches) -> Result<Option<TlsConfig>> {
193193
let cert_file = matches.get_one::<String>("cert-file").cloned();
194194
let key_file = matches.get_one::<String>("key-file").cloned();
195-
let client_ca_cert_file = matches.get_one::<String>("client-ca-file").cloned();
195+
let client_ca_file = matches.get_one::<String>("client-ca-file").cloned();
196196

197-
match (cert_file, key_file, &client_ca_cert_file) {
197+
match (cert_file, key_file, &client_ca_file) {
198198
(Some(cert_file), Some(key_file), _) => Ok(Some(TlsConfig {
199199
cert_file,
200200
key_file,
201-
client_ca_cert_file,
201+
client_ca_file,
202202
})),
203203
// No TLS configuration provided
204204
(None, None, None) => Ok(None),

src/lib.rs

+35-37
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,17 @@ use policy_evaluator::{
3333
use profiling::activate_memory_profiling;
3434
use rayon::prelude::*;
3535
use std::{fs, net::SocketAddr, sync::Arc};
36+
use std::{fs::File, io::BufReader};
3637
use tokio::{
3738
sync::{oneshot, Notify, Semaphore},
3839
time,
3940
};
4041
use tower_http::trace::{self, TraceLayer};
4142

43+
use rustls::{server::WebPkiClientVerifier, RootCertStore, ServerConfig};
44+
use rustls_pemfile::Item;
45+
use rustls_pki_types::{CertificateDer, PrivateKeyDer};
46+
4247
// This is required by certificate hot reload when using inotify, which is available only on linux
4348
#[cfg(target_os = "linux")]
4449
use tokio_stream::StreamExt;
@@ -286,21 +291,14 @@ impl PolicyServer {
286291
}
287292
}
288293

289-
// Load the ServerConfig to be used by the Policy Server configuring the server
290-
// certificate and mTLS when necessary
291-
//
292-
// RustlsConfig does not offer a function to load the client CA certificate together with the
293-
// service certificates. Therefore, we need to load everything and build the ServerConfig
294+
/// Load the ServerConfig to be used by the Policy Server configuring the server
295+
/// certificate and mTLS when necessary
296+
///
297+
/// RustlsConfig does not offer a function to load the client CA certificate together with the
298+
/// service certificates. Therefore, we need to load everything and build the ServerConfig
294299
async fn build_tls_server_config(tls_config: &TlsConfig) -> Result<rustls::ServerConfig> {
295-
use std::{fs::File, io::BufReader};
296-
297-
use rustls::{server::WebPkiClientVerifier, RootCertStore, ServerConfig};
298-
use rustls_pemfile::Item;
299-
use rustls_pki_types::{CertificateDer, PrivateKeyDer};
300-
301-
let cert_file = &mut BufReader::new(File::open(tls_config.cert_file.clone())?);
302-
let key_file = &mut BufReader::new(File::open(tls_config.key_file.clone())?);
303-
let cert: Vec<CertificateDer> = rustls_pemfile::certs(cert_file)
300+
let cert_reader = &mut BufReader::new(File::open(tls_config.cert_file.clone())?);
301+
let cert: Vec<CertificateDer> = rustls_pemfile::certs(cert_reader)
304302
.filter_map(|it| {
305303
if let Err(ref e) = it {
306304
warn!("Cannot parse certificate: {e}");
@@ -312,7 +310,9 @@ async fn build_tls_server_config(tls_config: &TlsConfig) -> Result<rustls::Serve
312310
if cert.len() > 1 {
313311
return Err(anyhow!("Multiple certificates provided in cert file"));
314312
}
315-
let mut key_vec: Vec<Vec<u8>> = rustls_pemfile::read_all(key_file)
313+
314+
let key_file_reader = &mut BufReader::new(File::open(tls_config.key_file.clone())?);
315+
let mut key_vec: Vec<Vec<u8>> = rustls_pemfile::read_all(key_file_reader)
316316
.filter_map(|i| match i.ok()? {
317317
Item::Sec1Key(key) => Some(key.secret_sec1_der().to_vec()),
318318
Item::Pkcs1Key(key) => Some(key.secret_pkcs1_der().to_vec()),
@@ -332,36 +332,35 @@ async fn build_tls_server_config(tls_config: &TlsConfig) -> Result<rustls::Serve
332332
let key = PrivateKeyDer::try_from(key_vec.pop().unwrap())
333333
.map_err(|e| anyhow!("Cannot parse server key: {e}"))?;
334334

335-
let config = if let Some(client_ca_cert_file_path) = tls_config.client_ca_cert_file.clone() {
335+
if let Some(client_ca_file) = tls_config.client_ca_file.clone() {
336336
// we have the client CA. Therefore, we should enable mTLS.
337-
let client_ca_cert_file = &mut BufReader::new(File::open(client_ca_cert_file_path)?);
337+
let client_ca_reader = &mut BufReader::new(File::open(client_ca_file)?);
338338

339-
let mut ca_certs = RootCertStore::empty();
340-
let client_ca_certs: Vec<_> = rustls_pemfile::certs(client_ca_cert_file)
339+
let mut store = RootCertStore::empty();
340+
let client_ca_certs: Vec<_> = rustls_pemfile::certs(client_ca_reader)
341341
.filter_map(|it| {
342342
if let Err(ref e) = it {
343343
warn!("Cannot parse client CA certificate: {e}");
344344
}
345345
it.ok()
346346
})
347347
.collect();
348-
let (cert_added, cert_ignored) = ca_certs.add_parsable_certificates(client_ca_certs);
348+
let (cert_added, cert_ignored) = store.add_parsable_certificates(client_ca_certs);
349349
info!(
350350
client_ca_certs_added = cert_added,
351351
client_ca_certs_ignored = cert_ignored,
352352
"Loaded client CA certificates"
353353
);
354-
let client_verifier = WebPkiClientVerifier::builder(Arc::new(ca_certs)).build()?;
354+
let client_verifier = WebPkiClientVerifier::builder(Arc::new(store)).build()?;
355355

356-
ServerConfig::builder()
356+
return Ok(ServerConfig::builder()
357357
.with_client_cert_verifier(client_verifier)
358-
.with_single_cert(cert, key)?
359-
} else {
360-
ServerConfig::builder()
361-
.with_no_client_auth()
362-
.with_single_cert(cert, key)?
363-
};
364-
Ok(config)
358+
.with_single_cert(cert, key)?);
359+
}
360+
361+
Ok(ServerConfig::builder()
362+
.with_no_client_auth()
363+
.with_single_cert(cert, key)?)
365364
}
366365

367366
/// There's no watching of the certificate files on non-linux platforms
@@ -386,10 +385,6 @@ async fn create_tls_config_and_watch_certificate_changes(
386385
) -> Result<RustlsConfig> {
387386
use ::tracing::error;
388387

389-
let cert_file_path = tls_config.cert_file.clone();
390-
let key_file_path = tls_config.key_file.clone();
391-
let client_ca_cert_path = tls_config.client_ca_cert_file.clone();
392-
393388
let config = build_tls_server_config(&tls_config).await?;
394389

395390
let rust_config = RustlsConfig::from_config(Arc::new(config));
@@ -399,19 +394,22 @@ async fn create_tls_config_and_watch_certificate_changes(
399394
inotify::Inotify::init().map_err(|e| anyhow!("Cannot initialize inotify: {e}"))?;
400395
let cert_watch = inotify
401396
.watches()
402-
.add(cert_file_path.clone(), inotify::WatchMask::CLOSE_WRITE)
397+
.add(
398+
tls_config.cert_file.clone(),
399+
inotify::WatchMask::CLOSE_WRITE,
400+
)
403401
.map_err(|e| anyhow!("Cannot watch certificate file: {e}"))?;
404402
let key_watch = inotify
405403
.watches()
406-
.add(key_file_path.clone(), inotify::WatchMask::CLOSE_WRITE)
404+
.add(tls_config.key_file.clone(), inotify::WatchMask::CLOSE_WRITE)
407405
.map_err(|e| anyhow!("Cannot watch key file: {e}"))?;
408406

409407
let mut client_cert_watch = None;
410-
if let Some(ref client_cert_file) = client_ca_cert_path {
408+
if let Some(ref client_ca_file) = tls_config.client_ca_file {
411409
client_cert_watch = Some(
412410
inotify
413411
.watches()
414-
.add(client_cert_file.clone(), inotify::WatchMask::CLOSE_WRITE)
412+
.add(client_ca_file, inotify::WatchMask::CLOSE_WRITE)
415413
.map_err(|e| anyhow!("Cannot watch client certificate file: {e}"))?,
416414
);
417415
}

tests/integration_test.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -688,7 +688,7 @@ async fn test_detect_certificate_rotation() {
688688
config.tls_config = Some(policy_server::config::TlsConfig {
689689
cert_file: cert_file.to_str().unwrap().to_owned(),
690690
key_file: key_file.to_str().unwrap().to_owned(),
691-
client_ca_cert_file: Some(client_ca.to_str().unwrap().to_owned()),
691+
client_ca_file: Some(client_ca.to_str().unwrap().to_owned()),
692692
});
693693
config.policies = HashMap::new();
694694

0 commit comments

Comments
 (0)