Skip to content

Commit

Permalink
crypto: merkle: Implement Merkle hash with intermediate values (#10)
Browse files Browse the repository at this point in the history
* crypto: merkle: Implement Merkle hashing entrypoint

* crypto: poseidon2: poseidonUtils: Cleanup stack in each util

* crypto: merkle: Fix data pointer misalignment bugs
  • Loading branch information
joeykraut authored Jan 28, 2025
1 parent d88e5b8 commit 53199dc
Show file tree
Hide file tree
Showing 10 changed files with 281 additions and 44 deletions.
3 changes: 0 additions & 3 deletions foundry.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,3 @@ src = "src"
out = "out"
libs = ["lib"]
remappings = ["forge-std/=lib/forge-std/src/"]

[ffi]
enabled = true
202 changes: 202 additions & 0 deletions src/crypto/merkle/main.huff
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
/// @title Poseidon2
/// @author @joeykraut
/// @notice A Poseidon2 implementation in Huff. See https://eprint.iacr.org/2023/323
/// for more details. Inspired by https://github.com/zemse/poseidon2-evm/tree/main

#include "../poseidon2/poseidonUtils.huff"
#include "../poseidon2/poseidonPerm.huff"

// --- Interface --- //

/// @notice Hash an input with the given sister leaves in a Merkle tree
/// @param input The input to hash
/// @param idx The index of the input in the Merkle tree
/// @param sisterLeaves An array of sister leaves to hash with
/// @return The intermediate hashes of the input and sister leaves, with the root at the start
#define function hashMerkle(uint256 idx, uint256 input, uint256[] sisterLeaves) nonpayable returns(uint256)

// --- Constants --- //

/// @notice The depth of the Merkle tree
#define constant MERKLE_DEPTH = 0x20 // 32

/// @dev The revert code for the `sisterLeaves` array length mismatch
#define constant SISTER_LEAVES_LENGTH_MISMATCH = 0x01

/// @dev The total size of the return array (in bytes)
/// 32 bytes offset + 32 bytes length + 32 * 32 bytes data
#define constant RETURN_ARRAY_SIZE = 0x440 // 1088

// --- Entrypoint --- //

/// @notice Entrypoint for the Poseidon2 Merkle hash
#define macro MAIN() = takes(0) returns(0) {
// Currently we only have one function, so we don't need to mux the selector
// Push the data pointer of the return array onto the stack
// The helpers below expect the data pointer to trail the inputs
0x40 // [*returnArr]

// Load the inputs from calldata
0x04 calldataload // [idx, *returnArr]
0x24 calldataload // [input, idx, *returnArr]

// Load the data offset from the array start, then convert it to a data pointer
0x44 calldataload // [dataOffset, input, idx, *returnArr]
0x24 add // [*sisterLeaves, input, idx, *returnArr]

// Load the length of the sister leaves array, then check its length
0x64 calldataload // [len(sisterLeaves), dataOffset, input, idx, *returnArr]
[MERKLE_DEPTH] eq validLen jumpi // [dataOffset, idx, input, *returnArr]

[SISTER_LEAVES_LENGTH_MISMATCH] 0x00 mstore
0x20 0x00 revert

validLen:
// Setup the return value
// 1. Store the data offset at position 0x00
// Offset is always 32 bytes in our case
0x20 0x00 mstore // [dataOffset, idx, input, *returnArr]

// 2. Store the length of the return array at position 0x20
// Length is always 32 in our case
0x20 0x20 mstore // [dataOffset, idx, input, *returnArr]

// 3. Iteratively hash up the tree and store intermediate values in the array
HASH_AND_STORE_MERKLE_LEVEL() // Level 1
HASH_AND_STORE_MERKLE_LEVEL() // Level 2
HASH_AND_STORE_MERKLE_LEVEL() // Level 3
HASH_AND_STORE_MERKLE_LEVEL() // Level 4
HASH_AND_STORE_MERKLE_LEVEL() // Level 5
HASH_AND_STORE_MERKLE_LEVEL() // Level 6
HASH_AND_STORE_MERKLE_LEVEL() // Level 7
HASH_AND_STORE_MERKLE_LEVEL() // Level 8
HASH_AND_STORE_MERKLE_LEVEL() // Level 9
HASH_AND_STORE_MERKLE_LEVEL() // Level 10
HASH_AND_STORE_MERKLE_LEVEL() // Level 11
HASH_AND_STORE_MERKLE_LEVEL() // Level 12
HASH_AND_STORE_MERKLE_LEVEL() // Level 13
HASH_AND_STORE_MERKLE_LEVEL() // Level 14
HASH_AND_STORE_MERKLE_LEVEL() // Level 15
HASH_AND_STORE_MERKLE_LEVEL() // Level 16
HASH_AND_STORE_MERKLE_LEVEL() // Level 17
HASH_AND_STORE_MERKLE_LEVEL() // Level 18
HASH_AND_STORE_MERKLE_LEVEL() // Level 19
HASH_AND_STORE_MERKLE_LEVEL() // Level 20
HASH_AND_STORE_MERKLE_LEVEL() // Level 21
HASH_AND_STORE_MERKLE_LEVEL() // Level 22
HASH_AND_STORE_MERKLE_LEVEL() // Level 23
HASH_AND_STORE_MERKLE_LEVEL() // Level 24
HASH_AND_STORE_MERKLE_LEVEL() // Level 25
HASH_AND_STORE_MERKLE_LEVEL() // Level 26
HASH_AND_STORE_MERKLE_LEVEL() // Level 27
HASH_AND_STORE_MERKLE_LEVEL() // Level 28
HASH_AND_STORE_MERKLE_LEVEL() // Level 29
HASH_AND_STORE_MERKLE_LEVEL() // Level 30
HASH_AND_STORE_MERKLE_LEVEL() // Level 31
HASH_AND_STORE_MERKLE_LEVEL() // Level 32

// Return the array
[RETURN_ARRAY_SIZE] 0x00 return
}

