Skip to content

Commit

Permalink
WIP: add a leaf map
Browse files Browse the repository at this point in the history
When proving, we currently ask for positions. However, positions may
change during updates, so externally holding those positions may be
cumbersome. The leaf hash, on the other hand never changes.

This commit does the same we do for `MemForest`, creating a map from
leaf hashes to a reference to the actual leaf in memory. Once we need to
prove it, we compute the position on-the-fly.
  • Loading branch information
Davidson-Souza committed Jan 21, 2025
1 parent 6da503b commit 3864784
Showing 1 changed file with 180 additions and 11 deletions.
191 changes: 180 additions & 11 deletions src/accumulator/pollard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
/// //TODO: Add usage examples
use std::cell::Cell;
use std::cell::RefCell;
use std::collections::HashMap;
use std::fmt::Debug;
use std::fmt::Display;
use std::rc::Rc;
Expand All @@ -47,8 +48,10 @@ use super::util::detect_row;
use super::util::detwin;
use super::util::get_proof_positions;
use super::util::is_root_position;
use super::util::left_child;
use super::util::max_position_at_row;
use super::util::parent;
use super::util::right_child;
use super::util::root_position;
use super::util::tree_rows;

Expand Down Expand Up @@ -121,6 +124,20 @@ impl<Hash: AccumulatorHash> PollardNode<Hash> {
})
}

fn parent(&self) -> Option<Rc<PollardNode<Hash>>> {
let granparent = self.grandparent();
if granparent.is_none() {
return self.aunt();
}

let granparent = granparent.unwrap();
if granparent.left_niece().eq(&self.aunt()) {
granparent.right_niece()
} else {
granparent.left_niece()
}
}

