Skip to content

Commit

Permalink
crypto: poseidon2: Implement external MDS matrix multiplication (#5)
Browse files Browse the repository at this point in the history
* crypto: poseidon2: Double last element in internal MDS

* crypto: poseidon2: Implement external MDS matrix multiplication
  • Loading branch information
joeykraut authored Jan 26, 2025
1 parent 48903cd commit 123bb28
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 18 deletions.
65 changes: 49 additions & 16 deletions src/crypto/poseidon2/roundUtils.huff
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@
/// @dev The scalar field modulus of BN254
#define constant PRIME = 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001

/// @dev Push the prime onto the stack
#define macro PUSH_PRIME() = {
[PRIME]
}
// --- Core Permutation Methods --- //

/// @dev Add the round constant to the element on top of the stack
#define macro ADD_RC(RC) = takes(2) returns(1) {
Expand Down Expand Up @@ -48,19 +45,55 @@
/// @return [state'[0], state'[1], state'[2]]
#define macro INTERNAL_MDS() = takes(3) returns(7) {
// Takes [state[0], state[1], state[2]]
PUSH_PRIME() // [PRIME, state[0], state[1], state[2]]
dup4 dup4 dup4 // [state[0], state[1], state[2], PRIME, state[0], state[1], state[2]]
SUM_FIRST_THREE() // [sum, state[0], state[1], state[2]]

// Compute the sum of the elements (mod p)
PUSH_PRIME() swap2 // [state[1], state[0], PRIME, state[2], PRIME, state[0], state[1], state[2]]
addmod // [state[0] + state[1] mod p, state[2], PRIME, state[0], state[1], state[2]]
addmod // [sum, state[0], state[1], state[2]]
// Double the last state element and add the sum to each element
PUSH_PRIME() dup2 PUSH_PRIME() // [PRIME, sum, PRIME, sum, state[0], state[1], state[2]]
dup7 dup1 addmod // [state[2] * 2, sum, PRIME, sum, state[0], state[1], state[2]]
addmod // [state'[2], sum, state[0], state[1], state[2]]
PUSH_PRIME() dup3 // [sum, PRIME, state'[2], state[0], state[1], state[2]]
dup6 addmod // [state'[1], state'[2], sum, state[0], state[1], state[2]]
PUSH_PRIME() dup4 // [sum, PRIME, state'[1], state'[2], sum, state[0], state[1], state[2]]
dup6 addmod // [state'[0], state'[1], state'[2], sum, state[0], state[1], state[2]]
}

/// @dev Apply the external MDS matrix to the sponge state
/// For t = 3, this is the circulant matrix `circ(2, 1, 1)`
///
/// This is equivalent to doubling each element then adding the other two to
/// it, or more efficiently: adding the sum of the elements to each
/// individual element. This efficient structure is borrowed from:
/// https://github.com/HorizenLabs/poseidon2/blob/main/plain_implementations/src/poseidon2/poseidon2.rs#L129-L137
/// @param Takes [state[0], state[1], state[2]]
/// @return [state'[0], state'[1], state'[2]]
#define macro EXTERNAL_MDS() = takes(3) returns(7) {
// Takes [state[0], state[1], state[2]]
SUM_FIRST_THREE() // [sum, state[0], state[1], state[2]]

// Add the sum to each element
PUSH_PRIME() dup2 // [sum, PRIME, sum, state[0], state[1], state[2]]
dup6 addmod // [state[2] + sum, sum, state[0], state[1], state[2]]
PUSH_PRIME() dup3 // [sum, PRIME, state[2] + sum, sum, state[0], state[1], state[2]]
dup6 addmod // [state[1] + sum, state[2] + sum, sum, state[0], state[1], state[2]]
PUSH_PRIME() dup4 // [sum, PRIME, state[1] + sum, state[2] + sum, sum, state[0], state[1], state[2]]
dup6 addmod // [state[0] + sum, state[1] + sum, state[2] + sum, sum, state[0], state[1], state[2]]
PUSH_PRIME() dup2 // [sum, PRIME, sum, state[0], state[1], state[2]]
dup6 addmod // [state'[2], sum, state[0], state[1], state[2]]
PUSH_PRIME() dup3 // [sum, PRIME, state'[2], sum, state[0], state[1], state[2]]
dup6 addmod // [state'[1], state'[2], sum, state[0], state[1], state[2]]
PUSH_PRIME() dup4 // [sum, PRIME, state'[1], state'[2], sum, state[0], state[1], state[2]]
dup6 addmod // [state'[0], state'[1], state'[2], sum, state[0], state[1], state[2]]
}

// --- Helper Macros --- //

/// @dev Push the prime onto the stack
#define macro PUSH_PRIME() = {
[PRIME]
}

