aboutsummaryrefslogtreecommitdiffstats
path: root/trie/iterator.go
diff options
context:
space:
mode:
Diffstat (limited to 'trie/iterator.go')
-rw-r--r--trie/iterator.go127
1 files changed, 76 insertions, 51 deletions
diff --git a/trie/iterator.go b/trie/iterator.go
index fef5b2593..26ae1d5ad 100644
--- a/trie/iterator.go
+++ b/trie/iterator.go
@@ -19,10 +19,13 @@ package trie
import (
"bytes"
"container/heap"
+ "errors"
"github.com/ethereum/go-ethereum/common"
)
+var iteratorEnd = errors.New("end of iteration")
+
// Iterator is a key-value trie iterator that traverses a Trie.
type Iterator struct {
nodeIt NodeIterator
@@ -79,25 +82,24 @@ type nodeIteratorState struct {
hash common.Hash // Hash of the node being iterated (nil if not standalone)
node node // Trie node being iterated
parent common.Hash // Hash of the first full ancestor node (nil if current is the root)
- child int // Child to be processed next
+ index int // Child to be processed next
pathlen int // Length of the path to this node
}
type nodeIterator struct {
trie *Trie // Trie being iterated
stack []*nodeIteratorState // Hierarchy of trie nodes persisting the iteration state
-
- err error // Failure set in case of an internal error in the iterator
-
- path []byte // Path to the current node
+ err error // Failure set in case of an internal error in the iterator
+ path []byte // Path to the current node
}
-// newNodeIterator creates an post-order trie iterator.
-func newNodeIterator(trie *Trie) NodeIterator {
+func newNodeIterator(trie *Trie, start []byte) NodeIterator {
if trie.Hash() == emptyState {
return new(nodeIterator)
}
- return &nodeIterator{trie: trie}
+ it := &nodeIterator{trie: trie}
+ it.seek(start)
+ return it
}
// Hash returns the hash of the current node
@@ -147,6 +149,9 @@ func (it *nodeIterator) Path() []byte {
// Error returns the error set in case of an internal error in the iterator
func (it *nodeIterator) Error() error {
+ if it.err == iteratorEnd {
+ return nil
+ }
return it.err
}
@@ -155,47 +160,54 @@ func (it *nodeIterator) Error() error {
// sets the Error field to the encountered failure. If `descend` is false,
// skips iterating over any subnodes of the current node.
func (it *nodeIterator) Next(descend bool) bool {
- // If the iterator failed previously, don't do anything
if it.err != nil {
return false
}
// Otherwise step forward with the iterator and report any errors
- if err := it.step(descend); err != nil {
+ state, parentIndex, path, err := it.peek(descend)
+ if err != nil {
it.err = err
return false
}
- return it.trie != nil
+ it.push(state, parentIndex, path)
+ return true
}
-// step moves the iterator to the next node of the trie.
-func (it *nodeIterator) step(descend bool) error {
- if it.trie == nil {
- // Abort if we reached the end of the iteration
- return nil
+func (it *nodeIterator) seek(prefix []byte) {
+ // The path we're looking for is the hex encoded key without terminator.
+ key := keybytesToHex(prefix)
+ key = key[:len(key)-1]
+ // Move forward until we're just before the closest match to key.
+ for {
+ state, parentIndex, path, err := it.peek(bytes.HasPrefix(key, it.path))
+ if err != nil || bytes.Compare(path, key) >= 0 {
+ it.err = err
+ return
+ }
+ it.push(state, parentIndex, path)
}
+}
+
+// peek creates the next state of the iterator.
+func (it *nodeIterator) peek(descend bool) (*nodeIteratorState, *int, []byte, error) {
if len(it.stack) == 0 {
// Initialize the iterator if we've just started.
root := it.trie.Hash()
- state := &nodeIteratorState{node: it.trie.root, child: -1}
+ state := &nodeIteratorState{node: it.trie.root, index: -1}
if root != emptyRoot {
state.hash = root
}
- it.stack = append(it.stack, state)
- return nil
+ return state, nil, nil, nil
}
-
if !descend {
// If we're skipping children, pop the current node first
- it.path = it.path[:it.stack[len(it.stack)-1].pathlen]
- it.stack = it.stack[:len(it.stack)-1]
+ it.pop()
}
// Continue iteration to the next child
-outer:
for {
if len(it.stack) == 0 {
- it.trie = nil
- return nil
+ return nil, nil, nil, iteratorEnd
}
parent := it.stack[len(it.stack)-1]
ancestor := parent.hash
@@ -203,63 +215,76 @@ outer:
ancestor = parent.parent
}
if node, ok := parent.node.(*fullNode); ok {
- // Full node, iterate over children
- for parent.child++; parent.child < len(node.Children); parent.child++ {
- child := node.Children[parent.child]
+ // Full node, move to the first non-nil child.
+ for i := parent.index + 1; i < len(node.Children); i++ {
+ child := node.Children[i]
if child != nil {
hash, _ := child.cache()
- it.stack = append(it.stack, &nodeIteratorState{
+ state := &nodeIteratorState{
hash: common.BytesToHash(hash),
node: child,
parent: ancestor,
- child: -1,
+ index: -1,
pathlen: len(it.path),
- })
- it.path = append(it.path, byte(parent.child))
- break outer
+ }
+ path := append(it.path, byte(i))
+ parent.index = i - 1
+ return state, &parent.index, path, nil
}
}
} else if node, ok := parent.node.(*shortNode); ok {
// Short node, return the pointer singleton child
- if parent.child < 0 {
- parent.child++
+ if parent.index < 0 {
hash, _ := node.Val.cache()
- it.stack = append(it.stack, &nodeIteratorState{
+ state := &nodeIteratorState{
hash: common.BytesToHash(hash),
node: node.Val,
parent: ancestor,
- child: -1,
+ index: -1,
pathlen: len(it.path),
- })
+ }
+ var path []byte
if hasTerm(node.Key) {
- it.path = append(it.path, node.Key[:len(node.Key)-1]...)
+ path = append(it.path, node.Key[:len(node.Key)-1]...)
} else {
- it.path = append(it.path, node.Key...)
+ path = append(it.path, node.Key...)
}
- break
+ return state, &parent.index, path, nil
}
} else if hash, ok := parent.node.(hashNode); ok {
// Hash node, resolve the hash child from the database
- if parent.child < 0 {
- parent.child++
+ if parent.index < 0 {
node, err := it.trie.resolveHash(hash, nil, nil)
if err != nil {
- return err
+ return it.stack[len(it.stack)-1], &parent.index, it.path, err
}
- it.stack = append(it.stack, &nodeIteratorState{
+ state := &nodeIteratorState{
hash: common.BytesToHash(hash),
node: node,
parent: ancestor,
- child: -1,
+ index: -1,
pathlen: len(it.path),
- })
- break
+ }
+ return state, &parent.index, it.path, nil
}
}
- it.path = it.path[:parent.pathlen]
- it.stack = it.stack[:len(it.stack)-1]
+ // No more child nodes, move back up.
+ it.pop()
}
- return nil
+}
+
+func (it *nodeIterator) push(state *nodeIteratorState, parentIndex *int, path []byte) {
+ it.path = path
+ it.stack = append(it.stack, state)
+ if parentIndex != nil {
+ *parentIndex += 1
+ }
+}
+
+func (it *nodeIterator) pop() {
+ parent := it.stack[len(it.stack)-1]
+ it.path = it.path[:parent.pathlen]
+ it.stack = it.stack[:len(it.stack)-1]
}
func compareNodes(a, b NodeIterator) int {