Skip to content

Commit

Permalink
test: Merkle: Add test structure with reference impl
Browse files Browse the repository at this point in the history
  • Loading branch information
joeykraut committed Jan 28, 2025
1 parent 604158d commit d8670cf
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 32 deletions.
92 changes: 75 additions & 17 deletions test/Merkle.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -26,32 +26,88 @@ contract MerkleTest is TestUtils {
for (uint256 i = 0; i < MERKLE_DEPTH; i++) {
sisterLeaves[i] = randomFelt();
}
uint256 result = merklePoseidon.hashMerkle(input, idx, sisterLeaves);
console.log("result:", result);
uint256[] memory results = merklePoseidon.hashMerkle(idx, input, sisterLeaves);
uint256[] memory expected = runReferenceImpl(idx, input, sisterLeaves);
assertEq(results.length, MERKLE_DEPTH, "Expected 32 results");

for (uint256 i = 0; i < MERKLE_DEPTH; i++) {
assertEq(results[i], expected[i], string(abi.encodePacked("Result mismatch at index ", vm.toString(i))));
console.log("result[", i, "]:", results[i]);
}
}

// --- Helpers --- //

/// @dev Helper to run the reference implementation
function runReferenceImpl(uint256 input, uint256 idx, uint256[] memory sisterLeaves) internal returns (uint256) {
string[] memory args = new string[](4);
function runReferenceImpl(uint256 idx, uint256 input, uint256[] memory sisterLeaves)
internal
returns (uint256[] memory)
{
string[] memory args = new string[](35); // program name + idx + input + 32 sister leaves
args[0] = "./test/rust-reference-impls/target/debug/merkle";
args[1] = vm.toString(input);
args[2] = vm.toString(idx);
args[3] = arrayToString(sisterLeaves);
args[1] = vm.toString(idx);
args[2] = vm.toString(input);

// Pass sister leaves as individual arguments
for (uint256 i = 0; i < MERKLE_DEPTH; i++) {
args[i + 3] = vm.toString(sisterLeaves[i]);
}

bytes memory res = vm.ffi(args);
string memory str = string(res);

require(
bytes(str).length > 4 && bytes(str)[0] == "R" && bytes(str)[1] == "E" && bytes(str)[2] == "S"
&& bytes(str)[3] == ":",
"Invalid output format"
);
// Split by spaces and parse each value
string[] memory parts = split(str, " ");
require(parts.length == MERKLE_DEPTH, "Expected 32 values");

uint256[] memory values = new uint256[](MERKLE_DEPTH);
for (uint256 i = 0; i < MERKLE_DEPTH; i++) {
values[i] = vm.parseUint(parts[i]);
}

return values;
}

/// @dev Helper to split a string by a delimiter
function split(string memory _str, string memory _delim) internal pure returns (string[] memory) {
bytes memory str = bytes(_str);
bytes memory delim = bytes(_delim);

// Count number of delimiters to size array
uint256 count = 1;
for (uint256 i = 0; i < str.length; i++) {
if (str[i] == delim[0]) {
count++;
}
}

string[] memory parts = new string[](count);
count = 0;

// Track start of current part
uint256 start = 0;

// Split into parts
for (uint256 i = 0; i < str.length; i++) {
if (str[i] == delim[0]) {
parts[count] = substring(str, start, i);
start = i + 1;
count++;
}
}
// Add final part
parts[count] = substring(str, start, str.length);

return parts;
}

bytes memory hexBytes = new bytes(bytes(str).length - 4);
for (uint256 i = 4; i < bytes(str).length; i++) {
hexBytes[i - 4] = bytes(str)[i];
/// @dev Helper to get a substring
function substring(bytes memory _str, uint256 _start, uint256 _end) internal pure returns (string memory) {
bytes memory result = new bytes(_end - _start);
for (uint256 i = _start; i < _end; i++) {
result[i - _start] = _str[i];
}
return vm.parseUint(string(hexBytes));
return string(result);
}

function arrayToString(uint256[] memory arr) internal pure returns (string memory) {
Expand All @@ -68,5 +124,7 @@ contract MerkleTest is TestUtils {
}

interface MerklePoseidon {
function hashMerkle(uint256 input, uint256 idx, uint256[] calldata sisterLeaves) external returns (uint256);
function hashMerkle(uint256 idx, uint256 input, uint256[] calldata sisterLeaves)
external
returns (uint256[] memory);
}
44 changes: 29 additions & 15 deletions test/rust-reference-impls/merkle/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,41 +2,55 @@ use renegade_constants::Scalar;
use renegade_crypto::fields::scalar_to_biguint;
use renegade_crypto::hash::Poseidon2Sponge;

const NUM_VALUES: usize = 32;

fn main() {
let args: Vec<String> = std::env::args().collect();
if args.len() < 4 {
eprintln!("Usage: {} <input> <idx> <sister_leaves>", args[0]);
if args.len() != NUM_VALUES + 3 {
eprintln!(
"Usage: {} <idx> <input> <sister_leaf_1> <sister_leaf_2> ... <sister_leaf_32>",
args[0]
);
std::process::exit(1);
}

let input = Scalar::from_decimal_string(&args[1]).unwrap();
let idx = args[2].parse::<u64>().unwrap();
let sister_leaves: Vec<Scalar> = args[3]
.trim_matches(|c| c == '[' || c == ']')
.split(',')
let idx = args[1].parse::<u64>().unwrap();
let input = Scalar::from_decimal_string(&args[2]).unwrap();

// Parse sister leaves directly from arguments
let sister_leaves: Vec<Scalar> = args[3..NUM_VALUES + 3]
.iter()
.map(|s| Scalar::from_decimal_string(s).unwrap())
.collect();

let result = hash_merkle(input, idx, &sister_leaves);
let res_biguint = scalar_to_biguint(&result);
let res_hex = format!("{res_biguint:x}");
println!("RES:0x{}", res_hex);
let results = hash_merkle(idx, input, &sister_leaves);

// Output results as space-separated decimal values
let result_strings: Vec<String> = results
.iter()
.map(|r| scalar_to_biguint(r).to_string())
.collect();

println!("{}", result_strings.join(" "));
}

fn hash_merkle(input: Scalar, idx: u64, sister_leaves: &[Scalar]) -> Scalar {
fn hash_merkle(idx: u64, input: Scalar, sister_leaves: &[Scalar]) -> Vec<Scalar> {
let mut results = Vec::with_capacity(NUM_VALUES);
let mut current = input;
let mut current_idx = idx;
let mut sponge = Poseidon2Sponge::new();

for sister in sister_leaves {
for i in 0..NUM_VALUES {
let sister = &sister_leaves[i];
let inputs = if current_idx % 2 == 0 {
[current.inner(), sister.inner()]
} else {
[sister.inner(), current.inner()]
};
let mut sponge = Poseidon2Sponge::new();
current = Scalar::new(sponge.hash(&inputs));
results.push(current);
current_idx /= 2;
}

current
results
}

0 comments on commit d8670cf

Please sign in to comment.