Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add migration checker cmd #1114

Open
wants to merge 7 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 246 additions & 0 deletions cmd/migration-checker/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
package main

import (
"bytes"
"encoding/hex"
"flag"
"fmt"
"os"
"runtime"
"sync"
"sync/atomic"

"github.com/scroll-tech/go-ethereum/common"
"github.com/scroll-tech/go-ethereum/core/types"
"github.com/scroll-tech/go-ethereum/crypto"
"github.com/scroll-tech/go-ethereum/ethdb/leveldb"
"github.com/scroll-tech/go-ethereum/rlp"
"github.com/scroll-tech/go-ethereum/trie"
)

var accountsDone atomic.Uint64
var trieCheckers = make(chan struct{}, runtime.GOMAXPROCS(0)*4)

type dbs struct {
zkDb *leveldb.Database
mptDb *leveldb.Database
}

func main() {
var (
mptDbPath = flag.String("mpt-db", "", "path to the MPT node DB")
zkDbPath = flag.String("zk-db", "", "path to the ZK node DB")
mptRoot = flag.String("mpt-root", "", "root hash of the MPT node")
zkRoot = flag.String("zk-root", "", "root hash of the ZK node")
paranoid = flag.Bool("paranoid", false, "verifies all node contents against their expected hash")
)
flag.Parse()

zkDb, err := leveldb.New(*zkDbPath, 1024, 128, "", true)
panicOnError(err, "", "failed to open zk db")
mptDb, err := leveldb.New(*mptDbPath, 1024, 128, "", true)
panicOnError(err, "", "failed to open mpt db")

zkRootHash := common.HexToHash(*zkRoot)
mptRootHash := common.HexToHash(*mptRoot)

for i := 0; i < runtime.GOMAXPROCS(0)*4; i++ {
trieCheckers <- struct{}{}
}

checkTrieEquality(&dbs{
zkDb: zkDb,
mptDb: mptDb,
}, zkRootHash, mptRootHash, "", checkAccountEquality, true, *paranoid)

for i := 0; i < runtime.GOMAXPROCS(0)*4; i++ {
<-trieCheckers
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Update function calls to pass paranoid parameter.

The calls to loadMPT and loadZkTrie need to be updated to match their new function signatures and pass the paranoid parameter.

-	mptLeafCh := loadMPT(mptTrie, top)
-	zkLeafCh := loadZkTrie(zkTrie, top)
+	mptLeafCh := loadMPT(mptTrie, top)
+	zkLeafCh := loadZkTrie(zkTrie, top, paranoid)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for i := 0; i < runtime.GOMAXPROCS(0)*4; i++ {
<-trieCheckers
mptLeafCh := loadMPT(mptTrie, top)
zkLeafCh := loadZkTrie(zkTrie, top, paranoid)

}
}

func panicOnError(err error, label, msg string) {
if err != nil {
panic(fmt.Sprint(label, " error: ", msg, " ", err))
}
}

func dup(s []byte) []byte {
return append([]byte{}, s...)
}
func checkTrieEquality(dbs *dbs, zkRoot, mptRoot common.Hash, label string, leafChecker func(string, *dbs, []byte, []byte, bool), top, paranoid bool) {
zkTrie, err := trie.NewZkTrie(zkRoot, trie.NewZktrieDatabaseFromTriedb(trie.NewDatabaseWithConfig(dbs.zkDb, &trie.Config{Preimages: true})))
panicOnError(err, label, "failed to create zk trie")
mptTrie, err := trie.NewSecureNoTracer(mptRoot, trie.NewDatabaseWithConfig(dbs.mptDb, &trie.Config{Preimages: true}))
panicOnError(err, label, "failed to create mpt trie")

mptLeafCh := loadMPT(mptTrie, top)
zkLeafCh := loadZkTrie(zkTrie, top, paranoid)

mptLeafMap := <-mptLeafCh
zkLeafMap := <-zkLeafCh

if len(mptLeafMap) != len(zkLeafMap) {
panic(fmt.Sprintf("%s MPT and ZK trie leaf count mismatch: MPT: %d, ZK: %d", label, len(mptLeafMap), len(zkLeafMap)))
}

for preimageKey, zkValue := range zkLeafMap {
if top {
// ZkTrie pads preimages with 0s to make them 32 bytes.
// So we might need to clear those zeroes here since we need 20 byte addresses at top level (ie state trie)
if len(preimageKey) > 20 {
for _, b := range []byte(preimageKey)[20:] {
if b != 0 {
panic(fmt.Sprintf("%s padded byte is not 0 (preimage %s)", label, hex.EncodeToString([]byte(preimageKey))))
}
}
preimageKey = preimageKey[:20]
}
} else if len(preimageKey) != 32 {
// storage leafs should have 32 byte keys, pad them if needed
zeroes := make([]byte, 32)
copy(zeroes, []byte(preimageKey))
preimageKey = string(zeroes)
}

mptKey := crypto.Keccak256([]byte(preimageKey))
mptVal, ok := mptLeafMap[string(mptKey)]
if !ok {
panic(fmt.Sprintf("%s key %s (preimage %s) not found in mpt", label, hex.EncodeToString(mptKey), hex.EncodeToString([]byte(preimageKey))))
}

leafChecker(fmt.Sprintf("%s key: %s", label, hex.EncodeToString([]byte(preimageKey))), dbs, zkValue, mptVal, paranoid)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Update function call to pass paranoid parameter.

The function call to checkTrieEquality within the goroutine needs to include the paranoid parameter.

-			checkTrieEquality(dbs, zkRoot, mptRoot, label, checkStorageEquality, false)
+			checkTrieEquality(dbs, zkRoot, mptRoot, label, checkStorageEquality, false, paranoid)

Committable suggestion skipped: line range outside the PR's diff.

}
}

func checkAccountEquality(label string, dbs *dbs, zkAccountBytes, mptAccountBytes []byte, paranoid bool) {
mptAccount := &types.StateAccount{}
panicOnError(rlp.DecodeBytes(mptAccountBytes, mptAccount), label, "failed to decode mpt account")
zkAccount, err := types.UnmarshalStateAccount(zkAccountBytes)
panicOnError(err, label, "failed to decode zk account")

if mptAccount.Nonce != zkAccount.Nonce {
panic(fmt.Sprintf("%s nonce mismatch: zk: %d, mpt: %d", label, zkAccount.Nonce, mptAccount.Nonce))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Update function signature to include paranoid parameter.

The function signature for checkStorageEquality needs to be updated to include the paranoid parameter for consistency.

-func checkStorageEquality(label string, _ *dbs, zkStorageBytes, mptStorageBytes []byte) {
+func checkStorageEquality(label string, _ *dbs, zkStorageBytes, mptStorageBytes []byte, paranoid bool) {
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
panic(fmt.Sprintf("%s nonce mismatch: zk: %d, mpt: %d", label, zkAccount.Nonce, mptAccount.Nonce))
func checkStorageEquality(label string, _ *dbs, zkStorageBytes, mptStorageBytes []byte, paranoid bool) {
panic(fmt.Sprintf("%s nonce mismatch: zk: %d, mpt: %d", label, zkAccount.Nonce, mptAccount.Nonce))
}

}

if mptAccount.Balance.Cmp(zkAccount.Balance) != 0 {
panic(fmt.Sprintf("%s balance mismatch: zk: %s, mpt: %s", label, zkAccount.Balance.String(), mptAccount.Balance.String()))
}

if !bytes.Equal(mptAccount.KeccakCodeHash, zkAccount.KeccakCodeHash) {
panic(fmt.Sprintf("%s code hash mismatch: zk: %s, mpt: %s", label, hex.EncodeToString(zkAccount.KeccakCodeHash), hex.EncodeToString(mptAccount.KeccakCodeHash)))
}

if (zkAccount.Root == common.Hash{}) != (mptAccount.Root == types.EmptyRootHash) {
panic(fmt.Sprintf("%s empty account root mismatch", label))
} else if zkAccount.Root != (common.Hash{}) {
zkRoot := common.BytesToHash(zkAccount.Root[:])
mptRoot := common.BytesToHash(mptAccount.Root[:])
<-trieCheckers
go func() {
defer func() {
if p := recover(); p != nil {
fmt.Println(p)
os.Exit(1)
}
}()

checkTrieEquality(dbs, zkRoot, mptRoot, label, checkStorageEquality, false, paranoid)
accountsDone.Add(1)
fmt.Println("Accounts done:", accountsDone.Load())
trieCheckers <- struct{}{}
}()
} else {
accountsDone.Add(1)
fmt.Println("Accounts done:", accountsDone.Load())
}
}

func checkStorageEquality(label string, _ *dbs, zkStorageBytes, mptStorageBytes []byte, _ bool) {
zkValue := common.BytesToHash(zkStorageBytes)
_, content, _, err := rlp.Split(mptStorageBytes)
panicOnError(err, label, "failed to decode mpt storage")
mptValue := common.BytesToHash(content)
if !bytes.Equal(zkValue[:], mptValue[:]) {
panic(fmt.Sprintf("%s storage mismatch: zk: %s, mpt: %s", label, zkValue.Hex(), mptValue.Hex()))
}
}

func loadMPT(mptTrie *trie.SecureTrie, parallel bool) chan map[string][]byte {
startKey := make([]byte, 32)
workers := 1 << 5
if !parallel {
workers = 1
}
step := byte(0xFF) / byte(workers)

mptLeafMap := make(map[string][]byte, 1000)
var mptLeafMutex sync.Mutex

var mptWg sync.WaitGroup
for i := 0; i < workers; i++ {
startKey[0] = byte(i) * step
trieIt := trie.NewIterator(mptTrie.NodeIterator(startKey))

mptWg.Add(1)
go func() {
defer mptWg.Done()
for trieIt.Next() {
if parallel {
mptLeafMutex.Lock()
}

if _, ok := mptLeafMap[string(trieIt.Key)]; ok {
mptLeafMutex.Unlock()
break
}

mptLeafMap[string(dup(trieIt.Key))] = dup(trieIt.Value)

if parallel {
mptLeafMutex.Unlock()
}

if parallel && len(mptLeafMap)%10000 == 0 {
fmt.Println("MPT Accounts Loaded:", len(mptLeafMap))
}
}
}()
}

respChan := make(chan map[string][]byte)
Comment on lines +224 to +252
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Resolve conflict in loadZkTrie function definitions.

There appears to be two definitions of the loadZkTrie function. The first version (lines 213-241) doesn't have a paranoid parameter, while the second version (lines 249-277) does. You need to resolve this conflict and keep only the second version.

Remove lines 213-241 and keep the version that includes the paranoid parameter (lines 249-277).

Also applies to: 242-277

go func() {
mptWg.Wait()
respChan <- mptLeafMap
}()
return respChan
}

func loadZkTrie(zkTrie *trie.ZkTrie, parallel, paranoid bool) chan map[string][]byte {
zkLeafMap := make(map[string][]byte, 1000)
var zkLeafMutex sync.Mutex
zkDone := make(chan map[string][]byte)
go func() {
zkTrie.CountLeaves(func(key, value []byte) {
preimageKey := zkTrie.GetKey(key)
if len(preimageKey) == 0 {
panic(fmt.Sprintf("preimage not found zk trie %s", hex.EncodeToString(key)))
}

if parallel {
zkLeafMutex.Lock()
}

zkLeafMap[string(dup(preimageKey))] = value

if parallel {
zkLeafMutex.Unlock()
}

if parallel && len(zkLeafMap)%10000 == 0 {
fmt.Println("ZK Accounts Loaded:", len(zkLeafMap))
}
}, parallel, paranoid)
zkDone <- zkLeafMap
}()
return zkDone
}
10 changes: 10 additions & 0 deletions trie/secure_trie.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,16 @@ func NewSecure(root common.Hash, db *Database) (*SecureTrie, error) {
return &SecureTrie{trie: *trie, preimages: db.preimages}, nil
}

func NewSecureNoTracer(root common.Hash, db *Database) (*SecureTrie, error) {
t, err := NewSecure(root, db)
if err != nil {
return nil, err
}

t.trie.tracer = nil
return t, nil
}

// Get returns the value for key stored in the trie.
// The value bytes must not be modified by the caller.
func (t *SecureTrie) Get(key []byte) []byte {
Expand Down
24 changes: 24 additions & 0 deletions trie/tracer.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,21 @@ func newTracer() *tracer {
// blob internally. Don't change the value outside of function since
// it's not deep-copied.
func (t *tracer) onRead(path []byte, val []byte) {
if t == nil {
return
}

t.accessList[string(path)] = val
}

// onInsert tracks the newly inserted trie node. If it's already
// in the deletion set (resurrected node), then just wipe it from
// the deletion set as it's "untouched".
func (t *tracer) onInsert(path []byte) {
if t == nil {
return
}

if _, present := t.deletes[string(path)]; present {
delete(t.deletes, string(path))
return
Expand All @@ -78,6 +86,10 @@ func (t *tracer) onInsert(path []byte) {
// in the addition set, then just wipe it from the addition set
// as it's untouched.
func (t *tracer) onDelete(path []byte) {
if t == nil {
return
}

if _, present := t.inserts[string(path)]; present {
delete(t.inserts, string(path))
return
Expand All @@ -87,13 +99,21 @@ func (t *tracer) onDelete(path []byte) {

// reset clears the content tracked by tracer.
func (t *tracer) reset() {
if t == nil {
return
}

t.inserts = make(map[string]struct{})
t.deletes = make(map[string]struct{})
t.accessList = make(map[string][]byte)
}

// copy returns a deep copied tracer instance.
func (t *tracer) copy() *tracer {
if t == nil {
return nil
}

accessList := make(map[string][]byte, len(t.accessList))
for path, blob := range t.accessList {
accessList[path] = common.CopyBytes(blob)
Expand All @@ -107,6 +127,10 @@ func (t *tracer) copy() *tracer {

// deletedNodes returns a list of node paths which are deleted from the trie.
func (t *tracer) deletedNodes() []string {
if t == nil {
return nil
}

var paths []string
for path := range t.deletes {
// It's possible a few deleted nodes were embedded
Expand Down
49 changes: 49 additions & 0 deletions trie/zk_trie.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,3 +238,52 @@ func VerifyProofSMT(rootHash common.Hash, key []byte, proofDb ethdb.KeyValueRead
func (t *ZkTrie) Witness() map[string]struct{} {
panic("not implemented")
}

func (t *ZkTrie) CountLeaves(cb func(key, value []byte), parallel, verifyNodeHashes bool) uint64 {
root, err := t.ZkTrie.Tree().Root()
if err != nil {
panic("CountLeaves cannot get root")
}
return t.countLeaves(root, cb, 0, parallel, verifyNodeHashes)
}

func (t *ZkTrie) countLeaves(root *zkt.Hash, cb func(key, value []byte), depth int, parallel, verifyNodeHashes bool) uint64 {
if root == nil {
return 0
}

rootNode, err := t.ZkTrie.Tree().GetNode(root)
if err != nil {
panic("countLeaves cannot get rootNode")
}

if rootNode.Type == zktrie.NodeTypeLeaf_New {
if verifyNodeHashes {
calculatedNodeHash, err := rootNode.NodeHash()
if err != nil {
panic("countLeaves cannot get calculatedNodeHash")
}
if *calculatedNodeHash != *root {
panic("countLeaves node hash mismatch")
}
}

cb(append([]byte{}, rootNode.NodeKey.Bytes()...), append([]byte{}, rootNode.Data()...))
return 1
} else {
if parallel && depth < 5 {
count := make(chan uint64)
leftT := t.Copy()
rightT := t.Copy()
go func() {
count <- leftT.countLeaves(rootNode.ChildL, cb, depth+1, parallel, verifyNodeHashes)
}()
go func() {
count <- rightT.countLeaves(rootNode.ChildR, cb, depth+1, parallel, verifyNodeHashes)
}()
return <-count + <-count
} else {
return t.countLeaves(rootNode.ChildL, cb, depth+1, parallel, verifyNodeHashes) + t.countLeaves(rootNode.ChildR, cb, depth+1, parallel, verifyNodeHashes)
}
}
}
Loading