Skip to content

Commit

Permalink
crypto: poseidon2: Implement round functions (#6)
Browse files Browse the repository at this point in the history
* crypto: poseidon2: Add external round macro impl

* crypto: poseidon2: Add internal round macro impl
  • Loading branch information
joeykraut authored Jan 26, 2025
1 parent 123bb28 commit b290d59
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 30 deletions.
40 changes: 40 additions & 0 deletions src/crypto/poseidon2/roundUtils.huff
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,46 @@
/// @dev The scalar field modulus of BN254
#define constant PRIME = 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001

/// @dev Apply an external round to the state
/// @param Takes [a, b, c]
/// @return [a', b', c']
#define macro EXTERNAL_ROUND(RC1, RC2, RC3) = takes(3) returns(7) {
// Add the round constants to the state and apply the sbox to individual elements
PUSH_PRIME() dup4 // [c, PRIME, a, b, c]
ADD_RC(<RC3>) // [c + RC3, a, b, c]
PUSH_PRIME() swap1 // [c + RC3, PRIME, a, b, c]
SBOX() // [(c + RC3)^5, a, b, c]

PUSH_PRIME() dup4 // [b, PRIME, (c + RC3)^5, a, b, c]
ADD_RC(<RC2>) // [b + RC2, (c + RC3)^5, a, b, c]
PUSH_PRIME() swap1 // [b + RC2, PRIME, (c + RC3)^5, a, b, c]
SBOX() // [(b + RC2)^5, (c + RC3)^5, a, b, c]

PUSH_PRIME() dup4 // [a, PRIME, (b + RC2)^5, (c + RC3)^5, a, b, c]
ADD_RC(<RC1>) // [a + RC1, (b + RC2)^5, (c + RC3)^5, a, b, c]
PUSH_PRIME() swap1 // [a + RC1, PRIME, (b + RC2)^5, (c + RC3)^5, a, b, c]
SBOX() // [(a + RC1)^5, (b + RC2)^5, (c + RC3)^5, a, b, c]

// Multiply the intermediate state by the external round MDS matrix
EXTERNAL_MDS()
}

/// @dev Apply an internal round to the state
/// In an internal round, the round constant is only added to the first element,
/// and the sbox is applied to the first element only
/// @param Takes [a, b, c]
/// @return [a', b', c']
#define macro INTERNAL_ROUND(RC1) = takes(3) returns(7) {
// Add the round constant to the first element
PUSH_PRIME() swap1 // [a, PRIME, b, c]
ADD_RC(<RC1>) // [a + RC1, b, c]
PUSH_PRIME() swap1 // [a + RC1, PRIME, b, c]
SBOX() // [(a + RC1)^5, b, c]

// Apply the internal MDS matrix
INTERNAL_MDS()
}

// --- Core Permutation Methods --- //

