aboutsummaryrefslogtreecommitdiffstats
path: root/trie
diff options
context:
space:
mode:
authorYondon Fu <yondon.fu@gmail.com>2017-12-19 06:17:41 +0800
committerYondon Fu <yondon.fu@gmail.com>2017-12-19 06:17:41 +0800
commit3857cdc267e3192697f561df0a0f827f65dfb6b5 (patch)
tree401c52c4972a68229ea283a394a0b0a5f3cfdc8e /trie
parenta5330fe0c569b75cb8a524f60f7e8dc06498262b (diff)
parentfe070ab5c32702033489f1b9d1655ea1b894c29e (diff)
downloaddexon-3857cdc267e3192697f561df0a0f827f65dfb6b5.tar
dexon-3857cdc267e3192697f561df0a0f827f65dfb6b5.tar.gz
dexon-3857cdc267e3192697f561df0a0f827f65dfb6b5.tar.bz2
dexon-3857cdc267e3192697f561df0a0f827f65dfb6b5.tar.lz
dexon-3857cdc267e3192697f561df0a0f827f65dfb6b5.tar.xz
dexon-3857cdc267e3192697f561df0a0f827f65dfb6b5.tar.zst
dexon-3857cdc267e3192697f561df0a0f827f65dfb6b5.zip
Merge branch 'master' into abi-offset-fixed-arrays
Diffstat (limited to 'trie')
-rw-r--r--trie/hasher.go111
-rw-r--r--trie/proof.go53
-rw-r--r--trie/proof_test.go52
-rw-r--r--trie/secure_trie.go8
-rw-r--r--trie/trie.go1
-rw-r--r--trie/trie_test.go44
6 files changed, 175 insertions, 94 deletions
diff --git a/trie/hasher.go b/trie/hasher.go
index 4719aabf6..5186d7669 100644
--- a/trie/hasher.go
+++ b/trie/hasher.go
@@ -26,27 +26,46 @@ import (
"github.com/ethereum/go-ethereum/rlp"
)
-type hasher struct {
- tmp *bytes.Buffer
- sha hash.Hash
- cachegen, cachelimit uint16
+// calculator is a utility used by the hasher to calculate the hash value of the tree node.
+type calculator struct {
+ sha hash.Hash
+ buffer *bytes.Buffer
}
-// hashers live in a global pool.
-var hasherPool = sync.Pool{
+// calculatorPool is a set of temporary calculators that may be individually saved and retrieved.
+var calculatorPool = sync.Pool{
New: func() interface{} {
- return &hasher{tmp: new(bytes.Buffer), sha: sha3.NewKeccak256()}
+ return &calculator{buffer: new(bytes.Buffer), sha: sha3.NewKeccak256()}
},
}
+// hasher hasher is used to calculate the hash value of the whole tree.
+type hasher struct {
+ cachegen uint16
+ cachelimit uint16
+ threaded bool
+ mu sync.Mutex
+}
+
func newHasher(cachegen, cachelimit uint16) *hasher {
- h := hasherPool.Get().(*hasher)
- h.cachegen, h.cachelimit = cachegen, cachelimit
+ h := &hasher{
+ cachegen: cachegen,
+ cachelimit: cachelimit,
+ }
return h
}
-func returnHasherToPool(h *hasher) {
- hasherPool.Put(h)
+// newCalculator retrieves a cleaned calculator from calculator pool.
+func (h *hasher) newCalculator() *calculator {
+ calculator := calculatorPool.Get().(*calculator)
+ calculator.buffer.Reset()
+ calculator.sha.Reset()
+ return calculator
+}
+
+// returnCalculator returns a no longer used calculator to the pool.
+func (h *hasher) returnCalculator(calculator *calculator) {
+ calculatorPool.Put(calculator)
}
// hash collapses a node down into a hash node, also returning a copy of the
@@ -123,16 +142,49 @@ func (h *hasher) hashChildren(original node, db DatabaseWriter) (node, node, err
// Hash the full node's children, caching the newly hashed subtrees
collapsed, cached := n.copy(), n.copy()
- for i := 0; i < 16; i++ {
- if n.Children[i] != nil {
- collapsed.Children[i], cached.Children[i], err = h.hash(n.Children[i], db, false)
- if err != nil {
- return original, original, err
- }
- } else {
- collapsed.Children[i] = valueNode(nil) // Ensure that nil children are encoded as empty strings.
+ // hashChild is a helper to hash a single child, which is called either on the
+ // same thread as the caller or in a goroutine for the toplevel branching.
+ hashChild := func(index int, wg *sync.WaitGroup) {
+ if wg != nil {
+ defer wg.Done()
+ }
+ // Ensure that nil children are encoded as empty strings.
+ if collapsed.Children[index] == nil {
+ collapsed.Children[index] = valueNode(nil)
+ return
+ }
+ // Hash all other children properly
+ var herr error
+ collapsed.Children[index], cached.Children[index], herr = h.hash(n.Children[index], db, false)
+ if herr != nil {
+ h.mu.Lock() // rarely if ever locked, no congenstion
+ err = herr
+ h.mu.Unlock()
}
}
+ // If we're not running in threaded mode yet, span a goroutine for each child
+ if !h.threaded {
+ // Disable further threading
+ h.threaded = true
+
+ // Hash all the children concurrently
+ var wg sync.WaitGroup
+ for i := 0; i < 16; i++ {
+ wg.Add(1)
+ go hashChild(i, &wg)
+ }
+ wg.Wait()
+
+ // Reenable threading for subsequent hash calls
+ h.threaded = false
+ } else {
+ for i := 0; i < 16; i++ {
+ hashChild(i, nil)
+ }
+ }
+ if err != nil {
+ return original, original, err
+ }
cached.Children[16] = n.Children[16]
if collapsed.Children[16] == nil {
collapsed.Children[16] = valueNode(nil)
@@ -150,24 +202,29 @@ func (h *hasher) store(n node, db DatabaseWriter, force bool) (node, error) {
if _, isHash := n.(hashNode); n == nil || isHash {
return n, nil
}
+ calculator := h.newCalculator()
+ defer h.returnCalculator(calculator)
+
// Generate the RLP encoding of the node
- h.tmp.Reset()
- if err := rlp.Encode(h.tmp, n); err != nil {
+ if err := rlp.Encode(calculator.buffer, n); err != nil {
panic("encode error: " + err.Error())
}
-
- if h.tmp.Len() < 32 && !force {
+ if calculator.buffer.Len() < 32 && !force {
return n, nil // Nodes smaller than 32 bytes are stored inside their parent
}
// Larger nodes are replaced by their hash and stored in the database.
hash, _ := n.cache()
if hash == nil {
- h.sha.Reset()
- h.sha.Write(h.tmp.Bytes())
- hash = hashNode(h.sha.Sum(nil))
+ calculator.sha.Write(calculator.buffer.Bytes())
+ hash = hashNode(calculator.sha.Sum(nil))
}
if db != nil {
- return hash, db.Put(hash, h.tmp.Bytes())
+ // db might be a leveldb batch, which is not safe for concurrent writes
+ h.mu.Lock()
+ err := db.Put(hash, calculator.buffer.Bytes())
+ h.mu.Unlock()
+
+ return hash, err
}
return hash, nil
}
diff --git a/trie/proof.go b/trie/proof.go
index 298f648c4..5e886a259 100644
--- a/trie/proof.go
+++ b/trie/proof.go
@@ -18,11 +18,10 @@ package trie
import (
"bytes"
- "errors"
"fmt"
"github.com/ethereum/go-ethereum/common"
- "github.com/ethereum/go-ethereum/crypto/sha3"
+ "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/rlp"
)
@@ -36,7 +35,7 @@ import (
// contains all nodes of the longest existing prefix of the key
// (at least the root node), ending with the node that proves the
// absence of the key.
-func (t *Trie) Prove(key []byte) []rlp.RawValue {
+func (t *Trie) Prove(key []byte, fromLevel uint, proofDb DatabaseWriter) error {
// Collect all nodes on the path to key.
key = keybytesToHex(key)
nodes := []node{}
@@ -61,67 +60,63 @@ func (t *Trie) Prove(key []byte) []rlp.RawValue {
tn, err = t.resolveHash(n, nil)
if err != nil {
log.Error(fmt.Sprintf("Unhandled trie error: %v", err))
- return nil
+ return err
}
default:
panic(fmt.Sprintf("%T: invalid node: %v", tn, tn))
}
}
hasher := newHasher(0, 0)
- proof := make([]rlp.RawValue, 0, len(nodes))
for i, n := range nodes {
// Don't bother checking for errors here since hasher panics
// if encoding doesn't work and we're not writing to any database.
n, _, _ = hasher.hashChildren(n, nil)
hn, _ := hasher.store(n, nil, false)
- if _, ok := hn.(hashNode); ok || i == 0 {
+ if hash, ok := hn.(hashNode); ok || i == 0 {
// If the node's database encoding is a hash (or is the
// root node), it becomes a proof element.
- enc, _ := rlp.EncodeToBytes(n)
- proof = append(proof, enc)
+ if fromLevel > 0 {
+ fromLevel--
+ } else {
+ enc, _ := rlp.EncodeToBytes(n)
+ if !ok {
+ hash = crypto.Keccak256(enc)
+ }
+ proofDb.Put(hash, enc)
+ }
}
}
- return proof
+ return nil
}
// VerifyProof checks merkle proofs. The given proof must contain the
// value for key in a trie with the given root hash. VerifyProof
// returns an error if the proof contains invalid trie nodes or the
// wrong value.
-func VerifyProof(rootHash common.Hash, key []byte, proof []rlp.RawValue) (value []byte, err error) {
+func VerifyProof(rootHash common.Hash, key []byte, proofDb DatabaseReader) (value []byte, err error, nodes int) {
key = keybytesToHex(key)
- sha := sha3.NewKeccak256()
- wantHash := rootHash.Bytes()
- for i, buf := range proof {
- sha.Reset()
- sha.Write(buf)
- if !bytes.Equal(sha.Sum(nil), wantHash) {
- return nil, fmt.Errorf("bad proof node %d: hash mismatch", i)
+ wantHash := rootHash[:]
+ for i := 0; ; i++ {
+ buf, _ := proofDb.Get(wantHash)
+ if buf == nil {
+ return nil, fmt.Errorf("proof node %d (hash %064x) missing", i, wantHash[:]), i
}
n, err := decodeNode(wantHash, buf, 0)
if err != nil {
- return nil, fmt.Errorf("bad proof node %d: %v", i, err)
+ return nil, fmt.Errorf("bad proof node %d: %v", i, err), i
}
keyrest, cld := get(n, key)
switch cld := cld.(type) {
case nil:
- if i != len(proof)-1 {
- return nil, fmt.Errorf("key mismatch at proof node %d", i)
- } else {
- // The trie doesn't contain the key.
- return nil, nil
- }
+ // The trie doesn't contain the key.
+ return nil, nil, i
case hashNode:
key = keyrest
wantHash = cld
case valueNode:
- if i != len(proof)-1 {
- return nil, errors.New("additional nodes at end of proof")
- }
- return cld, nil
+ return cld, nil, i + 1
}
}
- return nil, errors.New("unexpected end of proof")
}
func get(tn node, key []byte) ([]byte, node) {
diff --git a/trie/proof_test.go b/trie/proof_test.go
index 91ebcd4a5..fff313d7f 100644
--- a/trie/proof_test.go
+++ b/trie/proof_test.go
@@ -24,7 +24,8 @@ import (
"time"
"github.com/ethereum/go-ethereum/common"
- "github.com/ethereum/go-ethereum/rlp"
+ "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/ethdb"
)
func init() {
@@ -35,13 +36,13 @@ func TestProof(t *testing.T) {
trie, vals := randomTrie(500)
root := trie.Hash()
for _, kv := range vals {
- proof := trie.Prove(kv.k)
- if proof == nil {
+ proofs, _ := ethdb.NewMemDatabase()
+ if trie.Prove(kv.k, 0, proofs) != nil {
t.Fatalf("missing key %x while constructing proof", kv.k)
}
- val, err := VerifyProof(root, kv.k, proof)
+ val, err, _ := VerifyProof(root, kv.k, proofs)
if err != nil {
- t.Fatalf("VerifyProof error for key %x: %v\nraw proof: %x", kv.k, err, proof)
+ t.Fatalf("VerifyProof error for key %x: %v\nraw proof: %v", kv.k, err, proofs)
}
if !bytes.Equal(val, kv.v) {
t.Fatalf("VerifyProof returned wrong value for key %x: got %x, want %x", kv.k, val, kv.v)
@@ -52,16 +53,14 @@ func TestProof(t *testing.T) {
func TestOneElementProof(t *testing.T) {
trie := new(Trie)
updateString(trie, "k", "v")
- proof := trie.Prove([]byte("k"))
- if proof == nil {
- t.Fatal("nil proof")
- }
- if len(proof) != 1 {
+ proofs, _ := ethdb.NewMemDatabase()
+ trie.Prove([]byte("k"), 0, proofs)
+ if len(proofs.Keys()) != 1 {
t.Error("proof should have one element")
}
- val, err := VerifyProof(trie.Hash(), []byte("k"), proof)
+ val, err, _ := VerifyProof(trie.Hash(), []byte("k"), proofs)
if err != nil {
- t.Fatalf("VerifyProof error: %v\nraw proof: %x", err, proof)
+ t.Fatalf("VerifyProof error: %v\nproof hashes: %v", err, proofs.Keys())
}
if !bytes.Equal(val, []byte("v")) {
t.Fatalf("VerifyProof returned wrong value: got %x, want 'k'", val)
@@ -72,12 +71,18 @@ func TestVerifyBadProof(t *testing.T) {
trie, vals := randomTrie(800)
root := trie.Hash()
for _, kv := range vals {
- proof := trie.Prove(kv.k)
- if proof == nil {
- t.Fatal("nil proof")
+ proofs, _ := ethdb.NewMemDatabase()
+ trie.Prove(kv.k, 0, proofs)
+ if len(proofs.Keys()) == 0 {
+ t.Fatal("zero length proof")
}
- mutateByte(proof[mrand.Intn(len(proof))])
- if _, err := VerifyProof(root, kv.k, proof); err == nil {
+ keys := proofs.Keys()
+ key := keys[mrand.Intn(len(keys))]
+ node, _ := proofs.Get(key)
+ proofs.Delete(key)
+ mutateByte(node)
+ proofs.Put(crypto.Keccak256(node), node)
+ if _, err, _ := VerifyProof(root, kv.k, proofs); err == nil {
t.Fatalf("expected proof to fail for key %x", kv.k)
}
}
@@ -104,8 +109,9 @@ func BenchmarkProve(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
kv := vals[keys[i%len(keys)]]
- if trie.Prove(kv.k) == nil {
- b.Fatalf("nil proof for %x", kv.k)
+ proofs, _ := ethdb.NewMemDatabase()
+ if trie.Prove(kv.k, 0, proofs); len(proofs.Keys()) == 0 {
+ b.Fatalf("zero length proof for %x", kv.k)
}
}
}
@@ -114,16 +120,18 @@ func BenchmarkVerifyProof(b *testing.B) {
trie, vals := randomTrie(100)
root := trie.Hash()
var keys []string
- var proofs [][]rlp.RawValue
+ var proofs []*ethdb.MemDatabase
for k := range vals {
keys = append(keys, k)
- proofs = append(proofs, trie.Prove([]byte(k)))
+ proof, _ := ethdb.NewMemDatabase()
+ trie.Prove([]byte(k), 0, proof)
+ proofs = append(proofs, proof)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
im := i % len(keys)
- if _, err := VerifyProof(root, []byte(keys[im]), proofs[im]); err != nil {
+ if _, err, _ := VerifyProof(root, []byte(keys[im]), proofs[im]); err != nil {
b.Fatalf("key %x: %v", keys[im], err)
}
}
diff --git a/trie/secure_trie.go b/trie/secure_trie.go
index 20c303f31..1fde45165 100644
--- a/trie/secure_trie.go
+++ b/trie/secure_trie.go
@@ -199,10 +199,10 @@ func (t *SecureTrie) secKey(key []byte) []byte {
// invalid on the next call to hashKey or secKey.
func (t *SecureTrie) hashKey(key []byte) []byte {
h := newHasher(0, 0)
- h.sha.Reset()
- h.sha.Write(key)
- buf := h.sha.Sum(t.hashKeyBuf[:0])
- returnHasherToPool(h)
+ calculator := h.newCalculator()
+ calculator.sha.Write(key)
+ buf := calculator.sha.Sum(t.hashKeyBuf[:0])
+ h.returnCalculator(calculator)
return buf
}
diff --git a/trie/trie.go b/trie/trie.go
index 7f69a3d1d..c211e7554 100644
--- a/trie/trie.go
+++ b/trie/trie.go
@@ -501,6 +501,5 @@ func (t *Trie) hashRoot(db DatabaseWriter) (node, node, error) {
return hashNode(emptyRoot.Bytes()), nil, nil
}
h := newHasher(t.cachegen, t.cachelimit)
- defer returnHasherToPool(h)
return h.hash(t.root, db, true)
}
diff --git a/trie/trie_test.go b/trie/trie_test.go
index 1c9095070..1e28c3bc4 100644
--- a/trie/trie_test.go
+++ b/trie/trie_test.go
@@ -22,6 +22,7 @@ import (
"errors"
"fmt"
"io/ioutil"
+ "math/big"
"math/rand"
"os"
"reflect"
@@ -30,7 +31,9 @@ import (
"github.com/davecgh/go-spew/spew"
"github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb"
+ "github.com/ethereum/go-ethereum/rlp"
)
func init() {
@@ -505,8 +508,6 @@ func BenchmarkGet(b *testing.B) { benchGet(b, false) }
func BenchmarkGetDB(b *testing.B) { benchGet(b, true) }
func BenchmarkUpdateBE(b *testing.B) { benchUpdate(b, binary.BigEndian) }
func BenchmarkUpdateLE(b *testing.B) { benchUpdate(b, binary.LittleEndian) }
-func BenchmarkHashBE(b *testing.B) { benchHash(b, binary.BigEndian) }
-func BenchmarkHashLE(b *testing.B) { benchHash(b, binary.LittleEndian) }
const benchElemCount = 20000
@@ -549,18 +550,39 @@ func benchUpdate(b *testing.B, e binary.ByteOrder) *Trie {
return trie
}
-func benchHash(b *testing.B, e binary.ByteOrder) {
+// Benchmarks the trie hashing. Since the trie caches the result of any operation,
+// we cannot use b.N as the number of hashing rouns, since all rounds apart from
+// the first one will be NOOP. As such, we'll use b.N as the number of account to
+// insert into the trie before measuring the hashing.
+func BenchmarkHash(b *testing.B) {
+ // Make the random benchmark deterministic
+ random := rand.New(rand.NewSource(0))
+
+ // Create a realistic account trie to hash
+ addresses := make([][20]byte, b.N)
+ for i := 0; i < len(addresses); i++ {
+ for j := 0; j < len(addresses[i]); j++ {
+ addresses[i][j] = byte(random.Intn(256))
+ }
+ }
+ accounts := make([][]byte, len(addresses))
+ for i := 0; i < len(accounts); i++ {
+ var (
+ nonce = uint64(random.Int63())
+ balance = new(big.Int).Rand(random, new(big.Int).Exp(common.Big2, common.Big256, nil))
+ root = emptyRoot
+ code = crypto.Keccak256(nil)
+ )
+ accounts[i], _ = rlp.EncodeToBytes([]interface{}{nonce, balance, root, code})
+ }
+ // Insert the accounts into the trie and hash it
trie := newEmpty()
- k := make([]byte, 32)
- for i := 0; i < benchElemCount; i++ {
- e.PutUint64(k, uint64(i))
- trie.Update(k, k)
+ for i := 0; i < len(addresses); i++ {
+ trie.Update(crypto.Keccak256(addresses[i][:]), accounts[i])
}
-
b.ResetTimer()
- for i := 0; i < b.N; i++ {
- trie.Hash()
- }
+ b.ReportAllocs()
+ trie.Hash()
}
func tempDB() (string, Database) {