aboutsummaryrefslogtreecommitdiffstats
path: root/eth/downloader/downloader_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'eth/downloader/downloader_test.go')
-rw-r--r--eth/downloader/downloader_test.go123
1 files changed, 89 insertions, 34 deletions
diff --git a/eth/downloader/downloader_test.go b/eth/downloader/downloader_test.go
index 68c4ca26e..8944ae4b0 100644
--- a/eth/downloader/downloader_test.go
+++ b/eth/downloader/downloader_test.go
@@ -27,11 +27,13 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core"
+ "github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/params"
+ "github.com/ethereum/go-ethereum/trie"
)
var (
@@ -115,6 +117,7 @@ func makeChainFork(n, f int, parent *types.Block, parentReceipts types.Receipts)
// downloadTester is a test simulator for mocking out local block chain.
type downloadTester struct {
+ stateDb ethdb.Database
downloader *Downloader
ownHashes []common.Hash // Hash chain belonging to the tester
@@ -146,8 +149,10 @@ func newTester(mode SyncMode) *downloadTester {
peerReceipts: make(map[string]map[common.Hash]types.Receipts),
peerChainTds: make(map[string]map[common.Hash]*big.Int),
}
- tester.downloader = New(mode, new(event.TypeMux), tester.hasHeader, tester.hasBlock, tester.getHeader, tester.getBlock,
- tester.headHeader, tester.headBlock, tester.headFastBlock, tester.getTd, tester.insertHeaders, tester.insertBlocks, tester.insertReceipts, tester.dropPeer)
+ tester.stateDb, _ = ethdb.NewMemDatabase()
+ tester.downloader = New(mode, tester.stateDb, new(event.TypeMux), tester.hasHeader, tester.hasBlock, tester.getHeader,
+ tester.getBlock, tester.headHeader, tester.headBlock, tester.headFastBlock, tester.commitHeadBlock, tester.getTd,
+ tester.insertHeaders, tester.insertBlocks, tester.insertReceipts, tester.dropPeer)
return tester
}
@@ -213,7 +218,7 @@ func (dl *downloadTester) headHeader() *types.Header {
return header
}
}
- return nil
+ return genesis.Header()
}
// headBlock retrieves the current head block from the canonical chain.
@@ -223,10 +228,12 @@ func (dl *downloadTester) headBlock() *types.Block {
for i := len(dl.ownHashes) - 1; i >= 0; i-- {
if block := dl.getBlock(dl.ownHashes[i]); block != nil {
- return block
+ if _, err := dl.stateDb.Get(block.Root().Bytes()); err == nil {
+ return block
+ }
}
}
- return nil
+ return genesis
}
// headFastBlock retrieves the current head fast-sync block from the canonical chain.
@@ -236,12 +243,20 @@ func (dl *downloadTester) headFastBlock() *types.Block {
for i := len(dl.ownHashes) - 1; i >= 0; i-- {
if block := dl.getBlock(dl.ownHashes[i]); block != nil {
- if _, ok := dl.ownReceipts[block.Hash()]; ok {
- return block
- }
+ return block
}
}
- return nil
+ return genesis
+}
+
+// commitHeadBlock manually sets the head block to a given hash.
+func (dl *downloadTester) commitHeadBlock(hash common.Hash) error {
+ // For now only check that the state trie is correct
+ if block := dl.getBlock(hash); block != nil {
+ _, err := trie.NewSecure(block.Root(), dl.stateDb)
+ return err
+ }
+ return fmt.Errorf("non existent block: %x", hash[:4])
}
// getTd retrieves the block's total difficulty from the canonical chain.
@@ -283,6 +298,7 @@ func (dl *downloadTester) insertBlocks(blocks types.Blocks) (int, error) {
dl.ownHashes = append(dl.ownHashes, block.Hash())
dl.ownHeaders[block.Hash()] = block.Header()
dl.ownBlocks[block.Hash()] = block
+ dl.stateDb.Put(block.Root().Bytes(), []byte{})
dl.ownChainTd[block.Hash()] = dl.ownChainTd[block.ParentHash()]
}
return len(blocks), nil
@@ -321,13 +337,13 @@ func (dl *downloadTester) newSlowPeer(id string, version int, hashes []common.Ha
var err error
switch version {
case 61:
- err = dl.downloader.RegisterPeer(id, version, hashes[0], dl.peerGetRelHashesFn(id, delay), dl.peerGetAbsHashesFn(id, delay), dl.peerGetBlocksFn(id, delay), nil, nil, nil, nil)
+ err = dl.downloader.RegisterPeer(id, version, hashes[0], dl.peerGetRelHashesFn(id, delay), dl.peerGetAbsHashesFn(id, delay), dl.peerGetBlocksFn(id, delay), nil, nil, nil, nil, nil)
case 62:
- err = dl.downloader.RegisterPeer(id, version, hashes[0], nil, nil, nil, dl.peerGetRelHeadersFn(id, delay), dl.peerGetAbsHeadersFn(id, delay), dl.peerGetBodiesFn(id, delay), nil)
+ err = dl.downloader.RegisterPeer(id, version, hashes[0], nil, nil, nil, dl.peerGetRelHeadersFn(id, delay), dl.peerGetAbsHeadersFn(id, delay), dl.peerGetBodiesFn(id, delay), nil, nil)
case 63:
- err = dl.downloader.RegisterPeer(id, version, hashes[0], nil, nil, nil, dl.peerGetRelHeadersFn(id, delay), dl.peerGetAbsHeadersFn(id, delay), dl.peerGetBodiesFn(id, delay), dl.peerGetReceiptsFn(id, delay))
+ err = dl.downloader.RegisterPeer(id, version, hashes[0], nil, nil, nil, dl.peerGetRelHeadersFn(id, delay), dl.peerGetAbsHeadersFn(id, delay), dl.peerGetBodiesFn(id, delay), dl.peerGetReceiptsFn(id, delay), dl.peerGetNodeDataFn(id, delay))
case 64:
- err = dl.downloader.RegisterPeer(id, version, hashes[0], nil, nil, nil, dl.peerGetRelHeadersFn(id, delay), dl.peerGetAbsHeadersFn(id, delay), dl.peerGetBodiesFn(id, delay), dl.peerGetReceiptsFn(id, delay))
+ err = dl.downloader.RegisterPeer(id, version, hashes[0], nil, nil, nil, dl.peerGetRelHeadersFn(id, delay), dl.peerGetAbsHeadersFn(id, delay), dl.peerGetBodiesFn(id, delay), dl.peerGetReceiptsFn(id, delay), dl.peerGetNodeDataFn(id, delay))
}
if err == nil {
// Assign the owned hashes, headers and blocks to the peer (deep copy)
@@ -399,7 +415,7 @@ func (dl *downloadTester) peerGetRelHashesFn(id string, delay time.Duration) fun
// Delay delivery a bit to allow attacks to unfold
go func() {
time.Sleep(time.Millisecond)
- dl.downloader.DeliverHashes61(id, result)
+ dl.downloader.DeliverHashes(id, result)
}()
return nil
}
@@ -424,7 +440,7 @@ func (dl *downloadTester) peerGetAbsHashesFn(id string, delay time.Duration) fun
// Delay delivery a bit to allow attacks to unfold
go func() {
time.Sleep(time.Millisecond)
- dl.downloader.DeliverHashes61(id, result)
+ dl.downloader.DeliverHashes(id, result)
}()
return nil
}
@@ -447,7 +463,7 @@ func (dl *downloadTester) peerGetBlocksFn(id string, delay time.Duration) func([
result = append(result, block)
}
}
- go dl.downloader.DeliverBlocks61(id, result)
+ go dl.downloader.DeliverBlocks(id, result)
return nil
}
@@ -553,17 +569,54 @@ func (dl *downloadTester) peerGetReceiptsFn(id string, delay time.Duration) func
}
}
+// peerGetNodeDataFn constructs a getNodeData method associated with a particular
+// peer in the download tester. The returned function can be used to retrieve
+// batches of node state data from the particularly requested peer.
+func (dl *downloadTester) peerGetNodeDataFn(id string, delay time.Duration) func([]common.Hash) error {
+ return func(hashes []common.Hash) error {
+ time.Sleep(delay)
+
+ dl.lock.RLock()
+ defer dl.lock.RUnlock()
+
+ results := make([][]byte, 0, len(hashes))
+ for _, hash := range hashes {
+ if data, err := testdb.Get(hash.Bytes()); err == nil {
+ results = append(results, data)
+ }
+ }
+ go dl.downloader.DeliverNodeData(id, results)
+
+ return nil
+ }
+}
+
// assertOwnChain checks if the local chain contains the correct number of items
// of the various chain components.
func assertOwnChain(t *testing.T, tester *downloadTester, length int) {
- headers, blocks, receipts := length, length, length
+ assertOwnForkedChain(t, tester, 1, []int{length})
+}
+
+// assertOwnForkedChain checks if the local forked chain contains the correct
+// number of items of the various chain components.
+func assertOwnForkedChain(t *testing.T, tester *downloadTester, common int, lengths []int) {
+ // Initialize the counters for the first fork
+ headers, blocks, receipts := lengths[0], lengths[0], lengths[0]-minFullBlocks
+ if receipts < 0 {
+ receipts = 1
+ }
+ // Update the counters for each subsequent fork
+ for _, length := range lengths[1:] {
+ headers += length - common
+ blocks += length - common
+ receipts += length - common - minFullBlocks
+ }
switch tester.downloader.mode {
case FullSync:
receipts = 1
case LightSync:
blocks, receipts = 1, 1
}
-
if hs := len(tester.ownHeaders); hs != headers {
t.Fatalf("synchronised headers mismatch: have %v, want %v", hs, headers)
}
@@ -573,6 +626,14 @@ func assertOwnChain(t *testing.T, tester *downloadTester, length int) {
if rs := len(tester.ownReceipts); rs != receipts {
t.Fatalf("synchronised receipts mismatch: have %v, want %v", rs, receipts)
}
+ // Verify the state trie too for fast syncs
+ if tester.downloader.mode == FastSync {
+ if index := lengths[len(lengths)-1] - minFullBlocks - 1; index > 0 {
+ if statedb := state.New(tester.ownHeaders[tester.ownHashes[index]].Root, tester.stateDb); statedb == nil {
+ t.Fatalf("state reconstruction failed")
+ }
+ }
+ }
}
// Tests that simple synchronization against a canonical chain works correctly.
@@ -647,7 +708,9 @@ func testThrottling(t *testing.T, protocol int, mode SyncMode) {
cached = len(tester.downloader.queue.blockDonePool)
if mode == FastSync {
if receipts := len(tester.downloader.queue.receiptDonePool); receipts < cached {
- cached = receipts
+ if tester.downloader.queue.resultCache[receipts].Header.Number.Uint64() < tester.downloader.queue.fastSyncPivot {
+ cached = receipts
+ }
}
}
tester.downloader.queue.lock.RUnlock()
@@ -704,7 +767,7 @@ func testForkedSynchronisation(t *testing.T, protocol int, mode SyncMode) {
if err := tester.sync("fork B", nil); err != nil {
t.Fatalf("failed to synchronise blocks: %v", err)
}
- assertOwnChain(t, tester, common+2*fork+1)
+ assertOwnForkedChain(t, tester, common+1, []int{common + fork + 1, common + fork + 1})
}
// Tests that an inactive downloader will not accept incoming hashes and blocks.
@@ -712,10 +775,10 @@ func TestInactiveDownloader61(t *testing.T) {
tester := newTester(FullSync)
// Check that neither hashes nor blocks are accepted
- if err := tester.downloader.DeliverHashes61("bad peer", []common.Hash{}); err != errNoSyncActive {
+ if err := tester.downloader.DeliverHashes("bad peer", []common.Hash{}); err != errNoSyncActive {
t.Errorf("error mismatch: have %v, want %v", err, errNoSyncActive)
}
- if err := tester.downloader.DeliverBlocks61("bad peer", []*types.Block{}); err != errNoSyncActive {
+ if err := tester.downloader.DeliverBlocks("bad peer", []*types.Block{}); err != errNoSyncActive {
t.Errorf("error mismatch: have %v, want %v", err, errNoSyncActive)
}
}
@@ -809,14 +872,6 @@ func testMultiSynchronisation(t *testing.T, protocol int, mode SyncMode) {
id := fmt.Sprintf("peer #%d", i)
tester.newPeer(id, protocol, hashes[i*blockCacheLimit:], headers, blocks, receipts)
}
- // Synchronise with the middle peer and make sure half of the blocks were retrieved
- id := fmt.Sprintf("peer #%d", targetPeers/2)
- if err := tester.sync(id, nil); err != nil {
- t.Fatalf("failed to synchronise blocks: %v", err)
- }
- assertOwnChain(t, tester, len(tester.peerHashes[id]))
-
- // Synchronise with the best peer and make sure everything is retrieved
if err := tester.sync("peer #0", nil); err != nil {
t.Fatalf("failed to synchronise blocks: %v", err)
}
@@ -870,8 +925,8 @@ func TestEmptyShortCircuit64Fast(t *testing.T) { testEmptyShortCircuit(t, 64, F
func TestEmptyShortCircuit64Light(t *testing.T) { testEmptyShortCircuit(t, 64, LightSync) }
func testEmptyShortCircuit(t *testing.T, protocol int, mode SyncMode) {
- // Create a small enough block chain to download
- targetBlocks := blockCacheLimit - 15
+ // Create a block chain to download
+ targetBlocks := 2*blockCacheLimit - 15
hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil)
tester := newTester(mode)
@@ -898,8 +953,8 @@ func testEmptyShortCircuit(t *testing.T, protocol int, mode SyncMode) {
bodiesNeeded++
}
}
- for _, receipt := range receipts {
- if mode == FastSync && len(receipt) > 0 {
+ for hash, receipt := range receipts {
+ if mode == FastSync && len(receipt) > 0 && headers[hash].Number.Uint64() <= uint64(targetBlocks-minFullBlocks) {
receiptsNeeded++
}
}