From 76c3bf32f28a25eb95bf62b9f635b1a95093d217 Mon Sep 17 00:00:00 2001 From: Davidson Souza Date: Tue, 21 Jan 2025 15:08:23 -0300 Subject: [PATCH 1/2] Pollard: add a leaf map When proving, we currently ask for positions. However, positions may change during updates, so externally holding those positions may be cumbersome. The leaf hash, on the other hand never changes. This commit does the same we do for `MemForest`, creating a map from leaf hashes to a reference to the actual leaf in memory. Once we need to prove it, we compute the position on-the-fly. --- src/accumulator/pollard.rs | 191 ++++++++++++++++++++++++++++++++++--- 1 file changed, 180 insertions(+), 11 deletions(-) diff --git a/src/accumulator/pollard.rs b/src/accumulator/pollard.rs index 7572bf7..e128c20 100644 --- a/src/accumulator/pollard.rs +++ b/src/accumulator/pollard.rs @@ -36,6 +36,7 @@ /// //TODO: Add usage examples use std::cell::Cell; use std::cell::RefCell; +use std::collections::HashMap; use std::fmt::Debug; use std::fmt::Display; use std::rc::Rc; @@ -47,8 +48,10 @@ use super::util::detect_row; use super::util::detwin; use super::util::get_proof_positions; use super::util::is_root_position; +use super::util::left_child; use super::util::max_position_at_row; use super::util::parent; +use super::util::right_child; use super::util::root_position; use super::util::tree_rows; @@ -121,6 +124,20 @@ impl PollardNode { }) } + fn parent(&self) -> Option>> { + let granparent = self.grandparent(); + if granparent.is_none() { + return self.aunt(); + } + + let granparent = granparent.unwrap(); + if granparent.left_niece().eq(&self.aunt()) { + granparent.right_niece() + } else { + granparent.left_niece() + } + } + /// Returns the hash of this node fn hash(&self) -> Hash { self.hash.get() @@ -353,6 +370,7 @@ pub struct Pollard { /// add 5 leaves and delete 4, this value will still be 5. Moreover, the position of a leaf is /// the number of leaves when it was added, so we can always find a leaf by it's position. leaves: u64, + leaf_map: HashMap>>, } impl PartialEq for Pollard { @@ -432,7 +450,7 @@ impl Pollard { self.do_ingest_proof(proof, del_hashes, remembers, false) } - pub fn prune(&self, positions: &[u64]) -> Result<(), &'static str> { + pub fn prune(&mut self, positions: &[u64]) -> Result<(), &'static str> { let positions = detwin(positions.to_vec(), tree_rows(self.leaves)); let nodes = positions .into_iter() @@ -440,8 +458,9 @@ impl Pollard { .collect::>(); for node in nodes { - let node = node.ok_or("Position not found")?; - node.0.prune(); + let (node, _) = node.ok_or("Position not found")?; + self.leaf_map.remove(&node.hash()); + node.prune(); } Ok(()) @@ -461,8 +480,18 @@ impl Pollard { /// Proves the inclusion of the nodes at the given positions /// /// This function takes a list of positions and returns a list of proofs for each position. - pub fn batch_proof(&self, targets: &[u64]) -> Result, &'static str> { - let targets = detwin(targets.to_vec(), tree_rows(self.leaves)); + pub fn batch_proof(&self, targets: &[Hash]) -> Result, String> { + let mut positions = Vec::new(); + for target in targets { + let node = self + .leaf_map + .get(target) + .ok_or(format!("leaf {target} not found in the forest"))?; + let position = self.get_pos(node)?; + positions.push(position); + } + + let targets = detwin(positions, tree_rows(self.leaves)); let positions = get_proof_positions(&targets, self.leaves, tree_rows(self.leaves)); let mut proof_hashes = Vec::new(); @@ -482,7 +511,9 @@ impl Pollard { }) } - pub fn prove_single(&self, pos: u64) -> Result, &'static str> { + pub fn prove_single(&self, leaf: Hash) -> Result, String> { + let node = self.leaf_map.get(&leaf).ok_or("Leaf not found")?; + let pos = self.get_pos(node)?; let hashes = self.prove_single_inner(pos)?; let targets = vec![pos]; @@ -506,8 +537,8 @@ impl Pollard { proof: Proof, ) -> Result<(), String> { let targets = proof.targets.clone(); - self.ingest_proof(proof.clone(), del_hashes, &targets) - .unwrap(); + self.ingest_proof(proof.clone(), del_hashes, &targets)?; + let targets = detwin(targets, tree_rows(self.leaves)); let targets = targets .into_iter() @@ -535,7 +566,11 @@ impl Pollard { /// Creates a new empty [Pollard] pub fn new() -> Pollard { let roots: [Option>>; 64] = std::array::from_fn(|_| None); - Pollard:: { roots, leaves: 0 } + Pollard:: { + roots, + leaves: 0, + leaf_map: HashMap::new(), + } } } @@ -632,6 +667,7 @@ impl Pollard { .copied() .collect::>(); + self.map_targets(remembers)?; self.prune(&pruned)?; if recompute { @@ -643,6 +679,17 @@ impl Pollard { Ok(()) } + fn map_targets(&mut self, targets: &[u64]) -> Result<(), String> { + for target in targets { + let node = self + .grab_position(*target) + .ok_or(format!("Position {target} not found"))?; + self.leaf_map.insert(node.0.hash(), Rc::downgrade(&node.0)); + } + + Ok(()) + } + fn detect_offset(pos: u64, num_leaves: u64) -> (u8, u8, u64) { let mut tr = tree_rows(num_leaves); let nr = detect_row(pos, tr); @@ -758,6 +805,7 @@ impl Pollard { fn add_single(&mut self, node: PollardAddition) -> Result, String> { let mut row = 0; let mut new_node = PollardNode::new(node.hash, node.remember); + self.leaf_map.insert(node.hash, Rc::downgrade(&new_node)); let mut add_positions = Vec::new(); let mut roots_to_destroy = Vec::new(); @@ -830,6 +878,7 @@ impl Pollard { } fn delete_single(&mut self, node: Rc>) -> Result<(), String> { + self.leaf_map.remove(&node.hash()); // we are deleting a root, just write an empty hash where it was if node.aunt.borrow().is_none() { for i in 0..64 { @@ -866,6 +915,65 @@ impl Pollard { sibling.migrate_up().unwrap(); Ok(()) } + + /// Returns the position in the tree of this node + fn get_pos(&self, node: &Weak>) -> Result { + // This indicates whether the node is a left or right child at each level + // When we go down the tree, we can use the indicator to know which + // child to take. + let mut left_child_indicator = 0_u64; + let mut rows_to_top = 0; + let mut node = node + .upgrade() + .ok_or("Could not upgrade node. Is this reference valid?")?; + + while let Some(aunt) = node.parent() { + let aunt_left = aunt.children().ok_or("Aunt has no children")?.0; + // If the current node is a left child, we left-shift the indicator + // and leave the LSB as 0 + if aunt_left.hash() == node.hash() { + left_child_indicator <<= 1; + } else { + // If the current node is a right child, we left-shift the indicator + // and set the LSB to 1 + left_child_indicator <<= 1; + left_child_indicator |= 1; + } + rows_to_top += 1; + node = aunt; + } + + let root_row = self.roots.iter().position(|root| { + if let Some(root) = root { + return root.hash() == node.hash(); + } + + false + }); + + let forest_rows = tree_rows(self.leaves); + let root_row = root_row.ok_or(format!( + "Could not find the root position for root {:?}", + node.hash() + ))?; + + let mut pos = root_position(self.leaves, root_row as u8, forest_rows); + for _ in 0..rows_to_top { + // If LSB is 0, go left, otherwise go right + match left_child_indicator & 1 { + 0 => { + pos = left_child(pos, forest_rows); + } + 1 => { + pos = right_child(pos, forest_rows); + } + _ => unreachable!(), + } + left_child_indicator >>= 1; + } + + Ok(pos) + } } #[cfg(test)] @@ -1063,7 +1171,7 @@ mod tests { let mut acc = Pollard::::new(); acc.modify(&hashes, &[], Proof::default()).unwrap(); - let proof = acc.prove_single(3).unwrap(); + let proof = acc.prove_single(hashes[3].hash).unwrap(); let expected_hashes = [ "dbc1b4c900ffe48d575b5da5c638040125f65db0fe3e24494b76ea986457d986", "02242b37d8e851f1e86f46790298c7097df06893d6226b7c1453c213e91717de", @@ -1080,6 +1188,66 @@ mod tests { assert_eq!(proof, expected_proof); } + fn get_hashes_of(values: &[u8]) -> Vec> { + values + .iter() + .map(|preimage| { + let hash = hash_from_u8(*preimage); + PollardAddition { + hash, + remember: true, + } + }) + .collect() + } + + #[test] + fn test_get_pos() { + macro_rules! test_get_pos { + ($p:ident, $pos:literal) => { + let node = $p.grab_position($pos).unwrap().0; + assert_eq!( + $p.get_pos(&Rc::downgrade(&node)), + Ok($pos), + "Failed to get position of node {:?}", + node + ); + }; + } + + let hashes = get_hashes_of(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]); + let mut p = Pollard::new(); + p.modify(&hashes, &[], Proof::default()) + .expect("Test mem_forests are valid"); + test_get_pos!(p, 0); + test_get_pos!(p, 1); + test_get_pos!(p, 2); + test_get_pos!(p, 3); + test_get_pos!(p, 4); + test_get_pos!(p, 5); + test_get_pos!(p, 6); + test_get_pos!(p, 7); + test_get_pos!(p, 8); + test_get_pos!(p, 9); + test_get_pos!(p, 10); + test_get_pos!(p, 11); + test_get_pos!(p, 12); + + let root = p.roots[3].as_ref().unwrap(); + let left = root.left_niece().unwrap(); + let right = root.right_niece().unwrap(); + + assert_eq!(p.get_pos(&Rc::downgrade(&root)), Ok(28)); + assert_eq!( + p.get_pos(&Rc::downgrade(&left)), + Ok(24) + ); + assert_eq!( + p.get_pos(&Rc::downgrade(&right)), + Ok(25) + ); + } + #[test] fn test_ingest_proof() { let values = [0, 1, 2, 3, 4, 5, 6, 7] @@ -1109,7 +1277,8 @@ mod tests { acc.modify(&values, &[], Proof::default()).unwrap(); acc.ingest_proof(proof.clone(), &[hash_from_u8(3)], &[3]) .unwrap(); - let new_proof = acc.prove_single(3).unwrap(); + + let new_proof = acc.prove_single(values[3].hash).unwrap(); assert_eq!(new_proof, proof); let node = acc.grab_position(3).unwrap().0; From 35602afd6b35752ba808db930ed567c126783939 Mon Sep 17 00:00:00 2001 From: Davidson Souza Date: Wed, 22 Jan 2025 11:44:01 -0300 Subject: [PATCH 2/2] Pol: ask for hashes instead of positions + tests This commit makes all proving-related methods take leaf hashes instead of positions, and adds additional tests to proving and ingesting proofs --- src/accumulator/pollard.rs | 140 +++++++++++++++++++++++++++++-------- 1 file changed, 112 insertions(+), 28 deletions(-) diff --git a/src/accumulator/pollard.rs b/src/accumulator/pollard.rs index e128c20..12853ad 100644 --- a/src/accumulator/pollard.rs +++ b/src/accumulator/pollard.rs @@ -430,6 +430,11 @@ impl Pollard { self.do_ingest_proof(proof, del_hashes, remembers, false) } + pub fn verify(&self, proof: &Proof, del_hashes: &[Hash]) -> Result { + let roots = self.roots(); + proof.verify(del_hashes, &roots, self.leaves) + } + pub fn verify_and_ingest( &mut self, proof: Proof, @@ -451,6 +456,8 @@ impl Pollard { } pub fn prune(&mut self, positions: &[u64]) -> Result<(), &'static str> { + self.prune_map(positions); + let positions = detwin(positions.to_vec(), tree_rows(self.leaves)); let nodes = positions .into_iter() @@ -481,21 +488,21 @@ impl Pollard { /// /// This function takes a list of positions and returns a list of proofs for each position. pub fn batch_proof(&self, targets: &[Hash]) -> Result, String> { - let mut positions = Vec::new(); + let mut target_positions = Vec::new(); for target in targets { let node = self .leaf_map .get(target) .ok_or(format!("leaf {target} not found in the forest"))?; let position = self.get_pos(node)?; - positions.push(position); + target_positions.push(position); } - let targets = detwin(positions, tree_rows(self.leaves)); - let positions = get_proof_positions(&targets, self.leaves, tree_rows(self.leaves)); + let proof_positions = + get_proof_positions(&target_positions, self.leaves, tree_rows(self.leaves)); let mut proof_hashes = Vec::new(); - for pos in positions.iter() { + for pos in proof_positions.iter() { let hash = self .grab_position(*pos) .ok_or("Position not found")? @@ -507,7 +514,7 @@ impl Pollard { Ok(Proof:: { hashes: proof_hashes, - targets: positions, + targets: target_positions, }) } @@ -581,6 +588,13 @@ type AddSingleResult = (Vec<(u64, T)>, Vec); type ChildrenTuple = (Rc>, Rc>); impl Pollard { + fn prune_map(&mut self, positions: &[u64]) { + for pos in positions { + let node = self.grab_position(*pos).unwrap().0; + self.leaf_map.remove(&node.hash()); + } + } + fn grab_position(&self, pos: u64) -> Option> { let (root, depth, bits) = Self::detect_offset(pos, self.leaves); let mut node = self.roots[root as usize].clone()?; @@ -608,6 +622,7 @@ impl Pollard { fn ingest_positions( &mut self, mut iter: impl Iterator, + remembers: &[u64], ) -> Result<(), String> { let forest_rows = tree_rows(self.leaves); while let Some((pos1, hash1)) = iter.next() { @@ -638,6 +653,11 @@ impl Pollard { new_node.set_aunt(Rc::downgrade(&aunt)); new_sibling.set_aunt(Rc::downgrade(&aunt)); + if remembers.contains(&pos1) || remembers.contains(&pos2) { + self.leaf_map.insert(hash1, Rc::downgrade(&new_node)); + self.leaf_map.insert(hash2, Rc::downgrade(&new_sibling)); + } + aunt.set_niece(Some(new_sibling), Some(new_node)); } @@ -658,7 +678,7 @@ impl Pollard { all_nodes.extend(proof_positions.into_iter().zip(proof.hashes.clone())); all_nodes.sort(); let iter = all_nodes.into_iter().rev(); - self.ingest_positions(iter)?; + self.ingest_positions(iter, remembers)?; let pruned = proof .targets @@ -667,7 +687,6 @@ impl Pollard { .copied() .collect::>(); - self.map_targets(remembers)?; self.prune(&pruned)?; if recompute { @@ -679,17 +698,6 @@ impl Pollard { Ok(()) } - fn map_targets(&mut self, targets: &[u64]) -> Result<(), String> { - for target in targets { - let node = self - .grab_position(*target) - .ok_or(format!("Position {target} not found"))?; - self.leaf_map.insert(node.0.hash(), Rc::downgrade(&node.0)); - } - - Ok(()) - } - fn detect_offset(pos: u64, num_leaves: u64) -> (u8, u8, u64) { let mut tr = tree_rows(num_leaves); let nr = detect_row(pos, tr); @@ -1156,6 +1164,88 @@ mod tests { assert_eq!(root.unwrap().hash(), hashes[0].hash); } + #[test] + fn test_ingest_proof_and_prove() { + // this test will create a forest, prove a few leaves, prune all leaves, ingest the proof + // and prove the same leaves + siblings again + let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; + let hashes: Vec<_> = values + .into_iter() + .map(|preimage| { + let hash = hash_from_u8(preimage); + PollardAddition { + hash, + remember: true, + } + }) + .collect(); + + let mut acc = Pollard::::new(); + acc.modify(&hashes, &[], Proof::default()).unwrap(); + + let del_hashes = [ + hash_from_u8(2), + hash_from_u8(1), + hash_from_u8(4), + hash_from_u8(6), + ]; + let proof = acc.batch_proof(&del_hashes).unwrap(); + + acc.prune(&[0, 1, 2, 3, 4, 5, 6, 7]).unwrap(); + acc.ingest_proof(proof, &del_hashes, &[2, 1, 4, 6]).unwrap(); + + let del_hashes = [0, 1, 4, 5, 6, 7] + .iter() + .map(|x| hash_from_u8(*x)) + .collect::>(); + let proof = acc.batch_proof(&del_hashes).unwrap(); + assert!(acc.verify(&proof, &del_hashes).unwrap()); + } + #[test] + fn test_prove() { + let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; + let hashes: Vec<_> = values + .into_iter() + .map(|preimage| { + let hash = hash_from_u8(preimage); + PollardAddition { + hash, + remember: true, + } + }) + .collect(); + + let mut acc = Pollard::::new(); + acc.modify(&hashes, &[], Proof::default()).unwrap(); + let del_hashes = [ + hash_from_u8(2), + hash_from_u8(1), + hash_from_u8(4), + hash_from_u8(6), + ]; + let proof = acc.batch_proof(&del_hashes).unwrap(); + let expected_proof = Proof::new( + [2, 1, 4, 6].to_vec(), + vec![ + "6e340b9cffb37a989ca544e6bb780a2c78901d3fb33738768511a30617afa01d" + .parse() + .unwrap(), + "084fed08b978af4d7d196a7446a86b58009e636b611db16211b65a9aadff29c5" + .parse() + .unwrap(), + "e77b9a9ae9e30b0dbdb6f510a264ef9de781501d7b6b92ae89eb059c5ab743db" + .parse() + .unwrap(), + "ca358758f6d27e6cf45272937977a748fd88391db679ceda7dc7bf1f005ee879" + .parse() + .unwrap(), + ], + ); + + assert_eq!(proof, expected_proof); + assert!(acc.verify(&proof, &del_hashes).unwrap()); + } + #[test] fn test_prove_single() { let values = vec![0, 1, 2, 3, 4, 5]; @@ -1232,20 +1322,14 @@ mod tests { test_get_pos!(p, 10); test_get_pos!(p, 11); test_get_pos!(p, 12); - + let root = p.roots[3].as_ref().unwrap(); let left = root.left_niece().unwrap(); let right = root.right_niece().unwrap(); assert_eq!(p.get_pos(&Rc::downgrade(&root)), Ok(28)); - assert_eq!( - p.get_pos(&Rc::downgrade(&left)), - Ok(24) - ); - assert_eq!( - p.get_pos(&Rc::downgrade(&right)), - Ok(25) - ); + assert_eq!(p.get_pos(&Rc::downgrade(&left)), Ok(24)); + assert_eq!(p.get_pos(&Rc::downgrade(&right)), Ok(25)); } #[test]