aboutsummaryrefslogtreecommitdiffstats
path: root/trie
diff options
context:
space:
mode:
Diffstat (limited to 'trie')
-rw-r--r--trie/arc.go194
-rw-r--r--trie/cache.go78
-rw-r--r--trie/encoding.go72
-rw-r--r--trie/encoding_test.go36
-rw-r--r--trie/fullnode.go94
-rw-r--r--trie/hashnode.go46
-rw-r--r--trie/iterator.go64
-rw-r--r--trie/iterator_test.go6
-rw-r--r--trie/node.go174
-rw-r--r--trie/secure_trie.go97
-rw-r--r--trie/secure_trie_test.go74
-rw-r--r--trie/shortnode.go57
-rw-r--r--trie/slice.go69
-rw-r--r--trie/trie.go639
-rw-r--r--trie/trie_test.go264
-rw-r--r--trie/valuenode.go42
16 files changed, 1075 insertions, 931 deletions
diff --git a/trie/arc.go b/trie/arc.go
new file mode 100644
index 000000000..9da012e16
--- /dev/null
+++ b/trie/arc.go
@@ -0,0 +1,194 @@
+// Copyright (c) 2015 Hans Alexander Gugel <alexander.gugel@gmail.com>
+//
+// Permission is hereby granted, free of charge, to any person obtaining a copy
+// of this software and associated documentation files (the "Software"), to deal
+// in the Software without restriction, including without limitation the rights
+// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+// copies of the Software, and to permit persons to whom the Software is
+// furnished to do so, subject to the following conditions:
+//
+// The above copyright notice and this permission notice shall be included in all
+// copies or substantial portions of the Software.
+//
+// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+// SOFTWARE.
+
+// This file contains a modified version of package arc from
+// https://github.com/alexanderGugel/arc
+//
+// It implements the ARC (Adaptive Replacement Cache) algorithm as detailed in
+// https://www.usenix.org/legacy/event/fast03/tech/full_papers/megiddo/megiddo.pdf
+
+package trie
+
+import (
+ "container/list"
+ "sync"
+)
+
+type arc struct {
+ p int
+ c int
+ t1 *list.List
+ b1 *list.List
+ t2 *list.List
+ b2 *list.List
+ cache map[string]*entry
+ mutex sync.Mutex
+}
+
+type entry struct {
+ key hashNode
+ value node
+ ll *list.List
+ el *list.Element
+}
+
+// newARC returns a new Adaptive Replacement Cache with the
+// given capacity.
+func newARC(c int) *arc {
+ return &arc{
+ c: c,
+ t1: list.New(),
+ b1: list.New(),
+ t2: list.New(),
+ b2: list.New(),
+ cache: make(map[string]*entry, c),
+ }
+}
+
+// Put inserts a new key-value pair into the cache.
+// This optimizes future access to this entry (side effect).
+func (a *arc) Put(key hashNode, value node) bool {
+ a.mutex.Lock()
+ defer a.mutex.Unlock()
+ ent, ok := a.cache[string(key)]
+ if ok != true {
+ ent = &entry{key: key, value: value}
+ a.req(ent)
+ a.cache[string(key)] = ent
+ } else {
+ ent.value = value
+ a.req(ent)
+ }
+ return ok
+}
+
+// Get retrieves a previously via Set inserted entry.
+// This optimizes future access to this entry (side effect).
+func (a *arc) Get(key hashNode) (value node, ok bool) {
+ a.mutex.Lock()
+ defer a.mutex.Unlock()
+ ent, ok := a.cache[string(key)]
+ if ok {
+ a.req(ent)
+ return ent.value, ent.value != nil
+ }
+ return nil, false
+}
+
+func (a *arc) req(ent *entry) {
+ if ent.ll == a.t1 || ent.ll == a.t2 {
+ // Case I
+ ent.setMRU(a.t2)
+ } else if ent.ll == a.b1 {
+ // Case II
+ // Cache Miss in t1 and t2
+
+ // Adaptation
+ var d int
+ if a.b1.Len() >= a.b2.Len() {
+ d = 1
+ } else {
+ d = a.b2.Len() / a.b1.Len()
+ }
+ a.p = a.p + d
+ if a.p > a.c {
+ a.p = a.c
+ }
+
+ a.replace(ent)
+ ent.setMRU(a.t2)
+ } else if ent.ll == a.b2 {
+ // Case III
+ // Cache Miss in t1 and t2
+
+ // Adaptation
+ var d int
+ if a.b2.Len() >= a.b1.Len() {
+ d = 1
+ } else {
+ d = a.b1.Len() / a.b2.Len()
+ }
+ a.p = a.p - d
+ if a.p < 0 {
+ a.p = 0
+ }
+
+ a.replace(ent)
+ ent.setMRU(a.t2)
+ } else if ent.ll == nil {
+ // Case IV
+
+ if a.t1.Len()+a.b1.Len() == a.c {
+ // Case A
+ if a.t1.Len() < a.c {
+ a.delLRU(a.b1)
+ a.replace(ent)
+ } else {
+ a.delLRU(a.t1)
+ }
+ } else if a.t1.Len()+a.b1.Len() < a.c {
+ // Case B
+ if a.t1.Len()+a.t2.Len()+a.b1.Len()+a.b2.Len() >= a.c {
+ if a.t1.Len()+a.t2.Len()+a.b1.Len()+a.b2.Len() == 2*a.c {
+ a.delLRU(a.b2)
+ }
+ a.replace(ent)
+ }
+ }
+
+ ent.setMRU(a.t1)
+ }
+}
+
+func (a *arc) delLRU(list *list.List) {
+ lru := list.Back()
+ list.Remove(lru)
+ delete(a.cache, string(lru.Value.(*entry).key))
+}
+
+func (a *arc) replace(ent *entry) {
+ if a.t1.Len() > 0 && ((a.t1.Len() > a.p) || (ent.ll == a.b2 && a.t1.Len() == a.p)) {
+ lru := a.t1.Back().Value.(*entry)
+ lru.value = nil
+ lru.setMRU(a.b1)
+ } else {
+ lru := a.t2.Back().Value.(*entry)
+ lru.value = nil
+ lru.setMRU(a.b2)
+ }
+}
+
+func (e *entry) setLRU(list *list.List) {
+ e.detach()
+ e.ll = list
+ e.el = e.ll.PushBack(e)
+}
+
+func (e *entry) setMRU(list *list.List) {
+ e.detach()
+ e.ll = list
+ e.el = e.ll.PushFront(e)
+}
+
+func (e *entry) detach() {
+ if e.ll != nil {
+ e.ll.Remove(e.el)
+ }
+}
diff --git a/trie/cache.go b/trie/cache.go
deleted file mode 100644
index e475fc861..000000000
--- a/trie/cache.go
+++ /dev/null
@@ -1,78 +0,0 @@
-// Copyright 2014 The go-ethereum Authors
-// This file is part of the go-ethereum library.
-//
-// The go-ethereum library is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Lesser General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// The go-ethereum library is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Lesser General Public License for more details.
-//
-// You should have received a copy of the GNU Lesser General Public License
-// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
-
-package trie
-
-import (
- "github.com/ethereum/go-ethereum/ethdb"
- "github.com/ethereum/go-ethereum/logger/glog"
- "github.com/syndtr/goleveldb/leveldb"
-)
-
-type Backend interface {
- Get([]byte) ([]byte, error)
- Put([]byte, []byte) error
-}
-
-type Cache struct {
- batch *leveldb.Batch
- store map[string][]byte
- backend Backend
-}
-
-func NewCache(backend Backend) *Cache {
- return &Cache{new(leveldb.Batch), make(map[string][]byte), backend}
-}
-
-func (self *Cache) Get(key []byte) []byte {
- data := self.store[string(key)]
- if data == nil {
- data, _ = self.backend.Get(key)
- }
-
- return data
-}
-
-func (self *Cache) Put(key []byte, data []byte) {
- self.batch.Put(key, data)
- self.store[string(key)] = data
-}
-
-// Flush flushes the trie to the backing layer. If this is a leveldb instance
-// we'll use a batched write, otherwise we'll use regular put.
-func (self *Cache) Flush() {
- if db, ok := self.backend.(*ethdb.LDBDatabase); ok {
- if err := db.LDB().Write(self.batch, nil); err != nil {
- glog.Fatal("db write err:", err)
- }
- } else {
- for k, v := range self.store {
- self.backend.Put([]byte(k), v)
- }
- }
-}
-
-func (self *Cache) Copy() *Cache {
- cache := NewCache(self.backend)
- for k, v := range self.store {
- cache.store[k] = v
- }
- return cache
-}
-
-func (self *Cache) Reset() {
- //self.store = make(map[string][]byte)
-}
diff --git a/trie/encoding.go b/trie/encoding.go
index 9c862d78f..3c172b843 100644
--- a/trie/encoding.go
+++ b/trie/encoding.go
@@ -16,34 +16,36 @@
package trie
-func CompactEncode(hexSlice []byte) []byte {
- terminator := 0
+func compactEncode(hexSlice []byte) []byte {
+ terminator := byte(0)
if hexSlice[len(hexSlice)-1] == 16 {
terminator = 1
- }
-
- if terminator == 1 {
hexSlice = hexSlice[:len(hexSlice)-1]
}
-
- oddlen := len(hexSlice) % 2
- flags := byte(2*terminator + oddlen)
- if oddlen != 0 {
- hexSlice = append([]byte{flags}, hexSlice...)
- } else {
- hexSlice = append([]byte{flags, 0}, hexSlice...)
+ var (
+ odd = byte(len(hexSlice) % 2)
+ buflen = len(hexSlice)/2 + 1
+ bi, hi = 0, 0 // indices
+ hs = byte(0) // shift: flips between 0 and 4
+ )
+ if odd == 0 {
+ bi = 1
+ hs = 4
}
-
- l := len(hexSlice) / 2
- var buf = make([]byte, l)
- for i := 0; i < l; i++ {
- buf[i] = 16*hexSlice[2*i] + hexSlice[2*i+1]
+ buf := make([]byte, buflen)
+ buf[0] = terminator<<5 | byte(odd)<<4
+ for bi < len(buf) && hi < len(hexSlice) {
+ buf[bi] |= hexSlice[hi] << hs
+ if hs == 0 {
+ bi++
+ }
+ hi, hs = hi+1, hs^(1<<2)
}
return buf
}
-func CompactDecode(str []byte) []byte {
- base := CompactHexDecode(str)
+func compactDecode(str []byte) []byte {
+ base := compactHexDecode(str)
base = base[:len(base)-1]
if base[0] >= 2 {
base = append(base, 16)
@@ -53,11 +55,10 @@ func CompactDecode(str []byte) []byte {
} else {
base = base[2:]
}
-
return base
}
-func CompactHexDecode(str []byte) []byte {
+func compactHexDecode(str []byte) []byte {
l := len(str)*2 + 1
var nibbles = make([]byte, l)
for i, b := range str {
@@ -68,7 +69,7 @@ func CompactHexDecode(str []byte) []byte {
return nibbles
}
-func DecodeCompact(key []byte) []byte {
+func decodeCompact(key []byte) []byte {
l := len(key) / 2
var res = make([]byte, l)
for i := 0; i < l; i++ {
@@ -77,3 +78,30 @@ func DecodeCompact(key []byte) []byte {
}
return res
}
+
+// prefixLen returns the length of the common prefix of a and b.
+func prefixLen(a, b []byte) int {
+ var i, length = 0, len(a)
+ if len(b) < length {
+ length = len(b)
+ }
+ for ; i < length; i++ {
+ if a[i] != b[i] {
+ break
+ }
+ }
+ return i
+}
+
+func hasTerm(s []byte) bool {
+ return s[len(s)-1] == 16
+}
+
+func remTerm(s []byte) []byte {
+ if hasTerm(s) {
+ b := make([]byte, len(s)-1)
+ copy(b, s)
+ return b
+ }
+ return s
+}
diff --git a/trie/encoding_test.go b/trie/encoding_test.go
index e49b57ef0..061d48d58 100644
--- a/trie/encoding_test.go
+++ b/trie/encoding_test.go
@@ -23,7 +23,7 @@ import (
checker "gopkg.in/check.v1"
)
-func Test(t *testing.T) { checker.TestingT(t) }
+func TestEncoding(t *testing.T) { checker.TestingT(t) }
type TrieEncodingSuite struct{}
@@ -32,64 +32,64 @@ var _ = checker.Suite(&TrieEncodingSuite{})
func (s *TrieEncodingSuite) TestCompactEncode(c *checker.C) {
// even compact encode
test1 := []byte{1, 2, 3, 4, 5}
- res1 := CompactEncode(test1)
+ res1 := compactEncode(test1)
c.Assert(res1, checker.DeepEquals, []byte("\x11\x23\x45"))
// odd compact encode
test2 := []byte{0, 1, 2, 3, 4, 5}
- res2 := CompactEncode(test2)
+ res2 := compactEncode(test2)
c.Assert(res2, checker.DeepEquals, []byte("\x00\x01\x23\x45"))
//odd terminated compact encode
test3 := []byte{0, 15, 1, 12, 11, 8 /*term*/, 16}
- res3 := CompactEncode(test3)
+ res3 := compactEncode(test3)
c.Assert(res3, checker.DeepEquals, []byte("\x20\x0f\x1c\xb8"))
// even terminated compact encode
test4 := []byte{15, 1, 12, 11, 8 /*term*/, 16}
- res4 := CompactEncode(test4)
+ res4 := compactEncode(test4)
c.Assert(res4, checker.DeepEquals, []byte("\x3f\x1c\xb8"))
}
func (s *TrieEncodingSuite) TestCompactHexDecode(c *checker.C) {
exp := []byte{7, 6, 6, 5, 7, 2, 6, 2, 16}
- res := CompactHexDecode([]byte("verb"))
+ res := compactHexDecode([]byte("verb"))
c.Assert(res, checker.DeepEquals, exp)
}
func (s *TrieEncodingSuite) TestCompactDecode(c *checker.C) {
// odd compact decode
exp := []byte{1, 2, 3, 4, 5}
- res := CompactDecode([]byte("\x11\x23\x45"))
+ res := compactDecode([]byte("\x11\x23\x45"))
c.Assert(res, checker.DeepEquals, exp)
// even compact decode
exp = []byte{0, 1, 2, 3, 4, 5}
- res = CompactDecode([]byte("\x00\x01\x23\x45"))
+ res = compactDecode([]byte("\x00\x01\x23\x45"))
c.Assert(res, checker.DeepEquals, exp)
// even terminated compact decode
exp = []byte{0, 15, 1, 12, 11, 8 /*term*/, 16}
- res = CompactDecode([]byte("\x20\x0f\x1c\xb8"))
+ res = compactDecode([]byte("\x20\x0f\x1c\xb8"))
c.Assert(res, checker.DeepEquals, exp)
// even terminated compact decode
exp = []byte{15, 1, 12, 11, 8 /*term*/, 16}
- res = CompactDecode([]byte("\x3f\x1c\xb8"))
+ res = compactDecode([]byte("\x3f\x1c\xb8"))
c.Assert(res, checker.DeepEquals, exp)
}
func (s *TrieEncodingSuite) TestDecodeCompact(c *checker.C) {
exp, _ := hex.DecodeString("012345")
- res := DecodeCompact([]byte{0, 1, 2, 3, 4, 5})
+ res := decodeCompact([]byte{0, 1, 2, 3, 4, 5})
c.Assert(res, checker.DeepEquals, exp)
exp, _ = hex.DecodeString("012345")
- res = DecodeCompact([]byte{0, 1, 2, 3, 4, 5, 16})
+ res = decodeCompact([]byte{0, 1, 2, 3, 4, 5, 16})
c.Assert(res, checker.DeepEquals, exp)
exp, _ = hex.DecodeString("abcdef")
- res = DecodeCompact([]byte{10, 11, 12, 13, 14, 15})
+ res = decodeCompact([]byte{10, 11, 12, 13, 14, 15})
c.Assert(res, checker.DeepEquals, exp)
}
@@ -97,29 +97,27 @@ func BenchmarkCompactEncode(b *testing.B) {
testBytes := []byte{0, 15, 1, 12, 11, 8 /*term*/, 16}
for i := 0; i < b.N; i++ {
- CompactEncode(testBytes)
+ compactEncode(testBytes)
}
}
func BenchmarkCompactDecode(b *testing.B) {
testBytes := []byte{0, 15, 1, 12, 11, 8 /*term*/, 16}
for i := 0; i < b.N; i++ {
- CompactDecode(testBytes)
+ compactDecode(testBytes)
}
}
func BenchmarkCompactHexDecode(b *testing.B) {
testBytes := []byte{7, 6, 6, 5, 7, 2, 6, 2, 16}
for i := 0; i < b.N; i++ {
- CompactHexDecode(testBytes)
+ compactHexDecode(testBytes)
}
-
}
func BenchmarkDecodeCompact(b *testing.B) {
testBytes := []byte{7, 6, 6, 5, 7, 2, 6, 2, 16}
for i := 0; i < b.N; i++ {
- DecodeCompact(testBytes)
+ decodeCompact(testBytes)
}
-
}
diff --git a/trie/fullnode.go b/trie/fullnode.go
deleted file mode 100644
index 8ff019ec4..000000000
--- a/trie/fullnode.go
+++ /dev/null
@@ -1,94 +0,0 @@
-// Copyright 2014 The go-ethereum Authors
-// This file is part of the go-ethereum library.
-//
-// The go-ethereum library is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Lesser General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// The go-ethereum library is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Lesser General Public License for more details.
-//
-// You should have received a copy of the GNU Lesser General Public License
-// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
-
-package trie
-
-type FullNode struct {
- trie *Trie
- nodes [17]Node
- dirty bool
-}
-
-func NewFullNode(t *Trie) *FullNode {
- return &FullNode{trie: t}
-}
-
-func (self *FullNode) Dirty() bool { return self.dirty }
-func (self *FullNode) Value() Node {
- self.nodes[16] = self.trie.trans(self.nodes[16])
- return self.nodes[16]
-}
-func (self *FullNode) Branches() []Node {
- return self.nodes[:16]
-}
-
-func (self *FullNode) Copy(t *Trie) Node {
- nnode := NewFullNode(t)
- for i, node := range self.nodes {
- if node != nil {
- nnode.nodes[i] = node
- }
- }
- nnode.dirty = true
-
- return nnode
-}
-
-// Returns the length of non-nil nodes
-func (self *FullNode) Len() (amount int) {
- for _, node := range self.nodes {
- if node != nil {
- amount++
- }
- }
-
- return
-}
-
-func (self *FullNode) Hash() interface{} {
- return self.trie.store(self)
-}
-
-func (self *FullNode) RlpData() interface{} {
- t := make([]interface{}, 17)
- for i, node := range self.nodes {
- if node != nil {
- t[i] = node.Hash()
- } else {
- t[i] = ""
- }
- }
-
- return t
-}
-
-func (self *FullNode) set(k byte, value Node) {
- self.nodes[int(k)] = value
- self.dirty = true
-}
-
-func (self *FullNode) branch(i byte) Node {
- if self.nodes[int(i)] != nil {
- self.nodes[int(i)] = self.trie.trans(self.nodes[int(i)])
-
- return self.nodes[int(i)]
- }
- return nil
-}
-
-func (self *FullNode) setDirty(dirty bool) {
- self.dirty = dirty
-}
diff --git a/trie/hashnode.go b/trie/hashnode.go
deleted file mode 100644
index d4a0bc7ec..000000000
--- a/trie/hashnode.go
+++ /dev/null
@@ -1,46 +0,0 @@
-// Copyright 2014 The go-ethereum Authors
-// This file is part of the go-ethereum library.
-//
-// The go-ethereum library is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Lesser General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// The go-ethereum library is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Lesser General Public License for more details.
-//
-// You should have received a copy of the GNU Lesser General Public License
-// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
-
-package trie
-
-import "github.com/ethereum/go-ethereum/common"
-
-type HashNode struct {
- key []byte
- trie *Trie
- dirty bool
-}
-
-func NewHash(key []byte, trie *Trie) *HashNode {
- return &HashNode{key, trie, false}
-}
-
-func (self *HashNode) RlpData() interface{} {
- return self.key
-}
-
-func (self *HashNode) Hash() interface{} {
- return self.key
-}
-
-func (self *HashNode) setDirty(dirty bool) {
- self.dirty = dirty
-}
-
-// These methods will never be called but we have to satisfy Node interface
-func (self *HashNode) Value() Node { return nil }
-func (self *HashNode) Dirty() bool { return true }
-func (self *HashNode) Copy(t *Trie) Node { return NewHash(common.CopyBytes(self.key), t) }
diff --git a/trie/iterator.go b/trie/iterator.go
index 9c4c7fbe5..38555fe08 100644
--- a/trie/iterator.go
+++ b/trie/iterator.go
@@ -16,9 +16,7 @@
package trie
-import (
- "bytes"
-)
+import "bytes"
type Iterator struct {
trie *Trie
@@ -32,32 +30,29 @@ func NewIterator(trie *Trie) *Iterator {
}
func (self *Iterator) Next() bool {
- self.trie.mu.Lock()
- defer self.trie.mu.Unlock()
-
isIterStart := false
if self.Key == nil {
isIterStart = true
self.Key = make([]byte, 32)
}
- key := RemTerm(CompactHexDecode(self.Key))
+ key := remTerm(compactHexDecode(self.Key))
k := self.next(self.trie.root, key, isIterStart)
- self.Key = []byte(DecodeCompact(k))
+ self.Key = []byte(decodeCompact(k))
return len(k) > 0
}
-func (self *Iterator) next(node Node, key []byte, isIterStart bool) []byte {
+func (self *Iterator) next(node interface{}, key []byte, isIterStart bool) []byte {
if node == nil {
return nil
}
switch node := node.(type) {
- case *FullNode:
+ case fullNode:
if len(key) > 0 {
- k := self.next(node.branch(key[0]), key[1:], isIterStart)
+ k := self.next(node[key[0]], key[1:], isIterStart)
if k != nil {
return append([]byte{key[0]}, k...)
}
@@ -69,31 +64,31 @@ func (self *Iterator) next(node Node, key []byte, isIterStart bool) []byte {
}
for i := r; i < 16; i++ {
- k := self.key(node.branch(byte(i)))
+ k := self.key(node[i])
if k != nil {
return append([]byte{i}, k...)
}
}
- case *ShortNode:
- k := RemTerm(node.Key())
- if vnode, ok := node.Value().(*ValueNode); ok {
+ case shortNode:
+ k := remTerm(node.Key)
+ if vnode, ok := node.Val.(valueNode); ok {
switch bytes.Compare([]byte(k), key) {
case 0:
if isIterStart {
- self.Value = vnode.Val()
+ self.Value = vnode
return k
}
case 1:
- self.Value = vnode.Val()
+ self.Value = vnode
return k
}
} else {
- cnode := node.Value()
+ cnode := node.Val
var ret []byte
skey := key[len(k):]
- if BeginsWith(key, k) {
+ if bytes.HasPrefix(key, k) {
ret = self.next(cnode, skey, isIterStart)
} else if bytes.Compare(k, key[:len(k)]) > 0 {
return self.key(node)
@@ -103,37 +98,36 @@ func (self *Iterator) next(node Node, key []byte, isIterStart bool) []byte {
return append(k, ret...)
}
}
- }
+ case hashNode:
+ return self.next(self.trie.resolveHash(node), key, isIterStart)
+ }
return nil
}
-func (self *Iterator) key(node Node) []byte {
+func (self *Iterator) key(node interface{}) []byte {
switch node := node.(type) {
- case *ShortNode:
+ case shortNode:
// Leaf node
- if vnode, ok := node.Value().(*ValueNode); ok {
- k := RemTerm(node.Key())
- self.Value = vnode.Val()
-
+ k := remTerm(node.Key)
+ if vnode, ok := node.Val.(valueNode); ok {
+ self.Value = vnode
return k
- } else {
- k := RemTerm(node.Key())
- return append(k, self.key(node.Value())...)
}
- case *FullNode:
- if node.Value() != nil {
- self.Value = node.Value().(*ValueNode).Val()
-
+ return append(k, self.key(node.Val)...)
+ case fullNode:
+ if node[16] != nil {
+ self.Value = node[16].(valueNode)
return []byte{16}
}
-
for i := 0; i < 16; i++ {
- k := self.key(node.branch(byte(i)))
+ k := self.key(node[i])
if k != nil {
return append([]byte{byte(i)}, k...)
}
}
+ case hashNode:
+ return self.key(self.trie.resolveHash(node))
}
return nil
diff --git a/trie/iterator_test.go b/trie/iterator_test.go
index 148f9adf9..fdc60b412 100644
--- a/trie/iterator_test.go
+++ b/trie/iterator_test.go
@@ -19,7 +19,7 @@ package trie
import "testing"
func TestIterator(t *testing.T) {
- trie := NewEmpty()
+ trie := newEmpty()
vals := []struct{ k, v string }{
{"do", "verb"},
{"ether", "wookiedoo"},
@@ -32,11 +32,11 @@ func TestIterator(t *testing.T) {
v := make(map[string]bool)
for _, val := range vals {
v[val.k] = false
- trie.UpdateString(val.k, val.v)
+ trie.Update([]byte(val.k), []byte(val.v))
}
trie.Commit()
- it := trie.Iterator()
+ it := NewIterator(trie)
for it.Next() {
v[string(it.Key)] = true
}
diff --git a/trie/node.go b/trie/node.go
index 9d49029de..0bfa21dc4 100644
--- a/trie/node.go
+++ b/trie/node.go
@@ -16,46 +16,172 @@
package trie
-import "fmt"
+import (
+ "fmt"
+ "io"
+ "strings"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/rlp"
+)
var indices = []string{"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "a", "b", "c", "d", "e", "f", "[17]"}
-type Node interface {
- Value() Node
- Copy(*Trie) Node // All nodes, for now, return them self
- Dirty() bool
+type node interface {
fstring(string) string
- Hash() interface{}
- RlpData() interface{}
- setDirty(dirty bool)
}
-// Value node
-func (self *ValueNode) String() string { return self.fstring("") }
-func (self *FullNode) String() string { return self.fstring("") }
-func (self *ShortNode) String() string { return self.fstring("") }
-func (self *ValueNode) fstring(ind string) string { return fmt.Sprintf("%x ", self.data) }
+type (
+ fullNode [17]node
+ shortNode struct {
+ Key []byte
+ Val node
+ }
+ hashNode []byte
+ valueNode []byte
+)
-//func (self *HashNode) fstring(ind string) string { return fmt.Sprintf("< %x > ", self.key) }
-func (self *HashNode) fstring(ind string) string {
- return fmt.Sprintf("%v", self.trie.trans(self))
-}
+// Pretty printing.
+func (n fullNode) String() string { return n.fstring("") }
+func (n shortNode) String() string { return n.fstring("") }
+func (n hashNode) String() string { return n.fstring("") }
+func (n valueNode) String() string { return n.fstring("") }
-// Full node
-func (self *FullNode) fstring(ind string) string {
+func (n fullNode) fstring(ind string) string {
resp := fmt.Sprintf("[\n%s ", ind)
- for i, node := range self.nodes {
+ for i, node := range n {
if node == nil {
resp += fmt.Sprintf("%s: <nil> ", indices[i])
} else {
resp += fmt.Sprintf("%s: %v", indices[i], node.fstring(ind+" "))
}
}
-
return resp + fmt.Sprintf("\n%s] ", ind)
}
+func (n shortNode) fstring(ind string) string {
+ return fmt.Sprintf("{%x: %v} ", n.Key, n.Val.fstring(ind+" "))
+}
+func (n hashNode) fstring(ind string) string {
+ return fmt.Sprintf("<%x> ", []byte(n))
+}
+func (n valueNode) fstring(ind string) string {
+ return fmt.Sprintf("%x ", []byte(n))
+}
+
+func mustDecodeNode(dbkey, buf []byte) node {
+ n, err := decodeNode(buf)
+ if err != nil {
+ panic(fmt.Sprintf("node %x: %v", dbkey, err))
+ }
+ return n
+}
+
+// decodeNode parses the RLP encoding of a trie node.
+func decodeNode(buf []byte) (node, error) {
+ if len(buf) == 0 {
+ return nil, io.ErrUnexpectedEOF
+ }
+ elems, _, err := rlp.SplitList(buf)
+ if err != nil {
+ return nil, fmt.Errorf("decode error: %v", err)
+ }
+ switch c, _ := rlp.CountValues(elems); c {
+ case 2:
+ n, err := decodeShort(elems)
+ return n, wrapError(err, "short")
+ case 17:
+ n, err := decodeFull(elems)
+ return n, wrapError(err, "full")
+ default:
+ return nil, fmt.Errorf("invalid number of list elements: %v", c)
+ }
+}
+
+func decodeShort(buf []byte) (node, error) {
+ kbuf, rest, err := rlp.SplitString(buf)
+ if err != nil {
+ return nil, err
+ }
+ key := compactDecode(kbuf)
+ if key[len(key)-1] == 16 {
+ // value node
+ val, _, err := rlp.SplitString(rest)
+ if err != nil {
+ return nil, fmt.Errorf("invalid value node: %v", err)
+ }
+ return shortNode{key, valueNode(val)}, nil
+ }
+ r, _, err := decodeRef(rest)
+ if err != nil {
+ return nil, wrapError(err, "val")
+ }
+ return shortNode{key, r}, nil
+}
+
+func decodeFull(buf []byte) (fullNode, error) {
+ var n fullNode
+ for i := 0; i < 16; i++ {
+ cld, rest, err := decodeRef(buf)
+ if err != nil {
+ return n, wrapError(err, fmt.Sprintf("[%d]", i))
+ }
+ n[i], buf = cld, rest
+ }
+ val, _, err := rlp.SplitString(buf)
+ if err != nil {
+ return n, err
+ }
+ if len(val) > 0 {
+ n[16] = valueNode(val)
+ }
+ return n, nil
+}
+
+const hashLen = len(common.Hash{})
+
+func decodeRef(buf []byte) (node, []byte, error) {
+ kind, val, rest, err := rlp.Split(buf)
+ if err != nil {
+ return nil, buf, err
+ }
+ switch {
+ case kind == rlp.List:
+ // 'embedded' node reference. The encoding must be smaller
+ // than a hash in order to be valid.
+ if size := len(buf) - len(rest); size > hashLen {
+ err := fmt.Errorf("oversized embedded node (size is %d bytes, want size < %d)", size, hashLen)
+ return nil, buf, err
+ }
+ n, err := decodeNode(buf)
+ return n, rest, err
+ case kind == rlp.String && len(val) == 0:
+ // empty node
+ return nil, rest, nil
+ case kind == rlp.String && len(val) == 32:
+ return hashNode(val), rest, nil
+ default:
+ return nil, nil, fmt.Errorf("invalid RLP string size %d (want 0 or 32)", len(val))
+ }
+}
+
+// wraps a decoding error with information about the path to the
+// invalid child node (for debugging encoding issues).
+type decodeError struct {
+ what error
+ stack []string
+}
+
+func wrapError(err error, ctx string) error {
+ if err == nil {
+ return nil
+ }
+ if decErr, ok := err.(*decodeError); ok {
+ decErr.stack = append(decErr.stack, ctx)
+ return decErr
+ }
+ return &decodeError{err, []string{ctx}}
+}
-// Short node
-func (self *ShortNode) fstring(ind string) string {
- return fmt.Sprintf("[ %x: %v ] ", self.key, self.value.fstring(ind+" "))
+func (err *decodeError) Error() string {
+ return fmt.Sprintf("%v (decode path: %s)", err.what, strings.Join(err.stack, "<-"))
}
diff --git a/trie/secure_trie.go b/trie/secure_trie.go
index 47c7542bb..47d1934d0 100644
--- a/trie/secure_trie.go
+++ b/trie/secure_trie.go
@@ -16,46 +16,93 @@
package trie
-import "github.com/ethereum/go-ethereum/crypto"
+import (
+ "hash"
-var keyPrefix = []byte("secure-key-")
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/crypto/sha3"
+)
+var secureKeyPrefix = []byte("secure-key-")
+
+// SecureTrie wraps a trie with key hashing. In a secure trie, all
+// access operations hash the key using keccak256. This prevents
+// calling code from creating long chains of nodes that
+// increase the access time.
+//
+// Contrary to a regular trie, a SecureTrie can only be created with
+// New and must have an attached database. The database also stores
+// the preimage of each key.
+//
+// SecureTrie is not safe for concurrent use.
type SecureTrie struct {
*Trie
-}
-func NewSecure(root []byte, backend Backend) *SecureTrie {
- return &SecureTrie{New(root, backend)}
+ hash hash.Hash
+ secKeyBuf []byte
+ hashKeyBuf []byte
}
-func (self *SecureTrie) Update(key, value []byte) Node {
- shaKey := crypto.Sha3(key)
- self.Trie.cache.Put(append(keyPrefix, shaKey...), key)
-
- return self.Trie.Update(shaKey, value)
-}
-func (self *SecureTrie) UpdateString(key, value string) Node {
- return self.Update([]byte(key), []byte(value))
+// NewSecure creates a trie with an existing root node from db.
+//
+// If root is the zero hash or the sha3 hash of an empty string, the
+// trie is initially empty. Otherwise, New will panics if db is nil
+// and returns ErrMissingRoot if the root node cannpt be found.
+// Accessing the trie loads nodes from db on demand.
+func NewSecure(root common.Hash, db Database) (*SecureTrie, error) {
+ if db == nil {
+ panic("NewSecure called with nil database")
+ }
+ trie, err := New(root, db)
+ if err != nil {
+ return nil, err
+ }
+ return &SecureTrie{Trie: trie}, nil
}
-func (self *SecureTrie) Get(key []byte) []byte {
- return self.Trie.Get(crypto.Sha3(key))
+// Get returns the value for key stored in the trie.
+// The value bytes must not be modified by the caller.
+func (t *SecureTrie) Get(key []byte) []byte {
+ return t.Trie.Get(t.hashKey(key))
}
-func (self *SecureTrie) GetString(key string) []byte {
- return self.Get([]byte(key))
+
+// Update associates key with value in the trie. Subsequent calls to
+// Get will return value. If value has length zero, any existing value
+// is deleted from the trie and calls to Get will return nil.
+//
+// The value bytes must not be modified by the caller while they are
+// stored in the trie.
+func (t *SecureTrie) Update(key, value []byte) {
+ hk := t.hashKey(key)
+ t.Trie.Update(hk, value)
+ t.Trie.db.Put(t.secKey(hk), key)
}
-func (self *SecureTrie) Delete(key []byte) Node {
- return self.Trie.Delete(crypto.Sha3(key))
+// Delete removes any existing value for key from the trie.
+func (t *SecureTrie) Delete(key []byte) {
+ t.Trie.Delete(t.hashKey(key))
}
-func (self *SecureTrie) DeleteString(key string) Node {
- return self.Delete([]byte(key))
+
+// GetKey returns the sha3 preimage of a hashed key that was
+// previously used to store a value.
+func (t *SecureTrie) GetKey(shaKey []byte) []byte {
+ key, _ := t.Trie.db.Get(t.secKey(shaKey))
+ return key
}
-func (self *SecureTrie) Copy() *SecureTrie {
- return &SecureTrie{self.Trie.Copy()}
+func (t *SecureTrie) secKey(key []byte) []byte {
+ t.secKeyBuf = append(t.secKeyBuf[:0], secureKeyPrefix...)
+ t.secKeyBuf = append(t.secKeyBuf, key...)
+ return t.secKeyBuf
}
-func (self *SecureTrie) GetKey(shaKey []byte) []byte {
- return self.Trie.cache.Get(append(keyPrefix, shaKey...))
+func (t *SecureTrie) hashKey(key []byte) []byte {
+ if t.hash == nil {
+ t.hash = sha3.NewKeccak256()
+ t.hashKeyBuf = make([]byte, 32)
+ }
+ t.hash.Reset()
+ t.hash.Write(key)
+ t.hashKeyBuf = t.hash.Sum(t.hashKeyBuf[:0])
+ return t.hashKeyBuf
}
diff --git a/trie/secure_trie_test.go b/trie/secure_trie_test.go
new file mode 100644
index 000000000..13c6cd02e
--- /dev/null
+++ b/trie/secure_trie_test.go
@@ -0,0 +1,74 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+package trie
+
+import (
+ "bytes"
+ "testing"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/ethdb"
+)
+
+func newEmptySecure() *SecureTrie {
+ db, _ := ethdb.NewMemDatabase()
+ trie, _ := NewSecure(common.Hash{}, db)
+ return trie
+}
+
+func TestSecureDelete(t *testing.T) {
+ trie := newEmptySecure()
+ vals := []struct{ k, v string }{
+ {"do", "verb"},
+ {"ether", "wookiedoo"},
+ {"horse", "stallion"},
+ {"shaman", "horse"},
+ {"doge", "coin"},
+ {"ether", ""},
+ {"dog", "puppy"},
+ {"shaman", ""},
+ }
+ for _, val := range vals {
+ if val.v != "" {
+ trie.Update([]byte(val.k), []byte(val.v))
+ } else {
+ trie.Delete([]byte(val.k))
+ }
+ }
+ hash := trie.Hash()
+ exp := common.HexToHash("29b235a58c3c25ab83010c327d5932bcf05324b7d6b1185e650798034783ca9d")
+ if hash != exp {
+ t.Errorf("expected %x got %x", exp, hash)
+ }
+}
+
+func TestSecureGetKey(t *testing.T) {
+ trie := newEmptySecure()
+ trie.Update([]byte("foo"), []byte("bar"))
+
+ key := []byte("foo")
+ value := []byte("bar")
+ seckey := crypto.Sha3(key)
+
+ if !bytes.Equal(trie.Get(key), value) {
+ t.Errorf("Get did not return bar")
+ }
+ if k := trie.GetKey(seckey); !bytes.Equal(k, key) {
+ t.Errorf("GetKey returned %q, want %q", k, key)
+ }
+}
diff --git a/trie/shortnode.go b/trie/shortnode.go
deleted file mode 100644
index 569d5f109..000000000
--- a/trie/shortnode.go
+++ /dev/null
@@ -1,57 +0,0 @@
-// Copyright 2014 The go-ethereum Authors
-// This file is part of the go-ethereum library.
-//
-// The go-ethereum library is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Lesser General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// The go-ethereum library is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Lesser General Public License for more details.
-//
-// You should have received a copy of the GNU Lesser General Public License
-// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
-
-package trie
-
-import "github.com/ethereum/go-ethereum/common"
-
-type ShortNode struct {
- trie *Trie
- key []byte
- value Node
- dirty bool
-}
-
-func NewShortNode(t *Trie, key []byte, value Node) *ShortNode {
- return &ShortNode{t, CompactEncode(key), value, false}
-}
-func (self *ShortNode) Value() Node {
- self.value = self.trie.trans(self.value)
-
- return self.value
-}
-func (self *ShortNode) Dirty() bool { return self.dirty }
-func (self *ShortNode) Copy(t *Trie) Node {
- node := &ShortNode{t, nil, self.value.Copy(t), self.dirty}
- node.key = common.CopyBytes(self.key)
- node.dirty = true
- return node
-}
-
-func (self *ShortNode) RlpData() interface{} {
- return []interface{}{self.key, self.value.Hash()}
-}
-func (self *ShortNode) Hash() interface{} {
- return self.trie.store(self)
-}
-
-func (self *ShortNode) Key() []byte {
- return CompactDecode(self.key)
-}
-
-func (self *ShortNode) setDirty(dirty bool) {
- self.dirty = dirty
-}
diff --git a/trie/slice.go b/trie/slice.go
deleted file mode 100644
index ccefbd064..000000000
--- a/trie/slice.go
+++ /dev/null
@@ -1,69 +0,0 @@
-// Copyright 2014 The go-ethereum Authors
-// This file is part of the go-ethereum library.
-//
-// The go-ethereum library is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Lesser General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// The go-ethereum library is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Lesser General Public License for more details.
-//
-// You should have received a copy of the GNU Lesser General Public License
-// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
-
-package trie
-
-import (
- "bytes"
- "math"
-)
-
-// Helper function for comparing slices
-func CompareIntSlice(a, b []int) bool {
- if len(a) != len(b) {
- return false
- }
- for i, v := range a {
- if v != b[i] {
- return false
- }
- }
- return true
-}
-
-// Returns the amount of nibbles that match each other from 0 ...
-func MatchingNibbleLength(a, b []byte) int {
- var i, length = 0, int(math.Min(float64(len(a)), float64(len(b))))
-
- for i < length {
- if a[i] != b[i] {
- break
- }
- i++
- }
-
- return i
-}
-
-func HasTerm(s []byte) bool {
- return s[len(s)-1] == 16
-}
-
-func RemTerm(s []byte) []byte {
- if HasTerm(s) {
- return s[:len(s)-1]
- }
-
- return s
-}
-
-func BeginsWith(a, b []byte) bool {
- if len(b) > len(a) {
- return false
- }
-
- return bytes.Equal(a[:len(b)], b)
-}
diff --git a/trie/trie.go b/trie/trie.go
index abf48a850..aa8d39fe2 100644
--- a/trie/trie.go
+++ b/trie/trie.go
@@ -19,372 +19,425 @@ package trie
import (
"bytes"
- "container/list"
+ "errors"
"fmt"
- "sync"
+ "hash"
"github.com/ethereum/go-ethereum/common"
- "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/crypto/sha3"
+ "github.com/ethereum/go-ethereum/logger"
+ "github.com/ethereum/go-ethereum/logger/glog"
+ "github.com/ethereum/go-ethereum/rlp"
)
-func ParanoiaCheck(t1 *Trie, backend Backend) (bool, *Trie) {
- t2 := New(nil, backend)
+const defaultCacheCapacity = 800
- it := t1.Iterator()
- for it.Next() {
- t2.Update(it.Key, it.Value)
- }
-
- return bytes.Equal(t2.Hash(), t1.Hash()), t2
-}
-
-type Trie struct {
- mu sync.Mutex
- root Node
- roothash []byte
- cache *Cache
-
- revisions *list.List
-}
-
-func New(root []byte, backend Backend) *Trie {
- trie := &Trie{}
- trie.revisions = list.New()
- trie.roothash = root
- if backend != nil {
- trie.cache = NewCache(backend)
- }
+var (
+ // The global cache stores decoded trie nodes by hash as they get loaded.
+ globalCache = newARC(defaultCacheCapacity)
+ // This is the known root hash of an empty trie.
+ emptyRoot = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421")
+)
- if root != nil {
- value := common.NewValueFromBytes(trie.cache.Get(root))
- trie.root = trie.mknode(value)
- }
+var ErrMissingRoot = errors.New("missing root node")
- return trie
+// Database must be implemented by backing stores for the trie.
+type Database interface {
+ DatabaseWriter
+ // Get returns the value for key from the database.
+ Get(key []byte) (value []byte, err error)
}
-func (self *Trie) Iterator() *Iterator {
- return NewIterator(self)
+// DatabaseWriter wraps the Put method of a backing store for the trie.
+type DatabaseWriter interface {
+ // Put stores the mapping key->value in the database.
+ // Implementations must not hold onto the value bytes, the trie
+ // will reuse the slice across calls to Put.
+ Put(key, value []byte) error
}
-func (self *Trie) Copy() *Trie {
- cpy := make([]byte, 32)
- copy(cpy, self.roothash) // NOTE: cpy isn't being used anywhere?
- trie := New(nil, nil)
- trie.cache = self.cache.Copy()
- if self.root != nil {
- trie.root = self.root.Copy(trie)
- }
-
- return trie
+// Trie is a Merkle Patricia Trie.
+// The zero value is an empty trie with no database.
+// Use New to create a trie that sits on top of a database.
+//
+// Trie is not safe for concurrent use.
+type Trie struct {
+ root node
+ db Database
+ *hasher
}
-// Legacy support
-func (self *Trie) Root() []byte { return self.Hash() }
-func (self *Trie) Hash() []byte {
- var hash []byte
- if self.root != nil {
- t := self.root.Hash()
- if byts, ok := t.([]byte); ok && len(byts) > 0 {
- hash = byts
- } else {
- hash = crypto.Sha3(common.Encode(self.root.RlpData()))
+// New creates a trie with an existing root node from db.
+//
+// If root is the zero hash or the sha3 hash of an empty string, the
+// trie is initially empty and does not require a database. Otherwise,
+// New will panics if db is nil or root does not exist in the
+// database. Accessing the trie loads nodes from db on demand.
+func New(root common.Hash, db Database) (*Trie, error) {
+ trie := &Trie{db: db}
+ if (root != common.Hash{}) && root != emptyRoot {
+ if db == nil {
+ panic("trie.New: cannot use existing root without a database")
}
- } else {
- hash = crypto.Sha3(common.Encode(""))
- }
-
- if !bytes.Equal(hash, self.roothash) {
- self.revisions.PushBack(self.roothash)
- self.roothash = hash
+ if v, _ := trie.db.Get(root[:]); len(v) == 0 {
+ return nil, ErrMissingRoot
+ }
+ trie.root = hashNode(root.Bytes())
}
-
- return hash
+ return trie, nil
}
-func (self *Trie) Commit() {
- self.mu.Lock()
- defer self.mu.Unlock()
- // Hash first
- self.Hash()
-
- self.cache.Flush()
+// Iterator returns an iterator over all mappings in the trie.
+func (t *Trie) Iterator() *Iterator {
+ return NewIterator(t)
}
-// Reset should only be called if the trie has been hashed
-func (self *Trie) Reset() {
- self.mu.Lock()
- defer self.mu.Unlock()
-
- self.cache.Reset()
-
- if self.revisions.Len() > 0 {
- revision := self.revisions.Remove(self.revisions.Back()).([]byte)
- self.roothash = revision
+// Get returns the value for key stored in the trie.
+// The value bytes must not be modified by the caller.
+func (t *Trie) Get(key []byte) []byte {
+ key = compactHexDecode(key)
+ tn := t.root
+ for len(key) > 0 {
+ switch n := tn.(type) {
+ case shortNode:
+ if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) {
+ return nil
+ }
+ tn = n.Val
+ key = key[len(n.Key):]
+ case fullNode:
+ tn = n[key[0]]
+ key = key[1:]
+ case nil:
+ return nil
+ case hashNode:
+ tn = t.resolveHash(n)
+ default:
+ panic(fmt.Sprintf("%T: invalid node: %v", tn, tn))
+ }
}
- value := common.NewValueFromBytes(self.cache.Get(self.roothash))
- self.root = self.mknode(value)
+ return tn.(valueNode)
}
-func (self *Trie) UpdateString(key, value string) Node { return self.Update([]byte(key), []byte(value)) }
-func (self *Trie) Update(key, value []byte) Node {
- self.mu.Lock()
- defer self.mu.Unlock()
-
- k := CompactHexDecode(key)
-
+// Update associates key with value in the trie. Subsequent calls to
+// Get will return value. If value has length zero, any existing value
+// is deleted from the trie and calls to Get will return nil.
+//
+// The value bytes must not be modified by the caller while they are
+// stored in the trie.
+func (t *Trie) Update(key, value []byte) {
+ k := compactHexDecode(key)
if len(value) != 0 {
- node := NewValueNode(self, value)
- node.dirty = true
- self.root = self.insert(self.root, k, node)
+ t.root = t.insert(t.root, k, valueNode(value))
} else {
- self.root = self.delete(self.root, k)
+ t.root = t.delete(t.root, k)
}
-
- return self.root
-}
-
-func (self *Trie) GetString(key string) []byte { return self.Get([]byte(key)) }
-func (self *Trie) Get(key []byte) []byte {
- self.mu.Lock()
- defer self.mu.Unlock()
-
- k := CompactHexDecode(key)
-
- n := self.get(self.root, k)
- if n != nil {
- return n.(*ValueNode).Val()
- }
-
- return nil
}
-func (self *Trie) DeleteString(key string) Node { return self.Delete([]byte(key)) }
-func (self *Trie) Delete(key []byte) Node {
- self.mu.Lock()
- defer self.mu.Unlock()
-
- k := CompactHexDecode(key)
- self.root = self.delete(self.root, k)
-
- return self.root
-}
-
-func (self *Trie) insert(node Node, key []byte, value Node) Node {
+func (t *Trie) insert(n node, key []byte, value node) node {
if len(key) == 0 {
return value
}
-
- if node == nil {
- node := NewShortNode(self, key, value)
- node.dirty = true
- return node
- }
-
- switch node := node.(type) {
- case *ShortNode:
- k := node.Key()
- cnode := node.Value()
- if bytes.Equal(k, key) {
- node := NewShortNode(self, key, value)
- node.dirty = true
- return node
-
+ switch n := n.(type) {
+ case shortNode:
+ matchlen := prefixLen(key, n.Key)
+ // If the whole key matches, keep this short node as is
+ // and only update the value.
+ if matchlen == len(n.Key) {
+ return shortNode{n.Key, t.insert(n.Val, key[matchlen:], value)}
}
-
- var n Node
- matchlength := MatchingNibbleLength(key, k)
- if matchlength == len(k) {
- n = self.insert(cnode, key[matchlength:], value)
- } else {
- pnode := self.insert(nil, k[matchlength+1:], cnode)
- nnode := self.insert(nil, key[matchlength+1:], value)
- fulln := NewFullNode(self)
- fulln.dirty = true
- fulln.set(k[matchlength], pnode)
- fulln.set(key[matchlength], nnode)
- n = fulln
- }
- if matchlength == 0 {
- return n
+ // Otherwise branch out at the index where they differ.
+ var branch fullNode
+ branch[n.Key[matchlen]] = t.insert(nil, n.Key[matchlen+1:], n.Val)
+ branch[key[matchlen]] = t.insert(nil, key[matchlen+1:], value)
+ // Replace this shortNode with the branch if it occurs at index 0.
+ if matchlen == 0 {
+ return branch
}
+ // Otherwise, replace it with a short node leading up to the branch.
+ return shortNode{key[:matchlen], branch}
- snode := NewShortNode(self, key[:matchlength], n)
- snode.dirty = true
- return snode
+ case fullNode:
+ n[key[0]] = t.insert(n[key[0]], key[1:], value)
+ return n
- case *FullNode:
- cpy := node.Copy(self).(*FullNode)
- cpy.set(key[0], self.insert(node.branch(key[0]), key[1:], value))
- cpy.dirty = true
+ case nil:
+ return shortNode{key, value}
- return cpy
+ case hashNode:
+ // We've hit a part of the trie that isn't loaded yet. Load
+ // the node and insert into it. This leaves all child nodes on
+ // the path to the value in the trie.
+ //
+ // TODO: track whether insertion changed the value and keep
+ // n as a hash node if it didn't.
+ return t.insert(t.resolveHash(n), key, value)
default:
- panic(fmt.Sprintf("%T: invalid node: %v", node, node))
+ panic(fmt.Sprintf("%T: invalid node: %v", n, n))
}
}
-func (self *Trie) get(node Node, key []byte) Node {
- if len(key) == 0 {
- return node
- }
-
- if node == nil {
- return nil
- }
-
- switch node := node.(type) {
- case *ShortNode:
- k := node.Key()
- cnode := node.Value()
-
- if len(key) >= len(k) && bytes.Equal(k, key[:len(k)]) {
- return self.get(cnode, key[len(k):])
- }
-
- return nil
- case *FullNode:
- return self.get(node.branch(key[0]), key[1:])
- default:
- panic(fmt.Sprintf("%T: invalid node: %v", node, node))
- }
+// Delete removes any existing value for key from the trie.
+func (t *Trie) Delete(key []byte) {
+ k := compactHexDecode(key)
+ t.root = t.delete(t.root, k)
}
-func (self *Trie) delete(node Node, key []byte) Node {
- if len(key) == 0 && node == nil {
- return nil
- }
-
- switch node := node.(type) {
- case *ShortNode:
- k := node.Key()
- cnode := node.Value()
- if bytes.Equal(key, k) {
- return nil
- } else if bytes.Equal(key[:len(k)], k) {
- child := self.delete(cnode, key[len(k):])
-
- var n Node
- switch child := child.(type) {
- case *ShortNode:
- nkey := append(k, child.Key()...)
- n = NewShortNode(self, nkey, child.Value())
- n.(*ShortNode).dirty = true
- case *FullNode:
- sn := NewShortNode(self, node.Key(), child)
- sn.dirty = true
- sn.key = node.key
- n = sn
- }
-
- return n
- } else {
- return node
+// delete returns the new root of the trie with key deleted.
+// It reduces the trie to minimal form by simplifying
+// nodes on the way up after deleting recursively.
+func (t *Trie) delete(n node, key []byte) node {
+ switch n := n.(type) {
+ case shortNode:
+ matchlen := prefixLen(key, n.Key)
+ if matchlen < len(n.Key) {
+ return n // don't replace n on mismatch
+ }
+ if matchlen == len(key) {
+ return nil // remove n entirely for whole matches
+ }
+ // The key is longer than n.Key. Remove the remaining suffix
+ // from the subtrie. Child can never be nil here since the
+ // subtrie must contain at least two other values with keys
+ // longer than n.Key.
+ child := t.delete(n.Val, key[len(n.Key):])
+ switch child := child.(type) {
+ case shortNode:
+ // Deleting from the subtrie reduced it to another
+ // short node. Merge the nodes to avoid creating a
+ // shortNode{..., shortNode{...}}. Use concat (which
+ // always creates a new slice) instead of append to
+ // avoid modifying n.Key since it might be shared with
+ // other nodes.
+ return shortNode{concat(n.Key, child.Key...), child.Val}
+ default:
+ return shortNode{n.Key, child}
}
- case *FullNode:
- n := node.Copy(self).(*FullNode)
- n.set(key[0], self.delete(n.branch(key[0]), key[1:]))
- n.dirty = true
-
+ case fullNode:
+ n[key[0]] = t.delete(n[key[0]], key[1:])
+ // Check how many non-nil entries are left after deleting and
+ // reduce the full node to a short node if only one entry is
+ // left. Since n must've contained at least two children
+ // before deletion (otherwise it would not be a full node) n
+ // can never be reduced to nil.
+ //
+ // When the loop is done, pos contains the index of the single
+ // value that is left in n or -2 if n contains at least two
+ // values.
pos := -1
- for i := 0; i < 17; i++ {
- if n.branch(byte(i)) != nil {
+ for i, cld := range n {
+ if cld != nil {
if pos == -1 {
pos = i
} else {
pos = -2
+ break
}
}
}
-
- var nnode Node
- if pos == 16 {
- nnode = NewShortNode(self, []byte{16}, n.branch(byte(pos)))
- nnode.(*ShortNode).dirty = true
- } else if pos >= 0 {
- cnode := n.branch(byte(pos))
- switch cnode := cnode.(type) {
- case *ShortNode:
- // Stitch keys
- k := append([]byte{byte(pos)}, cnode.Key()...)
- nnode = NewShortNode(self, k, cnode.Value())
- nnode.(*ShortNode).dirty = true
- case *FullNode:
- nnode = NewShortNode(self, []byte{byte(pos)}, n.branch(byte(pos)))
- nnode.(*ShortNode).dirty = true
+ if pos >= 0 {
+ if pos != 16 {
+ // If the remaining entry is a short node, it replaces
+ // n and its key gets the missing nibble tacked to the
+ // front. This avoids creating an invalid
+ // shortNode{..., shortNode{...}}. Since the entry
+ // might not be loaded yet, resolve it just for this
+ // check.
+ cnode := t.resolve(n[pos])
+ if cnode, ok := cnode.(shortNode); ok {
+ k := append([]byte{byte(pos)}, cnode.Key...)
+ return shortNode{k, cnode.Val}
+ }
}
- } else {
- nnode = n
+ // Otherwise, n is replaced by a one-nibble short node
+ // containing the child.
+ return shortNode{[]byte{byte(pos)}, n[pos]}
}
+ // n still contains at least two values and cannot be reduced.
+ return n
- return nnode
case nil:
return nil
+
+ case hashNode:
+ // We've hit a part of the trie that isn't loaded yet. Load
+ // the node and delete from it. This leaves all child nodes on
+ // the path to the value in the trie.
+ //
+ // TODO: track whether deletion actually hit a key and keep
+ // n as a hash node if it didn't.
+ return t.delete(t.resolveHash(n), key)
+
default:
- panic(fmt.Sprintf("%T: invalid node: %v (%v)", node, node, key))
+ panic(fmt.Sprintf("%T: invalid node: %v (%v)", n, n, key))
}
}
-// casting functions and cache storing
-func (self *Trie) mknode(value *common.Value) Node {
- l := value.Len()
- switch l {
- case 0:
- return nil
- case 2:
- // A value node may consists of 2 bytes.
- if value.Get(0).Len() != 0 {
- key := CompactDecode(value.Get(0).Bytes())
- if key[len(key)-1] == 16 {
- return NewShortNode(self, key, NewValueNode(self, value.Get(1).Bytes()))
- } else {
- return NewShortNode(self, key, self.mknode(value.Get(1)))
- }
- }
- case 17:
- if len(value.Bytes()) != 17 {
- fnode := NewFullNode(self)
- for i := 0; i < 16; i++ {
- fnode.set(byte(i), self.mknode(value.Get(i)))
- }
- return fnode
+func concat(s1 []byte, s2 ...byte) []byte {
+ r := make([]byte, len(s1)+len(s2))
+ copy(r, s1)
+ copy(r[len(s1):], s2)
+ return r
+}
+
+func (t *Trie) resolve(n node) node {
+ if n, ok := n.(hashNode); ok {
+ return t.resolveHash(n)
+ }
+ return n
+}
+
+func (t *Trie) resolveHash(n hashNode) node {
+ if v, ok := globalCache.Get(n); ok {
+ return v
+ }
+ enc, err := t.db.Get(n)
+ if err != nil || enc == nil {
+ // TODO: This needs to be improved to properly distinguish errors.
+ // Disk I/O errors shouldn't produce nil (and cause a
+ // consensus failure or weird crash), but it is unclear how
+ // they could be handled because the entire stack above the trie isn't
+ // prepared to cope with missing state nodes.
+ if glog.V(logger.Error) {
+ glog.Errorf("Dangling hash node ref %x: %v", n, err)
}
- case 32:
- return NewHash(value.Bytes(), self)
+ return nil
+ }
+ dec := mustDecodeNode(n, enc)
+ if dec != nil {
+ globalCache.Put(n, dec)
}
+ return dec
+}
+
+// Root returns the root hash of the trie.
+// Deprecated: use Hash instead.
+func (t *Trie) Root() []byte { return t.Hash().Bytes() }
- return NewValueNode(self, value.Bytes())
+// Hash returns the root hash of the trie. It does not write to the
+// database and can be used even if the trie doesn't have one.
+func (t *Trie) Hash() common.Hash {
+ root, _ := t.hashRoot(nil)
+ return common.BytesToHash(root.(hashNode))
}
-func (self *Trie) trans(node Node) Node {
- switch node := node.(type) {
- case *HashNode:
- value := common.NewValueFromBytes(self.cache.Get(node.key))
- return self.mknode(value)
- default:
- return node
+// Commit writes all nodes to the trie's database.
+// Nodes are stored with their sha3 hash as the key.
+//
+// Committing flushes nodes from memory.
+// Subsequent Get calls will load nodes from the database.
+func (t *Trie) Commit() (root common.Hash, err error) {
+ if t.db == nil {
+ panic("Commit called on trie with nil database")
}
+ return t.CommitTo(t.db)
}
-func (self *Trie) store(node Node) interface{} {
- data := common.Encode(node)
- if len(data) >= 32 {
- key := crypto.Sha3(data)
- if node.Dirty() {
- //fmt.Println("save", node)
- //fmt.Println()
- self.cache.Put(key, data)
- }
+// CommitTo writes all nodes to the given database.
+// Nodes are stored with their sha3 hash as the key.
+//
+// Committing flushes nodes from memory. Subsequent Get calls will
+// load nodes from the trie's database. Calling code must ensure that
+// the changes made to db are written back to the trie's attached
+// database before using the trie.
+func (t *Trie) CommitTo(db DatabaseWriter) (root common.Hash, err error) {
+ n, err := t.hashRoot(db)
+ if err != nil {
+ return (common.Hash{}), err
+ }
+ t.root = n
+ return common.BytesToHash(n.(hashNode)), nil
+}
- return key
+func (t *Trie) hashRoot(db DatabaseWriter) (node, error) {
+ if t.root == nil {
+ return hashNode(emptyRoot.Bytes()), nil
+ }
+ if t.hasher == nil {
+ t.hasher = newHasher()
}
+ return t.hasher.hash(t.root, db, true)
+}
- return node.RlpData()
+type hasher struct {
+ tmp *bytes.Buffer
+ sha hash.Hash
}
-func (self *Trie) PrintRoot() {
- fmt.Println(self.root)
- fmt.Printf("root=%x\n", self.Root())
+func newHasher() *hasher {
+ return &hasher{tmp: new(bytes.Buffer), sha: sha3.NewKeccak256()}
+}
+
+func (h *hasher) hash(n node, db DatabaseWriter, force bool) (node, error) {
+ hashed, err := h.replaceChildren(n, db)
+ if err != nil {
+ return hashNode{}, err
+ }
+ if n, err = h.store(hashed, db, force); err != nil {
+ return hashNode{}, err
+ }
+ return n, nil
+}
+
+// hashChildren replaces child nodes of n with their hashes if the encoded
+// size of the child is larger than a hash.
+func (h *hasher) replaceChildren(n node, db DatabaseWriter) (node, error) {
+ var err error
+ switch n := n.(type) {
+ case shortNode:
+ n.Key = compactEncode(n.Key)
+ if _, ok := n.Val.(valueNode); !ok {
+ if n.Val, err = h.hash(n.Val, db, false); err != nil {
+ return n, err
+ }
+ }
+ if n.Val == nil {
+ // Ensure that nil children are encoded as empty strings.
+ n.Val = valueNode(nil)
+ }
+ return n, nil
+ case fullNode:
+ for i := 0; i < 16; i++ {
+ if n[i] != nil {
+ if n[i], err = h.hash(n[i], db, false); err != nil {
+ return n, err
+ }
+ } else {
+ // Ensure that nil children are encoded as empty strings.
+ n[i] = valueNode(nil)
+ }
+ }
+ if n[16] == nil {
+ n[16] = valueNode(nil)
+ }
+ return n, nil
+ default:
+ return n, nil
+ }
+}
+
+func (h *hasher) store(n node, db DatabaseWriter, force bool) (node, error) {
+ // Don't store hashes or empty nodes.
+ if _, isHash := n.(hashNode); n == nil || isHash {
+ return n, nil
+ }
+ h.tmp.Reset()
+ if err := rlp.Encode(h.tmp, n); err != nil {
+ panic("encode error: " + err.Error())
+ }
+ if h.tmp.Len() < 32 && !force {
+ // Nodes smaller than 32 bytes are stored inside their parent.
+ return n, nil
+ }
+ // Larger nodes are replaced by their hash and stored in the database.
+ h.sha.Reset()
+ h.sha.Write(h.tmp.Bytes())
+ key := hashNode(h.sha.Sum(nil))
+ if db != nil {
+ err := db.Put(key, h.tmp.Bytes())
+ return key, err
+ }
+ return key, nil
}
diff --git a/trie/trie_test.go b/trie/trie_test.go
index 607c96b0f..c96861bed 100644
--- a/trie/trie_test.go
+++ b/trie/trie_test.go
@@ -24,86 +24,103 @@ import (
"os"
"testing"
+ "github.com/davecgh/go-spew/spew"
"github.com/ethereum/go-ethereum/common"
- "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb"
)
-type Db map[string][]byte
-
-func (self Db) Get(k []byte) ([]byte, error) { return self[string(k)], nil }
-func (self Db) Put(k, v []byte) error { self[string(k)] = v; return nil }
-
-// Used for testing
-func NewEmpty() *Trie {
- return New(nil, make(Db))
+func init() {
+ spew.Config.Indent = " "
+ spew.Config.DisableMethods = true
}
-func NewEmptySecure() *SecureTrie {
- return NewSecure(nil, make(Db))
+// Used for testing
+func newEmpty() *Trie {
+ db, _ := ethdb.NewMemDatabase()
+ trie, _ := New(common.Hash{}, db)
+ return trie
}
func TestEmptyTrie(t *testing.T) {
- trie := NewEmpty()
+ var trie Trie
res := trie.Hash()
- exp := crypto.Sha3(common.Encode(""))
- if !bytes.Equal(res, exp) {
+ exp := emptyRoot
+ if res != common.Hash(exp) {
t.Errorf("expected %x got %x", exp, res)
}
}
func TestNull(t *testing.T) {
- trie := NewEmpty()
-
+ var trie Trie
key := make([]byte, 32)
value := common.FromHex("0x823140710bf13990e4500136726d8b55")
trie.Update(key, value)
value = trie.Get(key)
}
+func TestMissingRoot(t *testing.T) {
+ db, _ := ethdb.NewMemDatabase()
+ trie, err := New(common.HexToHash("0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33"), db)
+ if trie != nil {
+ t.Error("New returned non-nil trie for invalid root")
+ }
+ if err != ErrMissingRoot {
+ t.Error("New returned wrong error: %v", err)
+ }
+}
+
func TestInsert(t *testing.T) {
- trie := NewEmpty()
+ trie := newEmpty()
- trie.UpdateString("doe", "reindeer")
- trie.UpdateString("dog", "puppy")
- trie.UpdateString("dogglesworth", "cat")
+ updateString(trie, "doe", "reindeer")
+ updateString(trie, "dog", "puppy")
+ updateString(trie, "dogglesworth", "cat")
- exp := common.Hex2Bytes("8aad789dff2f538bca5d8ea56e8abe10f4c7ba3a5dea95fea4cd6e7c3a1168d3")
+ exp := common.HexToHash("8aad789dff2f538bca5d8ea56e8abe10f4c7ba3a5dea95fea4cd6e7c3a1168d3")
root := trie.Hash()
- if !bytes.Equal(root, exp) {
+ if root != exp {
t.Errorf("exp %x got %x", exp, root)
}
- trie = NewEmpty()
- trie.UpdateString("A", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")
+ trie = newEmpty()
+ updateString(trie, "A", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")
- exp = common.Hex2Bytes("d23786fb4a010da3ce639d66d5e904a11dbc02746d1ce25029e53290cabf28ab")
- root = trie.Hash()
- if !bytes.Equal(root, exp) {
+ exp = common.HexToHash("d23786fb4a010da3ce639d66d5e904a11dbc02746d1ce25029e53290cabf28ab")
+ root, err := trie.Commit()
+ if err != nil {
+ t.Fatalf("commit error: %v", err)
+ }
+ if root != exp {
t.Errorf("exp %x got %x", exp, root)
}
}
func TestGet(t *testing.T) {
- trie := NewEmpty()
-
- trie.UpdateString("doe", "reindeer")
- trie.UpdateString("dog", "puppy")
- trie.UpdateString("dogglesworth", "cat")
+ trie := newEmpty()
+ updateString(trie, "doe", "reindeer")
+ updateString(trie, "dog", "puppy")
+ updateString(trie, "dogglesworth", "cat")
+
+ for i := 0; i < 2; i++ {
+ res := getString(trie, "dog")
+ if !bytes.Equal(res, []byte("puppy")) {
+ t.Errorf("expected puppy got %x", res)
+ }
- res := trie.GetString("dog")
- if !bytes.Equal(res, []byte("puppy")) {
- t.Errorf("expected puppy got %x", res)
- }
+ unknown := getString(trie, "unknown")
+ if unknown != nil {
+ t.Errorf("expected nil got %x", unknown)
+ }
- unknown := trie.GetString("unknown")
- if unknown != nil {
- t.Errorf("expected nil got %x", unknown)
+ if i == 1 {
+ return
+ }
+ trie.Commit()
}
}
func TestDelete(t *testing.T) {
- trie := NewEmpty()
+ trie := newEmpty()
vals := []struct{ k, v string }{
{"do", "verb"},
{"ether", "wookiedoo"},
@@ -116,21 +133,21 @@ func TestDelete(t *testing.T) {
}
for _, val := range vals {
if val.v != "" {
- trie.UpdateString(val.k, val.v)
+ updateString(trie, val.k, val.v)
} else {
- trie.DeleteString(val.k)
+ deleteString(trie, val.k)
}
}
hash := trie.Hash()
- exp := common.Hex2Bytes("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84")
- if !bytes.Equal(hash, exp) {
+ exp := common.HexToHash("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84")
+ if hash != exp {
t.Errorf("expected %x got %x", exp, hash)
}
}
func TestEmptyValues(t *testing.T) {
- trie := NewEmpty()
+ trie := newEmpty()
vals := []struct{ k, v string }{
{"do", "verb"},
@@ -143,78 +160,85 @@ func TestEmptyValues(t *testing.T) {
{"shaman", ""},
}
for _, val := range vals {
- trie.UpdateString(val.k, val.v)
+ updateString(trie, val.k, val.v)
}
hash := trie.Hash()
- exp := common.Hex2Bytes("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84")
- if !bytes.Equal(hash, exp) {
+ exp := common.HexToHash("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84")
+ if hash != exp {
t.Errorf("expected %x got %x", exp, hash)
}
}
func TestReplication(t *testing.T) {
- trie := NewEmpty()
+ trie := newEmpty()
vals := []struct{ k, v string }{
{"do", "verb"},
{"ether", "wookiedoo"},
{"horse", "stallion"},
{"shaman", "horse"},
{"doge", "coin"},
- {"ether", ""},
{"dog", "puppy"},
- {"shaman", ""},
{"somethingveryoddindeedthis is", "myothernodedata"},
}
for _, val := range vals {
- trie.UpdateString(val.k, val.v)
+ updateString(trie, val.k, val.v)
}
- trie.Commit()
-
- trie2 := New(trie.Root(), trie.cache.backend)
- if string(trie2.GetString("horse")) != "stallion" {
- t.Error("expected to have horse => stallion")
+ exp, err := trie.Commit()
+ if err != nil {
+ t.Fatalf("commit error: %v", err)
}
- hash := trie2.Hash()
- exp := trie.Hash()
- if !bytes.Equal(hash, exp) {
+ // create a new trie on top of the database and check that lookups work.
+ trie2, err := New(exp, trie.db)
+ if err != nil {
+ t.Fatalf("can't recreate trie at %x: %v", exp, err)
+ }
+ for _, kv := range vals {
+ if string(getString(trie2, kv.k)) != kv.v {
+ t.Errorf("trie2 doesn't have %q => %q", kv.k, kv.v)
+ }
+ }
+ hash, err := trie2.Commit()
+ if err != nil {
+ t.Fatalf("commit error: %v", err)
+ }
+ if hash != exp {
t.Errorf("root failure. expected %x got %x", exp, hash)
}
-}
-
-func TestReset(t *testing.T) {
- trie := NewEmpty()
- vals := []struct{ k, v string }{
+ // perform some insertions on the new trie.
+ vals2 := []struct{ k, v string }{
{"do", "verb"},
{"ether", "wookiedoo"},
{"horse", "stallion"},
+ // {"shaman", "horse"},
+ // {"doge", "coin"},
+ // {"ether", ""},
+ // {"dog", "puppy"},
+ // {"somethingveryoddindeedthis is", "myothernodedata"},
+ // {"shaman", ""},
}
- for _, val := range vals {
- trie.UpdateString(val.k, val.v)
+ for _, val := range vals2 {
+ updateString(trie2, val.k, val.v)
}
- trie.Commit()
-
- before := common.CopyBytes(trie.roothash)
- trie.UpdateString("should", "revert")
- trie.Hash()
- // Should have no effect
- trie.Hash()
- trie.Hash()
- // ###
-
- trie.Reset()
- after := common.CopyBytes(trie.roothash)
+ if trie2.Hash() != exp {
+ t.Errorf("root failure. expected %x got %x", exp, hash)
+ }
+}
- if !bytes.Equal(before, after) {
- t.Errorf("expected roots to be equal. %x - %x", before, after)
+func paranoiaCheck(t1 *Trie) (bool, *Trie) {
+ t2 := new(Trie)
+ it := NewIterator(t1)
+ for it.Next() {
+ t2.Update(it.Key, it.Value)
}
+ return t2.Hash() == t1.Hash(), t2
}
func TestParanoia(t *testing.T) {
t.Skip()
- trie := NewEmpty()
+ trie := newEmpty()
vals := []struct{ k, v string }{
{"do", "verb"},
@@ -228,13 +252,13 @@ func TestParanoia(t *testing.T) {
{"somethingveryoddindeedthis is", "myothernodedata"},
}
for _, val := range vals {
- trie.UpdateString(val.k, val.v)
+ updateString(trie, val.k, val.v)
}
trie.Commit()
- ok, t2 := ParanoiaCheck(trie, trie.cache.backend)
+ ok, t2 := paranoiaCheck(trie)
if !ok {
- t.Errorf("trie paranoia check failed %x %x", trie.roothash, t2.roothash)
+ t.Errorf("trie paranoia check failed %x %x", trie.Hash(), t2.Hash())
}
}
@@ -243,27 +267,35 @@ func TestOutput(t *testing.T) {
t.Skip()
base := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
- trie := NewEmpty()
+ trie := newEmpty()
for i := 0; i < 50; i++ {
- trie.UpdateString(fmt.Sprintf("%s%d", base, i), "valueeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee")
+ updateString(trie, fmt.Sprintf("%s%d", base, i), "valueeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee")
}
fmt.Println("############################## FULL ################################")
fmt.Println(trie.root)
trie.Commit()
fmt.Println("############################## SMALL ################################")
- trie2 := New(trie.roothash, trie.cache.backend)
- trie2.GetString(base + "20")
+ trie2, _ := New(trie.Hash(), trie.db)
+ getString(trie2, base+"20")
fmt.Println(trie2.root)
}
+func TestLargeValue(t *testing.T) {
+ trie := newEmpty()
+ trie.Update([]byte("key1"), []byte{99, 99, 99, 99})
+ trie.Update([]byte("key2"), bytes.Repeat([]byte{1}, 32))
+ trie.Hash()
+
+}
+
type kv struct {
k, v []byte
t bool
}
func TestLargeData(t *testing.T) {
- trie := NewEmpty()
+ trie := newEmpty()
vals := make(map[string]*kv)
for i := byte(0); i < 255; i++ {
@@ -275,7 +307,7 @@ func TestLargeData(t *testing.T) {
vals[string(value2.k)] = value2
}
- it := trie.Iterator()
+ it := NewIterator(trie)
for it.Next() {
vals[string(it.Key)].t = true
}
@@ -295,34 +327,6 @@ func TestLargeData(t *testing.T) {
}
}
-func TestSecureDelete(t *testing.T) {
- trie := NewEmptySecure()
-
- vals := []struct{ k, v string }{
- {"do", "verb"},
- {"ether", "wookiedoo"},
- {"horse", "stallion"},
- {"shaman", "horse"},
- {"doge", "coin"},
- {"ether", ""},
- {"dog", "puppy"},
- {"shaman", ""},
- }
- for _, val := range vals {
- if val.v != "" {
- trie.UpdateString(val.k, val.v)
- } else {
- trie.DeleteString(val.k)
- }
- }
-
- hash := trie.Hash()
- exp := common.Hex2Bytes("29b235a58c3c25ab83010c327d5932bcf05324b7d6b1185e650798034783ca9d")
- if !bytes.Equal(hash, exp) {
- t.Errorf("expected %x got %x", exp, hash)
- }
-}
-
func BenchmarkGet(b *testing.B) { benchGet(b, false) }
func BenchmarkGetDB(b *testing.B) { benchGet(b, true) }
func BenchmarkUpdateBE(b *testing.B) { benchUpdate(b, binary.BigEndian) }
@@ -333,11 +337,11 @@ func BenchmarkHashLE(b *testing.B) { benchHash(b, binary.LittleEndian) }
const benchElemCount = 20000
func benchGet(b *testing.B, commit bool) {
- trie := New(nil, nil)
+ trie := new(Trie)
if commit {
dir, tmpdb := tempDB()
defer os.RemoveAll(dir)
- trie = New(nil, tmpdb)
+ trie, _ = New(common.Hash{}, tmpdb)
}
k := make([]byte, 32)
for i := 0; i < benchElemCount; i++ {
@@ -356,7 +360,7 @@ func benchGet(b *testing.B, commit bool) {
}
func benchUpdate(b *testing.B, e binary.ByteOrder) *Trie {
- trie := NewEmpty()
+ trie := newEmpty()
k := make([]byte, 32)
for i := 0; i < b.N; i++ {
e.PutUint64(k, uint64(i))
@@ -366,7 +370,7 @@ func benchUpdate(b *testing.B, e binary.ByteOrder) *Trie {
}
func benchHash(b *testing.B, e binary.ByteOrder) {
- trie := NewEmpty()
+ trie := newEmpty()
k := make([]byte, 32)
for i := 0; i < benchElemCount; i++ {
e.PutUint64(k, uint64(i))
@@ -379,7 +383,7 @@ func benchHash(b *testing.B, e binary.ByteOrder) {
}
}
-func tempDB() (string, Backend) {
+func tempDB() (string, Database) {
dir, err := ioutil.TempDir("", "trie-bench")
if err != nil {
panic(fmt.Sprintf("can't create temporary directory: %v", err))
@@ -390,3 +394,15 @@ func tempDB() (string, Backend) {
}
return dir, db
}
+
+func getString(trie *Trie, k string) []byte {
+ return trie.Get([]byte(k))
+}
+
+func updateString(trie *Trie, k, v string) {
+ trie.Update([]byte(k), []byte(v))
+}
+
+func deleteString(trie *Trie, k string) {
+ trie.Delete([]byte(k))
+}
diff --git a/trie/valuenode.go b/trie/valuenode.go
deleted file mode 100644
index 0afa64d54..000000000
--- a/trie/valuenode.go
+++ /dev/null
@@ -1,42 +0,0 @@
-// Copyright 2014 The go-ethereum Authors
-// This file is part of the go-ethereum library.
-//
-// The go-ethereum library is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Lesser General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// The go-ethereum library is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Lesser General Public License for more details.
-//
-// You should have received a copy of the GNU Lesser General Public License
-// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
-
-package trie
-
-import "github.com/ethereum/go-ethereum/common"
-
-type ValueNode struct {
- trie *Trie
- data []byte
- dirty bool
-}
-
-func NewValueNode(trie *Trie, data []byte) *ValueNode {
- return &ValueNode{trie, data, false}
-}
-
-func (self *ValueNode) Value() Node { return self } // Best not to call :-)
-func (self *ValueNode) Val() []byte { return self.data }
-func (self *ValueNode) Dirty() bool { return self.dirty }
-func (self *ValueNode) Copy(t *Trie) Node {
- return &ValueNode{t, common.CopyBytes(self.data), self.dirty}
-}
-func (self *ValueNode) RlpData() interface{} { return self.data }
-func (self *ValueNode) Hash() interface{} { return self.data }
-
-func (self *ValueNode) setDirty(dirty bool) {
- self.dirty = dirty
-}