/// @dev Hash the next level of the Merkle tree and store the result in the output array
/// @param Takes [*sister, input, idx, *returnArr]
/// @return [*nextSister, hash, nextIdx, *nextReturnArr]
#define macro HASH_AND_STORE_MERKLE_LEVEL() = takes(4) returns(4) {
// Takes [*sister, input, idx, *returnArr]
// Hash the input with the sister leaves
HASH_MERKLE_LEVEL() // [*nextSister, hash, nextIdx, *returnArr]

// Store the hash in the output array
swap3 // [*returnArr, hash, nextIdx, *nextSister]
dup2 dup2 // [*returnArr, hash, *returnArr, hash, nextIdx, *nextSister]
mstore // [*returnArr, hash, nextIdx, *nextSister]
0x20 add // [*nextReturnArr, hash, nextIdx, *nextSister]
swap3 // [*nextSister, hash, nextIdx, *nextReturnArr]
}

/// @dev Hash the next level in the Merkle tree using the sister node in the given calldata location
/// @param Takes [*sister, input, idx]
/// @dev The *nextSister returned is the next sister node in the next level of the Merkle tree
/// @dev The idx returned is the index in the next level of the Merkle tree
/// @return [*nextSister, hash, nextIdx]
#define macro HASH_MERKLE_LEVEL() = takes(3) returns(3) {
// Takes [*sister, input, idx]
// Compute the next sister node location then move to the back of the resident stack
dup1 0x20 add // [*nextSister, *sister, input, idx]
swap3 swap1 // [*sister, idx, input, *nextSister]
calldataload // [sister, idx, input, *nextSister]

// Compute the next index then move it to the back of the resident stack
dup2 0x01 shr // [nextIdx, sister, idx, input, *nextSister]
swap4 swap3 // [input, sister, idx, *nextSister, nextIdx]

// Hash the input with the sister node
HASH_TWO_LEAVES() // [hash, *nextSister, nextIdx]
swap1 // [*nextSister, hash, nextIdx]
}

