Skip to content

Commit a9d9968

Browse files
committed
feat: add migration checker cmd
1 parent b56dd9f commit a9d9968

File tree

4 files changed

+314
-0
lines changed

4 files changed

+314
-0
lines changed

cmd/migration-checker/main.go

+241
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
package main
2+
3+
import (
4+
"bytes"
5+
"encoding/hex"
6+
"flag"
7+
"fmt"
8+
"os"
9+
"runtime"
10+
"sync"
11+
"sync/atomic"
12+
13+
"github.com/scroll-tech/go-ethereum/common"
14+
"github.com/scroll-tech/go-ethereum/core/types"
15+
"github.com/scroll-tech/go-ethereum/crypto"
16+
"github.com/scroll-tech/go-ethereum/ethdb/leveldb"
17+
"github.com/scroll-tech/go-ethereum/rlp"
18+
"github.com/scroll-tech/go-ethereum/trie"
19+
)
20+
21+
var accountsDone atomic.Uint64
22+
var trieCheckers = make(chan struct{}, runtime.GOMAXPROCS(0)*4)
23+
24+
type dbs struct {
25+
zkDb *leveldb.Database
26+
mptDb *leveldb.Database
27+
}
28+
29+
func main() {
30+
var (
31+
mptDbPath = flag.String("mpt-db", "", "path to the MPT node DB")
32+
zkDbPath = flag.String("zk-db", "", "path to the ZK node DB")
33+
mptRoot = flag.String("mpt-root", "", "root hash of the MPT node")
34+
zkRoot = flag.String("zk-root", "", "root hash of the ZK node")
35+
)
36+
flag.Parse()
37+
38+
zkDb, err := leveldb.New(*zkDbPath, 1024, 128, "", true)
39+
panicOnError(err, "", "failed to open zk db")
40+
mptDb, err := leveldb.New(*mptDbPath, 1024, 128, "", true)
41+
panicOnError(err, "", "failed to open mpt db")
42+
43+
zkRootHash := common.HexToHash(*zkRoot)
44+
mptRootHash := common.HexToHash(*mptRoot)
45+
46+
for i := 0; i < runtime.GOMAXPROCS(0)*4; i++ {
47+
trieCheckers <- struct{}{}
48+
}
49+
50+
checkTrieEquality(&dbs{
51+
zkDb: zkDb,
52+
mptDb: mptDb,
53+
}, zkRootHash, mptRootHash, "", checkAccountEquality, true)
54+
}
55+
56+
func panicOnError(err error, label, msg string) {
57+
if err != nil {
58+
panic(fmt.Sprint(label, " error: ", msg, " ", err))
59+
}
60+
}
61+
62+
func dup(s []byte) []byte {
63+
return append([]byte{}, s...)
64+
}
65+
func checkTrieEquality(dbs *dbs, zkRoot, mptRoot common.Hash, label string, leafChecker func(string, *dbs, []byte, []byte), top bool) {
66+
zkTrie, err := trie.NewZkTrie(zkRoot, trie.NewZktrieDatabaseFromTriedb(trie.NewDatabaseWithConfig(dbs.zkDb, &trie.Config{Preimages: true})))
67+
panicOnError(err, label, "failed to create zk trie")
68+
mptTrie, err := trie.NewSecureNoTracer(mptRoot, trie.NewDatabaseWithConfig(dbs.mptDb, &trie.Config{Preimages: true}))
69+
panicOnError(err, label, "failed to create mpt trie")
70+
71+
mptLeafCh := loadMPT(mptTrie, top)
72+
zkLeafCh := loadZkTrie(zkTrie, top)
73+
74+
mptLeafMap := <-mptLeafCh
75+
zkLeafMap := <-zkLeafCh
76+
77+
if len(mptLeafMap) != len(zkLeafMap) {
78+
panic(fmt.Sprintf("%s MPT and ZK trie leaf count mismatch: MPT: %d, ZK: %d", label, len(mptLeafMap), len(zkLeafMap)))
79+
}
80+
81+
for preimageKey, zkValue := range zkLeafMap {
82+
if top {
83+
// ZkTrie pads preimages with 0s to make them 32 bytes.
84+
// So we might need to clear those zeroes here since we need 20 byte addresses at top level (ie state trie)
85+
if len(preimageKey) > 20 {
86+
for b := range preimageKey[20:] {
87+
if b != 0 {
88+
panic(fmt.Sprintf("%s padded byte is not 0 (preimage %s)", label, hex.EncodeToString([]byte(preimageKey))))
89+
}
90+
}
91+
preimageKey = preimageKey[:20]
92+
}
93+
} else if len(preimageKey) != 32 {
94+
// storage leafs should have 32 byte keys, pad them if needed
95+
zeroes := make([]byte, 32)
96+
copy(zeroes, []byte(preimageKey))
97+
preimageKey = string(zeroes)
98+
}
99+
100+
mptKey := crypto.Keccak256([]byte(preimageKey))
101+
mptVal, ok := mptLeafMap[string(mptKey)]
102+
if !ok {
103+
panic(fmt.Sprintf("%s key %s (preimage %s) not found in mpt", label, hex.EncodeToString([]byte(mptKey)), hex.EncodeToString([]byte(preimageKey))))
104+
}
105+
106+
leafChecker(fmt.Sprintf("%s key: %s", label, hex.EncodeToString([]byte(preimageKey))), dbs, zkValue, mptVal)
107+
}
108+
}
109+
110+
func checkAccountEquality(label string, dbs *dbs, zkAccountBytes, mptAccountBytes []byte) {
111+
mptAccount := &types.StateAccount{}
112+
panicOnError(rlp.DecodeBytes(mptAccountBytes, mptAccount), label, "failed to decode mpt account")
113+
zkAccount, err := types.UnmarshalStateAccount(zkAccountBytes)
114+
panicOnError(err, label, "failed to decode zk account")
115+
116+
if mptAccount.Nonce != zkAccount.Nonce {
117+
panic(fmt.Sprintf("%s nonce mismatch: zk: %d, mpt: %d", label, zkAccount.Nonce, mptAccount.Nonce))
118+
}
119+
120+
if mptAccount.Balance.Cmp(zkAccount.Balance) != 0 {
121+
panic(fmt.Sprintf("%s balance mismatch: zk: %s, mpt: %s", label, zkAccount.Balance.String(), mptAccount.Balance.String()))
122+
}
123+
124+
if !bytes.Equal(mptAccount.KeccakCodeHash, zkAccount.KeccakCodeHash) {
125+
panic(fmt.Sprintf("%s code hash mismatch: zk: %s, mpt: %s", label, hex.EncodeToString(zkAccount.KeccakCodeHash), hex.EncodeToString(mptAccount.KeccakCodeHash)))
126+
}
127+
128+
if (zkAccount.Root == common.Hash{}) != (mptAccount.Root == types.EmptyRootHash) {
129+
panic(fmt.Sprintf("%s empty account root mismatch", label))
130+
} else if zkAccount.Root != (common.Hash{}) {
131+
zkRoot := common.BytesToHash(zkAccount.Root[:])
132+
mptRoot := common.BytesToHash(mptAccount.Root[:])
133+
<-trieCheckers
134+
go func() {
135+
defer func() {
136+
if p := recover(); p != nil {
137+
fmt.Println(p)
138+
os.Exit(1)
139+
}
140+
}()
141+
142+
checkTrieEquality(dbs, zkRoot, mptRoot, label, checkStorageEquality, false)
143+
accountsDone.Add(1)
144+
fmt.Println("Accounts done:", accountsDone.Load())
145+
trieCheckers <- struct{}{}
146+
}()
147+
} else {
148+
accountsDone.Add(1)
149+
fmt.Println("Accounts done:", accountsDone.Load())
150+
}
151+
}
152+
153+
func checkStorageEquality(label string, _ *dbs, zkStorageBytes, mptStorageBytes []byte) {
154+
zkValue := common.BytesToHash(zkStorageBytes)
155+
_, content, _, err := rlp.Split(mptStorageBytes)
156+
panicOnError(err, label, "failed to decode mpt storage")
157+
mptValue := common.BytesToHash(content)
158+
if !bytes.Equal(zkValue[:], mptValue[:]) {
159+
panic(fmt.Sprintf("%s storage mismatch: zk: %s, mpt: %s", label, zkValue.Hex(), mptValue.Hex()))
160+
}
161+
}
162+
163+
func loadMPT(mptTrie *trie.SecureTrie, parallel bool) chan map[string][]byte {
164+
startKey := make([]byte, 32)
165+
workers := 1 << 5
166+
if !parallel {
167+
workers = 1
168+
}
169+
step := byte(0xFF) / byte(workers)
170+
171+
mptLeafMap := make(map[string][]byte, 1000)
172+
var mptLeafMutex sync.Mutex
173+
174+
var mptWg sync.WaitGroup
175+
for i := 0; i < workers; i++ {
176+
startKey[0] = byte(i) * step
177+
trieIt := trie.NewIterator(mptTrie.NodeIterator(startKey))
178+
179+
mptWg.Add(1)
180+
go func() {
181+
defer mptWg.Done()
182+
for trieIt.Next() {
183+
if parallel {
184+
mptLeafMutex.Lock()
185+
}
186+
187+
if _, ok := mptLeafMap[string(trieIt.Key)]; ok {
188+
mptLeafMutex.Unlock()
189+
break
190+
}
191+
192+
mptLeafMap[string(dup(trieIt.Key))] = dup(trieIt.Value)
193+
194+
if parallel {
195+
mptLeafMutex.Unlock()
196+
}
197+
198+
if parallel && len(mptLeafMap)%10000 == 0 {
199+
fmt.Println("MPT Accounts Loaded:", len(mptLeafMap))
200+
}
201+
}
202+
}()
203+
}
204+
205+
respChan := make(chan map[string][]byte)
206+
go func() {
207+
mptWg.Wait()
208+
respChan <- mptLeafMap
209+
}()
210+
return respChan
211+
}
212+
213+
func loadZkTrie(zkTrie *trie.ZkTrie, parallel bool) chan map[string][]byte {
214+
zkLeafMap := make(map[string][]byte, 1000)
215+
var zkLeafMutex sync.Mutex
216+
zkDone := make(chan map[string][]byte)
217+
go func() {
218+
zkTrie.CountLeaves(func(key, value []byte) {
219+
preimageKey := zkTrie.GetKey(key)
220+
if len(preimageKey) == 0 {
221+
panic(fmt.Sprintf("preimage not found zk trie %s", hex.EncodeToString(key)))
222+
}
223+
224+
if parallel {
225+
zkLeafMutex.Lock()
226+
}
227+
228+
zkLeafMap[string(dup(preimageKey))] = value
229+
230+
if parallel {
231+
zkLeafMutex.Unlock()
232+
}
233+
234+
if parallel && len(zkLeafMap)%10000 == 0 {
235+
fmt.Println("ZK Accounts Loaded:", len(zkLeafMap))
236+
}
237+
}, parallel)
238+
zkDone <- zkLeafMap
239+
}()
240+
return zkDone
241+
}

