diff options
Diffstat (limited to 'trie/iterator_test.go')
-rw-r--r-- | trie/iterator_test.go | 65 |
1 files changed, 55 insertions, 10 deletions
diff --git a/trie/iterator_test.go b/trie/iterator_test.go index 04d51aaf5..f161fd99d 100644 --- a/trie/iterator_test.go +++ b/trie/iterator_test.go @@ -17,6 +17,8 @@ package trie import ( + "bytes" + "fmt" "testing" "github.com/ethereum/go-ethereum/common" @@ -42,7 +44,7 @@ func TestIterator(t *testing.T) { trie.Commit() found := make(map[string]string) - it := NewIterator(trie.NodeIterator()) + it := NewIterator(trie.NodeIterator(nil)) for it.Next() { found[string(it.Key)] = string(it.Value) } @@ -72,7 +74,7 @@ func TestIteratorLargeData(t *testing.T) { vals[string(value2.k)] = value2 } - it := NewIterator(trie.NodeIterator()) + it := NewIterator(trie.NodeIterator(nil)) for it.Next() { vals[string(it.Key)].t = true } @@ -99,7 +101,7 @@ func TestNodeIteratorCoverage(t *testing.T) { // Gather all the node hashes found by the iterator hashes := make(map[common.Hash]struct{}) - for it := trie.NodeIterator(); it.Next(true); { + for it := trie.NodeIterator(nil); it.Next(true); { if it.Hash() != (common.Hash{}) { hashes[it.Hash()] = struct{}{} } @@ -117,18 +119,20 @@ func TestNodeIteratorCoverage(t *testing.T) { } } -var testdata1 = []struct{ k, v string }{ - {"bar", "b"}, +type kvs struct{ k, v string } + +var testdata1 = []kvs{ {"barb", "ba"}, - {"bars", "bb"}, {"bard", "bc"}, + {"bars", "bb"}, + {"bar", "b"}, {"fab", "z"}, - {"foo", "a"}, {"food", "ab"}, {"foos", "aa"}, + {"foo", "a"}, } -var testdata2 = []struct{ k, v string }{ +var testdata2 = []kvs{ {"aardvark", "c"}, {"bar", "b"}, {"barb", "bd"}, @@ -140,6 +144,47 @@ var testdata2 = []struct{ k, v string }{ {"jars", "d"}, } +func TestIteratorSeek(t *testing.T) { + trie := newEmpty() + for _, val := range testdata1 { + trie.Update([]byte(val.k), []byte(val.v)) + } + + // Seek to the middle. + it := NewIterator(trie.NodeIterator([]byte("fab"))) + if err := checkIteratorOrder(testdata1[4:], it); err != nil { + t.Fatal(err) + } + + // Seek to a non-existent key. + it = NewIterator(trie.NodeIterator([]byte("barc"))) + if err := checkIteratorOrder(testdata1[1:], it); err != nil { + t.Fatal(err) + } + + // Seek beyond the end. + it = NewIterator(trie.NodeIterator([]byte("z"))) + if err := checkIteratorOrder(nil, it); err != nil { + t.Fatal(err) + } +} + +func checkIteratorOrder(want []kvs, it *Iterator) error { + for it.Next() { + if len(want) == 0 { + return fmt.Errorf("didn't expect any more values, got key %q", it.Key) + } + if !bytes.Equal(it.Key, []byte(want[0].k)) { + return fmt.Errorf("wrong key: got %q, want %q", it.Key, want[0].k) + } + want = want[1:] + } + if len(want) > 0 { + return fmt.Errorf("iterator ended early, want key %q", want[0]) + } + return nil +} + func TestDifferenceIterator(t *testing.T) { triea := newEmpty() for _, val := range testdata1 { @@ -154,7 +199,7 @@ func TestDifferenceIterator(t *testing.T) { trieb.Commit() found := make(map[string]string) - di, _ := NewDifferenceIterator(triea.NodeIterator(), trieb.NodeIterator()) + di, _ := NewDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator(nil)) it := NewIterator(di) for it.Next() { found[string(it.Key)] = string(it.Value) @@ -189,7 +234,7 @@ func TestUnionIterator(t *testing.T) { } trieb.Commit() - di, _ := NewUnionIterator([]NodeIterator{triea.NodeIterator(), trieb.NodeIterator()}) + di, _ := NewUnionIterator([]NodeIterator{triea.NodeIterator(nil), trieb.NodeIterator(nil)}) it := NewIterator(di) all := []struct{ k, v string }{ |