Skip to content

Commit 53fca7f

Browse files
authored
Merge pull request #7 from dfns/state-machine
Add state machine
2 parents fe42d6b + 8f713a9 commit 53fca7f

File tree

21 files changed

+1165
-58
lines changed

21 files changed

+1165
-58
lines changed

.github/workflows/rust.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ jobs:
6060
with:
6161
cache-on-failure: "true"
6262
- name: Run clippy
63-
run: cargo clippy --all --lib
63+
run: cargo clippy --all --lib --all-features -- -D clippy::all
6464

6565
build-wasm-nostd:
6666
runs-on: ubuntu-latest

Cargo.lock

+12
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

examples/random-generation-protocol/Cargo.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@ thiserror = { version = "1", optional = true }
2020
generic-array = { version = "0.14", features = ["serde"] }
2121

2222
[dev-dependencies]
23-
round-based = { path = "../../round-based", features = ["derive", "dev"] }
23+
round-based = { path = "../../round-based", features = ["derive", "dev", "state-machine"] }
2424
tokio = { version = "1.15", features = ["macros", "rt"] }
2525
futures = "0.3"
2626
hex = "0.4"
2727
rand_dev = "0.1"
28+
rand = "0.8"
2829

2930
[features]
3031
std = ["thiserror"]

examples/random-generation-protocol/src/lib.rs

+150-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#![no_std]
44
#![forbid(unused_crate_dependencies, missing_docs)]
55

6-
#[cfg(feature = "std")]
6+
#[cfg(any(feature = "std", test))]
77
extern crate std;
88

