Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

crypto: poseidon2: Implement external MDS matrix multiplication #5

Merged
merged 2 commits into from
Jan 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 49 additions & 16 deletions src/crypto/poseidon2/roundUtils.huff
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@
/// @dev The scalar field modulus of BN254
#define constant PRIME = 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001

/// @dev Push the prime onto the stack
#define macro PUSH_PRIME() = {
[PRIME]
}
// --- Core Permutation Methods --- //

/// @dev Add the round constant to the element on top of the stack
#define macro ADD_RC(RC) = takes(2) returns(1) {
Expand Down Expand Up @@ -48,19 +45,55 @@
/// @return [state'[0], state'[1], state'[2]]
#define macro INTERNAL_MDS() = takes(3) returns(7) {
// Takes [state[0], state[1], state[2]]
PUSH_PRIME() // [PRIME, state[0], state[1], state[2]]
dup4 dup4 dup4 // [state[0], state[1], state[2], PRIME, state[0], state[1], state[2]]
SUM_FIRST_THREE() // [sum, state[0], state[1], state[2]]

// Compute the sum of the elements (mod p)
PUSH_PRIME() swap2 // [state[1], state[0], PRIME, state[2], PRIME, state[0], state[1], state[2]]
addmod // [state[0] + state[1] mod p, state[2], PRIME, state[0], state[1], state[2]]
addmod // [sum, state[0], state[1], state[2]]
// Double the last state element and add the sum to each element
PUSH_PRIME() dup2 PUSH_PRIME() // [PRIME, sum, PRIME, sum, state[0], state[1], state[2]]
dup7 dup1 addmod // [state[2] * 2, sum, PRIME, sum, state[0], state[1], state[2]]
addmod // [state'[2], sum, state[0], state[1], state[2]]
PUSH_PRIME() dup3 // [sum, PRIME, state'[2], state[0], state[1], state[2]]
dup6 addmod // [state'[1], state'[2], sum, state[0], state[1], state[2]]
PUSH_PRIME() dup4 // [sum, PRIME, state'[1], state'[2], sum, state[0], state[1], state[2]]
dup6 addmod // [state'[0], state'[1], state'[2], sum, state[0], state[1], state[2]]
}

/// @dev Apply the external MDS matrix to the sponge state
/// For t = 3, this is the circulant matrix `circ(2, 1, 1)`
///
/// This is equivalent to doubling each element then adding the other two to
/// it, or more efficiently: adding the sum of the elements to each
/// individual element. This efficient structure is borrowed from:
/// https://github.com/HorizenLabs/poseidon2/blob/main/plain_implementations/src/poseidon2/poseidon2.rs#L129-L137
/// @param Takes [state[0], state[1], state[2]]
/// @return [state'[0], state'[1], state'[2]]
#define macro EXTERNAL_MDS() = takes(3) returns(7) {
// Takes [state[0], state[1], state[2]]
SUM_FIRST_THREE() // [sum, state[0], state[1], state[2]]

// Add the sum to each element
PUSH_PRIME() dup2 // [sum, PRIME, sum, state[0], state[1], state[2]]
dup6 addmod // [state[2] + sum, sum, state[0], state[1], state[2]]
PUSH_PRIME() dup3 // [sum, PRIME, state[2] + sum, sum, state[0], state[1], state[2]]
dup6 addmod // [state[1] + sum, state[2] + sum, sum, state[0], state[1], state[2]]
PUSH_PRIME() dup4 // [sum, PRIME, state[1] + sum, state[2] + sum, sum, state[0], state[1], state[2]]
dup6 addmod // [state[0] + sum, state[1] + sum, state[2] + sum, sum, state[0], state[1], state[2]]
PUSH_PRIME() dup2 // [sum, PRIME, sum, state[0], state[1], state[2]]
dup6 addmod // [state'[2], sum, state[0], state[1], state[2]]
PUSH_PRIME() dup3 // [sum, PRIME, state'[2], sum, state[0], state[1], state[2]]
dup6 addmod // [state'[1], state'[2], sum, state[0], state[1], state[2]]
PUSH_PRIME() dup4 // [sum, PRIME, state'[1], state'[2], sum, state[0], state[1], state[2]]
dup6 addmod // [state'[0], state'[1], state'[2], sum, state[0], state[1], state[2]]
}

// --- Helper Macros --- //

/// @dev Push the prime onto the stack
#define macro PUSH_PRIME() = {
[PRIME]
}

