Skip to content

Commit

Permalink
test: Merkle: Test against Arkworks merkle implementation (#11)
Browse files Browse the repository at this point in the history
* test: rust-reference-impls: Refactor reference impls into directory

* test: Merkle: Add test structure with reference impl

* test: rust-reference-impl: merkle: Add Arkworks helper test
  • Loading branch information
joeykraut authored Jan 28, 2025
1 parent 53199dc commit 8a1ef54
Show file tree
Hide file tree
Showing 7 changed files with 336 additions and 12 deletions.
101 changes: 97 additions & 4 deletions test/Merkle.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,112 @@ contract MerkleTest is TestUtils {
merklePoseidon = MerklePoseidon(HuffDeployer.deploy("crypto/merkle/main"));
}

/// @dev Test the hashMerkle function
/// @dev Test the hashMerkle function with sequential inserts
function testHashMerkle() public {
uint256 input = 1;
uint256 idx = 15;
uint256[] memory sisterLeaves = new uint256[](MERKLE_DEPTH);
for (uint256 i = 0; i < MERKLE_DEPTH; i++) {
sisterLeaves[i] = randomFelt();
}
uint256 result = merklePoseidon.hashMerkle(input, idx, sisterLeaves);
console.log("result:", result);
uint256[] memory results = merklePoseidon.hashMerkle(idx, input, sisterLeaves);
uint256[] memory expected = runReferenceImpl(idx, input, sisterLeaves);
assertEq(results.length, MERKLE_DEPTH, "Expected 32 results");

for (uint256 i = 0; i < MERKLE_DEPTH; i++) {
assertEq(results[i], expected[i], string(abi.encodePacked("Result mismatch at index ", vm.toString(i))));
}
}

// --- Helpers --- //

/// @dev Helper to run the reference implementation
function runReferenceImpl(uint256 idx, uint256 input, uint256[] memory sisterLeaves)
internal
returns (uint256[] memory)
{
string[] memory args = new string[](35); // program name + idx + input + 32 sister leaves
args[0] = "./test/rust-reference-impls/target/debug/merkle";
args[1] = vm.toString(idx);
args[2] = vm.toString(input);

// Pass sister leaves as individual arguments
for (uint256 i = 0; i < MERKLE_DEPTH; i++) {
args[i + 3] = vm.toString(sisterLeaves[i]);
}

bytes memory res = vm.ffi(args);
string memory str = string(res);

// Split by spaces and parse each value
string[] memory parts = split(str, " ");
require(parts.length == MERKLE_DEPTH, "Expected 32 values");

uint256[] memory values = new uint256[](MERKLE_DEPTH);
for (uint256 i = 0; i < MERKLE_DEPTH; i++) {
values[i] = vm.parseUint(parts[i]);
}

return values;
}

/// @dev Helper to split a string by a delimiter
function split(string memory _str, string memory _delim) internal pure returns (string[] memory) {
bytes memory str = bytes(_str);
bytes memory delim = bytes(_delim);

// Count number of delimiters to size array
uint256 count = 1;
for (uint256 i = 0; i < str.length; i++) {
if (str[i] == delim[0]) {
count++;
}
}

string[] memory parts = new string[](count);
count = 0;

// Track start of current part
uint256 start = 0;

// Split into parts
for (uint256 i = 0; i < str.length; i++) {
if (str[i] == delim[0]) {
parts[count] = substring(str, start, i);
start = i + 1;
count++;
}
}
// Add final part
parts[count] = substring(str, start, str.length);

return parts;
}

/// @dev Helper to get a substring
function substring(bytes memory _str, uint256 _start, uint256 _end) internal pure returns (string memory) {
bytes memory result = new bytes(_end - _start);
for (uint256 i = _start; i < _end; i++) {
result[i - _start] = _str[i];
}
return string(result);
}

function arrayToString(uint256[] memory arr) internal pure returns (string memory) {
string memory result = "[";
for (uint256 i = 0; i < arr.length; i++) {
if (i > 0) {
result = string(abi.encodePacked(result, ","));
}
result = string(abi.encodePacked(result, vm.toString(arr[i])));
}
result = string(abi.encodePacked(result, "]"));
return result;
}
}

