diff --git a/src/accumulator/pollard.rs b/src/accumulator/pollard.rs index 7572bf7..12853ad 100644 --- a/src/accumulator/pollard.rs +++ b/src/accumulator/pollard.rs @@ -36,6 +36,7 @@ /// //TODO: Add usage examples use std::cell::Cell; use std::cell::RefCell; +use std::collections::HashMap; use std::fmt::Debug; use std::fmt::Display; use std::rc::Rc; @@ -47,8 +48,10 @@ use super::util::detect_row; use super::util::detwin; use super::util::get_proof_positions; use super::util::is_root_position; +use super::util::left_child; use super::util::max_position_at_row; use super::util::parent; +use super::util::right_child; use super::util::root_position; use super::util::tree_rows; @@ -121,6 +124,20 @@ impl PollardNode { }) } + fn parent(&self) -> Option>> { + let granparent = self.grandparent(); + if granparent.is_none() { + return self.aunt(); + } + + let granparent = granparent.unwrap(); + if granparent.left_niece().eq(&self.aunt()) { + granparent.right_niece() + } else { + granparent.left_niece() + } + } + /// Returns the hash of this node fn hash(&self) -> Hash { self.hash.get() @@ -353,6 +370,7 @@ pub struct Pollard { /// add 5 leaves and delete 4, this value will still be 5. Moreover, the position of a leaf is /// the number of leaves when it was added, so we can always find a leaf by it's position. leaves: u64, + leaf_map: HashMap>>, } impl PartialEq for Pollard { @@ -412,6 +430,11 @@ impl Pollard { self.do_ingest_proof(proof, del_hashes, remembers, false) } + pub fn verify(&self, proof: &Proof, del_hashes: &[Hash]) -> Result { + let roots = self.roots(); + proof.verify(del_hashes, &roots, self.leaves) + } + pub fn verify_and_ingest( &mut self, proof: Proof, @@ -432,7 +455,9 @@ impl Pollard { self.do_ingest_proof(proof, del_hashes, remembers, false) } - pub fn prune(&self, positions: &[u64]) -> Result<(), &'static str> { + pub fn prune(&mut self, positions: &[u64]) -> Result<(), &'static str> { + self.prune_map(positions); + let positions = detwin(positions.to_vec(), tree_rows(self.leaves)); let nodes = positions .into_iter() @@ -440,8 +465,9 @@ impl Pollard { .collect::>(); for node in nodes { - let node = node.ok_or("Position not found")?; - node.0.prune(); + let (node, _) = node.ok_or("Position not found")?; + self.leaf_map.remove(&node.hash()); + node.prune(); } Ok(()) @@ -461,12 +487,22 @@ impl Pollard { /// Proves the inclusion of the nodes at the given positions /// /// This function takes a list of positions and returns a list of proofs for each position. - pub fn batch_proof(&self, targets: &[u64]) -> Result, &'static str> { - let targets = detwin(targets.to_vec(), tree_rows(self.leaves)); - let positions = get_proof_positions(&targets, self.leaves, tree_rows(self.leaves)); + pub fn batch_proof(&self, targets: &[Hash]) -> Result, String> { + let mut target_positions = Vec::new(); + for target in targets { + let node = self + .leaf_map + .get(target) + .ok_or(format!("leaf {target} not found in the forest"))?; + let position = self.get_pos(node)?; + target_positions.push(position); + } + + let proof_positions = + get_proof_positions(&target_positions, self.leaves, tree_rows(self.leaves)); let mut proof_hashes = Vec::new(); - for pos in positions.iter() { + for pos in proof_positions.iter() { let hash = self .grab_position(*pos) .ok_or("Position not found")? @@ -478,11 +514,13 @@ impl Pollard { Ok(Proof:: { hashes: proof_hashes, - targets: positions, + targets: target_positions, }) } - pub fn prove_single(&self, pos: u64) -> Result, &'static str> { + pub fn prove_single(&self, leaf: Hash) -> Result, String> { + let node = self.leaf_map.get(&leaf).ok_or("Leaf not found")?; + let pos = self.get_pos(node)?; let hashes = self.prove_single_inner(pos)?; let targets = vec![pos]; @@ -506,8 +544,8 @@ impl Pollard { proof: Proof, ) -> Result<(), String> { let targets = proof.targets.clone(); - self.ingest_proof(proof.clone(), del_hashes, &targets) - .unwrap(); + self.ingest_proof(proof.clone(), del_hashes, &targets)?; + let targets = detwin(targets, tree_rows(self.leaves)); let targets = targets .into_iter() @@ -535,7 +573,11 @@ impl Pollard { /// Creates a new empty [Pollard] pub fn new() -> Pollard { let roots: [Option>>; 64] = std::array::from_fn(|_| None); - Pollard:: { roots, leaves: 0 } + Pollard:: { + roots, + leaves: 0, + leaf_map: HashMap::new(), + } } } @@ -546,6 +588,13 @@ type AddSingleResult = (Vec<(u64, T)>, Vec); type ChildrenTuple = (Rc>, Rc>); impl Pollard { + fn prune_map(&mut self, positions: &[u64]) { + for pos in positions { + let node = self.grab_position(*pos).unwrap().0; + self.leaf_map.remove(&node.hash()); + } + } + fn grab_position(&self, pos: u64) -> Option> { let (root, depth, bits) = Self::detect_offset(pos, self.leaves); let mut node = self.roots[root as usize].clone()?; @@ -573,6 +622,7 @@ impl Pollard { fn ingest_positions( &mut self, mut iter: impl Iterator, + remembers: &[u64], ) -> Result<(), String> { let forest_rows = tree_rows(self.leaves); while let Some((pos1, hash1)) = iter.next() { @@ -603,6 +653,11 @@ impl Pollard { new_node.set_aunt(Rc::downgrade(&aunt)); new_sibling.set_aunt(Rc::downgrade(&aunt)); + if remembers.contains(&pos1) || remembers.contains(&pos2) { + self.leaf_map.insert(hash1, Rc::downgrade(&new_node)); + self.leaf_map.insert(hash2, Rc::downgrade(&new_sibling)); + } + aunt.set_niece(Some(new_sibling), Some(new_node)); } @@ -623,7 +678,7 @@ impl Pollard { all_nodes.extend(proof_positions.into_iter().zip(proof.hashes.clone())); all_nodes.sort(); let iter = all_nodes.into_iter().rev(); - self.ingest_positions(iter)?; + self.ingest_positions(iter, remembers)?; let pruned = proof .targets @@ -758,6 +813,7 @@ impl Pollard { fn add_single(&mut self, node: PollardAddition) -> Result, String> { let mut row = 0; let mut new_node = PollardNode::new(node.hash, node.remember); + self.leaf_map.insert(node.hash, Rc::downgrade(&new_node)); let mut add_positions = Vec::new(); let mut roots_to_destroy = Vec::new(); @@ -830,6 +886,7 @@ impl Pollard { } fn delete_single(&mut self, node: Rc>) -> Result<(), String> { + self.leaf_map.remove(&node.hash()); // we are deleting a root, just write an empty hash where it was if node.aunt.borrow().is_none() { for i in 0..64 { @@ -866,6 +923,65 @@ impl Pollard { sibling.migrate_up().unwrap(); Ok(()) } + + /// Returns the position in the tree of this node + fn get_pos(&self, node: &Weak>) -> Result { + // This indicates whether the node is a left or right child at each level + // When we go down the tree, we can use the indicator to know which + // child to take. + let mut left_child_indicator = 0_u64; + let mut rows_to_top = 0; + let mut node = node + .upgrade() + .ok_or("Could not upgrade node. Is this reference valid?")?; + + while let Some(aunt) = node.parent() { + let aunt_left = aunt.children().ok_or("Aunt has no children")?.0; + // If the current node is a left child, we left-shift the indicator + // and leave the LSB as 0 + if aunt_left.hash() == node.hash() { + left_child_indicator <<= 1; + } else { + // If the current node is a right child, we left-shift the indicator + // and set the LSB to 1 + left_child_indicator <<= 1; + left_child_indicator |= 1; + } + rows_to_top += 1; + node = aunt; + } + + let root_row = self.roots.iter().position(|root| { + if let Some(root) = root { + return root.hash() == node.hash(); + } + + false + }); + + let forest_rows = tree_rows(self.leaves); + let root_row = root_row.ok_or(format!( + "Could not find the root position for root {:?}", + node.hash() + ))?; + + let mut pos = root_position(self.leaves, root_row as u8, forest_rows); + for _ in 0..rows_to_top { + // If LSB is 0, go left, otherwise go right + match left_child_indicator & 1 { + 0 => { + pos = left_child(pos, forest_rows); + } + 1 => { + pos = right_child(pos, forest_rows); + } + _ => unreachable!(), + } + left_child_indicator >>= 1; + } + + Ok(pos) + } } #[cfg(test)] @@ -1048,6 +1164,88 @@ mod tests { assert_eq!(root.unwrap().hash(), hashes[0].hash); } + #[test] + fn test_ingest_proof_and_prove() { + // this test will create a forest, prove a few leaves, prune all leaves, ingest the proof + // and prove the same leaves + siblings again + let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; + let hashes: Vec<_> = values + .into_iter() + .map(|preimage| { + let hash = hash_from_u8(preimage); + PollardAddition { + hash, + remember: true, + } + }) + .collect(); + + let mut acc = Pollard::::new(); + acc.modify(&hashes, &[], Proof::default()).unwrap(); + + let del_hashes = [ + hash_from_u8(2), + hash_from_u8(1), + hash_from_u8(4), + hash_from_u8(6), + ]; + let proof = acc.batch_proof(&del_hashes).unwrap(); + + acc.prune(&[0, 1, 2, 3, 4, 5, 6, 7]).unwrap(); + acc.ingest_proof(proof, &del_hashes, &[2, 1, 4, 6]).unwrap(); + + let del_hashes = [0, 1, 4, 5, 6, 7] + .iter() + .map(|x| hash_from_u8(*x)) + .collect::>(); + let proof = acc.batch_proof(&del_hashes).unwrap(); + assert!(acc.verify(&proof, &del_hashes).unwrap()); + } + #[test] + fn test_prove() { + let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; + let hashes: Vec<_> = values + .into_iter() + .map(|preimage| { + let hash = hash_from_u8(preimage); + PollardAddition { + hash, + remember: true, + } + }) + .collect(); + + let mut acc = Pollard::::new(); + acc.modify(&hashes, &[], Proof::default()).unwrap(); + let del_hashes = [ + hash_from_u8(2), + hash_from_u8(1), + hash_from_u8(4), + hash_from_u8(6), + ]; + let proof = acc.batch_proof(&del_hashes).unwrap(); + let expected_proof = Proof::new( + [2, 1, 4, 6].to_vec(), + vec![ + "6e340b9cffb37a989ca544e6bb780a2c78901d3fb33738768511a30617afa01d" + .parse() + .unwrap(), + "084fed08b978af4d7d196a7446a86b58009e636b611db16211b65a9aadff29c5" + .parse() + .unwrap(), + "e77b9a9ae9e30b0dbdb6f510a264ef9de781501d7b6b92ae89eb059c5ab743db" + .parse() + .unwrap(), + "ca358758f6d27e6cf45272937977a748fd88391db679ceda7dc7bf1f005ee879" + .parse() + .unwrap(), + ], + ); + + assert_eq!(proof, expected_proof); + assert!(acc.verify(&proof, &del_hashes).unwrap()); + } + #[test] fn test_prove_single() { let values = vec![0, 1, 2, 3, 4, 5]; @@ -1063,7 +1261,7 @@ mod tests { let mut acc = Pollard::::new(); acc.modify(&hashes, &[], Proof::default()).unwrap(); - let proof = acc.prove_single(3).unwrap(); + let proof = acc.prove_single(hashes[3].hash).unwrap(); let expected_hashes = [ "dbc1b4c900ffe48d575b5da5c638040125f65db0fe3e24494b76ea986457d986", "02242b37d8e851f1e86f46790298c7097df06893d6226b7c1453c213e91717de", @@ -1080,6 +1278,60 @@ mod tests { assert_eq!(proof, expected_proof); } + fn get_hashes_of(values: &[u8]) -> Vec> { + values + .iter() + .map(|preimage| { + let hash = hash_from_u8(*preimage); + PollardAddition { + hash, + remember: true, + } + }) + .collect() + } + + #[test] + fn test_get_pos() { + macro_rules! test_get_pos { + ($p:ident, $pos:literal) => { + let node = $p.grab_position($pos).unwrap().0; + assert_eq!( + $p.get_pos(&Rc::downgrade(&node)), + Ok($pos), + "Failed to get position of node {:?}", + node + ); + }; + } + + let hashes = get_hashes_of(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]); + let mut p = Pollard::new(); + p.modify(&hashes, &[], Proof::default()) + .expect("Test mem_forests are valid"); + test_get_pos!(p, 0); + test_get_pos!(p, 1); + test_get_pos!(p, 2); + test_get_pos!(p, 3); + test_get_pos!(p, 4); + test_get_pos!(p, 5); + test_get_pos!(p, 6); + test_get_pos!(p, 7); + test_get_pos!(p, 8); + test_get_pos!(p, 9); + test_get_pos!(p, 10); + test_get_pos!(p, 11); + test_get_pos!(p, 12); + + let root = p.roots[3].as_ref().unwrap(); + let left = root.left_niece().unwrap(); + let right = root.right_niece().unwrap(); + + assert_eq!(p.get_pos(&Rc::downgrade(&root)), Ok(28)); + assert_eq!(p.get_pos(&Rc::downgrade(&left)), Ok(24)); + assert_eq!(p.get_pos(&Rc::downgrade(&right)), Ok(25)); + } + #[test] fn test_ingest_proof() { let values = [0, 1, 2, 3, 4, 5, 6, 7] @@ -1109,7 +1361,8 @@ mod tests { acc.modify(&values, &[], Proof::default()).unwrap(); acc.ingest_proof(proof.clone(), &[hash_from_u8(3)], &[3]) .unwrap(); - let new_proof = acc.prove_single(3).unwrap(); + + let new_proof = acc.prove_single(values[3].hash).unwrap(); assert_eq!(new_proof, proof); let node = acc.grab_position(3).unwrap().0;