diff --git a/src/accumulator/pollard.rs b/src/accumulator/pollard.rs index 7572bf7..e128c20 100644 --- a/src/accumulator/pollard.rs +++ b/src/accumulator/pollard.rs @@ -36,6 +36,7 @@ /// //TODO: Add usage examples use std::cell::Cell; use std::cell::RefCell; +use std::collections::HashMap; use std::fmt::Debug; use std::fmt::Display; use std::rc::Rc; @@ -47,8 +48,10 @@ use super::util::detect_row; use super::util::detwin; use super::util::get_proof_positions; use super::util::is_root_position; +use super::util::left_child; use super::util::max_position_at_row; use super::util::parent; +use super::util::right_child; use super::util::root_position; use super::util::tree_rows; @@ -121,6 +124,20 @@ impl PollardNode { }) } + fn parent(&self) -> Option>> { + let granparent = self.grandparent(); + if granparent.is_none() { + return self.aunt(); + } + + let granparent = granparent.unwrap(); + if granparent.left_niece().eq(&self.aunt()) { + granparent.right_niece() + } else { + granparent.left_niece() + } + } + /// Returns the hash of this node fn hash(&self) -> Hash { self.hash.get() @@ -353,6 +370,7 @@ pub struct Pollard { /// add 5 leaves and delete 4, this value will still be 5. Moreover, the position of a leaf is /// the number of leaves when it was added, so we can always find a leaf by it's position. leaves: u64, + leaf_map: HashMap>>, } impl PartialEq for Pollard { @@ -432,7 +450,7 @@ impl Pollard { self.do_ingest_proof(proof, del_hashes, remembers, false) } - pub fn prune(&self, positions: &[u64]) -> Result<(), &'static str> { + pub fn prune(&mut self, positions: &[u64]) -> Result<(), &'static str> { let positions = detwin(positions.to_vec(), tree_rows(self.leaves)); let nodes = positions .into_iter() @@ -440,8 +458,9 @@ impl Pollard { .collect::>(); for node in nodes { - let node = node.ok_or("Position not found")?; - node.0.prune(); + let (node, _) = node.ok_or("Position not found")?; + self.leaf_map.remove(&node.hash()); + node.prune(); } Ok(()) @@ -461,8 +480,18 @@ impl Pollard { /// Proves the inclusion of the nodes at the given positions /// /// This function takes a list of positions and returns a list of proofs for each position. - pub fn batch_proof(&self, targets: &[u64]) -> Result, &'static str> { - let targets = detwin(targets.to_vec(), tree_rows(self.leaves)); + pub fn batch_proof(&self, targets: &[Hash]) -> Result, String> { + let mut positions = Vec::new(); + for target in targets { + let node = self + .leaf_map + .get(target) + .ok_or(format!("leaf {target} not found in the forest"))?; + let position = self.get_pos(node)?; + positions.push(position); + } + + let targets = detwin(positions, tree_rows(self.leaves)); let positions = get_proof_positions(&targets, self.leaves, tree_rows(self.leaves)); let mut proof_hashes = Vec::new(); @@ -482,7 +511,9 @@ impl Pollard { }) } - pub fn prove_single(&self, pos: u64) -> Result, &'static str> { + pub fn prove_single(&self, leaf: Hash) -> Result, String> { + let node = self.leaf_map.get(&leaf).ok_or("Leaf not found")?; + let pos = self.get_pos(node)?; let hashes = self.prove_single_inner(pos)?; let targets = vec![pos]; @@ -506,8 +537,8 @@ impl Pollard { proof: Proof, ) -> Result<(), String> { let targets = proof.targets.clone(); - self.ingest_proof(proof.clone(), del_hashes, &targets) - .unwrap(); + self.ingest_proof(proof.clone(), del_hashes, &targets)?; + let targets = detwin(targets, tree_rows(self.leaves)); let targets = targets .into_iter() @@ -535,7 +566,11 @@ impl Pollard { /// Creates a new empty [Pollard] pub fn new() -> Pollard { let roots: [Option>>; 64] = std::array::from_fn(|_| None); - Pollard:: { roots, leaves: 0 } + Pollard:: { + roots, + leaves: 0, + leaf_map: HashMap::new(), + } } } @@ -632,6 +667,7 @@ impl Pollard { .copied() .collect::>(); + self.map_targets(remembers)?; self.prune(&pruned)?; if recompute { @@ -643,6 +679,17 @@ impl Pollard { Ok(()) } + fn map_targets(&mut self, targets: &[u64]) -> Result<(), String> { + for target in targets { + let node = self + .grab_position(*target) + .ok_or(format!("Position {target} not found"))?; + self.leaf_map.insert(node.0.hash(), Rc::downgrade(&node.0)); + } + + Ok(()) + } + fn detect_offset(pos: u64, num_leaves: u64) -> (u8, u8, u64) { let mut tr = tree_rows(num_leaves); let nr = detect_row(pos, tr); @@ -758,6 +805,7 @@ impl Pollard { fn add_single(&mut self, node: PollardAddition) -> Result, String> { let mut row = 0; let mut new_node = PollardNode::new(node.hash, node.remember); + self.leaf_map.insert(node.hash, Rc::downgrade(&new_node)); let mut add_positions = Vec::new(); let mut roots_to_destroy = Vec::new(); @@ -830,6 +878,7 @@ impl Pollard { } fn delete_single(&mut self, node: Rc>) -> Result<(), String> { + self.leaf_map.remove(&node.hash()); // we are deleting a root, just write an empty hash where it was if node.aunt.borrow().is_none() { for i in 0..64 { @@ -866,6 +915,65 @@ impl Pollard { sibling.migrate_up().unwrap(); Ok(()) } + + /// Returns the position in the tree of this node + fn get_pos(&self, node: &Weak>) -> Result { + // This indicates whether the node is a left or right child at each level + // When we go down the tree, we can use the indicator to know which + // child to take. + let mut left_child_indicator = 0_u64; + let mut rows_to_top = 0; + let mut node = node + .upgrade() + .ok_or("Could not upgrade node. Is this reference valid?")?; + + while let Some(aunt) = node.parent() { + let aunt_left = aunt.children().ok_or("Aunt has no children")?.0; + // If the current node is a left child, we left-shift the indicator + // and leave the LSB as 0 + if aunt_left.hash() == node.hash() { + left_child_indicator <<= 1; + } else { + // If the current node is a right child, we left-shift the indicator + // and set the LSB to 1 + left_child_indicator <<= 1; + left_child_indicator |= 1; + } + rows_to_top += 1; + node = aunt; + } + + let root_row = self.roots.iter().position(|root| { + if let Some(root) = root { + return root.hash() == node.hash(); + } + + false + }); + + let forest_rows = tree_rows(self.leaves); + let root_row = root_row.ok_or(format!( + "Could not find the root position for root {:?}", + node.hash() + ))?; + + let mut pos = root_position(self.leaves, root_row as u8, forest_rows); + for _ in 0..rows_to_top { + // If LSB is 0, go left, otherwise go right + match left_child_indicator & 1 { + 0 => { + pos = left_child(pos, forest_rows); + } + 1 => { + pos = right_child(pos, forest_rows); + } + _ => unreachable!(), + } + left_child_indicator >>= 1; + } + + Ok(pos) + } } #[cfg(test)] @@ -1063,7 +1171,7 @@ mod tests { let mut acc = Pollard::::new(); acc.modify(&hashes, &[], Proof::default()).unwrap(); - let proof = acc.prove_single(3).unwrap(); + let proof = acc.prove_single(hashes[3].hash).unwrap(); let expected_hashes = [ "dbc1b4c900ffe48d575b5da5c638040125f65db0fe3e24494b76ea986457d986", "02242b37d8e851f1e86f46790298c7097df06893d6226b7c1453c213e91717de", @@ -1080,6 +1188,66 @@ mod tests { assert_eq!(proof, expected_proof); } + fn get_hashes_of(values: &[u8]) -> Vec> { + values + .iter() + .map(|preimage| { + let hash = hash_from_u8(*preimage); + PollardAddition { + hash, + remember: true, + } + }) + .collect() + } + + #[test] + fn test_get_pos() { + macro_rules! test_get_pos { + ($p:ident, $pos:literal) => { + let node = $p.grab_position($pos).unwrap().0; + assert_eq!( + $p.get_pos(&Rc::downgrade(&node)), + Ok($pos), + "Failed to get position of node {:?}", + node + ); + }; + } + + let hashes = get_hashes_of(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]); + let mut p = Pollard::new(); + p.modify(&hashes, &[], Proof::default()) + .expect("Test mem_forests are valid"); + test_get_pos!(p, 0); + test_get_pos!(p, 1); + test_get_pos!(p, 2); + test_get_pos!(p, 3); + test_get_pos!(p, 4); + test_get_pos!(p, 5); + test_get_pos!(p, 6); + test_get_pos!(p, 7); + test_get_pos!(p, 8); + test_get_pos!(p, 9); + test_get_pos!(p, 10); + test_get_pos!(p, 11); + test_get_pos!(p, 12); + + let root = p.roots[3].as_ref().unwrap(); + let left = root.left_niece().unwrap(); + let right = root.right_niece().unwrap(); + + assert_eq!(p.get_pos(&Rc::downgrade(&root)), Ok(28)); + assert_eq!( + p.get_pos(&Rc::downgrade(&left)), + Ok(24) + ); + assert_eq!( + p.get_pos(&Rc::downgrade(&right)), + Ok(25) + ); + } + #[test] fn test_ingest_proof() { let values = [0, 1, 2, 3, 4, 5, 6, 7] @@ -1109,7 +1277,8 @@ mod tests { acc.modify(&values, &[], Proof::default()).unwrap(); acc.ingest_proof(proof.clone(), &[hash_from_u8(3)], &[3]) .unwrap(); - let new_proof = acc.prove_single(3).unwrap(); + + let new_proof = acc.prove_single(values[3].hash).unwrap(); assert_eq!(new_proof, proof); let node = acc.grab_position(3).unwrap().0;