Skip to content

Commit 79bb1ed

Browse files
committed
refactor: added tests to key_manager
1 parent 0afd619 commit 79bb1ed

File tree

4 files changed

+90
-39
lines changed

4 files changed

+90
-39
lines changed

src/http_server.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ impl HttpServer {
173173

174174
// query.info_hash somehow receives a corrupt string
175175
// so we have to get the info_hash manually from the raw query
176-
let info_hashes = HttpServer::get_info_hashes_from_raw_query(&raw_query);
176+
let info_hashes = HttpServer::info_hashes_from_raw_query(&raw_query);
177177
if info_hashes.len() < 1 { return HttpServer::send_error("info_hash not found") }
178178
query.info_hash = info_hashes[0].to_string();
179179
debug!("{:?}", query.info_hash);
@@ -198,7 +198,7 @@ impl HttpServer {
198198
})
199199
.and_then(move |(key, raw_query, http_server): (Option<String>, String, Arc<HttpServer>)| {
200200
async move {
201-
let info_hashes = HttpServer::get_info_hashes_from_raw_query(&raw_query);
201+
let info_hashes = HttpServer::info_hashes_from_raw_query(&raw_query);
202202
if info_hashes.len() < 1 { return HttpServer::send_error("info_hash not found") }
203203
if info_hashes.len() > 50 { return HttpServer::send_error("exceeded the max of 50 info_hashes") }
204204
debug!("{:?}", info_hashes);
@@ -214,7 +214,7 @@ impl HttpServer {
214214
warp::any().and(announce_route.or(scrape_route))
215215
}
216216

217-
fn get_info_hashes_from_raw_query(raw_query: &str) -> Vec<InfoHash> {
217+
fn info_hashes_from_raw_query(raw_query: &str) -> Vec<InfoHash> {
218218
let split_raw_query: Vec<&str> = raw_query.split("&").collect();
219219
let mut info_hashes: Vec<InfoHash> = Vec::new();
220220

src/key_manager.rs

+80-29
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,29 @@ use serde::Serialize;
66
use log::debug;
77
use derive_more::{Display, Error};
88

9+
pub fn generate_auth_key(seconds_valid: u64) -> AuthKey {
10+
let key: String = thread_rng()
11+
.sample_iter(&Alphanumeric)
12+
.take(AUTH_KEY_LENGTH)
13+
.map(char::from)
14+
.collect();
15+
16+
debug!("Generated key: {}, valid for: {} seconds", key, seconds_valid);
17+
18+
AuthKey {
19+
key,
20+
valid_until: Some(current_time() + seconds_valid),
21+
}
22+
}
23+
24+
pub fn verify_auth_key(auth_key: &AuthKey) -> Result<(), Error> {
25+
let current_time = current_time();
26+
if auth_key.valid_until.is_none() { return Err(Error::KeyInvalid) }
27+
if auth_key.valid_until.unwrap() < current_time { return Err(Error::KeyExpired) }
28+
29+
Ok(())
30+
}
31+
932
#[derive(Serialize, Debug, Eq, PartialEq, Clone)]
1033
pub struct AuthKey {
1134
pub key: String,
@@ -14,27 +37,35 @@ pub struct AuthKey {
1437

1538
impl AuthKey {
1639
pub fn from_buffer(key_buffer: [u8; AUTH_KEY_LENGTH]) -> Option<AuthKey> {
17-
Some(AuthKey {
18-
key: String::from_utf8(Vec::from(key_buffer)).unwrap(),
19-
valid_until: None,
20-
})
40+
if let Ok(key) = String::from_utf8(Vec::from(key_buffer)) {
41+
Some(AuthKey {
42+
key,
43+
valid_until: None,
44+
})
45+
} else {
46+
None
47+
}
2148
}
2249

2350
pub fn from_string(key: &str) -> Option<AuthKey> {
24-
if key.len() != AUTH_KEY_LENGTH { return None }
25-
26-
Some(AuthKey {
27-
key: key.to_string(),
28-
valid_until: None,
29-
})
51+
if key.len() != AUTH_KEY_LENGTH {
52+
None
53+
} else {
54+
Some(AuthKey {
55+
key: key.to_string(),
56+
valid_until: None,
57+
})
58+
}
3059
}
3160
}
3261

3362
#[derive(Debug, Display, PartialEq, Error)]
3463
#[allow(dead_code)]
3564
pub enum Error {
36-
#[display(fmt = "Key is invalid.")]
65+
#[display(fmt = "Key could not be verified.")]
3766
KeyVerificationError,
67+
#[display(fmt = "Key is invalid.")]
68+
KeyInvalid,
3869
#[display(fmt = "Key has expired.")]
3970
KeyExpired
4071
}
@@ -46,29 +77,49 @@ impl From<r2d2_sqlite::rusqlite::Error> for Error {
4677
}
4778
}
4879

49-
pub struct KeyManager;
80+
#[cfg(test)]
81+
mod tests {
82+
use crate::key_manager;
5083

51-
impl KeyManager {
52-
pub fn generate_auth_key(&self, seconds_valid: u64) -> AuthKey {
53-
let key: String = thread_rng()
54-
.sample_iter(&Alphanumeric)
55-
.take(AUTH_KEY_LENGTH)
56-
.map(char::from)
57-
.collect();
84+
#[test]
85+
fn auth_key_from_buffer() {
86+
let auth_key = key_manager::AuthKey::from_buffer(
87+
[
88+
89, 90, 83, 108,
89+
52, 108, 77, 90,
90+
117, 112, 82, 117,
91+
79, 112, 83, 82,
92+
67, 51, 107, 114,
93+
73, 75, 82, 53,
94+
66, 80, 66, 49,
95+
52, 110, 114, 74]
96+
);
5897

59-
debug!("Generated key: {}, valid for: {} seconds", key, seconds_valid);
98+
assert!(auth_key.is_some());
99+
assert_eq!(auth_key.unwrap().key, "YZSl4lMZupRuOpSRC3krIKR5BPB14nrJ");
100+
}
60101

61-
AuthKey {
62-
key,
63-
valid_until: Some(current_time() + seconds_valid),
64-
}
102+
#[test]
103+
fn auth_key_from_string() {
104+
let key_string = "YZSl4lMZupRuOpSRC3krIKR5BPB14nrJ";
105+
let auth_key = key_manager::AuthKey::from_string(key_string);
106+
107+
assert!(auth_key.is_some());
108+
assert_eq!(auth_key.unwrap().key, key_string);
109+
}
110+
111+
#[test]
112+
fn generate_valid_auth_key() {
113+
let auth_key = key_manager::generate_auth_key(9999);
114+
115+
assert!(key_manager::verify_auth_key(&auth_key).is_ok());
65116
}
66117

67-
pub async fn verify_auth_key(&self, auth_key: &AuthKey) -> Result<(), Error> {
68-
let current_time = current_time();
69-
if auth_key.valid_until.is_none() { return Err(Error::KeyVerificationError) }
70-
if &auth_key.valid_until.unwrap() < &current_time { return Err(Error::KeyExpired) }
118+
#[test]
119+
fn generate_expired_auth_key() {
120+
let mut auth_key = key_manager::generate_auth_key(0);
121+
auth_key.valid_until = Some(0);
71122

72-
Ok(())
123+
assert!(key_manager::verify_auth_key(&auth_key).is_err());
73124
}
74125
}

src/request.rs

+2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use std::io;
44
use std::io::{Cursor, Read};
55
use byteorder::{NetworkEndian, ReadBytesExt};
66
use std::convert::TryInto;
7+
use log::debug;
78
use crate::key_manager::AuthKey;
89

910
#[derive(PartialEq, Eq, Clone, Debug)]
@@ -169,6 +170,7 @@ impl Request {
169170
// key should be the last bytes
170171
cursor.set_position((bytes.len() - AUTH_KEY_LENGTH) as u64);
171172
if cursor.read_exact(&mut key_buffer).is_ok() {
173+
debug!("AuthKey buffer: {:?}", key_buffer);
172174
AuthKey::from_buffer(key_buffer)
173175
} else {
174176
None

src/tracker.rs

+5-7
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use std::collections::btree_map::Entry;
1010
use crate::database::SqliteDatabase;
1111
use std::sync::Arc;
1212
use log::debug;
13-
use crate::key_manager::{AuthKey, KeyManager};
13+
use crate::key_manager::{AuthKey};
1414
use r2d2_sqlite::rusqlite;
1515
use crate::http_server::HttpAnnounceRequest;
1616

@@ -250,7 +250,6 @@ pub struct TorrentTracker {
250250
pub config: Arc<Configuration>,
251251
torrents: tokio::sync::RwLock<std::collections::BTreeMap<InfoHash, TorrentEntry>>,
252252
database: SqliteDatabase,
253-
key_manager: KeyManager,
254253
}
255254

256255
impl TorrentTracker {
@@ -263,12 +262,11 @@ impl TorrentTracker {
263262
config,
264263
torrents: RwLock::new(std::collections::BTreeMap::new()),
265264
database,
266-
key_manager: KeyManager {}
267265
}
268266
}
269267

270268
pub async fn generate_auth_key(&self, seconds_valid: u64) -> Result<AuthKey, rusqlite::Error> {
271-
let auth_key = self.key_manager.generate_auth_key(seconds_valid);
269+
let auth_key = key_manager::generate_auth_key(seconds_valid);
272270

273271
// add key to database
274272
if let Err(error) = self.database.add_key_to_keys(&auth_key).await { return Err(error) }
@@ -282,7 +280,7 @@ impl TorrentTracker {
282280

283281
pub async fn verify_auth_key(&self, auth_key: &AuthKey) -> Result<(), key_manager::Error> {
284282
let db_key = self.database.get_key_from_keys(&auth_key.key).await?;
285-
self.key_manager.verify_auth_key(&db_key).await
283+
key_manager::verify_auth_key(&db_key)
286284
}
287285

288286
pub async fn authenticate_request(&self, info_hash: &InfoHash, key: &Option<AuthKey>) -> Result<(), TorrentError> {
@@ -298,7 +296,7 @@ impl TorrentTracker {
298296
TrackerMode::PrivateMode => {
299297
match key {
300298
Some(key) => {
301-
if self.key_manager.verify_auth_key(key).await.is_err() {
299+
if key_manager::verify_auth_key(key).is_err() {
302300
return Err(TorrentError::PeerKeyNotValid)
303301
}
304302

@@ -312,7 +310,7 @@ impl TorrentTracker {
312310
TrackerMode::PrivateListedMode => {
313311
match key {
314312
Some(key) => {
315-
if self.key_manager.verify_auth_key(key).await.is_err() {
313+
if key_manager::verify_auth_key(key).is_err() {
316314
return Err(TorrentError::PeerKeyNotValid)
317315
}
318316

0 commit comments

Comments
 (0)