Skip to content

Commit

Permalink
Add Unit Tests for Network & Storage. And minor improvements (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
dharanad authored Aug 1, 2024
1 parent 7462d6b commit 50b9f16
Show file tree
Hide file tree
Showing 5 changed files with 366 additions and 109 deletions.
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ slog = "2.7.0"
slog-term = "2.9.1"
chrono = "0.4"

[dev-dependencies]
tempfile = "3.10.1"

[[example]]
name = "simple_run"
path = "examples/simple_run.rs"
9 changes: 4 additions & 5 deletions examples/simple_run.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
// Author: Vipul Vaibhaw
// Organization: SpacewalkHq
// License: MIT License

Expand Down Expand Up @@ -52,9 +51,9 @@ async fn main() {
}

// Simulate a client request after some delay
thread::sleep(Duration::from_secs(20));
client_request(1, 42 as u32).await;
thread::sleep(Duration::from_secs(2));
tokio::time::sleep(Duration::from_secs(20)).await;
client_request(1, 42u32).await;
tokio::time::sleep(Duration::from_secs(2)).await;
// Join all server threads
for handle in handles {
handle.join().unwrap();
Expand All @@ -65,7 +64,7 @@ async fn client_request(client_id: u32, data: u32) {
let log = get_logger();

let server_address = "127.0.0.1"; // Assuming server 1 is the leader
let network_manager = TCPManager::new(server_address.to_string(), 5001, log.clone());
let network_manager = TCPManager::new(server_address.to_string(), 5001);

let request_data = vec![
client_id.to_be_bytes().to_vec(),
Expand Down
168 changes: 140 additions & 28 deletions src/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
use crate::parse_ip_address;
use async_trait::async_trait;
use futures::future::join_all;
use slog::info;
use std::error::Error;
use std::net::SocketAddr;
use std::sync::Arc;
Expand All @@ -27,7 +26,7 @@ pub trait NetworkLayer: Send + Sync {
addresses: Vec<String>,
) -> Result<(), Box<dyn Error + Send + Sync>>;
async fn open(&self) -> Result<(), Box<dyn Error + Send + Sync>>;
async fn close(&self) -> Result<(), Box<dyn Error + Send + Sync>>;
async fn close(self) -> Result<(), Box<dyn Error + Send + Sync>>;
}

#[derive(Debug, Clone)]
Expand All @@ -36,17 +35,15 @@ pub struct TCPManager {
port: u16,
listener: Arc<Mutex<Option<TcpListener>>>,
is_open: Arc<Mutex<bool>>,
log: slog::Logger,
}

impl TCPManager {
pub fn new(address: String, port: u16, log: slog::Logger) -> Self {
pub fn new(address: String, port: u16) -> Self {
TCPManager {
address,
port,
listener: Arc::new(Mutex::new(None)),
is_open: Arc::new(Mutex::new(false)),
log,
}
}

Expand Down Expand Up @@ -116,53 +113,168 @@ impl NetworkLayer for TCPManager {
let listener = TcpListener::bind(addr).await?;
*self.listener.lock().await = Some(listener);
*is_open = true;
info!(self.log, "Listening on {}", addr);
Ok(())
}

async fn close(&self) -> Result<(), Box<dyn Error + Send + Sync>> {
async fn close(self) -> Result<(), Box<dyn Error + Send + Sync>> {
let mut is_open = self.is_open.lock().await;
if !*is_open {
return Err("Listener is not open".into());
}
*self.listener.lock().await = None;
*is_open = false;
info!(self.log, "Listener closed");
Ok(())
}
}

#[cfg(test)]
mod tests {
use slog::{o, Drain};
use crate::network::{NetworkLayer, TCPManager};
use tokio::task::JoinSet;

use super::*;

fn get_logger() -> slog::Logger {
let decorator = slog_term::PlainSyncDecorator::new(std::io::stdout());
let drain = slog_term::FullFormat::new(decorator).build().fuse();
let log = slog::Logger::root(drain, o!());
return log;
}
const LOCALHOST: &str = "127.0.0.1";

#[tokio::test]
async fn test_send() {
let network = TCPManager::new("127.0.0.1".to_string(), 8082, get_logger());
let network = TCPManager::new(LOCALHOST.to_string(), 8082);
let data = vec![1, 2, 3];
network.open().await.unwrap();
let network_clone = network.clone();
let handler = tokio::spawn(async move {
loop {
let data = network_clone.receive().await.unwrap();
if data.is_empty() {
continue;
} else {
assert_eq!(data, vec![1, 2, 3]);
break;
}
}
let _ = network_clone.receive().await.unwrap();
});
network.send("127.0.0.1", "8082", &data).await.unwrap();

let send_result = network.send(LOCALHOST, "8082", &data).await;
assert!(send_result.is_ok());

handler.await.unwrap();
}

#[tokio::test]
async fn test_send_closed_connection() {
let network = TCPManager::new(LOCALHOST.to_string(), 8020);
let data = vec![1, 2, 3];
network.open().await.unwrap();
let network_clone = network.clone();
tokio::spawn(async move {
let _ = network_clone.receive().await.unwrap();
});

let send_result = network.send(LOCALHOST, "8021", &data).await;
assert!(send_result.is_err());
}

#[tokio::test]
async fn test_receive_happy_case() {
let network = TCPManager::new(LOCALHOST.to_string(), 8030);
let data = vec![1, 2, 3];
network.open().await.unwrap();
let network_clone = network.clone();
let handler = tokio::spawn(async move { network_clone.receive().await.unwrap() });

network.send(LOCALHOST, "8030", &data).await.unwrap();
let rx_data = handler.await.unwrap();
assert_eq!(rx_data, data)
}

#[tokio::test]
async fn test_open() {
let network = TCPManager::new(LOCALHOST.to_string(), 8040);
let status = network.open().await;
assert!(status.is_ok());
assert!(*network.is_open.lock().await);
}

#[tokio::test]
async fn test_reopen_opened_port() {
let network = TCPManager::new(LOCALHOST.to_string(), 8042);
let status = network.open().await;
assert!(status.is_ok());
let another_network = network.clone();
let status = another_network.open().await;
assert!(status.is_err());
}

#[tokio::test]
async fn test_close() {
let network = TCPManager::new(LOCALHOST.to_string(), 8046);
let _ = network.open().await;

let close_status = network.close().await;
assert!(close_status.is_ok());
}

#[tokio::test]
async fn test_broadcast_happy_case() {
let data = vec![1, 2, 3, 4];
// server which is about to broadcast data
let broadcasting_node = TCPManager::new(LOCALHOST.to_string(), 8050);
broadcasting_node.open().await.unwrap();
assert!(*broadcasting_node.is_open.lock().await);

// vec to keep track of all other server which should be receiving data
let mut receivers = vec![];
// vec to keep track of the address of servers
let mut receiver_addresses = vec![];

for p in 8051..8060 {
// create receiver server
let rx = TCPManager::new(LOCALHOST.to_string(), p);
receiver_addresses.push(format!("{}:{}", LOCALHOST, p));

rx.open().await.unwrap();
assert!(*rx.is_open.lock().await);
receivers.push(rx)
}

let mut s = JoinSet::new();
for rx in receivers {
s.spawn(async move {
let rx_data = rx.receive().await;
assert!(rx_data.is_ok());
// return the received data
rx_data.unwrap()
});
}

// broadcast the message
let broadcast_result = broadcasting_node.broadcast(&data, receiver_addresses).await;
assert!(broadcast_result.is_ok());

// assert the data received on servers
while let Some(res) = s.join_next().await {
let rx_data = res.unwrap();
assert_eq!(data, rx_data)
}
}

#[tokio::test]
async fn test_broadcast_some_nodes_down() {
let data = vec![1, 2, 3, 4];
// server which is about to broadcast data
let broadcasting_node = TCPManager::new(LOCALHOST.to_string(), 8061);
broadcasting_node.open().await.unwrap();
assert!(*broadcasting_node.is_open.lock().await);

// vec to keep track of all servers which should be receiving data
let mut receivers = vec![];
// vec to keep track of the address
let mut receiver_addresses = vec![];
for p in 8062..8070 {
// Create a receiver node
let rx = TCPManager::new(LOCALHOST.to_string(), p);
receiver_addresses.push(format!("{}:{}", LOCALHOST, p));
// open connection for half server
// mocking rest half to be down
if p & 1 == 1 {
rx.open().await.unwrap();
assert!(*rx.is_open.lock().await);
}
receivers.push(rx)
}

// broadcast the data
let broadcast_result = broadcasting_node.broadcast(&data, receiver_addresses).await;
assert!(broadcast_result.is_err());
}
}
61 changes: 31 additions & 30 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ enum RaftState {
}

#[derive(Debug, Clone)]
enum MesageType {
enum MessageType {
RequestVote,
RequestVoteResponse,
AppendEntries,
Expand Down Expand Up @@ -118,13 +118,14 @@ impl Server {
last_heartbeat: Instant::now(),
votes_received: HashMap::new(),
};
let network_manager = TCPManager::new(config.address.clone(), config.port, log.clone());
let network_manager = TCPManager::new(config.address.clone(), config.port);

// if storage location is provided, use it else set empty string to use default location
let storage = match config.storage_location.clone() {
Some(location) => LocalStorage::new(location + &format!("server_{}.log", id)),
None => LocalStorage::new(format!("server_{}.log", id)),
let storage_location = match config.storage_location.clone() {
Some(location) => location + &format!("server_{}.log", id),
None => format!("server_{}.log", id),
};
let storage = LocalStorage::new(storage_location);

Server {
id,
Expand Down Expand Up @@ -432,44 +433,44 @@ impl Server {
}

let message_type = match message_type {
0 => MesageType::RequestVote,
1 => MesageType::RequestVoteResponse,
2 => MesageType::AppendEntries,
3 => MesageType::AppendEntriesResponse,
4 => MesageType::Heartbeat,
5 => MesageType::HeartbeatResponse,
6 => MesageType::ClientRequest,
7 => MesageType::ClientResponse,
8 => MesageType::RepairRequest,
9 => MesageType::RepairResponse,
10 => MesageType::JoinRequest,
11 => MesageType::JoinResponse,
0 => MessageType::RequestVote,
1 => MessageType::RequestVoteResponse,
2 => MessageType::AppendEntries,
3 => MessageType::AppendEntriesResponse,
4 => MessageType::Heartbeat,
5 => MessageType::HeartbeatResponse,
6 => MessageType::ClientRequest,
7 => MessageType::ClientResponse,
8 => MessageType::RepairRequest,
9 => MessageType::RepairResponse,
10 => MessageType::JoinRequest,
11 => MessageType::JoinResponse,
_ => return,
};

match message_type {
MesageType::RequestVote => {
MessageType::RequestVote => {
self.handle_request_vote(&data).await;
}
MesageType::RequestVoteResponse => {
MessageType::RequestVoteResponse => {
self.handle_request_vote_response(&data).await;
}
MesageType::AppendEntries => {
MessageType::AppendEntries => {
self.handle_append_entries(data).await;
}
MesageType::AppendEntriesResponse => {
MessageType::AppendEntriesResponse => {
self.handle_append_entries_response(&data).await;
}
MesageType::Heartbeat => {
MessageType::Heartbeat => {
self.handle_heartbeat().await;
}
MesageType::HeartbeatResponse => {
MessageType::HeartbeatResponse => {
self.handle_heartbeat_response().await;
}
MesageType::ClientRequest => {
MessageType::ClientRequest => {
self.handle_client_request(data).await;
}
MesageType::ClientResponse => {
MessageType::ClientResponse => {
// TODO: get implementation from user based on the application
info!(self.log, "Received client response: {:?}", data);
let data = u32::from_be_bytes(data[12..16].try_into().unwrap());
Expand All @@ -479,17 +480,17 @@ impl Server {
info!(self.log, "Consensus not reached!");
}
}
MesageType::RepairRequest => {
MessageType::RepairRequest => {
self.handle_repair_request(&data).await;
}
MesageType::RepairResponse => {
MessageType::RepairResponse => {
self.handle_repair_response(&data).await;
}
MesageType::JoinRequest => {
MessageType::JoinRequest => {
info!(self.log, "Received join request: {:?}", data);
self.handle_join_request(&data).await;
}
MesageType::JoinResponse => {
MessageType::JoinResponse => {
self.handle_join_response(&data).await;
}
}
Expand Down Expand Up @@ -918,7 +919,7 @@ impl Server {
}

#[allow(dead_code)]
async fn stop(&self) {
async fn stop(self) {
if let Err(e) = self.network_manager.close().await {
error!(self.log, "Failed to close network manager: {}", e);
}
Expand Down
Loading

0 comments on commit 50b9f16

Please sign in to comment.