diff options
Diffstat (limited to 'trie/trie.go')
-rw-r--r-- | trie/trie.go | 114 |
1 files changed, 57 insertions, 57 deletions
diff --git a/trie/trie.go b/trie/trie.go index 55598af98..65005bae8 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -62,6 +62,23 @@ type Trie struct { root node db Database originalRoot common.Hash + + // Cache generation values. + // cachegen increase by one with each commit operation. + // new nodes are tagged with the current generation and unloaded + // when their generation is older than than cachegen-cachelimit. + cachegen, cachelimit uint16 +} + +// SetCacheLimit sets the number of 'cache generations' to keep. +// A cache generations is created by a call to Commit. +func (t *Trie) SetCacheLimit(l uint16) { + t.cachelimit = l +} + +// newFlag returns the cache flag value for a newly created node. +func (t *Trie) newFlag() nodeFlag { + return nodeFlag{dirty: true, gen: t.cachegen} } // New creates a trie with an existing root node from db. @@ -120,27 +137,25 @@ func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode return nil, nil, false, nil case valueNode: return n, n, false, nil - case shortNode: + case *shortNode: if len(key)-pos < len(n.Key) || !bytes.Equal(n.Key, key[pos:pos+len(n.Key)]) { // key not found in trie return nil, n, false, nil } value, newnode, didResolve, err = t.tryGet(n.Val, key, pos+len(n.Key)) if err == nil && didResolve { + n = n.copy() n.Val = newnode - return value, n, didResolve, err - } else { - return value, origNode, didResolve, err } - case fullNode: - child := n.Children[key[pos]] - value, newnode, didResolve, err = t.tryGet(child, key, pos+1) + return value, n, didResolve, err + case *fullNode: + value, newnode, didResolve, err = t.tryGet(n.Children[key[pos]], key, pos+1) if err == nil && didResolve { + n = n.copy() n.Children[key[pos]] = newnode - return value, n, didResolve, err - } else { - return value, origNode, didResolve, err + } + return value, n, didResolve, err case hashNode: child, err := t.resolveHash(n, key[:pos], key[pos:]) if err != nil { @@ -199,22 +214,19 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error return true, value, nil } switch n := n.(type) { - case shortNode: + case *shortNode: matchlen := prefixLen(key, n.Key) // If the whole key matches, keep this short node as is // and only update the value. if matchlen == len(n.Key) { dirty, nn, err := t.insert(n.Val, append(prefix, key[:matchlen]...), key[matchlen:], value) - if err != nil { - return false, nil, err + if !dirty || err != nil { + return false, n, err } - if !dirty { - return false, n, nil - } - return true, shortNode{n.Key, nn, nil, true}, nil + return true, &shortNode{n.Key, nn, t.newFlag()}, nil } // Otherwise branch out at the index where they differ. - branch := fullNode{dirty: true} + branch := &fullNode{flags: t.newFlag()} var err error _, branch.Children[n.Key[matchlen]], err = t.insert(nil, append(prefix, n.Key[:matchlen+1]...), n.Key[matchlen+1:], n.Val) if err != nil { @@ -229,21 +241,19 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error return true, branch, nil } // Otherwise, replace it with a short node leading up to the branch. - return true, shortNode{key[:matchlen], branch, nil, true}, nil + return true, &shortNode{key[:matchlen], branch, t.newFlag()}, nil - case fullNode: + case *fullNode: dirty, nn, err := t.insert(n.Children[key[0]], append(prefix, key[0]), key[1:], value) - if err != nil { - return false, nil, err + if !dirty || err != nil { + return false, n, err } - if !dirty { - return false, n, nil - } - n.Children[key[0]], n.hash, n.dirty = nn, nil, true + n = n.copy() + n.Children[key[0]], n.flags.hash, n.flags.dirty = nn, nil, true return true, n, nil case nil: - return true, shortNode{key, value, nil, true}, nil + return true, &shortNode{key, value, t.newFlag()}, nil case hashNode: // We've hit a part of the trie that isn't loaded yet. Load @@ -254,11 +264,8 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error return false, nil, err } dirty, nn, err := t.insert(rn, prefix, key, value) - if err != nil { - return false, nil, err - } - if !dirty { - return false, rn, nil + if !dirty || err != nil { + return false, rn, err } return true, nn, nil @@ -291,7 +298,7 @@ func (t *Trie) TryDelete(key []byte) error { // nodes on the way up after deleting recursively. func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) { switch n := n.(type) { - case shortNode: + case *shortNode: matchlen := prefixLen(key, n.Key) if matchlen < len(n.Key) { return false, n, nil // don't replace n on mismatch @@ -304,34 +311,29 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) { // subtrie must contain at least two other values with keys // longer than n.Key. dirty, child, err := t.delete(n.Val, append(prefix, key[:len(n.Key)]...), key[len(n.Key):]) - if err != nil { - return false, nil, err - } - if !dirty { - return false, n, nil + if !dirty || err != nil { + return false, n, err } switch child := child.(type) { - case shortNode: + case *shortNode: // Deleting from the subtrie reduced it to another // short node. Merge the nodes to avoid creating a // shortNode{..., shortNode{...}}. Use concat (which // always creates a new slice) instead of append to // avoid modifying n.Key since it might be shared with // other nodes. - return true, shortNode{concat(n.Key, child.Key...), child.Val, nil, true}, nil + return true, &shortNode{concat(n.Key, child.Key...), child.Val, t.newFlag()}, nil default: - return true, shortNode{n.Key, child, nil, true}, nil + return true, &shortNode{n.Key, child, t.newFlag()}, nil } - case fullNode: + case *fullNode: dirty, nn, err := t.delete(n.Children[key[0]], append(prefix, key[0]), key[1:]) - if err != nil { - return false, nil, err - } - if !dirty { - return false, n, nil + if !dirty || err != nil { + return false, n, err } - n.Children[key[0]], n.hash, n.dirty = nn, nil, true + n = n.copy() + n.Children[key[0]], n.flags.hash, n.flags.dirty = nn, nil, true // Check how many non-nil entries are left after deleting and // reduce the full node to a short node if only one entry is @@ -365,14 +367,14 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) { if err != nil { return false, nil, err } - if cnode, ok := cnode.(shortNode); ok { + if cnode, ok := cnode.(*shortNode); ok { k := append([]byte{byte(pos)}, cnode.Key...) - return true, shortNode{k, cnode.Val, nil, true}, nil + return true, &shortNode{k, cnode.Val, t.newFlag()}, nil } } // Otherwise, n is replaced by a one-nibble short node // containing the child. - return true, shortNode{[]byte{byte(pos)}, n.Children[pos], nil, true}, nil + return true, &shortNode{[]byte{byte(pos)}, n.Children[pos], t.newFlag()}, nil } // n still contains at least two values and cannot be reduced. return true, n, nil @@ -392,11 +394,8 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) { return false, nil, err } dirty, nn, err := t.delete(rn, prefix, key) - if err != nil { - return false, nil, err - } - if !dirty { - return false, rn, nil + if !dirty || err != nil { + return false, rn, err } return true, nn, nil @@ -471,6 +470,7 @@ func (t *Trie) CommitTo(db DatabaseWriter) (root common.Hash, err error) { return (common.Hash{}), err } t.root = cached + t.cachegen++ return common.BytesToHash(hash.(hashNode)), nil } @@ -478,7 +478,7 @@ func (t *Trie) hashRoot(db DatabaseWriter) (node, node, error) { if t.root == nil { return hashNode(emptyRoot.Bytes()), nil, nil } - h := newHasher() + h := newHasher(t.cachegen, t.cachelimit) defer returnHasherToPool(h) return h.hash(t.root, db, true) } |