diff options
Diffstat (limited to 'trie')
-rw-r--r-- | trie/proof.go | 122 | ||||
-rw-r--r-- | trie/proof_test.go | 139 |
2 files changed, 261 insertions, 0 deletions
diff --git a/trie/proof.go b/trie/proof.go new file mode 100644 index 000000000..a705c49db --- /dev/null +++ b/trie/proof.go @@ -0,0 +1,122 @@ +package trie + +import ( + "bytes" + "errors" + "fmt" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto/sha3" + "github.com/ethereum/go-ethereum/rlp" +) + +// Prove constructs a merkle proof for key. The result contains all +// encoded nodes on the path to the value at key. The value itself is +// also included in the last node and can be retrieved by verifying +// the proof. +// +// The returned proof is nil if the trie does not contain a value for key. +// For existing keys, the proof will have at least one element. +func (t *Trie) Prove(key []byte) []rlp.RawValue { + // Collect all nodes on the path to key. + key = compactHexDecode(key) + nodes := []node{} + tn := t.root + for len(key) > 0 { + switch n := tn.(type) { + case shortNode: + if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) { + // The trie doesn't contain the key. + return nil + } + tn = n.Val + key = key[len(n.Key):] + nodes = append(nodes, n) + case fullNode: + tn = n[key[0]] + key = key[1:] + nodes = append(nodes, n) + case nil: + return nil + case hashNode: + tn = t.resolveHash(n) + default: + panic(fmt.Sprintf("%T: invalid node: %v", tn, tn)) + } + } + if t.hasher == nil { + t.hasher = newHasher() + } + proof := make([]rlp.RawValue, 0, len(nodes)) + for i, n := range nodes { + // Don't bother checking for errors here since hasher panics + // if encoding doesn't work and we're not writing to any database. + n, _ = t.hasher.replaceChildren(n, nil) + hn, _ := t.hasher.store(n, nil, false) + if _, ok := hn.(hashNode); ok || i == 0 { + // If the node's database encoding is a hash (or is the + // root node), it becomes a proof element. + enc, _ := rlp.EncodeToBytes(n) + proof = append(proof, enc) + } + } + return proof +} + +// VerifyProof checks merkle proofs. The given proof must contain the +// value for key in a trie with the given root hash. VerifyProof +// returns an error if the proof contains invalid trie nodes or the +// wrong value. +func VerifyProof(rootHash common.Hash, key []byte, proof []rlp.RawValue) (value []byte, err error) { + key = compactHexDecode(key) + sha := sha3.NewKeccak256() + wantHash := rootHash.Bytes() + for i, buf := range proof { + sha.Reset() + sha.Write(buf) + if !bytes.Equal(sha.Sum(nil), wantHash) { + return nil, fmt.Errorf("bad proof node %d: hash mismatch", i) + } + n, err := decodeNode(buf) + if err != nil { + return nil, fmt.Errorf("bad proof node %d: %v", i, err) + } + keyrest, cld := get(n, key) + switch cld := cld.(type) { + case nil: + return nil, fmt.Errorf("key mismatch at proof node %d", i) + case hashNode: + key = keyrest + wantHash = cld + case valueNode: + if i != len(proof)-1 { + return nil, errors.New("additional nodes at end of proof") + } + return cld, nil + } + } + return nil, errors.New("unexpected end of proof") +} + +func get(tn node, key []byte) ([]byte, node) { + for len(key) > 0 { + switch n := tn.(type) { + case shortNode: + if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) { + return nil, nil + } + tn = n.Val + key = key[len(n.Key):] + case fullNode: + tn = n[key[0]] + key = key[1:] + case hashNode: + return key, n + case nil: + return key, nil + default: + panic(fmt.Sprintf("%T: invalid node: %v", tn, tn)) + } + } + return nil, tn.(valueNode) +} diff --git a/trie/proof_test.go b/trie/proof_test.go new file mode 100644 index 000000000..6b5bef05c --- /dev/null +++ b/trie/proof_test.go @@ -0,0 +1,139 @@ +package trie + +import ( + "bytes" + crand "crypto/rand" + mrand "math/rand" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/rlp" +) + +func init() { + mrand.Seed(time.Now().Unix()) +} + +func TestProof(t *testing.T) { + trie, vals := randomTrie(500) + root := trie.Hash() + for _, kv := range vals { + proof := trie.Prove(kv.k) + if proof == nil { + t.Fatalf("missing key %x while constructing proof", kv.k) + } + val, err := VerifyProof(root, kv.k, proof) + if err != nil { + t.Fatalf("VerifyProof error for key %x: %v\nraw proof: %x", kv.k, err, proof) + } + if !bytes.Equal(val, kv.v) { + t.Fatalf("VerifyProof returned wrong value for key %x: got %x, want %x", kv.k, val, kv.v) + } + } +} + +func TestOneElementProof(t *testing.T) { + trie := new(Trie) + updateString(trie, "k", "v") + proof := trie.Prove([]byte("k")) + if proof == nil { + t.Fatal("nil proof") + } + if len(proof) != 1 { + t.Error("proof should have one element") + } + val, err := VerifyProof(trie.Hash(), []byte("k"), proof) + if err != nil { + t.Fatalf("VerifyProof error: %v\nraw proof: %x", err, proof) + } + if !bytes.Equal(val, []byte("v")) { + t.Fatalf("VerifyProof returned wrong value: got %x, want 'k'", val) + } +} + +func TestVerifyBadProof(t *testing.T) { + trie, vals := randomTrie(800) + root := trie.Hash() + for _, kv := range vals { + proof := trie.Prove(kv.k) + if proof == nil { + t.Fatal("nil proof") + } + mutateByte(proof[mrand.Intn(len(proof))]) + if _, err := VerifyProof(root, kv.k, proof); err == nil { + t.Fatalf("expected proof to fail for key %x", kv.k) + } + } +} + +// mutateByte changes one byte in b. +func mutateByte(b []byte) { + for r := mrand.Intn(len(b)); ; { + new := byte(mrand.Intn(255)) + if new != b[r] { + b[r] = new + break + } + } +} + +func BenchmarkProve(b *testing.B) { + trie, vals := randomTrie(100) + var keys []string + for k := range vals { + keys = append(keys, k) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + kv := vals[keys[i%len(keys)]] + if trie.Prove(kv.k) == nil { + b.Fatalf("nil proof for %x", kv.k) + } + } +} + +func BenchmarkVerifyProof(b *testing.B) { + trie, vals := randomTrie(100) + root := trie.Hash() + var keys []string + var proofs [][]rlp.RawValue + for k := range vals { + keys = append(keys, k) + proofs = append(proofs, trie.Prove([]byte(k))) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + im := i % len(keys) + if _, err := VerifyProof(root, []byte(keys[im]), proofs[im]); err != nil { + b.Fatalf("key %x: error", keys[im], err) + } + } +} + +func randomTrie(n int) (*Trie, map[string]*kv) { + trie := new(Trie) + vals := make(map[string]*kv) + for i := byte(0); i < 100; i++ { + value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false} + value2 := &kv{common.LeftPadBytes([]byte{i + 10}, 32), []byte{i}, false} + trie.Update(value.k, value.v) + trie.Update(value2.k, value2.v) + vals[string(value.k)] = value + vals[string(value2.k)] = value2 + } + for i := 0; i < n; i++ { + value := &kv{randBytes(32), randBytes(20), false} + trie.Update(value.k, value.v) + vals[string(value.k)] = value + } + return trie, vals +} + +func randBytes(n int) []byte { + r := make([]byte, n) + crand.Read(r) + return r +} |