diff options
Diffstat (limited to 'dex/handler.go')
-rw-r--r-- | dex/handler.go | 66 |
1 files changed, 35 insertions, 31 deletions
diff --git a/dex/handler.go b/dex/handler.go index 854b646b6..855445a9e 100644 --- a/dex/handler.go +++ b/dex/handler.go @@ -34,6 +34,7 @@ package dex import ( + "bytes" "encoding/json" "errors" "fmt" @@ -509,7 +510,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { if currentBlock.NumberU64() < last.Number.Uint64() { log.Debug("Current block < last request", "current", currentBlock.NumberU64(), "last", last.Number.Uint64()) - return p.SendBlockHeaders([]*types.HeaderWithGovState{}) + return p.SendBlockHeaders(query.Flag, []*types.HeaderWithGovState{}) } snapshotHeight := map[uint64]struct{}{} @@ -531,44 +532,44 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { s, err := pm.blockchain.GetGovStateByHash(header.Hash()) if err != nil { log.Warn("Get gov state by hash fail", "number", header.Number.Uint64(), "err", err) - return p.SendBlockHeaders([]*types.HeaderWithGovState{}) + return p.SendBlockHeaders(query.Flag, []*types.HeaderWithGovState{}) } header.GovState = s } log.Trace("Send header", "round", header.Round, "number", header.Number.Uint64(), "gov state == nil", header.GovState == nil) } } - return p.SendBlockHeaders(headers) + return p.SendBlockHeaders(query.Flag, headers) case msg.Code == BlockHeadersMsg: // A batch of headers arrived to one of our previous requests - var headers []*types.HeaderWithGovState - if err := msg.Decode(&headers); err != nil { + var data headersData + if err := msg.Decode(&data); err != nil { return errResp(ErrDecode, "msg %v: %v", msg, err) } - // Filter out any explicitly requested headers, deliver the rest to the downloader - filter := len(headers) == 1 - if filter { - h := []*types.Header{headers[0].Header} - h = pm.fetcher.FilterHeaders(p.id, h, time.Now()) - if len(h) == 0 { - headers = nil + + switch data.Flag { + case fetcherReq: + if len(data.Headers) > 0 { + pm.fetcher.FilterHeaders(p.id, []*types.Header{data.Headers[0].Header}, time.Now()) } - } - for _, header := range headers { - log.Trace("Received header", "round", header.Round, "number", header.Number.Uint64(), "gov state == nil", header.GovState == nil) - } - if len(headers) > 0 || !filter { - // if the header that has gov state is filter out, the header's gov state is useless - err := pm.downloader.DeliverHeaders(p.id, headers) + case downloaderReq: + err := pm.downloader.DeliverHeaders(p.id, data.Headers) if err != nil { log.Debug("Failed to deliver headers", "err", err) } + default: + log.Debug("Got headers with unexpected flag", "flag", data.Flag) } case msg.Code == GetBlockBodiesMsg: // Decode the retrieval message - msgStream := rlp.NewStream(msg.Payload, uint64(msg.Size)) + var query getBlockBodiesData + if err := msg.Decode(&query); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + + msgStream := rlp.NewStream(bytes.NewBuffer(query.Hashes), uint64(len(query.Hashes))) if _, err := msgStream.List(); err != nil { return err } @@ -591,7 +592,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { bytes += len(data) } } - return p.SendBlockBodiesRLP(bodies) + return p.SendBlockBodiesRLP(query.Flag, bodies) case msg.Code == BlockBodiesMsg: // A batch of block bodies arrived to one of our previous requests @@ -600,23 +601,26 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { return errResp(ErrDecode, "msg %v: %v", msg, err) } // Deliver them all to the downloader for queuing - transactions := make([][]*types.Transaction, len(request)) - uncles := make([][]*types.Header, len(request)) + transactions := make([][]*types.Transaction, len(request.Bodies)) + uncles := make([][]*types.Header, len(request.Bodies)) - for i, body := range request { + for i, body := range request.Bodies { transactions[i] = body.Transactions uncles[i] = body.Uncles } - // Filter out any explicitly requested bodies, deliver the rest to the downloader - filter := len(transactions) > 0 || len(uncles) > 0 - if filter { - transactions, uncles = pm.fetcher.FilterBodies(p.id, transactions, uncles, time.Now()) - } - if len(transactions) > 0 || len(uncles) > 0 || !filter { + + switch request.Flag { + case fetcherReq: + if len(transactions) > 0 || len(uncles) > 0 { + pm.fetcher.FilterBodies(p.id, transactions, uncles, time.Now()) + } + case downloaderReq: err := pm.downloader.DeliverBodies(p.id, transactions, uncles) if err != nil { log.Debug("Failed to deliver bodies", "err", err) } + default: + log.Debug("Got bodies with unexpected flag", "flag", request.Flag) } case msg.Code == GetNodeDataMsg: @@ -721,7 +725,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { } } for _, block := range unknown { - pm.fetcher.Notify(p.id, block.Hash, block.Number, time.Now(), p.RequestOneHeader, p.RequestBodies) + pm.fetcher.Notify(p.id, block.Hash, block.Number, time.Now(), p.RequestOneHeader, p.FetchBodies) } case msg.Code == NewBlockMsg: |