Skip to content

Commit

Permalink
refactor: general cleanup to facilitate feature development (#18)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ajpantuso authored Jun 1, 2024
1 parent bcca644 commit 6659619
Show file tree
Hide file tree
Showing 8 changed files with 650 additions and 478 deletions.
56 changes: 56 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
use figment::{
providers::{Env, Format, Yaml},
Figment,
};
use serde::{de, Deserialize, Deserializer};

#[derive(Clone, Deserialize, Debug)]
pub struct Config {
pub key: String,
pub secret: String,
#[serde(deserialize_with = "from_str_deserialize")]
pub base: reqwest::Url,
#[serde(default = "default_bind")]
pub bind: String,
#[serde(default)]
pub domain_filters: Vec<String>,
#[serde(default)]
pub allow_invalid_certs: bool,
#[serde(deserialize_with = "deserialize_certificate", default)]
pub certificate_bundle: Vec<reqwest::Certificate>,
}

fn from_str_deserialize<'de, D, T>(deserializer: D) -> Result<T, D::Error>
where
D: Deserializer<'de>,
T: std::str::FromStr,
<T as std::str::FromStr>::Err: std::fmt::Display,
{
let s = String::deserialize(deserializer)?;
T::from_str(&s).map_err(de::Error::custom)
}

fn default_bind() -> String {
"127.0.0.1:8800".to_owned()
}

fn deserialize_certificate<'de, D>(deserializer: D) -> Result<Vec<reqwest::Certificate>, D::Error>
where
D: Deserializer<'de>,
{
Ok(match Option::<String>::deserialize(deserializer)? {
None => Vec::new(),
Some(b) => {
reqwest::Certificate::from_pem_bundle(b.as_bytes()).map_err(de::Error::custom)?
}
})
}

impl Config {
pub fn try_from_env() -> figment::Result<Config> {
Figment::new()
.merge(Yaml::file("config.yaml"))
.merge(Env::prefixed("OPNSENSE_"))
.extract()
}
}
103 changes: 103 additions & 0 deletions src/external_dns.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
use crate::opnsense::unbound;
use axum::{
http::{header, HeaderValue, StatusCode},
response::IntoResponse,
};
use serde::{Deserialize, Deserializer, Serialize};

#[derive(Serialize, Debug)]
pub struct Filters {
pub filters: Vec<String>,
}

#[derive(Serialize, Deserialize, Debug, Default)]
pub struct Endpoints(pub Vec<Endpoint>);

#[derive(Serialize, Deserialize, Default, Clone, Debug)]
#[serde(rename_all = "camelCase")]
pub struct Endpoint {
pub dns_name: String,
#[serde(default)]
pub targets: Targets,
pub record_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub set_identifier: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub record_ttl: Option<u64>,
#[serde(skip_serializing_if = "Vec::is_empty", default)]
pub provider_specific: Vec<ProviderSpecificProperty>,
}

impl From<unbound::Row> for Endpoint {
fn from(value: unbound::Row) -> Endpoint {
Endpoint {
dns_name: value.domain.clone(),
set_identifier: None,
record_type: value
.rr
.split_whitespace()
.next()
.map(|s| s.to_string())
.unwrap_or("A".to_string()),
targets: Targets(vec![value.server.clone()]),
..Default::default()
}
}
}

#[derive(Serialize, Deserialize, Default, Clone, Debug)]
pub struct Targets(pub Vec<String>);

#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct ProviderSpecificProperty {
name: String,
value: String,
}

#[derive(Deserialize, Debug)]
#[serde(rename_all = "PascalCase")]
pub struct Changes {
#[serde(deserialize_with = "deserialize_null_default")]
pub create: Endpoints,
#[serde(deserialize_with = "deserialize_null_default")]
_update_old: Endpoints,
#[serde(deserialize_with = "deserialize_null_default")]
pub update_new: Endpoints,
#[serde(deserialize_with = "deserialize_null_default")]
pub delete: Endpoints,
}

fn deserialize_null_default<'de, D, T>(deserializer: D) -> Result<T, D::Error>
where
D: Deserializer<'de>,
T: Default + Deserialize<'de>,
{
let opt = Option::deserialize(deserializer)?;
Ok(opt.unwrap_or_default())
}

pub struct Edns<T>(pub T);

impl<T> IntoResponse for Edns<T>
where
T: Serialize,
{
fn into_response(self) -> axum::response::Response {
match serde_json::to_string(&self.0) {
Ok(buf) => (
[(
header::CONTENT_TYPE,
HeaderValue::from_static("application/external.dns.webhook+json;version=1"),
)],
buf,
)
.into_response(),
Err(err) => (
StatusCode::INTERNAL_SERVER_ERROR,
[(header::CONTENT_TYPE, HeaderValue::from_static("plain/text"))],
err.to_string(),
)
.into_response(),
}
}
}
201 changes: 201 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
pub mod config;
mod external_dns;
mod opnsense;

