Skip to content

Commit c0f2b59

Browse files
committed
dev: extract config from core::tracker
1 parent b480b0e commit c0f2b59

File tree

17 files changed

+180
-125
lines changed

17 files changed

+180
-125
lines changed

packages/configuration/src/lib.rs

+6-12
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@
229229
//! [health_check_api]
230230
//! bind_address = "127.0.0.1:1313"
231231
//!```
232-
use std::collections::{HashMap, HashSet};
232+
use std::collections::HashMap;
233233
use std::net::IpAddr;
234234
use std::str::FromStr;
235235
use std::sync::Arc;
@@ -336,6 +336,8 @@ pub struct HttpTracker {
336336
pub ssl_key_path: Option<String>,
337337
}
338338

339+
pub type AccessTokens = HashMap<String, String>;
340+
339341
/// Configuration for the HTTP API.
340342
#[serde_as]
341343
#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)]
@@ -359,21 +361,13 @@ pub struct HttpApi {
359361
/// token and the value is the token itself. The token is used to
360362
/// authenticate the user. All tokens are valid for all endpoints and have
361363
/// the all permissions.
362-
pub access_tokens: HashMap<String, String>,
364+
pub access_tokens: AccessTokens,
363365
}
364366

365367
impl HttpApi {
366368
fn override_admin_token(&mut self, api_admin_token: &str) {
367369
self.access_tokens.insert("admin".to_string(), api_admin_token.to_string());
368370
}
369-
370-
/// Checks if the given token is one of the token in the configuration.
371-
#[must_use]
372-
pub fn contains_token(&self, token: &str) -> bool {
373-
let tokens: HashMap<String, String> = self.access_tokens.clone();
374-
let tokens: HashSet<String> = tokens.into_values().collect();
375-
tokens.contains(token)
376-
}
377371
}
378372

379373
/// Configuration for the Health Check API.
@@ -781,7 +775,7 @@ mod tests {
781775
fn http_api_configuration_should_check_if_it_contains_a_token() {
782776
let configuration = Configuration::default();
783777

784-
assert!(configuration.http_api.contains_token("MyAccessToken"));
785-
assert!(!configuration.http_api.contains_token("NonExistingToken"));
778+
assert!(configuration.http_api.access_tokens.values().any(|t| t == "MyAccessToken"));
779+
assert!(!configuration.http_api.access_tokens.values().any(|t| t == "NonExistingToken"));
786780
}
787781
}

src/bootstrap/app.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ pub fn initialize_static() {
6161
/// It's used by other higher-level components like the UDP and HTTP trackers or the tracker API.
6262
#[must_use]
6363
pub fn initialize_tracker(config: &Arc<Configuration>) -> Tracker {
64-
tracker_factory(config.clone())
64+
tracker_factory(&config.clone())
6565
}
6666

6767
/// It initializes the log level, format and channel.

src/bootstrap/jobs/tracker_apis.rs

+11-4
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ use std::sync::Arc;
2626
use axum_server::tls_rustls::RustlsConfig;
2727
use log::info;
2828
use tokio::task::JoinHandle;
29-
use torrust_tracker_configuration::HttpApi;
29+
use torrust_tracker_configuration::{AccessTokens, HttpApi};
3030

3131
use super::make_rust_tls;
3232
use crate::core;
@@ -64,18 +64,25 @@ pub async fn start_job(config: &HttpApi, tracker: Arc<core::Tracker>, version: V
6464
.await
6565
.map(|tls| tls.expect("it should have a valid tracker api tls configuration"));
6666

67+
let access_tokens = Arc::new(config.access_tokens.clone());
68+
6769
match version {
68-
Version::V1 => Some(start_v1(bind_to, tls, tracker.clone()).await),
70+
Version::V1 => Some(start_v1(bind_to, tls, tracker.clone(), access_tokens).await),
6971
}
7072
} else {
7173
info!("Note: Not loading Http Tracker Service, Not Enabled in Configuration.");
7274
None
7375
}
7476
}
7577