/// Returns the hash of this node
fn hash(&self) -> Hash {
self.hash.get()
Expand Down Expand Up @@ -353,6 +370,7 @@ pub struct Pollard<Hash: AccumulatorHash> {
/// add 5 leaves and delete 4, this value will still be 5. Moreover, the position of a leaf is
/// the number of leaves when it was added, so we can always find a leaf by it's position.
leaves: u64,
leaf_map: HashMap<Hash, Weak<PollardNode<Hash>>>,
}

impl<Hash: AccumulatorHash> PartialEq for Pollard<Hash> {
Expand Down Expand Up @@ -432,16 +450,17 @@ impl<Hash: AccumulatorHash> Pollard<Hash> {
self.do_ingest_proof(proof, del_hashes, remembers, false)
}

pub fn prune(&self, positions: &[u64]) -> Result<(), &'static str> {
pub fn prune(&mut self, positions: &[u64]) -> Result<(), &'static str> {
let positions = detwin(positions.to_vec(), tree_rows(self.leaves));
let nodes = positions
.into_iter()
.map(|pos| self.grab_position(pos))
.collect::<Vec<_>>();

for node in nodes {
let node = node.ok_or("Position not found")?;
node.0.prune();
let (node, _) = node.ok_or("Position not found")?;
self.leaf_map.remove(&node.hash());
node.prune();
}

Ok(())
Expand All @@ -461,8 +480,18 @@ impl<Hash: AccumulatorHash> Pollard<Hash> {
/// Proves the inclusion of the nodes at the given positions
///
/// This function takes a list of positions and returns a list of proofs for each position.
pub fn batch_proof(&self, targets: &[u64]) -> Result<Proof<Hash>, &'static str> {
let targets = detwin(targets.to_vec(), tree_rows(self.leaves));
pub fn batch_proof(&self, targets: &[Hash]) -> Result<Proof<Hash>, String> {
let mut positions = Vec::new();
for target in targets {
let node = self
.leaf_map
.get(target)
.ok_or(format!("leaf {target} not found in the forest"))?;
let position = self.get_pos(node)?;
positions.push(position);
}

let targets = detwin(positions, tree_rows(self.leaves));
let positions = get_proof_positions(&targets, self.leaves, tree_rows(self.leaves));
let mut proof_hashes = Vec::new();

Expand All @@ -482,7 +511,9 @@ impl<Hash: AccumulatorHash> Pollard<Hash> {
})
}

pub fn prove_single(&self, pos: u64) -> Result<Proof<Hash>, &'static str> {
pub fn prove_single(&self, leaf: Hash) -> Result<Proof<Hash>, String> {
let node = self.leaf_map.get(&leaf).ok_or("Leaf not found")?;
let pos = self.get_pos(node)?;
let hashes = self.prove_single_inner(pos)?;
let targets = vec![pos];

Expand All @@ -506,8 +537,8 @@ impl<Hash: AccumulatorHash> Pollard<Hash> {
proof: Proof<Hash>,
) -> Result<(), String> {
let targets = proof.targets.clone();
self.ingest_proof(proof.clone(), del_hashes, &targets)
.unwrap();
self.ingest_proof(proof.clone(), del_hashes, &targets)?;

let targets = detwin(targets, tree_rows(self.leaves));
let targets = targets
.into_iter()
Expand Down Expand Up @@ -535,7 +566,11 @@ impl<Hash: AccumulatorHash> Pollard<Hash> {
/// Creates a new empty [Pollard]
pub fn new() -> Pollard<Hash> {
let roots: [Option<Rc<PollardNode<Hash>>>; 64] = std::array::from_fn(|_| None);
Pollard::<Hash> { roots, leaves: 0 }
Pollard::<Hash> {
roots,
leaves: 0,
leaf_map: HashMap::new(),
}
}
}

Expand Down Expand Up @@ -632,6 +667,7 @@ impl<Hash: AccumulatorHash> Pollard<Hash> {
.copied()
.collect::<Vec<_>>();

self.map_targets(remembers)?;
self.prune(&pruned)?;

if recompute {
Expand All @@ -643,6 +679,17 @@ impl<Hash: AccumulatorHash> Pollard<Hash> {
Ok(())
}

fn map_targets(&mut self, targets: &[u64]) -> Result<(), String> {
for target in targets {
let node = self
.grab_position(*target)
.ok_or(format!("Position {target} not found"))?;
self.leaf_map.insert(node.0.hash(), Rc::downgrade(&node.0));
}

Ok(())
}

fn detect_offset(pos: u64, num_leaves: u64) -> (u8, u8, u64) {
let mut tr = tree_rows(num_leaves);
let nr = detect_row(pos, tr);
Expand Down Expand Up @@ -758,6 +805,7 @@ impl<Hash: AccumulatorHash> Pollard<Hash> {
fn add_single(&mut self, node: PollardAddition<Hash>) -> Result<AddSingleResult<Hash>, String> {
let mut row = 0;
let mut new_node = PollardNode::new(node.hash, node.remember);
self.leaf_map.insert(node.hash, Rc::downgrade(&new_node));

let mut add_positions = Vec::new();
let mut roots_to_destroy = Vec::new();
Expand Down Expand Up @@ -830,6 +878,7 @@ impl<Hash: AccumulatorHash> Pollard<Hash> {
}

fn delete_single(&mut self, node: Rc<PollardNode<Hash>>) -> Result<(), String> {
self.leaf_map.remove(&node.hash());
// we are deleting a root, just write an empty hash where it was
if node.aunt.borrow().is_none() {
for i in 0..64 {
Expand Down Expand Up @@ -866,6 +915,65 @@ impl<Hash: AccumulatorHash> Pollard<Hash> {
sibling.migrate_up().unwrap();
Ok(())
}

/// Returns the position in the tree of this node
fn get_pos(&self, node: &Weak<PollardNode<Hash>>) -> Result<u64, String> {
// This indicates whether the node is a left or right child at each level
// When we go down the tree, we can use the indicator to know which
// child to take.
let mut left_child_indicator = 0_u64;
let mut rows_to_top = 0;
let mut node = node
.upgrade()
.ok_or("Could not upgrade node. Is this reference valid?")?;

while let Some(aunt) = node.parent() {
let aunt_left = aunt.children().ok_or("Aunt has no children")?.0;
// If the current node is a left child, we left-shift the indicator
// and leave the LSB as 0
if aunt_left.hash() == node.hash() {
left_child_indicator <<= 1;
} else {
// If the current node is a right child, we left-shift the indicator
// and set the LSB to 1
left_child_indicator <<= 1;
left_child_indicator |= 1;
}
rows_to_top += 1;
node = aunt;
}

let root_row = self.roots.iter().position(|root| {
if let Some(root) = root {
return root.hash() == node.hash();
}

false
});

let forest_rows = tree_rows(self.leaves);
let root_row = root_row.ok_or(format!(
"Could not find the root position for root {:?}",
node.hash()
))?;

let mut pos = root_position(self.leaves, root_row as u8, forest_rows);
for _ in 0..rows_to_top {
// If LSB is 0, go left, otherwise go right
match left_child_indicator & 1 {
0 => {
pos = left_child(pos, forest_rows);
}
1 => {
pos = right_child(pos, forest_rows);
}
_ => unreachable!(),
}
left_child_indicator >>= 1;
}

Ok(pos)
}
}

#[cfg(test)]
Expand Down Expand Up @@ -1063,7 +1171,7 @@ mod tests {
let mut acc = Pollard::<BitcoinNodeHash>::new();
acc.modify(&hashes, &[], Proof::default()).unwrap();

let proof = acc.prove_single(3).unwrap();
let proof = acc.prove_single(hashes[3].hash).unwrap();
let expected_hashes = [
"dbc1b4c900ffe48d575b5da5c638040125f65db0fe3e24494b76ea986457d986",
"02242b37d8e851f1e86f46790298c7097df06893d6226b7c1453c213e91717de",
Expand All @@ -1080,6 +1188,66 @@ mod tests {
assert_eq!(proof, expected_proof);
}

fn get_hashes_of(values: &[u8]) -> Vec<PollardAddition<BitcoinNodeHash>> {
values
.iter()
.map(|preimage| {
let hash = hash_from_u8(*preimage);
PollardAddition {
hash,
remember: true,
}
})
.collect()
}

#[test]
fn test_get_pos() {
macro_rules! test_get_pos {
($p:ident, $pos:literal) => {
let node = $p.grab_position($pos).unwrap().0;
assert_eq!(
$p.get_pos(&Rc::downgrade(&node)),
Ok($pos),
"Failed to get position of node {:?}",
node
);
};
}

let hashes = get_hashes_of(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]);
let mut p = Pollard::new();
p.modify(&hashes, &[], Proof::default())
.expect("Test mem_forests are valid");
test_get_pos!(p, 0);
test_get_pos!(p, 1);
test_get_pos!(p, 2);
test_get_pos!(p, 3);
test_get_pos!(p, 4);
test_get_pos!(p, 5);
test_get_pos!(p, 6);
test_get_pos!(p, 7);
test_get_pos!(p, 8);
test_get_pos!(p, 9);
test_get_pos!(p, 10);
test_get_pos!(p, 11);
test_get_pos!(p, 12);

let root = p.roots[3].as_ref().unwrap();
let left = root.left_niece().unwrap();
let right = root.right_niece().unwrap();

assert_eq!(p.get_pos(&Rc::downgrade(&root)), Ok(28));
assert_eq!(
p.get_pos(&Rc::downgrade(&left)),
Ok(24)
);
assert_eq!(
p.get_pos(&Rc::downgrade(&right)),
Ok(25)
);
}

#[test]
fn test_ingest_proof() {
let values = [0, 1, 2, 3, 4, 5, 6, 7]
Expand Down Expand Up @@ -1109,7 +1277,8 @@ mod tests {
acc.modify(&values, &[], Proof::default()).unwrap();
acc.ingest_proof(proof.clone(), &[hash_from_u8(3)], &[3])
.unwrap();
let new_proof = acc.prove_single(3).unwrap();

let new_proof = acc.prove_single(values[3].hash).unwrap();
assert_eq!(new_proof, proof);

let node = acc.grab_position(3).unwrap().0;
Expand Down

0 comments on commit 3864784

Please sign in to comment.