aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--core/tx_list.go21
-rw-r--r--core/tx_pool.go130
-rw-r--r--core/tx_pool_test.go50
-rw-r--r--core/vm/instructions.go2
-rw-r--r--trie/iterator.go71
-rw-r--r--trie/proof_test.go127
6 files changed, 283 insertions, 118 deletions
diff --git a/core/tx_list.go b/core/tx_list.go
index ea6ee7019..287dda4c3 100644
--- a/core/tx_list.go
+++ b/core/tx_list.go
@@ -397,13 +397,13 @@ func (h *priceHeap) Pop() interface{} {
// txPricedList is a price-sorted heap to allow operating on transactions pool
// contents in a price-incrementing way.
type txPricedList struct {
- all *map[common.Hash]*types.Transaction // Pointer to the map of all transactions
- items *priceHeap // Heap of prices of all the stored transactions
- stales int // Number of stale price points to (re-heap trigger)
+ all *txLookup // Pointer to the map of all transactions
+ items *priceHeap // Heap of prices of all the stored transactions
+ stales int // Number of stale price points to (re-heap trigger)
}
// newTxPricedList creates a new price-sorted transaction heap.
-func newTxPricedList(all *map[common.Hash]*types.Transaction) *txPricedList {
+func newTxPricedList(all *txLookup) *txPricedList {
return &txPricedList{
all: all,
items: new(priceHeap),
@@ -425,12 +425,13 @@ func (l *txPricedList) Removed() {
return
}
// Seems we've reached a critical number of stale transactions, reheap
- reheap := make(priceHeap, 0, len(*l.all))
+ reheap := make(priceHeap, 0, l.all.Count())
l.stales, l.items = 0, &reheap
- for _, tx := range *l.all {
+ l.all.Range(func(hash common.Hash, tx *types.Transaction) bool {
*l.items = append(*l.items, tx)
- }
+ return true
+ })
heap.Init(l.items)
}
@@ -443,7 +444,7 @@ func (l *txPricedList) Cap(threshold *big.Int, local *accountSet) types.Transact
for len(*l.items) > 0 {
// Discard stale transactions if found during cleanup
tx := heap.Pop(l.items).(*types.Transaction)
- if _, ok := (*l.all)[tx.Hash()]; !ok {
+ if l.all.Get(tx.Hash()) == nil {
l.stales--
continue
}
@@ -475,7 +476,7 @@ func (l *txPricedList) Underpriced(tx *types.Transaction, local *accountSet) boo
// Discard stale price points if found at the heap start
for len(*l.items) > 0 {
head := []*types.Transaction(*l.items)[0]
- if _, ok := (*l.all)[head.Hash()]; !ok {
+ if l.all.Get(head.Hash()) == nil {
l.stales--
heap.Pop(l.items)
continue
@@ -500,7 +501,7 @@ func (l *txPricedList) Discard(count int, local *accountSet) types.Transactions
for len(*l.items) > 0 && count > 0 {
// Discard stale transactions if found during cleanup
tx := heap.Pop(l.items).(*types.Transaction)
- if _, ok := (*l.all)[tx.Hash()]; !ok {
+ if l.all.Get(tx.Hash()) == nil {
l.stales--
continue
}
diff --git a/core/tx_pool.go b/core/tx_pool.go
index f89e11441..1c9516b1b 100644
--- a/core/tx_pool.go
+++ b/core/tx_pool.go
@@ -200,11 +200,11 @@ type TxPool struct {
locals *accountSet // Set of local transaction to exempt from eviction rules
journal *txJournal // Journal of local transaction to back up to disk
- pending map[common.Address]*txList // All currently processable transactions
- queue map[common.Address]*txList // Queued but non-processable transactions
- beats map[common.Address]time.Time // Last heartbeat from each known account
- all map[common.Hash]*types.Transaction // All transactions to allow lookups
- priced *txPricedList // All transactions sorted by price
+ pending map[common.Address]*txList // All currently processable transactions
+ queue map[common.Address]*txList // Queued but non-processable transactions
+ beats map[common.Address]time.Time // Last heartbeat from each known account
+ all *txLookup // All transactions to allow lookups
+ priced *txPricedList // All transactions sorted by price
wg sync.WaitGroup // for shutdown sync
@@ -226,12 +226,12 @@ func NewTxPool(config TxPoolConfig, chainconfig *params.ChainConfig, chain block
pending: make(map[common.Address]*txList),
queue: make(map[common.Address]*txList),
beats: make(map[common.Address]time.Time),
- all: make(map[common.Hash]*types.Transaction),
+ all: newTxLookup(),
chainHeadCh: make(chan ChainHeadEvent, chainHeadChanSize),
gasPrice: new(big.Int).SetUint64(config.PriceLimit),
}
pool.locals = newAccountSet(pool.signer)
- pool.priced = newTxPricedList(&pool.all)
+ pool.priced = newTxPricedList(pool.all)
pool.reset(nil, chain.CurrentBlock().Header())
// If local transactions and journaling is enabled, load from disk
@@ -605,7 +605,7 @@ func (pool *TxPool) validateTx(tx *types.Transaction, local bool) error {
func (pool *TxPool) add(tx *types.Transaction, local bool) (bool, error) {
// If the transaction is already known, discard it
hash := tx.Hash()
- if pool.all[hash] != nil {
+ if pool.all.Get(hash) != nil {
log.Trace("Discarding already known transaction", "hash", hash)
return false, fmt.Errorf("known transaction: %x", hash)
}
@@ -616,7 +616,7 @@ func (pool *TxPool) add(tx *types.Transaction, local bool) (bool, error) {
return false, err
}
// If the transaction pool is full, discard underpriced transactions
- if uint64(len(pool.all)) >= pool.config.GlobalSlots+pool.config.GlobalQueue {
+ if uint64(pool.all.Count()) >= pool.config.GlobalSlots+pool.config.GlobalQueue {
// If the new transaction is underpriced, don't accept it
if !local && pool.priced.Underpriced(tx, pool.locals) {
log.Trace("Discarding underpriced transaction", "hash", hash, "price", tx.GasPrice())
@@ -624,7 +624,7 @@ func (pool *TxPool) add(tx *types.Transaction, local bool) (bool, error) {
return false, ErrUnderpriced
}
// New transaction is better than our worse ones, make room for it
- drop := pool.priced.Discard(len(pool.all)-int(pool.config.GlobalSlots+pool.config.GlobalQueue-1), pool.locals)
+ drop := pool.priced.Discard(pool.all.Count()-int(pool.config.GlobalSlots+pool.config.GlobalQueue-1), pool.locals)
for _, tx := range drop {
log.Trace("Discarding freshly underpriced transaction", "hash", tx.Hash(), "price", tx.GasPrice())
underpricedTxCounter.Inc(1)
@@ -642,11 +642,11 @@ func (pool *TxPool) add(tx *types.Transaction, local bool) (bool, error) {
}
// New transaction is better, replace old one
if old != nil {
- delete(pool.all, old.Hash())
+ pool.all.Remove(old.Hash())
pool.priced.Removed()
pendingReplaceCounter.Inc(1)
}
- pool.all[tx.Hash()] = tx
+ pool.all.Add(tx)
pool.priced.Put(tx)
pool.journalTx(from, tx)
@@ -689,12 +689,12 @@ func (pool *TxPool) enqueueTx(hash common.Hash, tx *types.Transaction) (bool, er
}
// Discard any previous transaction and mark this
if old != nil {
- delete(pool.all, old.Hash())
+ pool.all.Remove(old.Hash())
pool.priced.Removed()
queuedReplaceCounter.Inc(1)
}
- if pool.all[hash] == nil {
- pool.all[hash] = tx
+ if pool.all.Get(hash) == nil {
+ pool.all.Add(tx)
pool.priced.Put(tx)
}
return old != nil, nil
@@ -726,7 +726,7 @@ func (pool *TxPool) promoteTx(addr common.Address, hash common.Hash, tx *types.T
inserted, old := list.Add(tx, pool.config.PriceBump)
if !inserted {
// An older transaction was better, discard this
- delete(pool.all, hash)
+ pool.all.Remove(hash)
pool.priced.Removed()
pendingDiscardCounter.Inc(1)
@@ -734,14 +734,14 @@ func (pool *TxPool) promoteTx(addr common.Address, hash common.Hash, tx *types.T
}
// Otherwise discard any previous transaction and mark this
if old != nil {
- delete(pool.all, old.Hash())
+ pool.all.Remove(old.Hash())
pool.priced.Removed()
pendingReplaceCounter.Inc(1)
}
// Failsafe to work around direct pending inserts (tests)
- if pool.all[hash] == nil {
- pool.all[hash] = tx
+ if pool.all.Get(hash) == nil {
+ pool.all.Add(tx)
pool.priced.Put(tx)
}
// Set the potentially new pending nonce and notify any subsystems of the new tx
@@ -840,7 +840,7 @@ func (pool *TxPool) Status(hashes []common.Hash) []TxStatus {
status := make([]TxStatus, len(hashes))
for i, hash := range hashes {
- if tx := pool.all[hash]; tx != nil {
+ if tx := pool.all.Get(hash); tx != nil {
from, _ := types.Sender(pool.signer, tx) // already validated
if pool.pending[from] != nil && pool.pending[from].txs.items[tx.Nonce()] != nil {
status[i] = TxStatusPending
@@ -855,24 +855,21 @@ func (pool *TxPool) Status(hashes []common.Hash) []TxStatus {
// Get returns a transaction if it is contained in the pool
// and nil otherwise.
func (pool *TxPool) Get(hash common.Hash) *types.Transaction {
- pool.mu.RLock()
- defer pool.mu.RUnlock()
-
- return pool.all[hash]
+ return pool.all.Get(hash)
}
// removeTx removes a single transaction from the queue, moving all subsequent
// transactions back to the future queue.
func (pool *TxPool) removeTx(hash common.Hash, outofbound bool) {
// Fetch the transaction we wish to delete
- tx, ok := pool.all[hash]
- if !ok {
+ tx := pool.all.Get(hash)
+ if tx == nil {
return
}
addr, _ := types.Sender(pool.signer, tx) // already validated during insertion
// Remove it from the list of known transactions
- delete(pool.all, hash)
+ pool.all.Remove(hash)
if outofbound {
pool.priced.Removed()
}
@@ -928,7 +925,7 @@ func (pool *TxPool) promoteExecutables(accounts []common.Address) {
for _, tx := range list.Forward(pool.currentState.GetNonce(addr)) {
hash := tx.Hash()
log.Trace("Removed old queued transaction", "hash", hash)
- delete(pool.all, hash)
+ pool.all.Remove(hash)
pool.priced.Removed()
}
// Drop all transactions that are too costly (low balance or out of gas)
@@ -936,7 +933,7 @@ func (pool *TxPool) promoteExecutables(accounts []common.Address) {
for _, tx := range drops {
hash := tx.Hash()
log.Trace("Removed unpayable queued transaction", "hash", hash)
- delete(pool.all, hash)
+ pool.all.Remove(hash)
pool.priced.Removed()
queuedNofundsCounter.Inc(1)
}
@@ -952,7 +949,7 @@ func (pool *TxPool) promoteExecutables(accounts []common.Address) {
if !pool.locals.contains(addr) {
for _, tx := range list.Cap(int(pool.config.AccountQueue)) {
hash := tx.Hash()
- delete(pool.all, hash)
+ pool.all.Remove(hash)
pool.priced.Removed()
queuedRateLimitCounter.Inc(1)
log.Trace("Removed cap-exceeding queued transaction", "hash", hash)
@@ -1001,7 +998,7 @@ func (pool *TxPool) promoteExecutables(accounts []common.Address) {
for _, tx := range list.Cap(list.Len() - 1) {
// Drop the transaction from the global pools too
hash := tx.Hash()
- delete(pool.all, hash)
+ pool.all.Remove(hash)
pool.priced.Removed()
// Update the account nonce to the dropped transaction
@@ -1023,7 +1020,7 @@ func (pool *TxPool) promoteExecutables(accounts []common.Address) {
for _, tx := range list.Cap(list.Len() - 1) {
// Drop the transaction from the global pools too
hash := tx.Hash()
- delete(pool.all, hash)
+ pool.all.Remove(hash)
pool.priced.Removed()
// Update the account nonce to the dropped transaction
@@ -1092,7 +1089,7 @@ func (pool *TxPool) demoteUnexecutables() {
for _, tx := range list.Forward(nonce) {
hash := tx.Hash()
log.Trace("Removed old pending transaction", "hash", hash)
- delete(pool.all, hash)
+ pool.all.Remove(hash)
pool.priced.Removed()
}
// Drop all transactions that are too costly (low balance or out of gas), and queue any invalids back for later
@@ -1100,7 +1097,7 @@ func (pool *TxPool) demoteUnexecutables() {
for _, tx := range drops {
hash := tx.Hash()
log.Trace("Removed unpayable pending transaction", "hash", hash)
- delete(pool.all, hash)
+ pool.all.Remove(hash)
pool.priced.Removed()
pendingNofundsCounter.Inc(1)
}
@@ -1172,3 +1169,68 @@ func (as *accountSet) containsTx(tx *types.Transaction) bool {
func (as *accountSet) add(addr common.Address) {
as.accounts[addr] = struct{}{}
}
+
+// txLookup is used internally by TxPool to track transactions while allowing lookup without
+// mutex contention.
+//
+// Note, although this type is properly protected against concurrent access, it
+// is **not** a type that should ever be mutated or even exposed outside of the
+// transaction pool, since its internal state is tightly coupled with the pools
+// internal mechanisms. The sole purpose of the type is to permit out-of-bound
+// peeking into the pool in TxPool.Get without having to acquire the widely scoped
+// TxPool.mu mutex.
+type txLookup struct {
+ all map[common.Hash]*types.Transaction
+ lock sync.RWMutex
+}
+
+// newTxLookup returns a new txLookup structure.
+func newTxLookup() *txLookup {
+ return &txLookup{
+ all: make(map[common.Hash]*types.Transaction),
+ }
+}
+
+// Range calls f on each key and value present in the map.
+func (t *txLookup) Range(f func(hash common.Hash, tx *types.Transaction) bool) {
+ t.lock.RLock()
+ defer t.lock.RUnlock()
+
+ for key, value := range t.all {
+ if !f(key, value) {
+ break
+ }
+ }
+}
+
+// Get returns a transaction if it exists in the lookup, or nil if not found.
+func (t *txLookup) Get(hash common.Hash) *types.Transaction {
+ t.lock.RLock()
+ defer t.lock.RUnlock()
+
+ return t.all[hash]
+}
+
+// Count returns the current number of items in the lookup.
+func (t *txLookup) Count() int {
+ t.lock.RLock()
+ defer t.lock.RUnlock()
+
+ return len(t.all)
+}
+
+// Add adds a transaction to the lookup.
+func (t *txLookup) Add(tx *types.Transaction) {
+ t.lock.Lock()
+ defer t.lock.Unlock()
+
+ t.all[tx.Hash()] = tx
+}
+
+// Remove removes a transaction from the lookup.
+func (t *txLookup) Remove(hash common.Hash) {
+ t.lock.Lock()
+ defer t.lock.Unlock()
+
+ delete(t.all, hash)
+}
diff --git a/core/tx_pool_test.go b/core/tx_pool_test.go
index 25993c258..5a5920544 100644
--- a/core/tx_pool_test.go
+++ b/core/tx_pool_test.go
@@ -94,7 +94,7 @@ func validateTxPoolInternals(pool *TxPool) error {
// Ensure the total transaction set is consistent with pending + queued
pending, queued := pool.stats()
- if total := len(pool.all); total != pending+queued {
+ if total := pool.all.Count(); total != pending+queued {
return fmt.Errorf("total transaction count %d != %d pending + %d queued", total, pending, queued)
}
if priced := pool.priced.items.Len() - pool.priced.stales; priced != pending+queued {
@@ -401,8 +401,8 @@ func TestTransactionDoubleNonce(t *testing.T) {
t.Errorf("transaction mismatch: have %x, want %x", tx.Hash(), tx2.Hash())
}
// Ensure the total transaction count is correct
- if len(pool.all) != 1 {
- t.Error("expected 1 total transactions, got", len(pool.all))
+ if pool.all.Count() != 1 {
+ t.Error("expected 1 total transactions, got", pool.all.Count())
}
}
@@ -424,8 +424,8 @@ func TestTransactionMissingNonce(t *testing.T) {
if pool.queue[addr].Len() != 1 {
t.Error("expected 1 queued transaction, got", pool.queue[addr].Len())
}
- if len(pool.all) != 1 {
- t.Error("expected 1 total transactions, got", len(pool.all))
+ if pool.all.Count() != 1 {
+ t.Error("expected 1 total transactions, got", pool.all.Count())
}
}
@@ -488,8 +488,8 @@ func TestTransactionDropping(t *testing.T) {
if pool.queue[account].Len() != 3 {
t.Errorf("queued transaction mismatch: have %d, want %d", pool.queue[account].Len(), 3)
}
- if len(pool.all) != 6 {
- t.Errorf("total transaction mismatch: have %d, want %d", len(pool.all), 6)
+ if pool.all.Count() != 6 {
+ t.Errorf("total transaction mismatch: have %d, want %d", pool.all.Count(), 6)
}
pool.lockedReset(nil, nil)
if pool.pending[account].Len() != 3 {
@@ -498,8 +498,8 @@ func TestTransactionDropping(t *testing.T) {
if pool.queue[account].Len() != 3 {
t.Errorf("queued transaction mismatch: have %d, want %d", pool.queue[account].Len(), 3)
}
- if len(pool.all) != 6 {
- t.Errorf("total transaction mismatch: have %d, want %d", len(pool.all), 6)
+ if pool.all.Count() != 6 {
+ t.Errorf("total transaction mismatch: have %d, want %d", pool.all.Count(), 6)
}
// Reduce the balance of the account, and check that invalidated transactions are dropped
pool.currentState.AddBalance(account, big.NewInt(-650))
@@ -523,8 +523,8 @@ func TestTransactionDropping(t *testing.T) {
if _, ok := pool.queue[account].txs.items[tx12.Nonce()]; ok {
t.Errorf("out-of-fund queued transaction present: %v", tx11)
}
- if len(pool.all) != 4 {
- t.Errorf("total transaction mismatch: have %d, want %d", len(pool.all), 4)
+ if pool.all.Count() != 4 {
+ t.Errorf("total transaction mismatch: have %d, want %d", pool.all.Count(), 4)
}
// Reduce the block gas limit, check that invalidated transactions are dropped
pool.chain.(*testBlockChain).gasLimit = 100
@@ -542,8 +542,8 @@ func TestTransactionDropping(t *testing.T) {
if _, ok := pool.queue[account].txs.items[tx11.Nonce()]; ok {
t.Errorf("over-gased queued transaction present: %v", tx11)
}
- if len(pool.all) != 2 {
- t.Errorf("total transaction mismatch: have %d, want %d", len(pool.all), 2)
+ if pool.all.Count() != 2 {
+ t.Errorf("total transaction mismatch: have %d, want %d", pool.all.Count(), 2)
}
}
@@ -596,8 +596,8 @@ func TestTransactionPostponing(t *testing.T) {
if len(pool.queue) != 0 {
t.Errorf("queued accounts mismatch: have %d, want %d", len(pool.queue), 0)
}
- if len(pool.all) != len(txs) {
- t.Errorf("total transaction mismatch: have %d, want %d", len(pool.all), len(txs))
+ if pool.all.Count() != len(txs) {
+ t.Errorf("total transaction mismatch: have %d, want %d", pool.all.Count(), len(txs))
}
pool.lockedReset(nil, nil)
if pending := pool.pending[accs[0]].Len() + pool.pending[accs[1]].Len(); pending != len(txs) {
@@ -606,8 +606,8 @@ func TestTransactionPostponing(t *testing.T) {
if len(pool.queue) != 0 {
t.Errorf("queued accounts mismatch: have %d, want %d", len(pool.queue), 0)
}
- if len(pool.all) != len(txs) {
- t.Errorf("total transaction mismatch: have %d, want %d", len(pool.all), len(txs))
+ if pool.all.Count() != len(txs) {
+ t.Errorf("total transaction mismatch: have %d, want %d", pool.all.Count(), len(txs))
}
// Reduce the balance of the account, and check that transactions are reorganised
for _, addr := range accs {
@@ -656,8 +656,8 @@ func TestTransactionPostponing(t *testing.T) {
}
}
}
- if len(pool.all) != len(txs)/2 {
- t.Errorf("total transaction mismatch: have %d, want %d", len(pool.all), len(txs)/2)
+ if pool.all.Count() != len(txs)/2 {
+ t.Errorf("total transaction mismatch: have %d, want %d", pool.all.Count(), len(txs)/2)
}
}
@@ -748,8 +748,8 @@ func TestTransactionQueueAccountLimiting(t *testing.T) {
}
}
}
- if len(pool.all) != int(testTxPoolConfig.AccountQueue) {
- t.Errorf("total transaction mismatch: have %d, want %d", len(pool.all), testTxPoolConfig.AccountQueue)
+ if pool.all.Count() != int(testTxPoolConfig.AccountQueue) {
+ t.Errorf("total transaction mismatch: have %d, want %d", pool.all.Count(), testTxPoolConfig.AccountQueue)
}
}
@@ -942,8 +942,8 @@ func TestTransactionPendingLimiting(t *testing.T) {
t.Errorf("tx %d: queue size mismatch: have %d, want %d", i, pool.queue[account].Len(), 0)
}
}
- if len(pool.all) != int(testTxPoolConfig.AccountQueue+5) {
- t.Errorf("total transaction mismatch: have %d, want %d", len(pool.all), testTxPoolConfig.AccountQueue+5)
+ if pool.all.Count() != int(testTxPoolConfig.AccountQueue+5) {
+ t.Errorf("total transaction mismatch: have %d, want %d", pool.all.Count(), testTxPoolConfig.AccountQueue+5)
}
if err := validateEvents(events, int(testTxPoolConfig.AccountQueue+5)); err != nil {
t.Fatalf("event firing failed: %v", err)
@@ -993,8 +993,8 @@ func testTransactionLimitingEquivalency(t *testing.T, origin uint64) {
if len(pool1.queue) != len(pool2.queue) {
t.Errorf("queued transaction count mismatch: one-by-one algo: %d, batch algo: %d", len(pool1.queue), len(pool2.queue))
}
- if len(pool1.all) != len(pool2.all) {
- t.Errorf("total transaction count mismatch: one-by-one algo %d, batch algo %d", len(pool1.all), len(pool2.all))
+ if pool1.all.Count() != pool2.all.Count() {
+ t.Errorf("total transaction count mismatch: one-by-one algo %d, batch algo %d", pool1.all.Count(), pool2.all.Count())
}
if err := validateTxPoolInternals(pool1); err != nil {
t.Errorf("pool 1 internal state corrupted: %v", err)
diff --git a/core/vm/instructions.go b/core/vm/instructions.go
index 0689ee39c..3a67e1865 100644
--- a/core/vm/instructions.go
+++ b/core/vm/instructions.go
@@ -850,7 +850,7 @@ func makePush(size uint64, pushByteSize int) executionFunc {
}
}
-// make push instruction function
+// make dup instruction function
func makeDup(size int64) executionFunc {
return func(pc *uint64, evm *EVM, contract *Contract, memory *Memory, stack *Stack) ([]byte, error) {
stack.dup(evm.interpreter.intPool, int(size))
diff --git a/trie/iterator.go b/trie/iterator.go
index 64110c6d9..00b890eb8 100644
--- a/trie/iterator.go
+++ b/trie/iterator.go
@@ -22,6 +22,7 @@ import (
"errors"
"github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/rlp"
)
// Iterator is a key-value trie iterator that traverses a Trie.
@@ -55,31 +56,50 @@ func (it *Iterator) Next() bool {
return false
}
+// Prove generates the Merkle proof for the leaf node the iterator is currently
+// positioned on.
+func (it *Iterator) Prove() [][]byte {
+ return it.nodeIt.LeafProof()
+}
+
// NodeIterator is an iterator to traverse the trie pre-order.
type NodeIterator interface {
// Next moves the iterator to the next node. If the parameter is false, any child
// nodes will be skipped.
Next(bool) bool
+
// Error returns the error status of the iterator.
Error() error
// Hash returns the hash of the current node.
Hash() common.Hash
+
// Parent returns the hash of the parent of the current node. The hash may be the one
// grandparent if the immediate parent is an internal node with no hash.
Parent() common.Hash
+
// Path returns the hex-encoded path to the current node.
// Callers must not retain references to the return value after calling Next.
// For leaf nodes, the last element of the path is the 'terminator symbol' 0x10.
Path() []byte
// Leaf returns true iff the current node is a leaf node.
- // LeafBlob, LeafKey return the contents and key of the leaf node. These
- // method panic if the iterator is not positioned at a leaf.
- // Callers must not retain references to their return value after calling Next
Leaf() bool
- LeafBlob() []byte
+
+ // LeafKey returns the key of the leaf. The method panics if the iterator is not
+ // positioned at a leaf. Callers must not retain references to the value after
+ // calling Next.
LeafKey() []byte
+
+ // LeafBlob returns the content of the leaf. The method panics if the iterator
+ // is not positioned at a leaf. Callers must not retain references to the value
+ // after calling Next.
+ LeafBlob() []byte
+
+ // LeafProof returns the Merkle proof of the leaf. The method panics if the
+ // iterator is not positioned at a leaf. Callers must not retain references
+ // to the value after calling Next.
+ LeafProof() [][]byte
}
// nodeIteratorState represents the iteration state at one particular node of the
@@ -139,6 +159,15 @@ func (it *nodeIterator) Leaf() bool {
return hasTerm(it.path)
}
+func (it *nodeIterator) LeafKey() []byte {
+ if len(it.stack) > 0 {
+ if _, ok := it.stack[len(it.stack)-1].node.(valueNode); ok {
+ return hexToKeybytes(it.path)
+ }
+ }
+ panic("not at leaf")
+}
+
func (it *nodeIterator) LeafBlob() []byte {
if len(it.stack) > 0 {
if node, ok := it.stack[len(it.stack)-1].node.(valueNode); ok {
@@ -148,10 +177,22 @@ func (it *nodeIterator) LeafBlob() []byte {
panic("not at leaf")
}
-func (it *nodeIterator) LeafKey() []byte {
+func (it *nodeIterator) LeafProof() [][]byte {
if len(it.stack) > 0 {
if _, ok := it.stack[len(it.stack)-1].node.(valueNode); ok {
- return hexToKeybytes(it.path)
+ hasher := newHasher(0, 0, nil)
+ proofs := make([][]byte, 0, len(it.stack))
+
+ for i, item := range it.stack[:len(it.stack)-1] {
+ // Gather nodes that end up as hash nodes (or the root)
+ node, _, _ := hasher.hashChildren(item.node, nil)
+ hashed, _ := hasher.store(node, nil, false)
+ if _, ok := hashed.(hashNode); ok || i == 0 {
+ enc, _ := rlp.EncodeToBytes(node)
+ proofs = append(proofs, enc)
+ }
+ }
+ return proofs
}
}
panic("not at leaf")
@@ -361,12 +402,16 @@ func (it *differenceIterator) Leaf() bool {
return it.b.Leaf()
}
+func (it *differenceIterator) LeafKey() []byte {
+ return it.b.LeafKey()
+}
+
func (it *differenceIterator) LeafBlob() []byte {
return it.b.LeafBlob()
}
-func (it *differenceIterator) LeafKey() []byte {
- return it.b.LeafKey()
+func (it *differenceIterator) LeafProof() [][]byte {
+ return it.b.LeafProof()
}
func (it *differenceIterator) Path() []byte {
@@ -464,12 +509,16 @@ func (it *unionIterator) Leaf() bool {
return (*it.items)[0].Leaf()
}
+func (it *unionIterator) LeafKey() []byte {
+ return (*it.items)[0].LeafKey()
+}
+
func (it *unionIterator) LeafBlob() []byte {
return (*it.items)[0].LeafBlob()
}
-func (it *unionIterator) LeafKey() []byte {
- return (*it.items)[0].LeafKey()
+func (it *unionIterator) LeafProof() [][]byte {
+ return (*it.items)[0].LeafProof()
}
func (it *unionIterator) Path() []byte {
@@ -509,12 +558,10 @@ func (it *unionIterator) Next(descend bool) bool {
heap.Push(it.items, skipped)
}
}
-
if least.Next(descend) {
it.count++
heap.Push(it.items, least)
}
-
return len(*it.items) > 0
}
diff --git a/trie/proof_test.go b/trie/proof_test.go
index dee6f7d85..996f87478 100644
--- a/trie/proof_test.go
+++ b/trie/proof_test.go
@@ -32,20 +32,46 @@ func init() {
mrand.Seed(time.Now().Unix())
}
+// makeProvers creates Merkle trie provers based on different implementations to
+// test all variations.
+func makeProvers(trie *Trie) []func(key []byte) *ethdb.MemDatabase {
+ var provers []func(key []byte) *ethdb.MemDatabase
+
+ // Create a direct trie based Merkle prover
+ provers = append(provers, func(key []byte) *ethdb.MemDatabase {
+ proof := ethdb.NewMemDatabase()
+ trie.Prove(key, 0, proof)
+ return proof
+ })
+ // Create a leaf iterator based Merkle prover
+ provers = append(provers, func(key []byte) *ethdb.MemDatabase {
+ proof := ethdb.NewMemDatabase()
+ if it := NewIterator(trie.NodeIterator(key)); it.Next() && bytes.Equal(key, it.Key) {
+ for _, p := range it.Prove() {
+ proof.Put(crypto.Keccak256(p), p)
+ }
+ }
+ return proof
+ })
+ return provers
+}
+
func TestProof(t *testing.T) {
trie, vals := randomTrie(500)
root := trie.Hash()
- for _, kv := range vals {
- proofs := ethdb.NewMemDatabase()
- if trie.Prove(kv.k, 0, proofs) != nil {
- t.Fatalf("missing key %x while constructing proof", kv.k)
- }
- val, _, err := VerifyProof(root, kv.k, proofs)
- if err != nil {
- t.Fatalf("VerifyProof error for key %x: %v\nraw proof: %v", kv.k, err, proofs)
- }
- if !bytes.Equal(val, kv.v) {
- t.Fatalf("VerifyProof returned wrong value for key %x: got %x, want %x", kv.k, val, kv.v)
+ for i, prover := range makeProvers(trie) {
+ for _, kv := range vals {
+ proof := prover(kv.k)
+ if proof == nil {
+ t.Fatalf("prover %d: missing key %x while constructing proof", i, kv.k)
+ }
+ val, _, err := VerifyProof(root, kv.k, proof)
+ if err != nil {
+ t.Fatalf("prover %d: failed to verify proof for key %x: %v\nraw proof: %x", i, kv.k, err, proof)
+ }
+ if !bytes.Equal(val, kv.v) {
+ t.Fatalf("prover %d: verified value mismatch for key %x: have %x, want %x", i, kv.k, val, kv.v)
+ }
}
}
}
@@ -53,37 +79,66 @@ func TestProof(t *testing.T) {
func TestOneElementProof(t *testing.T) {
trie := new(Trie)
updateString(trie, "k", "v")
- proofs := ethdb.NewMemDatabase()
- trie.Prove([]byte("k"), 0, proofs)
- if len(proofs.Keys()) != 1 {
- t.Error("proof should have one element")
- }
- val, _, err := VerifyProof(trie.Hash(), []byte("k"), proofs)
- if err != nil {
- t.Fatalf("VerifyProof error: %v\nproof hashes: %v", err, proofs.Keys())
- }
- if !bytes.Equal(val, []byte("v")) {
- t.Fatalf("VerifyProof returned wrong value: got %x, want 'k'", val)
+ for i, prover := range makeProvers(trie) {
+ proof := prover([]byte("k"))
+ if proof == nil {
+ t.Fatalf("prover %d: nil proof", i)
+ }
+ if proof.Len() != 1 {
+ t.Errorf("prover %d: proof should have one element", i)
+ }
+ val, _, err := VerifyProof(trie.Hash(), []byte("k"), proof)
+ if err != nil {
+ t.Fatalf("prover %d: failed to verify proof: %v\nraw proof: %x", i, err, proof)
+ }
+ if !bytes.Equal(val, []byte("v")) {
+ t.Fatalf("prover %d: verified value mismatch: have %x, want 'k'", i, val)
+ }
}
}
-func TestVerifyBadProof(t *testing.T) {
+func TestBadProof(t *testing.T) {
trie, vals := randomTrie(800)
root := trie.Hash()
- for _, kv := range vals {
- proofs := ethdb.NewMemDatabase()
- trie.Prove(kv.k, 0, proofs)
- if len(proofs.Keys()) == 0 {
- t.Fatal("zero length proof")
+ for i, prover := range makeProvers(trie) {
+ for _, kv := range vals {
+ proof := prover(kv.k)
+ if proof == nil {
+ t.Fatalf("prover %d: nil proof", i)
+ }
+ key := proof.Keys()[mrand.Intn(proof.Len())]
+ val, _ := proof.Get(key)
+ proof.Delete(key)
+
+ mutateByte(val)
+ proof.Put(crypto.Keccak256(val), val)
+
+ if _, _, err := VerifyProof(root, kv.k, proof); err == nil {
+ t.Fatalf("prover %d: expected proof to fail for key %x", i, kv.k)
+ }
+ }
+ }
+}
+
+// Tests that missing keys can also be proven. The test explicitly uses a single
+// entry trie and checks for missing keys both before and after the single entry.
+func TestMissingKeyProof(t *testing.T) {
+ trie := new(Trie)
+ updateString(trie, "k", "v")
+
+ for i, key := range []string{"a", "j", "l", "z"} {
+ proof := ethdb.NewMemDatabase()
+ trie.Prove([]byte(key), 0, proof)
+
+ if proof.Len() != 1 {
+ t.Errorf("test %d: proof should have one element", i)
+ }
+ val, _, err := VerifyProof(trie.Hash(), []byte(key), proof)
+ if err != nil {
+ t.Fatalf("test %d: failed to verify proof: %v\nraw proof: %x", i, err, proof)
}
- keys := proofs.Keys()
- key := keys[mrand.Intn(len(keys))]
- node, _ := proofs.Get(key)
- proofs.Delete(key)
- mutateByte(node)
- proofs.Put(crypto.Keccak256(node), node)
- if _, _, err := VerifyProof(root, kv.k, proofs); err == nil {
- t.Fatalf("expected proof to fail for key %x", kv.k)
+ if val != nil {
+ t.Fatalf("test %d: verified value mismatch: have %x, want nil", i, val)
}
}
}