/// @dev Hash the given input with the given sister leaf as if inserting into the given index
/// @param Takes [input, sister, idx]
/// @return [hash]
#define macro HASH_TWO_LEAVES() = takes(3) returns(2) {
// Takes [input, sister, idx]
// Reorder the inputs if the lowest order bit is 1, so that the input is the RHS
swap2 LOWEST_BIT() // [idx_0, sister, input]
noReorder jumpi // [sister, input]

// Swap the inputs then fallthrough
swap1 // [input, sister]
noReorder:
// Hash the input with the sister leaf
POSEIDON_TWO_TO_ONE() // [hash]
}

// --- Helpers --- //

/// @dev Push the lowest order bit of the input onto the stack
#define macro LOWEST_BIT(input) = takes(1) returns(1) {
// Takes [input]
PUSH_ONE() and // [input_0]
}

/// @dev Push the constant 0x01 onto the stack
#define macro PUSH_ONE() = returns(0) {
0x01
}

/// @dev Return the first value on the stack
#define macro RETURN_FIRST() = returns(0) {
0x00 mstore
0x20 0x00 return
}

/// @dev Return the second value on the stack
#define macro RETURN_SECOND() = returns(0) {
dup2 0x00 mstore
0x20 0x00 return
}

/// @dev Return the third value on the stack
#define macro RETURN_THIRD() = returns(0) {
dup3 0x00 mstore
0x20 0x00 return
}

/// @dev Return the fourth value on the stack
#define macro RETURN_FOURTH() = returns(0) {
dup4 0x00 mstore
0x20 0x00 return
}

/// @notice Write a zero return array to memory then return
#define macro RETURN_ZERO_ARRAY() = takes(0) returns(0) {
// First 32 bytes: array length (number of elements)
0x20 0x00 mstore // Store offset to data (32) at position 0x00
0x01 0x20 mstore // Store length of 1 at position 0x20
0x00 0x40 mstore // Store value of 0 at position 0x40

// Return the array (offset + length + data)
0x60 0x00 return // Return 96 bytes (32 for offset + 32 for length + 32 for data)
}

9 changes: 0 additions & 9 deletions src/crypto/poseidon2/main.huff

This file was deleted.

4 changes: 2 additions & 2 deletions src/crypto/poseidon2/poseidonPerm.huff
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
// Generated by renegade-solidity-contracts/codegen/poseidon-codegen

/// @dev Poseidon2 permutation function
#define macro POSEIDON_PERM() = takes(3) returns(3) {
// Takes [state[0], state[1]]
#define fn POSEIDON_PERM() = takes(3) returns(3) {
// Takes [state[0], state[1], state[2]]

// Start with the external MDS transformation
EXTERNAL_MDS()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
POSEIDON_PERM() // [state'[0], state'[1], state'[2]]

// We return the first element past the state capacity, in this case state'[1]
dup2 // [state'[1], state'[0], state'[1], state'[2], ...]
// Cleanup the stack before doing so
pop // [state'[1], state'[2]]
swap1 pop // [state'[1]]
}

/// @dev Apply an external round to the state
/// @param Takes [a, b, c]
/// @return [a', b', c']
#define macro EXTERNAL_ROUND(RC1, RC2, RC3) = takes(3) returns(7) {
#define macro EXTERNAL_ROUND(RC1, RC2, RC3) = takes(3) returns(3) {
// Add the round constants to the state and apply the sbox to individual elements
PUSH_PRIME() dup4 // [c, PRIME, a, b, c]
ADD_RC(<RC3>) // [c + RC3, a, b, c]
Expand Down Expand Up @@ -99,7 +101,7 @@
/// element
/// @param Takes [state[0], state[1], state[2]]
/// @return [state'[0], state'[1], state'[2]]
#define macro INTERNAL_MDS() = takes(3) returns(7) {
#define macro INTERNAL_MDS() = takes(3) returns(3) {
// Takes [state[0], state[1], state[2]]
SUM_FIRST_THREE() // [sum, state[0], state[1], state[2]]

Expand All @@ -113,9 +115,8 @@
dup5 addmod // [state'[1], sum, state[0], state[1], state'[2]]
swap3 pop // [sum, state[0], state'[1], state'[2]]

