Skip to content

Commit

Permalink
crypto: merkle: Fix data pointer misalignment bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
joeykraut committed Jan 28, 2025
1 parent 07274d8 commit bf7a846
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 32 deletions.
63 changes: 43 additions & 20 deletions src/crypto/merkle/main.huff
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,26 @@
/// 32 bytes offset + 32 bytes length + 32 * 32 bytes data
#define constant RETURN_ARRAY_SIZE = 0x440 // 1088

// --- Entrypoint --- //

/// @notice Entrypoint for the Poseidon2 Merkle hash
#define macro MAIN() = takes(0) returns(0) {
// Currently we only have one function, so we don't need to mux the selector
0x04 calldataload // [idx]
0x24 calldataload // [input, idx]
// Push the data pointer of the return array onto the stack
// The helpers below expect the data pointer to trail the inputs
0x40 // [*returnArr]

0x44 calldataload // [dataOffset, input, idx]
0x64 calldataload // [len(sisterLeaves), dataOffset, input, idx]

// Check that the length of the sister leaves is the correct length
[MERKLE_DEPTH] eq validLen jumpi // [dataOffset, idx, input]
// Load the inputs from calldata
0x04 calldataload // [idx, *returnArr]
0x24 calldataload // [input, idx, *returnArr]

// Load the data offset from the array start, then convert it to a data pointer
0x44 calldataload // [dataOffset, input, idx, *returnArr]
0x24 add // [*sisterLeaves, input, idx, *returnArr]

// Load the length of the sister leaves array, then check its length
0x64 calldataload // [len(sisterLeaves), dataOffset, input, idx, *returnArr]
[MERKLE_DEPTH] eq validLen jumpi // [dataOffset, idx, input, *returnArr]

[SISTER_LEAVES_LENGTH_MISMATCH] 0x00 mstore
0x20 0x00 revert
Expand All @@ -46,16 +55,13 @@
// Setup the return value
// 1. Store the data offset at position 0x00
// Offset is always 32 bytes in our case
0x20 0x00 mstore // [dataOffset, idx, input]
0x20 0x00 mstore // [dataOffset, idx, input, *returnArr]

// 2. Store the length of the return array at position 0x20
// Length is always 32 in our case
0x20 0x20 mstore // [dataOffset, idx, input]
0x20 0x20 mstore // [dataOffset, idx, input, *returnArr]

// 3. Iteratively hash up the tree and store intermediate values in the array
// Iteratively hash the input into the sister leaves
0x40 // [*returnArr, dataOffset, input, idx]

HASH_AND_STORE_MERKLE_LEVEL() // Level 1
HASH_AND_STORE_MERKLE_LEVEL() // Level 2
HASH_AND_STORE_MERKLE_LEVEL() // Level 3
Expand Down Expand Up @@ -90,7 +96,6 @@
HASH_AND_STORE_MERKLE_LEVEL() // Level 32

// Return the array
RETURN_FIRST()
[RETURN_ARRAY_SIZE] 0x00 return
}

Expand Down Expand Up @@ -118,17 +123,17 @@
#define macro HASH_MERKLE_LEVEL() = takes(3) returns(3) {
// Takes [*sister, input, idx]
// Compute the next sister node location then move to the back of the resident stack
dup1 0x20 add // [*nextSister, *sister, input, idx]
swap3 swap1 // [*sister, idx, input, *nextSister]
calldataload // [sister, idx, input, *nextSister]
dup1 0x20 add // [*nextSister, *sister, input, idx]
swap3 swap1 // [*sister, idx, input, *nextSister]
calldataload // [sister, idx, input, *nextSister]

// Compute the next index then move it to the back of the resident stack
dup2 0x01 add shr // [nextIdx, sister, idx, input, *nextSister]
swap3 // [input, sister, idx, *nextSister, nextIdx]
dup2 0x01 shr // [nextIdx, sister, idx, input, *nextSister]
swap4 swap3 // [input, sister, idx, *nextSister, nextIdx]

// Hash the input with the sister node
HASH_TWO_LEAVES() // [hash, *nextSister, nextIdx]
swap1 // [*nextSister, hash, nextIdx]
HASH_TWO_LEAVES() // [hash, *nextSister, nextIdx]
swap1 // [*nextSister, hash, nextIdx]
}