99
extern crate alloc;
@@ -173,14 +173,16 @@ pub struct Blame {
173173

174174
#[cfg(test)]
175175
mod tests {
176-
use alloc::vec;
176+
use alloc::{vec, vec::Vec};
177177

178+
use rand::Rng;
178179
use round_based::simulation::Simulation;
180+
use sha2::{Digest, Sha256};
179181

180182
use super::{protocol_of_random_generation, Msg};
181183

182184
#[tokio::test]
183-
async fn main() {
185+
async fn simulation_async() {
184186
let mut rng = rand_dev::DevRng::new();
185187

186188
let n: u16 = 5;
@@ -203,4 +205,149 @@ mod tests {
203205

204206
std::println!("Output randomness: {}", hex::encode(output[0]));
205207
}
208+
209+
#[test]
210+
fn simulation_sync() {
211+
let mut rng = rand_dev::DevRng::new();
212+
213+
let simulation = round_based::simulation::SimulationSync::from_async_fn(5, |i, party| {
214+
protocol_of_random_generation(party, i, 5, rng.fork())
215+
});
216+
217+
let outputs = simulation
218+
.run()
219+
.unwrap()
220+
.into_iter()
221+
.collect::<Result<Vec<_>, _>>()
222+
.unwrap();
223+
for output_i in &outputs {
224+
assert_eq!(*output_i, outputs[0]);
225+
}
226+
}
227+
228+
// Emulate the protocol using the state machine interface
229+
#[test]
230+
fn state_machine() {
231+
use super::{CommitMsg, DecommitMsg, Msg};
232+
use round_based::{
233+
state_machine::{ProceedResult, StateMachine},
234+
Incoming, Outgoing,
235+
};
236+
237+
let mut rng = rand_dev::DevRng::new();
238+
239+
let party1_rng: [u8; 32] = rng.gen();
240+
let party1_com = Sha256::digest(party1_rng);
241+
242+
let party2_rng: [u8; 32] = rng.gen();
243+
let party2_com = Sha256::digest(party2_rng);
244+
245+
// Start the protocol
246+
let mut party0 = round_based::state_machine::wrap_protocol(|party| async {
247+
protocol_of_random_generation(party, 0, 3, rng).await
248+
});
249+
250+
// Round 1
251+
252+
// Party sends its commitment
253+
let ProceedResult::SendMsg(Outgoing {
254+
msg: Msg::CommitMsg(party0_com),
255+
..
256+
}) = party0.proceed()
257+
else {
258+
panic!("unexpected response")
259+
};
260+
261+
// Round 2
262+
263+
// Party needs messages sent by other parties in round 1
264+
let ProceedResult::NeedsOneMoreMessage = party0.proceed() else {
265+
panic!("unexpected response")
266+
};
267+
// Provide message from party 1
268+
party0
269+
.received_msg(Incoming {
270+
id: 0,
271+
sender: 1,
272+
msg_type: round_based::MessageType::Broadcast,
273+
msg: Msg::CommitMsg(CommitMsg {
274+
commitment: party1_com,
275+
}),
276+
})
277+
.unwrap();
278+
let ProceedResult::NeedsOneMoreMessage = party0.proceed() else {
279+
panic!("unexpected response")
280+
};
281+
// Provide message from party 2
282+
party0
283+
.received_msg(Incoming {
284+
id: 1,
285+
sender: 2,
286+
msg_type: round_based::MessageType::Broadcast,
287+
msg: Msg::CommitMsg(CommitMsg {
288+
commitment: party2_com,
289+
}),
290+
})
291+
.unwrap();
292+
293+
// Party sends message in round 2
294+
let ProceedResult::SendMsg(Outgoing {
295+
msg: Msg::DecommitMsg(party0_rng),
296+
..
297+
}) = party0.proceed()
298+
else {
299+
panic!("unexpected response")
300+
};
301+
302+
{
303+
// Check that commitment matches the revealed randomness
304+
let expected = Sha256::digest(party0_rng.randomness);
305+
assert_eq!(party0_com.commitment, expected);
306+
}
307+
308+
// Final round
309+
310+
// Party needs messages sent by other parties in round 2
311+
let ProceedResult::NeedsOneMoreMessage = party0.proceed() else {
312+
panic!("unexpected response")
313+
};
314+
// Provide message from party 1
315+
party0
316+
.received_msg(Incoming {
317+
id: 3,
318+
sender: 1,
319+
msg_type: round_based::MessageType::Broadcast,
320+
msg: Msg::DecommitMsg(DecommitMsg {
321+
randomness: party1_rng,
322+
}),
323+
})
324+
.unwrap();
325+
let ProceedResult::NeedsOneMoreMessage = party0.proceed() else {
326+
panic!("unexpected response")
327+
};
328+
// Provide message from party 2
329+
party0
330+
.received_msg(Incoming {
331+
id: 3,
332+
sender: 2,
333+
msg_type: round_based::MessageType::Broadcast,
334+
msg: Msg::DecommitMsg(DecommitMsg {
335+
randomness: party2_rng,
336+
}),
337+
})
338+
.unwrap();
339+
// Obtain the protocol result
340+
let ProceedResult::Output(Ok(output_rng)) = party0.proceed() else {
341+
panic!("unexpected response")
342+
};
343+
344+
let output_expected = party0_rng
345+
.randomness
346+
.iter()
347+
.zip(&party1_rng)
348+
.zip(&party2_rng)
349+
.map(|((a, b), c)| a ^ b ^ c)
350+
.collect::<alloc::vec::Vec<_>>();
351+
assert_eq!(output_rng, output_expected.as_slice());
352+
}
206353
}

round-based/CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
## v0.3.0
22
* Add no_std and wasm support [#6]
3+
* Add state machine wrapper that provides sync API to carry out the protocol defined as async function [#7]
34

45
[#6]: https://github.com/dfns/round-based/pull/6
6+
[#7]: https://github.com/dfns/round-based/pull/7
57

68
## v0.2.2
79

round-based/Cargo.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ tokio = { version = "1", features = ["macros"] }
3535

3636
[features]
3737
default = ["std"]
38-
dev = ["tokio/sync", "tokio-stream"]
38+
state-machine = []
39+
dev = ["std", "tokio/sync", "tokio-stream"]
3940
derive = ["round-based-derive"]
4041
runtime-tokio = ["tokio"]
4142
std = ["thiserror"]

round-based/src/lib.rs

+2
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ mod delivery;
6060
pub mod party;
6161
pub mod rounds_router;
6262
pub mod runtime;
63+
#[cfg(feature = "state-machine")]
64+
pub mod state_machine;
6365

6466
#[cfg(feature = "dev")]
6567
pub mod simulation;

round-based/src/party.rs

-2
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,6 @@ pub struct MpcParty<M, D, R = runtime::DefaultRuntime> {
106106

107107
impl<M, D> MpcParty<M, D>
108108
where
109-
M: Send + 'static,
110109
D: Delivery<M>,
111110
{
112111
/// Party connected to the network
@@ -123,7 +122,6 @@ where
123122

124123
impl<M, D, X> MpcParty<M, D, X>
125124
where
126-
M: Send + 'static,
127125
D: Delivery<M>,
128126
{
129127
/// Specifies a [async runtime](runtime)

round-based/src/rounds_router/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ where
169169
}
170170
}
171171

172+
#[allow(clippy::type_complexity)]
172173
fn retrieve_round_output_if_its_completed<R>(
173174
&mut self,
174175
) -> Option<Result<R::Output, CompleteRoundError<R::Error, E>>>
@@ -321,6 +322,7 @@ trait ProcessRoundMessage {
321322
/// * `Ok(Ok(any))` — round is successfully completed, `any` needs to be downcasted to `MessageStore::Output`
322323
/// * `Ok(Err(any))` — round has terminated with an error, `any` needs to be downcasted to `CompleteRoundError<MessageStore::Error>`
323324
/// * `Err(err)` — couldn't retrieve the output, see [`TakeOutputError`]
325+
#[allow(clippy::type_complexity)]
324326
fn take_output(&mut self) -> Result<Result<Box<dyn Any>, Box<dyn Any>>, TakeOutputError>;
325327
}
326328

round-based/src/runtime.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
/// function.
1010
pub trait AsyncRuntime {
1111
/// Future type returned by [yield_now](Self::yield_now)
12-
type YieldNowFuture: core::future::Future<Output = ()> + Send + 'static;
12+
type YieldNowFuture: core::future::Future<Output = ()>;
1313

1414
/// Yields the execution back to the runtime
1515
///

round-based/src/simulation/mod.rs

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
//! Multiparty protocol simulation
2+
//!
3+
//! [`Simulation`] is an essential developer tool for testing the multiparty protocol locally.
4+
//! It covers most of the boilerplate by mocking networking.
5+
//!
6+
//! ## Example
7+
//!
8+
//! ```rust
9+
//! use round_based::{Mpc, PartyIndex};
10+
//! use round_based::simulation::Simulation;
11+
//! use futures::future::try_join_all;
12+
//!
13+
//! # type Result<T, E = ()> = std::result::Result<T, E>;
14+
//! # type Randomness = [u8; 32];
15+
//! # type Msg = ();
16+
//! // Any MPC protocol you want to test
17+
//! pub async fn protocol_of_random_generation<M>(party: M, i: PartyIndex, n: u16) -> Result<Randomness>
18+
//! where M: Mpc<ProtocolMessage = Msg>
19+
//! {
20+
//! // ...
21+
//! # todo!()
22+
//! }
23+
//!
24+
//! async fn test_randomness_generation() {
25+
//! let n = 3;
26+
//!
27+
//! let mut simulation = Simulation::<Msg>::new();
28+
//! let mut outputs = vec![];
29+
//! for i in 0..n {
30+
//! let party = simulation.add_party();
31+
//! outputs.push(protocol_of_random_generation(party, i, n));
32+
//! }
33+
//!
34+
//! // Waits each party to complete the protocol
35+
//! let outputs = try_join_all(outputs).await.expect("protocol wasn't completed successfully");
36+
//! // Asserts that all parties output the same randomness
37+
//! for output in outputs.iter().skip(1) {
38+
//! assert_eq!(&outputs[0], output);
39+
//! }
40+
//! }
41+
//! ```
42+
43+
mod sim_async;
44+
#[cfg(feature = "state-machine")]
45+
mod sim_sync;
46+
47+
pub use sim_async::*;
48+
#[cfg(feature = "state-machine")]
49+
pub use sim_sync::*;

0 commit comments

Comments
 (0)