From 58497f46bd0bdd105828c30500e863e826e598cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felf=C3=B6ldi=20Zsolt?= Date: Thu, 30 May 2019 20:51:13 +0200 Subject: les, les/flowcontrol: implement LES/3 (#19329) les, les/flowcontrol: implement LES/3 --- les/handler.go | 649 +++++++++++++++++++++++++++++++-------------------------- 1 file changed, 349 insertions(+), 300 deletions(-) (limited to 'les/handler.go') diff --git a/les/handler.go b/les/handler.go index f53a4722f..59bfd81cd 100644 --- a/les/handler.go +++ b/les/handler.go @@ -34,6 +34,7 @@ import ( "github.com/ethereum/go-ethereum/eth/downloader" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/event" + "github.com/ethereum/go-ethereum/les/csvlogger" "github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p" @@ -118,6 +119,7 @@ type ProtocolManager struct { wg *sync.WaitGroup eventMux *event.TypeMux + logger *csvlogger.Logger // Callbacks synced func() bool @@ -165,8 +167,6 @@ func NewProtocolManager( if odr != nil { manager.retriever = odr.retriever manager.reqDist = odr.retriever.dist - } else { - manager.servingQueue = newServingQueue(int64(time.Millisecond * 10)) } if ulcConfig != nil { @@ -272,6 +272,7 @@ func (pm *ProtocolManager) handle(p *peer) error { // Ignore maxPeers if this is a trusted peer // In server mode we try to check into the client pool after handshake if pm.client && pm.peers.Len() >= pm.maxPeers && !p.Peer.Info().Network.Trusted { + pm.logger.Event("Rejected (too many peers), " + p.id) return p2p.DiscTooManyPeers } // Reject light clients if server is not synced. @@ -290,6 +291,7 @@ func (pm *ProtocolManager) handle(p *peer) error { ) if err := p.Handshake(td, hash, number, genesis.Hash(), pm.server); err != nil { p.Log().Debug("Light Ethereum handshake failed", "err", err) + pm.logger.Event("Handshake error: " + err.Error() + ", " + p.id) return err } if p.fcClient != nil { @@ -303,9 +305,12 @@ func (pm *ProtocolManager) handle(p *peer) error { // Register the peer locally if err := pm.peers.Register(p); err != nil { p.Log().Error("Light Ethereum peer registration failed", "err", err) + pm.logger.Event("Peer registration error: " + err.Error() + ", " + p.id) return err } + pm.logger.Event("Connection established, " + p.id) defer func() { + pm.logger.Event("Closed connection, " + p.id) pm.removePeer(p.id) }() @@ -326,6 +331,7 @@ func (pm *ProtocolManager) handle(p *peer) error { // main loop. handle incoming messages. for { if err := pm.handleMsg(p); err != nil { + pm.logger.Event("Message handling error: " + err.Error() + ", " + p.id) p.Log().Debug("Light Ethereum message handling failed", "err", err) if p.fcServer != nil { p.fcServer.DumpLogs() @@ -358,23 +364,40 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { ) accept := func(reqID, reqCnt, maxCnt uint64) bool { - if reqCnt == 0 { - return false + inSizeCost := func() uint64 { + if pm.server.costTracker != nil { + return pm.server.costTracker.realCost(0, msg.Size, 0) + } + return 0 } - if p.fcClient == nil || reqCnt > maxCnt { + if p.isFrozen() || reqCnt == 0 || p.fcClient == nil || reqCnt > maxCnt { + p.fcClient.OneTimeCost(inSizeCost()) return false } - maxCost = p.fcCosts.getCost(msg.Code, reqCnt) + maxCost = p.fcCosts.getMaxCost(msg.Code, reqCnt) + gf := float64(1) + if pm.server.costTracker != nil { + gf = pm.server.costTracker.globalFactor() + if gf < 0.001 { + p.Log().Error("Invalid global cost factor", "globalFactor", gf) + gf = 1 + } + } + maxTime := uint64(float64(maxCost) / gf) if accepted, bufShort, servingPriority := p.fcClient.AcceptRequest(reqID, responseCount, maxCost); !accepted { - if bufShort > 0 { - p.Log().Error("Request came too early", "remaining", common.PrettyDuration(time.Duration(bufShort*1000000/p.fcParams.MinRecharge))) - } + p.freezeClient() + p.Log().Warn("Request came too early", "remaining", common.PrettyDuration(time.Duration(bufShort*1000000/p.fcParams.MinRecharge))) + p.fcClient.OneTimeCost(inSizeCost()) return false } else { - task = pm.servingQueue.newTask(servingPriority) + task = pm.servingQueue.newTask(p, maxTime, servingPriority) + } + if task.start() { + return true } - return task.start() + p.fcClient.RequestProcessed(reqID, responseCount, maxCost, inSizeCost()) + return false } if msg.Size > ProtocolMaxMsgSize { @@ -388,6 +411,10 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { p.responseLock.Lock() defer p.responseLock.Unlock() + if p.isFrozen() { + amount = 0 + reply = nil + } var replySize uint32 if reply != nil { replySize = reply.size() @@ -395,7 +422,9 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { var realCost uint64 if pm.server.costTracker != nil { realCost = pm.server.costTracker.realCost(servingTime, msg.Size, replySize) - pm.server.costTracker.updateStats(msg.Code, amount, servingTime, realCost) + if amount != 0 { + pm.server.costTracker.updateStats(msg.Code, amount, servingTime, realCost) + } } else { realCost = maxCost } @@ -463,94 +492,94 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { } query := req.Query - if !accept(req.ReqID, query.Amount, MaxHeaderFetch) { - return errResp(ErrRequestRejected, "") - } - go func() { - hashMode := query.Origin.Hash != (common.Hash{}) - first := true - maxNonCanonical := uint64(100) - - // Gather headers until the fetch or network limits is reached - var ( - bytes common.StorageSize - headers []*types.Header - unknown bool - ) - for !unknown && len(headers) < int(query.Amount) && bytes < softResponseLimit { - if !first && !task.waitOrStop() { - return - } - // Retrieve the next header satisfying the query - var origin *types.Header - if hashMode { - if first { - origin = pm.blockchain.GetHeaderByHash(query.Origin.Hash) - if origin != nil { - query.Origin.Number = origin.Number.Uint64() + if accept(req.ReqID, query.Amount, MaxHeaderFetch) { + go func() { + hashMode := query.Origin.Hash != (common.Hash{}) + first := true + maxNonCanonical := uint64(100) + + // Gather headers until the fetch or network limits is reached + var ( + bytes common.StorageSize + headers []*types.Header + unknown bool + ) + for !unknown && len(headers) < int(query.Amount) && bytes < softResponseLimit { + if !first && !task.waitOrStop() { + sendResponse(req.ReqID, 0, nil, task.servingTime) + return + } + // Retrieve the next header satisfying the query + var origin *types.Header + if hashMode { + if first { + origin = pm.blockchain.GetHeaderByHash(query.Origin.Hash) + if origin != nil { + query.Origin.Number = origin.Number.Uint64() + } + } else { + origin = pm.blockchain.GetHeader(query.Origin.Hash, query.Origin.Number) } } else { - origin = pm.blockchain.GetHeader(query.Origin.Hash, query.Origin.Number) + origin = pm.blockchain.GetHeaderByNumber(query.Origin.Number) } - } else { - origin = pm.blockchain.GetHeaderByNumber(query.Origin.Number) - } - if origin == nil { - break - } - headers = append(headers, origin) - bytes += estHeaderRlpSize - - // Advance to the next header of the query - switch { - case hashMode && query.Reverse: - // Hash based traversal towards the genesis block - ancestor := query.Skip + 1 - if ancestor == 0 { - unknown = true - } else { - query.Origin.Hash, query.Origin.Number = pm.blockchain.GetAncestor(query.Origin.Hash, query.Origin.Number, ancestor, &maxNonCanonical) - unknown = (query.Origin.Hash == common.Hash{}) + if origin == nil { + break } - case hashMode && !query.Reverse: - // Hash based traversal towards the leaf block - var ( - current = origin.Number.Uint64() - next = current + query.Skip + 1 - ) - if next <= current { - infos, _ := json.MarshalIndent(p.Peer.Info(), "", " ") - p.Log().Warn("GetBlockHeaders skip overflow attack", "current", current, "skip", query.Skip, "next", next, "attacker", infos) - unknown = true - } else { - if header := pm.blockchain.GetHeaderByNumber(next); header != nil { - nextHash := header.Hash() - expOldHash, _ := pm.blockchain.GetAncestor(nextHash, next, query.Skip+1, &maxNonCanonical) - if expOldHash == query.Origin.Hash { - query.Origin.Hash, query.Origin.Number = nextHash, next + headers = append(headers, origin) + bytes += estHeaderRlpSize + + // Advance to the next header of the query + switch { + case hashMode && query.Reverse: + // Hash based traversal towards the genesis block + ancestor := query.Skip + 1 + if ancestor == 0 { + unknown = true + } else { + query.Origin.Hash, query.Origin.Number = pm.blockchain.GetAncestor(query.Origin.Hash, query.Origin.Number, ancestor, &maxNonCanonical) + unknown = (query.Origin.Hash == common.Hash{}) + } + case hashMode && !query.Reverse: + // Hash based traversal towards the leaf block + var ( + current = origin.Number.Uint64() + next = current + query.Skip + 1 + ) + if next <= current { + infos, _ := json.MarshalIndent(p.Peer.Info(), "", " ") + p.Log().Warn("GetBlockHeaders skip overflow attack", "current", current, "skip", query.Skip, "next", next, "attacker", infos) + unknown = true + } else { + if header := pm.blockchain.GetHeaderByNumber(next); header != nil { + nextHash := header.Hash() + expOldHash, _ := pm.blockchain.GetAncestor(nextHash, next, query.Skip+1, &maxNonCanonical) + if expOldHash == query.Origin.Hash { + query.Origin.Hash, query.Origin.Number = nextHash, next + } else { + unknown = true + } } else { unknown = true } + } + case query.Reverse: + // Number based traversal towards the genesis block + if query.Origin.Number >= query.Skip+1 { + query.Origin.Number -= query.Skip + 1 } else { unknown = true } - } - case query.Reverse: - // Number based traversal towards the genesis block - if query.Origin.Number >= query.Skip+1 { - query.Origin.Number -= query.Skip + 1 - } else { - unknown = true - } - case !query.Reverse: - // Number based traversal towards the leaf block - query.Origin.Number += query.Skip + 1 + case !query.Reverse: + // Number based traversal towards the leaf block + query.Origin.Number += query.Skip + 1 + } + first = false } - first = false - } - sendResponse(req.ReqID, query.Amount, p.ReplyBlockHeaders(req.ReqID, headers), task.done()) - }() + sendResponse(req.ReqID, query.Amount, p.ReplyBlockHeaders(req.ReqID, headers), task.done()) + }() + } case BlockHeadersMsg: if pm.downloader == nil { @@ -592,27 +621,27 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { bodies []rlp.RawValue ) reqCnt := len(req.Hashes) - if !accept(req.ReqID, uint64(reqCnt), MaxBodyFetch) { - return errResp(ErrRequestRejected, "") - } - go func() { - for i, hash := range req.Hashes { - if i != 0 && !task.waitOrStop() { - return - } - if bytes >= softResponseLimit { - break - } - // Retrieve the requested block body, stopping if enough was found - if number := rawdb.ReadHeaderNumber(pm.chainDb, hash); number != nil { - if data := rawdb.ReadBodyRLP(pm.chainDb, hash, *number); len(data) != 0 { - bodies = append(bodies, data) - bytes += len(data) + if accept(req.ReqID, uint64(reqCnt), MaxBodyFetch) { + go func() { + for i, hash := range req.Hashes { + if i != 0 && !task.waitOrStop() { + sendResponse(req.ReqID, 0, nil, task.servingTime) + return + } + if bytes >= softResponseLimit { + break + } + // Retrieve the requested block body, stopping if enough was found + if number := rawdb.ReadHeaderNumber(pm.chainDb, hash); number != nil { + if data := rawdb.ReadBodyRLP(pm.chainDb, hash, *number); len(data) != 0 { + bodies = append(bodies, data) + bytes += len(data) + } } } - } - sendResponse(req.ReqID, uint64(reqCnt), p.ReplyBlockBodiesRLP(req.ReqID, bodies), task.done()) - }() + sendResponse(req.ReqID, uint64(reqCnt), p.ReplyBlockBodiesRLP(req.ReqID, bodies), task.done()) + }() + } case BlockBodiesMsg: if pm.odr == nil { @@ -651,45 +680,45 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { data [][]byte ) reqCnt := len(req.Reqs) - if !accept(req.ReqID, uint64(reqCnt), MaxCodeFetch) { - return errResp(ErrRequestRejected, "") - } - go func() { - for i, req := range req.Reqs { - if i != 0 && !task.waitOrStop() { - return - } - // Look up the root hash belonging to the request - number := rawdb.ReadHeaderNumber(pm.chainDb, req.BHash) - if number == nil { - p.Log().Warn("Failed to retrieve block num for code", "hash", req.BHash) - continue - } - header := rawdb.ReadHeader(pm.chainDb, req.BHash, *number) - if header == nil { - p.Log().Warn("Failed to retrieve header for code", "block", *number, "hash", req.BHash) - continue - } - triedb := pm.blockchain.StateCache().TrieDB() + if accept(req.ReqID, uint64(reqCnt), MaxCodeFetch) { + go func() { + for i, request := range req.Reqs { + if i != 0 && !task.waitOrStop() { + sendResponse(req.ReqID, 0, nil, task.servingTime) + return + } + // Look up the root hash belonging to the request + number := rawdb.ReadHeaderNumber(pm.chainDb, request.BHash) + if number == nil { + p.Log().Warn("Failed to retrieve block num for code", "hash", request.BHash) + continue + } + header := rawdb.ReadHeader(pm.chainDb, request.BHash, *number) + if header == nil { + p.Log().Warn("Failed to retrieve header for code", "block", *number, "hash", request.BHash) + continue + } + triedb := pm.blockchain.StateCache().TrieDB() - account, err := pm.getAccount(triedb, header.Root, common.BytesToHash(req.AccKey)) - if err != nil { - p.Log().Warn("Failed to retrieve account for code", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(req.AccKey), "err", err) - continue - } - code, err := triedb.Node(common.BytesToHash(account.CodeHash)) - if err != nil { - p.Log().Warn("Failed to retrieve account code", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(req.AccKey), "codehash", common.BytesToHash(account.CodeHash), "err", err) - continue - } - // Accumulate the code and abort if enough data was retrieved - data = append(data, code) - if bytes += len(code); bytes >= softResponseLimit { - break + account, err := pm.getAccount(triedb, header.Root, common.BytesToHash(request.AccKey)) + if err != nil { + p.Log().Warn("Failed to retrieve account for code", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "err", err) + continue + } + code, err := triedb.Node(common.BytesToHash(account.CodeHash)) + if err != nil { + p.Log().Warn("Failed to retrieve account code", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "codehash", common.BytesToHash(account.CodeHash), "err", err) + continue + } + // Accumulate the code and abort if enough data was retrieved + data = append(data, code) + if bytes += len(code); bytes >= softResponseLimit { + break + } } - } - sendResponse(req.ReqID, uint64(reqCnt), p.ReplyCode(req.ReqID, data), task.done()) - }() + sendResponse(req.ReqID, uint64(reqCnt), p.ReplyCode(req.ReqID, data), task.done()) + }() + } case CodeMsg: if pm.odr == nil { @@ -728,37 +757,37 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { receipts []rlp.RawValue ) reqCnt := len(req.Hashes) - if !accept(req.ReqID, uint64(reqCnt), MaxReceiptFetch) { - return errResp(ErrRequestRejected, "") - } - go func() { - for i, hash := range req.Hashes { - if i != 0 && !task.waitOrStop() { - return - } - if bytes >= softResponseLimit { - break - } - // Retrieve the requested block's receipts, skipping if unknown to us - var results types.Receipts - if number := rawdb.ReadHeaderNumber(pm.chainDb, hash); number != nil { - results = rawdb.ReadRawReceipts(pm.chainDb, hash, *number) - } - if results == nil { - if header := pm.blockchain.GetHeaderByHash(hash); header == nil || header.ReceiptHash != types.EmptyRootHash { - continue + if accept(req.ReqID, uint64(reqCnt), MaxReceiptFetch) { + go func() { + for i, hash := range req.Hashes { + if i != 0 && !task.waitOrStop() { + sendResponse(req.ReqID, 0, nil, task.servingTime) + return + } + if bytes >= softResponseLimit { + break + } + // Retrieve the requested block's receipts, skipping if unknown to us + var results types.Receipts + if number := rawdb.ReadHeaderNumber(pm.chainDb, hash); number != nil { + results = rawdb.ReadRawReceipts(pm.chainDb, hash, *number) + } + if results == nil { + if header := pm.blockchain.GetHeaderByHash(hash); header == nil || header.ReceiptHash != types.EmptyRootHash { + continue + } + } + // If known, encode and queue for response packet + if encoded, err := rlp.EncodeToBytes(results); err != nil { + log.Error("Failed to encode receipt", "err", err) + } else { + receipts = append(receipts, encoded) + bytes += len(encoded) } } - // If known, encode and queue for response packet - if encoded, err := rlp.EncodeToBytes(results); err != nil { - log.Error("Failed to encode receipt", "err", err) - } else { - receipts = append(receipts, encoded) - bytes += len(encoded) - } - } - sendResponse(req.ReqID, uint64(reqCnt), p.ReplyReceiptsRLP(req.ReqID, receipts), task.done()) - }() + sendResponse(req.ReqID, uint64(reqCnt), p.ReplyReceiptsRLP(req.ReqID, receipts), task.done()) + }() + } case ReceiptsMsg: if pm.odr == nil { @@ -797,70 +826,70 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { root common.Hash ) reqCnt := len(req.Reqs) - if !accept(req.ReqID, uint64(reqCnt), MaxProofsFetch) { - return errResp(ErrRequestRejected, "") - } - go func() { - nodes := light.NewNodeSet() - - for i, req := range req.Reqs { - if i != 0 && !task.waitOrStop() { - return - } - // Look up the root hash belonging to the request - var ( - number *uint64 - header *types.Header - trie state.Trie - ) - if req.BHash != lastBHash { - root, lastBHash = common.Hash{}, req.BHash - - if number = rawdb.ReadHeaderNumber(pm.chainDb, req.BHash); number == nil { - p.Log().Warn("Failed to retrieve block num for proof", "hash", req.BHash) - continue + if accept(req.ReqID, uint64(reqCnt), MaxProofsFetch) { + go func() { + nodes := light.NewNodeSet() + + for i, request := range req.Reqs { + if i != 0 && !task.waitOrStop() { + sendResponse(req.ReqID, 0, nil, task.servingTime) + return } - if header = rawdb.ReadHeader(pm.chainDb, req.BHash, *number); header == nil { - p.Log().Warn("Failed to retrieve header for proof", "block", *number, "hash", req.BHash) - continue + // Look up the root hash belonging to the request + var ( + number *uint64 + header *types.Header + trie state.Trie + ) + if request.BHash != lastBHash { + root, lastBHash = common.Hash{}, request.BHash + + if number = rawdb.ReadHeaderNumber(pm.chainDb, request.BHash); number == nil { + p.Log().Warn("Failed to retrieve block num for proof", "hash", request.BHash) + continue + } + if header = rawdb.ReadHeader(pm.chainDb, request.BHash, *number); header == nil { + p.Log().Warn("Failed to retrieve header for proof", "block", *number, "hash", request.BHash) + continue + } + root = header.Root } - root = header.Root - } - // Open the account or storage trie for the request - statedb := pm.blockchain.StateCache() - - switch len(req.AccKey) { - case 0: - // No account key specified, open an account trie - trie, err = statedb.OpenTrie(root) - if trie == nil || err != nil { - p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "root", root, "err", err) - continue + // Open the account or storage trie for the request + statedb := pm.blockchain.StateCache() + + switch len(request.AccKey) { + case 0: + // No account key specified, open an account trie + trie, err = statedb.OpenTrie(root) + if trie == nil || err != nil { + p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "root", root, "err", err) + continue + } + default: + // Account key specified, open a storage trie + account, err := pm.getAccount(statedb.TrieDB(), root, common.BytesToHash(request.AccKey)) + if err != nil { + p.Log().Warn("Failed to retrieve account for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "err", err) + continue + } + trie, err = statedb.OpenStorageTrie(common.BytesToHash(request.AccKey), account.Root) + if trie == nil || err != nil { + p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "root", account.Root, "err", err) + continue + } } - default: - // Account key specified, open a storage trie - account, err := pm.getAccount(statedb.TrieDB(), root, common.BytesToHash(req.AccKey)) - if err != nil { - p.Log().Warn("Failed to retrieve account for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(req.AccKey), "err", err) + // Prove the user's request from the account or stroage trie + if err := trie.Prove(request.Key, request.FromLevel, nodes); err != nil { + p.Log().Warn("Failed to prove state request", "block", header.Number, "hash", header.Hash(), "err", err) continue } - trie, err = statedb.OpenStorageTrie(common.BytesToHash(req.AccKey), account.Root) - if trie == nil || err != nil { - p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(req.AccKey), "root", account.Root, "err", err) - continue + if nodes.DataSize() >= softResponseLimit { + break } } - // Prove the user's request from the account or stroage trie - if err := trie.Prove(req.Key, req.FromLevel, nodes); err != nil { - p.Log().Warn("Failed to prove state request", "block", header.Number, "hash", header.Hash(), "err", err) - continue - } - if nodes.DataSize() >= softResponseLimit { - break - } - } - sendResponse(req.ReqID, uint64(reqCnt), p.ReplyProofsV2(req.ReqID, nodes.NodeList()), task.done()) - }() + sendResponse(req.ReqID, uint64(reqCnt), p.ReplyProofsV2(req.ReqID, nodes.NodeList()), task.done()) + }() + } case ProofsV2Msg: if pm.odr == nil { @@ -899,53 +928,53 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { auxData [][]byte ) reqCnt := len(req.Reqs) - if !accept(req.ReqID, uint64(reqCnt), MaxHelperTrieProofsFetch) { - return errResp(ErrRequestRejected, "") - } - go func() { - - var ( - lastIdx uint64 - lastType uint - root common.Hash - auxTrie *trie.Trie - ) - nodes := light.NewNodeSet() - for i, req := range req.Reqs { - if i != 0 && !task.waitOrStop() { - return - } - if auxTrie == nil || req.Type != lastType || req.TrieIdx != lastIdx { - auxTrie, lastType, lastIdx = nil, req.Type, req.TrieIdx + if accept(req.ReqID, uint64(reqCnt), MaxHelperTrieProofsFetch) { + go func() { - var prefix string - if root, prefix = pm.getHelperTrie(req.Type, req.TrieIdx); root != (common.Hash{}) { - auxTrie, _ = trie.New(root, trie.NewDatabase(rawdb.NewTable(pm.chainDb, prefix))) - } - } - if req.AuxReq == auxRoot { - var data []byte - if root != (common.Hash{}) { - data = root[:] + var ( + lastIdx uint64 + lastType uint + root common.Hash + auxTrie *trie.Trie + ) + nodes := light.NewNodeSet() + for i, request := range req.Reqs { + if i != 0 && !task.waitOrStop() { + sendResponse(req.ReqID, 0, nil, task.servingTime) + return } - auxData = append(auxData, data) - auxBytes += len(data) - } else { - if auxTrie != nil { - auxTrie.Prove(req.Key, req.FromLevel, nodes) + if auxTrie == nil || request.Type != lastType || request.TrieIdx != lastIdx { + auxTrie, lastType, lastIdx = nil, request.Type, request.TrieIdx + + var prefix string + if root, prefix = pm.getHelperTrie(request.Type, request.TrieIdx); root != (common.Hash{}) { + auxTrie, _ = trie.New(root, trie.NewDatabase(rawdb.NewTable(pm.chainDb, prefix))) + } } - if req.AuxReq != 0 { - data := pm.getHelperTrieAuxData(req) + if request.AuxReq == auxRoot { + var data []byte + if root != (common.Hash{}) { + data = root[:] + } auxData = append(auxData, data) auxBytes += len(data) + } else { + if auxTrie != nil { + auxTrie.Prove(request.Key, request.FromLevel, nodes) + } + if request.AuxReq != 0 { + data := pm.getHelperTrieAuxData(request) + auxData = append(auxData, data) + auxBytes += len(data) + } + } + if nodes.DataSize()+auxBytes >= softResponseLimit { + break } } - if nodes.DataSize()+auxBytes >= softResponseLimit { - break - } - } - sendResponse(req.ReqID, uint64(reqCnt), p.ReplyHelperTrieProofs(req.ReqID, HelperTrieResps{Proofs: nodes.NodeList(), AuxData: auxData}), task.done()) - }() + sendResponse(req.ReqID, uint64(reqCnt), p.ReplyHelperTrieProofs(req.ReqID, HelperTrieResps{Proofs: nodes.NodeList(), AuxData: auxData}), task.done()) + }() + } case HelperTrieProofsMsg: if pm.odr == nil { @@ -981,27 +1010,27 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { return errResp(ErrDecode, "msg %v: %v", msg, err) } reqCnt := len(req.Txs) - if !accept(req.ReqID, uint64(reqCnt), MaxTxSend) { - return errResp(ErrRequestRejected, "") - } - go func() { - stats := make([]light.TxStatus, len(req.Txs)) - for i, tx := range req.Txs { - if i != 0 && !task.waitOrStop() { - return - } - hash := tx.Hash() - stats[i] = pm.txStatus(hash) - if stats[i].Status == core.TxStatusUnknown { - if errs := pm.txpool.AddRemotes([]*types.Transaction{tx}); errs[0] != nil { - stats[i].Error = errs[0].Error() - continue + if accept(req.ReqID, uint64(reqCnt), MaxTxSend) { + go func() { + stats := make([]light.TxStatus, len(req.Txs)) + for i, tx := range req.Txs { + if i != 0 && !task.waitOrStop() { + sendResponse(req.ReqID, 0, nil, task.servingTime) + return } + hash := tx.Hash() stats[i] = pm.txStatus(hash) + if stats[i].Status == core.TxStatusUnknown { + if errs := pm.txpool.AddRemotes([]*types.Transaction{tx}); errs[0] != nil { + stats[i].Error = errs[0].Error() + continue + } + stats[i] = pm.txStatus(hash) + } } - } - sendResponse(req.ReqID, uint64(reqCnt), p.ReplyTxStatus(req.ReqID, stats), task.done()) - }() + sendResponse(req.ReqID, uint64(reqCnt), p.ReplyTxStatus(req.ReqID, stats), task.done()) + }() + } case GetTxStatusMsg: if pm.txpool == nil { @@ -1016,19 +1045,19 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { return errResp(ErrDecode, "msg %v: %v", msg, err) } reqCnt := len(req.Hashes) - if !accept(req.ReqID, uint64(reqCnt), MaxTxStatus) { - return errResp(ErrRequestRejected, "") - } - go func() { - stats := make([]light.TxStatus, len(req.Hashes)) - for i, hash := range req.Hashes { - if i != 0 && !task.waitOrStop() { - return + if accept(req.ReqID, uint64(reqCnt), MaxTxStatus) { + go func() { + stats := make([]light.TxStatus, len(req.Hashes)) + for i, hash := range req.Hashes { + if i != 0 && !task.waitOrStop() { + sendResponse(req.ReqID, 0, nil, task.servingTime) + return + } + stats[i] = pm.txStatus(hash) } - stats[i] = pm.txStatus(hash) - } - sendResponse(req.ReqID, uint64(reqCnt), p.ReplyTxStatus(req.ReqID, stats), task.done()) - }() + sendResponse(req.ReqID, uint64(reqCnt), p.ReplyTxStatus(req.ReqID, stats), task.done()) + }() + } case TxStatusMsg: if pm.odr == nil { @@ -1053,6 +1082,26 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { Obj: resp.Status, } + case StopMsg: + if pm.odr == nil { + return errResp(ErrUnexpectedResponse, "") + } + p.freezeServer(true) + pm.retriever.frozen(p) + p.Log().Warn("Service stopped") + + case ResumeMsg: + if pm.odr == nil { + return errResp(ErrUnexpectedResponse, "") + } + var bv uint64 + if err := msg.Decode(&bv); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + p.fcServer.ResumeFreeze(bv) + p.freezeServer(false) + p.Log().Warn("Service resumed") + default: p.Log().Trace("Received unknown message", "code", msg.Code) return errResp(ErrInvalidMsgCode, "%v", msg.Code) -- cgit v1.2.3