aboutsummaryrefslogtreecommitdiffstats
path: root/dex/handler.go
diff options
context:
space:
mode:
Diffstat (limited to 'dex/handler.go')
-rw-r--r--dex/handler.go85
1 files changed, 80 insertions, 5 deletions
diff --git a/dex/handler.go b/dex/handler.go
index d66403fe6..0c7a3e919 100644
--- a/dex/handler.go
+++ b/dex/handler.go
@@ -34,6 +34,7 @@
package dex
import (
+ "context"
"encoding/json"
"errors"
"fmt"
@@ -220,7 +221,7 @@ func NewProtocolManager(
return 0, nil
}
atomic.StoreUint32(&manager.acceptTxs, 1) // Mark initial sync done on any fetcher import
- return manager.blockchain.InsertChain(blocks)
+ return manager.blockchain.InsertChain2(blocks)
}
manager.fetcher = fetcher.New(blockchain.GetBlockByHash, validator, manager.BroadcastBlock, heighter, inserter, manager.removePeer)
@@ -414,10 +415,11 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
first := true
maxNonCanonical := uint64(100)
+ round := map[uint64]uint64{}
// Gather headers until the fetch or network limits is reached
var (
bytes common.StorageSize
- headers []*types.Header
+ headers []*types.HeaderWithGovState
unknown bool
)
for !unknown && len(headers) < int(query.Amount) && bytes < softResponseLimit && len(headers) < downloader.MaxHeaderFetch {
@@ -439,7 +441,10 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
if origin == nil {
break
}
- headers = append(headers, origin)
+ headers = append(headers, &types.HeaderWithGovState{Header: origin})
+ if round[origin.Round] == 0 {
+ round[origin.Round] = origin.Number.Uint64()
+ }
bytes += estHeaderRlpSize
// Advance to the next header of the query
@@ -489,20 +494,71 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
query.Origin.Number += query.Skip + 1
}
}
+
+ ctx := context.Background()
+ if query.WithGov && len(headers) > 0 {
+ last := headers[len(headers)-1]
+ currentBlock := pm.blockchain.CurrentBlock()
+
+ // Do not reply if we don't have current gov state
+ if currentBlock.NumberU64() < last.Number.Uint64() {
+ log.Debug("Current block < last request",
+ "current", currentBlock.NumberU64(), "last", last.Number.Uint64())
+ return p.SendBlockHeaders([]*types.HeaderWithGovState{})
+ }
+
+ snapshotHeight := map[uint64]struct{}{}
+ for r, height := range round {
+ log.Trace("#Include round", "round", r)
+ if r == 0 {
+ continue
+ }
+ h, err := pm.gov.GetRoundHeight(ctx, r)
+ if err != nil {
+ log.Warn("Get round height fail", "err", err)
+ return p.SendBlockHeaders([]*types.HeaderWithGovState{})
+ }
+ log.Trace("#Snapshot height", "height", h)
+ if h == 0 {
+ h = height
+ }
+ snapshotHeight[h] = struct{}{}
+ }
+
+ for _, header := range headers {
+ if _, exist := snapshotHeight[header.Number.Uint64()]; exist {
+ s, err := pm.blockchain.GetGovStateByHash(header.Hash())
+ if err != nil {
+ log.Warn("Get gov state by hash fail", "number", header.Number.Uint64(), "err", err)
+ return p.SendBlockHeaders([]*types.HeaderWithGovState{})
+ }
+ header.GovState = s
+ }
+ log.Trace("Send header", "round", header.Round, "number", header.Number.Uint64(), "gov state == nil", header.GovState == nil)
+ }
+ }
return p.SendBlockHeaders(headers)
case msg.Code == BlockHeadersMsg:
// A batch of headers arrived to one of our previous requests
- var headers []*types.Header
+ var headers []*types.HeaderWithGovState
if err := msg.Decode(&headers); err != nil {
return errResp(ErrDecode, "msg %v: %v", msg, err)
}
// Filter out any explicitly requested headers, deliver the rest to the downloader
filter := len(headers) == 1
if filter {
- headers = pm.fetcher.FilterHeaders(p.id, headers, time.Now())
+ h := []*types.Header{headers[0].Header}
+ h = pm.fetcher.FilterHeaders(p.id, h, time.Now())
+ if len(h) == 0 {
+ headers = nil
+ }
+ }
+ for _, header := range headers {
+ log.Trace("Received header", "round", header.Round, "number", header.Number.Uint64(), "gov state == nil", header.GovState == nil)
}
if len(headers) > 0 || !filter {
+ // if the header that has gov state is filter out, the header's gov state is useless
err := pm.downloader.DeliverHeaders(p.id, headers)
if err != nil {
log.Debug("Failed to deliver headers", "err", err)
@@ -834,6 +890,25 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
return err
}
}
+ case msg.Code == GetGovStateMsg:
+ var hash common.Hash
+ if err := msg.Decode(&hash); err != nil {
+ return errResp(ErrDecode, "msg %v: %v", msg, err)
+ }
+ govState, err := pm.blockchain.GetGovStateByHash(hash)
+ if err != nil {
+ // TODO(sonic): handle this error
+ panic(err)
+ }
+ return p.SendGovState(govState)
+ case msg.Code == GovStateMsg:
+ var govState types.GovState
+ if err := msg.Decode(&govState); err != nil {
+ return errResp(ErrDecode, "msg %v: %v", msg, err)
+ }
+ if err := pm.downloader.DeliverGovState(p.id, &govState); err != nil {
+ log.Debug("Failed to deliver govstates", "err", err)
+ }
default:
return errResp(ErrInvalidMsgCode, "%v", msg.Code)
}