From b3622c655271c594d19c78eb92b26c1f539ea31b Mon Sep 17 00:00:00 2001
From: Sonic <sonic@dexon.org>
Date: Tue, 12 Feb 2019 16:16:13 +0800
Subject: dex: Add a flag to GetBlockHeadersMsg and GetBlockBodiesMsg (#188)

* dex: Add a flag to GetBlockHeadersMsg and GetBlockBodiesMsg

So that we can dispatch the response msg to fetcher or downloader
easily.

* fixup! dex: Add a flag to GetBlockHeadersMsg and GetBlockBodiesMsg
---
 dex/downloader/downloader_test.go |  8 ++---
 dex/downloader/fakepeer.go        |  4 +--
 dex/downloader/peer.go            |  8 ++---
 dex/handler.go                    | 66 +++++++++++++++++++++------------------
 dex/handler_test.go               |  8 ++---
 dex/peer.go                       | 37 ++++++++++++----------
 dex/protocol.go                   | 27 +++++++++++++++-
 7 files changed, 95 insertions(+), 63 deletions(-)

diff --git a/dex/downloader/downloader_test.go b/dex/downloader/downloader_test.go
index 80993bd75..e8ec0056b 100644
--- a/dex/downloader/downloader_test.go
+++ b/dex/downloader/downloader_test.go
@@ -376,10 +376,10 @@ func (dlp *downloadTesterPeer) RequestGovStateByHash(hash common.Hash) error {
 	return nil
 }
 
-// RequestBodies constructs a getBlockBodies method associated with a particular
+// DownloadBodies constructs a getBlockBodies method associated with a particular
 // peer in the download tester. The returned function can be used to retrieve
 // batches of block bodies from the particularly requested peer.
-func (dlp *downloadTesterPeer) RequestBodies(hashes []common.Hash) error {
+func (dlp *downloadTesterPeer) DownloadBodies(hashes []common.Hash) error {
 	txs, uncles := dlp.chain.bodies(hashes)
 	go dlp.dl.downloader.DeliverBodies(dlp.id, txs, uncles)
 	return nil
@@ -1322,8 +1322,8 @@ func (ftp *floodingTestPeer) RequestHeadersByHash(hash common.Hash, count int, s
 func (ftp *floodingTestPeer) RequestGovStateByHash(hash common.Hash) error {
 	return ftp.peer.RequestGovStateByHash(hash)
 }
-func (ftp *floodingTestPeer) RequestBodies(hashes []common.Hash) error {
-	return ftp.peer.RequestBodies(hashes)
+func (ftp *floodingTestPeer) DownloadBodies(hashes []common.Hash) error {
+	return ftp.peer.DownloadBodies(hashes)
 }
 func (ftp *floodingTestPeer) RequestReceipts(hashes []common.Hash) error {
 	return ftp.peer.RequestReceipts(hashes)
diff --git a/dex/downloader/fakepeer.go b/dex/downloader/fakepeer.go
index f0d596a4b..f4ff9b517 100644
--- a/dex/downloader/fakepeer.go
+++ b/dex/downloader/fakepeer.go
@@ -132,9 +132,9 @@ func (p *FakePeer) RequestHeadersByNumber(number uint64, amount int, skip int, r
 	return nil
 }
 
-// RequestBodies implements downloader.Peer, returning a batch of block bodies
+// DownloadBodies implements downloader.Peer, returning a batch of block bodies
 // corresponding to the specified block hashes.
-func (p *FakePeer) RequestBodies(hashes []common.Hash) error {
+func (p *FakePeer) DownloadBodies(hashes []common.Hash) error {
 	var (
 		txs    [][]*types.Transaction
 		uncles [][]*types.Header
diff --git a/dex/downloader/peer.go b/dex/downloader/peer.go
index 25c355df1..e1c6960f1 100644
--- a/dex/downloader/peer.go
+++ b/dex/downloader/peer.go
@@ -85,7 +85,7 @@ type LightPeer interface {
 // Peer encapsulates the methods required to synchronise with a remote full peer.
 type Peer interface {
 	LightPeer
-	RequestBodies([]common.Hash) error
+	DownloadBodies([]common.Hash) error
 	RequestReceipts([]common.Hash) error
 	RequestNodeData([]common.Hash) error
 }
@@ -106,8 +106,8 @@ func (w *lightPeerWrapper) RequestGovStateByHash(common.Hash) error {
 	// TODO(sonic): support this
 	panic("RequestGovStateByHash not supported in light client mode sync")
 }
-func (w *lightPeerWrapper) RequestBodies([]common.Hash) error {
-	panic("RequestBodies not supported in light client mode sync")
+func (w *lightPeerWrapper) DownloadBodies([]common.Hash) error {
+	panic("DownloadBodies not supported in light client mode sync")
 }
 func (w *lightPeerWrapper) RequestReceipts([]common.Hash) error {
 	panic("RequestReceipts not supported in light client mode sync")
@@ -182,7 +182,7 @@ func (p *peerConnection) FetchBodies(request *fetchRequest) error {
 	for _, header := range request.Headers {
 		hashes = append(hashes, header.Hash())
 	}
-	go p.peer.RequestBodies(hashes)
+	go p.peer.DownloadBodies(hashes)
 
 	return nil
 }
diff --git a/dex/handler.go b/dex/handler.go
index 490e1ec33..71962b865 100644
--- a/dex/handler.go
+++ b/dex/handler.go
@@ -34,6 +34,7 @@
 package dex
 
 import (
+	"bytes"
 	"encoding/json"
 	"errors"
 	"fmt"
@@ -504,7 +505,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
 			if currentBlock.NumberU64() < last.Number.Uint64() {
 				log.Debug("Current block < last request",
 					"current", currentBlock.NumberU64(), "last", last.Number.Uint64())
-				return p.SendBlockHeaders([]*types.HeaderWithGovState{})
+				return p.SendBlockHeaders(query.Flag, []*types.HeaderWithGovState{})
 			}
 
 			snapshotHeight := map[uint64]struct{}{}
@@ -526,44 +527,44 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
 					s, err := pm.blockchain.GetGovStateByHash(header.Hash())
 					if err != nil {
 						log.Warn("Get gov state by hash fail", "number", header.Number.Uint64(), "err", err)
-						return p.SendBlockHeaders([]*types.HeaderWithGovState{})
+						return p.SendBlockHeaders(query.Flag, []*types.HeaderWithGovState{})
 					}
 					header.GovState = s
 				}
 				log.Trace("Send header", "round", header.Round, "number", header.Number.Uint64(), "gov state == nil", header.GovState == nil)
 			}
 		}
-		return p.SendBlockHeaders(headers)
+		return p.SendBlockHeaders(query.Flag, headers)
 
 	case msg.Code == BlockHeadersMsg:
 		// A batch of headers arrived to one of our previous requests
-		var headers []*types.HeaderWithGovState
-		if err := msg.Decode(&headers); err != nil {
+		var data headersData
+		if err := msg.Decode(&data); err != nil {
 			return errResp(ErrDecode, "msg %v: %v", msg, err)
 		}
-		// Filter out any explicitly requested headers, deliver the rest to the downloader
-		filter := len(headers) == 1
-		if filter {
-			h := []*types.Header{headers[0].Header}
-			h = pm.fetcher.FilterHeaders(p.id, h, time.Now())
-			if len(h) == 0 {
-				headers = nil
+
+		switch data.Flag {
+		case fetcherReq:
+			if len(data.Headers) > 0 {
+				pm.fetcher.FilterHeaders(p.id, []*types.Header{data.Headers[0].Header}, time.Now())
 			}
-		}
-		for _, header := range headers {
-			log.Trace("Received header", "round", header.Round, "number", header.Number.Uint64(), "gov state == nil", header.GovState == nil)
-		}
-		if len(headers) > 0 || !filter {
-			// if the header that has gov state is filter out, the header's gov state is useless
-			err := pm.downloader.DeliverHeaders(p.id, headers)
+		case downloaderReq:
+			err := pm.downloader.DeliverHeaders(p.id, data.Headers)
 			if err != nil {
 				log.Debug("Failed to deliver headers", "err", err)
 			}
+		default:
+			log.Debug("Got headers with unexpected flag", "flag", data.Flag)
 		}
 
 	case msg.Code == GetBlockBodiesMsg:
 		// Decode the retrieval message
-		msgStream := rlp.NewStream(msg.Payload, uint64(msg.Size))
+		var query getBlockBodiesData
+		if err := msg.Decode(&query); err != nil {
+			return errResp(ErrDecode, "msg %v: %v", msg, err)
+		}
+
+		msgStream := rlp.NewStream(bytes.NewBuffer(query.Hashes), uint64(len(query.Hashes)))
 		if _, err := msgStream.List(); err != nil {
 			return err
 		}
@@ -586,7 +587,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
 				bytes += len(data)
 			}
 		}
-		return p.SendBlockBodiesRLP(bodies)
+		return p.SendBlockBodiesRLP(query.Flag, bodies)
 
 	case msg.Code == BlockBodiesMsg:
 		// A batch of block bodies arrived to one of our previous requests
@@ -595,23 +596,26 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
 			return errResp(ErrDecode, "msg %v: %v", msg, err)
 		}
 		// Deliver them all to the downloader for queuing
-		transactions := make([][]*types.Transaction, len(request))
-		uncles := make([][]*types.Header, len(request))
+		transactions := make([][]*types.Transaction, len(request.Bodies))
+		uncles := make([][]*types.Header, len(request.Bodies))
 
-		for i, body := range request {
+		for i, body := range request.Bodies {
 			transactions[i] = body.Transactions
 			uncles[i] = body.Uncles
 		}
-		// Filter out any explicitly requested bodies, deliver the rest to the downloader
-		filter := len(transactions) > 0 || len(uncles) > 0
-		if filter {
-			transactions, uncles = pm.fetcher.FilterBodies(p.id, transactions, uncles, time.Now())
-		}
-		if len(transactions) > 0 || len(uncles) > 0 || !filter {
+
+		switch request.Flag {
+		case fetcherReq:
+			if len(transactions) > 0 || len(uncles) > 0 {
+				pm.fetcher.FilterBodies(p.id, transactions, uncles, time.Now())
+			}
+		case downloaderReq:
 			err := pm.downloader.DeliverBodies(p.id, transactions, uncles)
 			if err != nil {
 				log.Debug("Failed to deliver bodies", "err", err)
 			}
+		default:
+			log.Debug("Got bodies with unexpected flag", "flag", request.Flag)
 		}
 
 	case msg.Code == GetNodeDataMsg:
@@ -716,7 +720,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
 			}
 		}
 		for _, block := range unknown {
-			pm.fetcher.Notify(p.id, block.Hash, block.Number, time.Now(), p.RequestOneHeader, p.RequestBodies)
+			pm.fetcher.Notify(p.id, block.Hash, block.Number, time.Now(), p.RequestOneHeader, p.FetchBodies)
 		}
 
 	case msg.Code == NewBlockMsg:
diff --git a/dex/handler_test.go b/dex/handler_test.go
index d8398bd2d..75a57c125 100644
--- a/dex/handler_test.go
+++ b/dex/handler_test.go
@@ -205,7 +205,7 @@ func testGetBlockHeaders(t *testing.T, protocol int) {
 		}
 		// Send the hash request and verify the response
 		p2p.Send(peer.app, 0x03, tt.query)
-		if err := p2p.ExpectMsg(peer.app, 0x04, headers); err != nil {
+		if err := p2p.ExpectMsg(peer.app, 0x04, headersData{Headers: headers}); err != nil {
 			t.Errorf("test %d: headers mismatch: %v", i, err)
 		}
 		// If the test used number origins, repeat with hashes as the too
@@ -214,7 +214,7 @@ func testGetBlockHeaders(t *testing.T, protocol int) {
 				tt.query.Origin.Hash, tt.query.Origin.Number = origin.Hash(), 0
 
 				p2p.Send(peer.app, 0x03, tt.query)
-				if err := p2p.ExpectMsg(peer.app, 0x04, headers); err != nil {
+				if err := p2p.ExpectMsg(peer.app, 0x04, headersData{Headers: headers}); err != nil {
 					t.Errorf("test %d: headers mismatch: %v", i, err)
 				}
 			}
@@ -287,8 +287,8 @@ func testGetBlockBodies(t *testing.T, protocol int) {
 			}
 		}
 		// Send the hash request and verify the response
-		p2p.Send(peer.app, 0x05, hashes)
-		if err := p2p.ExpectMsg(peer.app, 0x06, bodies); err != nil {
+		p2p.Send(peer.app, 0x05, []interface{}{downloaderReq, hashes})
+		if err := p2p.ExpectMsg(peer.app, 0x06, blockBodiesData{Flag: downloaderReq, Bodies: bodies}); err != nil {
 			t.Errorf("test %d: bodies mismatch: %v", i, err)
 		}
 	}
diff --git a/dex/peer.go b/dex/peer.go
index 97f42ccac..2c531ee07 100644
--- a/dex/peer.go
+++ b/dex/peer.go
@@ -563,19 +563,14 @@ func (p *peer) AsyncSendPullRandomness(hashes coreCommon.Hashes) {
 }
 
 // SendBlockHeaders sends a batch of block headers to the remote peer.
-func (p *peer) SendBlockHeaders(headers []*types.HeaderWithGovState) error {
-	return p2p.Send(p.rw, BlockHeadersMsg, headers)
-}
-
-// SendBlockBodies sends a batch of block contents to the remote peer.
-func (p *peer) SendBlockBodies(bodies []*blockBody) error {
-	return p2p.Send(p.rw, BlockBodiesMsg, blockBodiesData(bodies))
+func (p *peer) SendBlockHeaders(flag uint8, headers []*types.HeaderWithGovState) error {
+	return p2p.Send(p.rw, BlockHeadersMsg, headersData{Flag: flag, Headers: headers})
 }
 
 // SendBlockBodiesRLP sends a batch of block contents to the remote peer from
 // an already RLP encoded format.
-func (p *peer) SendBlockBodiesRLP(bodies []rlp.RawValue) error {
-	return p2p.Send(p.rw, BlockBodiesMsg, bodies)
+func (p *peer) SendBlockBodiesRLP(flag uint8, bodies []rlp.RawValue) error {
+	return p2p.Send(p.rw, BlockBodiesMsg, blockBodiesDataRLP{Flag: flag, Bodies: bodies})
 }
 
 // SendNodeDataRLP sends a batch of arbitrary internal data, corresponding to the
@@ -598,21 +593,21 @@ func (p *peer) SendGovState(govState *types.GovState) error {
 // single header. It is used solely by the fetcher.
 func (p *peer) RequestOneHeader(hash common.Hash) error {
 	p.Log().Debug("Fetching single header", "hash", hash)
-	return p2p.Send(p.rw, GetBlockHeadersMsg, &getBlockHeadersData{Origin: hashOrNumber{Hash: hash}, Amount: uint64(1), Skip: uint64(0), Reverse: false, WithGov: false})
+	return p2p.Send(p.rw, GetBlockHeadersMsg, &getBlockHeadersData{Origin: hashOrNumber{Hash: hash}, Amount: uint64(1), Skip: uint64(0), Reverse: false, WithGov: false, Flag: fetcherReq})
 }
 
 // RequestHeadersByHash fetches a batch of blocks' headers corresponding to the
 // specified header query, based on the hash of an origin block.
 func (p *peer) RequestHeadersByHash(origin common.Hash, amount int, skip int, reverse, withGov bool) error {
-	p.Log().Debug("Fetching batch of headers", "count", amount, "fromhash", origin, "skip", skip, "reverse", reverse, "withgov", withGov)
-	return p2p.Send(p.rw, GetBlockHeadersMsg, &getBlockHeadersData{Origin: hashOrNumber{Hash: origin}, Amount: uint64(amount), Skip: uint64(skip), Reverse: reverse, WithGov: withGov})
+	p.Log().Debug("Fetching batch of headers", "count", amount, "fromhash", origin, "skip", skip, "reverse", reverse, "withgov", withGov, "flag", downloaderReq)
+	return p2p.Send(p.rw, GetBlockHeadersMsg, &getBlockHeadersData{Origin: hashOrNumber{Hash: origin}, Amount: uint64(amount), Skip: uint64(skip), Reverse: reverse, WithGov: withGov, Flag: downloaderReq})
 }
 
 // RequestHeadersByNumber fetches a batch of blocks' headers corresponding to the
 // specified header query, based on the number of an origin block.
 func (p *peer) RequestHeadersByNumber(origin uint64, amount int, skip int, reverse, withGov bool) error {
-	p.Log().Debug("Fetching batch of headers", "count", amount, "fromnum", origin, "skip", skip, "reverse", reverse, "withgov", withGov)
-	return p2p.Send(p.rw, GetBlockHeadersMsg, &getBlockHeadersData{Origin: hashOrNumber{Number: origin}, Amount: uint64(amount), Skip: uint64(skip), Reverse: reverse, WithGov: withGov})
+	p.Log().Debug("Fetching batch of headers", "count", amount, "fromnum", origin, "skip", skip, "reverse", reverse, "withgov", withGov, "flag", downloaderReq)
+	return p2p.Send(p.rw, GetBlockHeadersMsg, &getBlockHeadersData{Origin: hashOrNumber{Number: origin}, Amount: uint64(amount), Skip: uint64(skip), Reverse: reverse, WithGov: withGov, Flag: downloaderReq})
 }
 
 func (p *peer) RequestGovStateByHash(hash common.Hash) error {
@@ -622,9 +617,17 @@ func (p *peer) RequestGovStateByHash(hash common.Hash) error {
 
 // RequestBodies fetches a batch of blocks' bodies corresponding to the hashes
 // specified.
-func (p *peer) RequestBodies(hashes []common.Hash) error {
-	p.Log().Debug("Fetching batch of block bodies", "count", len(hashes))
-	return p2p.Send(p.rw, GetBlockBodiesMsg, hashes)
+func (p *peer) RequestBodies(flag uint8, hashes []common.Hash) error {
+	p.Log().Debug("Fetching batch of block bodies", "count", len(hashes), "flag", flag)
+	return p2p.Send(p.rw, GetBlockBodiesMsg, []interface{}{flag, hashes})
+}
+
+func (p *peer) FetchBodies(hashes []common.Hash) error {
+	return p.RequestBodies(fetcherReq, hashes)
+}
+
+func (p *peer) DownloadBodies(hashes []common.Hash) error {
+	return p.RequestBodies(downloaderReq, hashes)
 }
 
 // RequestNodeData fetches a batch of arbitrary data from a node's known state
diff --git a/dex/protocol.go b/dex/protocol.go
index 0cb00ada6..6ee02959a 100644
--- a/dex/protocol.go
+++ b/dex/protocol.go
@@ -113,6 +113,11 @@ const (
 	ErrSuspendedPeer
 )
 
+const (
+	fetcherReq = uint8(iota)
+	downloaderReq
+)
+
 func (e errCode) String() string {
 	return errorToString[int(e)]
 }
@@ -195,6 +200,7 @@ type getBlockHeadersData struct {
 	Reverse bool         // Query direction (false = rising towards latest, true = falling towards genesis)
 
 	WithGov bool
+	Flag    uint8
 }
 
 // hashOrNumber is a combined field for specifying an origin block.
@@ -233,11 +239,27 @@ func (hn *hashOrNumber) DecodeRLP(s *rlp.Stream) error {
 	return err
 }
 
+// headersData is the network packet for header content distribution.
+type headersData struct {
+	Flag    uint8
+	Headers []*types.HeaderWithGovState
+}
+
 // newBlockData is the network packet for the block propagation message.
 type newBlockData struct {
 	Block *types.Block
 }
 
+type getBlockBodiesData struct {
+	Flag   uint8
+	Hashes rlp.RawValue
+}
+
+type blockBodiesDataRLP struct {
+	Flag   uint8
+	Bodies []rlp.RawValue
+}
+
 // blockBody represents the data content of a single block.
 type blockBody struct {
 	Transactions []*types.Transaction // Transactions contained within a block
@@ -245,7 +267,10 @@ type blockBody struct {
 }
 
 // blockBodiesData is the network packet for block content distribution.
-type blockBodiesData []*blockBody
+type blockBodiesData struct {
+	Flag   uint8
+	Bodies []*blockBody
+}
 
 func rlpHash(x interface{}) (h common.Hash) {
 	hw := sha3.NewLegacyKeccak256()
-- 
cgit v1.2.3