aboutsummaryrefslogtreecommitdiffstats
path: root/ethtrie/iterator.go
diff options
context:
space:
mode:
Diffstat (limited to 'ethtrie/iterator.go')
-rw-r--r--ethtrie/iterator.go143
1 files changed, 143 insertions, 0 deletions
diff --git a/ethtrie/iterator.go b/ethtrie/iterator.go
new file mode 100644
index 000000000..43f497416
--- /dev/null
+++ b/ethtrie/iterator.go
@@ -0,0 +1,143 @@
+package ethtrie
+
+import (
+ "bytes"
+
+ "github.com/ethereum/go-ethereum/ethutil"
+)
+
+type NodeType byte
+
+const (
+ EmptyNode NodeType = iota
+ BranchNode
+ LeafNode
+ ExtNode
+)
+
+func getType(node *ethutil.Value) NodeType {
+ if node.Len() == 0 {
+ return EmptyNode
+ }
+
+ if node.Len() == 2 {
+ k := CompactDecode(node.Get(0).Str())
+ if HasTerm(k) {
+ return LeafNode
+ }
+
+ return ExtNode
+ }
+
+ return BranchNode
+}
+
+type Iterator struct {
+ Path [][]byte
+ trie *Trie
+
+ Key []byte
+ Value *ethutil.Value
+}
+
+func NewIterator(trie *Trie) *Iterator {
+ return &Iterator{trie: trie}
+}
+
+func (self *Iterator) key(node *ethutil.Value, path [][]byte) []byte {
+ switch getType(node) {
+ case LeafNode:
+ k := RemTerm(CompactDecode(node.Get(0).Str()))
+
+ self.Path = append(path, k)
+ self.Value = node.Get(1)
+
+ return k
+ case BranchNode:
+ if node.Get(16).Len() > 0 {
+ return []byte{16}
+ }
+
+ for i := byte(0); i < 16; i++ {
+ o := self.key(self.trie.getNode(node.Get(int(i)).Raw()), append(path, []byte{i}))
+ if o != nil {
+ return append([]byte{i}, o...)
+ }
+ }
+ case ExtNode:
+ currKey := node.Get(0).Bytes()
+
+ return self.key(self.trie.getNode(node.Get(1).Raw()), append(path, currKey))
+ }
+
+ return nil
+}
+
+func (self *Iterator) next(node *ethutil.Value, key []byte, path [][]byte) []byte {
+ switch typ := getType(node); typ {
+ case EmptyNode:
+ return nil
+ case BranchNode:
+ if len(key) > 0 {
+ subNode := self.trie.getNode(node.Get(int(key[0])).Raw())
+
+ o := self.next(subNode, key[1:], append(path, key[:1]))
+ if o != nil {
+ return append([]byte{key[0]}, o...)
+ }
+ }
+
+ var r byte = 0
+ if len(key) > 0 {
+ r = key[0] + 1
+ }
+
+ for i := r; i < 16; i++ {
+ subNode := self.trie.getNode(node.Get(int(i)).Raw())
+ o := self.key(subNode, append(path, []byte{i}))
+ if o != nil {
+ return append([]byte{i}, o...)
+ }
+ }
+ case LeafNode, ExtNode:
+ k := RemTerm(CompactDecode(node.Get(0).Str()))
+ if typ == LeafNode {
+ if bytes.Compare([]byte(k), []byte(key)) > 0 {
+ self.Value = node.Get(1)
+ self.Path = append(path, k)
+
+ return k
+ }
+ } else {
+ subNode := self.trie.getNode(node.Get(1).Raw())
+ subKey := key[len(k):]
+ var ret []byte
+ if BeginsWith(key, k) {
+ ret = self.next(subNode, subKey, append(path, k))
+ } else if bytes.Compare(k, key[:len(k)]) > 0 {
+ ret = self.key(node, append(path, k))
+ } else {
+ ret = nil
+ }
+
+ if ret != nil {
+ return append(k, ret...)
+ }
+ }
+ }
+
+ return nil
+}
+
+// Get the next in keys
+func (self *Iterator) Next(key string) []byte {
+ self.trie.mut.Lock()
+ defer self.trie.mut.Unlock()
+
+ k := RemTerm(CompactHexDecode(key))
+ n := self.next(self.trie.getNode(self.trie.Root), k, nil)
+
+ self.Key = []byte(DecodeCompact(n))
+
+ return self.Key
+}