use axum::{
extract::State,
http::StatusCode,
routing::{get, post},
Json, Router,
};
use config::Config;
use external_dns::{Changes, Edns, Endpoint, Endpoints, Filters, Targets};
use opnsense::unbound;
use opnsense::Opnsense;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;
use tower_http::trace::{self, TraceLayer};
use tracing::instrument;

struct AppState {
config: Config,
opnsense: Opnsense,
uuid_map: Mutex<HashMap<String, String>>,
}

pub struct Server {
config: Config,
}

impl From<Config> for Server {
fn from(config: Config) -> Self {
Self { config }
}
}

impl Server {
pub async fn serve(&self) -> Result<(), Box<dyn std::error::Error>> {
let state = Arc::new(AppState {
opnsense: Opnsense::try_from(&self.config)?,
config: self.config.clone(),
uuid_map: Mutex::new(HashMap::new()),
});

let app = Router::new()
.route("/", get(negotiate))
.route("/healthz", get(healthz))
.route("/records", get(get_records).post(set_records))
.route("/adjustendpoints", post(adjust_records))
.with_state(state)
.layer(
TraceLayer::new_for_http()
.make_span_with(trace::DefaultMakeSpan::new().level(tracing::Level::INFO))
.on_request(trace::DefaultOnRequest::new().level(tracing::Level::INFO))
.on_response(trace::DefaultOnResponse::new().level(tracing::Level::INFO)),
);

let listener = tokio::net::TcpListener::bind(&self.config.bind).await?;
tracing::info!("listening on {}", self.config.bind);

Ok(axum::serve(listener, app).await?)
}
}

#[instrument(skip(state))]
async fn negotiate(State(state): State<Arc<AppState>>) -> Result<Edns<Filters>, StatusCode> {
//TODO: the rest of the implementation doesn't check is domains are valid in regards with those filters

Ok(Edns(Filters {
filters: state.config.domain_filters.clone(),
}))
}

#[instrument(skip(_state))]
async fn healthz(State(_state): State<Arc<AppState>>) -> () {}

#[instrument(skip(state))]
async fn get_records(State(state): State<Arc<AppState>>) -> Result<Edns<Endpoints>, StatusCode> {
let list = state
.opnsense
.unbound()
.settings()
.search_host_override()
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

let mut guard = state.uuid_map.lock().await;
guard.clear();
for r in list.rows.iter() {
let _ = guard.insert(r.domain.clone(), r.uuid.clone());
}
tracing::debug!("new uuid map: {:?}", guard);
drop(guard);

Ok(Edns(Endpoints(
list.rows
.into_iter()
.filter(|r| r.enabled == "1")
.map(Endpoint::from)
.collect(),
)))
}

#[instrument(skip(state))]
async fn set_records(
State(state): State<Arc<AppState>>,
Json(changes): Json<Changes>,
) -> Result<StatusCode, StatusCode> {
let mut need_restart = false;

for ep in changes
.create
.0
.iter()
.map(|c| unbound::Row::from(c.clone()))
{
let res = state
.opnsense
.unbound()
.settings()
.add_host_override(&ep)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

let mut guard = state.uuid_map.lock().await;
guard.insert(ep.domain.clone(), res.uuid);
need_restart = true;
}

for ep in changes
.update_new
.0
.iter()
.map(|c| unbound::Row::from(c.clone()))
{
let guard = state.uuid_map.lock().await;
if let Some(uuid) = guard.get(&ep.domain) {
if let Err(e) = state
.opnsense
.unbound()
.settings()
.set_host_override(uuid, &ep)
.await
{
tracing::error!("update: {:?}", e);
return Err(StatusCode::INTERNAL_SERVER_ERROR);
} else {
need_restart = true;
}
} else {
tracing::error!("update: could not find uuid in map: {:?}", &ep.domain);
return Err(StatusCode::INTERNAL_SERVER_ERROR);
}
}

for ep in changes.delete.0.iter() {
let mut guard = state.uuid_map.lock().await;
if let Some(uuid) = guard.get(&ep.dns_name) {
state
.opnsense
.unbound()
.settings()
.delete_host_override(uuid)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let _ = guard.remove(&ep.dns_name);
need_restart = true;
} else {
tracing::error!("delete: could not find uuid in map");
return Err(StatusCode::INTERNAL_SERVER_ERROR);
}
}

if need_restart {
let _ = state.opnsense.unbound().service().restart().await;
}

Ok(StatusCode::NO_CONTENT)
}
#[instrument(skip(_state))]
async fn adjust_records(
State(_state): State<Arc<AppState>>,
Json(endpoints): Json<Endpoints>,
) -> Result<Edns<Endpoints>, StatusCode> {
let mut results = Endpoints(Vec::new());

for ep in endpoints.0.iter() {
match ep.record_type.as_str() {
"A" | "AAAA" => {}
_ => continue,
};

results.0.push(Endpoint {
targets: Targets(ep.targets.0[..1].to_vec()),
record_ttl: None,
..ep.clone()
});
}

Ok(Edns(results))
}
Loading

0 comments on commit 6659619

Please sign in to comment.