Skip to content

Commit

Permalink
Get rid of sleep in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
xd009642 committed Mar 3, 2025
1 parent 23c90ce commit 26a8f19
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 43 deletions.
43 changes: 36 additions & 7 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use std::sync::{
Arc,
};
use std::time::Instant;
use tokio::sync::{broadcast, oneshot, RwLock};
use tokio::sync::{broadcast, oneshot, watch, Mutex, RwLock};
use tracing::{debug, error, Instrument};
use tungstenite::{
protocol::{frame::Utf8Bytes, CloseFrame},
Expand All @@ -47,6 +47,7 @@ pub struct MockServer {
addr: String,
shutdown: Option<oneshot::Sender<()>>,
mocks: MockList,
active_requests: Mutex<watch::Receiver<usize>>,
}

#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash)]
Expand Down Expand Up @@ -237,8 +238,17 @@ async fn ws_handler_pathless(
headers: HeaderMap,
params: Query<HashMap<String, String>>,
mocks: Extension<MockList>,
request_counter: Extension<watch::Sender<usize>>,
) -> Response {
ws_handler(ws, Path(String::new()), headers, params, mocks).await
ws_handler(
ws,
Path(String::new()),
headers,
params,
mocks,
request_counter,
)
.await
}

#[inline(always)]
Expand All @@ -259,7 +269,9 @@ async fn ws_handler(
headers: HeaderMap,
Query(params): Query<HashMap<String, String>>,
mocks: Extension<MockList>,
request_counter: Extension<watch::Sender<usize>>,
) -> Response {
request_counter.send_modify(|x| *x += 1);
let mut active_mocks = vec![];
let mut matched_mock: Option<ActiveMockCandidate> = None;
let mut current_mask = 0;
Expand Down Expand Up @@ -303,7 +315,12 @@ async fn ws_handler(

debug!("about to upgrade websocket connection");
ws.on_upgrade(move |socket| async move {
handle_socket(socket, mocks.0, active_mocks, current_mask).await
let res =
tokio::task::spawn(handle_socket(socket, mocks.0, active_mocks, current_mask)).await;
if let Err(res) = res {
error!("Task panicked: {}", res);
}
request_counter.send_modify(|x| *x -= 1);
})
}

Expand Down Expand Up @@ -476,9 +493,6 @@ async fn handle_socket(
if mask == active_mocks[0].expected_mask() && no_mismatch {
active_mocks[0].register_hit();
}
if let Some(hnd) = sender_task {
hnd.await.unwrap().unwrap();
}
}

impl MockServer {
Expand All @@ -493,10 +507,13 @@ impl MockServer {
pub async fn start() -> Self {
let mocks: MockList = Default::default();

let (tx, active_requests) = watch::channel(0);

let router = Router::new()
.route("/{*path}", any(ws_handler))
.route("/", any(ws_handler_pathless))
.layer(Extension(mocks.clone()));
.layer(Extension(mocks.clone()))
.layer(Extension(tx));
let listener = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap();
let addr = format!("ws://{}", listener.local_addr().unwrap());

Expand All @@ -513,6 +530,7 @@ impl MockServer {
addr,
shutdown: Some(tx),
mocks,
active_requests: Mutex::new(active_requests),
}
}

Expand Down Expand Up @@ -541,6 +559,17 @@ impl MockServer {
}

pub async fn mocks_pass(&self) -> bool {
let mut active_requests = self.active_requests.lock().await;
// If there's no more senders then in
if let Err(e) = active_requests
.wait_for(|x| {
debug!("Current active requests: {}", x);
*x == 0
})
.await
{
unreachable!("There should always be a sender while the server is running");
}
let mut res = true;
for (index, mock) in self.mocks.read().await.iter().enumerate() {
let mock_res = mock.verify();
Expand Down
37 changes: 17 additions & 20 deletions tests/advanced_matchers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use bytes::Bytes;
use futures::{SinkExt, StreamExt};
use serde_json::json;
use std::time::Duration;
use tokio::time::sleep;
use tokio_tungstenite::connect_async;
use tracing_test::traced_test;
use tungstenite::client::IntoClientRequest;
Expand All @@ -27,15 +26,23 @@ impl Match for BinaryStreamMatcher {
Some(false)
}
} else if last.is_close() {
let message = match_state.get_message(len - 2).unwrap();
let res = json.unary_match(message);
println!(
"Got close frame, checking {}: {:?}\ti{:?}",
len - 2,
res,
message
);
res
if len == 1 {
None
} else {
let message = match_state.get_message(len - 2);
if let Some(message) = message {
let res = json.unary_match(message);
println!(
"Got close frame, checking {}: {:?}\ti{:?}",
len - 2,
res,
message
);
res
} else {
Some(false)
}
}
} else if last.is_text() {
let res = json.unary_match(last);
match_state.keep_message(len - 1);
Expand Down Expand Up @@ -74,8 +81,6 @@ async fn binary_stream_matcher_passes() {
stream.send(Message::text(val.to_string())).await.unwrap();
stream.send(Message::Close(None)).await.unwrap();

sleep(Duration::from_millis(100)).await;

std::mem::drop(stream);

server.verify().await;
Expand All @@ -93,8 +98,6 @@ async fn binary_stream_matcher_passes() {
stream.send(Message::text(val.to_string())).await.unwrap();
stream.send(Message::Close(None)).await.unwrap();

sleep(Duration::from_millis(100)).await;

std::mem::drop(stream);

server.verify().await;
Expand Down Expand Up @@ -135,8 +138,6 @@ async fn binary_stream_matcher_fails() {

std::mem::drop(stream);

sleep(Duration::from_millis(100)).await;

assert!(!server.mocks_pass().await);

println!("Testing no end message");
Expand All @@ -154,8 +155,6 @@ async fn binary_stream_matcher_fails() {

std::mem::drop(stream);

sleep(Duration::from_millis(100)).await;

assert!(!server.mocks_pass().await);

let (mut stream, response) = connect_async(format!("{}/api/binary_stream", server.uri()))
Expand All @@ -172,7 +171,5 @@ async fn binary_stream_matcher_fails() {

std::mem::drop(stream);

sleep(Duration::from_millis(100)).await;

assert!(!server.mocks_pass().await);
}
17 changes: 1 addition & 16 deletions tests/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use bytes::Bytes;
use futures::{SinkExt, StreamExt};
use serde_json::json;
use std::time::Duration;
use tokio::time::sleep;
use tokio_tungstenite::connect_async;
use tracing_test::traced_test;
use tungstenite::client::IntoClientRequest;
Expand Down Expand Up @@ -35,7 +34,6 @@ async fn no_matches() {
let (mut stream, response) = connect_async(server.uri()).await.unwrap();

std::mem::drop(stream);
sleep(Duration::from_millis(100)).await;

server.verify().await;
assert!(logs_contain("mock[0]"));
Expand Down Expand Up @@ -65,7 +63,6 @@ async fn only_json_matcher() {

std::mem::drop(stream);
// todo there should be a better way than this
sleep(Duration::from_millis(100)).await;

server.verify().await;
}
Expand All @@ -87,9 +84,7 @@ async fn deny_invalid_json() {
let val = json!({"hello": "world"}).to_string().as_bytes().to_vec();
stream.send(Message::Ping(val.into())).await.unwrap();

// TODO there should be a better way than this
std::mem::drop(stream);
sleep(Duration::from_millis(100)).await;

server.verify().await;
}
Expand All @@ -112,7 +107,6 @@ async fn match_path() {
stream.send(Message::text(val.to_string())).await.unwrap();

std::mem::drop(stream);
sleep(Duration::from_millis(100)).await;

server.verify().await;
}
Expand All @@ -138,7 +132,6 @@ async fn header_exists() {
stream.send(Message::text(val.to_string())).await.unwrap();

std::mem::drop(stream);
sleep(Duration::from_millis(100)).await;

server.verify().await;
}
Expand All @@ -159,7 +152,6 @@ async fn header_doesnt_exist() {
stream.send(Message::text(val.to_string())).await.unwrap();

std::mem::drop(stream);
sleep(Duration::from_millis(100)).await;

server.verify().await;
}
Expand Down Expand Up @@ -188,7 +180,6 @@ async fn header_exactly_matches() {
let (stream, response) = connect_async(request).await.unwrap();

std::mem::drop(stream);
sleep(Duration::from_millis(100)).await;

server.verify().await;
}
Expand All @@ -211,7 +202,6 @@ async fn header_doesnt_match() {
let (stream, response) = connect_async(request).await.unwrap();

std::mem::drop(stream);
sleep(Duration::from_millis(100)).await;

server.verify().await;
}
Expand All @@ -233,7 +223,6 @@ async fn query_param_matchers() {
let (stream, response) = connect_async(uri).await.unwrap();

std::mem::drop(stream);
sleep(Duration::from_millis(100)).await;

server.verify().await;
}
Expand All @@ -258,7 +247,7 @@ async fn combine_request_and_content_matchers() {
// Send a message just to show it doesn't change anything.
let val = json!({"hello": "world"});
stream.send(Message::text(val.to_string())).await.unwrap();
sleep(Duration::from_millis(100)).await;
std::mem::drop(stream);

assert!(!server.mocks_pass().await);

Expand All @@ -271,7 +260,6 @@ async fn combine_request_and_content_matchers() {
stream.send(Message::text(val.to_string())).await.unwrap();

std::mem::drop(stream);
sleep(Duration::from_millis(100)).await;

assert!(server.mocks_pass().await);
}
Expand Down Expand Up @@ -306,7 +294,6 @@ async fn echo_response_test() {
assert_eq!(sent_message, echoed);

std::mem::drop(stream);
sleep(Duration::from_millis(100)).await;

assert!(server.mocks_pass().await);
}
Expand All @@ -331,7 +318,6 @@ async fn ensure_close_frame_sent() {
// Send a message just to show it doesn't change anything.
let val = json!({"hello": "world"});
stream.send(Message::text(val.to_string())).await.unwrap();
sleep(Duration::from_millis(100)).await;

std::mem::drop(stream);

Expand All @@ -347,7 +333,6 @@ async fn ensure_close_frame_sent() {
stream.send(Message::Close(None)).await.unwrap();

std::mem::drop(stream);
sleep(Duration::from_millis(100)).await;

assert!(server.mocks_pass().await);
}

0 comments on commit 26a8f19

Please sign in to comment.