Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Unit Tests for Network & Storage. And minor improvements #11

Merged
merged 13 commits into from
Aug 1, 2024
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ slog = "2.7.0"
slog-term = "2.9.1"
chrono = "0.4"

[dev-dependencies]
tempfile = "3.10.1"

[[example]]
name = "simple_run"
path = "examples/simple_run.rs"
9 changes: 4 additions & 5 deletions examples/simple_run.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
// Author: Vipul Vaibhaw
// Organization: SpacewalkHq
// License: MIT License

Expand Down Expand Up @@ -52,9 +51,9 @@ async fn main() {
}

// Simulate a client request after some delay
thread::sleep(Duration::from_secs(20));
client_request(1, 42 as u32).await;
thread::sleep(Duration::from_secs(2));
tokio::time::sleep(Duration::from_secs(20)).await;
client_request(1, 42u32).await;
tokio::time::sleep(Duration::from_secs(2)).await;
// Join all server threads
for handle in handles {
handle.join().unwrap();
Expand All @@ -65,7 +64,7 @@ async fn client_request(client_id: u32, data: u32) {
let log = get_logger();

let server_address = "127.0.0.1"; // Assuming server 1 is the leader
let network_manager = TCPManager::new(server_address.to_string(), 5001, log.clone());
let network_manager = TCPManager::new(server_address.to_string(), 5001);

let request_data = vec![
client_id.to_be_bytes().to_vec(),
Expand Down
168 changes: 140 additions & 28 deletions src/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
use crate::parse_ip_address;
use async_trait::async_trait;
use futures::future::join_all;
use slog::info;
use std::error::Error;
use std::net::SocketAddr;
use std::sync::Arc;
Expand All @@ -27,7 +26,7 @@ pub trait NetworkLayer: Send + Sync {
addresses: Vec<String>,
) -> Result<(), Box<dyn Error + Send + Sync>>;
async fn open(&self) -> Result<(), Box<dyn Error + Send + Sync>>;
async fn close(&self) -> Result<(), Box<dyn Error + Send + Sync>>;
async fn close(self) -> Result<(), Box<dyn Error + Send + Sync>>;
}

#[derive(Debug, Clone)]
Expand All @@ -36,17 +35,15 @@ pub struct TCPManager {
port: u16,
listener: Arc<Mutex<Option<TcpListener>>>,
is_open: Arc<Mutex<bool>>,
log: slog::Logger,
}

impl TCPManager {
pub fn new(address: String, port: u16, log: slog::Logger) -> Self {
pub fn new(address: String, port: u16) -> Self {
TCPManager {
address,
port,
listener: Arc::new(Mutex::new(None)),
is_open: Arc::new(Mutex::new(false)),
log,
}
}

Expand Down Expand Up @@ -116,53 +113,168 @@ impl NetworkLayer for TCPManager {
let listener = TcpListener::bind(addr).await?;
*self.listener.lock().await = Some(listener);
*is_open = true;
info!(self.log, "Listening on {}", addr);
Ok(())
}

async fn close(&self) -> Result<(), Box<dyn Error + Send + Sync>> {
async fn close(self) -> Result<(), Box<dyn Error + Send + Sync>> {
let mut is_open = self.is_open.lock().await;
if !*is_open {
return Err("Listener is not open".into());
}
*self.listener.lock().await = None;
*is_open = false;
info!(self.log, "Listener closed");
Ok(())
}
}

#[cfg(test)]
mod tests {
use slog::{o, Drain};
use crate::network::{NetworkLayer, TCPManager};
use tokio::task::JoinSet;

use super::*;

fn get_logger() -> slog::Logger {
let decorator = slog_term::PlainSyncDecorator::new(std::io::stdout());
let drain = slog_term::FullFormat::new(decorator).build().fuse();
let log = slog::Logger::root(drain, o!());
return log;
}
const LOCALHOST: &str = "127.0.0.1";

#[tokio::test]
async fn test_send() {
let network = TCPManager::new("127.0.0.1".to_string(), 8082, get_logger());
let network = TCPManager::new(LOCALHOST.to_string(), 8082);
let data = vec![1, 2, 3];
network.open().await.unwrap();
let network_clone = network.clone();
let handler = tokio::spawn(async move {
loop {
let data = network_clone.receive().await.unwrap();
if data.is_empty() {
continue;
} else {
assert_eq!(data, vec![1, 2, 3]);
break;
}
}
let _ = network_clone.receive().await.unwrap();
});
network.send("127.0.0.1", "8082", &data).await.unwrap();

let send_result = network.send(LOCALHOST, "8082", &data).await;
assert!(send_result.is_ok());

handler.await.unwrap();
}

#[tokio::test]
async fn test_send_closed_connection() {
let network = TCPManager::new(LOCALHOST.to_string(), 8020);
let data = vec![1, 2, 3];
network.open().await.unwrap();
let network_clone = network.clone();
tokio::spawn(async move {
let _ = network_clone.receive().await.unwrap();
});

let send_result = network.send(LOCALHOST, "8021", &data).await;
assert!(send_result.is_err());
}

#[tokio::test]
async fn test_receive_happy_case() {
let network = TCPManager::new(LOCALHOST.to_string(), 8030);
let data = vec![1, 2, 3];
network.open().await.unwrap();
let network_clone = network.clone();
let handler = tokio::spawn(async move { network_clone.receive().await.unwrap() });

network.send(LOCALHOST, "8030", &data).await.unwrap();
let rx_data = handler.await.unwrap();
assert_eq!(rx_data, data)
}

#[tokio::test]
async fn test_open() {
let network = TCPManager::new(LOCALHOST.to_string(), 8040);
let status = network.open().await;
assert!(status.is_ok());
assert!(*network.is_open.lock().await);
}

#[tokio::test]
async fn test_reopen_opened_port() {
let network = TCPManager::new(LOCALHOST.to_string(), 8042);
let status = network.open().await;
assert!(status.is_ok());
let another_network = network.clone();
let status = another_network.open().await;
assert!(status.is_err());
}

#[tokio::test]
async fn test_close() {
let network = TCPManager::new(LOCALHOST.to_string(), 8046);
let _ = network.open().await;

let close_status = network.close().await;
assert!(close_status.is_ok());
}

#[tokio::test]
async fn test_broadcast_happy_case() {
let data = vec![1, 2, 3, 4];
// server which is about to broadcast data
let broadcasting_node = TCPManager::new(LOCALHOST.to_string(), 8050);
broadcasting_node.open().await.unwrap();
assert!(*broadcasting_node.is_open.lock().await);

// vec to keep track of all other server which should be receiving data
let mut receivers = vec![];
// vec to keep track of the address of servers
let mut receiver_addresses = vec![];

for p in 8051..8060 {
// create receiver server
let rx = TCPManager::new(LOCALHOST.to_string(), p);
receiver_addresses.push(format!("{}:{}", LOCALHOST, p));

rx.open().await.unwrap();
assert!(*rx.is_open.lock().await);
receivers.push(rx)
}

let mut s = JoinSet::new();
for rx in receivers {
s.spawn(async move {
let rx_data = rx.receive().await;
assert!(rx_data.is_ok());
// return the received data
rx_data.unwrap()
});
}

// broadcast the message
let broadcast_result = broadcasting_node.broadcast(&data, receiver_addresses).await;
assert!(broadcast_result.is_ok());

// assert the data received on servers
while let Some(res) = s.join_next().await {
let rx_data = res.unwrap();
assert_eq!(data, rx_data)
}
}

#[tokio::test]
async fn test_broadcast_some_nodes_down() {
let data = vec![1, 2, 3, 4];
// server which is about to broadcast data
let broadcasting_node = TCPManager::new(LOCALHOST.to_string(), 8061);
broadcasting_node.open().await.unwrap();
assert!(*broadcasting_node.is_open.lock().await);

// vec to keep track of all servers which should be receiving data
let mut receivers = vec![];
// vec to keep track of the address
let mut receiver_addresses = vec![];
for p in 8062..8070 {
// Create a receiver node
let rx = TCPManager::new(LOCALHOST.to_string(), p);
receiver_addresses.push(format!("{}:{}", LOCALHOST, p));
// open connection for half server
// mocking rest half to be down
if p & 1 == 1 {
rx.open().await.unwrap();
assert!(*rx.is_open.lock().await);
}
receivers.push(rx)
}

// broadcast the data
let broadcast_result = broadcasting_node.broadcast(&data, receiver_addresses).await;
assert!(broadcast_result.is_err());
}
}
61 changes: 31 additions & 30 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ enum RaftState {
}

#[derive(Debug, Clone)]
enum MesageType {
enum MessageType {
RequestVote,
RequestVoteResponse,
AppendEntries,
Expand Down Expand Up @@ -118,13 +118,14 @@ impl Server {
last_heartbeat: Instant::now(),
votes_received: HashMap::new(),
};
let network_manager = TCPManager::new(config.address.clone(), config.port, log.clone());
let network_manager = TCPManager::new(config.address.clone(), config.port);

// if storage location is provided, use it else set empty string to use default location
let storage = match config.storage_location.clone() {
Some(location) => LocalStorage::new(location + &format!("server_{}.log", id)),
None => LocalStorage::new(format!("server_{}.log", id)),
let storage_location = match config.storage_location.clone() {
Some(location) => location + &format!("server_{}.log", id),
None => format!("server_{}.log", id),
};
let storage = LocalStorage::new(storage_location);

Server {
id,
Expand Down Expand Up @@ -432,44 +433,44 @@ impl Server {
}

let message_type = match message_type {
0 => MesageType::RequestVote,
1 => MesageType::RequestVoteResponse,
2 => MesageType::AppendEntries,
3 => MesageType::AppendEntriesResponse,
4 => MesageType::Heartbeat,
5 => MesageType::HeartbeatResponse,
6 => MesageType::ClientRequest,
7 => MesageType::ClientResponse,
8 => MesageType::RepairRequest,
9 => MesageType::RepairResponse,
10 => MesageType::JoinRequest,
11 => MesageType::JoinResponse,
0 => MessageType::RequestVote,
1 => MessageType::RequestVoteResponse,
2 => MessageType::AppendEntries,
3 => MessageType::AppendEntriesResponse,
4 => MessageType::Heartbeat,
5 => MessageType::HeartbeatResponse,
6 => MessageType::ClientRequest,
7 => MessageType::ClientResponse,
8 => MessageType::RepairRequest,
9 => MessageType::RepairResponse,
10 => MessageType::JoinRequest,
11 => MessageType::JoinResponse,
_ => return,
};

match message_type {
MesageType::RequestVote => {
MessageType::RequestVote => {
self.handle_request_vote(&data).await;
}
MesageType::RequestVoteResponse => {
MessageType::RequestVoteResponse => {
self.handle_request_vote_response(&data).await;
}
MesageType::AppendEntries => {
MessageType::AppendEntries => {
self.handle_append_entries(data).await;
}
MesageType::AppendEntriesResponse => {
MessageType::AppendEntriesResponse => {
self.handle_append_entries_response(&data).await;
}
MesageType::Heartbeat => {
MessageType::Heartbeat => {
self.handle_heartbeat().await;
}
MesageType::HeartbeatResponse => {
MessageType::HeartbeatResponse => {
self.handle_heartbeat_response().await;
}
MesageType::ClientRequest => {
MessageType::ClientRequest => {
self.handle_client_request(data).await;
}
MesageType::ClientResponse => {
MessageType::ClientResponse => {
// TODO: get implementation from user based on the application
info!(self.log, "Received client response: {:?}", data);
let data = u32::from_be_bytes(data[12..16].try_into().unwrap());
Expand All @@ -479,17 +480,17 @@ impl Server {
info!(self.log, "Consensus not reached!");
}
}
MesageType::RepairRequest => {
MessageType::RepairRequest => {
self.handle_repair_request(&data).await;
}
MesageType::RepairResponse => {
MessageType::RepairResponse => {
self.handle_repair_response(&data).await;
}
MesageType::JoinRequest => {
MessageType::JoinRequest => {
info!(self.log, "Received join request: {:?}", data);
self.handle_join_request(&data).await;
}
MesageType::JoinResponse => {
MessageType::JoinResponse => {
self.handle_join_response(&data).await;
}
}
Expand Down Expand Up @@ -918,7 +919,7 @@ impl Server {
}

#[allow(dead_code)]
async fn stop(&self) {
async fn stop(self) {
if let Err(e) = self.network_manager.close().await {
error!(self.log, "Failed to close network manager: {}", e);
}
Expand Down
Loading
Loading