From 53199dc587e2058eba98d87d16e14e6eaf58ee94 Mon Sep 17 00:00:00 2001 From: Joey Kraut <108701651+joeykraut@users.noreply.github.com> Date: Tue, 28 Jan 2025 11:20:12 -0800 Subject: [PATCH] crypto: merkle: Implement Merkle hash with intermediate values (#10) * crypto: merkle: Implement Merkle hashing entrypoint * crypto: poseidon2: poseidonUtils: Cleanup stack in each util * crypto: merkle: Fix data pointer misalignment bugs --- foundry.toml | 3 - src/crypto/merkle/main.huff | 202 ++++++++++++++++++ src/crypto/poseidon2/main.huff | 9 - src/crypto/poseidon2/poseidonPerm.huff | 4 +- .../{roundUtils.huff => poseidonUtils.huff} | 29 +-- test/Merkle.t.sol | 36 ++++ test/Poseidon.t.sol | 20 +- test/huff/testPoseidonUtils.huff | 2 +- .../src/main.rs | 4 +- test/utils/TestUtils.sol | 16 ++ 10 files changed, 281 insertions(+), 44 deletions(-) create mode 100644 src/crypto/merkle/main.huff delete mode 100644 src/crypto/poseidon2/main.huff rename src/crypto/poseidon2/{roundUtils.huff => poseidonUtils.huff} (85%) create mode 100644 test/Merkle.t.sol create mode 100644 test/utils/TestUtils.sol diff --git a/foundry.toml b/foundry.toml index 503d0df..803b416 100644 --- a/foundry.toml +++ b/foundry.toml @@ -5,6 +5,3 @@ src = "src" out = "out" libs = ["lib"] remappings = ["forge-std/=lib/forge-std/src/"] - -[ffi] -enabled = true diff --git a/src/crypto/merkle/main.huff b/src/crypto/merkle/main.huff new file mode 100644 index 0000000..b50ddfe --- /dev/null +++ b/src/crypto/merkle/main.huff @@ -0,0 +1,202 @@ +/// @title Poseidon2 +/// @author @joeykraut +/// @notice A Poseidon2 implementation in Huff. See https://eprint.iacr.org/2023/323 +/// for more details. Inspired by https://github.com/zemse/poseidon2-evm/tree/main + +#include "../poseidon2/poseidonUtils.huff" +#include "../poseidon2/poseidonPerm.huff" + +// --- Interface --- // + +/// @notice Hash an input with the given sister leaves in a Merkle tree +/// @param input The input to hash +/// @param idx The index of the input in the Merkle tree +/// @param sisterLeaves An array of sister leaves to hash with +/// @return The intermediate hashes of the input and sister leaves, with the root at the start +#define function hashMerkle(uint256 idx, uint256 input, uint256[] sisterLeaves) nonpayable returns(uint256) + +// --- Constants --- // + +/// @notice The depth of the Merkle tree +#define constant MERKLE_DEPTH = 0x20 // 32 + +/// @dev The revert code for the `sisterLeaves` array length mismatch +#define constant SISTER_LEAVES_LENGTH_MISMATCH = 0x01 + +/// @dev The total size of the return array (in bytes) +/// 32 bytes offset + 32 bytes length + 32 * 32 bytes data +#define constant RETURN_ARRAY_SIZE = 0x440 // 1088 + +// --- Entrypoint --- // + +/// @notice Entrypoint for the Poseidon2 Merkle hash +#define macro MAIN() = takes(0) returns(0) { + // Currently we only have one function, so we don't need to mux the selector + // Push the data pointer of the return array onto the stack + // The helpers below expect the data pointer to trail the inputs + 0x40 // [*returnArr] + + // Load the inputs from calldata + 0x04 calldataload // [idx, *returnArr] + 0x24 calldataload // [input, idx, *returnArr] + + // Load the data offset from the array start, then convert it to a data pointer + 0x44 calldataload // [dataOffset, input, idx, *returnArr] + 0x24 add // [*sisterLeaves, input, idx, *returnArr] + + // Load the length of the sister leaves array, then check its length + 0x64 calldataload // [len(sisterLeaves), dataOffset, input, idx, *returnArr] + [MERKLE_DEPTH] eq validLen jumpi // [dataOffset, idx, input, *returnArr] + + [SISTER_LEAVES_LENGTH_MISMATCH] 0x00 mstore + 0x20 0x00 revert + + validLen: + // Setup the return value + // 1. Store the data offset at position 0x00 + // Offset is always 32 bytes in our case + 0x20 0x00 mstore // [dataOffset, idx, input, *returnArr] + + // 2. Store the length of the return array at position 0x20 + // Length is always 32 in our case + 0x20 0x20 mstore // [dataOffset, idx, input, *returnArr] + + // 3. Iteratively hash up the tree and store intermediate values in the array + HASH_AND_STORE_MERKLE_LEVEL() // Level 1 + HASH_AND_STORE_MERKLE_LEVEL() // Level 2 + HASH_AND_STORE_MERKLE_LEVEL() // Level 3 + HASH_AND_STORE_MERKLE_LEVEL() // Level 4 + HASH_AND_STORE_MERKLE_LEVEL() // Level 5 + HASH_AND_STORE_MERKLE_LEVEL() // Level 6 + HASH_AND_STORE_MERKLE_LEVEL() // Level 7 + HASH_AND_STORE_MERKLE_LEVEL() // Level 8 + HASH_AND_STORE_MERKLE_LEVEL() // Level 9 + HASH_AND_STORE_MERKLE_LEVEL() // Level 10 + HASH_AND_STORE_MERKLE_LEVEL() // Level 11 + HASH_AND_STORE_MERKLE_LEVEL() // Level 12 + HASH_AND_STORE_MERKLE_LEVEL() // Level 13 + HASH_AND_STORE_MERKLE_LEVEL() // Level 14 + HASH_AND_STORE_MERKLE_LEVEL() // Level 15 + HASH_AND_STORE_MERKLE_LEVEL() // Level 16 + HASH_AND_STORE_MERKLE_LEVEL() // Level 17 + HASH_AND_STORE_MERKLE_LEVEL() // Level 18 + HASH_AND_STORE_MERKLE_LEVEL() // Level 19 + HASH_AND_STORE_MERKLE_LEVEL() // Level 20 + HASH_AND_STORE_MERKLE_LEVEL() // Level 21 + HASH_AND_STORE_MERKLE_LEVEL() // Level 22 + HASH_AND_STORE_MERKLE_LEVEL() // Level 23 + HASH_AND_STORE_MERKLE_LEVEL() // Level 24 + HASH_AND_STORE_MERKLE_LEVEL() // Level 25 + HASH_AND_STORE_MERKLE_LEVEL() // Level 26 + HASH_AND_STORE_MERKLE_LEVEL() // Level 27 + HASH_AND_STORE_MERKLE_LEVEL() // Level 28 + HASH_AND_STORE_MERKLE_LEVEL() // Level 29 + HASH_AND_STORE_MERKLE_LEVEL() // Level 30 + HASH_AND_STORE_MERKLE_LEVEL() // Level 31 + HASH_AND_STORE_MERKLE_LEVEL() // Level 32 + + // Return the array + [RETURN_ARRAY_SIZE] 0x00 return +} + +/// @dev Hash the next level of the Merkle tree and store the result in the output array +/// @param Takes [*sister, input, idx, *returnArr] +/// @return [*nextSister, hash, nextIdx, *nextReturnArr] +#define macro HASH_AND_STORE_MERKLE_LEVEL() = takes(4) returns(4) { + // Takes [*sister, input, idx, *returnArr] + // Hash the input with the sister leaves + HASH_MERKLE_LEVEL() // [*nextSister, hash, nextIdx, *returnArr] + + // Store the hash in the output array + swap3 // [*returnArr, hash, nextIdx, *nextSister] + dup2 dup2 // [*returnArr, hash, *returnArr, hash, nextIdx, *nextSister] + mstore // [*returnArr, hash, nextIdx, *nextSister] + 0x20 add // [*nextReturnArr, hash, nextIdx, *nextSister] + swap3 // [*nextSister, hash, nextIdx, *nextReturnArr] +} + +/// @dev Hash the next level in the Merkle tree using the sister node in the given calldata location +/// @param Takes [*sister, input, idx] +/// @dev The *nextSister returned is the next sister node in the next level of the Merkle tree +/// @dev The idx returned is the index in the next level of the Merkle tree +/// @return [*nextSister, hash, nextIdx] +#define macro HASH_MERKLE_LEVEL() = takes(3) returns(3) { + // Takes [*sister, input, idx] + // Compute the next sister node location then move to the back of the resident stack + dup1 0x20 add // [*nextSister, *sister, input, idx] + swap3 swap1 // [*sister, idx, input, *nextSister] + calldataload // [sister, idx, input, *nextSister] + + // Compute the next index then move it to the back of the resident stack + dup2 0x01 shr // [nextIdx, sister, idx, input, *nextSister] + swap4 swap3 // [input, sister, idx, *nextSister, nextIdx] + + // Hash the input with the sister node + HASH_TWO_LEAVES() // [hash, *nextSister, nextIdx] + swap1 // [*nextSister, hash, nextIdx] +} + +/// @dev Hash the given input with the given sister leaf as if inserting into the given index +/// @param Takes [input, sister, idx] +/// @return [hash] +#define macro HASH_TWO_LEAVES() = takes(3) returns(2) { + // Takes [input, sister, idx] + // Reorder the inputs if the lowest order bit is 1, so that the input is the RHS + swap2 LOWEST_BIT() // [idx_0, sister, input] + noReorder jumpi // [sister, input] + + // Swap the inputs then fallthrough + swap1 // [input, sister] + noReorder: + // Hash the input with the sister leaf + POSEIDON_TWO_TO_ONE() // [hash] +} + +// --- Helpers --- // + +/// @dev Push the lowest order bit of the input onto the stack +#define macro LOWEST_BIT(input) = takes(1) returns(1) { + // Takes [input] + PUSH_ONE() and // [input_0] +} + +/// @dev Push the constant 0x01 onto the stack +#define macro PUSH_ONE() = returns(0) { + 0x01 +} + +/// @dev Return the first value on the stack +#define macro RETURN_FIRST() = returns(0) { + 0x00 mstore + 0x20 0x00 return +} + +/// @dev Return the second value on the stack +#define macro RETURN_SECOND() = returns(0) { + dup2 0x00 mstore + 0x20 0x00 return +} + +/// @dev Return the third value on the stack +#define macro RETURN_THIRD() = returns(0) { + dup3 0x00 mstore + 0x20 0x00 return +} + +/// @dev Return the fourth value on the stack +#define macro RETURN_FOURTH() = returns(0) { + dup4 0x00 mstore + 0x20 0x00 return +} + +/// @notice Write a zero return array to memory then return +#define macro RETURN_ZERO_ARRAY() = takes(0) returns(0) { + // First 32 bytes: array length (number of elements) + 0x20 0x00 mstore // Store offset to data (32) at position 0x00 + 0x01 0x20 mstore // Store length of 1 at position 0x20 + 0x00 0x40 mstore // Store value of 0 at position 0x40 + + // Return the array (offset + length + data) + 0x60 0x00 return // Return 96 bytes (32 for offset + 32 for length + 32 for data) +} + diff --git a/src/crypto/poseidon2/main.huff b/src/crypto/poseidon2/main.huff deleted file mode 100644 index 3a491f6..0000000 --- a/src/crypto/poseidon2/main.huff +++ /dev/null @@ -1,9 +0,0 @@ -/// @title Poseidon2 -/// @author @joeykraut -/// @notice A Poseidon2 implementation in Huff. See https://eprint.iacr.org/2023/323 -/// for more details. Inspired by https://github.com/zemse/poseidon2-evm/tree/main - -#define macro MAIN() = takes(0) returns(0) { -} - - diff --git a/src/crypto/poseidon2/poseidonPerm.huff b/src/crypto/poseidon2/poseidonPerm.huff index 7bc3d55..827e6b3 100644 --- a/src/crypto/poseidon2/poseidonPerm.huff +++ b/src/crypto/poseidon2/poseidonPerm.huff @@ -2,8 +2,8 @@ // Generated by renegade-solidity-contracts/codegen/poseidon-codegen /// @dev Poseidon2 permutation function -#define macro POSEIDON_PERM() = takes(3) returns(3) { - // Takes [state[0], state[1]] +#define fn POSEIDON_PERM() = takes(3) returns(3) { + // Takes [state[0], state[1], state[2]] // Start with the external MDS transformation EXTERNAL_MDS() diff --git a/src/crypto/poseidon2/roundUtils.huff b/src/crypto/poseidon2/poseidonUtils.huff similarity index 85% rename from src/crypto/poseidon2/roundUtils.huff rename to src/crypto/poseidon2/poseidonUtils.huff index dd8716c..1faaee1 100644 --- a/src/crypto/poseidon2/roundUtils.huff +++ b/src/crypto/poseidon2/poseidonUtils.huff @@ -13,13 +13,15 @@ POSEIDON_PERM() // [state'[0], state'[1], state'[2]] // We return the first element past the state capacity, in this case state'[1] - dup2 // [state'[1], state'[0], state'[1], state'[2], ...] + // Cleanup the stack before doing so + pop // [state'[1], state'[2]] + swap1 pop // [state'[1]] } /// @dev Apply an external round to the state /// @param Takes [a, b, c] /// @return [a', b', c'] -#define macro EXTERNAL_ROUND(RC1, RC2, RC3) = takes(3) returns(7) { +#define macro EXTERNAL_ROUND(RC1, RC2, RC3) = takes(3) returns(3) { // Add the round constants to the state and apply the sbox to individual elements PUSH_PRIME() dup4 // [c, PRIME, a, b, c] ADD_RC() // [c + RC3, a, b, c] @@ -99,7 +101,7 @@ /// element /// @param Takes [state[0], state[1], state[2]] /// @return [state'[0], state'[1], state'[2]] -#define macro INTERNAL_MDS() = takes(3) returns(7) { +#define macro INTERNAL_MDS() = takes(3) returns(3) { // Takes [state[0], state[1], state[2]] SUM_FIRST_THREE() // [sum, state[0], state[1], state[2]] @@ -113,9 +115,8 @@ dup5 addmod // [state'[1], sum, state[0], state[1], state'[2]] swap3 pop // [sum, state[0], state'[1], state'[2]] - PUSH_PRIME() dup2 // [sum, PRIME, sum, state[0], state'[1], state'[2]] - dup4 addmod // [state'[0], sum, state[0], state'[1], state'[2]] - swap2 pop pop // [state'[0], state'[1], state'[2]] + PUSH_PRIME() swap2 // [state[0], sum, PRIME, state'[1], state'[2]] + addmod // [state'[0], state'[1], state'[2]] } /// @dev Apply the external MDS matrix to the sponge state @@ -127,17 +128,21 @@ /// https://github.com/HorizenLabs/poseidon2/blob/main/plain_implementations/src/poseidon2/poseidon2.rs#L129-L137 /// @param Takes [state[0], state[1], state[2]] /// @return [state'[0], state'[1], state'[2]] -#define macro EXTERNAL_MDS() = takes(3) returns(7) { +#define macro EXTERNAL_MDS() = takes(3) returns(3) { // Takes [state[0], state[1], state[2]] SUM_FIRST_THREE() // [sum, state[0], state[1], state[2]] // Add the sum to each element PUSH_PRIME() dup2 // [sum, PRIME, sum, state[0], state[1], state[2]] dup6 addmod // [state'[2], sum, state[0], state[1], state[2]] - PUSH_PRIME() dup3 // [sum, PRIME, state'[2], sum, state[0], state[1], state[2]] - dup6 addmod // [state'[1], state'[2], sum, state[0], state[1], state[2]] - PUSH_PRIME() dup4 // [sum, PRIME, state'[1], state'[2], sum, state[0], state[1], state[2]] - dup6 addmod // [state'[0], state'[1], state'[2], sum, state[0], state[1], state[2]] + swap4 pop // [sum, state[0], state[1], state'[2]] + + PUSH_PRIME() dup2 // [sum, PRIME, sum, state[0], state[1], state'[2]] + dup5 addmod // [state'[1], sum, state[0], state[1], state'[2]] + swap3 pop // [sum, state[0], state'[1], state'[2]] + + PUSH_PRIME() swap2 // [state[0], sum, PRIME, state'[1], state'[2]] + addmod // [state'[0], state'[1], state'[2]] } // --- Helper Macros --- // @@ -148,7 +153,7 @@ } /// @dev Sum the first three elements of the stack modulo the prime -#define macro SUM_FIRST_THREE() = takes(0) returns(1) { +#define macro SUM_FIRST_THREE() = takes(3) returns(4) { // Takes [a, b, c] PUSH_PRIME() // [PRIME, a, b, c] dup4 dup4 dup4 // [a, b, c, PRIME, a, b, c] diff --git a/test/Merkle.t.sol b/test/Merkle.t.sol new file mode 100644 index 0000000..c21d888 --- /dev/null +++ b/test/Merkle.t.sol @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.0; + +import {Test} from "forge-std/Test.sol"; +import {console} from "forge-std/console.sol"; +import {HuffDeployer} from "foundry-huff/HuffDeployer.sol"; +import {TestUtils} from "./utils/TestUtils.sol"; + +contract MerkleTest is TestUtils { + /// @dev The Merkle depth + uint256 constant MERKLE_DEPTH = 32; + + /// @dev The MerklePoseidon contract + MerklePoseidon public merklePoseidon; + + /// @dev Deploy the MerklePoseidon contract + function setUp() public { + merklePoseidon = MerklePoseidon(HuffDeployer.deploy("crypto/merkle/main")); + } + + /// @dev Test the hashMerkle function + function testHashMerkle() public { + uint256 input = 1; + uint256 idx = 15; + uint256[] memory sisterLeaves = new uint256[](MERKLE_DEPTH); + for (uint256 i = 0; i < MERKLE_DEPTH; i++) { + sisterLeaves[i] = randomFelt(); + } + uint256 result = merklePoseidon.hashMerkle(input, idx, sisterLeaves); + console.log("result:", result); + } +} + +interface MerklePoseidon { + function hashMerkle(uint256 input, uint256 idx, uint256[] calldata sisterLeaves) external returns (uint256); +} diff --git a/test/Poseidon.t.sol b/test/Poseidon.t.sol index 8d5c670..9155067 100644 --- a/test/Poseidon.t.sol +++ b/test/Poseidon.t.sol @@ -4,13 +4,12 @@ pragma solidity ^0.8.0; import {Test} from "forge-std/Test.sol"; import {console} from "forge-std/console.sol"; import {HuffDeployer} from "foundry-huff/HuffDeployer.sol"; +import {TestUtils} from "./utils/TestUtils.sol"; -contract PoseidonTest is Test { +contract PoseidonTest is TestUtils { /// @dev The Poseidon main contract PoseidonSuite public poseidonSuite; - /// @dev The BN254 field modulus from roundUtils.huff - uint256 PRIME = 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001; /// @dev The round constant used in testing uint256 TEST_RC1 = 0x1337; /// @dev The second round constant used in testing @@ -143,22 +142,15 @@ contract PoseidonTest is Test { /// --- Helpers --- /// - /// @dev Generates a random input modulo the PRIME - /// Note that this is not uniformly distributed over the prime field, because of the "wraparound" - /// but it suffices for fuzzing test inputs - function randomFelt() internal returns (uint256) { - return vm.randomUint() % PRIME; - } - /// @dev Calculate the fifth power of an input - function fifthPower(uint256 x) internal view returns (uint256) { + function fifthPower(uint256 x) internal pure returns (uint256) { uint256 x2 = mulmod(x, x, PRIME); uint256 x4 = mulmod(x2, x2, PRIME); return mulmod(x, x4, PRIME); } /// @dev Calculate the result of the internal MDS matrix applied to the inputs - function internalMds(uint256 a, uint256 b, uint256 c) internal view returns (uint256, uint256, uint256) { + function internalMds(uint256 a, uint256 b, uint256 c) internal pure returns (uint256, uint256, uint256) { uint256 sum = sumInputs(a, b, c); uint256 a1 = addmod(a, sum, PRIME); uint256 b1 = addmod(b, sum, PRIME); @@ -167,7 +159,7 @@ contract PoseidonTest is Test { } /// @dev Calculate the result of the external MDS matrix applied to the inputs - function externalMds(uint256 a, uint256 b, uint256 c) internal view returns (uint256, uint256, uint256) { + function externalMds(uint256 a, uint256 b, uint256 c) internal pure returns (uint256, uint256, uint256) { uint256 sum = sumInputs(a, b, c); uint256 a1 = addmod(a, sum, PRIME); uint256 b1 = addmod(b, sum, PRIME); @@ -194,7 +186,7 @@ contract PoseidonTest is Test { } /// @dev Sum the inputs and return the result - function sumInputs(uint256 a, uint256 b, uint256 c) internal view returns (uint256) { + function sumInputs(uint256 a, uint256 b, uint256 c) internal pure returns (uint256) { uint256 sum = addmod(a, b, PRIME); sum = addmod(sum, c, PRIME); return sum; diff --git a/test/huff/testPoseidonUtils.huff b/test/huff/testPoseidonUtils.huff index 34f44cd..68563eb 100644 --- a/test/huff/testPoseidonUtils.huff +++ b/test/huff/testPoseidonUtils.huff @@ -2,7 +2,7 @@ /// @author @joeykraut /// @notice Test the utils for the Poseidon2 permutation -#include "../../src/crypto/poseidon2/roundUtils.huff" +#include "../../src/crypto/poseidon2/poseidonUtils.huff" #include "../../src/crypto/poseidon2/poseidonPerm.huff" /// @dev Test the SBOX function applied to a single input diff --git a/test/poseidon-reference-implementation/src/main.rs b/test/poseidon-reference-implementation/src/main.rs index baa95a2..8732435 100644 --- a/test/poseidon-reference-implementation/src/main.rs +++ b/test/poseidon-reference-implementation/src/main.rs @@ -1,9 +1,7 @@ -use num_bigint::BigUint; use renegade_constants::Scalar; -use renegade_crypto::fields::{biguint_to_scalar, scalar_to_biguint}; +use renegade_crypto::fields::scalar_to_biguint; use renegade_crypto::hash::Poseidon2Sponge; use std::env; -use std::str::FromStr; fn main() { let args: Vec = env::args().collect(); diff --git a/test/utils/TestUtils.sol b/test/utils/TestUtils.sol new file mode 100644 index 0000000..9f0bf5d --- /dev/null +++ b/test/utils/TestUtils.sol @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.0; + +import {Test} from "forge-std/Test.sol"; + +contract TestUtils is Test { + /// @dev The BN254 field modulus from roundUtils.huff + uint256 constant PRIME = 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001; + + /// @dev Generates a random input modulo the PRIME + /// Note that this is not uniformly distributed over the prime field, because of the "wraparound" + /// but it suffices for fuzzing test inputs + function randomFelt() internal returns (uint256) { + return vm.randomUint() % PRIME; + } +}