Skip to content

Commit

Permalink
Merge pull request #62 from Davidson-Souza/fix-pollard-crash
Browse files Browse the repository at this point in the history
fix a bug with MemForest due to off-by-one error
  • Loading branch information
Davidson-Souza authored Jan 17, 2025
2 parents 071df44 + c372a3b commit 6da503b
Showing 1 changed file with 34 additions and 26 deletions.
60 changes: 34 additions & 26 deletions src/accumulator/mem_forest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,8 @@ impl<Hash: AccumulatorHash> MemForest<Hash> {
let mut positions = Vec::new();
for target in targets {
let node = self.map.get(target).ok_or("Could not find node")?;
let position = self.get_pos(node);
let position = self.get_pos(node)?;

positions.push(position);
}
let needed = get_proof_positions(&positions, self.leaves, tree_rows(self.leaves));
Expand Down Expand Up @@ -414,21 +415,20 @@ impl<Hash: AccumulatorHash> MemForest<Hash> {
}

fn del(&mut self, targets: &[Hash]) -> Result<(), String> {
let mut pos = targets
.iter()
.flat_map(|target| self.map.get(target))
.flat_map(|target| target.upgrade())
.map(|target| {
(
self.get_pos(self.map.get(&target.data.get()).unwrap()),
target.data.get(),
)
})
.collect::<Vec<_>>();
let mut nodes = Vec::new();

pos.sort();
let (_, targets): (Vec<u64>, Vec<Hash>) = pos.into_iter().unzip();
for target in targets {
let node_ref = self.map.get(target).ok_or("Could not find node")?;
let pos = self.get_pos(node_ref)?;

let node = node_ref.upgrade().ok_or("Could not upgrade node")?;

nodes.push((pos, node.get_data()));
}

nodes.sort_by(|a, b| a.0.cmp(&b.0));

for (_, target) in nodes {
match self.map.remove(&target) {
Some(target) => {
self.del_single(&target.upgrade().unwrap());
Expand All @@ -450,18 +450,21 @@ impl<Hash: AccumulatorHash> MemForest<Hash> {
proof.verify(del_hashes, &roots, self.leaves)
}

fn get_pos(&self, node: &Weak<Node<Hash>>) -> u64 {
fn get_pos(&self, node: &Weak<Node<Hash>>) -> Result<u64, String> {
// This indicates whether the node is a left or right child at each level
// When we go down the tree, we can use the indicator to know which
// child to take.
let mut left_child_indicator = 0_u64;
let mut rows_to_top = 0;
let mut node = node.upgrade().unwrap();
let mut node = node
.upgrade()
.ok_or("Could not upgrade node. Is this reference valid?")?;

while let Some(parent) = node.parent.clone().into_inner() {
let parent_left = parent
.upgrade()
.and_then(|parent| parent.left.clone().into_inner())
.unwrap()
.ok_or("Could not upgrade parent")?
.clone();

// If the current node is a left child, we left-shift the indicator
Expand All @@ -475,22 +478,26 @@ impl<Hash: AccumulatorHash> MemForest<Hash> {
left_child_indicator |= 1;
}
rows_to_top += 1;
node = parent.upgrade().unwrap();
node = parent.upgrade().ok_or("could not upgrade parent")?;
}
let mut root_idx = self.roots.len() - 1;
let forest_rows = tree_rows(self.leaves);
let mut root_row = 0;
let mut root_row = None;
// Find the root of the tree that the node belongs to
for row in 0..forest_rows {
for row in 0..=forest_rows {
if is_root_populated(row, self.leaves) {
let root = &self.roots[root_idx];
if root.get_data() == node.get_data() {
root_row = row;
root_row = Some(row);
break;
}
root_idx -= 1;
}
}

let root_row = root_row.ok_or(format!(
"Could not find the root position for row {root_idx}"
))?;
let mut pos = root_position(self.leaves, root_row, forest_rows);
for _ in 0..rows_to_top {
// If LSB is 0, go left, otherwise go right
Expand All @@ -505,7 +512,8 @@ impl<Hash: AccumulatorHash> MemForest<Hash> {
}
left_child_indicator >>= 1;
}
pos

Ok(pos)
}

fn del_single(&mut self, node: &Node<Hash>) -> Option<()> {
Expand Down Expand Up @@ -950,7 +958,7 @@ mod test {
($p:ident, $pos:literal) => {
assert_eq!(
$p.get_pos(&Rc::downgrade(&$p.grab_node($pos).unwrap().0)),
$pos
Ok($pos)
);
};
}
Expand All @@ -971,18 +979,18 @@ mod test {
test_get_pos!(p, 11);
test_get_pos!(p, 12);

assert_eq!(p.get_pos(&Rc::downgrade(&p.get_roots()[0])), 28);
assert_eq!(p.get_pos(&Rc::downgrade(&p.get_roots()[0])), Ok(28));
assert_eq!(
p.get_pos(&Rc::downgrade(
p.get_roots()[0].left.borrow().as_ref().unwrap()
)),
24
Ok(24)
);
assert_eq!(
p.get_pos(&Rc::downgrade(
p.get_roots()[0].right.borrow().as_ref().unwrap()
)),
25
Ok(25)
);
}

Expand Down

0 comments on commit 6da503b

Please sign in to comment.