From c372a3b4bf4cabd94ffc19c91e858eb9c64eed8f Mon Sep 17 00:00:00 2001 From: Davidson Souza Date: Wed, 15 Jan 2025 16:14:37 -0300 Subject: [PATCH] fix a bug with MemForest due to off-by-one error When we try to find in what subtree our node is, we loop over all existing roots. However, this loop is not inclusive, and will never reach the highest tree. If we don't find the tree, it continues with `0` as the tree position, but this is wrong, and will cause problems downstream when we try to use that position. This commit updates the loop and make it inclusive --- src/accumulator/mem_forest.rs | 60 ++++++++++++++++++++--------------- 1 file changed, 34 insertions(+), 26 deletions(-) diff --git a/src/accumulator/mem_forest.rs b/src/accumulator/mem_forest.rs index b777291..cd5cae8 100644 --- a/src/accumulator/mem_forest.rs +++ b/src/accumulator/mem_forest.rs @@ -320,7 +320,8 @@ impl MemForest { let mut positions = Vec::new(); for target in targets { let node = self.map.get(target).ok_or("Could not find node")?; - let position = self.get_pos(node); + let position = self.get_pos(node)?; + positions.push(position); } let needed = get_proof_positions(&positions, self.leaves, tree_rows(self.leaves)); @@ -414,21 +415,20 @@ impl MemForest { } fn del(&mut self, targets: &[Hash]) -> Result<(), String> { - let mut pos = targets - .iter() - .flat_map(|target| self.map.get(target)) - .flat_map(|target| target.upgrade()) - .map(|target| { - ( - self.get_pos(self.map.get(&target.data.get()).unwrap()), - target.data.get(), - ) - }) - .collect::>(); + let mut nodes = Vec::new(); - pos.sort(); - let (_, targets): (Vec, Vec) = pos.into_iter().unzip(); for target in targets { + let node_ref = self.map.get(target).ok_or("Could not find node")?; + let pos = self.get_pos(node_ref)?; + + let node = node_ref.upgrade().ok_or("Could not upgrade node")?; + + nodes.push((pos, node.get_data())); + } + + nodes.sort_by(|a, b| a.0.cmp(&b.0)); + + for (_, target) in nodes { match self.map.remove(&target) { Some(target) => { self.del_single(&target.upgrade().unwrap()); @@ -450,18 +450,21 @@ impl MemForest { proof.verify(del_hashes, &roots, self.leaves) } - fn get_pos(&self, node: &Weak>) -> u64 { + fn get_pos(&self, node: &Weak>) -> Result { // This indicates whether the node is a left or right child at each level // When we go down the tree, we can use the indicator to know which // child to take. let mut left_child_indicator = 0_u64; let mut rows_to_top = 0; - let mut node = node.upgrade().unwrap(); + let mut node = node + .upgrade() + .ok_or("Could not upgrade node. Is this reference valid?")?; + while let Some(parent) = node.parent.clone().into_inner() { let parent_left = parent .upgrade() .and_then(|parent| parent.left.clone().into_inner()) - .unwrap() + .ok_or("Could not upgrade parent")? .clone(); // If the current node is a left child, we left-shift the indicator @@ -475,22 +478,26 @@ impl MemForest { left_child_indicator |= 1; } rows_to_top += 1; - node = parent.upgrade().unwrap(); + node = parent.upgrade().ok_or("could not upgrade parent")?; } let mut root_idx = self.roots.len() - 1; let forest_rows = tree_rows(self.leaves); - let mut root_row = 0; + let mut root_row = None; // Find the root of the tree that the node belongs to - for row in 0..forest_rows { + for row in 0..=forest_rows { if is_root_populated(row, self.leaves) { let root = &self.roots[root_idx]; if root.get_data() == node.get_data() { - root_row = row; + root_row = Some(row); break; } root_idx -= 1; } } + + let root_row = root_row.ok_or(format!( + "Could not find the root position for row {root_idx}" + ))?; let mut pos = root_position(self.leaves, root_row, forest_rows); for _ in 0..rows_to_top { // If LSB is 0, go left, otherwise go right @@ -505,7 +512,8 @@ impl MemForest { } left_child_indicator >>= 1; } - pos + + Ok(pos) } fn del_single(&mut self, node: &Node) -> Option<()> { @@ -950,7 +958,7 @@ mod test { ($p:ident, $pos:literal) => { assert_eq!( $p.get_pos(&Rc::downgrade(&$p.grab_node($pos).unwrap().0)), - $pos + Ok($pos) ); }; } @@ -971,18 +979,18 @@ mod test { test_get_pos!(p, 11); test_get_pos!(p, 12); - assert_eq!(p.get_pos(&Rc::downgrade(&p.get_roots()[0])), 28); + assert_eq!(p.get_pos(&Rc::downgrade(&p.get_roots()[0])), Ok(28)); assert_eq!( p.get_pos(&Rc::downgrade( p.get_roots()[0].left.borrow().as_ref().unwrap() )), - 24 + Ok(24) ); assert_eq!( p.get_pos(&Rc::downgrade( p.get_roots()[0].right.borrow().as_ref().unwrap() )), - 25 + Ok(25) ); }