From 187d6a66a5176a1dc3e75d5ad4baad623762acb9 Mon Sep 17 00:00:00 2001 From: Felix Lange Date: Mon, 17 Oct 2016 21:31:27 +0200 Subject: trie: avoid loading the root node twice New checks whether the root node is present by loading it from the database. Keep the node around instead of discarding it. --- trie/trie.go | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) (limited to 'trie') diff --git a/trie/trie.go b/trie/trie.go index 65005bae8..cce4cfeb6 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -93,13 +93,11 @@ func New(root common.Hash, db Database) (*Trie, error) { if db == nil { panic("trie.New: cannot use existing root without a database") } - if v, _ := trie.db.Get(root[:]); len(v) == 0 { - return nil, &MissingNodeError{ - RootHash: root, - NodeHash: root, - } + rootnode, err := trie.resolveHash(root[:], nil, nil) + if err != nil { + return nil, err } - trie.root = hashNode(root.Bytes()) + trie.root = rootnode } return trie, nil } -- cgit v1.2.3 From 177cab5fe70910ee0af3fcf493d51999ae2d923d Mon Sep 17 00:00:00 2001 From: Felix Lange Date: Mon, 17 Oct 2016 16:13:50 +0200 Subject: trie: ensure resolved nodes stay loaded Commit 40cdcf1183 broke the optimisation which kept nodes resolved during Get in the trie. The decoder assigned cache generation 0 unconditionally, causing resolved nodes to get flushed on Commit. This commit fixes it and adds two tests. --- trie/hasher.go | 2 +- trie/node.go | 26 ++++++++-------- trie/proof.go | 2 +- trie/sync.go | 6 ++-- trie/trie.go | 11 ++++--- trie/trie_test.go | 91 ++++++++++++++++++++++++++++++++++++++++++------------- 6 files changed, 95 insertions(+), 43 deletions(-) (limited to 'trie') diff --git a/trie/hasher.go b/trie/hasher.go index e395e00d7..57e156ebf 100644 --- a/trie/hasher.go +++ b/trie/hasher.go @@ -58,7 +58,7 @@ func (h *hasher) hash(n node, db DatabaseWriter, force bool) (node, node, error) return hash, n, nil } if n.canUnload(h.cachegen, h.cachelimit) { - // Evict the node from cache. All of its subnodes will have a lower or equal + // Unload the node from cache. All of its subnodes will have a lower or equal // cache generation number. return hash, hash, nil } diff --git a/trie/node.go b/trie/node.go index de9752c93..4aa0cab65 100644 --- a/trie/node.go +++ b/trie/node.go @@ -104,8 +104,8 @@ func (n valueNode) fstring(ind string) string { return fmt.Sprintf("%x ", []byte(n)) } -func mustDecodeNode(hash, buf []byte) node { - n, err := decodeNode(hash, buf) +func mustDecodeNode(hash, buf []byte, cachegen uint16) node { + n, err := decodeNode(hash, buf, cachegen) if err != nil { panic(fmt.Sprintf("node %x: %v", hash, err)) } @@ -113,7 +113,7 @@ func mustDecodeNode(hash, buf []byte) node { } // decodeNode parses the RLP encoding of a trie node. -func decodeNode(hash, buf []byte) (node, error) { +func decodeNode(hash, buf []byte, cachegen uint16) (node, error) { if len(buf) == 0 { return nil, io.ErrUnexpectedEOF } @@ -123,22 +123,22 @@ func decodeNode(hash, buf []byte) (node, error) { } switch c, _ := rlp.CountValues(elems); c { case 2: - n, err := decodeShort(hash, buf, elems) + n, err := decodeShort(hash, buf, elems, cachegen) return n, wrapError(err, "short") case 17: - n, err := decodeFull(hash, buf, elems) + n, err := decodeFull(hash, buf, elems, cachegen) return n, wrapError(err, "full") default: return nil, fmt.Errorf("invalid number of list elements: %v", c) } } -func decodeShort(hash, buf, elems []byte) (node, error) { +func decodeShort(hash, buf, elems []byte, cachegen uint16) (node, error) { kbuf, rest, err := rlp.SplitString(elems) if err != nil { return nil, err } - flag := nodeFlag{hash: hash} + flag := nodeFlag{hash: hash, gen: cachegen} key := compactDecode(kbuf) if key[len(key)-1] == 16 { // value node @@ -148,17 +148,17 @@ func decodeShort(hash, buf, elems []byte) (node, error) { } return &shortNode{key, append(valueNode{}, val...), flag}, nil } - r, _, err := decodeRef(rest) + r, _, err := decodeRef(rest, cachegen) if err != nil { return nil, wrapError(err, "val") } return &shortNode{key, r, flag}, nil } -func decodeFull(hash, buf, elems []byte) (*fullNode, error) { - n := &fullNode{flags: nodeFlag{hash: hash}} +func decodeFull(hash, buf, elems []byte, cachegen uint16) (*fullNode, error) { + n := &fullNode{flags: nodeFlag{hash: hash, gen: cachegen}} for i := 0; i < 16; i++ { - cld, rest, err := decodeRef(elems) + cld, rest, err := decodeRef(elems, cachegen) if err != nil { return n, wrapError(err, fmt.Sprintf("[%d]", i)) } @@ -176,7 +176,7 @@ func decodeFull(hash, buf, elems []byte) (*fullNode, error) { const hashLen = len(common.Hash{}) -func decodeRef(buf []byte) (node, []byte, error) { +func decodeRef(buf []byte, cachegen uint16) (node, []byte, error) { kind, val, rest, err := rlp.Split(buf) if err != nil { return nil, buf, err @@ -189,7 +189,7 @@ func decodeRef(buf []byte) (node, []byte, error) { err := fmt.Errorf("oversized embedded node (size is %d bytes, want size < %d)", size, hashLen) return nil, buf, err } - n, err := decodeNode(nil, buf) + n, err := decodeNode(nil, buf, cachegen) return n, rest, err case kind == rlp.String && len(val) == 0: // empty node diff --git a/trie/proof.go b/trie/proof.go index f193b52df..bea5e5c09 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -101,7 +101,7 @@ func VerifyProof(rootHash common.Hash, key []byte, proof []rlp.RawValue) (value if !bytes.Equal(sha.Sum(nil), wantHash) { return nil, fmt.Errorf("bad proof node %d: hash mismatch", i) } - n, err := decodeNode(wantHash, buf) + n, err := decodeNode(wantHash, buf, 0) if err != nil { return nil, fmt.Errorf("bad proof node %d: %v", i, err) } diff --git a/trie/sync.go b/trie/sync.go index 400dff903..30caf6980 100644 --- a/trie/sync.go +++ b/trie/sync.go @@ -82,7 +82,7 @@ func (s *TrieSync) AddSubTrie(root common.Hash, depth int, parent common.Hash, c } key := root.Bytes() blob, _ := s.database.Get(key) - if local, err := decodeNode(key, blob); local != nil && err == nil { + if local, err := decodeNode(key, blob, 0); local != nil && err == nil { return } // Assemble the new sub-trie sync request @@ -158,7 +158,7 @@ func (s *TrieSync) Process(results []SyncResult) (int, error) { continue } // Decode the node data content and update the request - node, err := decodeNode(item.Hash[:], item.Data) + node, err := decodeNode(item.Hash[:], item.Data, 0) if err != nil { return i, err } @@ -246,7 +246,7 @@ func (s *TrieSync) children(req *request) ([]*request, error) { if node, ok := (*child.node).(hashNode); ok { // Try to resolve the node from the local database blob, _ := s.database.Get(node) - if local, err := decodeNode(node[:], blob); local != nil && err == nil { + if local, err := decodeNode(node[:], blob, 0); local != nil && err == nil { *child.node = local continue } diff --git a/trie/trie.go b/trie/trie.go index cce4cfeb6..30f566a8d 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -144,14 +144,15 @@ func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode if err == nil && didResolve { n = n.copy() n.Val = newnode + n.flags.gen = t.cachegen } return value, n, didResolve, err case *fullNode: value, newnode, didResolve, err = t.tryGet(n.Children[key[pos]], key, pos+1) if err == nil && didResolve { n = n.copy() + n.flags.gen = t.cachegen n.Children[key[pos]] = newnode - } return value, n, didResolve, err case hashNode: @@ -247,7 +248,8 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error return false, n, err } n = n.copy() - n.Children[key[0]], n.flags.hash, n.flags.dirty = nn, nil, true + n.flags = t.newFlag() + n.Children[key[0]] = nn return true, n, nil case nil: @@ -331,7 +333,8 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) { return false, n, err } n = n.copy() - n.Children[key[0]], n.flags.hash, n.flags.dirty = nn, nil, true + n.flags = t.newFlag() + n.Children[key[0]] = nn // Check how many non-nil entries are left after deleting and // reduce the full node to a short node if only one entry is @@ -427,7 +430,7 @@ func (t *Trie) resolveHash(n hashNode, prefix, suffix []byte) (node, error) { SuffixLen: len(suffix), } } - dec := mustDecodeNode(n, enc) + dec := mustDecodeNode(n, enc, t.cachegen) return dec, nil } diff --git a/trie/trie_test.go b/trie/trie_test.go index 32fbe6801..da0d2360b 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -300,25 +300,6 @@ func TestReplication(t *testing.T) { } } -// Not an actual test -func TestOutput(t *testing.T) { - t.Skip() - - base := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" - trie := newEmpty() - for i := 0; i < 50; i++ { - updateString(trie, fmt.Sprintf("%s%d", base, i), "valueeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee") - } - fmt.Println("############################## FULL ################################") - fmt.Println(trie.root) - - trie.Commit() - fmt.Println("############################## SMALL ################################") - trie2, _ := New(trie.Hash(), trie.db) - getString(trie2, base+"20") - fmt.Println(trie2.root) -} - func TestLargeValue(t *testing.T) { trie := newEmpty() trie.Update([]byte("key1"), []byte{99, 99, 99, 99}) @@ -326,14 +307,56 @@ func TestLargeValue(t *testing.T) { trie.Hash() } +type countingDB struct { + Database + gets map[string]int +} + +func (db *countingDB) Get(key []byte) ([]byte, error) { + db.gets[string(key)]++ + return db.Database.Get(key) +} + +// TestCacheUnload checks that decoded nodes are unloaded after a +// certain number of commit operations. +func TestCacheUnload(t *testing.T) { + // Create test trie with two branches. + trie := newEmpty() + key1 := "---------------------------------" + key2 := "---some other branch" + updateString(trie, key1, "this is the branch of key1.") + updateString(trie, key2, "this is the branch of key2.") + root, _ := trie.Commit() + + // Commit the trie repeatedly and access key1. + // The branch containing it is loaded from DB exactly two times: + // in the 0th and 6th iteration. + db := &countingDB{Database: trie.db, gets: make(map[string]int)} + trie, _ = New(root, db) + trie.SetCacheLimit(5) + for i := 0; i < 12; i++ { + getString(trie, key1) + trie.Commit() + } + + // Check that it got loaded two times. + for dbkey, count := range db.gets { + if count != 2 { + t.Errorf("db key %x loaded %d times, want %d times", []byte(dbkey), count, 2) + } + } +} + +// randTest performs random trie operations. +// Instances of this test are created by Generate. +type randTest []randTestStep + type randTestStep struct { op int key []byte // for opUpdate, opDelete, opGet value []byte // for opUpdate } -type randTest []randTestStep - const ( opUpdate = iota opDelete @@ -342,6 +365,7 @@ const ( opHash opReset opItercheckhash + opCheckCacheInvariant opMax // boundary value, not an actual op ) @@ -437,7 +461,32 @@ func runRandTest(rt randTest) bool { fmt.Println("hashes not equal") return false } + case opCheckCacheInvariant: + return checkCacheInvariant(tr.root, tr.cachegen, 0) + } + } + return true +} + +func checkCacheInvariant(n node, parentCachegen uint16, depth int) bool { + switch n := n.(type) { + case *shortNode: + if n.flags.gen > parentCachegen { + fmt.Printf("cache invariant violation: %d > %d\nat depth %d node %s", n.flags.gen, parentCachegen, depth, spew.Sdump(n)) + return false + } + return checkCacheInvariant(n.Val, n.flags.gen, depth+1) + case *fullNode: + if n.flags.gen > parentCachegen { + fmt.Printf("cache invariant violation: %d > %d\nat depth %d node %s", n.flags.gen, parentCachegen, depth, spew.Sdump(n)) + return false + } + for _, child := range n.Children { + if !checkCacheInvariant(child, n.flags.gen, depth+1) { + return false + } } + return true } return true } -- cgit v1.2.3 From 8d56bf5ceb74a7ed45c986450848a89e2df61189 Mon Sep 17 00:00:00 2001 From: Felix Lange Date: Mon, 17 Oct 2016 23:01:29 +0200 Subject: trie: ensure dirty flag is unset for embedded child nodes This was caught by the new invariant check. --- trie/hasher.go | 31 ++++++++++++++----------------- trie/trie_test.go | 43 ++++++++++++++++++++++++++++--------------- 2 files changed, 42 insertions(+), 32 deletions(-) (limited to 'trie') diff --git a/trie/hasher.go b/trie/hasher.go index 57e156ebf..b6223bf32 100644 --- a/trie/hasher.go +++ b/trie/hasher.go @@ -75,23 +75,20 @@ func (h *hasher) hash(n node, db DatabaseWriter, force bool) (node, node, error) if err != nil { return hashNode{}, n, err } - // Cache the hash of the ndoe for later reuse. - if hash, ok := hashed.(hashNode); ok && !force { - switch cached := cached.(type) { - case *shortNode: - cached = cached.copy() - cached.flags.hash = hash - if db != nil { - cached.flags.dirty = false - } - return hashed, cached, nil - case *fullNode: - cached = cached.copy() - cached.flags.hash = hash - if db != nil { - cached.flags.dirty = false - } - return hashed, cached, nil + // Cache the hash of the ndoe for later reuse and remove + // the dirty flag in commit mode. It's fine to assign these values directly + // without copying the node first because hashChildren copies it. + cachedHash, _ := hashed.(hashNode) + switch cn := cached.(type) { + case *shortNode: + cn.flags.hash = cachedHash + if db != nil { + cn.flags.dirty = false + } + case *fullNode: + cn.flags.hash = cachedHash + if db != nil { + cn.flags.dirty = false } } return hashed, cached, nil diff --git a/trie/trie_test.go b/trie/trie_test.go index da0d2360b..14ac5a666 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -462,31 +462,44 @@ func runRandTest(rt randTest) bool { return false } case opCheckCacheInvariant: - return checkCacheInvariant(tr.root, tr.cachegen, 0) + return checkCacheInvariant(tr.root, nil, tr.cachegen, false, 0) } } return true } -func checkCacheInvariant(n node, parentCachegen uint16, depth int) bool { +func checkCacheInvariant(n, parent node, parentCachegen uint16, parentDirty bool, depth int) bool { + var children []node + var flag nodeFlag switch n := n.(type) { case *shortNode: - if n.flags.gen > parentCachegen { - fmt.Printf("cache invariant violation: %d > %d\nat depth %d node %s", n.flags.gen, parentCachegen, depth, spew.Sdump(n)) - return false - } - return checkCacheInvariant(n.Val, n.flags.gen, depth+1) + flag = n.flags + children = []node{n.Val} case *fullNode: - if n.flags.gen > parentCachegen { - fmt.Printf("cache invariant violation: %d > %d\nat depth %d node %s", n.flags.gen, parentCachegen, depth, spew.Sdump(n)) + flag = n.flags + children = n.Children[:] + default: + return true + } + + showerror := func() { + fmt.Printf("at depth %d node %s", depth, spew.Sdump(n)) + fmt.Printf("parent: %s", spew.Sdump(parent)) + } + if flag.gen > parentCachegen { + fmt.Printf("cache invariant violation: %d > %d\n", flag.gen, parentCachegen) + showerror() + return false + } + if depth > 0 && !parentDirty && flag.dirty { + fmt.Printf("cache invariant violation: child is dirty but parent isn't\n") + showerror() + return false + } + for _, child := range children { + if !checkCacheInvariant(child, n, flag.gen, flag.dirty, depth+1) { return false } - for _, child := range n.Children { - if !checkCacheInvariant(child, n.flags.gen, depth+1) { - return false - } - } - return true } return true } -- cgit v1.2.3