aboutsummaryrefslogtreecommitdiffstats
path: root/trie/trie.go
diff options
context:
space:
mode:
Diffstat (limited to 'trie/trie.go')
-rw-r--r--trie/trie.go114
1 files changed, 57 insertions, 57 deletions
diff --git a/trie/trie.go b/trie/trie.go
index 55598af98..65005bae8 100644
--- a/trie/trie.go
+++ b/trie/trie.go
@@ -62,6 +62,23 @@ type Trie struct {
root node
db Database
originalRoot common.Hash
+
+ // Cache generation values.
+ // cachegen increase by one with each commit operation.
+ // new nodes are tagged with the current generation and unloaded
+ // when their generation is older than than cachegen-cachelimit.
+ cachegen, cachelimit uint16
+}
+
+// SetCacheLimit sets the number of 'cache generations' to keep.
+// A cache generations is created by a call to Commit.
+func (t *Trie) SetCacheLimit(l uint16) {
+ t.cachelimit = l
+}
+
+// newFlag returns the cache flag value for a newly created node.
+func (t *Trie) newFlag() nodeFlag {
+ return nodeFlag{dirty: true, gen: t.cachegen}
}
// New creates a trie with an existing root node from db.
@@ -120,27 +137,25 @@ func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode
return nil, nil, false, nil
case valueNode:
return n, n, false, nil
- case shortNode:
+ case *shortNode:
if len(key)-pos < len(n.Key) || !bytes.Equal(n.Key, key[pos:pos+len(n.Key)]) {
// key not found in trie
return nil, n, false, nil
}
value, newnode, didResolve, err = t.tryGet(n.Val, key, pos+len(n.Key))
if err == nil && didResolve {
+ n = n.copy()
n.Val = newnode
- return value, n, didResolve, err
- } else {
- return value, origNode, didResolve, err
}
- case fullNode:
- child := n.Children[key[pos]]
- value, newnode, didResolve, err = t.tryGet(child, key, pos+1)
+ return value, n, didResolve, err
+ case *fullNode:
+ value, newnode, didResolve, err = t.tryGet(n.Children[key[pos]], key, pos+1)
if err == nil && didResolve {
+ n = n.copy()
n.Children[key[pos]] = newnode
- return value, n, didResolve, err
- } else {
- return value, origNode, didResolve, err
+
}
+ return value, n, didResolve, err
case hashNode:
child, err := t.resolveHash(n, key[:pos], key[pos:])
if err != nil {
@@ -199,22 +214,19 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error
return true, value, nil
}
switch n := n.(type) {
- case shortNode:
+ case *shortNode:
matchlen := prefixLen(key, n.Key)
// If the whole key matches, keep this short node as is
// and only update the value.
if matchlen == len(n.Key) {
dirty, nn, err := t.insert(n.Val, append(prefix, key[:matchlen]...), key[matchlen:], value)
- if err != nil {
- return false, nil, err
+ if !dirty || err != nil {
+ return false, n, err
}
- if !dirty {
- return false, n, nil
- }
- return true, shortNode{n.Key, nn, nil, true}, nil
+ return true, &shortNode{n.Key, nn, t.newFlag()}, nil
}
// Otherwise branch out at the index where they differ.
- branch := fullNode{dirty: true}
+ branch := &fullNode{flags: t.newFlag()}
var err error
_, branch.Children[n.Key[matchlen]], err = t.insert(nil, append(prefix, n.Key[:matchlen+1]...), n.Key[matchlen+1:], n.Val)
if err != nil {
@@ -229,21 +241,19 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error
return true, branch, nil
}
// Otherwise, replace it with a short node leading up to the branch.
- return true, shortNode{key[:matchlen], branch, nil, true}, nil
+ return true, &shortNode{key[:matchlen], branch, t.newFlag()}, nil
- case fullNode:
+ case *fullNode:
dirty, nn, err := t.insert(n.Children[key[0]], append(prefix, key[0]), key[1:], value)
- if err != nil {
- return false, nil, err
+ if !dirty || err != nil {
+ return false, n, err
}
- if !dirty {
- return false, n, nil
- }
- n.Children[key[0]], n.hash, n.dirty = nn, nil, true
+ n = n.copy()
+ n.Children[key[0]], n.flags.hash, n.flags.dirty = nn, nil, true
return true, n, nil
case nil:
- return true, shortNode{key, value, nil, true}, nil
+ return true, &shortNode{key, value, t.newFlag()}, nil
case hashNode:
// We've hit a part of the trie that isn't loaded yet. Load
@@ -254,11 +264,8 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error
return false, nil, err
}
dirty, nn, err := t.insert(rn, prefix, key, value)
- if err != nil {
- return false, nil, err
- }
- if !dirty {
- return false, rn, nil
+ if !dirty || err != nil {
+ return false, rn, err
}
return true, nn, nil
@@ -291,7 +298,7 @@ func (t *Trie) TryDelete(key []byte) error {
// nodes on the way up after deleting recursively.
func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) {
switch n := n.(type) {
- case shortNode:
+ case *shortNode:
matchlen := prefixLen(key, n.Key)
if matchlen < len(n.Key) {
return false, n, nil // don't replace n on mismatch
@@ -304,34 +311,29 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) {
// subtrie must contain at least two other values with keys
// longer than n.Key.
dirty, child, err := t.delete(n.Val, append(prefix, key[:len(n.Key)]...), key[len(n.Key):])
- if err != nil {
- return false, nil, err
- }
- if !dirty {
- return false, n, nil
+ if !dirty || err != nil {
+ return false, n, err
}
switch child := child.(type) {
- case shortNode:
+ case *shortNode:
// Deleting from the subtrie reduced it to another
// short node. Merge the nodes to avoid creating a
// shortNode{..., shortNode{...}}. Use concat (which
// always creates a new slice) instead of append to
// avoid modifying n.Key since it might be shared with
// other nodes.
- return true, shortNode{concat(n.Key, child.Key...), child.Val, nil, true}, nil
+ return true, &shortNode{concat(n.Key, child.Key...), child.Val, t.newFlag()}, nil
default:
- return true, shortNode{n.Key, child, nil, true}, nil
+ return true, &shortNode{n.Key, child, t.newFlag()}, nil
}
- case fullNode:
+ case *fullNode:
dirty, nn, err := t.delete(n.Children[key[0]], append(prefix, key[0]), key[1:])
- if err != nil {
- return false, nil, err
- }
- if !dirty {
- return false, n, nil
+ if !dirty || err != nil {
+ return false, n, err
}
- n.Children[key[0]], n.hash, n.dirty = nn, nil, true
+ n = n.copy()
+ n.Children[key[0]], n.flags.hash, n.flags.dirty = nn, nil, true
// Check how many non-nil entries are left after deleting and
// reduce the full node to a short node if only one entry is
@@ -365,14 +367,14 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) {
if err != nil {
return false, nil, err
}
- if cnode, ok := cnode.(shortNode); ok {
+ if cnode, ok := cnode.(*shortNode); ok {
k := append([]byte{byte(pos)}, cnode.Key...)
- return true, shortNode{k, cnode.Val, nil, true}, nil
+ return true, &shortNode{k, cnode.Val, t.newFlag()}, nil
}
}
// Otherwise, n is replaced by a one-nibble short node
// containing the child.
- return true, shortNode{[]byte{byte(pos)}, n.Children[pos], nil, true}, nil
+ return true, &shortNode{[]byte{byte(pos)}, n.Children[pos], t.newFlag()}, nil
}
// n still contains at least two values and cannot be reduced.
return true, n, nil
@@ -392,11 +394,8 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) {
return false, nil, err
}
dirty, nn, err := t.delete(rn, prefix, key)
- if err != nil {
- return false, nil, err
- }
- if !dirty {
- return false, rn, nil
+ if !dirty || err != nil {
+ return false, rn, err
}
return true, nn, nil
@@ -471,6 +470,7 @@ func (t *Trie) CommitTo(db DatabaseWriter) (root common.Hash, err error) {
return (common.Hash{}), err
}
t.root = cached
+ t.cachegen++
return common.BytesToHash(hash.(hashNode)), nil
}
@@ -478,7 +478,7 @@ func (t *Trie) hashRoot(db DatabaseWriter) (node, node, error) {
if t.root == nil {
return hashNode(emptyRoot.Bytes()), nil, nil
}
- h := newHasher()
+ h := newHasher(t.cachegen, t.cachelimit)
defer returnHasherToPool(h)
return h.hash(t.root, db, true)
}