PUSH_PRIME() dup2 // [sum, PRIME, sum, state[0], state'[1], state'[2]]
dup4 addmod // [state'[0], sum, state[0], state'[1], state'[2]]
swap2 pop pop // [state'[0], state'[1], state'[2]]
PUSH_PRIME() swap2 // [state[0], sum, PRIME, state'[1], state'[2]]
addmod // [state'[0], state'[1], state'[2]]
}

/// @dev Apply the external MDS matrix to the sponge state
Expand All @@ -127,17 +128,21 @@
/// https://github.com/HorizenLabs/poseidon2/blob/main/plain_implementations/src/poseidon2/poseidon2.rs#L129-L137
/// @param Takes [state[0], state[1], state[2]]
/// @return [state'[0], state'[1], state'[2]]
#define macro EXTERNAL_MDS() = takes(3) returns(7) {
#define macro EXTERNAL_MDS() = takes(3) returns(3) {
// Takes [state[0], state[1], state[2]]
SUM_FIRST_THREE() // [sum, state[0], state[1], state[2]]

// Add the sum to each element
PUSH_PRIME() dup2 // [sum, PRIME, sum, state[0], state[1], state[2]]
dup6 addmod // [state'[2], sum, state[0], state[1], state[2]]
PUSH_PRIME() dup3 // [sum, PRIME, state'[2], sum, state[0], state[1], state[2]]
dup6 addmod // [state'[1], state'[2], sum, state[0], state[1], state[2]]
PUSH_PRIME() dup4 // [sum, PRIME, state'[1], state'[2], sum, state[0], state[1], state[2]]
dup6 addmod // [state'[0], state'[1], state'[2], sum, state[0], state[1], state[2]]
swap4 pop // [sum, state[0], state[1], state'[2]]

PUSH_PRIME() dup2 // [sum, PRIME, sum, state[0], state[1], state'[2]]
dup5 addmod // [state'[1], sum, state[0], state[1], state'[2]]
swap3 pop // [sum, state[0], state'[1], state'[2]]

PUSH_PRIME() swap2 // [state[0], sum, PRIME, state'[1], state'[2]]
addmod // [state'[0], state'[1], state'[2]]
}

// --- Helper Macros --- //
Expand All @@ -148,7 +153,7 @@
}

/// @dev Sum the first three elements of the stack modulo the prime
#define macro SUM_FIRST_THREE() = takes(0) returns(1) {
#define macro SUM_FIRST_THREE() = takes(3) returns(4) {
// Takes [a, b, c]
PUSH_PRIME() // [PRIME, a, b, c]
dup4 dup4 dup4 // [a, b, c, PRIME, a, b, c]
Expand Down
36 changes: 36 additions & 0 deletions test/Merkle.t.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.0;

import {Test} from "forge-std/Test.sol";
import {console} from "forge-std/console.sol";
import {HuffDeployer} from "foundry-huff/HuffDeployer.sol";
import {TestUtils} from "./utils/TestUtils.sol";

contract MerkleTest is TestUtils {
/// @dev The Merkle depth
uint256 constant MERKLE_DEPTH = 32;

/// @dev The MerklePoseidon contract
MerklePoseidon public merklePoseidon;

/// @dev Deploy the MerklePoseidon contract
function setUp() public {
merklePoseidon = MerklePoseidon(HuffDeployer.deploy("crypto/merkle/main"));
}

/// @dev Test the hashMerkle function
function testHashMerkle() public {
uint256 input = 1;
uint256 idx = 15;
uint256[] memory sisterLeaves = new uint256[](MERKLE_DEPTH);
for (uint256 i = 0; i < MERKLE_DEPTH; i++) {
sisterLeaves[i] = randomFelt();
}
uint256 result = merklePoseidon.hashMerkle(input, idx, sisterLeaves);
console.log("result:", result);
}
}

interface MerklePoseidon {
function hashMerkle(uint256 input, uint256 idx, uint256[] calldata sisterLeaves) external returns (uint256);
}
Loading

0 comments on commit 53199dc

Please sign in to comment.