From accc0fab4f407eaeab428127bd5395a28f371f9f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?P=C3=A9ter=20Szil=C3=A1gyi?= <peterke@gmail.com>
Date: Fri, 16 Nov 2018 13:15:05 +0200
Subject: core, eth/downloader: fix ancestor lookup for fast sync

---
 eth/downloader/downloader.go      | 33 +++++++++++++++++++++++++++++----
 eth/downloader/downloader_test.go | 30 +++++++++++++++++++++---------
 2 files changed, 50 insertions(+), 13 deletions(-)

(limited to 'eth')

diff --git a/eth/downloader/downloader.go b/eth/downloader/downloader.go
index 1f52fc6db..f81a5cbac 100644
--- a/eth/downloader/downloader.go
+++ b/eth/downloader/downloader.go
@@ -181,6 +181,9 @@ type BlockChain interface {
 	// HasBlock verifies a block's presence in the local chain.
 	HasBlock(common.Hash, uint64) bool
 
+	// HasFastBlock verifies a fast block's presence in the local chain.
+	HasFastBlock(common.Hash, uint64) bool
+
 	// GetBlockByHash retrieves a block from the local chain.
 	GetBlockByHash(common.Hash) *types.Block
 
@@ -663,8 +666,9 @@ func (d *Downloader) findAncestor(p *peerConnection, remoteHeader *types.Header)
 	if localHeight >= MaxForkAncestry {
 		floor = int64(localHeight - MaxForkAncestry)
 	}
-
 	from, count, skip, max := calculateRequestSpan(remoteHeight, localHeight)
+
+	p.log.Trace("Span searching for common ancestor", "count", count, "from", from, "skip", skip)
 	go p.peer.RequestHeadersByNumber(uint64(from), count, skip, false)
 
 	// Wait for the remote response to the head fetch
@@ -708,8 +712,17 @@ func (d *Downloader) findAncestor(p *peerConnection, remoteHeader *types.Header)
 				// Otherwise check if we already know the header or not
 				h := headers[i].Hash()
 				n := headers[i].Number.Uint64()
-				if (d.mode == FullSync && d.blockchain.HasBlock(h, n)) ||
-					(d.mode != FullSync && d.lightchain.HasHeader(h, n)) {
+
+				var known bool
+				switch d.mode {
+				case FullSync:
+					known = d.blockchain.HasBlock(h, n)
+				case FastSync:
+					known = d.blockchain.HasFastBlock(h, n)
+				default:
+					known = d.lightchain.HasHeader(h, n)
+				}
+				if known {
 					number, hash = n, h
 					break
 				}
@@ -738,6 +751,8 @@ func (d *Downloader) findAncestor(p *peerConnection, remoteHeader *types.Header)
 	if floor > 0 {
 		start = uint64(floor)
 	}
+	p.log.Trace("Binary searching for common ancestor", "start", start, "end", end)
+
 	for start+1 < end {
 		// Split our chain interval in two, and request the hash to cross check
 		check := (start + end) / 2
@@ -770,7 +785,17 @@ func (d *Downloader) findAncestor(p *peerConnection, remoteHeader *types.Header)
 				// Modify the search interval based on the response
 				h := headers[0].Hash()
 				n := headers[0].Number.Uint64()
-				if (d.mode == FullSync && !d.blockchain.HasBlock(h, n)) || (d.mode != FullSync && !d.lightchain.HasHeader(h, n)) {
+
+				var known bool
+				switch d.mode {
+				case FullSync:
+					known = d.blockchain.HasBlock(h, n)
+				case FastSync:
+					known = d.blockchain.HasFastBlock(h, n)
+				default:
+					known = d.lightchain.HasHeader(h, n)
+				}
+				if !known {
 					end = check
 					break
 				}
diff --git a/eth/downloader/downloader_test.go b/eth/downloader/downloader_test.go
index b30976d72..1a42965d3 100644
--- a/eth/downloader/downloader_test.go
+++ b/eth/downloader/downloader_test.go
@@ -115,6 +115,15 @@ func (dl *downloadTester) HasBlock(hash common.Hash, number uint64) bool {
 	return dl.GetBlockByHash(hash) != nil
 }
 
+// HasFastBlock checks if a block is present in the testers canonical chain.
+func (dl *downloadTester) HasFastBlock(hash common.Hash, number uint64) bool {
+	dl.lock.RLock()
+	defer dl.lock.RUnlock()
+
+	_, ok := dl.ownReceipts[hash]
+	return ok
+}
+
 // GetHeader retrieves a header from the testers canonical chain.
 func (dl *downloadTester) GetHeaderByHash(hash common.Hash) *types.Header {
 	dl.lock.RLock()
@@ -235,6 +244,7 @@ func (dl *downloadTester) InsertChain(blocks types.Blocks) (i int, err error) {
 			dl.ownHeaders[block.Hash()] = block.Header()
 		}
 		dl.ownBlocks[block.Hash()] = block
+		dl.ownReceipts[block.Hash()] = make(types.Receipts, 0)
 		dl.stateDb.Put(block.Root().Bytes(), []byte{0x00})
 		dl.ownChainTd[block.Hash()] = new(big.Int).Add(dl.ownChainTd[block.ParentHash()], block.Difficulty())
 	}
@@ -375,28 +385,28 @@ func (dlp *downloadTesterPeer) RequestNodeData(hashes []common.Hash) error {
 // assertOwnChain checks if the local chain contains the correct number of items
 // of the various chain components.
 func assertOwnChain(t *testing.T, tester *downloadTester, length int) {
+	// Mark this method as a helper to report errors at callsite, not in here
+	t.Helper()
+
 	assertOwnForkedChain(t, tester, 1, []int{length})
 }
 
 // assertOwnForkedChain checks if the local forked chain contains the correct
 // number of items of the various chain components.
 func assertOwnForkedChain(t *testing.T, tester *downloadTester, common int, lengths []int) {
+	// Mark this method as a helper to report errors at callsite, not in here
+	t.Helper()
+
 	// Initialize the counters for the first fork
-	headers, blocks, receipts := lengths[0], lengths[0], lengths[0]-fsMinFullBlocks
+	headers, blocks, receipts := lengths[0], lengths[0], lengths[0]
 
-	if receipts < 0 {
-		receipts = 1
-	}
 	// Update the counters for each subsequent fork
 	for _, length := range lengths[1:] {
 		headers += length - common
 		blocks += length - common
-		receipts += length - common - fsMinFullBlocks
+		receipts += length - common
 	}
-	switch tester.downloader.mode {
-	case FullSync:
-		receipts = 1
-	case LightSync:
+	if tester.downloader.mode == LightSync {
 		blocks, receipts = 1, 1
 	}
 	if hs := len(tester.ownHeaders); hs != headers {
@@ -1150,7 +1160,9 @@ func testSyncProgress(t *testing.T, protocol int, mode SyncMode) {
 }
 
 func checkProgress(t *testing.T, d *Downloader, stage string, want ethereum.SyncProgress) {
+	// Mark this method as a helper to report errors at callsite, not in here
 	t.Helper()
+
 	p := d.Progress()
 	p.KnownStates, p.PulledStates = 0, 0
 	want.KnownStates, want.PulledStates = 0, 0
-- 
cgit v1.2.3