interface MerklePoseidon {
function hashMerkle(uint256 input, uint256 idx, uint256[] calldata sisterLeaves) external returns (uint256);
function hashMerkle(uint256 idx, uint256 input, uint256[] calldata sisterLeaves)
external
returns (uint256[] memory);
}
4 changes: 2 additions & 2 deletions test/Poseidon.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,12 @@ contract PoseidonTest is TestUtils {
compileInputs[1] = "build";
compileInputs[2] = "--quiet";
compileInputs[3] = "--manifest-path";
compileInputs[4] = "test/poseidon-reference-implementation/Cargo.toml";
compileInputs[4] = "test/rust-reference-impls/poseidon/Cargo.toml";
vm.ffi(compileInputs);

// Now run the binary directly from target/debug
string[] memory runInputs = new string[](3);
runInputs[0] = "./test/poseidon-reference-implementation/target/debug/poseidon-reference-implementation";
runInputs[0] = "./test/rust-reference-impls/target/debug/poseidon";
runInputs[1] = vm.toString(a);
runInputs[2] = vm.toString(b);

Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
[package]
name = "poseidon-reference-implementation"
version = "0.1.0"
edition = "2021"
[workspace]
members = ["poseidon", "merkle"]

[dependencies]
[workspace.dependencies]
renegade-constants = { package = "constants", git = "https://github.com/renegade-fi/renegade.git", default-features = false }
renegade-crypto = { git = "https://github.com/renegade-fi/renegade.git" }

num-bigint = "0.4"
18 changes: 18 additions & 0 deletions test/rust-reference-impls/merkle/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
[package]
name = "merkle"
version = "0.1.0"
edition = "2021"

[dependencies]
ark-crypto-primitives = { version = "0.4", features = [
"crh",
"merkle_tree",
"sponge",
] }
clap = { version = "4.5.1", features = ["derive"] }
itertools = "0.14"
num-bigint = { workspace = true }
rand = "0.8"

renegade-constants = { workspace = true }
renegade-crypto = { workspace = true }
207 changes: 207 additions & 0 deletions test/rust-reference-impls/merkle/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
use clap::Parser;
use renegade_constants::Scalar;
use renegade_crypto::fields::scalar_to_biguint;
use renegade_crypto::hash::compute_poseidon_hash;

/// The height of the Merkle tree
const TREE_HEIGHT: usize = 32;

#[derive(Parser)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Index in the Merkle tree
idx: u64,

/// Input value
input: String,

/// Sister leaves (32 values required)
#[arg(num_args = 32)]
sister_leaves: Vec<String>,
}

fn main() {
let args = Args::parse();

if args.sister_leaves.len() != TREE_HEIGHT {
eprintln!(
"Expected {} sister leaves, got {}",
TREE_HEIGHT,
args.sister_leaves.len()
);
std::process::exit(1);
}

let input = Scalar::from_decimal_string(&args.input).unwrap();

// Parse sister leaves directly from arguments
let sister_leaves: Vec<Scalar> = args
.sister_leaves
.iter()
.map(|s| Scalar::from_decimal_string(s).unwrap())
.collect();

let results = hash_merkle(args.idx, input, &sister_leaves);

// Output results as space-separated decimal values
let result_strings: Vec<String> = results
.iter()
.map(|r| scalar_to_biguint(r).to_string())
.collect();

println!("{}", result_strings.join(" "));
}

/// Hash the input through the Merkle tree using the given sister nodes
///
/// Returns the incremental results at each level, representing the updated values to the insertion path
fn hash_merkle(idx: u64, input: Scalar, sister_leaves: &[Scalar]) -> Vec<Scalar> {
let mut results = Vec::with_capacity(TREE_HEIGHT);
let mut current = input;
let mut current_idx = idx;

for sister in sister_leaves.iter().copied() {
// The input is a left-hand node if the index is even at this level
let inputs = if current_idx % 2 == 0 {
[current, sister]
} else {
[sister, current]
};

current = compute_poseidon_hash(&inputs);
results.push(current);
current_idx /= 2;
}

results
}

