aboutsummaryrefslogtreecommitdiffstats
path: root/trie
diff options
context:
space:
mode:
Diffstat (limited to 'trie')
-rw-r--r--trie/hasher.go31
-rw-r--r--trie/trie_test.go43
2 files changed, 42 insertions, 32 deletions
diff --git a/trie/hasher.go b/trie/hasher.go
index 57e156ebf..b6223bf32 100644
--- a/trie/hasher.go
+++ b/trie/hasher.go
@@ -75,23 +75,20 @@ func (h *hasher) hash(n node, db DatabaseWriter, force bool) (node, node, error)
if err != nil {
return hashNode{}, n, err
}
- // Cache the hash of the ndoe for later reuse.
- if hash, ok := hashed.(hashNode); ok && !force {
- switch cached := cached.(type) {
- case *shortNode:
- cached = cached.copy()
- cached.flags.hash = hash
- if db != nil {
- cached.flags.dirty = false
- }
- return hashed, cached, nil
- case *fullNode:
- cached = cached.copy()
- cached.flags.hash = hash
- if db != nil {
- cached.flags.dirty = false
- }
- return hashed, cached, nil
+ // Cache the hash of the ndoe for later reuse and remove
+ // the dirty flag in commit mode. It's fine to assign these values directly
+ // without copying the node first because hashChildren copies it.
+ cachedHash, _ := hashed.(hashNode)
+ switch cn := cached.(type) {
+ case *shortNode:
+ cn.flags.hash = cachedHash
+ if db != nil {
+ cn.flags.dirty = false
+ }
+ case *fullNode:
+ cn.flags.hash = cachedHash
+ if db != nil {
+ cn.flags.dirty = false
}
}
return hashed, cached, nil
diff --git a/trie/trie_test.go b/trie/trie_test.go
index da0d2360b..14ac5a666 100644
--- a/trie/trie_test.go
+++ b/trie/trie_test.go
@@ -462,31 +462,44 @@ func runRandTest(rt randTest) bool {
return false
}
case opCheckCacheInvariant:
- return checkCacheInvariant(tr.root, tr.cachegen, 0)
+ return checkCacheInvariant(tr.root, nil, tr.cachegen, false, 0)
}
}
return true
}
-func checkCacheInvariant(n node, parentCachegen uint16, depth int) bool {
+func checkCacheInvariant(n, parent node, parentCachegen uint16, parentDirty bool, depth int) bool {
+ var children []node
+ var flag nodeFlag
switch n := n.(type) {
case *shortNode:
- if n.flags.gen > parentCachegen {
- fmt.Printf("cache invariant violation: %d > %d\nat depth %d node %s", n.flags.gen, parentCachegen, depth, spew.Sdump(n))
- return false
- }
- return checkCacheInvariant(n.Val, n.flags.gen, depth+1)
+ flag = n.flags
+ children = []node{n.Val}
case *fullNode:
- if n.flags.gen > parentCachegen {
- fmt.Printf("cache invariant violation: %d > %d\nat depth %d node %s", n.flags.gen, parentCachegen, depth, spew.Sdump(n))
+ flag = n.flags
+ children = n.Children[:]
+ default:
+ return true
+ }
+
+ showerror := func() {
+ fmt.Printf("at depth %d node %s", depth, spew.Sdump(n))
+ fmt.Printf("parent: %s", spew.Sdump(parent))
+ }
+ if flag.gen > parentCachegen {
+ fmt.Printf("cache invariant violation: %d > %d\n", flag.gen, parentCachegen)
+ showerror()
+ return false
+ }
+ if depth > 0 && !parentDirty && flag.dirty {
+ fmt.Printf("cache invariant violation: child is dirty but parent isn't\n")
+ showerror()
+ return false
+ }
+ for _, child := range children {
+ if !checkCacheInvariant(child, n, flag.gen, flag.dirty, depth+1) {
return false
}
- for _, child := range n.Children {
- if !checkCacheInvariant(child, n.flags.gen, depth+1) {
- return false
- }
- }
- return true
}
return true
}