From 386478471b7f7412c496a2492d49179485286eb0 Mon Sep 17 00:00:00 2001 From: Davidson Souza Date: Tue, 21 Jan 2025 15:08:23 -0300 Subject: [PATCH] WIP: add a leaf map When proving, we currently ask for positions. However, positions may change during updates, so externally holding those positions may be cumbersome. The leaf hash, on the other hand never changes. This commit does the same we do for `MemForest`, creating a map from leaf hashes to a reference to the actual leaf in memory. Once we need to prove it, we compute the position on-the-fly. --- src/accumulator/pollard.rs | 191 ++++++++++++++++++++++++++++++++++--- 1 file changed, 180 insertions(+), 11 deletions(-) diff --git a/src/accumulator/pollard.rs b/src/accumulator/pollard.rs index 7572bf7..e128c20 100644 --- a/src/accumulator/pollard.rs +++ b/src/accumulator/pollard.rs @@ -36,6 +36,7 @@ /// //TODO: Add usage examples use std::cell::Cell; use std::cell::RefCell; +use std::collections::HashMap; use std::fmt::Debug; use std::fmt::Display; use std::rc::Rc; @@ -47,8 +48,10 @@ use super::util::detect_row; use super::util::detwin; use super::util::get_proof_positions; use super::util::is_root_position; +use super::util::left_child; use super::util::max_position_at_row; use super::util::parent; +use super::util::right_child; use super::util::root_position; use super::util::tree_rows; @@ -121,6 +124,20 @@ impl PollardNode { }) } + fn parent(&self) -> Option>> { + let granparent = self.grandparent(); + if granparent.is_none() { + return self.aunt(); + } + + let granparent = granparent.unwrap(); + if granparent.left_niece().eq(&self.aunt()) { + granparent.right_niece() + } else { + granparent.left_niece() + } + } + /// Returns the hash of this node fn hash(&self) -> Hash { self.hash.get() @@ -353,6 +370,7 @@ pub struct Pollard { /// add 5 leaves and delete 4, this value will still be 5. Moreover, the position of a leaf is /// the number of leaves when it was added, so we can always find a leaf by it's position. leaves: u64, + leaf_map: HashMap>>, } impl PartialEq for Pollard { @@ -432,7 +450,7 @@ impl Pollard { self.do_ingest_proof(proof, del_hashes, remembers, false) } - pub fn prune(&self, positions: &[u64]) -> Result<(), &'static str> { + pub fn prune(&mut self, positions: &[u64]) -> Result<(), &'static str> { let positions = detwin(positions.to_vec(), tree_rows(self.leaves)); let nodes = positions .into_iter() @@ -440,8 +458,9 @@ impl Pollard { .collect::>(); for node in nodes { - let node = node.ok_or("Position not found")?; - node.0.prune(); + let (node, _) = node.ok_or("Position not found")?; + self.leaf_map.remove(&node.hash()); + node.prune(); } Ok(()) @@ -461,8 +480,18 @@ impl Pollard { /// Proves the inclusion of the nodes at the given positions /// /// This function takes a list of positions and returns a list of proofs for each position. - pub fn batch_proof(&self, targets: &[u64]) -> Result, &'static str> { - let targets = detwin(targets.to_vec(), tree_rows(self.leaves)); + pub fn batch_proof(&self, targets: &[Hash]) -> Result, String> { + let mut positions = Vec::new(); + for target in targets { + let node = self + .leaf_map + .get(target) + .ok_or(format!("leaf {target} not found in the forest"))?; + let position = self.get_pos(node)?; + positions.push(position); + } + + let targets = detwin(positions, tree_rows(self.leaves)); let positions = get_proof_positions(&targets, self.leaves, tree_rows(self.leaves)); let mut proof_hashes = Vec::new(); @@ -482,7 +511,9 @@ impl Pollard { }) } - pub fn prove_single(&self, pos: u64) -> Result, &'static str> { + pub fn prove_single(&self, leaf: Hash) -> Result, String> { + let node = self.leaf_map.get(&leaf).ok_or("Leaf not found")?; + let pos = self.get_pos(node)?; let hashes = self.prove_single_inner(pos)?; let targets = vec![pos]; @@ -506,8 +537,8 @@ impl Pollard { proof: Proof, ) -> Result<(), String> { let targets = proof.targets.clone(); - self.ingest_proof(proof.clone(), del_hashes, &targets) - .unwrap(); + self.ingest_proof(proof.clone(), del_hashes, &targets)?; + let targets = detwin(targets, tree_rows(self.leaves)); let targets = targets .into_iter() @@ -535,7 +566,11 @@ impl Pollard { /// Creates a new empty [Pollard] pub fn new() -> Pollard { let roots: [Option>>; 64] = std::array::from_fn(|_| None); - Pollard:: { roots, leaves: 0 } + Pollard:: { + roots, + leaves: 0, + leaf_map: HashMap::new(), + } } } @@ -632,6 +667,7 @@ impl Pollard { .copied() .collect::>(); + self.map_targets(remembers)?; self.prune(&pruned)?; if recompute { @@ -643,6 +679,17 @@ impl Pollard { Ok(()) } + fn map_targets(&mut self, targets: &[u64]) -> Result<(), String> { + for target in targets { + let node = self + .grab_position(*target) + .ok_or(format!("Position {target} not found"))?; + self.leaf_map.insert(node.0.hash(), Rc::downgrade(&node.0)); + } + + Ok(()) + } + fn detect_offset(pos: u64, num_leaves: u64) -> (u8, u8, u64) { let mut tr = tree_rows(num_leaves); let nr = detect_row(pos, tr); @@ -758,6 +805,7 @@ impl Pollard { fn add_single(&mut self, node: PollardAddition) -> Result, String> { let mut row = 0; let mut new_node = PollardNode::new(node.hash, node.remember); + self.leaf_map.insert(node.hash, Rc::downgrade(&new_node)); let mut add_positions = Vec::new(); let mut roots_to_destroy = Vec::new(); @@ -830,6 +878,7 @@ impl Pollard { } fn delete_single(&mut self, node: Rc>) -> Result<(), String> { + self.leaf_map.remove(&node.hash()); // we are deleting a root, just write an empty hash where it was if node.aunt.borrow().is_none() { for i in 0..64 { @@ -866,6 +915,65 @@ impl Pollard { sibling.migrate_up().unwrap(); Ok(()) } + + /// Returns the position in the tree of this node + fn get_pos(&self, node: &Weak>) -> Result { + // This indicates whether the node is a left or right child at each level + // When we go down the tree, we can use the indicator to know which + // child to take. + let mut left_child_indicator = 0_u64; + let mut rows_to_top = 0; + let mut node = node + .upgrade() + .ok_or("Could not upgrade node. Is this reference valid?")?; + + while let Some(aunt) = node.parent() { + let aunt_left = aunt.children().ok_or("Aunt has no children")?.0; + // If the current node is a left child, we left-shift the indicator + // and leave the LSB as 0 + if aunt_left.hash() == node.hash() { + left_child_indicator <<= 1; + } else { + // If the current node is a right child, we left-shift the indicator + // and set the LSB to 1 + left_child_indicator <<= 1; + left_child_indicator |= 1; + } + rows_to_top += 1; + node = aunt; + } + + let root_row = self.roots.iter().position(|root| { + if let Some(root) = root { + return root.hash() == node.hash(); + } + + false + }); + + let forest_rows = tree_rows(self.leaves); + let root_row = root_row.ok_or(format!( + "Could not find the root position for root {:?}", + node.hash() + ))?; + + let mut pos = root_position(self.leaves, root_row as u8, forest_rows); + for _ in 0..rows_to_top { + // If LSB is 0, go left, otherwise go right + match left_child_indicator & 1 { + 0 => { + pos = left_child(pos, forest_rows); + } + 1 => { + pos = right_child(pos, forest_rows); + } + _ => unreachable!(), + } + left_child_indicator >>= 1; + } + + Ok(pos) + } } #[cfg(test)] @@ -1063,7 +1171,7 @@ mod tests { let mut acc = Pollard::::new(); acc.modify(&hashes, &[], Proof::default()).unwrap(); - let proof = acc.prove_single(3).unwrap(); + let proof = acc.prove_single(hashes[3].hash).unwrap(); let expected_hashes = [ "dbc1b4c900ffe48d575b5da5c638040125f65db0fe3e24494b76ea986457d986", "02242b37d8e851f1e86f46790298c7097df06893d6226b7c1453c213e91717de", @@ -1080,6 +1188,66 @@ mod tests { assert_eq!(proof, expected_proof); } + fn get_hashes_of(values: &[u8]) -> Vec> { + values + .iter() + .map(|preimage| { + let hash = hash_from_u8(*preimage); + PollardAddition { + hash, + remember: true, + } + }) + .collect() + } + + #[test] + fn test_get_pos() { + macro_rules! test_get_pos { + ($p:ident, $pos:literal) => { + let node = $p.grab_position($pos).unwrap().0; + assert_eq!( + $p.get_pos(&Rc::downgrade(&node)), + Ok($pos), + "Failed to get position of node {:?}", + node + ); + }; + } + + let hashes = get_hashes_of(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]); + let mut p = Pollard::new(); + p.modify(&hashes, &[], Proof::default()) + .expect("Test mem_forests are valid"); + test_get_pos!(p, 0); + test_get_pos!(p, 1); + test_get_pos!(p, 2); + test_get_pos!(p, 3); + test_get_pos!(p, 4); + test_get_pos!(p, 5); + test_get_pos!(p, 6); + test_get_pos!(p, 7); + test_get_pos!(p, 8); + test_get_pos!(p, 9); + test_get_pos!(p, 10); + test_get_pos!(p, 11); + test_get_pos!(p, 12); + + let root = p.roots[3].as_ref().unwrap(); + let left = root.left_niece().unwrap(); + let right = root.right_niece().unwrap(); + + assert_eq!(p.get_pos(&Rc::downgrade(&root)), Ok(28)); + assert_eq!( + p.get_pos(&Rc::downgrade(&left)), + Ok(24) + ); + assert_eq!( + p.get_pos(&Rc::downgrade(&right)), + Ok(25) + ); + } + #[test] fn test_ingest_proof() { let values = [0, 1, 2, 3, 4, 5, 6, 7] @@ -1109,7 +1277,8 @@ mod tests { acc.modify(&values, &[], Proof::default()).unwrap(); acc.ingest_proof(proof.clone(), &[hash_from_u8(3)], &[3]) .unwrap(); - let new_proof = acc.prove_single(3).unwrap(); + + let new_proof = acc.prove_single(values[3].hash).unwrap(); assert_eq!(new_proof, proof); let node = acc.grab_position(3).unwrap().0;