/// @dev Add the round constant to the element on top of the stack
Expand Down
110 changes: 89 additions & 21 deletions test/Poseidon.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@ contract PoseidonTest is Test {
/// @dev The BN254 field modulus from roundUtils.huff
uint256 PRIME = 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001;
/// @dev The round constant used in testing
uint256 TEST_RC = 0x1337;
uint256 TEST_RC1 = 0x1337;
/// @dev The second round constant used in testing
uint256 TEST_RC2 = 0x1338;
/// @dev The third round constant used in testing
uint256 TEST_RC3 = 0x1339;

/// @dev Deploy the PoseidonSuite contract
function setUp() public {
Expand All @@ -25,17 +29,15 @@ contract PoseidonTest is Test {
uint256 result = poseidonSuite.testSboxSingle(testValue);

// Calculate expected x^5 mod p
uint256 x2 = mulmod(testValue, testValue, PRIME);
uint256 x4 = mulmod(x2, x2, PRIME);
uint256 expected = mulmod(testValue, x4, PRIME);
uint256 expected = fifthPower(testValue);
assertEq(result, expected, "Expected result to match x^5 mod p");
}

/// @dev Test the add round constant function applied to a single input
function testAddRcSingle() public {
uint256 testValue = vm.randomUint();
uint256 result = poseidonSuite.testAddRc(testValue);
uint256 expected = addmod(testValue, TEST_RC, PRIME);
uint256 expected = addmod(testValue, TEST_RC1, PRIME);
assertEq(result, expected, "Expected result to match x + RC mod p");
}

Expand All @@ -47,14 +49,8 @@ contract PoseidonTest is Test {
uint256 c = vm.randomUint();
(uint256 a1, uint256 b1, uint256 c1) = poseidonSuite.testInternalMds(a, b, c);

// Calculate the sum of the elements
uint256 sum = addmod(a, b, PRIME);
sum = addmod(sum, c, PRIME);

// Calculate the expected results
uint256 expectedA = addmod(a, sum, PRIME);
uint256 expectedB = addmod(b, sum, PRIME);
uint256 expectedC = addmod(c, addmod(c, sum, PRIME), PRIME);
(uint256 expectedA, uint256 expectedB, uint256 expectedC) = internalMds(a, b, c);
assertEq(a1, expectedA, "Expected result to match a + sum mod p");
assertEq(b1, expectedB, "Expected result to match b + sum mod p");
assertEq(c1, expectedC, "Expected result to match c + sum mod p");
Expand All @@ -67,17 +63,87 @@ contract PoseidonTest is Test {
uint256 c = vm.randomUint();
(uint256 a1, uint256 b1, uint256 c1) = poseidonSuite.testExternalMds(a, b, c);

// Calculate the sum of the elements
// Calculate the expected results
(uint256 expectedA, uint256 expectedB, uint256 expectedC) = externalMds(a, b, c);
assertEq(a1, expectedA, "Expected result to match a");
assertEq(b1, expectedB, "Expected result to match b");
assertEq(c1, expectedC, "Expected result to match c");
}

/// @dev Test the external round function applied to a trio of inputs
function testExternalRound() public {
uint256 a = vm.randomUint();
uint256 b = vm.randomUint();
uint256 c = vm.randomUint();
(uint256 a1, uint256 b1, uint256 c1) = poseidonSuite.testExternalRound(a, b, c);
(uint256 expectedA, uint256 expectedB, uint256 expectedC) = externalRound(a, b, c);
assertEq(a1, expectedA, "Expected result to match a");
assertEq(b1, expectedB, "Expected result to match b");
assertEq(c1, expectedC, "Expected result to match c");
}

/// @dev Test the internal round function applied to a trio of inputs
function testInternalRound() public {
uint256 a = vm.randomUint();
uint256 b = vm.randomUint();
uint256 c = vm.randomUint();
(uint256 a1, uint256 b1, uint256 c1) = poseidonSuite.testInternalRound(a, b, c);
(uint256 expectedA, uint256 expectedB, uint256 expectedC) = internalRound(a, b, c);
assertEq(a1, expectedA, "Expected result to match a");
assertEq(b1, expectedB, "Expected result to match b");
assertEq(c1, expectedC, "Expected result to match c");
}

/// --- Helpers --- ///

/// @dev Calculate the fifth power of an input
function fifthPower(uint256 x) internal view returns (uint256) {
uint256 x2 = mulmod(x, x, PRIME);
uint256 x4 = mulmod(x2, x2, PRIME);
return mulmod(x, x4, PRIME);
}

/// @dev Calculate the result of the internal MDS matrix applied to the inputs
function internalMds(uint256 a, uint256 b, uint256 c) internal view returns (uint256, uint256, uint256) {
uint256 sum = sumInputs(a, b, c);
uint256 a1 = addmod(a, sum, PRIME);
uint256 b1 = addmod(b, sum, PRIME);
uint256 c1 = addmod(addmod(c, sum, PRIME), c, PRIME); // c is doubled
return (a1, b1, c1);
}

/// @dev Calculate the result of the external MDS matrix applied to the inputs
function externalMds(uint256 a, uint256 b, uint256 c) internal view returns (uint256, uint256, uint256) {
uint256 sum = sumInputs(a, b, c);
uint256 a1 = addmod(a, sum, PRIME);
uint256 b1 = addmod(b, sum, PRIME);
uint256 c1 = addmod(c, sum, PRIME);
return (a1, b1, c1);
}

/// @dev Calculate the result of the external round function applied to the inputs
function externalRound(uint256 a, uint256 b, uint256 c) internal view returns (uint256, uint256, uint256) {
uint256 a1 = addmod(a, TEST_RC1, PRIME);
uint256 b1 = addmod(b, TEST_RC2, PRIME);
uint256 c1 = addmod(c, TEST_RC3, PRIME);
uint256 a2 = fifthPower(a1);
uint256 b2 = fifthPower(b1);
uint256 c2 = fifthPower(c1);
return externalMds(a2, b2, c2);
}

/// @dev Calculate the result of the internal round function applied to the inputs
function internalRound(uint256 a, uint256 b, uint256 c) internal view returns (uint256, uint256, uint256) {
uint256 a1 = addmod(a, TEST_RC1, PRIME);
uint256 a2 = fifthPower(a1);
return internalMds(a2, b, c);
}

/// @dev Sum the inputs and return the result
function sumInputs(uint256 a, uint256 b, uint256 c) internal view returns (uint256) {
uint256 sum = addmod(a, b, PRIME);
sum = addmod(sum, c, PRIME);

// Calculate the expected results
uint256 expectedA = addmod(a, sum, PRIME);
uint256 expectedB = addmod(b, sum, PRIME);
uint256 expectedC = addmod(c, sum, PRIME);
assertEq(a1, expectedA, "Expected result to match a + sum mod p");
assertEq(b1, expectedB, "Expected result to match b + sum mod p");
assertEq(c1, expectedC, "Expected result to match c + sum mod p");
return sum;
}
}

Expand All @@ -86,4 +152,6 @@ interface PoseidonSuite {
function testAddRc(uint256) external returns (uint256);
function testInternalMds(uint256, uint256, uint256) external returns (uint256, uint256, uint256);
function testExternalMds(uint256, uint256, uint256) external returns (uint256, uint256, uint256);
function testExternalRound(uint256, uint256, uint256) external returns (uint256, uint256, uint256);
function testInternalRound(uint256, uint256, uint256) external returns (uint256, uint256, uint256);
}
59 changes: 50 additions & 9 deletions test/huff/testPoseidonUtils.huff
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,29 @@
#define function testInternalMds(uint256, uint256, uint256) nonpayable returns(uint256, uint256, uint256)
/// @dev Test the external MDS function applied to a trio of inputs
#define function testExternalMds(uint256, uint256, uint256) nonpayable returns(uint256, uint256, uint256)
/// @dev Test the external round function applied to a trio of inputs
#define function testExternalRound(uint256, uint256, uint256) nonpayable returns(uint256, uint256, uint256)
/// @dev Test the internal round function applied to a trio of inputs
#define function testInternalRound(uint256, uint256, uint256) nonpayable returns(uint256, uint256, uint256)

/// @dev The first test round constant
#define constant TEST_RC1 = 0x1337
/// @dev The second test round constant
#define constant TEST_RC2 = 0x1338
/// @dev The third test round constant
#define constant TEST_RC3 = 0x1339

/// @dev The test round constant
#define constant TEST_RC = 0x1337

/// @dev Entrypoint to the poseidon test suite
#define macro MAIN() = takes(0) returns(0) {
// Get the function selector
0x0 calldataload 0xe0 shr // [SELECTOR]
dup1 __FUNC_SIG(testSboxSingle) eq testSboxSingle jumpi
dup1 __FUNC_SIG(testAddRc) eq testAddRc jumpi
dup1 __FUNC_SIG(testInternalMds) eq testInternalMds jumpi
dup1 __FUNC_SIG(testExternalMds) eq testExternalMds jumpi
dup1 __FUNC_SIG(testSboxSingle) eq testSboxSingle jumpi
dup1 __FUNC_SIG(testAddRc) eq testAddRc jumpi
dup1 __FUNC_SIG(testInternalMds) eq testInternalMds jumpi
dup1 __FUNC_SIG(testExternalMds) eq testExternalMds jumpi
dup1 __FUNC_SIG(testExternalRound) eq testExternalRound jumpi
dup1 __FUNC_SIG(testInternalRound) eq testInternalRound jumpi

// Revert if the function selector is not valid
0x1 0x0 mstore
Expand All @@ -37,6 +48,10 @@
TEST_INTERNAL_MDS()
testExternalMds:
TEST_EXTERNAL_MDS()
testExternalRound:
TEST_EXTERNAL_ROUND()
testInternalRound:
TEST_INTERNAL_ROUND(TEST_RC1)
}

// --- Test Cases --- //
Expand All @@ -61,7 +76,7 @@
0x04 calldataload

// Call the add round constant function
ADD_RC(TEST_RC)
ADD_RC(TEST_RC1)

// Return the result
RETURN_FIRST()
Expand All @@ -70,7 +85,6 @@
/// @notice Test the internal MDS function applied to a trio of inputs
#define macro TEST_INTERNAL_MDS() = takes(0) returns(0) {
// Get the input from calldata
PUSH_PRIME()
0x44 calldataload
0x24 calldataload
0x04 calldataload // [state[0], state[1], state[2], PRIME]
Expand All @@ -85,7 +99,6 @@
/// @notice Test the external MDS function applied to a trio of inputs
#define macro TEST_EXTERNAL_MDS() = takes(0) returns(0) {
// Get the input from calldata
PUSH_PRIME()
0x44 calldataload
0x24 calldataload
0x04 calldataload // [state[0], state[1], state[2], PRIME]
Expand All @@ -97,6 +110,34 @@
RETURN_FIRST_THREE()
}

/// @notice Test the external round function applied to a trio of inputs
#define macro TEST_EXTERNAL_ROUND() = takes(0) returns(0) {
// Get the input from calldata
0x44 calldataload
0x24 calldataload
0x04 calldataload // [state[0], state[1], state[2], PRIME]

// Apply the external round function
EXTERNAL_ROUND(TEST_RC1, TEST_RC2, TEST_RC3)

// Return the result
RETURN_FIRST_THREE()
}

/// @notice Test the internal round function applied to a trio of inputs
#define macro TEST_INTERNAL_ROUND() = takes(0) returns(0) {
// Get the input from calldata
0x44 calldataload
0x24 calldataload
0x04 calldataload // [state[0], state[1], state[2], PRIME]

// Call the internal round function
INTERNAL_ROUND(TEST_RC1)

// Return the result
RETURN_FIRST_THREE()
}

// --- Helpers --- //

/// @dev Return the first value on the stack
Expand Down

0 comments on commit b290d59

Please sign in to comment.