diff options
Diffstat (limited to 'ethtrie/iterator.go')
-rw-r--r-- | ethtrie/iterator.go | 143 |
1 files changed, 143 insertions, 0 deletions
diff --git a/ethtrie/iterator.go b/ethtrie/iterator.go new file mode 100644 index 000000000..43f497416 --- /dev/null +++ b/ethtrie/iterator.go @@ -0,0 +1,143 @@ +package ethtrie + +import ( + "bytes" + + "github.com/ethereum/go-ethereum/ethutil" +) + +type NodeType byte + +const ( + EmptyNode NodeType = iota + BranchNode + LeafNode + ExtNode +) + +func getType(node *ethutil.Value) NodeType { + if node.Len() == 0 { + return EmptyNode + } + + if node.Len() == 2 { + k := CompactDecode(node.Get(0).Str()) + if HasTerm(k) { + return LeafNode + } + + return ExtNode + } + + return BranchNode +} + +type Iterator struct { + Path [][]byte + trie *Trie + + Key []byte + Value *ethutil.Value +} + +func NewIterator(trie *Trie) *Iterator { + return &Iterator{trie: trie} +} + +func (self *Iterator) key(node *ethutil.Value, path [][]byte) []byte { + switch getType(node) { + case LeafNode: + k := RemTerm(CompactDecode(node.Get(0).Str())) + + self.Path = append(path, k) + self.Value = node.Get(1) + + return k + case BranchNode: + if node.Get(16).Len() > 0 { + return []byte{16} + } + + for i := byte(0); i < 16; i++ { + o := self.key(self.trie.getNode(node.Get(int(i)).Raw()), append(path, []byte{i})) + if o != nil { + return append([]byte{i}, o...) + } + } + case ExtNode: + currKey := node.Get(0).Bytes() + + return self.key(self.trie.getNode(node.Get(1).Raw()), append(path, currKey)) + } + + return nil +} + +func (self *Iterator) next(node *ethutil.Value, key []byte, path [][]byte) []byte { + switch typ := getType(node); typ { + case EmptyNode: + return nil + case BranchNode: + if len(key) > 0 { + subNode := self.trie.getNode(node.Get(int(key[0])).Raw()) + + o := self.next(subNode, key[1:], append(path, key[:1])) + if o != nil { + return append([]byte{key[0]}, o...) + } + } + + var r byte = 0 + if len(key) > 0 { + r = key[0] + 1 + } + + for i := r; i < 16; i++ { + subNode := self.trie.getNode(node.Get(int(i)).Raw()) + o := self.key(subNode, append(path, []byte{i})) + if o != nil { + return append([]byte{i}, o...) + } + } + case LeafNode, ExtNode: + k := RemTerm(CompactDecode(node.Get(0).Str())) + if typ == LeafNode { + if bytes.Compare([]byte(k), []byte(key)) > 0 { + self.Value = node.Get(1) + self.Path = append(path, k) + + return k + } + } else { + subNode := self.trie.getNode(node.Get(1).Raw()) + subKey := key[len(k):] + var ret []byte + if BeginsWith(key, k) { + ret = self.next(subNode, subKey, append(path, k)) + } else if bytes.Compare(k, key[:len(k)]) > 0 { + ret = self.key(node, append(path, k)) + } else { + ret = nil + } + + if ret != nil { + return append(k, ret...) + } + } + } + + return nil +} + +// Get the next in keys +func (self *Iterator) Next(key string) []byte { + self.trie.mut.Lock() + defer self.trie.mut.Unlock() + + k := RemTerm(CompactHexDecode(key)) + n := self.next(self.trie.getNode(self.trie.Root), k, nil) + + self.Key = []byte(DecodeCompact(n)) + + return self.Key +} |