#[cfg(test)]
mod tests {
//! We test the Merkle tree helper above against the reference implementation in Arkworks
//! for a known reference implementation.
//!
//! It is difficult to test the huff contracts against the Arkworks impl because the Arkworks impl
//! handles deep trees very inefficiently, making a 32-depth tree impossible to run.
//!
//! Instead, we opt to test our helper against Arkworks on a shallower tree, thereby testing the
//! huff implementation only transitively.
use std::borrow::Borrow;

use ark_crypto_primitives::{
crh::{CRHScheme, TwoToOneCRHScheme},
merkle_tree::{Config, IdentityDigestConverter, MerkleTree},
};
use rand::{thread_rng, Rng};
use renegade_constants::{Scalar, ScalarField};
use renegade_crypto::hash::compute_poseidon_hash;

use crate::hash_merkle;

/// The height of the Merkle tree
const TEST_TREE_HEIGHT: usize = 10;
/// The number of leaves in the tree
const N_LEAVES: usize = 1 << (TEST_TREE_HEIGHT - 1);

// --- Hash Impls --- //

struct IdentityHasher;
impl CRHScheme for IdentityHasher {
type Input = ScalarField;
type Output = ScalarField;
type Parameters = ();

fn setup<R: Rng>(_: &mut R) -> Result<Self::Parameters, ark_crypto_primitives::Error> {
Ok(())
}

fn evaluate<T: Borrow<Self::Input>>(
_parameters: &Self::Parameters,
input: T,
) -> Result<Self::Output, ark_crypto_primitives::Error> {
Ok(*input.borrow())
}
}

/// A dummy hasher to build an arkworks Merkle tree on top of
struct Poseidon2Hasher;
impl TwoToOneCRHScheme for Poseidon2Hasher {
type Input = ScalarField;
type Output = ScalarField;
type Parameters = ();

fn setup<R: Rng>(_: &mut R) -> Result<Self::Parameters, ark_crypto_primitives::Error> {
Ok(())
}

fn evaluate<T: Borrow<Self::Input>>(
_parameters: &Self::Parameters,
left_input: T,
right_input: T,
) -> Result<Self::Output, ark_crypto_primitives::Error> {
let lhs = Scalar::new(*left_input.borrow());
let rhs = Scalar::new(*right_input.borrow());
let res = compute_poseidon_hash(&[lhs, rhs]);

Ok(res.inner())
}

fn compress<T: Borrow<Self::Output>>(
parameters: &Self::Parameters,
left_input: T,
right_input: T,
) -> Result<Self::Output, ark_crypto_primitives::Error> {
<Self as TwoToOneCRHScheme>::evaluate(parameters, left_input, right_input)
}
}

struct MerkleConfig {}
impl Config for MerkleConfig {
type Leaf = ScalarField;
type LeafDigest = ScalarField;
type InnerDigest = ScalarField;

type LeafHash = IdentityHasher;
type TwoToOneHash = Poseidon2Hasher;
type LeafInnerDigestConverter = IdentityDigestConverter<ScalarField>;
}

/// Build an arkworks tree and fill it with random values
fn build_arkworks_tree() -> MerkleTree<MerkleConfig> {
let mut rng = thread_rng();

let mut tree = MerkleTree::<MerkleConfig>::blank(&(), &(), TEST_TREE_HEIGHT).unwrap();
for i in 0..N_LEAVES {
let leaf = Scalar::random(&mut rng);
tree.update(i, &leaf.inner()).unwrap();
}

tree
}

/// Test the Merkle helper against an arkworks tree
#[test]
fn test_merkle_tree() {
// Build an arkworks tree and fill it with random values
let mut rng = thread_rng();
let mut tree = build_arkworks_tree();

// Choose a random index to update into
let idx = rng.gen_range(0..N_LEAVES);
let input = Scalar::random(&mut rng);

// Get a sibling path for the input
let path = tree.generate_proof(idx).unwrap();
let mut sister_scalars = vec![Scalar::new(path.leaf_sibling_hash)];
sister_scalars.extend(path.auth_path.into_iter().rev().map(Scalar::new));

// Get the updated path
let res = hash_merkle(idx as u64, input, &sister_scalars);
let new_root = res.last().unwrap();

// Update the tree with the input
tree.update(idx, &input.inner()).unwrap();
assert_eq!(tree.root(), new_root.inner());
}
}
9 changes: 9 additions & 0 deletions test/rust-reference-impls/poseidon/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[package]
name = "poseidon"
version = "0.1.0"
edition = "2021"

[dependencies]
renegade-constants.workspace = true
renegade-crypto.workspace = true
num-bigint.workspace = true

0 comments on commit 8a1ef54

Please sign in to comment.