diff options
author | obscuren <geffobscura@gmail.com> | 2014-11-19 22:05:08 +0800 |
---|---|---|
committer | obscuren <geffobscura@gmail.com> | 2014-11-19 22:05:08 +0800 |
commit | e70529a97785012368e7e0d5b272cccab705e551 (patch) | |
tree | ee74d588bda2352d0026f179f2e7a9e8c210e57b /ptrie | |
parent | 14e2e488fdf0f4d6ed1a5a48ffbbe883faa7edb6 (diff) | |
download | dexon-e70529a97785012368e7e0d5b272cccab705e551.tar dexon-e70529a97785012368e7e0d5b272cccab705e551.tar.gz dexon-e70529a97785012368e7e0d5b272cccab705e551.tar.bz2 dexon-e70529a97785012368e7e0d5b272cccab705e551.tar.lz dexon-e70529a97785012368e7e0d5b272cccab705e551.tar.xz dexon-e70529a97785012368e7e0d5b272cccab705e551.tar.zst dexon-e70529a97785012368e7e0d5b272cccab705e551.zip |
Added new iterator and tests
Diffstat (limited to 'ptrie')
-rw-r--r-- | ptrie/fullnode.go | 5 | ||||
-rw-r--r-- | ptrie/iterator.go | 114 | ||||
-rw-r--r-- | ptrie/iterator_test.go | 28 | ||||
-rw-r--r-- | ptrie/trie.go | 18 | ||||
-rw-r--r-- | ptrie/trie_test.go | 2 |
5 files changed, 159 insertions, 8 deletions
diff --git a/ptrie/fullnode.go b/ptrie/fullnode.go index 2b1a62789..eaa4611b6 100644 --- a/ptrie/fullnode.go +++ b/ptrie/fullnode.go @@ -14,6 +14,9 @@ func (self *FullNode) Value() Node { self.nodes[16] = self.trie.trans(self.nodes[16]) return self.nodes[16] } +func (self *FullNode) Branches() []Node { + return self.nodes[:16] +} func (self *FullNode) Copy() Node { return self } @@ -49,7 +52,7 @@ func (self *FullNode) set(k byte, value Node) { self.nodes[int(k)] = value } -func (self *FullNode) get(i byte) Node { +func (self *FullNode) branch(i byte) Node { if self.nodes[int(i)] != nil { self.nodes[int(i)] = self.trie.trans(self.nodes[int(i)]) diff --git a/ptrie/iterator.go b/ptrie/iterator.go new file mode 100644 index 000000000..c6d4f64a0 --- /dev/null +++ b/ptrie/iterator.go @@ -0,0 +1,114 @@ +package ptrie + +import ( + "bytes" + + "github.com/ethereum/go-ethereum/trie" +) + +type Iterator struct { + trie *Trie + + Key []byte + Value []byte +} + +func NewIterator(trie *Trie) *Iterator { + return &Iterator{trie: trie, Key: []byte{0}} +} + +func (self *Iterator) Next() bool { + self.trie.mu.Lock() + defer self.trie.mu.Unlock() + + key := trie.RemTerm(trie.CompactHexDecode(string(self.Key))) + k := self.next(self.trie.root, key) + + self.Key = []byte(trie.DecodeCompact(k)) + + return len(k) > 0 + +} + +func (self *Iterator) next(node Node, key []byte) []byte { + if node == nil { + return nil + } + + switch node := node.(type) { + case *FullNode: + if len(key) > 0 { + k := self.next(node.branch(key[0]), key[1:]) + if k != nil { + return append([]byte{key[0]}, k...) + } + } + + var r byte + if len(key) > 0 { + r = key[0] + 1 + } + + for i := r; i < 16; i++ { + k := self.key(node.branch(byte(i))) + if k != nil { + return append([]byte{i}, k...) + } + } + + case *ShortNode: + k := trie.RemTerm(node.Key()) + if vnode, ok := node.Value().(*ValueNode); ok { + if bytes.Compare([]byte(k), key) > 0 { + self.Value = vnode.Val() + return k + } + } else { + cnode := node.Value() + skey := key[len(k):] + + var ret []byte + if trie.BeginsWith(key, k) { + ret = self.next(cnode, skey) + } else if bytes.Compare(k, key[:len(k)]) > 0 { + ret = self.key(node) + } + + if ret != nil { + return append(k, ret...) + } + } + } + + return nil +} + +func (self *Iterator) key(node Node) []byte { + switch node := node.(type) { + case *ShortNode: + // Leaf node + if vnode, ok := node.Value().(*ValueNode); ok { + k := trie.RemTerm(node.Key()) + self.Value = vnode.Val() + + return k + } else { + return self.key(node.Value()) + } + case *FullNode: + if node.Value() != nil { + self.Value = node.Value().(*ValueNode).Val() + + return []byte{16} + } + + for i := 0; i < 16; i++ { + k := self.key(node.branch(byte(i))) + if k != nil { + return append([]byte{byte(i)}, k...) + } + } + } + + return nil +} diff --git a/ptrie/iterator_test.go b/ptrie/iterator_test.go new file mode 100644 index 000000000..8921bb670 --- /dev/null +++ b/ptrie/iterator_test.go @@ -0,0 +1,28 @@ +package ptrie + +import "testing" + +func TestIterator(t *testing.T) { + trie := NewEmpty() + vals := []struct{ k, v string }{ + {"do", "verb"}, + {"ether", "wookiedoo"}, + {"horse", "stallion"}, + } + v := make(map[string]bool) + for _, val := range vals { + v[val.k] = false + trie.UpdateString(val.k, val.v) + } + + it := trie.Iterator() + for it.Next() { + v[string(it.Key)] = true + } + + for k, found := range v { + if !found { + t.Error("iterator didn't find", k) + } + } +} diff --git a/ptrie/trie.go b/ptrie/trie.go index 207aad91e..bb2b3845a 100644 --- a/ptrie/trie.go +++ b/ptrie/trie.go @@ -45,6 +45,10 @@ func New(root []byte, backend Backend) *Trie { return trie } +func (self *Trie) Iterator() *Iterator { + return NewIterator(self) +} + // Legacy support func (self *Trie) Root() []byte { return self.Hash() } func (self *Trie) Hash() []byte { @@ -144,7 +148,7 @@ func (self *Trie) insert(node Node, key []byte, value Node) Node { case *FullNode: cpy := node.Copy().(*FullNode) - cpy.set(key[0], self.insert(node.get(key[0]), key[1:], value)) + cpy.set(key[0], self.insert(node.branch(key[0]), key[1:], value)) return cpy @@ -173,7 +177,7 @@ func (self *Trie) get(node Node, key []byte) Node { return nil case *FullNode: - return self.get(node.get(key[0]), key[1:]) + return self.get(node.branch(key[0]), key[1:]) default: panic("Invalid node") } @@ -209,11 +213,11 @@ func (self *Trie) delete(node Node, key []byte) Node { case *FullNode: n := node.Copy().(*FullNode) - n.set(key[0], self.delete(n.get(key[0]), key[1:])) + n.set(key[0], self.delete(n.branch(key[0]), key[1:])) pos := -1 for i := 0; i < 17; i++ { - if n.get(byte(i)) != nil { + if n.branch(byte(i)) != nil { if pos == -1 { pos = i } else { @@ -224,16 +228,16 @@ func (self *Trie) delete(node Node, key []byte) Node { var nnode Node if pos == 16 { - nnode = NewShortNode(self, []byte{16}, n.get(byte(pos))) + nnode = NewShortNode(self, []byte{16}, n.branch(byte(pos))) } else if pos >= 0 { - cnode := n.get(byte(pos)) + cnode := n.branch(byte(pos)) switch cnode := cnode.(type) { case *ShortNode: // Stitch keys k := append([]byte{byte(pos)}, cnode.Key()...) nnode = NewShortNode(self, k, cnode.Value()) case *FullNode: - nnode = NewShortNode(self, []byte{byte(pos)}, n.get(byte(pos))) + nnode = NewShortNode(self, []byte{byte(pos)}, n.branch(byte(pos))) } } else { nnode = n diff --git a/ptrie/trie_test.go b/ptrie/trie_test.go index 6cdd2bde4..6af6e1b40 100644 --- a/ptrie/trie_test.go +++ b/ptrie/trie_test.go @@ -139,6 +139,8 @@ func BenchmarkUpdate(b *testing.B) { // Not actual test func TestOutput(t *testing.T) { + t.Skip() + base := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" trie := NewEmpty() for i := 0; i < 50; i++ { |