/// @dev Hash the given input with the given sister leaf as if inserting into the given index
Expand Down Expand Up @@ -166,6 +171,24 @@
0x20 0x00 return
}

/// @dev Return the second value on the stack
#define macro RETURN_SECOND() = returns(0) {
dup2 0x00 mstore
0x20 0x00 return
}

/// @dev Return the third value on the stack
#define macro RETURN_THIRD() = returns(0) {
dup3 0x00 mstore
0x20 0x00 return
}

/// @dev Return the fourth value on the stack
#define macro RETURN_FOURTH() = returns(0) {
dup4 0x00 mstore
0x20 0x00 return
}

/// @notice Write a zero return array to memory then return
#define macro RETURN_ZERO_ARRAY() = takes(0) returns(0) {
// First 32 bytes: array length (number of elements)
Expand Down
8 changes: 3 additions & 5 deletions test/Merkle.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,11 @@ contract MerkleTest is TestUtils {
for (uint256 i = 0; i < MERKLE_DEPTH; i++) {
sisterLeaves[i] = randomFelt();
}
uint256[] memory result = merklePoseidon.hashMerkle(input, idx, sisterLeaves);
console.log("result length:", result.length);
uint256 result = merklePoseidon.hashMerkle(input, idx, sisterLeaves);
console.log("result:", result);
}
}

interface MerklePoseidon {
function hashMerkle(uint256 input, uint256 idx, uint256[] calldata sisterLeaves)
external
returns (uint256[] memory);
function hashMerkle(uint256 input, uint256 idx, uint256[] calldata sisterLeaves) external returns (uint256);
}
8 changes: 4 additions & 4 deletions test/Poseidon.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ contract PoseidonTest is TestUtils {

/// @dev Deploy the PoseidonSuite contract
function setUp() public {
poseidonSuite = PoseidonSuite(HuffDeployer.deploy("test/huff/testPoseidonUtils"));
poseidonSuite = PoseidonSuite(HuffDeployer.deploy("../test/huff/testPoseidonUtils"));
}

/// @dev Test the sbox function applied to a single input
Expand Down Expand Up @@ -150,7 +150,7 @@ contract PoseidonTest is TestUtils {
}

/// @dev Calculate the result of the internal MDS matrix applied to the inputs
function internalMds(uint256 a, uint256 b, uint256 c) internal view returns (uint256, uint256, uint256) {
function internalMds(uint256 a, uint256 b, uint256 c) internal pure returns (uint256, uint256, uint256) {
uint256 sum = sumInputs(a, b, c);
uint256 a1 = addmod(a, sum, PRIME);
uint256 b1 = addmod(b, sum, PRIME);
Expand All @@ -159,7 +159,7 @@ contract PoseidonTest is TestUtils {
}

/// @dev Calculate the result of the external MDS matrix applied to the inputs
function externalMds(uint256 a, uint256 b, uint256 c) internal view returns (uint256, uint256, uint256) {
function externalMds(uint256 a, uint256 b, uint256 c) internal pure returns (uint256, uint256, uint256) {
uint256 sum = sumInputs(a, b, c);
uint256 a1 = addmod(a, sum, PRIME);
uint256 b1 = addmod(b, sum, PRIME);
Expand All @@ -186,7 +186,7 @@ contract PoseidonTest is TestUtils {
}

/// @dev Sum the inputs and return the result
function sumInputs(uint256 a, uint256 b, uint256 c) internal view returns (uint256) {
function sumInputs(uint256 a, uint256 b, uint256 c) internal pure returns (uint256) {
uint256 sum = addmod(a, b, PRIME);
sum = addmod(sum, c, PRIME);
return sum;
Expand Down
4 changes: 1 addition & 3 deletions test/poseidon-reference-implementation/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
use num_bigint::BigUint;
use renegade_constants::Scalar;
use renegade_crypto::fields::{biguint_to_scalar, scalar_to_biguint};
use renegade_crypto::fields::scalar_to_biguint;
use renegade_crypto::hash::Poseidon2Sponge;
use std::env;
use std::str::FromStr;

fn main() {
let args: Vec<String> = env::args().collect();
Expand Down

0 comments on commit bf7a846

Please sign in to comment.