diff options
author | Yondon Fu <yondon.fu@gmail.com> | 2017-12-19 06:17:41 +0800 |
---|---|---|
committer | Yondon Fu <yondon.fu@gmail.com> | 2017-12-19 06:17:41 +0800 |
commit | 3857cdc267e3192697f561df0a0f827f65dfb6b5 (patch) | |
tree | 401c52c4972a68229ea283a394a0b0a5f3cfdc8e /trie | |
parent | a5330fe0c569b75cb8a524f60f7e8dc06498262b (diff) | |
parent | fe070ab5c32702033489f1b9d1655ea1b894c29e (diff) | |
download | dexon-3857cdc267e3192697f561df0a0f827f65dfb6b5.tar dexon-3857cdc267e3192697f561df0a0f827f65dfb6b5.tar.gz dexon-3857cdc267e3192697f561df0a0f827f65dfb6b5.tar.bz2 dexon-3857cdc267e3192697f561df0a0f827f65dfb6b5.tar.lz dexon-3857cdc267e3192697f561df0a0f827f65dfb6b5.tar.xz dexon-3857cdc267e3192697f561df0a0f827f65dfb6b5.tar.zst dexon-3857cdc267e3192697f561df0a0f827f65dfb6b5.zip |
Merge branch 'master' into abi-offset-fixed-arrays
Diffstat (limited to 'trie')
-rw-r--r-- | trie/hasher.go | 111 | ||||
-rw-r--r-- | trie/proof.go | 53 | ||||
-rw-r--r-- | trie/proof_test.go | 52 | ||||
-rw-r--r-- | trie/secure_trie.go | 8 | ||||
-rw-r--r-- | trie/trie.go | 1 | ||||
-rw-r--r-- | trie/trie_test.go | 44 |
6 files changed, 175 insertions, 94 deletions
diff --git a/trie/hasher.go b/trie/hasher.go index 4719aabf6..5186d7669 100644 --- a/trie/hasher.go +++ b/trie/hasher.go @@ -26,27 +26,46 @@ import ( "github.com/ethereum/go-ethereum/rlp" ) -type hasher struct { - tmp *bytes.Buffer - sha hash.Hash - cachegen, cachelimit uint16 +// calculator is a utility used by the hasher to calculate the hash value of the tree node. +type calculator struct { + sha hash.Hash + buffer *bytes.Buffer } -// hashers live in a global pool. -var hasherPool = sync.Pool{ +// calculatorPool is a set of temporary calculators that may be individually saved and retrieved. +var calculatorPool = sync.Pool{ New: func() interface{} { - return &hasher{tmp: new(bytes.Buffer), sha: sha3.NewKeccak256()} + return &calculator{buffer: new(bytes.Buffer), sha: sha3.NewKeccak256()} }, } +// hasher hasher is used to calculate the hash value of the whole tree. +type hasher struct { + cachegen uint16 + cachelimit uint16 + threaded bool + mu sync.Mutex +} + func newHasher(cachegen, cachelimit uint16) *hasher { - h := hasherPool.Get().(*hasher) - h.cachegen, h.cachelimit = cachegen, cachelimit + h := &hasher{ + cachegen: cachegen, + cachelimit: cachelimit, + } return h } -func returnHasherToPool(h *hasher) { - hasherPool.Put(h) +// newCalculator retrieves a cleaned calculator from calculator pool. +func (h *hasher) newCalculator() *calculator { + calculator := calculatorPool.Get().(*calculator) + calculator.buffer.Reset() + calculator.sha.Reset() + return calculator +} + +// returnCalculator returns a no longer used calculator to the pool. +func (h *hasher) returnCalculator(calculator *calculator) { + calculatorPool.Put(calculator) } // hash collapses a node down into a hash node, also returning a copy of the @@ -123,16 +142,49 @@ func (h *hasher) hashChildren(original node, db DatabaseWriter) (node, node, err // Hash the full node's children, caching the newly hashed subtrees collapsed, cached := n.copy(), n.copy() - for i := 0; i < 16; i++ { - if n.Children[i] != nil { - collapsed.Children[i], cached.Children[i], err = h.hash(n.Children[i], db, false) - if err != nil { - return original, original, err - } - } else { - collapsed.Children[i] = valueNode(nil) // Ensure that nil children are encoded as empty strings. + // hashChild is a helper to hash a single child, which is called either on the + // same thread as the caller or in a goroutine for the toplevel branching. + hashChild := func(index int, wg *sync.WaitGroup) { + if wg != nil { + defer wg.Done() + } + // Ensure that nil children are encoded as empty strings. + if collapsed.Children[index] == nil { + collapsed.Children[index] = valueNode(nil) + return + } + // Hash all other children properly + var herr error + collapsed.Children[index], cached.Children[index], herr = h.hash(n.Children[index], db, false) + if herr != nil { + h.mu.Lock() // rarely if ever locked, no congenstion + err = herr + h.mu.Unlock() } } + // If we're not running in threaded mode yet, span a goroutine for each child + if !h.threaded { + // Disable further threading + h.threaded = true + + // Hash all the children concurrently + var wg sync.WaitGroup + for i := 0; i < 16; i++ { + wg.Add(1) + go hashChild(i, &wg) + } + wg.Wait() + + // Reenable threading for subsequent hash calls + h.threaded = false + } else { + for i := 0; i < 16; i++ { + hashChild(i, nil) + } + } + if err != nil { + return original, original, err + } cached.Children[16] = n.Children[16] if collapsed.Children[16] == nil { collapsed.Children[16] = valueNode(nil) @@ -150,24 +202,29 @@ func (h *hasher) store(n node, db DatabaseWriter, force bool) (node, error) { if _, isHash := n.(hashNode); n == nil || isHash { return n, nil } + calculator := h.newCalculator() + defer h.returnCalculator(calculator) + // Generate the RLP encoding of the node - h.tmp.Reset() - if err := rlp.Encode(h.tmp, n); err != nil { + if err := rlp.Encode(calculator.buffer, n); err != nil { panic("encode error: " + err.Error()) } - - if h.tmp.Len() < 32 && !force { + if calculator.buffer.Len() < 32 && !force { return n, nil // Nodes smaller than 32 bytes are stored inside their parent } // Larger nodes are replaced by their hash and stored in the database. hash, _ := n.cache() if hash == nil { - h.sha.Reset() - h.sha.Write(h.tmp.Bytes()) - hash = hashNode(h.sha.Sum(nil)) + calculator.sha.Write(calculator.buffer.Bytes()) + hash = hashNode(calculator.sha.Sum(nil)) } if db != nil { - return hash, db.Put(hash, h.tmp.Bytes()) + // db might be a leveldb batch, which is not safe for concurrent writes + h.mu.Lock() + err := db.Put(hash, calculator.buffer.Bytes()) + h.mu.Unlock() + + return hash, err } return hash, nil } diff --git a/trie/proof.go b/trie/proof.go index 298f648c4..5e886a259 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -18,11 +18,10 @@ package trie import ( "bytes" - "errors" "fmt" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/crypto/sha3" + "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/rlp" ) @@ -36,7 +35,7 @@ import ( // contains all nodes of the longest existing prefix of the key // (at least the root node), ending with the node that proves the // absence of the key. -func (t *Trie) Prove(key []byte) []rlp.RawValue { +func (t *Trie) Prove(key []byte, fromLevel uint, proofDb DatabaseWriter) error { // Collect all nodes on the path to key. key = keybytesToHex(key) nodes := []node{} @@ -61,67 +60,63 @@ func (t *Trie) Prove(key []byte) []rlp.RawValue { tn, err = t.resolveHash(n, nil) if err != nil { log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) - return nil + return err } default: panic(fmt.Sprintf("%T: invalid node: %v", tn, tn)) } } hasher := newHasher(0, 0) - proof := make([]rlp.RawValue, 0, len(nodes)) for i, n := range nodes { // Don't bother checking for errors here since hasher panics // if encoding doesn't work and we're not writing to any database. n, _, _ = hasher.hashChildren(n, nil) hn, _ := hasher.store(n, nil, false) - if _, ok := hn.(hashNode); ok || i == 0 { + if hash, ok := hn.(hashNode); ok || i == 0 { // If the node's database encoding is a hash (or is the // root node), it becomes a proof element. - enc, _ := rlp.EncodeToBytes(n) - proof = append(proof, enc) + if fromLevel > 0 { + fromLevel-- + } else { + enc, _ := rlp.EncodeToBytes(n) + if !ok { + hash = crypto.Keccak256(enc) + } + proofDb.Put(hash, enc) + } } } - return proof + return nil } // VerifyProof checks merkle proofs. The given proof must contain the // value for key in a trie with the given root hash. VerifyProof // returns an error if the proof contains invalid trie nodes or the // wrong value. -func VerifyProof(rootHash common.Hash, key []byte, proof []rlp.RawValue) (value []byte, err error) { +func VerifyProof(rootHash common.Hash, key []byte, proofDb DatabaseReader) (value []byte, err error, nodes int) { key = keybytesToHex(key) - sha := sha3.NewKeccak256() - wantHash := rootHash.Bytes() - for i, buf := range proof { - sha.Reset() - sha.Write(buf) - if !bytes.Equal(sha.Sum(nil), wantHash) { - return nil, fmt.Errorf("bad proof node %d: hash mismatch", i) + wantHash := rootHash[:] + for i := 0; ; i++ { + buf, _ := proofDb.Get(wantHash) + if buf == nil { + return nil, fmt.Errorf("proof node %d (hash %064x) missing", i, wantHash[:]), i } n, err := decodeNode(wantHash, buf, 0) if err != nil { - return nil, fmt.Errorf("bad proof node %d: %v", i, err) + return nil, fmt.Errorf("bad proof node %d: %v", i, err), i } keyrest, cld := get(n, key) switch cld := cld.(type) { case nil: - if i != len(proof)-1 { - return nil, fmt.Errorf("key mismatch at proof node %d", i) - } else { - // The trie doesn't contain the key. - return nil, nil - } + // The trie doesn't contain the key. + return nil, nil, i case hashNode: key = keyrest wantHash = cld case valueNode: - if i != len(proof)-1 { - return nil, errors.New("additional nodes at end of proof") - } - return cld, nil + return cld, nil, i + 1 } } - return nil, errors.New("unexpected end of proof") } func get(tn node, key []byte) ([]byte, node) { diff --git a/trie/proof_test.go b/trie/proof_test.go index 91ebcd4a5..fff313d7f 100644 --- a/trie/proof_test.go +++ b/trie/proof_test.go @@ -24,7 +24,8 @@ import ( "time" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/ethdb" ) func init() { @@ -35,13 +36,13 @@ func TestProof(t *testing.T) { trie, vals := randomTrie(500) root := trie.Hash() for _, kv := range vals { - proof := trie.Prove(kv.k) - if proof == nil { + proofs, _ := ethdb.NewMemDatabase() + if trie.Prove(kv.k, 0, proofs) != nil { t.Fatalf("missing key %x while constructing proof", kv.k) } - val, err := VerifyProof(root, kv.k, proof) + val, err, _ := VerifyProof(root, kv.k, proofs) if err != nil { - t.Fatalf("VerifyProof error for key %x: %v\nraw proof: %x", kv.k, err, proof) + t.Fatalf("VerifyProof error for key %x: %v\nraw proof: %v", kv.k, err, proofs) } if !bytes.Equal(val, kv.v) { t.Fatalf("VerifyProof returned wrong value for key %x: got %x, want %x", kv.k, val, kv.v) @@ -52,16 +53,14 @@ func TestProof(t *testing.T) { func TestOneElementProof(t *testing.T) { trie := new(Trie) updateString(trie, "k", "v") - proof := trie.Prove([]byte("k")) - if proof == nil { - t.Fatal("nil proof") - } - if len(proof) != 1 { + proofs, _ := ethdb.NewMemDatabase() + trie.Prove([]byte("k"), 0, proofs) + if len(proofs.Keys()) != 1 { t.Error("proof should have one element") } - val, err := VerifyProof(trie.Hash(), []byte("k"), proof) + val, err, _ := VerifyProof(trie.Hash(), []byte("k"), proofs) if err != nil { - t.Fatalf("VerifyProof error: %v\nraw proof: %x", err, proof) + t.Fatalf("VerifyProof error: %v\nproof hashes: %v", err, proofs.Keys()) } if !bytes.Equal(val, []byte("v")) { t.Fatalf("VerifyProof returned wrong value: got %x, want 'k'", val) @@ -72,12 +71,18 @@ func TestVerifyBadProof(t *testing.T) { trie, vals := randomTrie(800) root := trie.Hash() for _, kv := range vals { - proof := trie.Prove(kv.k) - if proof == nil { - t.Fatal("nil proof") + proofs, _ := ethdb.NewMemDatabase() + trie.Prove(kv.k, 0, proofs) + if len(proofs.Keys()) == 0 { + t.Fatal("zero length proof") } - mutateByte(proof[mrand.Intn(len(proof))]) - if _, err := VerifyProof(root, kv.k, proof); err == nil { + keys := proofs.Keys() + key := keys[mrand.Intn(len(keys))] + node, _ := proofs.Get(key) + proofs.Delete(key) + mutateByte(node) + proofs.Put(crypto.Keccak256(node), node) + if _, err, _ := VerifyProof(root, kv.k, proofs); err == nil { t.Fatalf("expected proof to fail for key %x", kv.k) } } @@ -104,8 +109,9 @@ func BenchmarkProve(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { kv := vals[keys[i%len(keys)]] - if trie.Prove(kv.k) == nil { - b.Fatalf("nil proof for %x", kv.k) + proofs, _ := ethdb.NewMemDatabase() + if trie.Prove(kv.k, 0, proofs); len(proofs.Keys()) == 0 { + b.Fatalf("zero length proof for %x", kv.k) } } } @@ -114,16 +120,18 @@ func BenchmarkVerifyProof(b *testing.B) { trie, vals := randomTrie(100) root := trie.Hash() var keys []string - var proofs [][]rlp.RawValue + var proofs []*ethdb.MemDatabase for k := range vals { keys = append(keys, k) - proofs = append(proofs, trie.Prove([]byte(k))) + proof, _ := ethdb.NewMemDatabase() + trie.Prove([]byte(k), 0, proof) + proofs = append(proofs, proof) } b.ResetTimer() for i := 0; i < b.N; i++ { im := i % len(keys) - if _, err := VerifyProof(root, []byte(keys[im]), proofs[im]); err != nil { + if _, err, _ := VerifyProof(root, []byte(keys[im]), proofs[im]); err != nil { b.Fatalf("key %x: %v", keys[im], err) } } diff --git a/trie/secure_trie.go b/trie/secure_trie.go index 20c303f31..1fde45165 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -199,10 +199,10 @@ func (t *SecureTrie) secKey(key []byte) []byte { // invalid on the next call to hashKey or secKey. func (t *SecureTrie) hashKey(key []byte) []byte { h := newHasher(0, 0) - h.sha.Reset() - h.sha.Write(key) - buf := h.sha.Sum(t.hashKeyBuf[:0]) - returnHasherToPool(h) + calculator := h.newCalculator() + calculator.sha.Write(key) + buf := calculator.sha.Sum(t.hashKeyBuf[:0]) + h.returnCalculator(calculator) return buf } diff --git a/trie/trie.go b/trie/trie.go index 7f69a3d1d..c211e7554 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -501,6 +501,5 @@ func (t *Trie) hashRoot(db DatabaseWriter) (node, node, error) { return hashNode(emptyRoot.Bytes()), nil, nil } h := newHasher(t.cachegen, t.cachelimit) - defer returnHasherToPool(h) return h.hash(t.root, db, true) } diff --git a/trie/trie_test.go b/trie/trie_test.go index 1c9095070..1e28c3bc4 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -22,6 +22,7 @@ import ( "errors" "fmt" "io/ioutil" + "math/big" "math/rand" "os" "reflect" @@ -30,7 +31,9 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/rlp" ) func init() { @@ -505,8 +508,6 @@ func BenchmarkGet(b *testing.B) { benchGet(b, false) } func BenchmarkGetDB(b *testing.B) { benchGet(b, true) } func BenchmarkUpdateBE(b *testing.B) { benchUpdate(b, binary.BigEndian) } func BenchmarkUpdateLE(b *testing.B) { benchUpdate(b, binary.LittleEndian) } -func BenchmarkHashBE(b *testing.B) { benchHash(b, binary.BigEndian) } -func BenchmarkHashLE(b *testing.B) { benchHash(b, binary.LittleEndian) } const benchElemCount = 20000 @@ -549,18 +550,39 @@ func benchUpdate(b *testing.B, e binary.ByteOrder) *Trie { return trie } -func benchHash(b *testing.B, e binary.ByteOrder) { +// Benchmarks the trie hashing. Since the trie caches the result of any operation, +// we cannot use b.N as the number of hashing rouns, since all rounds apart from +// the first one will be NOOP. As such, we'll use b.N as the number of account to +// insert into the trie before measuring the hashing. +func BenchmarkHash(b *testing.B) { + // Make the random benchmark deterministic + random := rand.New(rand.NewSource(0)) + + // Create a realistic account trie to hash + addresses := make([][20]byte, b.N) + for i := 0; i < len(addresses); i++ { + for j := 0; j < len(addresses[i]); j++ { + addresses[i][j] = byte(random.Intn(256)) + } + } + accounts := make([][]byte, len(addresses)) + for i := 0; i < len(accounts); i++ { + var ( + nonce = uint64(random.Int63()) + balance = new(big.Int).Rand(random, new(big.Int).Exp(common.Big2, common.Big256, nil)) + root = emptyRoot + code = crypto.Keccak256(nil) + ) + accounts[i], _ = rlp.EncodeToBytes([]interface{}{nonce, balance, root, code}) + } + // Insert the accounts into the trie and hash it trie := newEmpty() - k := make([]byte, 32) - for i := 0; i < benchElemCount; i++ { - e.PutUint64(k, uint64(i)) - trie.Update(k, k) + for i := 0; i < len(addresses); i++ { + trie.Update(crypto.Keccak256(addresses[i][:]), accounts[i]) } - b.ResetTimer() - for i := 0; i < b.N; i++ { - trie.Hash() - } + b.ReportAllocs() + trie.Hash() } func tempDB() (string, Database) { |