Skip to content

Commit

Permalink
fix: parallelization in node collection
Browse files Browse the repository at this point in the history
  • Loading branch information
weiihann committed Feb 26, 2025
1 parent cbd9f7e commit dc3f89e
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 33 deletions.
36 changes: 25 additions & 11 deletions core/trie2/collector.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ import (
"github.com/NethermindEth/juno/core/trie2/trienode"
)

const (
parallelThreshold = 8 // TODO: arbitrary number, configure this based on monitoring
)

// Used as a tool to collect all dirty nodes in a NodeSet
type collector struct {
nodes *trienode.NodeSet
Expand All @@ -18,10 +22,17 @@ func newCollector(nodes *trienode.NodeSet) *collector {

// Collects the nodes in the node set and collapses a node into a hash node
func (c *collector) Collect(n Node, parallel bool) *HashNode {
return c.collect(new(Path), n, parallel).(*HashNode)
return c.collect(new(Path), n, parallel, 0).(*HashNode)
}

func (c *collector) collect(path *Path, n Node, parallel bool) Node {
// Collects all dirty nodes in the trie and converts them to hash nodes.
// Traverses the trie recursively, processing each node type:
// - EdgeNode: Collects its child before processing the edge node itself
// - BinaryNode: Collects both children (potentially in parallel) before processing the node
// - HashNode: Already processed, returns as is
// - ValueNode: Stores as a leaf in the node set
// Returns a HashNode representing the processed node.
func (c *collector) collect(path *Path, n Node, parallel bool, depth int) Node {
// This path has not been modified, just return the cache
hash, dirty := n.cache()
if hash != nil && !dirty {
Expand All @@ -36,11 +47,11 @@ func (c *collector) collect(path *Path, n Node, parallel bool) Node {
// If the child is a binary node, recurse into it.
// Otherwise, it can only be a HashNode or ValueNode.
// Combination of edge (parent) + edge (child) is not possible.
collapsed.Child = c.collect(new(Path).Append(path, cn.Path), cn.Child, parallel)
collapsed.Child = c.collect(new(Path).Append(path, cn.Path), cn.Child, parallel, depth+1)
return c.store(path, collapsed)
case *BinaryNode:
collapsed := cn.copy()
collapsed.Children = c.collectChildren(path, cn, parallel)
collapsed.Children = c.collectChildren(path, cn, parallel, depth+1)
return c.store(path, collapsed)
case *HashNode:
return cn
Expand All @@ -52,9 +63,11 @@ func (c *collector) collect(path *Path, n Node, parallel bool) Node {
}

// Collects the children of a binary node, may apply parallel processing if configured
func (c *collector) collectChildren(path *Path, n *BinaryNode, parallel bool) [2]Node {
func (c *collector) collectChildren(path *Path, n *BinaryNode, parallel bool, depth int) [2]Node {
children := [2]Node{}

var mu sync.Mutex

// Helper function to process a single child
processChild := func(i int) {
child := n.Children[i]
Expand All @@ -68,15 +81,20 @@ func (c *collector) collectChildren(path *Path, n *BinaryNode, parallel bool) [2
childPath := new(Path).AppendBit(path, uint8(i))

if !parallel {
children[i] = c.collect(childPath, child, parallel)
children[i] = c.collect(childPath, child, parallel, depth)
return
}

// Parallel processing
childSet := trienode.NewNodeSet(c.nodes.Owner)
childCollector := newCollector(childSet)
children[i] = childCollector.collect(childPath, child, parallel)
children[i] = childCollector.collect(childPath, child, depth < parallelThreshold, depth)

// Merge the child set into the parent set
// Must be done under the mutex because node set is not thread safe
mu.Lock()
c.nodes.MergeSet(childSet) //nolint:errcheck // guaranteed to succeed because same owner
mu.Unlock()
}

if !parallel {
Expand All @@ -88,15 +106,11 @@ func (c *collector) collectChildren(path *Path, n *BinaryNode, parallel bool) [2

// Parallel processing
var wg sync.WaitGroup
var mu sync.Mutex

for i := range 2 {
wg.Add(1)
go func(idx int) {
defer wg.Done()
mu.Lock()
processChild(idx)
mu.Unlock()
}(i)
}
wg.Wait()
Expand Down
2 changes: 1 addition & 1 deletion core/trie2/hasher.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"github.com/NethermindEth/juno/core/crypto"
)

// A tool for shashing nodes in the trie. It supports both sequential and parallel
// A tool for hashing nodes in the trie. It supports both sequential and parallel
// hashing modes.
type hasher struct {
hashFn crypto.HashFn // The hash function to use
Expand Down
34 changes: 15 additions & 19 deletions core/trie2/id.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,21 @@ type ID struct {
Owner felt.Felt // The contract address which the trie belongs to
}

// Represents a class trie
// Identifier for a class trie
type ClassTrieID struct{}

func NewClassTrieID() *ClassTrieID { return &ClassTrieID{} }
func (ClassTrieID) Bucket() db.Bucket { return db.ClassTrie }
func (ClassTrieID) Owner() felt.Felt { return felt.Zero }
func (ClassTrieID) IsContractStorage() bool { return false }
func NewClassTrieID() *ClassTrieID { return &ClassTrieID{} }
func (ClassTrieID) Bucket() db.Bucket { return db.ClassTrie }
func (ClassTrieID) Owner() felt.Felt { return felt.Zero }

// Represents a contract trie
// Identifier for a contract trie
type ContractTrieID struct{}

func NewContractTrieID() *ContractTrieID { return &ContractTrieID{} }
func (id ContractTrieID) Bucket() db.Bucket { return db.ContractTrieContract }
func (id ContractTrieID) Owner() felt.Felt { return felt.Zero }
func (id ContractTrieID) IsContractStorage() bool { return false }
func NewContractTrieID() *ContractTrieID { return &ContractTrieID{} }
func (id ContractTrieID) Bucket() db.Bucket { return db.ContractTrieContract }
func (id ContractTrieID) Owner() felt.Felt { return felt.Zero }

// Represents a contract storage trie
// Identifier for a contract storage trie
type ContractStorageTrieID struct {
owner felt.Felt
}
Expand All @@ -55,14 +53,12 @@ func NewContractStorageTrieIDFromFelt(owner felt.Felt) *ContractStorageTrieID {
return &ContractStorageTrieID{owner: owner}
}

func (id ContractStorageTrieID) Bucket() db.Bucket { return db.ContractTrieStorage }
func (id ContractStorageTrieID) Owner() felt.Felt { return id.owner }
func (id ContractStorageTrieID) IsContractStorage() bool { return true }
func (id ContractStorageTrieID) Bucket() db.Bucket { return db.ContractTrieStorage }
func (id ContractStorageTrieID) Owner() felt.Felt { return id.owner }

// Represents an empty trie, only used for temporary purposes
// Identifier for an empty trie, only used for temporary purposes
type EmptyTrieID struct{}

func NewEmptyTrieID() *EmptyTrieID { return &EmptyTrieID{} }
func (EmptyTrieID) Bucket() db.Bucket { return db.Bucket(0) }
func (EmptyTrieID) Owner() felt.Felt { return felt.Zero }
func (EmptyTrieID) IsContractStorage() bool { return false }
func NewEmptyTrieID() *EmptyTrieID { return &EmptyTrieID{} }
func (EmptyTrieID) Bucket() db.Bucket { return db.Bucket(0) }
func (EmptyTrieID) Owner() felt.Felt { return felt.Zero }
2 changes: 1 addition & 1 deletion core/trie2/node_enc.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ const (
)

var bufferPool = sync.Pool{
New: func() interface{} {
New: func() any {
return new(bytes.Buffer)
},
}
Expand Down
29 changes: 28 additions & 1 deletion core/trie2/trie.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ type TrieID interface {
Owner() felt.Felt
}

// Creates a new trie
func New(id TrieID, height uint8, hashFn crypto.HashFn, txn db.Transaction) (*Trie, error) {
database := triedb.New(txn, id.Bucket())
tr := &Trie{
Expand Down Expand Up @@ -223,6 +224,17 @@ func (t *Trie) HashFn() crypto.HashFn {
return t.hashFn
}

// Traverses the trie to find a value associated with a given key.
// It handles different node types:
// - EdgeNode: Checks if the path matches the key, then recursively traverses the child node
// - BinaryNode: Determines which child to follow based on the most significant bit of the key
// - HashNode: Resolves the actual node from the database before continuing traversal
// - ValueNode: Returns the stored value when found
// - nil: Returns nil when no value exists
//
// The method returns four values: the found value (or nil), the possibly updated node,
// a flag indicating if node resolution occurred, and any error encountered.
// When nodes are resolved from the database, the trie structure is updated to cache the resolved nodes.
func (t *Trie) get(n Node, prefix, key *Path) (*felt.Felt, Node, bool, error) {
switch n := n.(type) {
case *EdgeNode:
Expand Down Expand Up @@ -279,6 +291,13 @@ func (t *Trie) update(key, value *felt.Felt) error {
return nil
}

// Inserts a value into the trie. Handles different node types:
// - EdgeNode: Creates branch nodes when paths diverge, or updates existing paths
// - BinaryNode: Follows the appropriate child based on the key's MSB
// - HashNode: Resolves the actual node before insertion
// - nil: Creates a new edge or value node depending on key length
// Returns whether the trie was modified, the new/updated node, and any error.
//
//nolint:gocyclo,funlen
func (t *Trie) insert(n Node, prefix, key *Path, value Node) (bool, Node, error) {
// We reach the end of the key
Expand Down Expand Up @@ -371,6 +390,14 @@ func (t *Trie) insert(n Node, prefix, key *Path, value Node) (bool, Node, error)
}
}

// Deletes a key from the trie. Handles different node types:
// - EdgeNode: Removes the node if path matches, or recursively deletes from child
// - BinaryNode: Follows the appropriate child and may collapse the node if a child is removed
// - HashNode: Resolves the actual node before deletion
// - ValueNode: Removes the value node when found
// - nil: Returns false as there's nothing to delete
// Returns whether the trie was modified, the new/updated node, and any error.
//
//nolint:gocyclo,funlen
func (t *Trie) delete(n Node, prefix, key *Path) (bool, Node, error) {
switch n := n.(type) {
Expand Down Expand Up @@ -489,7 +516,7 @@ func (t *Trie) resolveNode(hn *HashNode, path Path) (Node, error) {
return decodeNode(blob, hash, path.Len(), t.height)
}

// Calculate the hash of the root node
// Calculates the hash of the root node
func (t *Trie) hashRoot() (Node, Node) {
if t.root == nil {
return &HashNode{Felt: felt.Zero}, nil
Expand Down
7 changes: 7 additions & 0 deletions core/trie2/triedb/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,13 @@ func (d *Database) NewIterator(owner felt.Felt) (db.Iterator, error) {
return d.txn.NewIterator(buffer.Bytes(), true)
}

// Construct key bytes to insert a trie node. The format is as follows:
//
// ClassTrie/ContractTrie:
// [1 byte prefix][1 byte node-type][path]
//
// StorageTrie of a Contract :
// [1 byte prefix][32 bytes owner][1 byte node-type][path]
func (d *Database) dbKey(buf *bytes.Buffer, owner felt.Felt, path trieutils.BitArray, isLeaf bool) error {
_, err := buf.Write(d.prefix.Key())
if err != nil {
Expand Down

0 comments on commit dc3f89e

Please sign in to comment.