diff options
-rw-r--r-- | core/tx_list.go | 21 | ||||
-rw-r--r-- | core/tx_pool.go | 130 | ||||
-rw-r--r-- | core/tx_pool_test.go | 50 | ||||
-rw-r--r-- | core/vm/instructions.go | 2 | ||||
-rw-r--r-- | trie/iterator.go | 71 | ||||
-rw-r--r-- | trie/proof_test.go | 127 |
6 files changed, 283 insertions, 118 deletions
diff --git a/core/tx_list.go b/core/tx_list.go index ea6ee7019..287dda4c3 100644 --- a/core/tx_list.go +++ b/core/tx_list.go @@ -397,13 +397,13 @@ func (h *priceHeap) Pop() interface{} { // txPricedList is a price-sorted heap to allow operating on transactions pool // contents in a price-incrementing way. type txPricedList struct { - all *map[common.Hash]*types.Transaction // Pointer to the map of all transactions - items *priceHeap // Heap of prices of all the stored transactions - stales int // Number of stale price points to (re-heap trigger) + all *txLookup // Pointer to the map of all transactions + items *priceHeap // Heap of prices of all the stored transactions + stales int // Number of stale price points to (re-heap trigger) } // newTxPricedList creates a new price-sorted transaction heap. -func newTxPricedList(all *map[common.Hash]*types.Transaction) *txPricedList { +func newTxPricedList(all *txLookup) *txPricedList { return &txPricedList{ all: all, items: new(priceHeap), @@ -425,12 +425,13 @@ func (l *txPricedList) Removed() { return } // Seems we've reached a critical number of stale transactions, reheap - reheap := make(priceHeap, 0, len(*l.all)) + reheap := make(priceHeap, 0, l.all.Count()) l.stales, l.items = 0, &reheap - for _, tx := range *l.all { + l.all.Range(func(hash common.Hash, tx *types.Transaction) bool { *l.items = append(*l.items, tx) - } + return true + }) heap.Init(l.items) } @@ -443,7 +444,7 @@ func (l *txPricedList) Cap(threshold *big.Int, local *accountSet) types.Transact for len(*l.items) > 0 { // Discard stale transactions if found during cleanup tx := heap.Pop(l.items).(*types.Transaction) - if _, ok := (*l.all)[tx.Hash()]; !ok { + if l.all.Get(tx.Hash()) == nil { l.stales-- continue } @@ -475,7 +476,7 @@ func (l *txPricedList) Underpriced(tx *types.Transaction, local *accountSet) boo // Discard stale price points if found at the heap start for len(*l.items) > 0 { head := []*types.Transaction(*l.items)[0] - if _, ok := (*l.all)[head.Hash()]; !ok { + if l.all.Get(head.Hash()) == nil { l.stales-- heap.Pop(l.items) continue @@ -500,7 +501,7 @@ func (l *txPricedList) Discard(count int, local *accountSet) types.Transactions for len(*l.items) > 0 && count > 0 { // Discard stale transactions if found during cleanup tx := heap.Pop(l.items).(*types.Transaction) - if _, ok := (*l.all)[tx.Hash()]; !ok { + if l.all.Get(tx.Hash()) == nil { l.stales-- continue } diff --git a/core/tx_pool.go b/core/tx_pool.go index f89e11441..1c9516b1b 100644 --- a/core/tx_pool.go +++ b/core/tx_pool.go @@ -200,11 +200,11 @@ type TxPool struct { locals *accountSet // Set of local transaction to exempt from eviction rules journal *txJournal // Journal of local transaction to back up to disk - pending map[common.Address]*txList // All currently processable transactions - queue map[common.Address]*txList // Queued but non-processable transactions - beats map[common.Address]time.Time // Last heartbeat from each known account - all map[common.Hash]*types.Transaction // All transactions to allow lookups - priced *txPricedList // All transactions sorted by price + pending map[common.Address]*txList // All currently processable transactions + queue map[common.Address]*txList // Queued but non-processable transactions + beats map[common.Address]time.Time // Last heartbeat from each known account + all *txLookup // All transactions to allow lookups + priced *txPricedList // All transactions sorted by price wg sync.WaitGroup // for shutdown sync @@ -226,12 +226,12 @@ func NewTxPool(config TxPoolConfig, chainconfig *params.ChainConfig, chain block pending: make(map[common.Address]*txList), queue: make(map[common.Address]*txList), beats: make(map[common.Address]time.Time), - all: make(map[common.Hash]*types.Transaction), + all: newTxLookup(), chainHeadCh: make(chan ChainHeadEvent, chainHeadChanSize), gasPrice: new(big.Int).SetUint64(config.PriceLimit), } pool.locals = newAccountSet(pool.signer) - pool.priced = newTxPricedList(&pool.all) + pool.priced = newTxPricedList(pool.all) pool.reset(nil, chain.CurrentBlock().Header()) // If local transactions and journaling is enabled, load from disk @@ -605,7 +605,7 @@ func (pool *TxPool) validateTx(tx *types.Transaction, local bool) error { func (pool *TxPool) add(tx *types.Transaction, local bool) (bool, error) { // If the transaction is already known, discard it hash := tx.Hash() - if pool.all[hash] != nil { + if pool.all.Get(hash) != nil { log.Trace("Discarding already known transaction", "hash", hash) return false, fmt.Errorf("known transaction: %x", hash) } @@ -616,7 +616,7 @@ func (pool *TxPool) add(tx *types.Transaction, local bool) (bool, error) { return false, err } // If the transaction pool is full, discard underpriced transactions - if uint64(len(pool.all)) >= pool.config.GlobalSlots+pool.config.GlobalQueue { + if uint64(pool.all.Count()) >= pool.config.GlobalSlots+pool.config.GlobalQueue { // If the new transaction is underpriced, don't accept it if !local && pool.priced.Underpriced(tx, pool.locals) { log.Trace("Discarding underpriced transaction", "hash", hash, "price", tx.GasPrice()) @@ -624,7 +624,7 @@ func (pool *TxPool) add(tx *types.Transaction, local bool) (bool, error) { return false, ErrUnderpriced } // New transaction is better than our worse ones, make room for it - drop := pool.priced.Discard(len(pool.all)-int(pool.config.GlobalSlots+pool.config.GlobalQueue-1), pool.locals) + drop := pool.priced.Discard(pool.all.Count()-int(pool.config.GlobalSlots+pool.config.GlobalQueue-1), pool.locals) for _, tx := range drop { log.Trace("Discarding freshly underpriced transaction", "hash", tx.Hash(), "price", tx.GasPrice()) underpricedTxCounter.Inc(1) @@ -642,11 +642,11 @@ func (pool *TxPool) add(tx *types.Transaction, local bool) (bool, error) { } // New transaction is better, replace old one if old != nil { - delete(pool.all, old.Hash()) + pool.all.Remove(old.Hash()) pool.priced.Removed() pendingReplaceCounter.Inc(1) } - pool.all[tx.Hash()] = tx + pool.all.Add(tx) pool.priced.Put(tx) pool.journalTx(from, tx) @@ -689,12 +689,12 @@ func (pool *TxPool) enqueueTx(hash common.Hash, tx *types.Transaction) (bool, er } // Discard any previous transaction and mark this if old != nil { - delete(pool.all, old.Hash()) + pool.all.Remove(old.Hash()) pool.priced.Removed() queuedReplaceCounter.Inc(1) } - if pool.all[hash] == nil { - pool.all[hash] = tx + if pool.all.Get(hash) == nil { + pool.all.Add(tx) pool.priced.Put(tx) } return old != nil, nil @@ -726,7 +726,7 @@ func (pool *TxPool) promoteTx(addr common.Address, hash common.Hash, tx *types.T inserted, old := list.Add(tx, pool.config.PriceBump) if !inserted { // An older transaction was better, discard this - delete(pool.all, hash) + pool.all.Remove(hash) pool.priced.Removed() pendingDiscardCounter.Inc(1) @@ -734,14 +734,14 @@ func (pool *TxPool) promoteTx(addr common.Address, hash common.Hash, tx *types.T } // Otherwise discard any previous transaction and mark this if old != nil { - delete(pool.all, old.Hash()) + pool.all.Remove(old.Hash()) pool.priced.Removed() pendingReplaceCounter.Inc(1) } // Failsafe to work around direct pending inserts (tests) - if pool.all[hash] == nil { - pool.all[hash] = tx + if pool.all.Get(hash) == nil { + pool.all.Add(tx) pool.priced.Put(tx) } // Set the potentially new pending nonce and notify any subsystems of the new tx @@ -840,7 +840,7 @@ func (pool *TxPool) Status(hashes []common.Hash) []TxStatus { status := make([]TxStatus, len(hashes)) for i, hash := range hashes { - if tx := pool.all[hash]; tx != nil { + if tx := pool.all.Get(hash); tx != nil { from, _ := types.Sender(pool.signer, tx) // already validated if pool.pending[from] != nil && pool.pending[from].txs.items[tx.Nonce()] != nil { status[i] = TxStatusPending @@ -855,24 +855,21 @@ func (pool *TxPool) Status(hashes []common.Hash) []TxStatus { // Get returns a transaction if it is contained in the pool // and nil otherwise. func (pool *TxPool) Get(hash common.Hash) *types.Transaction { - pool.mu.RLock() - defer pool.mu.RUnlock() - - return pool.all[hash] + return pool.all.Get(hash) } // removeTx removes a single transaction from the queue, moving all subsequent // transactions back to the future queue. func (pool *TxPool) removeTx(hash common.Hash, outofbound bool) { // Fetch the transaction we wish to delete - tx, ok := pool.all[hash] - if !ok { + tx := pool.all.Get(hash) + if tx == nil { return } addr, _ := types.Sender(pool.signer, tx) // already validated during insertion // Remove it from the list of known transactions - delete(pool.all, hash) + pool.all.Remove(hash) if outofbound { pool.priced.Removed() } @@ -928,7 +925,7 @@ func (pool *TxPool) promoteExecutables(accounts []common.Address) { for _, tx := range list.Forward(pool.currentState.GetNonce(addr)) { hash := tx.Hash() log.Trace("Removed old queued transaction", "hash", hash) - delete(pool.all, hash) + pool.all.Remove(hash) pool.priced.Removed() } // Drop all transactions that are too costly (low balance or out of gas) @@ -936,7 +933,7 @@ func (pool *TxPool) promoteExecutables(accounts []common.Address) { for _, tx := range drops { hash := tx.Hash() log.Trace("Removed unpayable queued transaction", "hash", hash) - delete(pool.all, hash) + pool.all.Remove(hash) pool.priced.Removed() queuedNofundsCounter.Inc(1) } @@ -952,7 +949,7 @@ func (pool *TxPool) promoteExecutables(accounts []common.Address) { if !pool.locals.contains(addr) { for _, tx := range list.Cap(int(pool.config.AccountQueue)) { hash := tx.Hash() - delete(pool.all, hash) + pool.all.Remove(hash) pool.priced.Removed() queuedRateLimitCounter.Inc(1) log.Trace("Removed cap-exceeding queued transaction", "hash", hash) @@ -1001,7 +998,7 @@ func (pool *TxPool) promoteExecutables(accounts []common.Address) { for _, tx := range list.Cap(list.Len() - 1) { // Drop the transaction from the global pools too hash := tx.Hash() - delete(pool.all, hash) + pool.all.Remove(hash) pool.priced.Removed() // Update the account nonce to the dropped transaction @@ -1023,7 +1020,7 @@ func (pool *TxPool) promoteExecutables(accounts []common.Address) { for _, tx := range list.Cap(list.Len() - 1) { // Drop the transaction from the global pools too hash := tx.Hash() - delete(pool.all, hash) + pool.all.Remove(hash) pool.priced.Removed() // Update the account nonce to the dropped transaction @@ -1092,7 +1089,7 @@ func (pool *TxPool) demoteUnexecutables() { for _, tx := range list.Forward(nonce) { hash := tx.Hash() log.Trace("Removed old pending transaction", "hash", hash) - delete(pool.all, hash) + pool.all.Remove(hash) pool.priced.Removed() } // Drop all transactions that are too costly (low balance or out of gas), and queue any invalids back for later @@ -1100,7 +1097,7 @@ func (pool *TxPool) demoteUnexecutables() { for _, tx := range drops { hash := tx.Hash() log.Trace("Removed unpayable pending transaction", "hash", hash) - delete(pool.all, hash) + pool.all.Remove(hash) pool.priced.Removed() pendingNofundsCounter.Inc(1) } @@ -1172,3 +1169,68 @@ func (as *accountSet) containsTx(tx *types.Transaction) bool { func (as *accountSet) add(addr common.Address) { as.accounts[addr] = struct{}{} } + +// txLookup is used internally by TxPool to track transactions while allowing lookup without +// mutex contention. +// +// Note, although this type is properly protected against concurrent access, it +// is **not** a type that should ever be mutated or even exposed outside of the +// transaction pool, since its internal state is tightly coupled with the pools +// internal mechanisms. The sole purpose of the type is to permit out-of-bound +// peeking into the pool in TxPool.Get without having to acquire the widely scoped +// TxPool.mu mutex. +type txLookup struct { + all map[common.Hash]*types.Transaction + lock sync.RWMutex +} + +// newTxLookup returns a new txLookup structure. +func newTxLookup() *txLookup { + return &txLookup{ + all: make(map[common.Hash]*types.Transaction), + } +} + +// Range calls f on each key and value present in the map. +func (t *txLookup) Range(f func(hash common.Hash, tx *types.Transaction) bool) { + t.lock.RLock() + defer t.lock.RUnlock() + + for key, value := range t.all { + if !f(key, value) { + break + } + } +} + +// Get returns a transaction if it exists in the lookup, or nil if not found. +func (t *txLookup) Get(hash common.Hash) *types.Transaction { + t.lock.RLock() + defer t.lock.RUnlock() + + return t.all[hash] +} + +// Count returns the current number of items in the lookup. +func (t *txLookup) Count() int { + t.lock.RLock() + defer t.lock.RUnlock() + + return len(t.all) +} + +// Add adds a transaction to the lookup. +func (t *txLookup) Add(tx *types.Transaction) { + t.lock.Lock() + defer t.lock.Unlock() + + t.all[tx.Hash()] = tx +} + +// Remove removes a transaction from the lookup. +func (t *txLookup) Remove(hash common.Hash) { + t.lock.Lock() + defer t.lock.Unlock() + + delete(t.all, hash) +} diff --git a/core/tx_pool_test.go b/core/tx_pool_test.go index 25993c258..5a5920544 100644 --- a/core/tx_pool_test.go +++ b/core/tx_pool_test.go @@ -94,7 +94,7 @@ func validateTxPoolInternals(pool *TxPool) error { // Ensure the total transaction set is consistent with pending + queued pending, queued := pool.stats() - if total := len(pool.all); total != pending+queued { + if total := pool.all.Count(); total != pending+queued { return fmt.Errorf("total transaction count %d != %d pending + %d queued", total, pending, queued) } if priced := pool.priced.items.Len() - pool.priced.stales; priced != pending+queued { @@ -401,8 +401,8 @@ func TestTransactionDoubleNonce(t *testing.T) { t.Errorf("transaction mismatch: have %x, want %x", tx.Hash(), tx2.Hash()) } // Ensure the total transaction count is correct - if len(pool.all) != 1 { - t.Error("expected 1 total transactions, got", len(pool.all)) + if pool.all.Count() != 1 { + t.Error("expected 1 total transactions, got", pool.all.Count()) } } @@ -424,8 +424,8 @@ func TestTransactionMissingNonce(t *testing.T) { if pool.queue[addr].Len() != 1 { t.Error("expected 1 queued transaction, got", pool.queue[addr].Len()) } - if len(pool.all) != 1 { - t.Error("expected 1 total transactions, got", len(pool.all)) + if pool.all.Count() != 1 { + t.Error("expected 1 total transactions, got", pool.all.Count()) } } @@ -488,8 +488,8 @@ func TestTransactionDropping(t *testing.T) { if pool.queue[account].Len() != 3 { t.Errorf("queued transaction mismatch: have %d, want %d", pool.queue[account].Len(), 3) } - if len(pool.all) != 6 { - t.Errorf("total transaction mismatch: have %d, want %d", len(pool.all), 6) + if pool.all.Count() != 6 { + t.Errorf("total transaction mismatch: have %d, want %d", pool.all.Count(), 6) } pool.lockedReset(nil, nil) if pool.pending[account].Len() != 3 { @@ -498,8 +498,8 @@ func TestTransactionDropping(t *testing.T) { if pool.queue[account].Len() != 3 { t.Errorf("queued transaction mismatch: have %d, want %d", pool.queue[account].Len(), 3) } - if len(pool.all) != 6 { - t.Errorf("total transaction mismatch: have %d, want %d", len(pool.all), 6) + if pool.all.Count() != 6 { + t.Errorf("total transaction mismatch: have %d, want %d", pool.all.Count(), 6) } // Reduce the balance of the account, and check that invalidated transactions are dropped pool.currentState.AddBalance(account, big.NewInt(-650)) @@ -523,8 +523,8 @@ func TestTransactionDropping(t *testing.T) { if _, ok := pool.queue[account].txs.items[tx12.Nonce()]; ok { t.Errorf("out-of-fund queued transaction present: %v", tx11) } - if len(pool.all) != 4 { - t.Errorf("total transaction mismatch: have %d, want %d", len(pool.all), 4) + if pool.all.Count() != 4 { + t.Errorf("total transaction mismatch: have %d, want %d", pool.all.Count(), 4) } // Reduce the block gas limit, check that invalidated transactions are dropped pool.chain.(*testBlockChain).gasLimit = 100 @@ -542,8 +542,8 @@ func TestTransactionDropping(t *testing.T) { if _, ok := pool.queue[account].txs.items[tx11.Nonce()]; ok { t.Errorf("over-gased queued transaction present: %v", tx11) } - if len(pool.all) != 2 { - t.Errorf("total transaction mismatch: have %d, want %d", len(pool.all), 2) + if pool.all.Count() != 2 { + t.Errorf("total transaction mismatch: have %d, want %d", pool.all.Count(), 2) } } @@ -596,8 +596,8 @@ func TestTransactionPostponing(t *testing.T) { if len(pool.queue) != 0 { t.Errorf("queued accounts mismatch: have %d, want %d", len(pool.queue), 0) } - if len(pool.all) != len(txs) { - t.Errorf("total transaction mismatch: have %d, want %d", len(pool.all), len(txs)) + if pool.all.Count() != len(txs) { + t.Errorf("total transaction mismatch: have %d, want %d", pool.all.Count(), len(txs)) } pool.lockedReset(nil, nil) if pending := pool.pending[accs[0]].Len() + pool.pending[accs[1]].Len(); pending != len(txs) { @@ -606,8 +606,8 @@ func TestTransactionPostponing(t *testing.T) { if len(pool.queue) != 0 { t.Errorf("queued accounts mismatch: have %d, want %d", len(pool.queue), 0) } - if len(pool.all) != len(txs) { - t.Errorf("total transaction mismatch: have %d, want %d", len(pool.all), len(txs)) + if pool.all.Count() != len(txs) { + t.Errorf("total transaction mismatch: have %d, want %d", pool.all.Count(), len(txs)) } // Reduce the balance of the account, and check that transactions are reorganised for _, addr := range accs { @@ -656,8 +656,8 @@ func TestTransactionPostponing(t *testing.T) { } } } - if len(pool.all) != len(txs)/2 { - t.Errorf("total transaction mismatch: have %d, want %d", len(pool.all), len(txs)/2) + if pool.all.Count() != len(txs)/2 { + t.Errorf("total transaction mismatch: have %d, want %d", pool.all.Count(), len(txs)/2) } } @@ -748,8 +748,8 @@ func TestTransactionQueueAccountLimiting(t *testing.T) { } } } - if len(pool.all) != int(testTxPoolConfig.AccountQueue) { - t.Errorf("total transaction mismatch: have %d, want %d", len(pool.all), testTxPoolConfig.AccountQueue) + if pool.all.Count() != int(testTxPoolConfig.AccountQueue) { + t.Errorf("total transaction mismatch: have %d, want %d", pool.all.Count(), testTxPoolConfig.AccountQueue) } } @@ -942,8 +942,8 @@ func TestTransactionPendingLimiting(t *testing.T) { t.Errorf("tx %d: queue size mismatch: have %d, want %d", i, pool.queue[account].Len(), 0) } } - if len(pool.all) != int(testTxPoolConfig.AccountQueue+5) { - t.Errorf("total transaction mismatch: have %d, want %d", len(pool.all), testTxPoolConfig.AccountQueue+5) + if pool.all.Count() != int(testTxPoolConfig.AccountQueue+5) { + t.Errorf("total transaction mismatch: have %d, want %d", pool.all.Count(), testTxPoolConfig.AccountQueue+5) } if err := validateEvents(events, int(testTxPoolConfig.AccountQueue+5)); err != nil { t.Fatalf("event firing failed: %v", err) @@ -993,8 +993,8 @@ func testTransactionLimitingEquivalency(t *testing.T, origin uint64) { if len(pool1.queue) != len(pool2.queue) { t.Errorf("queued transaction count mismatch: one-by-one algo: %d, batch algo: %d", len(pool1.queue), len(pool2.queue)) } - if len(pool1.all) != len(pool2.all) { - t.Errorf("total transaction count mismatch: one-by-one algo %d, batch algo %d", len(pool1.all), len(pool2.all)) + if pool1.all.Count() != pool2.all.Count() { + t.Errorf("total transaction count mismatch: one-by-one algo %d, batch algo %d", pool1.all.Count(), pool2.all.Count()) } if err := validateTxPoolInternals(pool1); err != nil { t.Errorf("pool 1 internal state corrupted: %v", err) diff --git a/core/vm/instructions.go b/core/vm/instructions.go index 0689ee39c..3a67e1865 100644 --- a/core/vm/instructions.go +++ b/core/vm/instructions.go @@ -850,7 +850,7 @@ func makePush(size uint64, pushByteSize int) executionFunc { } } -// make push instruction function +// make dup instruction function func makeDup(size int64) executionFunc { return func(pc *uint64, evm *EVM, contract *Contract, memory *Memory, stack *Stack) ([]byte, error) { stack.dup(evm.interpreter.intPool, int(size)) diff --git a/trie/iterator.go b/trie/iterator.go index 64110c6d9..00b890eb8 100644 --- a/trie/iterator.go +++ b/trie/iterator.go @@ -22,6 +22,7 @@ import ( "errors" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/rlp" ) // Iterator is a key-value trie iterator that traverses a Trie. @@ -55,31 +56,50 @@ func (it *Iterator) Next() bool { return false } +// Prove generates the Merkle proof for the leaf node the iterator is currently +// positioned on. +func (it *Iterator) Prove() [][]byte { + return it.nodeIt.LeafProof() +} + // NodeIterator is an iterator to traverse the trie pre-order. type NodeIterator interface { // Next moves the iterator to the next node. If the parameter is false, any child // nodes will be skipped. Next(bool) bool + // Error returns the error status of the iterator. Error() error // Hash returns the hash of the current node. Hash() common.Hash + // Parent returns the hash of the parent of the current node. The hash may be the one // grandparent if the immediate parent is an internal node with no hash. Parent() common.Hash + // Path returns the hex-encoded path to the current node. // Callers must not retain references to the return value after calling Next. // For leaf nodes, the last element of the path is the 'terminator symbol' 0x10. Path() []byte // Leaf returns true iff the current node is a leaf node. - // LeafBlob, LeafKey return the contents and key of the leaf node. These - // method panic if the iterator is not positioned at a leaf. - // Callers must not retain references to their return value after calling Next Leaf() bool - LeafBlob() []byte + + // LeafKey returns the key of the leaf. The method panics if the iterator is not + // positioned at a leaf. Callers must not retain references to the value after + // calling Next. LeafKey() []byte + + // LeafBlob returns the content of the leaf. The method panics if the iterator + // is not positioned at a leaf. Callers must not retain references to the value + // after calling Next. + LeafBlob() []byte + + // LeafProof returns the Merkle proof of the leaf. The method panics if the + // iterator is not positioned at a leaf. Callers must not retain references + // to the value after calling Next. + LeafProof() [][]byte } // nodeIteratorState represents the iteration state at one particular node of the @@ -139,6 +159,15 @@ func (it *nodeIterator) Leaf() bool { return hasTerm(it.path) } +func (it *nodeIterator) LeafKey() []byte { + if len(it.stack) > 0 { + if _, ok := it.stack[len(it.stack)-1].node.(valueNode); ok { + return hexToKeybytes(it.path) + } + } + panic("not at leaf") +} + func (it *nodeIterator) LeafBlob() []byte { if len(it.stack) > 0 { if node, ok := it.stack[len(it.stack)-1].node.(valueNode); ok { @@ -148,10 +177,22 @@ func (it *nodeIterator) LeafBlob() []byte { panic("not at leaf") } -func (it *nodeIterator) LeafKey() []byte { +func (it *nodeIterator) LeafProof() [][]byte { if len(it.stack) > 0 { if _, ok := it.stack[len(it.stack)-1].node.(valueNode); ok { - return hexToKeybytes(it.path) + hasher := newHasher(0, 0, nil) + proofs := make([][]byte, 0, len(it.stack)) + + for i, item := range it.stack[:len(it.stack)-1] { + // Gather nodes that end up as hash nodes (or the root) + node, _, _ := hasher.hashChildren(item.node, nil) + hashed, _ := hasher.store(node, nil, false) + if _, ok := hashed.(hashNode); ok || i == 0 { + enc, _ := rlp.EncodeToBytes(node) + proofs = append(proofs, enc) + } + } + return proofs } } panic("not at leaf") @@ -361,12 +402,16 @@ func (it *differenceIterator) Leaf() bool { return it.b.Leaf() } +func (it *differenceIterator) LeafKey() []byte { + return it.b.LeafKey() +} + func (it *differenceIterator) LeafBlob() []byte { return it.b.LeafBlob() } -func (it *differenceIterator) LeafKey() []byte { - return it.b.LeafKey() +func (it *differenceIterator) LeafProof() [][]byte { + return it.b.LeafProof() } func (it *differenceIterator) Path() []byte { @@ -464,12 +509,16 @@ func (it *unionIterator) Leaf() bool { return (*it.items)[0].Leaf() } +func (it *unionIterator) LeafKey() []byte { + return (*it.items)[0].LeafKey() +} + func (it *unionIterator) LeafBlob() []byte { return (*it.items)[0].LeafBlob() } -func (it *unionIterator) LeafKey() []byte { - return (*it.items)[0].LeafKey() +func (it *unionIterator) LeafProof() [][]byte { + return (*it.items)[0].LeafProof() } func (it *unionIterator) Path() []byte { @@ -509,12 +558,10 @@ func (it *unionIterator) Next(descend bool) bool { heap.Push(it.items, skipped) } } - if least.Next(descend) { it.count++ heap.Push(it.items, least) } - return len(*it.items) > 0 } diff --git a/trie/proof_test.go b/trie/proof_test.go index dee6f7d85..996f87478 100644 --- a/trie/proof_test.go +++ b/trie/proof_test.go @@ -32,20 +32,46 @@ func init() { mrand.Seed(time.Now().Unix()) } +// makeProvers creates Merkle trie provers based on different implementations to +// test all variations. +func makeProvers(trie *Trie) []func(key []byte) *ethdb.MemDatabase { + var provers []func(key []byte) *ethdb.MemDatabase + + // Create a direct trie based Merkle prover + provers = append(provers, func(key []byte) *ethdb.MemDatabase { + proof := ethdb.NewMemDatabase() + trie.Prove(key, 0, proof) + return proof + }) + // Create a leaf iterator based Merkle prover + provers = append(provers, func(key []byte) *ethdb.MemDatabase { + proof := ethdb.NewMemDatabase() + if it := NewIterator(trie.NodeIterator(key)); it.Next() && bytes.Equal(key, it.Key) { + for _, p := range it.Prove() { + proof.Put(crypto.Keccak256(p), p) + } + } + return proof + }) + return provers +} + func TestProof(t *testing.T) { trie, vals := randomTrie(500) root := trie.Hash() - for _, kv := range vals { - proofs := ethdb.NewMemDatabase() - if trie.Prove(kv.k, 0, proofs) != nil { - t.Fatalf("missing key %x while constructing proof", kv.k) - } - val, _, err := VerifyProof(root, kv.k, proofs) - if err != nil { - t.Fatalf("VerifyProof error for key %x: %v\nraw proof: %v", kv.k, err, proofs) - } - if !bytes.Equal(val, kv.v) { - t.Fatalf("VerifyProof returned wrong value for key %x: got %x, want %x", kv.k, val, kv.v) + for i, prover := range makeProvers(trie) { + for _, kv := range vals { + proof := prover(kv.k) + if proof == nil { + t.Fatalf("prover %d: missing key %x while constructing proof", i, kv.k) + } + val, _, err := VerifyProof(root, kv.k, proof) + if err != nil { + t.Fatalf("prover %d: failed to verify proof for key %x: %v\nraw proof: %x", i, kv.k, err, proof) + } + if !bytes.Equal(val, kv.v) { + t.Fatalf("prover %d: verified value mismatch for key %x: have %x, want %x", i, kv.k, val, kv.v) + } } } } @@ -53,37 +79,66 @@ func TestProof(t *testing.T) { func TestOneElementProof(t *testing.T) { trie := new(Trie) updateString(trie, "k", "v") - proofs := ethdb.NewMemDatabase() - trie.Prove([]byte("k"), 0, proofs) - if len(proofs.Keys()) != 1 { - t.Error("proof should have one element") - } - val, _, err := VerifyProof(trie.Hash(), []byte("k"), proofs) - if err != nil { - t.Fatalf("VerifyProof error: %v\nproof hashes: %v", err, proofs.Keys()) - } - if !bytes.Equal(val, []byte("v")) { - t.Fatalf("VerifyProof returned wrong value: got %x, want 'k'", val) + for i, prover := range makeProvers(trie) { + proof := prover([]byte("k")) + if proof == nil { + t.Fatalf("prover %d: nil proof", i) + } + if proof.Len() != 1 { + t.Errorf("prover %d: proof should have one element", i) + } + val, _, err := VerifyProof(trie.Hash(), []byte("k"), proof) + if err != nil { + t.Fatalf("prover %d: failed to verify proof: %v\nraw proof: %x", i, err, proof) + } + if !bytes.Equal(val, []byte("v")) { + t.Fatalf("prover %d: verified value mismatch: have %x, want 'k'", i, val) + } } } -func TestVerifyBadProof(t *testing.T) { +func TestBadProof(t *testing.T) { trie, vals := randomTrie(800) root := trie.Hash() - for _, kv := range vals { - proofs := ethdb.NewMemDatabase() - trie.Prove(kv.k, 0, proofs) - if len(proofs.Keys()) == 0 { - t.Fatal("zero length proof") + for i, prover := range makeProvers(trie) { + for _, kv := range vals { + proof := prover(kv.k) + if proof == nil { + t.Fatalf("prover %d: nil proof", i) + } + key := proof.Keys()[mrand.Intn(proof.Len())] + val, _ := proof.Get(key) + proof.Delete(key) + + mutateByte(val) + proof.Put(crypto.Keccak256(val), val) + + if _, _, err := VerifyProof(root, kv.k, proof); err == nil { + t.Fatalf("prover %d: expected proof to fail for key %x", i, kv.k) + } + } + } +} + +// Tests that missing keys can also be proven. The test explicitly uses a single +// entry trie and checks for missing keys both before and after the single entry. +func TestMissingKeyProof(t *testing.T) { + trie := new(Trie) + updateString(trie, "k", "v") + + for i, key := range []string{"a", "j", "l", "z"} { + proof := ethdb.NewMemDatabase() + trie.Prove([]byte(key), 0, proof) + + if proof.Len() != 1 { + t.Errorf("test %d: proof should have one element", i) + } + val, _, err := VerifyProof(trie.Hash(), []byte(key), proof) + if err != nil { + t.Fatalf("test %d: failed to verify proof: %v\nraw proof: %x", i, err, proof) } - keys := proofs.Keys() - key := keys[mrand.Intn(len(keys))] - node, _ := proofs.Get(key) - proofs.Delete(key) - mutateByte(node) - proofs.Put(crypto.Keccak256(node), node) - if _, _, err := VerifyProof(root, kv.k, proofs); err == nil { - t.Fatalf("expected proof to fail for key %x", kv.k) + if val != nil { + t.Fatalf("test %d: verified value mismatch: have %x, want nil", i, val) } } } |