diff options
-rw-r--r-- | core/blockchain.go | 37 | ||||
-rw-r--r-- | core/blockchain_test.go | 15 | ||||
-rw-r--r-- | eth/downloader/downloader.go | 60 | ||||
-rw-r--r-- | eth/downloader/downloader_test.go | 78 | ||||
-rw-r--r-- | eth/downloader/types.go | 3 | ||||
-rw-r--r-- | eth/handler.go | 2 |
6 files changed, 180 insertions, 15 deletions
diff --git a/core/blockchain.go b/core/blockchain.go index 3e7dfa9ee..490552ea0 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -245,7 +245,21 @@ func (bc *BlockChain) SetHead(head uint64) { if bc.currentBlock == nil { bc.currentBlock = bc.genesisBlock } - bc.insert(bc.currentBlock) + if bc.currentHeader == nil { + bc.currentHeader = bc.genesisBlock.Header() + } + if bc.currentFastBlock == nil { + bc.currentFastBlock = bc.genesisBlock + } + if err := WriteHeadBlockHash(bc.chainDb, bc.currentBlock.Hash()); err != nil { + glog.Fatalf("failed to reset head block hash: %v", err) + } + if err := WriteHeadHeaderHash(bc.chainDb, bc.currentHeader.Hash()); err != nil { + glog.Fatalf("failed to reset head header hash: %v", err) + } + if err := WriteHeadFastBlockHash(bc.chainDb, bc.currentFastBlock.Hash()); err != nil { + glog.Fatalf("failed to reset head fast block hash: %v", err) + } bc.loadLastState() } @@ -790,6 +804,27 @@ func (self *BlockChain) InsertHeaderChain(chain []*types.Header, checkFreq int) return 0, nil } +// Rollback is designed to remove a chain of links from the database that aren't +// certain enough to be valid. +func (self *BlockChain) Rollback(chain []common.Hash) { + for i := len(chain) - 1; i >= 0; i-- { + hash := chain[i] + + if self.currentHeader.Hash() == hash { + self.currentHeader = self.GetHeader(self.currentHeader.ParentHash) + WriteHeadHeaderHash(self.chainDb, self.currentHeader.Hash()) + } + if self.currentFastBlock.Hash() == hash { + self.currentFastBlock = self.GetBlock(self.currentFastBlock.ParentHash()) + WriteHeadFastBlockHash(self.chainDb, self.currentFastBlock.Hash()) + } + if self.currentBlock.Hash() == hash { + self.currentBlock = self.GetBlock(self.currentBlock.ParentHash()) + WriteHeadBlockHash(self.chainDb, self.currentBlock.Hash()) + } + } +} + // InsertReceiptChain attempts to complete an already existing header chain with // transaction and receipt data. func (self *BlockChain) InsertReceiptChain(blockChain types.Blocks, receiptChain []types.Receipts) (int, error) { diff --git a/core/blockchain_test.go b/core/blockchain_test.go index a614aaa2f..01667c21e 100644 --- a/core/blockchain_test.go +++ b/core/blockchain_test.go @@ -599,8 +599,8 @@ func testReorgBadHashes(t *testing.T, full bool) { t.Errorf("last block gasLimit mismatch: have: %x, want %x", ncm.GasLimit(), blocks[2].Header().GasLimit) } } else { - if ncm.CurrentHeader().Hash() != genesis.Hash() { - t.Errorf("last header hash mismatch: have: %x, want %x", ncm.CurrentHeader().Hash(), genesis.Hash()) + if ncm.CurrentHeader().Hash() != headers[2].Hash() { + t.Errorf("last header hash mismatch: have: %x, want %x", ncm.CurrentHeader().Hash(), headers[2].Hash()) } } } @@ -775,6 +775,11 @@ func TestLightVsFastVsFullChainHeads(t *testing.T) { height := uint64(1024) blocks, receipts := GenerateChain(genesis, gendb, int(height), nil) + // Configure a subchain to roll back + remove := []common.Hash{} + for _, block := range blocks[height/2:] { + remove = append(remove, block.Hash()) + } // Create a small assertion method to check the three heads assert := func(t *testing.T, kind string, chain *BlockChain, header uint64, fast uint64, block uint64) { if num := chain.CurrentBlock().NumberU64(); num != block { @@ -798,6 +803,8 @@ func TestLightVsFastVsFullChainHeads(t *testing.T) { t.Fatalf("failed to process block %d: %v", n, err) } assert(t, "archive", archive, height, height, height) + archive.Rollback(remove) + assert(t, "archive", archive, height/2, height/2, height/2) // Import the chain as a non-archive node and ensure all pointers are updated fastDb, _ := ethdb.NewMemDatabase() @@ -816,6 +823,8 @@ func TestLightVsFastVsFullChainHeads(t *testing.T) { t.Fatalf("failed to insert receipt %d: %v", n, err) } assert(t, "fast", fast, height, height, 0) + fast.Rollback(remove) + assert(t, "fast", fast, height/2, height/2, 0) // Import the chain as a light node and ensure all pointers are updated lightDb, _ := ethdb.NewMemDatabase() @@ -827,6 +836,8 @@ func TestLightVsFastVsFullChainHeads(t *testing.T) { t.Fatalf("failed to insert header %d: %v", n, err) } assert(t, "light", light, height, 0, 0) + light.Rollback(remove) + assert(t, "light", light, height/2, 0, 0) } // Tests that chain reorganizations handle transaction removals and reinsertions. diff --git a/eth/downloader/downloader.go b/eth/downloader/downloader.go index e19b70dfd..0298dfa0b 100644 --- a/eth/downloader/downloader.go +++ b/eth/downloader/downloader.go @@ -59,8 +59,8 @@ var ( maxQueuedStates = 256 * 1024 // [eth/63] Maximum number of state requests to queue (DOS protection) maxResultsProcess = 256 // Number of download results to import at once into the chain - headerCheckFrequency = 64 // Verification frequency of the downloaded headers during fast sync - minCheckedHeaders = 1024 // Number of headers to verify fully when approaching the chain head + headerCheckFrequency = 100 // Verification frequency of the downloaded headers during fast sync + minCheckedHeaders = 2048 // Number of headers to verify fully when approaching the chain head minFullBlocks = 1024 // Number of blocks to retrieve fully even in fast sync ) @@ -117,6 +117,7 @@ type Downloader struct { insertHeaders headerChainInsertFn // Injects a batch of headers into the chain insertBlocks blockChainInsertFn // Injects a batch of blocks into the chain insertReceipts receiptChainInsertFn // Injects a batch of blocks and their receipts into the chain + rollback chainRollbackFn // Removes a batch of recently added chain links dropPeer peerDropFn // Drops a peer for misbehaving // Status @@ -152,7 +153,7 @@ type Downloader struct { func New(mode SyncMode, stateDb ethdb.Database, mux *event.TypeMux, hasHeader headerCheckFn, hasBlock blockCheckFn, getHeader headerRetrievalFn, getBlock blockRetrievalFn, headHeader headHeaderRetrievalFn, headBlock headBlockRetrievalFn, headFastBlock headFastBlockRetrievalFn, commitHeadBlock headBlockCommitterFn, getTd tdRetrievalFn, insertHeaders headerChainInsertFn, insertBlocks blockChainInsertFn, - insertReceipts receiptChainInsertFn, dropPeer peerDropFn) *Downloader { + insertReceipts receiptChainInsertFn, rollback chainRollbackFn, dropPeer peerDropFn) *Downloader { return &Downloader{ mode: mode, @@ -171,6 +172,7 @@ func New(mode SyncMode, stateDb ethdb.Database, mux *event.TypeMux, hasHeader he insertHeaders: insertHeaders, insertBlocks: insertBlocks, insertReceipts: insertReceipts, + rollback: rollback, dropPeer: dropPeer, newPeerCh: make(chan *peer, 1), hashCh: make(chan dataPack, 1), @@ -383,7 +385,7 @@ func (d *Downloader) syncWithPeer(p *peer, hash common.Hash, td *big.Int) (err e d.syncStatsChainHeight = latest d.syncStatsLock.Unlock() - // Initiate the sync using a concurrent header and content retrieval algorithm + // Initiate the sync using a concurrent header and content retrieval algorithm pivot := uint64(0) if latest > uint64(minFullBlocks) { pivot = latest - uint64(minFullBlocks) @@ -394,10 +396,10 @@ func (d *Downloader) syncWithPeer(p *peer, hash common.Hash, td *big.Int) (err e d.syncInitHook(origin, latest) } errc := make(chan error, 4) - go func() { errc <- d.fetchHeaders(p, td, origin+1) }() // Headers are always retrieved - go func() { errc <- d.fetchBodies(origin + 1) }() // Bodies are retrieved during normal and fast sync - go func() { errc <- d.fetchReceipts(origin + 1) }() // Receipts are retrieved during fast sync - go func() { errc <- d.fetchNodeData() }() // Node state data is retrieved during fast sync + go func() { errc <- d.fetchHeaders(p, td, origin+1, latest) }() // Headers are always retrieved + go func() { errc <- d.fetchBodies(origin + 1) }() // Bodies are retrieved during normal and fast sync + go func() { errc <- d.fetchReceipts(origin + 1) }() // Receipts are retrieved during fast sync + go func() { errc <- d.fetchNodeData() }() // Node state data is retrieved during fast sync // If any fetcher fails, cancel the others var fail error @@ -1049,10 +1051,28 @@ func (d *Downloader) findAncestor(p *peer) (uint64, error) { // // The queue parameter can be used to switch between queuing headers for block // body download too, or directly import as pure header chains. -func (d *Downloader) fetchHeaders(p *peer, td *big.Int, from uint64) error { +func (d *Downloader) fetchHeaders(p *peer, td *big.Int, from, latest uint64) error { glog.V(logger.Debug).Infof("%v: downloading headers from #%d", p, from) defer glog.V(logger.Debug).Infof("%v: header download terminated", p) + // Keep a count of uncertain headers to roll back + rollback := []*types.Header{} + defer func() { + if len(rollback) > 0 { + hashes := make([]common.Hash, len(rollback)) + for i, header := range rollback { + hashes[i] = header.Hash() + } + d.rollback(hashes) + } + }() + // Calculate the pivoting point for switching from fast to slow sync + pivot := uint64(0) + if d.mode == FastSync && latest > uint64(minFullBlocks) { + pivot = latest - uint64(minFullBlocks) + } else if d.mode == LightSync { + pivot = latest + } // Create a timeout timer, and the associated hash fetcher request := time.Now() // time of the last fetch request timeout := time.NewTimer(0) // timer to dump a non-responsive active peer @@ -1124,10 +1144,30 @@ func (d *Downloader) fetchHeaders(p *peer, td *big.Int, from uint64) error { glog.V(logger.Detail).Infof("%v: schedule %d headers from #%d", p, len(headers), from) if d.mode == FastSync || d.mode == LightSync { - if n, err := d.insertHeaders(headers, headerCheckFrequency); err != nil { + // Collect the yet unknown headers to mark them as uncertain + unknown := make([]*types.Header, 0, len(headers)) + for _, header := range headers { + if !d.hasHeader(header.Hash()) { + unknown = append(unknown, header) + } + } + // If we're importing pure headers, verify based on their recentness + frequency := headerCheckFrequency + if headers[len(headers)-1].Number.Uint64()+uint64(minCheckedHeaders) > pivot { + frequency = 1 + } + if n, err := d.insertHeaders(headers, frequency); err != nil { glog.V(logger.Debug).Infof("%v: invalid header #%d [%x…]: %v", p, headers[n].Number, headers[n].Hash().Bytes()[:4], err) return errInvalidChain } + // All verifications passed, store newly found uncertain headers + rollback = append(rollback, unknown...) + if len(rollback) > minCheckedHeaders { + rollback = append(rollback[:0], rollback[len(rollback)-minCheckedHeaders:]...) + } + if headers[len(headers)-1].Number.Uint64() >= pivot { + rollback = rollback[:0] + } } if d.mode == FullSync || d.mode == FastSync { inserts := d.queue.Schedule(headers, from) diff --git a/eth/downloader/downloader_test.go b/eth/downloader/downloader_test.go index 0e60371b3..f01650ebd 100644 --- a/eth/downloader/downloader_test.go +++ b/eth/downloader/downloader_test.go @@ -152,7 +152,7 @@ func newTester(mode SyncMode) *downloadTester { tester.stateDb, _ = ethdb.NewMemDatabase() tester.downloader = New(mode, tester.stateDb, new(event.TypeMux), tester.hasHeader, tester.hasBlock, tester.getHeader, tester.getBlock, tester.headHeader, tester.headBlock, tester.headFastBlock, tester.commitHeadBlock, tester.getTd, - tester.insertHeaders, tester.insertBlocks, tester.insertReceipts, tester.dropPeer) + tester.insertHeaders, tester.insertBlocks, tester.insertReceipts, tester.rollback, tester.dropPeer) return tester } @@ -272,6 +272,16 @@ func (dl *downloadTester) insertHeaders(headers []*types.Header, checkFreq int) dl.lock.Lock() defer dl.lock.Unlock() + // Do a quick check, as the blockchain.InsertHeaderChain doesn't insert anthing in case of errors + if _, ok := dl.ownHeaders[headers[0].ParentHash]; !ok { + return 0, errors.New("unknown parent") + } + for i := 1; i < len(headers); i++ { + if headers[i].ParentHash != headers[i-1].Hash() { + return i, errors.New("unknown parent") + } + } + // Do a full insert if pre-checks passed for i, header := range headers { if _, ok := dl.ownHeaders[header.Hash()]; ok { continue @@ -322,6 +332,22 @@ func (dl *downloadTester) insertReceipts(blocks types.Blocks, receipts []types.R return len(blocks), nil } +// rollback removes some recently added elements from the chain. +func (dl *downloadTester) rollback(hashes []common.Hash) { + dl.lock.Lock() + defer dl.lock.Unlock() + + for i := len(hashes) - 1; i >= 0; i-- { + if dl.ownHashes[len(dl.ownHashes)-1] == hashes[i] { + dl.ownHashes = dl.ownHashes[:len(dl.ownHashes)-1] + } + delete(dl.ownChainTd, hashes[i]) + delete(dl.ownHeaders, hashes[i]) + delete(dl.ownReceipts, hashes[i]) + delete(dl.ownBlocks, hashes[i]) + } +} + // newPeer registers a new block download source into the downloader. func (dl *downloadTester) newPeer(id string, version int, hashes []common.Hash, headers map[common.Hash]*types.Header, blocks map[common.Hash]*types.Block, receipts map[common.Hash]types.Receipts) error { return dl.newSlowPeer(id, version, hashes, headers, blocks, receipts, 0) @@ -1031,6 +1057,56 @@ func testShiftedHeaderAttack(t *testing.T, protocol int, mode SyncMode) { assertOwnChain(t, tester, targetBlocks+1) } +// Tests that upon detecting an invalid header, the recent ones are rolled back +func TestInvalidHeaderRollback63Fast(t *testing.T) { testInvalidHeaderRollback(t, 63, FastSync) } +func TestInvalidHeaderRollback64Fast(t *testing.T) { testInvalidHeaderRollback(t, 64, FastSync) } +func TestInvalidHeaderRollback64Light(t *testing.T) { testInvalidHeaderRollback(t, 64, LightSync) } + +func testInvalidHeaderRollback(t *testing.T, protocol int, mode SyncMode) { + // Create a small enough block chain to download + targetBlocks := 3*minCheckedHeaders + minFullBlocks + hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil) + + tester := newTester(mode) + + // Attempt to sync with an attacker that feeds junk during the fast sync phase + tester.newPeer("fast-attack", protocol, hashes, headers, blocks, receipts) + missing := minCheckedHeaders + MaxHeaderFetch + 1 + delete(tester.peerHeaders["fast-attack"], hashes[len(hashes)-missing]) + + if err := tester.sync("fast-attack", nil); err == nil { + t.Fatalf("succeeded fast attacker synchronisation") + } + if head := tester.headHeader().Number.Int64(); int(head) > MaxHeaderFetch { + t.Fatalf("rollback head mismatch: have %v, want at most %v", head, MaxHeaderFetch) + } + // Attempt to sync with an attacker that feeds junk during the block import phase + tester.newPeer("block-attack", protocol, hashes, headers, blocks, receipts) + missing = 3*minCheckedHeaders + MaxHeaderFetch + 1 + delete(tester.peerHeaders["block-attack"], hashes[len(hashes)-missing]) + + if err := tester.sync("block-attack", nil); err == nil { + t.Fatalf("succeeded block attacker synchronisation") + } + if mode == FastSync { + // Fast sync should not discard anything below the verified pivot point + if head := tester.headHeader().Number.Int64(); int(head) < 3*minCheckedHeaders { + t.Fatalf("rollback head mismatch: have %v, want at least %v", head, 3*minCheckedHeaders) + } + } else if mode == LightSync { + // Light sync should still discard data as before + if head := tester.headHeader().Number.Int64(); int(head) > 3*minCheckedHeaders { + t.Fatalf("rollback head mismatch: have %v, want at most %v", head, 3*minCheckedHeaders) + } + } + // Synchronise with the valid peer and make sure sync succeeds + tester.newPeer("valid", protocol, hashes, headers, blocks, receipts) + if err := tester.sync("valid", nil); err != nil { + t.Fatalf("failed to synchronise blocks: %v", err) + } + assertOwnChain(t, tester, targetBlocks+1) +} + // Tests that if a peer sends an invalid block piece (body or receipt) for a // requested block, it gets dropped immediately by the downloader. func TestInvalidContentAttack62(t *testing.T) { testInvalidContentAttack(t, 62, FullSync) } diff --git a/eth/downloader/types.go b/eth/downloader/types.go index 60d9a2b12..5937be606 100644 --- a/eth/downloader/types.go +++ b/eth/downloader/types.go @@ -60,6 +60,9 @@ type blockChainInsertFn func(types.Blocks) (int, error) // receiptChainInsertFn is a callback type to insert a batch of receipts into the local chain. type receiptChainInsertFn func(types.Blocks, []types.Receipts) (int, error) +// chainRollbackFn is a callback type to remove a few recently added elements from the local chain. +type chainRollbackFn func([]common.Hash) + // peerDropFn is a callback type for dropping a peer detected as malicious. type peerDropFn func(id string) diff --git a/eth/handler.go b/eth/handler.go index b0916d50b..40a578842 100644 --- a/eth/handler.go +++ b/eth/handler.go @@ -131,7 +131,7 @@ func NewProtocolManager(mode Mode, networkId int, mux *event.TypeMux, txpool txP } manager.downloader = downloader.New(syncMode, chaindb, manager.eventMux, blockchain.HasHeader, blockchain.HasBlock, blockchain.GetHeader, blockchain.GetBlock, blockchain.CurrentHeader, blockchain.CurrentBlock, blockchain.CurrentFastBlock, blockchain.FastSyncCommitHead, - blockchain.GetTd, blockchain.InsertHeaderChain, blockchain.InsertChain, blockchain.InsertReceiptChain, manager.removePeer) + blockchain.GetTd, blockchain.InsertHeaderChain, blockchain.InsertChain, blockchain.InsertReceiptChain, blockchain.Rollback, manager.removePeer) validator := func(block *types.Block, parent *types.Block) error { return core.ValidateHeader(pow, block.Header(), parent.Header(), true, false) |