aboutsummaryrefslogtreecommitdiffstats
path: root/trie/iterator_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'trie/iterator_test.go')
-rw-r--r--trie/iterator_test.go65
1 files changed, 55 insertions, 10 deletions
diff --git a/trie/iterator_test.go b/trie/iterator_test.go
index 04d51aaf5..f161fd99d 100644
--- a/trie/iterator_test.go
+++ b/trie/iterator_test.go
@@ -17,6 +17,8 @@
package trie
import (
+ "bytes"
+ "fmt"
"testing"
"github.com/ethereum/go-ethereum/common"
@@ -42,7 +44,7 @@ func TestIterator(t *testing.T) {
trie.Commit()
found := make(map[string]string)
- it := NewIterator(trie.NodeIterator())
+ it := NewIterator(trie.NodeIterator(nil))
for it.Next() {
found[string(it.Key)] = string(it.Value)
}
@@ -72,7 +74,7 @@ func TestIteratorLargeData(t *testing.T) {
vals[string(value2.k)] = value2
}
- it := NewIterator(trie.NodeIterator())
+ it := NewIterator(trie.NodeIterator(nil))
for it.Next() {
vals[string(it.Key)].t = true
}
@@ -99,7 +101,7 @@ func TestNodeIteratorCoverage(t *testing.T) {
// Gather all the node hashes found by the iterator
hashes := make(map[common.Hash]struct{})
- for it := trie.NodeIterator(); it.Next(true); {
+ for it := trie.NodeIterator(nil); it.Next(true); {
if it.Hash() != (common.Hash{}) {
hashes[it.Hash()] = struct{}{}
}
@@ -117,18 +119,20 @@ func TestNodeIteratorCoverage(t *testing.T) {
}
}
-var testdata1 = []struct{ k, v string }{
- {"bar", "b"},
+type kvs struct{ k, v string }
+
+var testdata1 = []kvs{
{"barb", "ba"},
- {"bars", "bb"},
{"bard", "bc"},
+ {"bars", "bb"},
+ {"bar", "b"},
{"fab", "z"},
- {"foo", "a"},
{"food", "ab"},
{"foos", "aa"},
+ {"foo", "a"},
}
-var testdata2 = []struct{ k, v string }{
+var testdata2 = []kvs{
{"aardvark", "c"},
{"bar", "b"},
{"barb", "bd"},
@@ -140,6 +144,47 @@ var testdata2 = []struct{ k, v string }{
{"jars", "d"},
}
+func TestIteratorSeek(t *testing.T) {
+ trie := newEmpty()
+ for _, val := range testdata1 {
+ trie.Update([]byte(val.k), []byte(val.v))
+ }
+
+ // Seek to the middle.
+ it := NewIterator(trie.NodeIterator([]byte("fab")))
+ if err := checkIteratorOrder(testdata1[4:], it); err != nil {
+ t.Fatal(err)
+ }
+
+ // Seek to a non-existent key.
+ it = NewIterator(trie.NodeIterator([]byte("barc")))
+ if err := checkIteratorOrder(testdata1[1:], it); err != nil {
+ t.Fatal(err)
+ }
+
+ // Seek beyond the end.
+ it = NewIterator(trie.NodeIterator([]byte("z")))
+ if err := checkIteratorOrder(nil, it); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func checkIteratorOrder(want []kvs, it *Iterator) error {
+ for it.Next() {
+ if len(want) == 0 {
+ return fmt.Errorf("didn't expect any more values, got key %q", it.Key)
+ }
+ if !bytes.Equal(it.Key, []byte(want[0].k)) {
+ return fmt.Errorf("wrong key: got %q, want %q", it.Key, want[0].k)
+ }
+ want = want[1:]
+ }
+ if len(want) > 0 {
+ return fmt.Errorf("iterator ended early, want key %q", want[0])
+ }
+ return nil
+}
+
func TestDifferenceIterator(t *testing.T) {
triea := newEmpty()
for _, val := range testdata1 {
@@ -154,7 +199,7 @@ func TestDifferenceIterator(t *testing.T) {
trieb.Commit()
found := make(map[string]string)
- di, _ := NewDifferenceIterator(triea.NodeIterator(), trieb.NodeIterator())
+ di, _ := NewDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator(nil))
it := NewIterator(di)
for it.Next() {
found[string(it.Key)] = string(it.Value)
@@ -189,7 +234,7 @@ func TestUnionIterator(t *testing.T) {
}
trieb.Commit()
- di, _ := NewUnionIterator([]NodeIterator{triea.NodeIterator(), trieb.NodeIterator()})
+ di, _ := NewUnionIterator([]NodeIterator{triea.NodeIterator(nil), trieb.NodeIterator(nil)})
it := NewIterator(di)
all := []struct{ k, v string }{