aboutsummaryrefslogtreecommitdiffstats
path: root/ptrie
diff options
context:
space:
mode:
Diffstat (limited to 'ptrie')
-rw-r--r--ptrie/fullnode.go10
-rw-r--r--ptrie/iterator.go2
-rw-r--r--ptrie/node.go4
-rw-r--r--ptrie/trie.go49
-rw-r--r--ptrie/trie_test.go2
5 files changed, 49 insertions, 18 deletions
diff --git a/ptrie/fullnode.go b/ptrie/fullnode.go
index 7a7f7d22d..4dd98049d 100644
--- a/ptrie/fullnode.go
+++ b/ptrie/fullnode.go
@@ -1,5 +1,7 @@
package ptrie
+import "fmt"
+
type FullNode struct {
trie *Trie
nodes [17]Node
@@ -21,7 +23,9 @@ func (self *FullNode) Branches() []Node {
func (self *FullNode) Copy() Node {
nnode := NewFullNode(self.trie)
for i, node := range self.nodes {
- nnode.nodes[i] = node
+ if node != nil {
+ nnode.nodes[i] = node
+ }
}
return nnode
@@ -56,6 +60,10 @@ func (self *FullNode) RlpData() interface{} {
}
func (self *FullNode) set(k byte, value Node) {
+ if _, ok := value.(*ValueNode); ok && k != 16 {
+ fmt.Println(value, k)
+ }
+
self.nodes[int(k)] = value
}
diff --git a/ptrie/iterator.go b/ptrie/iterator.go
index 5714bdbc8..787ba09c0 100644
--- a/ptrie/iterator.go
+++ b/ptrie/iterator.go
@@ -14,7 +14,7 @@ type Iterator struct {
}
func NewIterator(trie *Trie) *Iterator {
- return &Iterator{trie: trie, Key: []byte{0}}
+ return &Iterator{trie: trie, Key: make([]byte, 32)}
}
func (self *Iterator) Next() bool {
diff --git a/ptrie/node.go b/ptrie/node.go
index 2c85dbce7..ab90a1a02 100644
--- a/ptrie/node.go
+++ b/ptrie/node.go
@@ -17,7 +17,7 @@ type Node interface {
func (self *ValueNode) String() string { return self.fstring("") }
func (self *FullNode) String() string { return self.fstring("") }
func (self *ShortNode) String() string { return self.fstring("") }
-func (self *ValueNode) fstring(ind string) string { return fmt.Sprintf("%s ", self.data) }
+func (self *ValueNode) fstring(ind string) string { return fmt.Sprintf("%x ", self.data) }
func (self *HashNode) fstring(ind string) string { return fmt.Sprintf("%x ", self.key) }
// Full node
@@ -36,5 +36,5 @@ func (self *FullNode) fstring(ind string) string {
// Short node
func (self *ShortNode) fstring(ind string) string {
- return fmt.Sprintf("[ %s: %v ] ", self.key, self.value.fstring(ind+" "))
+ return fmt.Sprintf("[ %x: %v ] ", self.key, self.value.fstring(ind+" "))
}
diff --git a/ptrie/trie.go b/ptrie/trie.go
index 9fe9ea52a..5c83b57d0 100644
--- a/ptrie/trie.go
+++ b/ptrie/trie.go
@@ -19,7 +19,7 @@ func ParanoiaCheck(t1 *Trie, backend Backend) (bool, *Trie) {
t2.Update(it.Key, it.Value)
}
- return bytes.Compare(t2.Hash(), t1.Hash()) == 0, t2
+ return bytes.Equal(t2.Hash(), t1.Hash()), t2
}
type Trie struct {
@@ -49,14 +49,17 @@ func (self *Trie) Iterator() *Iterator {
return NewIterator(self)
}
+func (self *Trie) Copy() *Trie {
+ return New(self.roothash, self.cache.backend)
+}
+
// Legacy support
func (self *Trie) Root() []byte { return self.Hash() }
func (self *Trie) Hash() []byte {
var hash []byte
if self.root != nil {
- //hash = self.root.Hash().([]byte)
t := self.root.Hash()
- if byts, ok := t.([]byte); ok {
+ if byts, ok := t.([]byte); ok && len(byts) > 0 {
hash = byts
} else {
hash = crypto.Sha3(ethutil.Encode(self.root.RlpData()))
@@ -73,6 +76,9 @@ func (self *Trie) Hash() []byte {
return hash
}
func (self *Trie) Commit() {
+ self.mu.Lock()
+ defer self.mu.Unlock()
+
// Hash first
self.Hash()
@@ -81,10 +87,15 @@ func (self *Trie) Commit() {
// Reset should only be called if the trie has been hashed
func (self *Trie) Reset() {
+ self.mu.Lock()
+ defer self.mu.Unlock()
+
self.cache.Reset()
- revision := self.revisions.Remove(self.revisions.Back()).([]byte)
- self.roothash = revision
+ if self.revisions.Len() > 0 {
+ revision := self.revisions.Remove(self.revisions.Back()).([]byte)
+ self.roothash = revision
+ }
value := ethutil.NewValueFromBytes(self.cache.Get(self.roothash))
self.root = self.mknode(value)
}
@@ -173,7 +184,7 @@ func (self *Trie) insert(node Node, key []byte, value Node) Node {
return cpy
default:
- panic("Invalid node")
+ panic(fmt.Sprintf("%T: invalid node: %v", node, node))
}
}
@@ -204,7 +215,7 @@ func (self *Trie) get(node Node, key []byte) Node {
}
func (self *Trie) delete(node Node, key []byte) Node {
- if len(key) == 0 {
+ if len(key) == 0 && node == nil {
return nil
}
@@ -223,7 +234,9 @@ func (self *Trie) delete(node Node, key []byte) Node {
nkey := append(k, child.Key()...)
n = NewShortNode(self, nkey, child.Value())
case *FullNode:
- n = NewShortNode(self, node.key, child)
+ sn := NewShortNode(self, node.Key(), child)
+ sn.key = node.key
+ n = sn
}
return n
@@ -264,9 +277,10 @@ func (self *Trie) delete(node Node, key []byte) Node {
}
return nnode
-
+ case nil:
+ return nil
default:
- panic("Invalid node")
+ panic(fmt.Sprintf("%T: invalid node: %v (%v)", node, node, key))
}
}
@@ -274,8 +288,13 @@ func (self *Trie) delete(node Node, key []byte) Node {
func (self *Trie) mknode(value *ethutil.Value) Node {
l := value.Len()
switch l {
+ case 0:
+ return nil
case 2:
- return NewShortNode(self, trie.CompactDecode(string(value.Get(0).Bytes())), self.mknode(value.Get(1)))
+ // A value node may consists of 2 bytes.
+ if value.Get(0).Len() != 0 {
+ return NewShortNode(self, trie.CompactDecode(string(value.Get(0).Bytes())), self.mknode(value.Get(1)))
+ }
case 17:
fnode := NewFullNode(self)
for i := 0; i < l; i++ {
@@ -284,9 +303,9 @@ func (self *Trie) mknode(value *ethutil.Value) Node {
return fnode
case 32:
return &HashNode{value.Bytes()}
- default:
- return &ValueNode{self, value.Bytes()}
}
+
+ return &ValueNode{self, value.Bytes()}
}
func (self *Trie) trans(node Node) Node {
@@ -310,3 +329,7 @@ func (self *Trie) store(node Node) interface{} {
return node.RlpData()
}
+
+func (self *Trie) PrintRoot() {
+ fmt.Println(self.root)
+}
diff --git a/ptrie/trie_test.go b/ptrie/trie_test.go
index 5b1c64140..63a8ed36e 100644
--- a/ptrie/trie_test.go
+++ b/ptrie/trie_test.go
@@ -141,7 +141,7 @@ func TestReplication(t *testing.T) {
trie2 := New(trie.roothash, trie.cache.backend)
if string(trie2.GetString("horse")) != "stallion" {
- t.Error("expected to have harse => stallion")
+ t.Error("expected to have horse => stallion")
}
hash := trie2.Hash()