diff options
Diffstat (limited to 'dex/handler.go')
-rw-r--r-- | dex/handler.go | 85 |
1 files changed, 80 insertions, 5 deletions
diff --git a/dex/handler.go b/dex/handler.go index d66403fe6..0c7a3e919 100644 --- a/dex/handler.go +++ b/dex/handler.go @@ -34,6 +34,7 @@ package dex import ( + "context" "encoding/json" "errors" "fmt" @@ -220,7 +221,7 @@ func NewProtocolManager( return 0, nil } atomic.StoreUint32(&manager.acceptTxs, 1) // Mark initial sync done on any fetcher import - return manager.blockchain.InsertChain(blocks) + return manager.blockchain.InsertChain2(blocks) } manager.fetcher = fetcher.New(blockchain.GetBlockByHash, validator, manager.BroadcastBlock, heighter, inserter, manager.removePeer) @@ -414,10 +415,11 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { first := true maxNonCanonical := uint64(100) + round := map[uint64]uint64{} // Gather headers until the fetch or network limits is reached var ( bytes common.StorageSize - headers []*types.Header + headers []*types.HeaderWithGovState unknown bool ) for !unknown && len(headers) < int(query.Amount) && bytes < softResponseLimit && len(headers) < downloader.MaxHeaderFetch { @@ -439,7 +441,10 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { if origin == nil { break } - headers = append(headers, origin) + headers = append(headers, &types.HeaderWithGovState{Header: origin}) + if round[origin.Round] == 0 { + round[origin.Round] = origin.Number.Uint64() + } bytes += estHeaderRlpSize // Advance to the next header of the query @@ -489,20 +494,71 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { query.Origin.Number += query.Skip + 1 } } + + ctx := context.Background() + if query.WithGov && len(headers) > 0 { + last := headers[len(headers)-1] + currentBlock := pm.blockchain.CurrentBlock() + + // Do not reply if we don't have current gov state + if currentBlock.NumberU64() < last.Number.Uint64() { + log.Debug("Current block < last request", + "current", currentBlock.NumberU64(), "last", last.Number.Uint64()) + return p.SendBlockHeaders([]*types.HeaderWithGovState{}) + } + + snapshotHeight := map[uint64]struct{}{} + for r, height := range round { + log.Trace("#Include round", "round", r) + if r == 0 { + continue + } + h, err := pm.gov.GetRoundHeight(ctx, r) + if err != nil { + log.Warn("Get round height fail", "err", err) + return p.SendBlockHeaders([]*types.HeaderWithGovState{}) + } + log.Trace("#Snapshot height", "height", h) + if h == 0 { + h = height + } + snapshotHeight[h] = struct{}{} + } + + for _, header := range headers { + if _, exist := snapshotHeight[header.Number.Uint64()]; exist { + s, err := pm.blockchain.GetGovStateByHash(header.Hash()) + if err != nil { + log.Warn("Get gov state by hash fail", "number", header.Number.Uint64(), "err", err) + return p.SendBlockHeaders([]*types.HeaderWithGovState{}) + } + header.GovState = s + } + log.Trace("Send header", "round", header.Round, "number", header.Number.Uint64(), "gov state == nil", header.GovState == nil) + } + } return p.SendBlockHeaders(headers) case msg.Code == BlockHeadersMsg: // A batch of headers arrived to one of our previous requests - var headers []*types.Header + var headers []*types.HeaderWithGovState if err := msg.Decode(&headers); err != nil { return errResp(ErrDecode, "msg %v: %v", msg, err) } // Filter out any explicitly requested headers, deliver the rest to the downloader filter := len(headers) == 1 if filter { - headers = pm.fetcher.FilterHeaders(p.id, headers, time.Now()) + h := []*types.Header{headers[0].Header} + h = pm.fetcher.FilterHeaders(p.id, h, time.Now()) + if len(h) == 0 { + headers = nil + } + } + for _, header := range headers { + log.Trace("Received header", "round", header.Round, "number", header.Number.Uint64(), "gov state == nil", header.GovState == nil) } if len(headers) > 0 || !filter { + // if the header that has gov state is filter out, the header's gov state is useless err := pm.downloader.DeliverHeaders(p.id, headers) if err != nil { log.Debug("Failed to deliver headers", "err", err) @@ -834,6 +890,25 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { return err } } + case msg.Code == GetGovStateMsg: + var hash common.Hash + if err := msg.Decode(&hash); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + govState, err := pm.blockchain.GetGovStateByHash(hash) + if err != nil { + // TODO(sonic): handle this error + panic(err) + } + return p.SendGovState(govState) + case msg.Code == GovStateMsg: + var govState types.GovState + if err := msg.Decode(&govState); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + if err := pm.downloader.DeliverGovState(p.id, &govState); err != nil { + log.Debug("Failed to deliver govstates", "err", err) + } default: return errResp(ErrInvalidMsgCode, "%v", msg.Code) } |