trie/secure_trie.go

+10
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,16 @@ func NewSecure(root common.Hash, db *Database) (*SecureTrie, error) {
6565
return &SecureTrie{trie: *trie, preimages: db.preimages}, nil
6666
}
6767

68+
func NewSecureNoTracer(root common.Hash, db *Database) (*SecureTrie, error) {
69+
t, err := NewSecure(root, db)
70+
if err != nil {
71+
return nil, err
72+
}
73+
74+
t.trie.tracer = nil
75+
return t, nil
76+
}
77+
6878
// Get returns the value for key stored in the trie.
6979
// The value bytes must not be modified by the caller.
7080
func (t *SecureTrie) Get(key []byte) []byte {

trie/tracer.go

+24
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,21 @@ func newTracer() *tracer {
6060
// blob internally. Don't change the value outside of function since
6161
// it's not deep-copied.
6262
func (t *tracer) onRead(path []byte, val []byte) {
63+
if t == nil {
64+
return
65+
}
66+
6367
t.accessList[string(path)] = val
6468
}
6569

6670
// onInsert tracks the newly inserted trie node. If it's already
6771
// in the deletion set (resurrected node), then just wipe it from
6872
// the deletion set as it's "untouched".
6973
func (t *tracer) onInsert(path []byte) {
74+
if t == nil {
75+
return
76+
}
77+
7078
if _, present := t.deletes[string(path)]; present {
7179
delete(t.deletes, string(path))
7280
return
@@ -78,6 +86,10 @@ func (t *tracer) onInsert(path []byte) {
7886
// in the addition set, then just wipe it from the addition set
7987
// as it's untouched.
8088
func (t *tracer) onDelete(path []byte) {
89+
if t == nil {
90+
return
91+
}
92+
8193
if _, present := t.inserts[string(path)]; present {
8294
delete(t.inserts, string(path))
8395
return
@@ -87,13 +99,21 @@ func (t *tracer) onDelete(path []byte) {
8799

88100
// reset clears the content tracked by tracer.
89101
func (t *tracer) reset() {
102+
if t == nil {
103+
return
104+
}
105+
90106
t.inserts = make(map[string]struct{})
91107
t.deletes = make(map[string]struct{})
92108
t.accessList = make(map[string][]byte)
93109
}
94110

95111
// copy returns a deep copied tracer instance.
96112
func (t *tracer) copy() *tracer {
113+
if t == nil {
114+
return nil
115+
}
116+
97117
accessList := make(map[string][]byte, len(t.accessList))
98118
for path, blob := range t.accessList {
99119
accessList[path] = common.CopyBytes(blob)
@@ -107,6 +127,10 @@ func (t *tracer) copy() *tracer {
107127

108128
// deletedNodes returns a list of node paths which are deleted from the trie.
109129
func (t *tracer) deletedNodes() []string {
130+
if t == nil {
131+
return nil
132+
}
133+
110134
var paths []string
111135
for path := range t.deletes {
112136
// It's possible a few deleted nodes were embedded

trie/zk_trie.go

+39
Original file line numberDiff line numberDiff line change
@@ -238,3 +238,42 @@ func VerifyProofSMT(rootHash common.Hash, key []byte, proofDb ethdb.KeyValueRead
238238
func (t *ZkTrie) Witness() map[string]struct{} {
239239
panic("not implemented")
240240
}
241+
242+
func (t *ZkTrie) CountLeaves(cb func(key, value []byte), parallel bool) uint64 {
243+
root, err := t.ZkTrie.Tree().Root()
244+
if err != nil {
245+
panic("CountLeaves cannot get root")
246+
}
247+
return t.countLeaves(root, cb, 0, parallel)
248+
}
249+
250+
func (t *ZkTrie) countLeaves(root *zkt.Hash, cb func(key, value []byte), depth int, parallel bool) uint64 {
251+
if root == nil {
252+
return 0
253+
}
254+
255+
rootNode, err := t.ZkTrie.Tree().GetNode(root)
256+
if err != nil {
257+
panic("countLeaves cannot get rootNode")
258+
}
259+
260+
if rootNode.Type == zktrie.NodeTypeLeaf_New {
261+
cb(append([]byte{}, rootNode.NodeKey.Bytes()...), append([]byte{}, rootNode.Data()...))
262+
return 1
263+
} else {
264+
if parallel && depth < 5 {
265+
count := make(chan uint64)
266+
leftT := t.Copy()
267+
rightT := t.Copy()
268+
go func() {
269+
count <- leftT.countLeaves(rootNode.ChildL, cb, depth+1, parallel)
270+
}()
271+
go func() {
272+
count <- rightT.countLeaves(rootNode.ChildR, cb, depth+1, parallel)
273+
}()
274+
return <-count + <-count
275+
} else {
276+
return t.countLeaves(rootNode.ChildL, cb, depth+1, parallel) + t.countLeaves(rootNode.ChildR, cb, depth+1, parallel)
277+
}
278+
}
279+
}

0 commit comments

Comments
 (0)