Skip to content

Commit

Permalink
verifier: Transcript: Re-implement with more efficient Scalar parse (#14
Browse files Browse the repository at this point in the history
)

* verifier: Transcript: Re-implement with efficient scalar parse

* test: Transcript.t: Add multiple challenge transcript test
  • Loading branch information
joeykraut authored Feb 4, 2025
1 parent 1428a1d commit d3185ba
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 61 deletions.
64 changes: 40 additions & 24 deletions src/verifier/Transcript.sol
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ uint256 constant SECOND_CHUNK_OFFSET = 0x20;
/// @title The Fiat-Shamir transcript used by the verifier
/// @dev The underlying hash function is keccak256
struct Transcript {
/// @dev The hash state of the transcript as a fixed-size byte array
bytes hashState;
/// @dev The low 32 bytes of the hash state
uint256 hashStateLow;
/// @dev The high 32 bytes of the hash state
uint256 hashStateHigh;
/// @dev The concatenated bytes of all elements
bytes elements;
}
Expand All @@ -49,7 +51,8 @@ struct Transcript {
library TranscriptLib {
/// @dev Creates a new transcript in memory
function new_transcript() internal pure returns (Transcript memory t) {
t.hashState = new bytes(HASH_STATE_SIZE);
t.hashStateLow = 0;
t.hashStateHigh = 0;
t.elements = new bytes(0);
return t;
}
Expand All @@ -64,41 +67,54 @@ library TranscriptLib {
/// @dev Gets the current challenge from the transcript
/// @param self The transcript
/// @return Challenge The Fiat-Shamir challenge
function getChallenge(Transcript memory self) internal view returns (BN254.ScalarField) {
function getChallenge(Transcript memory self) internal pure returns (BN254.ScalarField) {
// Concatenate state, transcript elements, and 0/1 bytes
bytes memory input0 = abi.encodePacked(self.hashState, self.elements, uint8(0));
bytes memory input1 = abi.encodePacked(self.hashState, self.elements, uint8(1));
bytes memory input0 = abi.encodePacked(self.hashStateLow, self.hashStateHigh, self.elements, uint8(0));
bytes memory input1 = abi.encodePacked(self.hashStateLow, self.hashStateHigh, self.elements, uint8(1));

// Hash inputs
// Hash inputs and update the hash state
bytes32 low = keccak256(input0);
bytes32 high = keccak256(input1);
self.hashStateLow = uint256(low);
self.hashStateHigh = uint256(high);

// Extract challenge bytes
bytes memory lowBytes = new bytes(32);
bytes memory highBytes = new bytes(32);
// Extract challenge bytes, we wish to interpret keccak output in little-endian order,
// so we need to reverse the bytes when converting to the scalar type
bytes32 lowBytes;
bytes32 highBytes;
assembly {
// Store the low and high bytes in the hash state
let hashStateBase := mload(add(self, STRUCT_HEADER_OFFSET))
let statePtr := add(hashStateBase, ARRAY_LENGTH_OFFSET)
mstore(statePtr, low) // Store low 32 bytes
mstore(add(statePtr, SECOND_CHUNK_OFFSET), high) // Store high 32 bytes

// Get the data pointer for the bytes arrays
let lowBytesPtr := add(lowBytes, 0x20)
let highBytesPtr := add(highBytes, 0x20)
let lowBytesStatePtr := self
let highBytesStatePtr := add(lowBytesStatePtr, CHALLENGE_LOW_BYTES)

// Mask and store the values
mstore(lowBytesPtr, and(mload(statePtr), LOW_BYTES_MASK))
mstore(highBytesPtr, and(mload(add(statePtr, 31)), HIGH_BYTES_MASK))
lowBytes := and(mload(lowBytesStatePtr), LOW_BYTES_MASK)
highBytes := and(mload(highBytesStatePtr), HIGH_BYTES_MASK)
}

// Convert from bytes
uint256 lowUint = BN254.fromLeBytesModOrder(lowBytes);
uint256 highUint = BN254.fromLeBytesModOrder(highBytes);
BN254.ScalarField lowScalar = BN254.ScalarField.wrap(lowUint);
BN254.ScalarField highScalar = BN254.ScalarField.wrap(highUint);
BN254.ScalarField lowScalar = scalarFromLeBytes(lowBytes);
BN254.ScalarField highScalar = scalarFromLeBytes(highBytes);

BN254.ScalarField shiftedHigh = BN254.mul(highScalar, BN254.ScalarField.wrap(CHALLENGE_HIGH_SHIFT));
return BN254.add(lowScalar, shiftedHigh);
}

/// @dev Converts a little-endian bytes array to a uint256
function scalarFromLeBytes(bytes32 buf) internal pure returns (BN254.ScalarField) {
// Reverse the byte order
bytes32 reversedBuf;
assembly {
for { let i := 0 } lt(i, 32) { i := add(i, 1) } {
// Copy the next byte into the reversed buffer
let shift := mul(sub(31, i), 8)
reversedBuf := or(shl(shift, and(buf, 0xff)), reversedBuf)
buf := shr(8, buf)
}
}

// Convert to uint256, reduce via the modulus, and return
uint256 reduced = uint256(reversedBuf) % BN254.R_MOD;
return BN254.ScalarField.wrap(reduced);
}
}
66 changes: 52 additions & 14 deletions test/Transcript.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@ import {console2} from "forge-std/console2.sol";
contract TranscriptTest is TestUtils {
using TranscriptLib for Transcript;

/// @dev Number of random bytes to generate for test data
uint256 constant TEST_DATA_BYTES = 1024;

/// @notice Test the basic flow of transcript operations
/// @notice Test the basic flow of transcript operations with a single input
function testTranscriptBasic() public {
uint256 TEST_DATA_BYTES = 1024;

// Create a new transcript
Transcript memory transcript = TranscriptLib.new_transcript();

Expand All @@ -27,30 +26,69 @@ contract TranscriptTest is TestUtils {

// Get a challenge from our implementation
BN254.ScalarField challenge = transcript.getChallenge();
uint256 expectedChallenge = runReferenceImpl(testData);

// Get challenge from reference implementation
bytes[] memory inputs = new bytes[](1);
inputs[0] = testData;
uint256[] memory expectedChallenges = runReferenceImpl(inputs);

// Compare results
assertEq(
BN254.ScalarField.unwrap(challenge),
expectedChallenge,
expectedChallenges[0],
"Challenge mismatch between Solidity and reference implementation"
);
}

/// @notice Test the basic flow of transcript operations with multiple inputs
function testTranscriptMultiple() public {
uint256 TEST_DATA_BYTES = 1024;
uint256 NUM_TEST_INPUTS = 5;

// Create a new transcript
Transcript memory transcript = TranscriptLib.new_transcript();

// Generate multiple random test inputs
bytes[] memory testInputs = new bytes[](NUM_TEST_INPUTS);
for (uint256 i = 0; i < NUM_TEST_INPUTS; i++) {
testInputs[i] = vm.randomBytes(TEST_DATA_BYTES);
}

// Get challenges from our implementation
BN254.ScalarField[] memory challenges = new BN254.ScalarField[](NUM_TEST_INPUTS);
for (uint256 i = 0; i < NUM_TEST_INPUTS; i++) {
transcript.appendMessage(testInputs[i]);
challenges[i] = transcript.getChallenge();
}

// Get challenges from reference implementation
uint256[] memory expectedChallenges = runReferenceImpl(testInputs);

// Compare results
for (uint256 i = 0; i < NUM_TEST_INPUTS; i++) {
assertEq(
BN254.ScalarField.unwrap(challenges[i]),
expectedChallenges[i],
string(abi.encodePacked("Challenge mismatch at index ", vm.toString(i)))
);
}
}

/// @dev Helper to run the reference implementation
function runReferenceImpl(bytes memory data) internal returns (uint256) {
function runReferenceImpl(bytes[] memory inputs) internal returns (uint256[] memory) {
// First compile the binary
compileRustBinary("test/rust-reference-impls/transcript/Cargo.toml");

// Convert input bytes to hex string without 0x prefix
string memory hexData = bytesToHexString(data);

// Prepare arguments for the Rust binary
string[] memory args = new string[](2);
string[] memory args = new string[](inputs.length + 1);
args[0] = "./test/rust-reference-impls/target/debug/transcript";
args[1] = string(abi.encodePacked("0x", hexData));

// Run the reference implementation and parse result
return vm.parseUint(runBinaryGetResponse(args));
// Convert each input to hex string and add as argument
for (uint256 i = 0; i < inputs.length; i++) {
args[i + 1] = string(abi.encodePacked("0x", bytesToHexString(inputs[i])));
}

// Run the reference implementation and parse space-separated array output
return runBinaryGetArray(args, " ");
}
}
2 changes: 1 addition & 1 deletion test/rust-reference-impls/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ members = ["merkle", "poseidon", "transcript"]

[workspace.dependencies]
# === Renegade Dependencies === #
renegade-constants = { package = "constants", git = "https://github.com/renegade-fi/renegade.git", default-features=false }
renegade-constants = { package = "constants", git = "https://github.com/renegade-fi/renegade.git", default-features = false }
renegade-crypto = { git = "https://github.com/renegade-fi/renegade.git" }
mpc-plonk = { git = "https://github.com/renegade-fi/mpc-jellyfish.git" }

Expand Down
61 changes: 39 additions & 22 deletions test/rust-reference-impls/transcript/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,30 +10,47 @@ use renegade_constants::Scalar;
type BaseField = <Bn254 as Pairing>::BaseField;

fn main() -> Result<(), Box<dyn Error>> {
// Skip the first argument
// Skip the first argument (program name)
let args: Vec<String> = env::args().skip(1).collect();
assert!(args.len() == 1, "Expected 1 argument, got {}", args.len());
if args.is_empty() {
eprintln!("Expected at least 1 argument, got 0");
std::process::exit(1);
}

// Parse the hex string
let trimmed = args[0].trim_start_matches("0x");
let bytes = hex::decode(trimmed)?;

// Create a transcript, append the input, and compute the challenge
// Create a single transcript instance to use for all inputs
let mut ts = <SolidityTranscript as PlonkTranscript<BaseField>>::new(b"unused_label");
<SolidityTranscript as PlonkTranscript<BaseField>>::append_message(
&mut ts,
b"unused_label",
&bytes,
)?;
let challenge = <SolidityTranscript as PlonkTranscript<BaseField>>::get_and_append_challenge::<
Bn254,
>(&mut ts, b"unused_label")?;

// Return the challenge as a hex string
let scalar = Scalar::new(challenge);
let hex_str = format!("{:#x}", scalar.to_biguint());

// Prefix with RES: to ensure consistent string parsing
println!("RES:{hex_str}");
let mut challenges = Vec::new();

// For each argument, append it to the transcript and get a challenge
for arg in args {
// Parse the hex string
let trimmed = arg.trim_start_matches("0x");
let bytes = hex::decode(trimmed)?;

// Append message and get challenge
<SolidityTranscript as PlonkTranscript<BaseField>>::append_message(
&mut ts,
b"unused_label",
&bytes,
)?;
let challenge =
<SolidityTranscript as PlonkTranscript<BaseField>>::get_and_append_challenge::<Bn254>(
&mut ts,
b"unused_label",
)?;

// Convert challenge to scalar and store
let scalar = Scalar::new(challenge);
challenges.push(scalar);
}

// Convert challenges to decimal strings and join with spaces
let challenge_strings: Vec<String> = challenges
.iter()
.map(|c| c.to_biguint().to_string())
.collect();

// Output results as space-separated values with RES: prefix
println!("RES:{}", challenge_strings.join(" "));
Ok(())
}

0 comments on commit d3185ba

Please sign in to comment.