/// @dev Sum the first three elements of the stack modulo the prime
#define macro SUM_FIRST_THREE() = takes(0) returns(1) {
// Takes [a, b, c]
PUSH_PRIME() // [PRIME, a, b, c]
dup4 dup4 dup4 // [a, b, c, PRIME, a, b, c]

// Compute the sum of the elements (mod p)
PUSH_PRIME() swap2 // [b, a, PRIME, c, PRIME, a, b, c]
addmod // [a + b mod p, c, PRIME, a, b, c]
addmod // [sum, a, b, c]
}
21 changes: 21 additions & 0 deletions test/Poseidon.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,26 @@ contract PoseidonTest is Test {
uint256 sum = addmod(a, b, PRIME);
sum = addmod(sum, c, PRIME);

// Calculate the expected results
uint256 expectedA = addmod(a, sum, PRIME);
uint256 expectedB = addmod(b, sum, PRIME);
uint256 expectedC = addmod(c, addmod(c, sum, PRIME), PRIME);
assertEq(a1, expectedA, "Expected result to match a + sum mod p");
assertEq(b1, expectedB, "Expected result to match b + sum mod p");
assertEq(c1, expectedC, "Expected result to match c + sum mod p");
}

/// @dev Test the external MDS function applied to a trio of inputs
function testExternalMds() public {
uint256 a = vm.randomUint();
uint256 b = vm.randomUint();
uint256 c = vm.randomUint();
(uint256 a1, uint256 b1, uint256 c1) = poseidonSuite.testExternalMds(a, b, c);

// Calculate the sum of the elements
uint256 sum = addmod(a, b, PRIME);
sum = addmod(sum, c, PRIME);

// Calculate the expected results
uint256 expectedA = addmod(a, sum, PRIME);
uint256 expectedB = addmod(b, sum, PRIME);
Expand All @@ -65,4 +85,5 @@ interface PoseidonSuite {
function testSboxSingle(uint256) external returns (uint256);
function testAddRc(uint256) external returns (uint256);
function testInternalMds(uint256, uint256, uint256) external returns (uint256, uint256, uint256);
function testExternalMds(uint256, uint256, uint256) external returns (uint256, uint256, uint256);
}
26 changes: 24 additions & 2 deletions test/huff/testPoseidonUtils.huff
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
#define function testSboxSingle(uint256) nonpayable returns(uint256)
/// @dev Test the ADD_RC function applied to a single input
#define function testAddRc(uint256) nonpayable returns(uint256)
/// @dev Test the internal MDS function applied to a single input
/// @dev Test the internal MDS function applied to a trio of inputs
#define function testInternalMds(uint256, uint256, uint256) nonpayable returns(uint256, uint256, uint256)
/// @dev Test the external MDS function applied to a trio of inputs
#define function testExternalMds(uint256, uint256, uint256) nonpayable returns(uint256, uint256, uint256)

/// @dev The test round constant
#define constant TEST_RC = 0x1337
Expand All @@ -21,6 +23,7 @@
dup1 __FUNC_SIG(testSboxSingle) eq testSboxSingle jumpi
dup1 __FUNC_SIG(testAddRc) eq testAddRc jumpi
dup1 __FUNC_SIG(testInternalMds) eq testInternalMds jumpi
dup1 __FUNC_SIG(testExternalMds) eq testExternalMds jumpi

// Revert if the function selector is not valid
0x1 0x0 mstore
Expand All @@ -32,8 +35,12 @@
TEST_ADD_RC()
testInternalMds:
TEST_INTERNAL_MDS()
testExternalMds:
TEST_EXTERNAL_MDS()
}

// --- Test Cases --- //

/// @notice Test the sbox function applied to a single input
#define macro TEST_SBOX_SINGLE() = takes(0) returns(0) {
// Get the input from calldata
Expand All @@ -60,7 +67,7 @@
RETURN_FIRST()
}

/// @notice Test the internal MDS function applied to a single input
/// @notice Test the internal MDS function applied to a trio of inputs
#define macro TEST_INTERNAL_MDS() = takes(0) returns(0) {
// Get the input from calldata
PUSH_PRIME()
Expand All @@ -75,6 +82,21 @@
RETURN_FIRST_THREE()
}

/// @notice Test the external MDS function applied to a trio of inputs
#define macro TEST_EXTERNAL_MDS() = takes(0) returns(0) {
// Get the input from calldata
PUSH_PRIME()
0x44 calldataload
0x24 calldataload
0x04 calldataload // [state[0], state[1], state[2], PRIME]

// Call the internal MDS function
EXTERNAL_MDS()

// Return the result
RETURN_FIRST_THREE()
}

// --- Helpers --- //

/// @dev Return the first value on the stack
Expand Down
Loading