aboutsummaryrefslogtreecommitdiffstats
path: root/ptrie
diff options
context:
space:
mode:
authorobscuren <geffobscura@gmail.com>2014-11-19 22:05:08 +0800
committerobscuren <geffobscura@gmail.com>2014-11-19 22:05:08 +0800
commite70529a97785012368e7e0d5b272cccab705e551 (patch)
treeee74d588bda2352d0026f179f2e7a9e8c210e57b /ptrie
parent14e2e488fdf0f4d6ed1a5a48ffbbe883faa7edb6 (diff)
downloaddexon-e70529a97785012368e7e0d5b272cccab705e551.tar
dexon-e70529a97785012368e7e0d5b272cccab705e551.tar.gz
dexon-e70529a97785012368e7e0d5b272cccab705e551.tar.bz2
dexon-e70529a97785012368e7e0d5b272cccab705e551.tar.lz
dexon-e70529a97785012368e7e0d5b272cccab705e551.tar.xz
dexon-e70529a97785012368e7e0d5b272cccab705e551.tar.zst
dexon-e70529a97785012368e7e0d5b272cccab705e551.zip
Added new iterator and tests
Diffstat (limited to 'ptrie')
-rw-r--r--ptrie/fullnode.go5
-rw-r--r--ptrie/iterator.go114
-rw-r--r--ptrie/iterator_test.go28
-rw-r--r--ptrie/trie.go18
-rw-r--r--ptrie/trie_test.go2
5 files changed, 159 insertions, 8 deletions
diff --git a/ptrie/fullnode.go b/ptrie/fullnode.go
index 2b1a62789..eaa4611b6 100644
--- a/ptrie/fullnode.go
+++ b/ptrie/fullnode.go
@@ -14,6 +14,9 @@ func (self *FullNode) Value() Node {
self.nodes[16] = self.trie.trans(self.nodes[16])
return self.nodes[16]
}
+func (self *FullNode) Branches() []Node {
+ return self.nodes[:16]
+}
func (self *FullNode) Copy() Node { return self }
@@ -49,7 +52,7 @@ func (self *FullNode) set(k byte, value Node) {
self.nodes[int(k)] = value
}
-func (self *FullNode) get(i byte) Node {
+func (self *FullNode) branch(i byte) Node {
if self.nodes[int(i)] != nil {
self.nodes[int(i)] = self.trie.trans(self.nodes[int(i)])
diff --git a/ptrie/iterator.go b/ptrie/iterator.go
new file mode 100644
index 000000000..c6d4f64a0
--- /dev/null
+++ b/ptrie/iterator.go
@@ -0,0 +1,114 @@
+package ptrie
+
+import (
+ "bytes"
+
+ "github.com/ethereum/go-ethereum/trie"
+)
+
+type Iterator struct {
+ trie *Trie
+
+ Key []byte
+ Value []byte
+}
+
+func NewIterator(trie *Trie) *Iterator {
+ return &Iterator{trie: trie, Key: []byte{0}}
+}
+
+func (self *Iterator) Next() bool {
+ self.trie.mu.Lock()
+ defer self.trie.mu.Unlock()
+
+ key := trie.RemTerm(trie.CompactHexDecode(string(self.Key)))
+ k := self.next(self.trie.root, key)
+
+ self.Key = []byte(trie.DecodeCompact(k))
+
+ return len(k) > 0
+
+}
+
+func (self *Iterator) next(node Node, key []byte) []byte {
+ if node == nil {
+ return nil
+ }
+
+ switch node := node.(type) {
+ case *FullNode:
+ if len(key) > 0 {
+ k := self.next(node.branch(key[0]), key[1:])
+ if k != nil {
+ return append([]byte{key[0]}, k...)
+ }
+ }
+
+ var r byte
+ if len(key) > 0 {
+ r = key[0] + 1
+ }
+
+ for i := r; i < 16; i++ {
+ k := self.key(node.branch(byte(i)))
+ if k != nil {
+ return append([]byte{i}, k...)
+ }
+ }
+
+ case *ShortNode:
+ k := trie.RemTerm(node.Key())
+ if vnode, ok := node.Value().(*ValueNode); ok {
+ if bytes.Compare([]byte(k), key) > 0 {
+ self.Value = vnode.Val()
+ return k
+ }
+ } else {
+ cnode := node.Value()
+ skey := key[len(k):]
+
+ var ret []byte
+ if trie.BeginsWith(key, k) {
+ ret = self.next(cnode, skey)
+ } else if bytes.Compare(k, key[:len(k)]) > 0 {
+ ret = self.key(node)
+ }
+
+ if ret != nil {
+ return append(k, ret...)
+ }
+ }
+ }
+
+ return nil
+}
+
+func (self *Iterator) key(node Node) []byte {
+ switch node := node.(type) {
+ case *ShortNode:
+ // Leaf node
+ if vnode, ok := node.Value().(*ValueNode); ok {
+ k := trie.RemTerm(node.Key())
+ self.Value = vnode.Val()
+
+ return k
+ } else {
+ return self.key(node.Value())
+ }
+ case *FullNode:
+ if node.Value() != nil {
+ self.Value = node.Value().(*ValueNode).Val()
+
+ return []byte{16}
+ }
+
+ for i := 0; i < 16; i++ {
+ k := self.key(node.branch(byte(i)))
+ if k != nil {
+ return append([]byte{byte(i)}, k...)
+ }
+ }
+ }
+
+ return nil
+}
diff --git a/ptrie/iterator_test.go b/ptrie/iterator_test.go
new file mode 100644
index 000000000..8921bb670
--- /dev/null
+++ b/ptrie/iterator_test.go
@@ -0,0 +1,28 @@
+package ptrie
+
+import "testing"
+
+func TestIterator(t *testing.T) {
+ trie := NewEmpty()
+ vals := []struct{ k, v string }{
+ {"do", "verb"},
+ {"ether", "wookiedoo"},
+ {"horse", "stallion"},
+ }
+ v := make(map[string]bool)
+ for _, val := range vals {
+ v[val.k] = false
+ trie.UpdateString(val.k, val.v)
+ }
+
+ it := trie.Iterator()
+ for it.Next() {
+ v[string(it.Key)] = true
+ }
+
+ for k, found := range v {
+ if !found {
+ t.Error("iterator didn't find", k)
+ }
+ }
+}
diff --git a/ptrie/trie.go b/ptrie/trie.go
index 207aad91e..bb2b3845a 100644
--- a/ptrie/trie.go
+++ b/ptrie/trie.go
@@ -45,6 +45,10 @@ func New(root []byte, backend Backend) *Trie {
return trie
}
+func (self *Trie) Iterator() *Iterator {
+ return NewIterator(self)
+}
+
// Legacy support
func (self *Trie) Root() []byte { return self.Hash() }
func (self *Trie) Hash() []byte {
@@ -144,7 +148,7 @@ func (self *Trie) insert(node Node, key []byte, value Node) Node {
case *FullNode:
cpy := node.Copy().(*FullNode)
- cpy.set(key[0], self.insert(node.get(key[0]), key[1:], value))
+ cpy.set(key[0], self.insert(node.branch(key[0]), key[1:], value))
return cpy
@@ -173,7 +177,7 @@ func (self *Trie) get(node Node, key []byte) Node {
return nil
case *FullNode:
- return self.get(node.get(key[0]), key[1:])
+ return self.get(node.branch(key[0]), key[1:])
default:
panic("Invalid node")
}
@@ -209,11 +213,11 @@ func (self *Trie) delete(node Node, key []byte) Node {
case *FullNode:
n := node.Copy().(*FullNode)
- n.set(key[0], self.delete(n.get(key[0]), key[1:]))
+ n.set(key[0], self.delete(n.branch(key[0]), key[1:]))
pos := -1
for i := 0; i < 17; i++ {
- if n.get(byte(i)) != nil {
+ if n.branch(byte(i)) != nil {
if pos == -1 {
pos = i
} else {
@@ -224,16 +228,16 @@ func (self *Trie) delete(node Node, key []byte) Node {
var nnode Node
if pos == 16 {
- nnode = NewShortNode(self, []byte{16}, n.get(byte(pos)))
+ nnode = NewShortNode(self, []byte{16}, n.branch(byte(pos)))
} else if pos >= 0 {
- cnode := n.get(byte(pos))
+ cnode := n.branch(byte(pos))
switch cnode := cnode.(type) {
case *ShortNode:
// Stitch keys
k := append([]byte{byte(pos)}, cnode.Key()...)
nnode = NewShortNode(self, k, cnode.Value())
case *FullNode:
- nnode = NewShortNode(self, []byte{byte(pos)}, n.get(byte(pos)))
+ nnode = NewShortNode(self, []byte{byte(pos)}, n.branch(byte(pos)))
}
} else {
nnode = n
diff --git a/ptrie/trie_test.go b/ptrie/trie_test.go
index 6cdd2bde4..6af6e1b40 100644
--- a/ptrie/trie_test.go
+++ b/ptrie/trie_test.go
@@ -139,6 +139,8 @@ func BenchmarkUpdate(b *testing.B) {
// Not actual test
func TestOutput(t *testing.T) {
+ t.Skip()
+
base := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
trie := NewEmpty()
for i := 0; i < 50; i++ {