/// @dev Sum the first three elements of the stack modulo the prime
#define macro SUM_FIRST_THREE() = takes(0) returns(1) {
// Takes [a, b, c]
PUSH_PRIME() // [PRIME, a, b, c]
dup4 dup4 dup4 // [a, b, c, PRIME, a, b, c]

// Compute the sum of the elements (mod p)
PUSH_PRIME() swap2 // [b, a, PRIME, c, PRIME, a, b, c]
addmod // [a + b mod p, c, PRIME, a, b, c]
addmod // [sum, a, b, c]
}
21 changes: 21 additions & 0 deletions test/Poseidon.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,26 @@ contract PoseidonTest is Test {
uint256 sum = addmod(a, b, PRIME);
sum = addmod(sum, c, PRIME);

// Calculate the expected results
uint256 expectedA = addmod(a, sum, PRIME);
uint256 expectedB = addmod(b, sum, PRIME);
uint256 expectedC = addmod(c, addmod(c, sum, PRIME), PRIME);
assertEq(a1, expectedA, "Expected result to match a + sum mod p");
assertEq(b1, expectedB, "Expected result to match b + sum mod p");
assertEq(c1, expectedC, "Expected result to match c + sum mod p");
}

/// @dev Test the external MDS function applied to a trio of inputs
function testExternalMds() public {
uint256 a = vm.randomUint();
uint256 b = vm.randomUint();
uint256 c = vm.randomUint();
(uint256 a1, uint256 b1, uint256 c1) = poseidonSuite.testExternalMds(a, b, c);

// Calculate the sum of the elements
uint256 sum = addmod(a, b, PRIME);
sum = addmod(sum, c, PRIME);

// Calculate the expected results
uint256 expectedA = addmod(a, sum, PRIME);
uint256 expectedB = addmod(b, sum, PRIME);
Expand All @@ -65,4 +85,5 @@ interface PoseidonSuite {
function testSboxSingle(uint256) external returns (uint256);
function testAddRc(uint256) external returns (uint256);
function testInternalMds(uint256, uint256, uint256) external returns (uint256, uint256, uint256);
function testExternalMds(uint256, uint256, uint256) external returns (uint256, uint256, uint256);
}
26 changes: 24 additions & 2 deletions test/huff/testPoseidonUtils.huff
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
#define function testSboxSingle(uint256) nonpayable returns(uint256)
/// @dev Test the ADD_RC function applied to a single input
#define function testAddRc(uint256) nonpayable returns(uint256)
/// @dev Test the internal MDS function applied to a single input
/// @dev Test the internal MDS function applied to a trio of inputs
#define function testInternalMds(uint256, uint256, uint256) nonpayable returns(uint256, uint256, uint256)
/// @dev Test the external MDS function applied to a trio of inputs
#define function testExternalMds(uint256, uint256, uint256) nonpayable returns(uint256, uint256, uint256)

/// @dev The test round constant
#define constant TEST_RC = 0x1337
Expand All @@ -21,6 +23,7 @@
dup1 __FUNC_SIG(testSboxSingle) eq testSboxSingle jumpi
dup1 __FUNC_SIG(testAddRc) eq testAddRc jumpi
dup1 __FUNC_SIG(testInternalMds) eq testInternalMds jumpi
dup1 __FUNC_SIG(testExternalMds) eq testExternalMds jumpi

// Revert if the function selector is not valid
0x1 0x0 mstore
Expand All @@ -32,8 +35,12 @@
TEST_ADD_RC()
testInternalMds:
TEST_INTERNAL_MDS()
testExternalMds:
TEST_EXTERNAL_MDS()
}

// --- Test Cases --- //

/// @notice Test the sbox function applied to a single input
#define macro TEST_SBOX_SINGLE() = takes(0) returns(0) {
// Get the input from calldata
Expand All @@ -60,7 +67,7 @@
RETURN_FIRST()
}

/// @notice Test the internal MDS function applied to a single input
/// @notice Test the internal MDS function applied to a trio of inputs
#define macro TEST_INTERNAL_MDS() = takes(0) returns(0) {
// Get the input from calldata
PUSH_PRIME()
Expand All @@ -75,6 +82,21 @@
RETURN_FIRST_THREE()
}

/// @notice Test the external MDS function applied to a trio of inputs
#define macro TEST_EXTERNAL_MDS() = takes(0) returns(0) {
// Get the input from calldata
PUSH_PRIME()
0x44 calldataload
0x24 calldataload
0x04 calldataload // [state[0], state[1], state[2], PRIME]

// Call the internal MDS function
EXTERNAL_MDS()

// Return the result
RETURN_FIRST_THREE()
}

// --- Helpers --- //

/// @dev Return the first value on the stack
Expand Down

0 comments on commit 123bb28

Please sign in to comment.