-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: general cleanup to facilitate feature development (#18)
- Loading branch information
Showing
8 changed files
with
650 additions
and
478 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
use figment::{ | ||
providers::{Env, Format, Yaml}, | ||
Figment, | ||
}; | ||
use serde::{de, Deserialize, Deserializer}; | ||
|
||
#[derive(Clone, Deserialize, Debug)] | ||
pub struct Config { | ||
pub key: String, | ||
pub secret: String, | ||
#[serde(deserialize_with = "from_str_deserialize")] | ||
pub base: reqwest::Url, | ||
#[serde(default = "default_bind")] | ||
pub bind: String, | ||
#[serde(default)] | ||
pub domain_filters: Vec<String>, | ||
#[serde(default)] | ||
pub allow_invalid_certs: bool, | ||
#[serde(deserialize_with = "deserialize_certificate", default)] | ||
pub certificate_bundle: Vec<reqwest::Certificate>, | ||
} | ||
|
||
fn from_str_deserialize<'de, D, T>(deserializer: D) -> Result<T, D::Error> | ||
where | ||
D: Deserializer<'de>, | ||
T: std::str::FromStr, | ||
<T as std::str::FromStr>::Err: std::fmt::Display, | ||
{ | ||
let s = String::deserialize(deserializer)?; | ||
T::from_str(&s).map_err(de::Error::custom) | ||
} | ||
|
||
fn default_bind() -> String { | ||
"127.0.0.1:8800".to_owned() | ||
} | ||
|
||
fn deserialize_certificate<'de, D>(deserializer: D) -> Result<Vec<reqwest::Certificate>, D::Error> | ||
where | ||
D: Deserializer<'de>, | ||
{ | ||
Ok(match Option::<String>::deserialize(deserializer)? { | ||
None => Vec::new(), | ||
Some(b) => { | ||
reqwest::Certificate::from_pem_bundle(b.as_bytes()).map_err(de::Error::custom)? | ||
} | ||
}) | ||
} | ||
|
||
impl Config { | ||
pub fn try_from_env() -> figment::Result<Config> { | ||
Figment::new() | ||
.merge(Yaml::file("config.yaml")) | ||
.merge(Env::prefixed("OPNSENSE_")) | ||
.extract() | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
use crate::opnsense::unbound; | ||
use axum::{ | ||
http::{header, HeaderValue, StatusCode}, | ||
response::IntoResponse, | ||
}; | ||
use serde::{Deserialize, Deserializer, Serialize}; | ||
|
||
#[derive(Serialize, Debug)] | ||
pub struct Filters { | ||
pub filters: Vec<String>, | ||
} | ||
|
||
#[derive(Serialize, Deserialize, Debug, Default)] | ||
pub struct Endpoints(pub Vec<Endpoint>); | ||
|
||
#[derive(Serialize, Deserialize, Default, Clone, Debug)] | ||
#[serde(rename_all = "camelCase")] | ||
pub struct Endpoint { | ||
pub dns_name: String, | ||
#[serde(default)] | ||
pub targets: Targets, | ||
pub record_type: String, | ||
#[serde(skip_serializing_if = "Option::is_none")] | ||
pub set_identifier: Option<String>, | ||
#[serde(skip_serializing_if = "Option::is_none")] | ||
pub record_ttl: Option<u64>, | ||
#[serde(skip_serializing_if = "Vec::is_empty", default)] | ||
pub provider_specific: Vec<ProviderSpecificProperty>, | ||
} | ||
|
||
impl From<unbound::Row> for Endpoint { | ||
fn from(value: unbound::Row) -> Endpoint { | ||
Endpoint { | ||
dns_name: value.domain.clone(), | ||
set_identifier: None, | ||
record_type: value | ||
.rr | ||
.split_whitespace() | ||
.next() | ||
.map(|s| s.to_string()) | ||
.unwrap_or("A".to_string()), | ||
targets: Targets(vec![value.server.clone()]), | ||
..Default::default() | ||
} | ||
} | ||
} | ||
|
||
#[derive(Serialize, Deserialize, Default, Clone, Debug)] | ||
pub struct Targets(pub Vec<String>); | ||
|
||
#[derive(Serialize, Deserialize, Clone, Debug)] | ||
pub struct ProviderSpecificProperty { | ||
name: String, | ||
value: String, | ||
} | ||
|
||
#[derive(Deserialize, Debug)] | ||
#[serde(rename_all = "PascalCase")] | ||
pub struct Changes { | ||
#[serde(deserialize_with = "deserialize_null_default")] | ||
pub create: Endpoints, | ||
#[serde(deserialize_with = "deserialize_null_default")] | ||
_update_old: Endpoints, | ||
#[serde(deserialize_with = "deserialize_null_default")] | ||
pub update_new: Endpoints, | ||
#[serde(deserialize_with = "deserialize_null_default")] | ||
pub delete: Endpoints, | ||
} | ||
|
||
fn deserialize_null_default<'de, D, T>(deserializer: D) -> Result<T, D::Error> | ||
where | ||
D: Deserializer<'de>, | ||
T: Default + Deserialize<'de>, | ||
{ | ||
let opt = Option::deserialize(deserializer)?; | ||
Ok(opt.unwrap_or_default()) | ||
} | ||
|
||
pub struct Edns<T>(pub T); | ||
|
||
impl<T> IntoResponse for Edns<T> | ||
where | ||
T: Serialize, | ||
{ | ||
fn into_response(self) -> axum::response::Response { | ||
match serde_json::to_string(&self.0) { | ||
Ok(buf) => ( | ||
[( | ||
header::CONTENT_TYPE, | ||
HeaderValue::from_static("application/external.dns.webhook+json;version=1"), | ||
)], | ||
buf, | ||
) | ||
.into_response(), | ||
Err(err) => ( | ||
StatusCode::INTERNAL_SERVER_ERROR, | ||
[(header::CONTENT_TYPE, HeaderValue::from_static("plain/text"))], | ||
err.to_string(), | ||
) | ||
.into_response(), | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
pub mod config; | ||
mod external_dns; | ||
mod opnsense; | ||
|
||
use axum::{ | ||
extract::State, | ||
http::StatusCode, | ||
routing::{get, post}, | ||
Json, Router, | ||
}; | ||
use config::Config; | ||
use external_dns::{Changes, Edns, Endpoint, Endpoints, Filters, Targets}; | ||
use opnsense::unbound; | ||
use opnsense::Opnsense; | ||
use std::collections::HashMap; | ||
use std::sync::Arc; | ||
use tokio::sync::Mutex; | ||
use tower_http::trace::{self, TraceLayer}; | ||
use tracing::instrument; | ||
|
||
struct AppState { | ||
config: Config, | ||
opnsense: Opnsense, | ||
uuid_map: Mutex<HashMap<String, String>>, | ||
} | ||
|
||
pub struct Server { | ||
config: Config, | ||
} | ||
|
||
impl From<Config> for Server { | ||
fn from(config: Config) -> Self { | ||
Self { config } | ||
} | ||
} | ||
|
||
impl Server { | ||
pub async fn serve(&self) -> Result<(), Box<dyn std::error::Error>> { | ||
let state = Arc::new(AppState { | ||
opnsense: Opnsense::try_from(&self.config)?, | ||
config: self.config.clone(), | ||
uuid_map: Mutex::new(HashMap::new()), | ||
}); | ||
|
||
let app = Router::new() | ||
.route("/", get(negotiate)) | ||
.route("/healthz", get(healthz)) | ||
.route("/records", get(get_records).post(set_records)) | ||
.route("/adjustendpoints", post(adjust_records)) | ||
.with_state(state) | ||
.layer( | ||
TraceLayer::new_for_http() | ||
.make_span_with(trace::DefaultMakeSpan::new().level(tracing::Level::INFO)) | ||
.on_request(trace::DefaultOnRequest::new().level(tracing::Level::INFO)) | ||
.on_response(trace::DefaultOnResponse::new().level(tracing::Level::INFO)), | ||
); | ||
|
||
let listener = tokio::net::TcpListener::bind(&self.config.bind).await?; | ||
tracing::info!("listening on {}", self.config.bind); | ||
|
||
Ok(axum::serve(listener, app).await?) | ||
} | ||
} | ||
|
||
#[instrument(skip(state))] | ||
async fn negotiate(State(state): State<Arc<AppState>>) -> Result<Edns<Filters>, StatusCode> { | ||
//TODO: the rest of the implementation doesn't check is domains are valid in regards with those filters | ||
|
||
Ok(Edns(Filters { | ||
filters: state.config.domain_filters.clone(), | ||
})) | ||
} | ||
|
||
#[instrument(skip(_state))] | ||
async fn healthz(State(_state): State<Arc<AppState>>) -> () {} | ||
|
||
#[instrument(skip(state))] | ||
async fn get_records(State(state): State<Arc<AppState>>) -> Result<Edns<Endpoints>, StatusCode> { | ||
let list = state | ||
.opnsense | ||
.unbound() | ||
.settings() | ||
.search_host_override() | ||
.await | ||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; | ||
|
||
let mut guard = state.uuid_map.lock().await; | ||
guard.clear(); | ||
for r in list.rows.iter() { | ||
let _ = guard.insert(r.domain.clone(), r.uuid.clone()); | ||
} | ||
tracing::debug!("new uuid map: {:?}", guard); | ||
drop(guard); | ||
|
||
Ok(Edns(Endpoints( | ||
list.rows | ||
.into_iter() | ||
.filter(|r| r.enabled == "1") | ||
.map(Endpoint::from) | ||
.collect(), | ||
))) | ||
} | ||
|
||
#[instrument(skip(state))] | ||
async fn set_records( | ||
State(state): State<Arc<AppState>>, | ||
Json(changes): Json<Changes>, | ||
) -> Result<StatusCode, StatusCode> { | ||
let mut need_restart = false; | ||
|
||
for ep in changes | ||
.create | ||
.0 | ||
.iter() | ||
.map(|c| unbound::Row::from(c.clone())) | ||
{ | ||
let res = state | ||
.opnsense | ||
.unbound() | ||
.settings() | ||
.add_host_override(&ep) | ||
.await | ||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; | ||
|
||
let mut guard = state.uuid_map.lock().await; | ||
guard.insert(ep.domain.clone(), res.uuid); | ||
need_restart = true; | ||
} | ||
|
||
for ep in changes | ||
.update_new | ||
.0 | ||
.iter() | ||
.map(|c| unbound::Row::from(c.clone())) | ||
{ | ||
let guard = state.uuid_map.lock().await; | ||
if let Some(uuid) = guard.get(&ep.domain) { | ||
if let Err(e) = state | ||
.opnsense | ||
.unbound() | ||
.settings() | ||
.set_host_override(uuid, &ep) | ||
.await | ||
{ | ||
tracing::error!("update: {:?}", e); | ||
return Err(StatusCode::INTERNAL_SERVER_ERROR); | ||
} else { | ||
need_restart = true; | ||
} | ||
} else { | ||
tracing::error!("update: could not find uuid in map: {:?}", &ep.domain); | ||
return Err(StatusCode::INTERNAL_SERVER_ERROR); | ||
} | ||
} | ||
|
||
for ep in changes.delete.0.iter() { | ||
let mut guard = state.uuid_map.lock().await; | ||
if let Some(uuid) = guard.get(&ep.dns_name) { | ||
state | ||
.opnsense | ||
.unbound() | ||
.settings() | ||
.delete_host_override(uuid) | ||
.await | ||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; | ||
let _ = guard.remove(&ep.dns_name); | ||
need_restart = true; | ||
} else { | ||
tracing::error!("delete: could not find uuid in map"); | ||
return Err(StatusCode::INTERNAL_SERVER_ERROR); | ||
} | ||
} | ||
|
||
if need_restart { | ||
let _ = state.opnsense.unbound().service().restart().await; | ||
} | ||
|
||
Ok(StatusCode::NO_CONTENT) | ||
} | ||
#[instrument(skip(_state))] | ||
async fn adjust_records( | ||
State(_state): State<Arc<AppState>>, | ||
Json(endpoints): Json<Endpoints>, | ||
) -> Result<Edns<Endpoints>, StatusCode> { | ||
let mut results = Endpoints(Vec::new()); | ||
|
||
for ep in endpoints.0.iter() { | ||
match ep.record_type.as_str() { | ||
"A" | "AAAA" => {} | ||
_ => continue, | ||
}; | ||
|
||
results.0.push(Endpoint { | ||
targets: Targets(ep.targets.0[..1].to_vec()), | ||
record_ttl: None, | ||
..ep.clone() | ||
}); | ||
} | ||
|
||
Ok(Edns(results)) | ||
} |
Oops, something went wrong.