Skip to content

Commit

Permalink
crypto: poseidon2: Add internal round macro impl
Browse files Browse the repository at this point in the history
  • Loading branch information
joeykraut committed Jan 26, 2025
1 parent b1db263 commit d87796a
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 4 deletions.
16 changes: 16 additions & 0 deletions src/crypto/poseidon2/roundUtils.huff
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,22 @@
EXTERNAL_MDS()
}

/// @dev Apply an internal round to the state
/// In an internal round, the round constant is only added to the first element,
/// and the sbox is applied to the first element only
/// @param Takes [a, b, c]
/// @return [a', b', c']
#define macro INTERNAL_ROUND(RC1) = takes(3) returns(7) {
// Add the round constant to the first element
PUSH_PRIME() swap1 // [a, PRIME, b, c]
ADD_RC(<RC1>) // [a + RC1, b, c]
PUSH_PRIME() swap1 // [a + RC1, PRIME, b, c]
SBOX() // [(a + RC1)^5, b, c]

// Apply the internal MDS matrix
INTERNAL_MDS()
}

// --- Core Permutation Methods --- //

/// @dev Add the round constant to the element on top of the stack
Expand Down
49 changes: 46 additions & 3 deletions test/Poseidon.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,33 @@ contract PoseidonTest is Test {

// Calculate the expected results
(uint256 expectedA, uint256 expectedB, uint256 expectedC) = externalMds(a, b, c);
assertEq(a1, expectedA, "Expected result to match a + sum mod p");
assertEq(b1, expectedB, "Expected result to match b + sum mod p");
assertEq(c1, expectedC, "Expected result to match c + sum mod p");
assertEq(a1, expectedA, "Expected result to match a");
assertEq(b1, expectedB, "Expected result to match b");
assertEq(c1, expectedC, "Expected result to match c");
}

/// @dev Test the external round function applied to a trio of inputs
function testExternalRound() public {
uint256 a = vm.randomUint();
uint256 b = vm.randomUint();
uint256 c = vm.randomUint();
(uint256 a1, uint256 b1, uint256 c1) = poseidonSuite.testExternalRound(a, b, c);
(uint256 expectedA, uint256 expectedB, uint256 expectedC) = externalRound(a, b, c);
assertEq(a1, expectedA, "Expected result to match a");
assertEq(b1, expectedB, "Expected result to match b");
assertEq(c1, expectedC, "Expected result to match c");
}

/// @dev Test the internal round function applied to a trio of inputs
function testInternalRound() public {
uint256 a = vm.randomUint();
uint256 b = vm.randomUint();
uint256 c = vm.randomUint();
(uint256 a1, uint256 b1, uint256 c1) = poseidonSuite.testInternalRound(a, b, c);
(uint256 expectedA, uint256 expectedB, uint256 expectedC) = internalRound(a, b, c);
assertEq(a1, expectedA, "Expected result to match a");
assertEq(b1, expectedB, "Expected result to match b");
assertEq(c1, expectedC, "Expected result to match c");
}

/// --- Helpers --- ///
Expand Down Expand Up @@ -97,6 +121,24 @@ contract PoseidonTest is Test {
return (a1, b1, c1);
}

/// @dev Calculate the result of the external round function applied to the inputs
function externalRound(uint256 a, uint256 b, uint256 c) internal view returns (uint256, uint256, uint256) {
uint256 a1 = addmod(a, TEST_RC1, PRIME);
uint256 b1 = addmod(b, TEST_RC2, PRIME);
uint256 c1 = addmod(c, TEST_RC3, PRIME);
uint256 a2 = fifthPower(a1);
uint256 b2 = fifthPower(b1);
uint256 c2 = fifthPower(c1);
return externalMds(a2, b2, c2);
}

/// @dev Calculate the result of the internal round function applied to the inputs
function internalRound(uint256 a, uint256 b, uint256 c) internal view returns (uint256, uint256, uint256) {
uint256 a1 = addmod(a, TEST_RC1, PRIME);
uint256 a2 = fifthPower(a1);
return internalMds(a2, b, c);
}

/// @dev Sum the inputs and return the result
function sumInputs(uint256 a, uint256 b, uint256 c) internal view returns (uint256) {
uint256 sum = addmod(a, b, PRIME);
Expand All @@ -111,4 +153,5 @@ interface PoseidonSuite {
function testInternalMds(uint256, uint256, uint256) external returns (uint256, uint256, uint256);
function testExternalMds(uint256, uint256, uint256) external returns (uint256, uint256, uint256);
function testExternalRound(uint256, uint256, uint256) external returns (uint256, uint256, uint256);
function testInternalRound(uint256, uint256, uint256) external returns (uint256, uint256, uint256);
}
21 changes: 20 additions & 1 deletion test/huff/testPoseidonUtils.huff
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#define function testExternalMds(uint256, uint256, uint256) nonpayable returns(uint256, uint256, uint256)
/// @dev Test the external round function applied to a trio of inputs
#define function testExternalRound(uint256, uint256, uint256) nonpayable returns(uint256, uint256, uint256)
/// @dev Test the internal round function applied to a trio of inputs
#define function testInternalRound(uint256, uint256, uint256) nonpayable returns(uint256, uint256, uint256)

/// @dev The first test round constant
#define constant TEST_RC1 = 0x1337
Expand All @@ -32,6 +34,7 @@
dup1 __FUNC_SIG(testInternalMds) eq testInternalMds jumpi
dup1 __FUNC_SIG(testExternalMds) eq testExternalMds jumpi
dup1 __FUNC_SIG(testExternalRound) eq testExternalRound jumpi
dup1 __FUNC_SIG(testInternalRound) eq testInternalRound jumpi

// Revert if the function selector is not valid
0x1 0x0 mstore
Expand All @@ -47,6 +50,8 @@
TEST_EXTERNAL_MDS()
testExternalRound:
TEST_EXTERNAL_ROUND()
testInternalRound:
TEST_INTERNAL_ROUND(TEST_RC1)
}

// --- Test Cases --- //
Expand Down Expand Up @@ -113,7 +118,21 @@
0x04 calldataload // [state[0], state[1], state[2], PRIME]

// Apply the external round function
EXTERNAL_ROUND()
EXTERNAL_ROUND(TEST_RC1, TEST_RC2, TEST_RC3)

// Return the result
RETURN_FIRST_THREE()
}

/// @notice Test the internal round function applied to a trio of inputs
#define macro TEST_INTERNAL_ROUND() = takes(0) returns(0) {
// Get the input from calldata
0x44 calldataload
0x24 calldataload
0x04 calldataload // [state[0], state[1], state[2], PRIME]

// Call the internal round function
INTERNAL_ROUND(TEST_RC1)

// Return the result
RETURN_FIRST_THREE()
Expand Down

0 comments on commit d87796a

Please sign in to comment.