package trie import ( "bytes" "github.com/ethereum/go-ethereum/ethutil" ) type NodeType byte const ( EmptyNode NodeType = iota BranchNode LeafNode ExtNode ) func getType(node *ethutil.Value) NodeType { if node.Len() == 0 { return EmptyNode } if node.Len() == 2 { k := CompactDecode(node.Get(0).Str()) if HasTerm(k) { return LeafNode } return ExtNode } return BranchNode } type Iterator struct { Path [][]byte trie *Trie Key []byte Value *ethutil.Value } func NewIterator(trie *Trie) *Iterator { return &Iterator{trie: trie} } func (self *Iterator) key(node *ethutil.Value, path [][]byte) []byte { switch getType(node) { case LeafNode: k := RemTerm(CompactDecode(node.Get(0).Str())) self.Path = append(path, k) self.Value = node.Get(1) return k case BranchNode: if node.Get(16).Len() > 0 { return []byte{16} } for i := byte(0); i < 16; i++ { o := self.key(self.trie.getNode(node.Get(int(i)).Raw()), append(path, []byte{i})) if o != nil { return append([]byte{i}, o...) } } case ExtNode: currKey := node.Get(0).Bytes() return self.key(self.trie.getNode(node.Get(1).Raw()), append(path, currKey)) } return nil } func (self *Iterator) next(node *ethutil.Value, key []byte, path [][]byte) []byte { switch typ := getType(node); typ { case EmptyNode: return nil case BranchNode: if len(key) > 0 { subNode := self.trie.getNode(node.Get(int(key[0])).Raw()) o := self.next(subNode, key[1:], append(path, key[:1])) if o != nil { return append([]byte{key[0]}, o...) } } var r byte = 0 if len(key) > 0 { r = key[0] + 1 } for i := r; i < 16; i++ { subNode := self.trie.getNode(node.Get(int(i)).Raw()) o := self.key(subNode, append(path, []byte{i})) if o != nil { return append([]byte{i}, o...) } } case LeafNode, ExtNode: k := RemTerm(CompactDecode(node.Get(0).Str())) if typ == LeafNode { if bytes.Compare([]byte(k), []byte(key)) > 0 { self.Value = node.Get(1) self.Path = append(path, k) return k } } else { subNode := self.trie.getNode(node.Get(1).Raw()) subKey := key[len(k):] var ret []byte if BeginsWith(key, k) { ret = self.next(subNode, subKey, append(path, k)) } else if bytes.Compare(k, key[:len(k)]) > 0 { ret = self.key(node, append(path, k)) } else { ret = nil } if ret != nil { return append(k, ret...) } } } return nil } // Get the next in keys func (self *Iterator) Next(key string) []byte { self.trie.mut.Lock() defer self.trie.mut.Unlock() k := RemTerm(CompactHexDecode(key)) n := self.next(self.trie.getNode(self.trie.Root), k, nil) self.Key = []byte(DecodeCompact(n)) return self.Key }