76-
async fn start_v1(socket: SocketAddr, tls: Option<RustlsConfig>, tracker: Arc<core::Tracker>) -> JoinHandle<()> {
78+
async fn start_v1(
79+
socket: SocketAddr,
80+
tls: Option<RustlsConfig>,
81+
tracker: Arc<core::Tracker>,
82+
access_tokens: Arc<AccessTokens>,
83+
) -> JoinHandle<()> {
7784
let server = ApiServer::new(Launcher::new(socket, tls))
78-
.start(tracker)
85+
.start(tracker, access_tokens)
7986
.await
8087
.expect("it should be able to start to the tracker api");
8188

src/core/mod.rs

+49-17
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,7 @@ use std::time::Duration;
447447

448448
use derive_more::Constructor;
449449
use futures::future::join_all;
450+
use log::debug;
450451
use tokio::sync::mpsc::error::SendError;
451452
use torrust_tracker_configuration::Configuration;
452453
use torrust_tracker_primitives::TrackerMode;
@@ -472,17 +473,19 @@ pub const TORRENT_PEERS_LIMIT: usize = 74;
472473
/// Typically, the `Tracker` is used by a higher application service that handles
473474
/// the network layer.
474475
pub struct Tracker {
475-
/// `Tracker` configuration. See [`torrust-tracker-configuration`](torrust_tracker_configuration)
476-
pub config: Arc<Configuration>,
476+
announce_policy: AnnouncePolicy,
477477
/// A database driver implementation: [`Sqlite3`](crate::core::databases::sqlite)
478478
/// or [`MySQL`](crate::core::databases::mysql)
479479
pub database: Arc<Box<dyn Database>>,
480480
mode: TrackerMode,
481+
policy: TrackerPolicy,
481482
keys: tokio::sync::RwLock<std::collections::HashMap<Key, auth::ExpiringKey>>,
482483
whitelist: tokio::sync::RwLock<std::collections::HashSet<InfoHash>>,
483484
pub torrents: Arc<RepositoryAsyncSingle>,
484485
stats_event_sender: Option<Box<dyn statistics::EventSender>>,
485486
stats_repository: statistics::Repo,
487+
external_ip: Option<IpAddr>,
488+
on_reverse_proxy: bool,
486489
}
487490

488491
/// Structure that holds general `Tracker` torrents metrics.
@@ -500,6 +503,13 @@ pub struct TorrentsMetrics {
500503
pub torrents: u64,
501504
}
502505

506+
#[derive(Copy, Clone, Debug, PartialEq, Default, Constructor)]
507+
pub struct TrackerPolicy {
508+
pub remove_peerless_torrents: bool,
509+
pub max_peer_timeout: u32,
510+
pub persistent_torrent_completed_stat: bool,
511+
}
512+
503513
/// Tracker policy for Announcements
504514
#[derive(Copy, Clone, Debug, PartialEq, Default, Constructor)]
505515
pub struct AnnouncePolicy {
@@ -569,7 +579,7 @@ impl Tracker {
569579
///
570580
/// Will return a `databases::error::Error` if unable to connect to database. The `Tracker` is responsible for the persistence.
571581
pub fn new(
572-
config: Arc<Configuration>,
582+
config: &Arc<Configuration>,
573583
stats_event_sender: Option<Box<dyn statistics::EventSender>>,
574584
stats_repository: statistics::Repo,
575585
) -> Result<Tracker, databases::error::Error> {
@@ -578,14 +588,22 @@ impl Tracker {
578588
let mode = config.mode;
579589

580590
Ok(Tracker {
581-
config,
591+
//config,
592+
announce_policy: AnnouncePolicy::new(config.announce_interval, config.min_announce_interval),
582593
mode,
583594
keys: tokio::sync::RwLock::new(std::collections::HashMap::new()),
584595
whitelist: tokio::sync::RwLock::new(std::collections::HashSet::new()),
585596
torrents: Arc::new(RepositoryAsyncSingle::new()),
586597
stats_event_sender,
587598
stats_repository,
588599
database,
600+
external_ip: config.get_ext_ip(),
601+
policy: TrackerPolicy::new(
602+
config.remove_peerless_torrents,
603+
config.max_peer_timeout,
604+
config.persistent_torrent_completed_stat,
605+
),
606+
on_reverse_proxy: config.on_reverse_proxy,
589607
})
590608
}
591609

@@ -609,6 +627,19 @@ impl Tracker {
609627
self.is_private()
610628
}
611629

630+
/// Returns `true` is the tracker is in whitelisted mode.
631+
pub fn is_behind_reverse_proxy(&self) -> bool {
632+
self.on_reverse_proxy
633+
}
634+
635+
pub fn get_announce_policy(&self) -> AnnouncePolicy {
636+
self.announce_policy
637+
}
638+
639+
pub fn get_maybe_external_ip(&self) -> Option<IpAddr> {
640+
self.external_ip
641+
}
642+
612643
/// It handles an announce request.
613644
///
614645
/// # Context: Tracker
@@ -630,16 +661,17 @@ impl Tracker {
630661
// we are actually handling authentication at the handlers level. So I would extract that
631662
// responsibility into another authentication service.
632663

633-
peer.change_ip(&assign_ip_address_to_peer(remote_client_ip, self.config.get_ext_ip()));
664+
debug!("Before: {peer:?}");
665+
peer.change_ip(&assign_ip_address_to_peer(remote_client_ip, self.external_ip));
666+
debug!("After: {peer:?}");
634667

668+
// we should update the torrent and get the stats before we get the peer list.
635669
let swarm_stats = self.update_torrent_with_peer_and_get_stats(info_hash, peer).await;
636670

637671
let peers = self.get_torrent_peers_for_peer(info_hash, peer).await;
638672

639-
let policy = AnnouncePolicy::new(self.config.announce_interval, self.config.min_announce_interval);
640-
641673
AnnounceData {
642-
policy,
674+
policy: self.announce_policy,
643675
peers,
644676
swarm_stats,
645677
}
@@ -740,7 +772,7 @@ impl Tracker {
740772

741773
let (stats, stats_updated) = self.torrents.update_torrent_with_peer_and_get_stats(info_hash, peer).await;
742774

743-
if self.config.persistent_torrent_completed_stat && stats_updated {
775+
if self.policy.persistent_torrent_completed_stat && stats_updated {
744776
let completed = stats.completed;
745777
let info_hash = *info_hash;
746778

@@ -801,17 +833,17 @@ impl Tracker {
801833
let mut torrents_lock = self.torrents.get_torrents_mut().await;
802834

803835
// If we don't need to remove torrents we will use the faster iter
804-
if self.config.remove_peerless_torrents {
836+
if self.policy.remove_peerless_torrents {
805837
let mut cleaned_torrents_map: BTreeMap<InfoHash, torrent::Entry> = BTreeMap::new();
806838

807839
for (info_hash, torrent_entry) in &mut *torrents_lock {
808-
torrent_entry.remove_inactive_peers(self.config.max_peer_timeout);
840+
torrent_entry.remove_inactive_peers(self.policy.max_peer_timeout);
809841

810842
if torrent_entry.peers.is_empty() {
811843
continue;
812844
}
813845

814-
if self.config.persistent_torrent_completed_stat && torrent_entry.completed == 0 {
846+
if self.policy.persistent_torrent_completed_stat && torrent_entry.completed == 0 {
815847
continue;
816848
}
817849

@@ -821,7 +853,7 @@ impl Tracker {
821853
*torrents_lock = cleaned_torrents_map;
822854
} else {
823855
for torrent_entry in (*torrents_lock).values_mut() {
824-
torrent_entry.remove_inactive_peers(self.config.max_peer_timeout);
856+
torrent_entry.remove_inactive_peers(self.policy.max_peer_timeout);
825857
}
826858
}
827859
}
@@ -1086,21 +1118,21 @@ mod tests {
10861118
use crate::shared::clock::DurationSinceUnixEpoch;
10871119

10881120
fn public_tracker() -> Tracker {
1089-
tracker_factory(configuration::ephemeral_mode_public().into())
1121+
tracker_factory(&configuration::ephemeral_mode_public().into())
10901122
}
10911123

10921124
fn private_tracker() -> Tracker {
1093-
tracker_factory(configuration::ephemeral_mode_private().into())
1125+
tracker_factory(&configuration::ephemeral_mode_private().into())
10941126
}
10951127

10961128
fn whitelisted_tracker() -> Tracker {
1097-
tracker_factory(configuration::ephemeral_mode_whitelisted().into())
1129+
tracker_factory(&configuration::ephemeral_mode_whitelisted().into())
10981130
}
10991131

11001132
pub fn tracker_persisting_torrents_in_database() -> Tracker {
11011133
let mut configuration = configuration::ephemeral();
11021134
configuration.persistent_torrent_completed_stat = true;
1103-
tracker_factory(Arc::new(configuration))
1135+
tracker_factory(&Arc::new(configuration))
11041136
}
11051137

11061138
fn sample_info_hash() -> InfoHash {

src/core/services/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use crate::core::Tracker;
1919
///
2020
/// Will panic if tracker cannot be instantiated.
2121
#[must_use]
22-
pub fn tracker_factory(config: Arc<Configuration>) -> Tracker {
22+
pub fn tracker_factory(config: &Arc<Configuration>) -> Tracker {
2323
// Initialize statistics
2424
let (stats_event_sender, stats_repository) = statistics::setup::factory(config.tracker_usage_statistics);
2525

src/core/services/statistics/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ mod tests {
9898

9999
#[tokio::test]
100100
async fn the_statistics_service_should_return_the_tracker_metrics() {
101-
let tracker = Arc::new(tracker_factory(tracker_configuration()));
101+
let tracker = Arc::new(tracker_factory(&tracker_configuration()));
102102

103103
let tracker_metrics = get_metrics(tracker.clone()).await;
104104

src/core/services/torrent.rs

+7-7
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ mod tests {
174174

175175
#[tokio::test]
176176
async fn should_return_none_if_the_tracker_does_not_have_the_torrent() {
177-
let tracker = Arc::new(tracker_factory(tracker_configuration()));
177+
let tracker = Arc::new(tracker_factory(&tracker_configuration()));
178178

179179
let torrent_info = get_torrent_info(
180180
tracker.clone(),
@@ -187,7 +187,7 @@ mod tests {
187187

188188
#[tokio::test]
189189
async fn should_return_the_torrent_info_if_the_tracker_has_the_torrent() {
190-
let tracker = Arc::new(tracker_factory(tracker_configuration()));
190+
let tracker = Arc::new(tracker_factory(&tracker_configuration()));
191191

192192
let hash = "9e0217d0fa71c87332cd8bf9dbeabcb2c2cf3c4d".to_owned();
193193
let info_hash = InfoHash::from_str(&hash).unwrap();
@@ -229,7 +229,7 @@ mod tests {
229229

230230
#[tokio::test]
231231
async fn should_return_an_empty_result_if_the_tracker_does_not_have_any_torrent() {
232-
let tracker = Arc::new(tracker_factory(tracker_configuration()));
232+
let tracker = Arc::new(tracker_factory(&tracker_configuration()));
233233

234234
let torrents = get_torrents(tracker.clone(), &Pagination::default()).await;
235235

@@ -238,7 +238,7 @@ mod tests {
238238

239239
#[tokio::test]
240240
async fn should_return_a_summarized_info_for_all_torrents() {
241-
let tracker = Arc::new(tracker_factory(tracker_configuration()));
241+
let tracker = Arc::new(tracker_factory(&tracker_configuration()));
242242

243243
let hash = "9e0217d0fa71c87332cd8bf9dbeabcb2c2cf3c4d".to_owned();
244244
let info_hash = InfoHash::from_str(&hash).unwrap();
@@ -262,7 +262,7 @@ mod tests {
262262

263263
#[tokio::test]
264264
async fn should_allow_limiting_the_number_of_torrents_in_the_result() {
265-
let tracker = Arc::new(tracker_factory(tracker_configuration()));
265+
let tracker = Arc::new(tracker_factory(&tracker_configuration()));
266266

267267
let hash1 = "9e0217d0fa71c87332cd8bf9dbeabcb2c2cf3c4d".to_owned();
268268
let info_hash1 = InfoHash::from_str(&hash1).unwrap();
@@ -286,7 +286,7 @@ mod tests {
286286

287287
#[tokio::test]
288288
async fn should_allow_using_pagination_in_the_result() {
289-
let tracker = Arc::new(tracker_factory(tracker_configuration()));
289+
let tracker = Arc::new(tracker_factory(&tracker_configuration()));
290290

291291
let hash1 = "9e0217d0fa71c87332cd8bf9dbeabcb2c2cf3c4d".to_owned();
292292
let info_hash1 = InfoHash::from_str(&hash1).unwrap();
@@ -319,7 +319,7 @@ mod tests {
319319

320320
#[tokio::test]
321321
async fn should_return_torrents_ordered_by_info_hash() {
322-
let tracker = Arc::new(tracker_factory(tracker_configuration()));
322+
let tracker = Arc::new(tracker_factory(&tracker_configuration()));
323323

324324
let hash1 = "9e0217d0fa71c87332cd8bf9dbeabcb2c2cf3c4d".to_owned();
325325
let info_hash1 = InfoHash::from_str(&hash1).unwrap();

src/servers/apis/routes.rs

+6-5
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,27 @@ use std::sync::Arc;
99

1010
use axum::routing::get;
1111
use axum::{middleware, Router};
12+
use torrust_tracker_configuration::AccessTokens;
1213
use tower_http::compression::CompressionLayer;
1314

1415
use super::v1;
1516
use super::v1::context::health_check::handlers::health_check_handler;
17+
use super::v1::middlewares::auth::State;
1618
use crate::core::Tracker;
1719

1820
/// Add all API routes to the router.
1921
#[allow(clippy::needless_pass_by_value)]
20-
pub fn router(tracker: Arc<Tracker>) -> Router {
22+
pub fn router(tracker: Arc<Tracker>, access_tokens: Arc<AccessTokens>) -> Router {
2123
let router = Router::new();
2224

2325
let api_url_prefix = "/api";
2426

2527
let router = v1::routes::add(api_url_prefix, router, tracker.clone());
2628

29+
let state = State { access_tokens };
30+
2731
router
28-
.layer(middleware::from_fn_with_state(
29-
tracker.config.clone(),
30-
v1::middlewares::auth::auth,
31-
))
32+
.layer(middleware::from_fn_with_state(state, v1::middlewares::auth::auth))
3233
.route(&format!("{api_url_prefix}/health_check"), get(health_check_handler))
3334
.layer(CompressionLayer::new())
3435
}

0 commit comments

Comments
 (0)