diff options
Diffstat (limited to 'trie/trie.go')
-rw-r--r-- | trie/trie.go | 37 |
1 files changed, 29 insertions, 8 deletions
diff --git a/trie/trie.go b/trie/trie.go index d990338ee..7e17baa2f 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -117,7 +117,9 @@ func (self *Trie) Update(key, value []byte) Node { k := CompactHexDecode(string(key)) if len(value) != 0 { - self.root = self.insert(self.root, k, &ValueNode{self, value}) + node := NewValueNode(self, value) + node.dirty = true + self.root = self.insert(self.root, k, node) } else { self.root = self.delete(self.root, k) } @@ -157,7 +159,9 @@ func (self *Trie) insert(node Node, key []byte, value Node) Node { } if node == nil { - return NewShortNode(self, key, value) + node := NewShortNode(self, key, value) + node.dirty = true + return node } switch node := node.(type) { @@ -165,7 +169,10 @@ func (self *Trie) insert(node Node, key []byte, value Node) Node { k := node.Key() cnode := node.Value() if bytes.Equal(k, key) { - return NewShortNode(self, key, value) + node := NewShortNode(self, key, value) + node.dirty = true + return node + } var n Node @@ -176,6 +183,7 @@ func (self *Trie) insert(node Node, key []byte, value Node) Node { pnode := self.insert(nil, k[matchlength+1:], cnode) nnode := self.insert(nil, key[matchlength+1:], value) fulln := NewFullNode(self) + fulln.dirty = true fulln.set(k[matchlength], pnode) fulln.set(key[matchlength], nnode) n = fulln @@ -184,11 +192,14 @@ func (self *Trie) insert(node Node, key []byte, value Node) Node { return n } - return NewShortNode(self, key[:matchlength], n) + snode := NewShortNode(self, key[:matchlength], n) + snode.dirty = true + return snode case *FullNode: cpy := node.Copy(self).(*FullNode) cpy.set(key[0], self.insert(node.branch(key[0]), key[1:], value)) + cpy.dirty = true return cpy @@ -242,8 +253,10 @@ func (self *Trie) delete(node Node, key []byte) Node { case *ShortNode: nkey := append(k, child.Key()...) n = NewShortNode(self, nkey, child.Value()) + n.(*ShortNode).dirty = true case *FullNode: sn := NewShortNode(self, node.Key(), child) + sn.dirty = true sn.key = node.key n = sn } @@ -256,6 +269,7 @@ func (self *Trie) delete(node Node, key []byte) Node { case *FullNode: n := node.Copy(self).(*FullNode) n.set(key[0], self.delete(n.branch(key[0]), key[1:])) + n.dirty = true pos := -1 for i := 0; i < 17; i++ { @@ -271,6 +285,7 @@ func (self *Trie) delete(node Node, key []byte) Node { var nnode Node if pos == 16 { nnode = NewShortNode(self, []byte{16}, n.branch(byte(pos))) + nnode.(*ShortNode).dirty = true } else if pos >= 0 { cnode := n.branch(byte(pos)) switch cnode := cnode.(type) { @@ -278,8 +293,10 @@ func (self *Trie) delete(node Node, key []byte) Node { // Stitch keys k := append([]byte{byte(pos)}, cnode.Key()...) nnode = NewShortNode(self, k, cnode.Value()) + nnode.(*ShortNode).dirty = true case *FullNode: nnode = NewShortNode(self, []byte{byte(pos)}, n.branch(byte(pos))) + nnode.(*ShortNode).dirty = true } } else { nnode = n @@ -304,7 +321,7 @@ func (self *Trie) mknode(value *common.Value) Node { if value.Get(0).Len() != 0 { key := CompactDecode(string(value.Get(0).Bytes())) if key[len(key)-1] == 16 { - return NewShortNode(self, key, &ValueNode{self, value.Get(1).Bytes()}) + return NewShortNode(self, key, NewValueNode(self, value.Get(1).Bytes())) } else { return NewShortNode(self, key, self.mknode(value.Get(1))) } @@ -318,10 +335,10 @@ func (self *Trie) mknode(value *common.Value) Node { return fnode } case 32: - return &HashNode{value.Bytes(), self} + return NewHash(value.Bytes(), self) } - return &ValueNode{self, value.Bytes()} + return NewValueNode(self, value.Bytes()) } func (self *Trie) trans(node Node) Node { @@ -338,7 +355,11 @@ func (self *Trie) store(node Node) interface{} { data := common.Encode(node) if len(data) >= 32 { key := crypto.Sha3(data) - self.cache.Put(key, data) + if node.Dirty() { + //fmt.Println("save", node) + //fmt.Println() + self.cache.Put(key, data) + } return key } |