From 9f8d192991c4f68fa14c91366722bbca601da117 Mon Sep 17 00:00:00 2001 From: Zsolt Felfoldi Date: Fri, 14 Oct 2016 05:51:29 +0200 Subject: les: light client protocol and API --- les/api_backend.go | 144 ++++++++ les/backend.go | 218 ++++++++++++ les/fetcher.go | 295 ++++++++++++++++ les/flowcontrol/control.go | 172 +++++++++ les/flowcontrol/manager.go | 223 ++++++++++++ les/handler.go | 854 +++++++++++++++++++++++++++++++++++++++++++++ les/handler_test.go | 322 +++++++++++++++++ les/helper_test.go | 318 +++++++++++++++++ les/metrics.go | 111 ++++++ les/odr.go | 247 +++++++++++++ les/odr_peerset.go | 119 +++++++ les/odr_requests.go | 325 +++++++++++++++++ les/odr_test.go | 222 ++++++++++++ les/peer.go | 584 +++++++++++++++++++++++++++++++ les/protocol.go | 198 +++++++++++ les/request_test.go | 94 +++++ les/server.go | 401 +++++++++++++++++++++ les/sync.go | 84 +++++ les/txrelay.go | 156 +++++++++ 19 files changed, 5087 insertions(+) create mode 100644 les/api_backend.go create mode 100644 les/backend.go create mode 100644 les/fetcher.go create mode 100644 les/flowcontrol/control.go create mode 100644 les/flowcontrol/manager.go create mode 100644 les/handler.go create mode 100644 les/handler_test.go create mode 100644 les/helper_test.go create mode 100644 les/metrics.go create mode 100644 les/odr.go create mode 100644 les/odr_peerset.go create mode 100644 les/odr_requests.go create mode 100644 les/odr_test.go create mode 100644 les/peer.go create mode 100644 les/protocol.go create mode 100644 les/request_test.go create mode 100644 les/server.go create mode 100644 les/sync.go create mode 100644 les/txrelay.go (limited to 'les') diff --git a/les/api_backend.go b/les/api_backend.go new file mode 100644 index 000000000..d50b3ea33 --- /dev/null +++ b/les/api_backend.go @@ -0,0 +1,144 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of go-ethereum. +// +// go-ethereum is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// go-ethereum is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with go-ethereum. If not, see . + +package les + +import ( + "math/big" + + "github.com/ethereum/go-ethereum/accounts" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/core/vm" + "github.com/ethereum/go-ethereum/eth/downloader" + "github.com/ethereum/go-ethereum/eth/gasprice" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/event" + "github.com/ethereum/go-ethereum/internal/ethapi" + "github.com/ethereum/go-ethereum/light" + rpc "github.com/ethereum/go-ethereum/rpc" + "golang.org/x/net/context" +) + +type LesApiBackend struct { + eth *LightEthereum + gpo *gasprice.LightPriceOracle +} + +func (b *LesApiBackend) SetHead(number uint64) { + b.eth.blockchain.SetHead(number) +} + +func (b *LesApiBackend) HeaderByNumber(ctx context.Context, blockNr rpc.BlockNumber) (*types.Header, error) { + if blockNr == rpc.LatestBlockNumber || blockNr == rpc.PendingBlockNumber { + return b.eth.blockchain.CurrentHeader(), nil + } + + return b.eth.blockchain.GetHeaderByNumberOdr(ctx, uint64(blockNr)) +} + +func (b *LesApiBackend) BlockByNumber(ctx context.Context, blockNr rpc.BlockNumber) (*types.Block, error) { + header, err := b.HeaderByNumber(ctx, blockNr) + if header == nil || err != nil { + return nil, err + } + return b.GetBlock(ctx, header.Hash()) +} + +func (b *LesApiBackend) StateAndHeaderByNumber(ctx context.Context, blockNr rpc.BlockNumber) (ethapi.State, *types.Header, error) { + header, err := b.HeaderByNumber(ctx, blockNr) + if header == nil || err != nil { + return nil, nil, err + } + return light.NewLightState(light.StateTrieID(header), b.eth.odr), header, nil +} + +func (b *LesApiBackend) GetBlock(ctx context.Context, blockHash common.Hash) (*types.Block, error) { + return b.eth.blockchain.GetBlockByHash(ctx, blockHash) +} + +func (b *LesApiBackend) GetReceipts(ctx context.Context, blockHash common.Hash) (types.Receipts, error) { + return light.GetBlockReceipts(ctx, b.eth.odr, blockHash, core.GetBlockNumber(b.eth.chainDb, blockHash)) +} + +func (b *LesApiBackend) GetTd(blockHash common.Hash) *big.Int { + return b.eth.blockchain.GetTdByHash(blockHash) +} + +func (b *LesApiBackend) GetVMEnv(ctx context.Context, msg core.Message, state ethapi.State, header *types.Header) (vm.Environment, func() error, error) { + stateDb := state.(*light.LightState).Copy() + addr, _ := msg.From() + from, err := stateDb.GetOrNewStateObject(ctx, addr) + if err != nil { + return nil, nil, err + } + from.SetBalance(common.MaxBig) + env := light.NewEnv(ctx, stateDb, b.eth.chainConfig, b.eth.blockchain, msg, header, b.eth.chainConfig.VmConfig) + return env, env.Error, nil +} + +func (b *LesApiBackend) SendTx(ctx context.Context, signedTx *types.Transaction) error { + return b.eth.txPool.Add(ctx, signedTx) +} + +func (b *LesApiBackend) RemoveTx(txHash common.Hash) { + b.eth.txPool.RemoveTx(txHash) +} + +func (b *LesApiBackend) GetPoolTransactions() types.Transactions { + return b.eth.txPool.GetTransactions() +} + +func (b *LesApiBackend) GetPoolTransaction(txHash common.Hash) *types.Transaction { + return b.eth.txPool.GetTransaction(txHash) +} + +func (b *LesApiBackend) GetPoolNonce(ctx context.Context, addr common.Address) (uint64, error) { + return b.eth.txPool.GetNonce(ctx, addr) +} + +func (b *LesApiBackend) Stats() (pending int, queued int) { + return b.eth.txPool.Stats(), 0 +} + +func (b *LesApiBackend) TxPoolContent() (map[common.Address]types.Transactions, map[common.Address]types.Transactions) { + return b.eth.txPool.Content() +} + +func (b *LesApiBackend) Downloader() *downloader.Downloader { + return b.eth.Downloader() +} + +func (b *LesApiBackend) ProtocolVersion() int { + return b.eth.LesVersion() + 10000 +} + +func (b *LesApiBackend) SuggestPrice(ctx context.Context) (*big.Int, error) { + return b.gpo.SuggestPrice(ctx) +} + +func (b *LesApiBackend) ChainDb() ethdb.Database { + return b.eth.chainDb +} + +func (b *LesApiBackend) EventMux() *event.TypeMux { + return b.eth.eventMux +} + +func (b *LesApiBackend) AccountManager() *accounts.Manager { + return b.eth.accountManager +} diff --git a/les/backend.go b/les/backend.go new file mode 100644 index 000000000..8011a4b31 --- /dev/null +++ b/les/backend.go @@ -0,0 +1,218 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// Package les implements the Light Ethereum Subprotocol. +package les + +import ( + "errors" + "fmt" + "time" + + "github.com/ethereum/ethash" + "github.com/ethereum/go-ethereum/accounts" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/compiler" + "github.com/ethereum/go-ethereum/common/httpclient" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/core/vm" + "github.com/ethereum/go-ethereum/eth" + "github.com/ethereum/go-ethereum/eth/downloader" + "github.com/ethereum/go-ethereum/eth/filters" + "github.com/ethereum/go-ethereum/eth/gasprice" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/event" + "github.com/ethereum/go-ethereum/internal/ethapi" + "github.com/ethereum/go-ethereum/light" + "github.com/ethereum/go-ethereum/node" + "github.com/ethereum/go-ethereum/p2p" + rpc "github.com/ethereum/go-ethereum/rpc" +) + +type LightEthereum struct { + odr *LesOdr + relay *LesTxRelay + chainConfig *core.ChainConfig + // Channel for shutting down the service + shutdownChan chan bool + // Handlers + txPool *light.TxPool + blockchain *light.LightChain + protocolManager *ProtocolManager + // DB interfaces + chainDb ethdb.Database // Block chain database + + ApiBackend *LesApiBackend + + eventMux *event.TypeMux + pow *ethash.Ethash + httpclient *httpclient.HTTPClient + accountManager *accounts.Manager + solcPath string + solc *compiler.Solidity + + NatSpec bool + PowTest bool + netVersionId int + netRPCService *ethapi.PublicNetAPI +} + +func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) { + chainDb, err := eth.CreateDB(ctx, config, "lightchaindata") + if err != nil { + return nil, err + } + if err := eth.SetupGenesisBlock(&chainDb, config); err != nil { + return nil, err + } + pow, err := eth.CreatePoW(config) + if err != nil { + return nil, err + } + + odr := NewLesOdr(chainDb) + relay := NewLesTxRelay() + eth := &LightEthereum{ + odr: odr, + relay: relay, + chainDb: chainDb, + eventMux: ctx.EventMux, + accountManager: ctx.AccountManager, + pow: pow, + shutdownChan: make(chan bool), + httpclient: httpclient.New(config.DocRoot), + netVersionId: config.NetworkId, + NatSpec: config.NatSpec, + PowTest: config.PowTest, + solcPath: config.SolcPath, + } + + if config.ChainConfig == nil { + return nil, errors.New("missing chain config") + } + eth.chainConfig = config.ChainConfig + eth.chainConfig.VmConfig = vm.Config{ + EnableJit: config.EnableJit, + ForceJit: config.ForceJit, + } + eth.blockchain, err = light.NewLightChain(odr, eth.chainConfig, eth.pow, eth.eventMux) + if err != nil { + if err == core.ErrNoGenesis { + return nil, fmt.Errorf(`Genesis block not found. Please supply a genesis block with the "--genesis /path/to/file" argument`) + } + return nil, err + } + + eth.txPool = light.NewTxPool(eth.chainConfig, eth.eventMux, eth.blockchain, eth.relay) + if eth.protocolManager, err = NewProtocolManager(eth.chainConfig, config.LightMode, config.NetworkId, eth.eventMux, eth.pow, eth.blockchain, nil, chainDb, odr, relay); err != nil { + return nil, err + } + + eth.ApiBackend = &LesApiBackend{eth, nil} + eth.ApiBackend.gpo = gasprice.NewLightPriceOracle(eth.ApiBackend) + return eth, nil +} + +type LightDummyAPI struct{} + +// Etherbase is the address that mining rewards will be send to +func (s *LightDummyAPI) Etherbase() (common.Address, error) { + return common.Address{}, fmt.Errorf("not supported") +} + +// Coinbase is the address that mining rewards will be send to (alias for Etherbase) +func (s *LightDummyAPI) Coinbase() (common.Address, error) { + return common.Address{}, fmt.Errorf("not supported") +} + +// Hashrate returns the POW hashrate +func (s *LightDummyAPI) Hashrate() *rpc.HexNumber { + return rpc.NewHexNumber(0) +} + +// Mining returns an indication if this node is currently mining. +func (s *LightDummyAPI) Mining() bool { + return false +} + +// APIs returns the collection of RPC services the ethereum package offers. +// NOTE, some of these services probably need to be moved to somewhere else. +func (s *LightEthereum) APIs() []rpc.API { + return append(ethapi.GetAPIs(s.ApiBackend, s.solcPath), []rpc.API{ + { + Namespace: "eth", + Version: "1.0", + Service: &LightDummyAPI{}, + Public: true, + }, { + Namespace: "eth", + Version: "1.0", + Service: downloader.NewPublicDownloaderAPI(s.protocolManager.downloader, s.eventMux), + Public: true, + }, { + Namespace: "eth", + Version: "1.0", + Service: filters.NewPublicFilterAPI(s.ApiBackend, true), + Public: true, + }, { + Namespace: "net", + Version: "1.0", + Service: s.netRPCService, + Public: true, + }, + }...) +} + +func (s *LightEthereum) ResetWithGenesisBlock(gb *types.Block) { + s.blockchain.ResetWithGenesisBlock(gb) +} + +func (s *LightEthereum) BlockChain() *light.LightChain { return s.blockchain } +func (s *LightEthereum) TxPool() *light.TxPool { return s.txPool } +func (s *LightEthereum) LesVersion() int { return int(s.protocolManager.SubProtocols[0].Version) } +func (s *LightEthereum) Downloader() *downloader.Downloader { return s.protocolManager.downloader } + +// Protocols implements node.Service, returning all the currently configured +// network protocols to start. +func (s *LightEthereum) Protocols() []p2p.Protocol { + return s.protocolManager.SubProtocols +} + +// Start implements node.Service, starting all internal goroutines needed by the +// Ethereum protocol implementation. +func (s *LightEthereum) Start(srvr *p2p.Server) error { + s.netRPCService = ethapi.NewPublicNetAPI(srvr, s.netVersionId) + s.protocolManager.Start() + return nil +} + +// Stop implements node.Service, terminating all internal goroutines used by the +// Ethereum protocol. +func (s *LightEthereum) Stop() error { + s.odr.Stop() + s.blockchain.Stop() + s.protocolManager.Stop() + s.txPool.Stop() + + s.eventMux.Stop() + + time.Sleep(time.Millisecond * 200) + s.chainDb.Close() + close(s.shutdownChan) + + return nil +} diff --git a/les/fetcher.go b/les/fetcher.go new file mode 100644 index 000000000..3fa5cf0e2 --- /dev/null +++ b/les/fetcher.go @@ -0,0 +1,295 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// Package les implements the Light Ethereum Subprotocol. +package les + +import ( + "math/big" + "sync" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/types" +) + +type lightFetcher struct { + pm *ProtocolManager + odr *LesOdr + chain BlockChain + + headAnnouncedMu sync.Mutex + headAnnouncedBy map[common.Hash][]*peer + currentTd *big.Int + deliverChn chan fetchResponse + reqMu sync.RWMutex + requested map[uint64]fetchRequest + timeoutChn chan uint64 + notifyChn chan bool // true if initiated from outside + syncing bool + syncDone chan struct{} +} + +type fetchRequest struct { + hash common.Hash + amount uint64 + peer *peer +} + +type fetchResponse struct { + reqID uint64 + headers []*types.Header +} + +func newLightFetcher(pm *ProtocolManager) *lightFetcher { + f := &lightFetcher{ + pm: pm, + chain: pm.blockchain, + odr: pm.odr, + headAnnouncedBy: make(map[common.Hash][]*peer), + deliverChn: make(chan fetchResponse, 100), + requested: make(map[uint64]fetchRequest), + timeoutChn: make(chan uint64), + notifyChn: make(chan bool, 100), + syncDone: make(chan struct{}), + currentTd: big.NewInt(0), + } + go f.syncLoop() + return f +} + +func (f *lightFetcher) notify(p *peer, head *announceData) { + var headHash common.Hash + if head == nil { + // initial notify + headHash = p.Head() + } else { + if core.GetTd(f.pm.chainDb, head.Hash, head.Number) != nil { + head.haveHeaders = head.Number + } + //fmt.Println("notify", p.id, head.Number, head.ReorgDepth, head.haveHeaders) + if !p.addNotify(head) { + //fmt.Println("addNotify fail") + f.pm.removePeer(p.id) + } + headHash = head.Hash + } + f.headAnnouncedMu.Lock() + f.headAnnouncedBy[headHash] = append(f.headAnnouncedBy[headHash], p) + f.headAnnouncedMu.Unlock() + f.notifyChn <- true +} + +func (f *lightFetcher) gotHeader(header *types.Header) { + f.headAnnouncedMu.Lock() + defer f.headAnnouncedMu.Unlock() + + hash := header.Hash() + peerList := f.headAnnouncedBy[hash] + if peerList == nil { + return + } + number := header.GetNumberU64() + td := core.GetTd(f.pm.chainDb, hash, number) + for _, peer := range peerList { + peer.lock.Lock() + ok := peer.gotHeader(hash, number, td) + peer.lock.Unlock() + if !ok { + //fmt.Println("gotHeader fail") + f.pm.removePeer(peer.id) + } + } + delete(f.headAnnouncedBy, hash) +} + +func (f *lightFetcher) nextRequest() (*peer, *announceData) { + var bestPeer *peer + bestTd := f.currentTd + for _, peer := range f.pm.peers.AllPeers() { + peer.lock.RLock() + if !peer.headInfo.requested && (peer.headInfo.Td.Cmp(bestTd) > 0 || + (bestPeer != nil && peer.headInfo.Td.Cmp(bestTd) == 0 && peer.headInfo.haveHeaders > bestPeer.headInfo.haveHeaders)) { + bestPeer = peer + bestTd = peer.headInfo.Td + } + peer.lock.RUnlock() + } + if bestPeer == nil { + return nil, nil + } + bestPeer.lock.Lock() + res := bestPeer.headInfo + res.requested = true + bestPeer.lock.Unlock() + for _, peer := range f.pm.peers.AllPeers() { + if peer != bestPeer { + peer.lock.Lock() + if peer.headInfo.Hash == bestPeer.headInfo.Hash && peer.headInfo.haveHeaders == bestPeer.headInfo.haveHeaders { + peer.headInfo.requested = true + } + peer.lock.Unlock() + } + } + return bestPeer, res +} + +func (f *lightFetcher) deliverHeaders(reqID uint64, headers []*types.Header) { + f.deliverChn <- fetchResponse{reqID: reqID, headers: headers} +} + +func (f *lightFetcher) requestedID(reqID uint64) bool { + f.reqMu.RLock() + _, ok := f.requested[reqID] + f.reqMu.RUnlock() + return ok +} + +func (f *lightFetcher) request(p *peer, block *announceData) { + //fmt.Println("request", p.id, block.Number, block.haveHeaders) + amount := block.Number - block.haveHeaders + if amount == 0 { + return + } + if amount > 100 { + f.syncing = true + go func() { + //fmt.Println("f.pm.synchronise(p)") + f.pm.synchronise(p) + //fmt.Println("sync done") + f.syncDone <- struct{}{} + }() + return + } + + reqID := f.odr.getNextReqID() + f.reqMu.Lock() + f.requested[reqID] = fetchRequest{hash: block.Hash, amount: amount, peer: p} + f.reqMu.Unlock() + cost := p.GetRequestCost(GetBlockHeadersMsg, int(amount)) + p.fcServer.SendRequest(reqID, cost) + go p.RequestHeadersByHash(reqID, cost, block.Hash, int(amount), 0, true) + go func() { + time.Sleep(hardRequestTimeout) + f.timeoutChn <- reqID + }() +} + +func (f *lightFetcher) processResponse(req fetchRequest, resp fetchResponse) bool { + if uint64(len(resp.headers)) != req.amount || resp.headers[0].Hash() != req.hash { + return false + } + headers := make([]*types.Header, req.amount) + for i, header := range resp.headers { + headers[int(req.amount)-1-i] = header + } + if _, err := f.chain.InsertHeaderChain(headers, 1); err != nil { + return false + } + for _, header := range headers { + td := core.GetTd(f.pm.chainDb, header.Hash(), header.GetNumberU64()) + if td == nil { + return false + } + if td.Cmp(f.currentTd) > 0 { + f.currentTd = td + } + f.gotHeader(header) + } + return true +} + +func (f *lightFetcher) checkSyncedHeaders() { + //fmt.Println("checkSyncedHeaders()") + for _, peer := range f.pm.peers.AllPeers() { + peer.lock.Lock() + h := peer.firstHeadInfo + remove := false + loop: + for h != nil { + if td := core.GetTd(f.pm.chainDb, h.Hash, h.Number); td != nil { + //fmt.Println(" found", h.Number) + ok := peer.gotHeader(h.Hash, h.Number, td) + if !ok { + remove = true + break loop + } + if td.Cmp(f.currentTd) > 0 { + f.currentTd = td + } + } + h = h.next + } + peer.lock.Unlock() + if remove { + //fmt.Println("checkSync fail") + f.pm.removePeer(peer.id) + } + } +} + +func (f *lightFetcher) syncLoop() { + f.pm.wg.Add(1) + defer f.pm.wg.Done() + + srtoNotify := false + for { + select { + case <-f.pm.quitSync: + return + case ext := <-f.notifyChn: + //fmt.Println("<-f.notifyChn", f.syncing, ext, srtoNotify) + s := srtoNotify + srtoNotify = false + if !f.syncing && !(ext && s) { + if p, r := f.nextRequest(); r != nil { + srtoNotify = true + go func() { + time.Sleep(softRequestTimeout) + f.notifyChn <- false + }() + f.request(p, r) + } + } + case reqID := <-f.timeoutChn: + f.reqMu.Lock() + req, ok := f.requested[reqID] + if ok { + delete(f.requested, reqID) + } + f.reqMu.Unlock() + if ok { + //fmt.Println("hard timeout") + f.pm.removePeer(req.peer.id) + } + case resp := <-f.deliverChn: + //fmt.Println("<-f.deliverChn", f.syncing) + f.reqMu.Lock() + req, ok := f.requested[resp.reqID] + delete(f.requested, resp.reqID) + f.reqMu.Unlock() + if !ok || !(f.syncing || f.processResponse(req, resp)) { + //fmt.Println("processResponse fail") + f.pm.removePeer(req.peer.id) + } + case <-f.syncDone: + //fmt.Println("<-f.syncDone", f.syncing) + f.checkSyncedHeaders() + f.syncing = false + } + } +} diff --git a/les/flowcontrol/control.go b/les/flowcontrol/control.go new file mode 100644 index 000000000..1b569db0b --- /dev/null +++ b/les/flowcontrol/control.go @@ -0,0 +1,172 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// Package flowcontrol implements a client side flow control mechanism +package flowcontrol + +import ( + "sync" + "time" + + "github.com/ethereum/go-ethereum/common/mclock" +) + +const fcTimeConst = 1000000 + +type ServerParams struct { + BufLimit, MinRecharge uint64 +} + +type ClientNode struct { + params *ServerParams + bufValue uint64 + lastTime int64 + lock sync.Mutex + cm *ClientManager + cmNode *cmNode +} + +func NewClientNode(cm *ClientManager, params *ServerParams) *ClientNode { + node := &ClientNode{ + cm: cm, + params: params, + bufValue: params.BufLimit, + lastTime: getTime(), + } + node.cmNode = cm.addNode(node) + return node +} + +func (peer *ClientNode) Remove(cm *ClientManager) { + cm.removeNode(peer.cmNode) +} + +func (peer *ClientNode) recalcBV(time int64) { + dt := uint64(time - peer.lastTime) + if time < peer.lastTime { + dt = 0 + } + peer.bufValue += peer.params.MinRecharge * dt / fcTimeConst + if peer.bufValue > peer.params.BufLimit { + peer.bufValue = peer.params.BufLimit + } + peer.lastTime = time +} + +func (peer *ClientNode) AcceptRequest() (uint64, bool) { + peer.lock.Lock() + defer peer.lock.Unlock() + + time := getTime() + peer.recalcBV(time) + return peer.bufValue, peer.cm.accept(peer.cmNode, time) +} + +func (peer *ClientNode) RequestProcessed(cost uint64) (bv, realCost uint64) { + peer.lock.Lock() + defer peer.lock.Unlock() + + time := getTime() + peer.recalcBV(time) + peer.bufValue -= cost + peer.recalcBV(time) + rcValue, rcost := peer.cm.processed(peer.cmNode, time) + if rcValue < peer.params.BufLimit { + bv := peer.params.BufLimit - rcValue + if bv > peer.bufValue { + peer.bufValue = bv + } + } + return peer.bufValue, rcost +} + +type ServerNode struct { + bufEstimate uint64 + lastTime int64 + params *ServerParams + sumCost uint64 // sum of req costs sent to this server + pending map[uint64]uint64 // value = sumCost after sending the given req + lock sync.Mutex +} + +func NewServerNode(params *ServerParams) *ServerNode { + return &ServerNode{ + bufEstimate: params.BufLimit, + lastTime: getTime(), + params: params, + pending: make(map[uint64]uint64), + } +} + +func getTime() int64 { + return int64(mclock.Now()) +} + +func (peer *ServerNode) recalcBLE(time int64) { + dt := uint64(time - peer.lastTime) + if time < peer.lastTime { + dt = 0 + } + peer.bufEstimate += peer.params.MinRecharge * dt / fcTimeConst + if peer.bufEstimate > peer.params.BufLimit { + peer.bufEstimate = peer.params.BufLimit + } + peer.lastTime = time +} + +func (peer *ServerNode) canSend(maxCost uint64) uint64 { + if peer.bufEstimate >= maxCost { + return 0 + } + return (maxCost - peer.bufEstimate) * fcTimeConst / peer.params.MinRecharge +} + +func (peer *ServerNode) CanSend(maxCost uint64) uint64 { + peer.lock.Lock() + defer peer.lock.Unlock() + + return peer.canSend(maxCost) +} + +// blocks until request can be sent +func (peer *ServerNode) SendRequest(reqID, maxCost uint64) { + peer.lock.Lock() + defer peer.lock.Unlock() + + peer.recalcBLE(getTime()) + for peer.bufEstimate < maxCost { + time.Sleep(time.Duration(peer.canSend(maxCost))) + peer.recalcBLE(getTime()) + } + peer.bufEstimate -= maxCost + peer.sumCost += maxCost + if reqID >= 0 { + peer.pending[reqID] = peer.sumCost + } +} + +func (peer *ServerNode) GotReply(reqID, bv uint64) { + peer.lock.Lock() + defer peer.lock.Unlock() + + sc, ok := peer.pending[reqID] + if !ok { + return + } + delete(peer.pending, reqID) + peer.bufEstimate = bv - (peer.sumCost - sc) + peer.lastTime = getTime() +} diff --git a/les/flowcontrol/manager.go b/les/flowcontrol/manager.go new file mode 100644 index 000000000..c0469e7b6 --- /dev/null +++ b/les/flowcontrol/manager.go @@ -0,0 +1,223 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// Package flowcontrol implements a client side flow control mechanism +package flowcontrol + +import ( + "sync" + "time" +) + +const rcConst = 1000000 + +type cmNode struct { + node *ClientNode + lastUpdate int64 + reqAccepted int64 + serving, recharging bool + rcWeight uint64 + rcValue, rcDelta int64 + finishRecharge, startValue int64 +} + +func (node *cmNode) update(time int64) { + dt := time - node.lastUpdate + node.rcValue += node.rcDelta * dt / rcConst + node.lastUpdate = time + if node.recharging && time >= node.finishRecharge { + node.recharging = false + node.rcDelta = 0 + node.rcValue = 0 + } +} + +func (node *cmNode) set(serving bool, simReqCnt, sumWeight uint64) { + if node.serving && !serving { + node.recharging = true + sumWeight += node.rcWeight + } + node.serving = serving + if node.recharging && serving { + node.recharging = false + sumWeight -= node.rcWeight + } + + node.rcDelta = 0 + if serving { + node.rcDelta = int64(rcConst / simReqCnt) + } + if node.recharging { + node.rcDelta = -int64(node.node.cm.rcRecharge * node.rcWeight / sumWeight) + node.finishRecharge = node.lastUpdate + node.rcValue*rcConst/(-node.rcDelta) + } +} + +type ClientManager struct { + lock sync.Mutex + nodes map[*cmNode]struct{} + simReqCnt, sumWeight, rcSumValue uint64 + maxSimReq, maxRcSum uint64 + rcRecharge uint64 + resumeQueue chan chan bool + time int64 +} + +func NewClientManager(rcTarget, maxSimReq, maxRcSum uint64) *ClientManager { + cm := &ClientManager{ + nodes: make(map[*cmNode]struct{}), + resumeQueue: make(chan chan bool), + rcRecharge: rcConst * rcConst / (100*rcConst/rcTarget - rcConst), + maxSimReq: maxSimReq, + maxRcSum: maxRcSum, + } + go cm.queueProc() + return cm +} + +func (self *ClientManager) Stop() { + self.lock.Lock() + defer self.lock.Unlock() + + // signal any waiting accept routines to return false + self.nodes = make(map[*cmNode]struct{}) + close(self.resumeQueue) +} + +func (self *ClientManager) addNode(cnode *ClientNode) *cmNode { + time := getTime() + node := &cmNode{ + node: cnode, + lastUpdate: time, + finishRecharge: time, + rcWeight: 1, + } + self.lock.Lock() + defer self.lock.Unlock() + + self.nodes[node] = struct{}{} + self.update(getTime()) + return node +} + +func (self *ClientManager) removeNode(node *cmNode) { + self.lock.Lock() + defer self.lock.Unlock() + + time := getTime() + self.stop(node, time) + delete(self.nodes, node) + self.update(time) +} + +// recalc sumWeight +func (self *ClientManager) updateNodes(time int64) (rce bool) { + var sumWeight, rcSum uint64 + for node, _ := range self.nodes { + rc := node.recharging + node.update(time) + if rc && !node.recharging { + rce = true + } + if node.recharging { + sumWeight += node.rcWeight + } + rcSum += uint64(node.rcValue) + } + self.sumWeight = sumWeight + self.rcSumValue = rcSum + return +} + +func (self *ClientManager) update(time int64) { + for { + firstTime := time + for node, _ := range self.nodes { + if node.recharging && node.finishRecharge < firstTime { + firstTime = node.finishRecharge + } + } + if self.updateNodes(firstTime) { + for node, _ := range self.nodes { + if node.recharging { + node.set(node.serving, self.simReqCnt, self.sumWeight) + } + } + } else { + self.time = time + return + } + } +} + +func (self *ClientManager) canStartReq() bool { + return self.simReqCnt < self.maxSimReq && self.rcSumValue < self.maxRcSum +} + +func (self *ClientManager) queueProc() { + for rc := range self.resumeQueue { + for { + time.Sleep(time.Millisecond * 10) + self.lock.Lock() + self.update(getTime()) + cs := self.canStartReq() + self.lock.Unlock() + if cs { + break + } + } + close(rc) + } +} + +func (self *ClientManager) accept(node *cmNode, time int64) bool { + self.lock.Lock() + defer self.lock.Unlock() + + self.update(time) + if !self.canStartReq() { + resume := make(chan bool) + self.lock.Unlock() + self.resumeQueue <- resume + <-resume + self.lock.Lock() + if _, ok := self.nodes[node]; !ok { + return false // reject if node has been removed or manager has been stopped + } + } + self.simReqCnt++ + node.set(true, self.simReqCnt, self.sumWeight) + node.startValue = node.rcValue + self.update(self.time) + return true +} + +func (self *ClientManager) stop(node *cmNode, time int64) { + if node.serving { + self.update(time) + self.simReqCnt-- + node.set(false, self.simReqCnt, self.sumWeight) + self.update(time) + } +} + +func (self *ClientManager) processed(node *cmNode, time int64) (rcValue, rcCost uint64) { + self.lock.Lock() + defer self.lock.Unlock() + + self.stop(node, time) + return uint64(node.rcValue), uint64(node.rcValue - node.startValue) +} diff --git a/les/handler.go b/les/handler.go new file mode 100644 index 000000000..ef18af4d8 --- /dev/null +++ b/les/handler.go @@ -0,0 +1,854 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// Package les implements the Light Ethereum Subprotocol. +package les + +import ( + "encoding/binary" + "errors" + "fmt" + "math/big" + "sync" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/eth" + "github.com/ethereum/go-ethereum/eth/downloader" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/event" + "github.com/ethereum/go-ethereum/logger" + "github.com/ethereum/go-ethereum/logger/glog" + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/pow" + "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/trie" +) + +const ( + softResponseLimit = 2 * 1024 * 1024 // Target maximum size of returned blocks, headers or node data. + estHeaderRlpSize = 500 // Approximate size of an RLP encoded block header + + ethVersion = 63 // equivalent eth version for the downloader + + MaxHeaderFetch = 192 // Amount of block headers to be fetched per retrieval request + MaxBodyFetch = 32 // Amount of block bodies to be fetched per retrieval request + MaxReceiptFetch = 128 // Amount of transaction receipts to allow fetching per request + MaxCodeFetch = 64 // Amount of contract codes to allow fetching per request + MaxProofsFetch = 64 // Amount of merkle proofs to be fetched per retrieval request + MaxHeaderProofsFetch = 64 // Amount of merkle proofs to be fetched per retrieval request + MaxTxSend = 64 // Amount of transactions to be send per request + + disableClientRemovePeer = true +) + +// errIncompatibleConfig is returned if the requested protocols and configs are +// not compatible (low protocol version restrictions and high requirements). +var errIncompatibleConfig = errors.New("incompatible configuration") + +func errResp(code errCode, format string, v ...interface{}) error { + return fmt.Errorf("%v - %v", code, fmt.Sprintf(format, v...)) +} + +type hashFetcherFn func(common.Hash) error + +type BlockChain interface { + HasHeader(hash common.Hash) bool + GetHeader(hash common.Hash, number uint64) *types.Header + GetHeaderByHash(hash common.Hash) *types.Header + CurrentHeader() *types.Header + GetTdByHash(hash common.Hash) *big.Int + InsertHeaderChain(chain []*types.Header, checkFreq int) (int, error) + Rollback(chain []common.Hash) + Status() (td *big.Int, currentBlock common.Hash, genesisBlock common.Hash) + GetHeaderByNumber(number uint64) *types.Header + GetBlockHashesFromHash(hash common.Hash, max uint64) []common.Hash + LastBlockHash() common.Hash + Genesis() *types.Block +} + +type txPool interface { + // AddTransactions should add the given transactions to the pool. + AddBatch([]*types.Transaction) +} + +type ProtocolManager struct { + lightSync bool + txpool txPool + txrelay *LesTxRelay + networkId int + chainConfig *core.ChainConfig + blockchain BlockChain + chainDb ethdb.Database + odr *LesOdr + server *LesServer + + downloader *downloader.Downloader + fetcher *lightFetcher + peers *peerSet + + SubProtocols []p2p.Protocol + + eventMux *event.TypeMux + + // channels for fetcher, syncer, txsyncLoop + newPeerCh chan *peer + quitSync chan struct{} + noMorePeers chan struct{} + + syncMu sync.Mutex + syncing bool + syncDone chan struct{} + + // wait group is used for graceful shutdowns during downloading + // and processing + wg sync.WaitGroup +} + +// NewProtocolManager returns a new ethereum sub protocol manager. The Ethereum sub protocol manages peers capable +// with the ethereum network. +func NewProtocolManager(chainConfig *core.ChainConfig, lightSync bool, networkId int, mux *event.TypeMux, pow pow.PoW, blockchain BlockChain, txpool txPool, chainDb ethdb.Database, odr *LesOdr, txrelay *LesTxRelay) (*ProtocolManager, error) { + // Create the protocol manager with the base fields + manager := &ProtocolManager{ + lightSync: lightSync, + eventMux: mux, + blockchain: blockchain, + chainConfig: chainConfig, + chainDb: chainDb, + networkId: networkId, + txpool: txpool, + txrelay: txrelay, + odr: odr, + peers: newPeerSet(), + newPeerCh: make(chan *peer), + quitSync: make(chan struct{}), + noMorePeers: make(chan struct{}), + } + // Initiate a sub-protocol for every implemented version we can handle + manager.SubProtocols = make([]p2p.Protocol, 0, len(ProtocolVersions)) + for i, version := range ProtocolVersions { + // Compatible, initialize the sub-protocol + version := version // Closure for the run + manager.SubProtocols = append(manager.SubProtocols, p2p.Protocol{ + Name: "les", + Version: version, + Length: ProtocolLengths[i], + Run: func(p *p2p.Peer, rw p2p.MsgReadWriter) error { + peer := manager.newPeer(int(version), networkId, p, rw) + select { + case manager.newPeerCh <- peer: + manager.wg.Add(1) + defer manager.wg.Done() + return manager.handle(peer) + case <-manager.quitSync: + return p2p.DiscQuitting + } + }, + NodeInfo: func() interface{} { + return manager.NodeInfo() + }, + PeerInfo: func(id discover.NodeID) interface{} { + if p := manager.peers.Peer(fmt.Sprintf("%x", id[:8])); p != nil { + return p.Info() + } + return nil + }, + }) + } + if len(manager.SubProtocols) == 0 { + return nil, errIncompatibleConfig + } + + removePeer := manager.removePeer + if disableClientRemovePeer { + removePeer = func(id string) {} + } + + if lightSync { + glog.V(logger.Debug).Infof("LES: create downloader") + manager.downloader = downloader.New(downloader.LightSync, chainDb, manager.eventMux, blockchain.HasHeader, nil, blockchain.GetHeaderByHash, + nil, blockchain.CurrentHeader, nil, nil, nil, blockchain.GetTdByHash, + blockchain.InsertHeaderChain, nil, nil, blockchain.Rollback, removePeer) + manager.fetcher = newLightFetcher(manager) + } + + if odr != nil { + odr.removePeer = removePeer + } + + /*validator := func(block *types.Block, parent *types.Block) error { + return core.ValidateHeader(pow, block.Header(), parent.Header(), true, false) + } + heighter := func() uint64 { + return chainman.LastBlockNumberU64() + } + manager.fetcher = fetcher.New(chainman.GetBlockNoOdr, validator, nil, heighter, chainman.InsertChain, manager.removePeer) + */ + return manager, nil +} + +func (pm *ProtocolManager) removePeer(id string) { + // Short circuit if the peer was already removed + peer := pm.peers.Peer(id) + if peer == nil { + return + } + glog.V(logger.Debug).Infoln("Removing peer", id) + + // Unregister the peer from the downloader and Ethereum peer set + glog.V(logger.Debug).Infof("LES: unregister peer %v", id) + if pm.lightSync { + pm.downloader.UnregisterPeer(id) + pm.odr.UnregisterPeer(peer) + if pm.txrelay != nil { + pm.txrelay.removePeer(id) + } + } + if err := pm.peers.Unregister(id); err != nil { + glog.V(logger.Error).Infoln("Removal failed:", err) + } + // Hard disconnect at the networking layer + if peer != nil { + peer.Peer.Disconnect(p2p.DiscUselessPeer) + } +} + +func (pm *ProtocolManager) Start() { + if pm.lightSync { + // start sync handler + go pm.syncer() + } else { + go func() { + for range pm.newPeerCh { + } + }() + } +} + +func (pm *ProtocolManager) Stop() { + // Showing a log message. During download / process this could actually + // take between 5 to 10 seconds and therefor feedback is required. + glog.V(logger.Info).Infoln("Stopping light ethereum protocol handler...") + + // Quit the sync loop. + // After this send has completed, no new peers will be accepted. + pm.noMorePeers <- struct{}{} + + close(pm.quitSync) // quits syncer, fetcher + + // Disconnect existing sessions. + // This also closes the gate for any new registrations on the peer set. + // sessions which are already established but not added to pm.peers yet + // will exit when they try to register. + pm.peers.Close() + + // Wait for any process action + pm.wg.Wait() + + glog.V(logger.Info).Infoln("Light ethereum protocol handler stopped") +} + +func (pm *ProtocolManager) newPeer(pv, nv int, p *p2p.Peer, rw p2p.MsgReadWriter) *peer { + return newPeer(pv, nv, p, newMeteredMsgWriter(rw)) +} + +// handle is the callback invoked to manage the life cycle of a les peer. When +// this function terminates, the peer is disconnected. +func (pm *ProtocolManager) handle(p *peer) error { + glog.V(logger.Debug).Infof("%v: peer connected [%s]", p, p.Name()) + + // Execute the LES handshake + td, head, genesis := pm.blockchain.Status() + headNum := core.GetBlockNumber(pm.chainDb, head) + if err := p.Handshake(td, head, headNum, genesis, pm.server); err != nil { + glog.V(logger.Debug).Infof("%v: handshake failed: %v", p, err) + return err + } + if rw, ok := p.rw.(*meteredMsgReadWriter); ok { + rw.Init(p.version) + } + // Register the peer locally + glog.V(logger.Detail).Infof("%v: adding peer", p) + if err := pm.peers.Register(p); err != nil { + glog.V(logger.Error).Infof("%v: addition failed: %v", p, err) + return err + } + defer func() { + if pm.server != nil && pm.server.fcManager != nil && p.fcClient != nil { + p.fcClient.Remove(pm.server.fcManager) + } + pm.removePeer(p.id) + }() + + // Register the peer in the downloader. If the downloader considers it banned, we disconnect + glog.V(logger.Debug).Infof("LES: register peer %v", p.id) + if pm.lightSync { + requestHeadersByHash := func(origin common.Hash, amount int, skip int, reverse bool) error { + reqID := pm.odr.getNextReqID() + cost := p.GetRequestCost(GetBlockHeadersMsg, amount) + p.fcServer.SendRequest(reqID, cost) + return p.RequestHeadersByHash(reqID, cost, origin, amount, skip, reverse) + } + requestHeadersByNumber := func(origin uint64, amount int, skip int, reverse bool) error { + reqID := pm.odr.getNextReqID() + cost := p.GetRequestCost(GetBlockHeadersMsg, amount) + p.fcServer.SendRequest(reqID, cost) + return p.RequestHeadersByNumber(reqID, cost, origin, amount, skip, reverse) + } + if err := pm.downloader.RegisterPeer(p.id, ethVersion, p.HeadAndTd, + requestHeadersByHash, requestHeadersByNumber, nil, nil, nil); err != nil { + return err + } + pm.odr.RegisterPeer(p) + if pm.txrelay != nil { + pm.txrelay.addPeer(p) + } + + pm.fetcher.notify(p, nil) + } + + stop := make(chan struct{}) + defer close(stop) + go func() { + // new block announce loop + for { + select { + case announce := <-p.announceChn: + p.SendAnnounce(announce) + //fmt.Println(" BROADCAST sent") + case <-stop: + return + } + } + }() + + // main loop. handle incoming messages. + for { + if err := pm.handleMsg(p); err != nil { + glog.V(logger.Debug).Infof("%v: message handling failed: %v", p, err) + //fmt.Println("handleMsg err:", err) + return err + } + } +} + +var reqList = []uint64{GetBlockHeadersMsg, GetBlockBodiesMsg, GetCodeMsg, GetReceiptsMsg, GetProofsMsg, SendTxMsg, GetHeaderProofsMsg} + +// handleMsg is invoked whenever an inbound message is received from a remote +// peer. The remote connection is torn down upon returning any error. +func (pm *ProtocolManager) handleMsg(p *peer) error { + // Read the next message from the remote peer, and ensure it's fully consumed + msg, err := p.rw.ReadMsg() + if err != nil { + return err + } + + var costs *requestCosts + var reqCnt, maxReqs int + + //fmt.Println("MSG", msg.Code, msg.Size) + if rc, ok := p.fcCosts[msg.Code]; ok { // check if msg is a supported request type + costs = rc + if p.fcClient == nil { + return errResp(ErrRequestRejected, "") + } + bv, ok := p.fcClient.AcceptRequest() + if !ok || bv < costs.baseCost { + return errResp(ErrRequestRejected, "") + } + maxReqs = 10000 + if bv < pm.server.defParams.BufLimit { + d := bv - costs.baseCost + if d/10000 < costs.reqCost { + maxReqs = int(d / costs.reqCost) + } + } + } + + if msg.Size > ProtocolMaxMsgSize { + return errResp(ErrMsgTooLarge, "%v > %v", msg.Size, ProtocolMaxMsgSize) + } + defer msg.Discard() + + var deliverMsg *Msg + + // Handle the message depending on its contents + switch msg.Code { + case StatusMsg: + glog.V(logger.Debug).Infof("LES: received StatusMsg from peer %v", p.id) + // Status messages should never arrive after the handshake + return errResp(ErrExtraStatusMsg, "uncontrolled status message") + + // Block header query, collect the requested headers and reply + case AnnounceMsg: + var req announceData + if err := msg.Decode(&req); err != nil { + return errResp(ErrDecode, "%v: %v", msg, err) + } + //fmt.Println("RECEIVED", req.Number, req.Hash, req.Td, req.ReorgDepth) + pm.fetcher.notify(p, &req) + + case GetBlockHeadersMsg: + glog.V(logger.Debug).Infof("LES: received GetBlockHeadersMsg from peer %v", p.id) + // Decode the complex header query + var req struct { + ReqID uint64 + Query getBlockHeadersData + } + if err := msg.Decode(&req); err != nil { + return errResp(ErrDecode, "%v: %v", msg, err) + } + + query := req.Query + if query.Amount > uint64(maxReqs) || query.Amount > MaxHeaderFetch { + return errResp(ErrRequestRejected, "") + } + + hashMode := query.Origin.Hash != (common.Hash{}) + + // Gather headers until the fetch or network limits is reached + var ( + bytes common.StorageSize + headers []*types.Header + unknown bool + ) + for !unknown && len(headers) < int(query.Amount) && bytes < softResponseLimit { + // Retrieve the next header satisfying the query + var origin *types.Header + if hashMode { + origin = pm.blockchain.GetHeaderByHash(query.Origin.Hash) + } else { + origin = pm.blockchain.GetHeaderByNumber(query.Origin.Number) + } + if origin == nil { + break + } + number := origin.Number.Uint64() + headers = append(headers, origin) + bytes += estHeaderRlpSize + + // Advance to the next header of the query + switch { + case query.Origin.Hash != (common.Hash{}) && query.Reverse: + // Hash based traversal towards the genesis block + for i := 0; i < int(query.Skip)+1; i++ { + if header := pm.blockchain.GetHeader(query.Origin.Hash, number); header != nil { + query.Origin.Hash = header.ParentHash + number-- + } else { + unknown = true + break + } + } + case query.Origin.Hash != (common.Hash{}) && !query.Reverse: + // Hash based traversal towards the leaf block + if header := pm.blockchain.GetHeaderByNumber(origin.Number.Uint64() + query.Skip + 1); header != nil { + if pm.blockchain.GetBlockHashesFromHash(header.Hash(), query.Skip+1)[query.Skip] == query.Origin.Hash { + query.Origin.Hash = header.Hash() + } else { + unknown = true + } + } else { + unknown = true + } + case query.Reverse: + // Number based traversal towards the genesis block + if query.Origin.Number >= query.Skip+1 { + query.Origin.Number -= (query.Skip + 1) + } else { + unknown = true + } + + case !query.Reverse: + // Number based traversal towards the leaf block + query.Origin.Number += (query.Skip + 1) + } + } + + bv, rcost := p.fcClient.RequestProcessed(costs.baseCost + query.Amount*costs.reqCost) + pm.server.fcCostStats.update(msg.Code, query.Amount, rcost) + return p.SendBlockHeaders(req.ReqID, bv, headers) + + case BlockHeadersMsg: + if pm.downloader == nil { + return errResp(ErrUnexpectedResponse, "") + } + + glog.V(logger.Debug).Infof("LES: received BlockHeadersMsg from peer %v", p.id) + // A batch of headers arrived to one of our previous requests + var resp struct { + ReqID, BV uint64 + Headers []*types.Header + } + if err := msg.Decode(&resp); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + p.fcServer.GotReply(resp.ReqID, resp.BV) + if pm.fetcher.requestedID(resp.ReqID) { + pm.fetcher.deliverHeaders(resp.ReqID, resp.Headers) + } else { + err := pm.downloader.DeliverHeaders(p.id, resp.Headers) + if err != nil { + glog.V(logger.Debug).Infoln(err) + } + } + + case GetBlockBodiesMsg: + glog.V(logger.Debug).Infof("LES: received GetBlockBodiesMsg from peer %v", p.id) + // Decode the retrieval message + var req struct { + ReqID uint64 + Hashes []common.Hash + } + if err := msg.Decode(&req); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + // Gather blocks until the fetch or network limits is reached + var ( + bytes int + bodies []rlp.RawValue + ) + reqCnt = len(req.Hashes) + if reqCnt > maxReqs || reqCnt > MaxBodyFetch { + return errResp(ErrRequestRejected, "") + } + for _, hash := range req.Hashes { + if bytes >= softResponseLimit { + break + } + // Retrieve the requested block body, stopping if enough was found + if data := core.GetBodyRLP(pm.chainDb, hash, core.GetBlockNumber(pm.chainDb, hash)); len(data) != 0 { + bodies = append(bodies, data) + bytes += len(data) + } + } + bv, rcost := p.fcClient.RequestProcessed(costs.baseCost + uint64(reqCnt)*costs.reqCost) + pm.server.fcCostStats.update(msg.Code, uint64(reqCnt), rcost) + return p.SendBlockBodiesRLP(req.ReqID, bv, bodies) + + case BlockBodiesMsg: + if pm.odr == nil { + return errResp(ErrUnexpectedResponse, "") + } + + glog.V(logger.Debug).Infof("LES: received BlockBodiesMsg from peer %v", p.id) + // A batch of block bodies arrived to one of our previous requests + var resp struct { + ReqID, BV uint64 + Data []*types.Body + } + if err := msg.Decode(&resp); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + p.fcServer.GotReply(resp.ReqID, resp.BV) + deliverMsg = &Msg{ + MsgType: MsgBlockBodies, + ReqID: resp.ReqID, + Obj: resp.Data, + } + + case GetCodeMsg: + glog.V(logger.Debug).Infof("LES: received GetCodeMsg from peer %v", p.id) + // Decode the retrieval message + var req struct { + ReqID uint64 + Reqs []CodeReq + } + if err := msg.Decode(&req); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + // Gather state data until the fetch or network limits is reached + var ( + bytes int + data [][]byte + ) + reqCnt = len(req.Reqs) + if reqCnt > maxReqs || reqCnt > MaxCodeFetch { + return errResp(ErrRequestRejected, "") + } + for _, req := range req.Reqs { + // Retrieve the requested state entry, stopping if enough was found + if header := core.GetHeader(pm.chainDb, req.BHash, core.GetBlockNumber(pm.chainDb, req.BHash)); header != nil { + if trie, _ := trie.New(header.Root, pm.chainDb); trie != nil { + sdata := trie.Get(req.AccKey) + var acc state.Account + if err := rlp.DecodeBytes(sdata, &acc); err == nil { + entry, _ := pm.chainDb.Get(acc.CodeHash) + if bytes+len(entry) >= softResponseLimit { + break + } + data = append(data, entry) + bytes += len(entry) + } + } + } + } + bv, rcost := p.fcClient.RequestProcessed(costs.baseCost + uint64(reqCnt)*costs.reqCost) + pm.server.fcCostStats.update(msg.Code, uint64(reqCnt), rcost) + return p.SendCode(req.ReqID, bv, data) + + case CodeMsg: + if pm.odr == nil { + return errResp(ErrUnexpectedResponse, "") + } + + glog.V(logger.Debug).Infof("LES: received CodeMsg from peer %v", p.id) + // A batch of node state data arrived to one of our previous requests + var resp struct { + ReqID, BV uint64 + Data [][]byte + } + if err := msg.Decode(&resp); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + p.fcServer.GotReply(resp.ReqID, resp.BV) + deliverMsg = &Msg{ + MsgType: MsgCode, + ReqID: resp.ReqID, + Obj: resp.Data, + } + + case GetReceiptsMsg: + glog.V(logger.Debug).Infof("LES: received GetReceiptsMsg from peer %v", p.id) + // Decode the retrieval message + var req struct { + ReqID uint64 + Hashes []common.Hash + } + if err := msg.Decode(&req); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + // Gather state data until the fetch or network limits is reached + var ( + bytes int + receipts []rlp.RawValue + ) + reqCnt = len(req.Hashes) + if reqCnt > maxReqs || reqCnt > MaxReceiptFetch { + return errResp(ErrRequestRejected, "") + } + for _, hash := range req.Hashes { + if bytes >= softResponseLimit { + break + } + // Retrieve the requested block's receipts, skipping if unknown to us + results := core.GetBlockReceipts(pm.chainDb, hash, core.GetBlockNumber(pm.chainDb, hash)) + if results == nil { + if header := pm.blockchain.GetHeaderByHash(hash); header == nil || header.ReceiptHash != types.EmptyRootHash { + continue + } + } + // If known, encode and queue for response packet + if encoded, err := rlp.EncodeToBytes(results); err != nil { + glog.V(logger.Error).Infof("failed to encode receipt: %v", err) + } else { + receipts = append(receipts, encoded) + bytes += len(encoded) + } + } + bv, rcost := p.fcClient.RequestProcessed(costs.baseCost + uint64(reqCnt)*costs.reqCost) + pm.server.fcCostStats.update(msg.Code, uint64(reqCnt), rcost) + return p.SendReceiptsRLP(req.ReqID, bv, receipts) + + case ReceiptsMsg: + if pm.odr == nil { + return errResp(ErrUnexpectedResponse, "") + } + + glog.V(logger.Debug).Infof("LES: received ReceiptsMsg from peer %v", p.id) + // A batch of receipts arrived to one of our previous requests + var resp struct { + ReqID, BV uint64 + Receipts []types.Receipts + } + if err := msg.Decode(&resp); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + p.fcServer.GotReply(resp.ReqID, resp.BV) + deliverMsg = &Msg{ + MsgType: MsgReceipts, + ReqID: resp.ReqID, + Obj: resp.Receipts, + } + + case GetProofsMsg: + glog.V(logger.Debug).Infof("LES: received GetProofsMsg from peer %v", p.id) + // Decode the retrieval message + var req struct { + ReqID uint64 + Reqs []ProofReq + } + if err := msg.Decode(&req); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + // Gather state data until the fetch or network limits is reached + var ( + bytes int + proofs proofsData + ) + reqCnt = len(req.Reqs) + if reqCnt > maxReqs || reqCnt > MaxProofsFetch { + return errResp(ErrRequestRejected, "") + } + for _, req := range req.Reqs { + if bytes >= softResponseLimit { + break + } + // Retrieve the requested state entry, stopping if enough was found + if header := core.GetHeader(pm.chainDb, req.BHash, core.GetBlockNumber(pm.chainDb, req.BHash)); header != nil { + if tr, _ := trie.New(header.Root, pm.chainDb); tr != nil { + if len(req.AccKey) > 0 { + sdata := tr.Get(req.AccKey) + tr = nil + var acc state.Account + if err := rlp.DecodeBytes(sdata, &acc); err == nil { + tr, _ = trie.New(acc.Root, pm.chainDb) + } + } + if tr != nil { + proof := tr.Prove(req.Key) + proofs = append(proofs, proof) + bytes += len(proof) + } + } + } + } + bv, rcost := p.fcClient.RequestProcessed(costs.baseCost + uint64(reqCnt)*costs.reqCost) + pm.server.fcCostStats.update(msg.Code, uint64(reqCnt), rcost) + return p.SendProofs(req.ReqID, bv, proofs) + + case ProofsMsg: + if pm.odr == nil { + return errResp(ErrUnexpectedResponse, "") + } + + glog.V(logger.Debug).Infof("LES: received ProofsMsg from peer %v", p.id) + // A batch of merkle proofs arrived to one of our previous requests + var resp struct { + ReqID, BV uint64 + Data [][]rlp.RawValue + } + if err := msg.Decode(&resp); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + p.fcServer.GotReply(resp.ReqID, resp.BV) + deliverMsg = &Msg{ + MsgType: MsgProofs, + ReqID: resp.ReqID, + Obj: resp.Data, + } + + case GetHeaderProofsMsg: + glog.V(logger.Debug).Infof("LES: received GetHeaderProofsMsg from peer %v", p.id) + // Decode the retrieval message + var req struct { + ReqID uint64 + Reqs []ChtReq + } + if err := msg.Decode(&req); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + // Gather state data until the fetch or network limits is reached + var ( + bytes int + proofs []ChtResp + ) + reqCnt = len(req.Reqs) + if reqCnt > maxReqs || reqCnt > MaxHeaderProofsFetch { + return errResp(ErrRequestRejected, "") + } + for _, req := range req.Reqs { + if bytes >= softResponseLimit { + break + } + + if header := pm.blockchain.GetHeaderByNumber(req.BlockNum); header != nil { + if root := getChtRoot(pm.chainDb, req.ChtNum); root != (common.Hash{}) { + if tr, _ := trie.New(root, pm.chainDb); tr != nil { + var encNumber [8]byte + binary.BigEndian.PutUint64(encNumber[:], req.BlockNum) + proof := tr.Prove(encNumber[:]) + proofs = append(proofs, ChtResp{Header: header, Proof: proof}) + bytes += len(proof) + estHeaderRlpSize + } + } + } + } + bv, rcost := p.fcClient.RequestProcessed(costs.baseCost + uint64(reqCnt)*costs.reqCost) + pm.server.fcCostStats.update(msg.Code, uint64(reqCnt), rcost) + return p.SendHeaderProofs(req.ReqID, bv, proofs) + + case HeaderProofsMsg: + if pm.odr == nil { + return errResp(ErrUnexpectedResponse, "") + } + + glog.V(logger.Debug).Infof("LES: received HeaderProofsMsg from peer %v", p.id) + var resp struct { + ReqID, BV uint64 + Data []ChtResp + } + if err := msg.Decode(&resp); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + p.fcServer.GotReply(resp.ReqID, resp.BV) + deliverMsg = &Msg{ + MsgType: MsgHeaderProofs, + ReqID: resp.ReqID, + Obj: resp.Data, + } + + case SendTxMsg: + if pm.txpool == nil { + return errResp(ErrUnexpectedResponse, "") + } + // Transactions arrived, parse all of them and deliver to the pool + var txs []*types.Transaction + if err := msg.Decode(&txs); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + reqCnt = len(txs) + if reqCnt > maxReqs || reqCnt > MaxTxSend { + return errResp(ErrRequestRejected, "") + } + pm.txpool.AddBatch(txs) + _, rcost := p.fcClient.RequestProcessed(costs.baseCost + uint64(reqCnt)*costs.reqCost) + pm.server.fcCostStats.update(msg.Code, uint64(reqCnt), rcost) + + default: + glog.V(logger.Debug).Infof("LES: received unknown message with code %d from peer %v", msg.Code, p.id) + return errResp(ErrInvalidMsgCode, "%v", msg.Code) + } + + if deliverMsg != nil { + return pm.odr.Deliver(p, deliverMsg) + } + + return nil +} + +// NodeInfo retrieves some protocol metadata about the running host node. +func (self *ProtocolManager) NodeInfo() *eth.EthNodeInfo { + return ð.EthNodeInfo{ + Network: self.networkId, + Difficulty: self.blockchain.GetTdByHash(self.blockchain.LastBlockHash()), + Genesis: self.blockchain.Genesis().Hash(), + Head: self.blockchain.LastBlockHash(), + } +} diff --git a/les/handler_test.go b/les/handler_test.go new file mode 100644 index 000000000..2aa7a5590 --- /dev/null +++ b/les/handler_test.go @@ -0,0 +1,322 @@ +package les + +import ( + "math/rand" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/eth/downloader" + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/trie" +) + +func expectResponse(r p2p.MsgReader, msgcode, reqID, bv uint64, data interface{}) error { + type resp struct { + ReqID, BV uint64 + Data interface{} + } + return p2p.ExpectMsg(r, msgcode, resp{reqID, bv, data}) +} + +// Tests that block headers can be retrieved from a remote chain based on user queries. +func TestGetBlockHeadersLes1(t *testing.T) { testGetBlockHeaders(t, 1) } + +func testGetBlockHeaders(t *testing.T, protocol int) { + pm, _, _ := newTestProtocolManagerMust(t, false, downloader.MaxHashFetch+15, nil) + bc := pm.blockchain.(*core.BlockChain) + peer, _ := newTestPeer(t, "peer", protocol, pm, true) + defer peer.close() + + // Create a "random" unknown hash for testing + var unknown common.Hash + for i, _ := range unknown { + unknown[i] = byte(i) + } + // Create a batch of tests for various scenarios + limit := uint64(MaxHeaderFetch) + tests := []struct { + query *getBlockHeadersData // The query to execute for header retrieval + expect []common.Hash // The hashes of the block whose headers are expected + }{ + // A single random block should be retrievable by hash and number too + { + &getBlockHeadersData{Origin: hashOrNumber{Hash: bc.GetBlockByNumber(limit / 2).Hash()}, Amount: 1}, + []common.Hash{bc.GetBlockByNumber(limit / 2).Hash()}, + }, { + &getBlockHeadersData{Origin: hashOrNumber{Number: limit / 2}, Amount: 1}, + []common.Hash{bc.GetBlockByNumber(limit / 2).Hash()}, + }, + // Multiple headers should be retrievable in both directions + { + &getBlockHeadersData{Origin: hashOrNumber{Number: limit / 2}, Amount: 3}, + []common.Hash{ + bc.GetBlockByNumber(limit / 2).Hash(), + bc.GetBlockByNumber(limit/2 + 1).Hash(), + bc.GetBlockByNumber(limit/2 + 2).Hash(), + }, + }, { + &getBlockHeadersData{Origin: hashOrNumber{Number: limit / 2}, Amount: 3, Reverse: true}, + []common.Hash{ + bc.GetBlockByNumber(limit / 2).Hash(), + bc.GetBlockByNumber(limit/2 - 1).Hash(), + bc.GetBlockByNumber(limit/2 - 2).Hash(), + }, + }, + // Multiple headers with skip lists should be retrievable + { + &getBlockHeadersData{Origin: hashOrNumber{Number: limit / 2}, Skip: 3, Amount: 3}, + []common.Hash{ + bc.GetBlockByNumber(limit / 2).Hash(), + bc.GetBlockByNumber(limit/2 + 4).Hash(), + bc.GetBlockByNumber(limit/2 + 8).Hash(), + }, + }, { + &getBlockHeadersData{Origin: hashOrNumber{Number: limit / 2}, Skip: 3, Amount: 3, Reverse: true}, + []common.Hash{ + bc.GetBlockByNumber(limit / 2).Hash(), + bc.GetBlockByNumber(limit/2 - 4).Hash(), + bc.GetBlockByNumber(limit/2 - 8).Hash(), + }, + }, + // The chain endpoints should be retrievable + { + &getBlockHeadersData{Origin: hashOrNumber{Number: 0}, Amount: 1}, + []common.Hash{bc.GetBlockByNumber(0).Hash()}, + }, { + &getBlockHeadersData{Origin: hashOrNumber{Number: bc.CurrentBlock().NumberU64()}, Amount: 1}, + []common.Hash{bc.CurrentBlock().Hash()}, + }, + // Ensure protocol limits are honored + /*{ + &getBlockHeadersData{Origin: hashOrNumber{Number: bc.CurrentBlock().NumberU64() - 1}, Amount: limit + 10, Reverse: true}, + bc.GetBlockHashesFromHash(bc.CurrentBlock().Hash(), limit), + },*/ + // Check that requesting more than available is handled gracefully + { + &getBlockHeadersData{Origin: hashOrNumber{Number: bc.CurrentBlock().NumberU64() - 4}, Skip: 3, Amount: 3}, + []common.Hash{ + bc.GetBlockByNumber(bc.CurrentBlock().NumberU64() - 4).Hash(), + bc.GetBlockByNumber(bc.CurrentBlock().NumberU64()).Hash(), + }, + }, { + &getBlockHeadersData{Origin: hashOrNumber{Number: 4}, Skip: 3, Amount: 3, Reverse: true}, + []common.Hash{ + bc.GetBlockByNumber(4).Hash(), + bc.GetBlockByNumber(0).Hash(), + }, + }, + // Check that requesting more than available is handled gracefully, even if mid skip + { + &getBlockHeadersData{Origin: hashOrNumber{Number: bc.CurrentBlock().NumberU64() - 4}, Skip: 2, Amount: 3}, + []common.Hash{ + bc.GetBlockByNumber(bc.CurrentBlock().NumberU64() - 4).Hash(), + bc.GetBlockByNumber(bc.CurrentBlock().NumberU64() - 1).Hash(), + }, + }, { + &getBlockHeadersData{Origin: hashOrNumber{Number: 4}, Skip: 2, Amount: 3, Reverse: true}, + []common.Hash{ + bc.GetBlockByNumber(4).Hash(), + bc.GetBlockByNumber(1).Hash(), + }, + }, + // Check that non existing headers aren't returned + { + &getBlockHeadersData{Origin: hashOrNumber{Hash: unknown}, Amount: 1}, + []common.Hash{}, + }, { + &getBlockHeadersData{Origin: hashOrNumber{Number: bc.CurrentBlock().NumberU64() + 1}, Amount: 1}, + []common.Hash{}, + }, + } + // Run each of the tests and verify the results against the chain + var reqID uint64 + for i, tt := range tests { + // Collect the headers to expect in the response + headers := []*types.Header{} + for _, hash := range tt.expect { + headers = append(headers, bc.GetHeaderByHash(hash)) + } + // Send the hash request and verify the response + reqID++ + cost := peer.GetRequestCost(GetBlockHeadersMsg, int(tt.query.Amount)) + sendRequest(peer.app, GetBlockHeadersMsg, reqID, cost, tt.query) + if err := expectResponse(peer.app, BlockHeadersMsg, reqID, testBufLimit, headers); err != nil { + t.Errorf("test %d: headers mismatch: %v", i, err) + } + } +} + +// Tests that block contents can be retrieved from a remote chain based on their hashes. +func TestGetBlockBodiesLes1(t *testing.T) { testGetBlockBodies(t, 1) } + +func testGetBlockBodies(t *testing.T, protocol int) { + pm, _, _ := newTestProtocolManagerMust(t, false, downloader.MaxBlockFetch+15, nil) + bc := pm.blockchain.(*core.BlockChain) + peer, _ := newTestPeer(t, "peer", protocol, pm, true) + defer peer.close() + + // Create a batch of tests for various scenarios + limit := MaxBodyFetch + tests := []struct { + random int // Number of blocks to fetch randomly from the chain + explicit []common.Hash // Explicitly requested blocks + available []bool // Availability of explicitly requested blocks + expected int // Total number of existing blocks to expect + }{ + {1, nil, nil, 1}, // A single random block should be retrievable + {10, nil, nil, 10}, // Multiple random blocks should be retrievable + {limit, nil, nil, limit}, // The maximum possible blocks should be retrievable + //{limit + 1, nil, nil, limit}, // No more than the possible block count should be returned + {0, []common.Hash{bc.Genesis().Hash()}, []bool{true}, 1}, // The genesis block should be retrievable + {0, []common.Hash{bc.CurrentBlock().Hash()}, []bool{true}, 1}, // The chains head block should be retrievable + {0, []common.Hash{common.Hash{}}, []bool{false}, 0}, // A non existent block should not be returned + + // Existing and non-existing blocks interleaved should not cause problems + {0, []common.Hash{ + common.Hash{}, + bc.GetBlockByNumber(1).Hash(), + common.Hash{}, + bc.GetBlockByNumber(10).Hash(), + common.Hash{}, + bc.GetBlockByNumber(100).Hash(), + common.Hash{}, + }, []bool{false, true, false, true, false, true, false}, 3}, + } + // Run each of the tests and verify the results against the chain + var reqID uint64 + for i, tt := range tests { + // Collect the hashes to request, and the response to expect + hashes, seen := []common.Hash{}, make(map[int64]bool) + bodies := []*types.Body{} + + for j := 0; j < tt.random; j++ { + for { + num := rand.Int63n(int64(bc.CurrentBlock().NumberU64())) + if !seen[num] { + seen[num] = true + + block := bc.GetBlockByNumber(uint64(num)) + hashes = append(hashes, block.Hash()) + if len(bodies) < tt.expected { + bodies = append(bodies, &types.Body{Transactions: block.Transactions(), Uncles: block.Uncles()}) + } + break + } + } + } + for j, hash := range tt.explicit { + hashes = append(hashes, hash) + if tt.available[j] && len(bodies) < tt.expected { + block := bc.GetBlockByHash(hash) + bodies = append(bodies, &types.Body{Transactions: block.Transactions(), Uncles: block.Uncles()}) + } + } + reqID++ + // Send the hash request and verify the response + cost := peer.GetRequestCost(GetBlockBodiesMsg, len(hashes)) + sendRequest(peer.app, GetBlockBodiesMsg, reqID, cost, hashes) + if err := expectResponse(peer.app, BlockBodiesMsg, reqID, testBufLimit, bodies); err != nil { + t.Errorf("test %d: bodies mismatch: %v", i, err) + } + } +} + +// Tests that the contract codes can be retrieved based on account addresses. +func TestGetCodeLes1(t *testing.T) { testGetCode(t, 1) } + +func testGetCode(t *testing.T, protocol int) { + // Assemble the test environment + pm, _, _ := newTestProtocolManagerMust(t, false, 4, testChainGen) + bc := pm.blockchain.(*core.BlockChain) + peer, _ := newTestPeer(t, "peer", protocol, pm, true) + defer peer.close() + + var codereqs []*CodeReq + var codes [][]byte + + for i := uint64(0); i <= bc.CurrentBlock().NumberU64(); i++ { + header := bc.GetHeaderByNumber(i) + req := &CodeReq{ + BHash: header.Hash(), + AccKey: crypto.Keccak256(testContractAddr[:]), + } + codereqs = append(codereqs, req) + if i >= testContractDeployed { + codes = append(codes, testContractCodeDeployed) + } + } + + cost := peer.GetRequestCost(GetCodeMsg, len(codereqs)) + sendRequest(peer.app, GetCodeMsg, 42, cost, codereqs) + if err := expectResponse(peer.app, CodeMsg, 42, testBufLimit, codes); err != nil { + t.Errorf("codes mismatch: %v", err) + } +} + +// Tests that the transaction receipts can be retrieved based on hashes. +func TestGetReceiptLes1(t *testing.T) { testGetReceipt(t, 1) } + +func testGetReceipt(t *testing.T, protocol int) { + // Assemble the test environment + pm, db, _ := newTestProtocolManagerMust(t, false, 4, testChainGen) + bc := pm.blockchain.(*core.BlockChain) + peer, _ := newTestPeer(t, "peer", protocol, pm, true) + defer peer.close() + + // Collect the hashes to request, and the response to expect + hashes, receipts := []common.Hash{}, []types.Receipts{} + for i := uint64(0); i <= bc.CurrentBlock().NumberU64(); i++ { + block := bc.GetBlockByNumber(i) + + hashes = append(hashes, block.Hash()) + receipts = append(receipts, core.GetBlockReceipts(db, block.Hash(), block.NumberU64())) + } + // Send the hash request and verify the response + cost := peer.GetRequestCost(GetReceiptsMsg, len(hashes)) + sendRequest(peer.app, GetReceiptsMsg, 42, cost, hashes) + if err := expectResponse(peer.app, ReceiptsMsg, 42, testBufLimit, receipts); err != nil { + t.Errorf("receipts mismatch: %v", err) + } +} + +// Tests that trie merkle proofs can be retrieved +func TestGetProofsLes1(t *testing.T) { testGetReceipt(t, 1) } + +func testGetProofs(t *testing.T, protocol int) { + // Assemble the test environment + pm, db, _ := newTestProtocolManagerMust(t, false, 4, testChainGen) + bc := pm.blockchain.(*core.BlockChain) + peer, _ := newTestPeer(t, "peer", protocol, pm, true) + defer peer.close() + + var proofreqs []ProofReq + var proofs [][]rlp.RawValue + + accounts := []common.Address{testBankAddress, acc1Addr, acc2Addr, common.Address{}} + for i := uint64(0); i <= bc.CurrentBlock().NumberU64(); i++ { + header := bc.GetHeaderByNumber(i) + root := header.Root + trie, _ := trie.New(root, db) + + for _, acc := range accounts { + req := ProofReq{ + BHash: header.Hash(), + Key: acc[:], + } + proofreqs = append(proofreqs, req) + + proof := trie.Prove(crypto.Keccak256(acc[:])) + proofs = append(proofs, proof) + } + } + // Send the proof request and verify the response + cost := peer.GetRequestCost(GetProofsMsg, len(proofreqs)) + sendRequest(peer.app, GetProofsMsg, 42, cost, proofreqs) + if err := expectResponse(peer.app, ProofsMsg, 42, testBufLimit, proofs); err != nil { + t.Errorf("proofs mismatch: %v", err) + } +} diff --git a/les/helper_test.go b/les/helper_test.go new file mode 100644 index 000000000..9817e5004 --- /dev/null +++ b/les/helper_test.go @@ -0,0 +1,318 @@ +// This file contains some shares testing functionality, common to multiple +// different files and modules being tested. + +package les + +import ( + "crypto/ecdsa" + "crypto/rand" + "math/big" + "sync" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/event" + "github.com/ethereum/go-ethereum/les/flowcontrol" + "github.com/ethereum/go-ethereum/light" + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/params" +) + +var ( + testBankKey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") + testBankAddress = crypto.PubkeyToAddress(testBankKey.PublicKey) + testBankFunds = big.NewInt(1000000) + + acc1Key, _ = crypto.HexToECDSA("8a1f9a8f95be41cd7ccb6168179afb4504aefe388d1e14474d32c45c72ce7b7a") + acc2Key, _ = crypto.HexToECDSA("49a7b37aa6f6645917e7b807e9d1c00d4fa71f18343b0d4122a4d2df64dd6fee") + acc1Addr = crypto.PubkeyToAddress(acc1Key.PublicKey) + acc2Addr = crypto.PubkeyToAddress(acc2Key.PublicKey) + + testContractCode = common.Hex2Bytes("606060405260cc8060106000396000f360606040526000357c01000000000000000000000000000000000000000000000000000000009004806360cd2685146041578063c16431b914606b57603f565b005b6055600480803590602001909190505060a9565b6040518082815260200191505060405180910390f35b60886004808035906020019091908035906020019091905050608a565b005b80600060005083606481101560025790900160005b50819055505b5050565b6000600060005082606481101560025790900160005b5054905060c7565b91905056") + testContractAddr common.Address + testContractCodeDeployed = testContractCode[16:] + testContractDeployed = uint64(2) + + testBufLimit = uint64(100) +) + +/* +contract test { + + uint256[100] data; + + function Put(uint256 addr, uint256 value) { + data[addr] = value; + } + + function Get(uint256 addr) constant returns (uint256 value) { + return data[addr]; + } +} +*/ + +func testChainGen(i int, block *core.BlockGen) { + switch i { + case 0: + // In block 1, the test bank sends account #1 some ether. + tx, _ := types.NewTransaction(block.TxNonce(testBankAddress), acc1Addr, big.NewInt(10000), params.TxGas, nil, nil).SignECDSA(testBankKey) + block.AddTx(tx) + case 1: + // In block 2, the test bank sends some more ether to account #1. + // acc1Addr passes it on to account #2. + // acc1Addr creates a test contract. + tx1, _ := types.NewTransaction(block.TxNonce(testBankAddress), acc1Addr, big.NewInt(1000), params.TxGas, nil, nil).SignECDSA(testBankKey) + nonce := block.TxNonce(acc1Addr) + tx2, _ := types.NewTransaction(nonce, acc2Addr, big.NewInt(1000), params.TxGas, nil, nil).SignECDSA(acc1Key) + nonce++ + tx3, _ := types.NewContractCreation(nonce, big.NewInt(0), big.NewInt(200000), big.NewInt(0), testContractCode).SignECDSA(acc1Key) + testContractAddr = crypto.CreateAddress(acc1Addr, nonce) + block.AddTx(tx1) + block.AddTx(tx2) + block.AddTx(tx3) + case 2: + // Block 3 is empty but was mined by account #2. + block.SetCoinbase(acc2Addr) + block.SetExtra([]byte("yeehaw")) + data := common.Hex2Bytes("C16431B900000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000001") + tx, _ := types.NewTransaction(block.TxNonce(testBankAddress), testContractAddr, big.NewInt(0), big.NewInt(100000), nil, data).SignECDSA(testBankKey) + block.AddTx(tx) + case 3: + // Block 4 includes blocks 2 and 3 as uncle headers (with modified extra data). + b2 := block.PrevBlock(1).Header() + b2.Extra = []byte("foo") + block.AddUncle(b2) + b3 := block.PrevBlock(2).Header() + b3.Extra = []byte("foo") + block.AddUncle(b3) + data := common.Hex2Bytes("C16431B900000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000002") + tx, _ := types.NewTransaction(block.TxNonce(testBankAddress), testContractAddr, big.NewInt(0), big.NewInt(100000), nil, data).SignECDSA(testBankKey) + block.AddTx(tx) + } +} + +func testRCL() RequestCostList { + cl := make(RequestCostList, len(reqList)) + for i, code := range reqList { + cl[i].MsgCode = code + cl[i].BaseCost = 0 + cl[i].ReqCost = 0 + } + return cl +} + +// newTestProtocolManager creates a new protocol manager for testing purposes, +// with the given number of blocks already known, and potential notification +// channels for different events. +func newTestProtocolManager(lightSync bool, blocks int, generator func(int, *core.BlockGen)) (*ProtocolManager, ethdb.Database, *LesOdr, error) { + var ( + evmux = new(event.TypeMux) + pow = new(core.FakePow) + db, _ = ethdb.NewMemDatabase() + genesis = core.WriteGenesisBlockForTesting(db, core.GenesisAccount{Address: testBankAddress, Balance: testBankFunds}) + chainConfig = &core.ChainConfig{HomesteadBlock: big.NewInt(0)} // homestead set to 0 because of chain maker + odr *LesOdr + chain BlockChain + ) + + if lightSync { + odr = NewLesOdr(db) + chain, _ = light.NewLightChain(odr, chainConfig, pow, evmux) + } else { + blockchain, _ := core.NewBlockChain(db, chainConfig, pow, evmux) + gchain, _ := core.GenerateChain(nil, genesis, db, blocks, generator) + if _, err := blockchain.InsertChain(gchain); err != nil { + panic(err) + } + chain = blockchain + } + + pm, err := NewProtocolManager(chainConfig, lightSync, NetworkId, evmux, pow, chain, nil, db, odr, nil) + if err != nil { + return nil, nil, nil, err + } + if !lightSync { + srv := &LesServer{protocolManager: pm} + pm.server = srv + + srv.defParams = &flowcontrol.ServerParams{ + BufLimit: testBufLimit, + MinRecharge: 1, + } + + srv.fcManager = flowcontrol.NewClientManager(50, 10, 1000000000) + srv.fcCostStats = newCostStats(nil) + } + pm.Start() + return pm, db, odr, nil +} + +// newTestProtocolManagerMust creates a new protocol manager for testing purposes, +// with the given number of blocks already known, and potential notification +// channels for different events. In case of an error, the constructor force- +// fails the test. +func newTestProtocolManagerMust(t *testing.T, lightSync bool, blocks int, generator func(int, *core.BlockGen)) (*ProtocolManager, ethdb.Database, *LesOdr) { + pm, db, odr, err := newTestProtocolManager(lightSync, blocks, generator) + if err != nil { + t.Fatalf("Failed to create protocol manager: %v", err) + } + return pm, db, odr +} + +// testTxPool is a fake, helper transaction pool for testing purposes +type testTxPool struct { + pool []*types.Transaction // Collection of all transactions + added chan<- []*types.Transaction // Notification channel for new transactions + + lock sync.RWMutex // Protects the transaction pool +} + +// AddTransactions appends a batch of transactions to the pool, and notifies any +// listeners if the addition channel is non nil +func (p *testTxPool) AddBatch(txs []*types.Transaction) { + p.lock.Lock() + defer p.lock.Unlock() + + p.pool = append(p.pool, txs...) + if p.added != nil { + p.added <- txs + } +} + +// GetTransactions returns all the transactions known to the pool +func (p *testTxPool) GetTransactions() types.Transactions { + p.lock.RLock() + defer p.lock.RUnlock() + + txs := make([]*types.Transaction, len(p.pool)) + copy(txs, p.pool) + + return txs +} + +// newTestTransaction create a new dummy transaction. +func newTestTransaction(from *ecdsa.PrivateKey, nonce uint64, datasize int) *types.Transaction { + tx := types.NewTransaction(nonce, common.Address{}, big.NewInt(0), big.NewInt(100000), big.NewInt(0), make([]byte, datasize)) + tx, _ = tx.SignECDSA(from) + + return tx +} + +// testPeer is a simulated peer to allow testing direct network calls. +type testPeer struct { + net p2p.MsgReadWriter // Network layer reader/writer to simulate remote messaging + app *p2p.MsgPipeRW // Application layer reader/writer to simulate the local side + *peer +} + +// newTestPeer creates a new peer registered at the given protocol manager. +func newTestPeer(t *testing.T, name string, version int, pm *ProtocolManager, shake bool) (*testPeer, <-chan error) { + // Create a message pipe to communicate through + app, net := p2p.MsgPipe() + + // Generate a random id and create the peer + var id discover.NodeID + rand.Read(id[:]) + + peer := pm.newPeer(version, NetworkId, p2p.NewPeer(id, name, nil), net) + + // Start the peer on a new thread + errc := make(chan error, 1) + go func() { + select { + case pm.newPeerCh <- peer: + errc <- pm.handle(peer) + case <-pm.quitSync: + errc <- p2p.DiscQuitting + } + }() + tp := &testPeer{ + app: app, + net: net, + peer: peer, + } + // Execute any implicitly requested handshakes and return + if shake { + td, head, genesis := pm.blockchain.Status() + headNum := pm.blockchain.CurrentHeader().Number.Uint64() + tp.handshake(t, td, head, headNum, genesis) + } + return tp, errc +} + +func newTestPeerPair(name string, version int, pm, pm2 *ProtocolManager) (*peer, <-chan error, *peer, <-chan error) { + // Create a message pipe to communicate through + app, net := p2p.MsgPipe() + + // Generate a random id and create the peer + var id discover.NodeID + rand.Read(id[:]) + + peer := pm.newPeer(version, NetworkId, p2p.NewPeer(id, name, nil), net) + peer2 := pm2.newPeer(version, NetworkId, p2p.NewPeer(id, name, nil), app) + + // Start the peer on a new thread + errc := make(chan error, 1) + errc2 := make(chan error, 1) + go func() { + select { + case pm.newPeerCh <- peer: + errc <- pm.handle(peer) + case <-pm.quitSync: + errc <- p2p.DiscQuitting + } + }() + go func() { + select { + case pm2.newPeerCh <- peer2: + errc2 <- pm2.handle(peer2) + case <-pm2.quitSync: + errc2 <- p2p.DiscQuitting + } + }() + return peer, errc, peer2, errc2 +} + +// handshake simulates a trivial handshake that expects the same state from the +// remote side as we are simulating locally. +func (p *testPeer) handshake(t *testing.T, td *big.Int, head common.Hash, headNum uint64, genesis common.Hash) { + var expList keyValueList + expList = expList.add("protocolVersion", uint64(p.version)) + expList = expList.add("networkId", uint64(NetworkId)) + expList = expList.add("headTd", td) + expList = expList.add("headHash", head) + expList = expList.add("headNum", headNum) + expList = expList.add("genesisHash", genesis) + sendList := make(keyValueList, len(expList)) + copy(sendList, expList) + expList = expList.add("serveHeaders", nil) + expList = expList.add("serveChainSince", uint64(0)) + expList = expList.add("serveStateSince", uint64(0)) + expList = expList.add("txRelay", nil) + expList = expList.add("flowControl/BL", testBufLimit) + expList = expList.add("flowControl/MRR", uint64(1)) + expList = expList.add("flowControl/MRC", testRCL()) + + if err := p2p.ExpectMsg(p.app, StatusMsg, expList); err != nil { + t.Fatalf("status recv: %v", err) + } + if err := p2p.Send(p.app, StatusMsg, sendList); err != nil { + t.Fatalf("status send: %v", err) + } + + p.fcServerParams = &flowcontrol.ServerParams{ + BufLimit: testBufLimit, + MinRecharge: 1, + } +} + +// close terminates the local side of the peer, notifying the remote protocol +// manager of termination. +func (p *testPeer) close() { + p.app.Close() +} diff --git a/les/metrics.go b/les/metrics.go new file mode 100644 index 000000000..88e6726e2 --- /dev/null +++ b/les/metrics.go @@ -0,0 +1,111 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package les + +import ( + "github.com/ethereum/go-ethereum/metrics" + "github.com/ethereum/go-ethereum/p2p" +) + +var ( + /* propTxnInPacketsMeter = metrics.NewMeter("eth/prop/txns/in/packets") + propTxnInTrafficMeter = metrics.NewMeter("eth/prop/txns/in/traffic") + propTxnOutPacketsMeter = metrics.NewMeter("eth/prop/txns/out/packets") + propTxnOutTrafficMeter = metrics.NewMeter("eth/prop/txns/out/traffic") + propHashInPacketsMeter = metrics.NewMeter("eth/prop/hashes/in/packets") + propHashInTrafficMeter = metrics.NewMeter("eth/prop/hashes/in/traffic") + propHashOutPacketsMeter = metrics.NewMeter("eth/prop/hashes/out/packets") + propHashOutTrafficMeter = metrics.NewMeter("eth/prop/hashes/out/traffic") + propBlockInPacketsMeter = metrics.NewMeter("eth/prop/blocks/in/packets") + propBlockInTrafficMeter = metrics.NewMeter("eth/prop/blocks/in/traffic") + propBlockOutPacketsMeter = metrics.NewMeter("eth/prop/blocks/out/packets") + propBlockOutTrafficMeter = metrics.NewMeter("eth/prop/blocks/out/traffic") + reqHashInPacketsMeter = metrics.NewMeter("eth/req/hashes/in/packets") + reqHashInTrafficMeter = metrics.NewMeter("eth/req/hashes/in/traffic") + reqHashOutPacketsMeter = metrics.NewMeter("eth/req/hashes/out/packets") + reqHashOutTrafficMeter = metrics.NewMeter("eth/req/hashes/out/traffic") + reqBlockInPacketsMeter = metrics.NewMeter("eth/req/blocks/in/packets") + reqBlockInTrafficMeter = metrics.NewMeter("eth/req/blocks/in/traffic") + reqBlockOutPacketsMeter = metrics.NewMeter("eth/req/blocks/out/packets") + reqBlockOutTrafficMeter = metrics.NewMeter("eth/req/blocks/out/traffic") + reqHeaderInPacketsMeter = metrics.NewMeter("eth/req/headers/in/packets") + reqHeaderInTrafficMeter = metrics.NewMeter("eth/req/headers/in/traffic") + reqHeaderOutPacketsMeter = metrics.NewMeter("eth/req/headers/out/packets") + reqHeaderOutTrafficMeter = metrics.NewMeter("eth/req/headers/out/traffic") + reqBodyInPacketsMeter = metrics.NewMeter("eth/req/bodies/in/packets") + reqBodyInTrafficMeter = metrics.NewMeter("eth/req/bodies/in/traffic") + reqBodyOutPacketsMeter = metrics.NewMeter("eth/req/bodies/out/packets") + reqBodyOutTrafficMeter = metrics.NewMeter("eth/req/bodies/out/traffic") + reqStateInPacketsMeter = metrics.NewMeter("eth/req/states/in/packets") + reqStateInTrafficMeter = metrics.NewMeter("eth/req/states/in/traffic") + reqStateOutPacketsMeter = metrics.NewMeter("eth/req/states/out/packets") + reqStateOutTrafficMeter = metrics.NewMeter("eth/req/states/out/traffic") + reqReceiptInPacketsMeter = metrics.NewMeter("eth/req/receipts/in/packets") + reqReceiptInTrafficMeter = metrics.NewMeter("eth/req/receipts/in/traffic") + reqReceiptOutPacketsMeter = metrics.NewMeter("eth/req/receipts/out/packets") + reqReceiptOutTrafficMeter = metrics.NewMeter("eth/req/receipts/out/traffic")*/ + miscInPacketsMeter = metrics.NewMeter("les/misc/in/packets") + miscInTrafficMeter = metrics.NewMeter("les/misc/in/traffic") + miscOutPacketsMeter = metrics.NewMeter("les/misc/out/packets") + miscOutTrafficMeter = metrics.NewMeter("les/misc/out/traffic") +) + +// meteredMsgReadWriter is a wrapper around a p2p.MsgReadWriter, capable of +// accumulating the above defined metrics based on the data stream contents. +type meteredMsgReadWriter struct { + p2p.MsgReadWriter // Wrapped message stream to meter + version int // Protocol version to select correct meters +} + +// newMeteredMsgWriter wraps a p2p MsgReadWriter with metering support. If the +// metrics system is disabled, this fucntion returns the original object. +func newMeteredMsgWriter(rw p2p.MsgReadWriter) p2p.MsgReadWriter { + if !metrics.Enabled { + return rw + } + return &meteredMsgReadWriter{MsgReadWriter: rw} +} + +// Init sets the protocol version used by the stream to know which meters to +// increment in case of overlapping message ids between protocol versions. +func (rw *meteredMsgReadWriter) Init(version int) { + rw.version = version +} + +func (rw *meteredMsgReadWriter) ReadMsg() (p2p.Msg, error) { + // Read the message and short circuit in case of an error + msg, err := rw.MsgReadWriter.ReadMsg() + if err != nil { + return msg, err + } + // Account for the data traffic + packets, traffic := miscInPacketsMeter, miscInTrafficMeter + packets.Mark(1) + traffic.Mark(int64(msg.Size)) + + return msg, err +} + +func (rw *meteredMsgReadWriter) WriteMsg(msg p2p.Msg) error { + // Account for the data traffic + packets, traffic := miscOutPacketsMeter, miscOutTrafficMeter + packets.Mark(1) + traffic.Mark(int64(msg.Size)) + + // Send the packet to the p2p layer + return rw.MsgReadWriter.WriteMsg(msg) +} diff --git a/les/odr.go b/les/odr.go new file mode 100644 index 000000000..2674ba6a1 --- /dev/null +++ b/les/odr.go @@ -0,0 +1,247 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . +package les + +import ( + "sync" + "time" + + "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/light" + "github.com/ethereum/go-ethereum/logger" + "github.com/ethereum/go-ethereum/logger/glog" + "golang.org/x/net/context" +) + +var ( + softRequestTimeout = time.Millisecond * 500 + hardRequestTimeout = time.Second * 10 + retryPeers = time.Second * 1 +) + +// peerDropFn is a callback type for dropping a peer detected as malicious. +type peerDropFn func(id string) + +type LesOdr struct { + light.OdrBackend + db ethdb.Database + stop chan struct{} + removePeer peerDropFn + mlock, clock sync.Mutex + sentReqs map[uint64]*sentReq + peers *odrPeerSet + lastReqID uint64 +} + +func NewLesOdr(db ethdb.Database) *LesOdr { + return &LesOdr{ + db: db, + stop: make(chan struct{}), + peers: newOdrPeerSet(), + sentReqs: make(map[uint64]*sentReq), + } +} + +func (odr *LesOdr) Stop() { + close(odr.stop) +} + +func (odr *LesOdr) Database() ethdb.Database { + return odr.db +} + +// validatorFunc is a function that processes a message and returns true if +// it was a meaningful answer to a given request +type validatorFunc func(ethdb.Database, *Msg) bool + +// sentReq is a request waiting for an answer that satisfies its valFunc +type sentReq struct { + valFunc validatorFunc + sentTo map[*peer]chan struct{} + lock sync.RWMutex // protects acces to sentTo + answered chan struct{} // closed and set to nil when any peer answers it +} + +// RegisterPeer registers a new LES peer to the ODR capable peer set +func (self *LesOdr) RegisterPeer(p *peer) error { + return self.peers.register(p) +} + +// UnregisterPeer removes a peer from the ODR capable peer set +func (self *LesOdr) UnregisterPeer(p *peer) { + self.peers.unregister(p) +} + +const ( + MsgBlockBodies = iota + MsgCode + MsgReceipts + MsgProofs + MsgHeaderProofs +) + +// Msg encodes a LES message that delivers reply data for a request +type Msg struct { + MsgType int + ReqID uint64 + Obj interface{} +} + +// Deliver is called by the LES protocol manager to deliver ODR reply messages to waiting requests +func (self *LesOdr) Deliver(peer *peer, msg *Msg) error { + var delivered chan struct{} + self.mlock.Lock() + req, ok := self.sentReqs[msg.ReqID] + self.mlock.Unlock() + if ok { + req.lock.Lock() + delivered, ok = req.sentTo[peer] + req.lock.Unlock() + } + + if !ok { + return errResp(ErrUnexpectedResponse, "reqID = %v", msg.ReqID) + } + + if req.valFunc(self.db, msg) { + close(delivered) + req.lock.Lock() + if req.answered != nil { + close(req.answered) + req.answered = nil + } + req.lock.Unlock() + return nil + } + return errResp(ErrInvalidResponse, "reqID = %v", msg.ReqID) +} + +func (self *LesOdr) requestPeer(req *sentReq, peer *peer, delivered, timeout chan struct{}, reqWg *sync.WaitGroup) { + stime := mclock.Now() + defer func() { + req.lock.Lock() + delete(req.sentTo, peer) + req.lock.Unlock() + reqWg.Done() + }() + + select { + case <-delivered: + servTime := uint64(mclock.Now() - stime) + self.peers.updateTimeout(peer, false) + self.peers.updateServTime(peer, servTime) + return + case <-time.After(softRequestTimeout): + close(timeout) + if self.peers.updateTimeout(peer, true) { + self.removePeer(peer.id) + } + case <-self.stop: + return + } + + select { + case <-delivered: + servTime := uint64(mclock.Now() - stime) + self.peers.updateServTime(peer, servTime) + return + case <-time.After(hardRequestTimeout): + self.removePeer(peer.id) + case <-self.stop: + return + } +} + +// networkRequest sends a request to known peers until an answer is received +// or the context is cancelled +func (self *LesOdr) networkRequest(ctx context.Context, lreq LesOdrRequest) error { + answered := make(chan struct{}) + req := &sentReq{ + valFunc: lreq.Valid, + sentTo: make(map[*peer]chan struct{}), + answered: answered, // reply delivered by any peer + } + reqID := self.getNextReqID() + self.mlock.Lock() + self.sentReqs[reqID] = req + self.mlock.Unlock() + + reqWg := new(sync.WaitGroup) + reqWg.Add(1) + defer reqWg.Done() + go func() { + reqWg.Wait() + self.mlock.Lock() + delete(self.sentReqs, reqID) + self.mlock.Unlock() + }() + + exclude := make(map[*peer]struct{}) + for { + if peer := self.peers.bestPeer(lreq, exclude); peer == nil { + select { + case <-ctx.Done(): + return ctx.Err() + case <-req.answered: + return nil + case <-time.After(retryPeers): + } + } else { + exclude[peer] = struct{}{} + delivered := make(chan struct{}) + timeout := make(chan struct{}) + req.lock.Lock() + req.sentTo[peer] = delivered + req.lock.Unlock() + reqWg.Add(1) + cost := lreq.GetCost(peer) + peer.fcServer.SendRequest(reqID, cost) + go self.requestPeer(req, peer, delivered, timeout, reqWg) + lreq.Request(reqID, peer) + + select { + case <-ctx.Done(): + return ctx.Err() + case <-answered: + return nil + case <-timeout: + } + } + } +} + +// Retrieve tries to fetch an object from the local db, then from the LES network. +// If the network retrieval was successful, it stores the object in local db. +func (self *LesOdr) Retrieve(ctx context.Context, req light.OdrRequest) (err error) { + lreq := LesRequest(req) + err = self.networkRequest(ctx, lreq) + if err == nil { + // retrieved from network, store in db + req.StoreResult(self.db) + } else { + glog.V(logger.Debug).Infof("networkRequest err = %v", err) + } + return +} + +func (self *LesOdr) getNextReqID() uint64 { + self.clock.Lock() + defer self.clock.Unlock() + + self.lastReqID++ + return self.lastReqID +} diff --git a/les/odr_peerset.go b/les/odr_peerset.go new file mode 100644 index 000000000..0323ce27f --- /dev/null +++ b/les/odr_peerset.go @@ -0,0 +1,119 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . +package les + +import ( + "sync" +) + +const dropTimeoutRatio = 20 + +type odrPeerInfo struct { + reqTimeSum, reqTimeCnt, reqCnt, timeoutCnt uint64 +} + +// odrPeerSet represents the collection of active peer participating in the block +// download procedure. +type odrPeerSet struct { + peers map[*peer]*odrPeerInfo + lock sync.RWMutex +} + +// newPeerSet creates a new peer set top track the active download sources. +func newOdrPeerSet() *odrPeerSet { + return &odrPeerSet{ + peers: make(map[*peer]*odrPeerInfo), + } +} + +// Register injects a new peer into the working set, or returns an error if the +// peer is already known. +func (ps *odrPeerSet) register(p *peer) error { + ps.lock.Lock() + defer ps.lock.Unlock() + + if _, ok := ps.peers[p]; ok { + return errAlreadyRegistered + } + ps.peers[p] = &odrPeerInfo{} + return nil +} + +// Unregister removes a remote peer from the active set, disabling any further +// actions to/from that particular entity. +func (ps *odrPeerSet) unregister(p *peer) error { + ps.lock.Lock() + defer ps.lock.Unlock() + + if _, ok := ps.peers[p]; !ok { + return errNotRegistered + } + delete(ps.peers, p) + return nil +} + +func (ps *odrPeerSet) peerPriority(p *peer, info *odrPeerInfo, req LesOdrRequest) uint64 { + tm := p.fcServer.CanSend(req.GetCost(p)) + if info.reqTimeCnt > 0 { + tm += info.reqTimeSum / info.reqTimeCnt + } + return tm +} + +func (ps *odrPeerSet) bestPeer(req LesOdrRequest, exclude map[*peer]struct{}) *peer { + var best *peer + var bpv uint64 + ps.lock.Lock() + defer ps.lock.Unlock() + + for p, info := range ps.peers { + if _, ok := exclude[p]; !ok { + pv := ps.peerPriority(p, info, req) + if best == nil || pv < bpv { + best = p + bpv = pv + } + } + } + return best +} + +func (ps *odrPeerSet) updateTimeout(p *peer, timeout bool) (drop bool) { + ps.lock.Lock() + defer ps.lock.Unlock() + + if info, ok := ps.peers[p]; ok { + info.reqCnt++ + if timeout { + // check ratio before increase to allow an extra timeout + if info.timeoutCnt*dropTimeoutRatio >= info.reqCnt { + return true + } + info.timeoutCnt++ + } + } + return false +} + +func (ps *odrPeerSet) updateServTime(p *peer, servTime uint64) { + ps.lock.Lock() + defer ps.lock.Unlock() + + if info, ok := ps.peers[p]; ok { + info.reqTimeSum += servTime + info.reqTimeCnt++ + } +} diff --git a/les/odr_requests.go b/les/odr_requests.go new file mode 100644 index 000000000..bf0346977 --- /dev/null +++ b/les/odr_requests.go @@ -0,0 +1,325 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// Package light implements on-demand retrieval capable state and chain objects +// for the Ethereum Light Client. +package les + +import ( + "bytes" + "encoding/binary" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/light" + "github.com/ethereum/go-ethereum/logger" + "github.com/ethereum/go-ethereum/logger/glog" + "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/trie" +) + +type LesOdrRequest interface { + GetCost(*peer) uint64 + Request(uint64, *peer) error + Valid(ethdb.Database, *Msg) bool // if true, keeps the retrieved object +} + +func LesRequest(req light.OdrRequest) LesOdrRequest { + switch r := req.(type) { + case *light.BlockRequest: + return (*BlockRequest)(r) + case *light.ReceiptsRequest: + return (*ReceiptsRequest)(r) + case *light.TrieRequest: + return (*TrieRequest)(r) + case *light.CodeRequest: + return (*CodeRequest)(r) + case *light.ChtRequest: + return (*ChtRequest)(r) + default: + return nil + } +} + +// BlockRequest is the ODR request type for block bodies +type BlockRequest light.BlockRequest + +// GetCost returns the cost of the given ODR request according to the serving +// peer's cost table (implementation of LesOdrRequest) +func (self *BlockRequest) GetCost(peer *peer) uint64 { + return peer.GetRequestCost(GetBlockBodiesMsg, 1) +} + +// Request sends an ODR request to the LES network (implementation of LesOdrRequest) +func (self *BlockRequest) Request(reqID uint64, peer *peer) error { + glog.V(logger.Debug).Infof("ODR: requesting body of block %08x from peer %v", self.Hash[:4], peer.id) + return peer.RequestBodies(reqID, self.GetCost(peer), []common.Hash{self.Hash}) +} + +// Valid processes an ODR request reply message from the LES network +// returns true and stores results in memory if the message was a valid reply +// to the request (implementation of LesOdrRequest) +func (self *BlockRequest) Valid(db ethdb.Database, msg *Msg) bool { + glog.V(logger.Debug).Infof("ODR: validating body of block %08x", self.Hash[:4]) + if msg.MsgType != MsgBlockBodies { + glog.V(logger.Debug).Infof("ODR: invalid message type") + return false + } + bodies := msg.Obj.([]*types.Body) + if len(bodies) != 1 { + glog.V(logger.Debug).Infof("ODR: invalid number of entries: %d", len(bodies)) + return false + } + body := bodies[0] + header := core.GetHeader(db, self.Hash, self.Number) + if header == nil { + glog.V(logger.Debug).Infof("ODR: header not found for block %08x", self.Hash[:4]) + return false + } + txHash := types.DeriveSha(types.Transactions(body.Transactions)) + if header.TxHash != txHash { + glog.V(logger.Debug).Infof("ODR: header.TxHash %08x does not match received txHash %08x", header.TxHash[:4], txHash[:4]) + return false + } + uncleHash := types.CalcUncleHash(body.Uncles) + if header.UncleHash != uncleHash { + glog.V(logger.Debug).Infof("ODR: header.UncleHash %08x does not match received uncleHash %08x", header.UncleHash[:4], uncleHash[:4]) + return false + } + data, err := rlp.EncodeToBytes(body) + if err != nil { + glog.V(logger.Debug).Infof("ODR: body RLP encode error: %v", err) + return false + } + self.Rlp = data + glog.V(logger.Debug).Infof("ODR: validation successful") + return true +} + +// ReceiptsRequest is the ODR request type for block receipts by block hash +type ReceiptsRequest light.ReceiptsRequest + +// GetCost returns the cost of the given ODR request according to the serving +// peer's cost table (implementation of LesOdrRequest) +func (self *ReceiptsRequest) GetCost(peer *peer) uint64 { + return peer.GetRequestCost(GetReceiptsMsg, 1) +} + +// Request sends an ODR request to the LES network (implementation of LesOdrRequest) +func (self *ReceiptsRequest) Request(reqID uint64, peer *peer) error { + glog.V(logger.Debug).Infof("ODR: requesting receipts for block %08x from peer %v", self.Hash[:4], peer.id) + return peer.RequestReceipts(reqID, self.GetCost(peer), []common.Hash{self.Hash}) +} + +// Valid processes an ODR request reply message from the LES network +// returns true and stores results in memory if the message was a valid reply +// to the request (implementation of LesOdrRequest) +func (self *ReceiptsRequest) Valid(db ethdb.Database, msg *Msg) bool { + glog.V(logger.Debug).Infof("ODR: validating receipts for block %08x", self.Hash[:4]) + if msg.MsgType != MsgReceipts { + glog.V(logger.Debug).Infof("ODR: invalid message type") + return false + } + receipts := msg.Obj.([]types.Receipts) + if len(receipts) != 1 { + glog.V(logger.Debug).Infof("ODR: invalid number of entries: %d", len(receipts)) + return false + } + hash := types.DeriveSha(receipts[0]) + header := core.GetHeader(db, self.Hash, self.Number) + if header == nil { + glog.V(logger.Debug).Infof("ODR: header not found for block %08x", self.Hash[:4]) + return false + } + if !bytes.Equal(header.ReceiptHash[:], hash[:]) { + glog.V(logger.Debug).Infof("ODR: header receipts hash %08x does not match calculated RLP hash %08x", header.ReceiptHash[:4], hash[:4]) + return false + } + self.Receipts = receipts[0] + glog.V(logger.Debug).Infof("ODR: validation successful") + return true +} + +type ProofReq struct { + BHash common.Hash + AccKey, Key []byte + FromLevel uint +} + +// ODR request type for state/storage trie entries, see LesOdrRequest interface +type TrieRequest light.TrieRequest + +// GetCost returns the cost of the given ODR request according to the serving +// peer's cost table (implementation of LesOdrRequest) +func (self *TrieRequest) GetCost(peer *peer) uint64 { + return peer.GetRequestCost(GetProofsMsg, 1) +} + +// Request sends an ODR request to the LES network (implementation of LesOdrRequest) +func (self *TrieRequest) Request(reqID uint64, peer *peer) error { + glog.V(logger.Debug).Infof("ODR: requesting trie root %08x key %08x from peer %v", self.Id.Root[:4], self.Key[:4], peer.id) + req := &ProofReq{ + BHash: self.Id.BlockHash, + AccKey: self.Id.AccKey, + Key: self.Key, + } + return peer.RequestProofs(reqID, self.GetCost(peer), []*ProofReq{req}) +} + +// Valid processes an ODR request reply message from the LES network +// returns true and stores results in memory if the message was a valid reply +// to the request (implementation of LesOdrRequest) +func (self *TrieRequest) Valid(db ethdb.Database, msg *Msg) bool { + glog.V(logger.Debug).Infof("ODR: validating trie root %08x key %08x", self.Id.Root[:4], self.Key[:4]) + + if msg.MsgType != MsgProofs { + glog.V(logger.Debug).Infof("ODR: invalid message type") + return false + } + proofs := msg.Obj.([][]rlp.RawValue) + if len(proofs) != 1 { + glog.V(logger.Debug).Infof("ODR: invalid number of entries: %d", len(proofs)) + return false + } + _, err := trie.VerifyProof(self.Id.Root, self.Key, proofs[0]) + if err != nil { + glog.V(logger.Debug).Infof("ODR: merkle proof verification error: %v", err) + return false + } + self.Proof = proofs[0] + glog.V(logger.Debug).Infof("ODR: validation successful") + return true +} + +type CodeReq struct { + BHash common.Hash + AccKey []byte +} + +// ODR request type for node data (used for retrieving contract code), see LesOdrRequest interface +type CodeRequest light.CodeRequest + +// GetCost returns the cost of the given ODR request according to the serving +// peer's cost table (implementation of LesOdrRequest) +func (self *CodeRequest) GetCost(peer *peer) uint64 { + return peer.GetRequestCost(GetCodeMsg, 1) +} + +// Request sends an ODR request to the LES network (implementation of LesOdrRequest) +func (self *CodeRequest) Request(reqID uint64, peer *peer) error { + glog.V(logger.Debug).Infof("ODR: requesting node data for hash %08x from peer %v", self.Hash[:4], peer.id) + req := &CodeReq{ + BHash: self.Id.BlockHash, + AccKey: self.Id.AccKey, + } + return peer.RequestCode(reqID, self.GetCost(peer), []*CodeReq{req}) +} + +// Valid processes an ODR request reply message from the LES network +// returns true and stores results in memory if the message was a valid reply +// to the request (implementation of LesOdrRequest) +func (self *CodeRequest) Valid(db ethdb.Database, msg *Msg) bool { + glog.V(logger.Debug).Infof("ODR: validating node data for hash %08x", self.Hash[:4]) + if msg.MsgType != MsgCode { + glog.V(logger.Debug).Infof("ODR: invalid message type") + return false + } + reply := msg.Obj.([][]byte) + if len(reply) != 1 { + glog.V(logger.Debug).Infof("ODR: invalid number of entries: %d", len(reply)) + return false + } + data := reply[0] + hash := crypto.Sha3Hash(data) + if !bytes.Equal(self.Hash[:], hash[:]) { + glog.V(logger.Debug).Infof("ODR: requested hash %08x does not match received data hash %08x", self.Hash[:4], hash[:4]) + return false + } + self.Data = data + glog.V(logger.Debug).Infof("ODR: validation successful") + return true +} + +type ChtReq struct { + ChtNum, BlockNum, FromLevel uint64 +} + +type ChtResp struct { + Header *types.Header + Proof []rlp.RawValue +} + +// ODR request type for requesting headers by Canonical Hash Trie, see LesOdrRequest interface +type ChtRequest light.ChtRequest + +// GetCost returns the cost of the given ODR request according to the serving +// peer's cost table (implementation of LesOdrRequest) +func (self *ChtRequest) GetCost(peer *peer) uint64 { + return peer.GetRequestCost(GetHeaderProofsMsg, 1) +} + +// Request sends an ODR request to the LES network (implementation of LesOdrRequest) +func (self *ChtRequest) Request(reqID uint64, peer *peer) error { + glog.V(logger.Debug).Infof("ODR: requesting CHT #%d block #%d from peer %v", self.ChtNum, self.BlockNum, peer.id) + req := &ChtReq{ + ChtNum: self.ChtNum, + BlockNum: self.BlockNum, + } + return peer.RequestHeaderProofs(reqID, self.GetCost(peer), []*ChtReq{req}) +} + +// Valid processes an ODR request reply message from the LES network +// returns true and stores results in memory if the message was a valid reply +// to the request (implementation of LesOdrRequest) +func (self *ChtRequest) Valid(db ethdb.Database, msg *Msg) bool { + glog.V(logger.Debug).Infof("ODR: validating CHT #%d block #%d", self.ChtNum, self.BlockNum) + + if msg.MsgType != MsgHeaderProofs { + glog.V(logger.Debug).Infof("ODR: invalid message type") + return false + } + proofs := msg.Obj.([]ChtResp) + if len(proofs) != 1 { + glog.V(logger.Debug).Infof("ODR: invalid number of entries: %d", len(proofs)) + return false + } + proof := proofs[0] + var encNumber [8]byte + binary.BigEndian.PutUint64(encNumber[:], self.BlockNum) + value, err := trie.VerifyProof(self.ChtRoot, encNumber[:], proof.Proof) + if err != nil { + glog.V(logger.Debug).Infof("ODR: CHT merkle proof verification error: %v", err) + return false + } + var node light.ChtNode + if err := rlp.DecodeBytes(value, &node); err != nil { + glog.V(logger.Debug).Infof("ODR: error decoding CHT node: %v", err) + return false + } + if node.Hash != proof.Header.Hash() { + glog.V(logger.Debug).Infof("ODR: CHT header hash does not match") + return false + } + + self.Proof = proof.Proof + self.Header = proof.Header + self.Td = node.Td + glog.V(logger.Debug).Infof("ODR: validation successful") + return true +} diff --git a/les/odr_test.go b/les/odr_test.go new file mode 100644 index 000000000..bd52a82dd --- /dev/null +++ b/les/odr_test.go @@ -0,0 +1,222 @@ +package les + +import ( + "bytes" + "math/big" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/light" + "github.com/ethereum/go-ethereum/rlp" + "golang.org/x/net/context" +) + +type odrTestFn func(ctx context.Context, db ethdb.Database, config *core.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte + +func TestOdrGetBlockLes1(t *testing.T) { testOdr(t, 1, 1, odrGetBlock) } + +func odrGetBlock(ctx context.Context, db ethdb.Database, config *core.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte { + var block *types.Block + if bc != nil { + block = bc.GetBlockByHash(bhash) + } else { + block, _ = lc.GetBlockByHash(ctx, bhash) + } + if block == nil { + return nil + } + rlp, _ := rlp.EncodeToBytes(block) + return rlp +} + +func TestOdrGetReceiptsLes1(t *testing.T) { testOdr(t, 1, 1, odrGetReceipts) } + +func odrGetReceipts(ctx context.Context, db ethdb.Database, config *core.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte { + var receipts types.Receipts + if bc != nil { + receipts = core.GetBlockReceipts(db, bhash, core.GetBlockNumber(db, bhash)) + } else { + receipts, _ = light.GetBlockReceipts(ctx, lc.Odr(), bhash, core.GetBlockNumber(db, bhash)) + } + if receipts == nil { + return nil + } + rlp, _ := rlp.EncodeToBytes(receipts) + return rlp +} + +func TestOdrAccountsLes1(t *testing.T) { testOdr(t, 1, 1, odrAccounts) } + +func odrAccounts(ctx context.Context, db ethdb.Database, config *core.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte { + dummyAddr := common.HexToAddress("1234567812345678123456781234567812345678") + acc := []common.Address{testBankAddress, acc1Addr, acc2Addr, dummyAddr} + + var res []byte + for _, addr := range acc { + if bc != nil { + header := bc.GetHeaderByHash(bhash) + st, err := state.New(header.Root, db) + if err == nil { + bal := st.GetBalance(addr) + rlp, _ := rlp.EncodeToBytes(bal) + res = append(res, rlp...) + } + } else { + header := lc.GetHeaderByHash(bhash) + st := light.NewLightState(light.StateTrieID(header), lc.Odr()) + bal, err := st.GetBalance(ctx, addr) + if err == nil { + rlp, _ := rlp.EncodeToBytes(bal) + res = append(res, rlp...) + } + } + } + + return res +} + +func TestOdrContractCallLes1(t *testing.T) { testOdr(t, 1, 2, odrContractCall) } + +// fullcallmsg is the message type used for call transations. +type fullcallmsg struct { + from *state.StateObject + to *common.Address + gas, gasPrice *big.Int + value *big.Int + data []byte +} + +// accessor boilerplate to implement core.Message +func (m fullcallmsg) From() (common.Address, error) { return m.from.Address(), nil } +func (m fullcallmsg) FromFrontier() (common.Address, error) { return m.from.Address(), nil } +func (m fullcallmsg) Nonce() uint64 { return 0 } +func (m fullcallmsg) CheckNonce() bool { return false } +func (m fullcallmsg) To() *common.Address { return m.to } +func (m fullcallmsg) GasPrice() *big.Int { return m.gasPrice } +func (m fullcallmsg) Gas() *big.Int { return m.gas } +func (m fullcallmsg) Value() *big.Int { return m.value } +func (m fullcallmsg) Data() []byte { return m.data } + +// callmsg is the message type used for call transations. +type lightcallmsg struct { + from *light.StateObject + to *common.Address + gas, gasPrice *big.Int + value *big.Int + data []byte +} + +// accessor boilerplate to implement core.Message +func (m lightcallmsg) From() (common.Address, error) { return m.from.Address(), nil } +func (m lightcallmsg) FromFrontier() (common.Address, error) { return m.from.Address(), nil } +func (m lightcallmsg) Nonce() uint64 { return 0 } +func (m lightcallmsg) CheckNonce() bool { return false } +func (m lightcallmsg) To() *common.Address { return m.to } +func (m lightcallmsg) GasPrice() *big.Int { return m.gasPrice } +func (m lightcallmsg) Gas() *big.Int { return m.gas } +func (m lightcallmsg) Value() *big.Int { return m.value } +func (m lightcallmsg) Data() []byte { return m.data } + +func odrContractCall(ctx context.Context, db ethdb.Database, config *core.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte { + data := common.Hex2Bytes("60CD26850000000000000000000000000000000000000000000000000000000000000000") + + var res []byte + for i := 0; i < 3; i++ { + data[35] = byte(i) + if bc != nil { + header := bc.GetHeaderByHash(bhash) + statedb, err := state.New(header.Root, db) + if err == nil { + from := statedb.GetOrNewStateObject(testBankAddress) + from.SetBalance(common.MaxBig) + + msg := fullcallmsg{ + from: from, + gas: big.NewInt(100000), + gasPrice: big.NewInt(0), + value: big.NewInt(0), + data: data, + to: &testContractAddr, + } + + vmenv := core.NewEnv(statedb, config, bc, msg, header, config.VmConfig) + gp := new(core.GasPool).AddGas(common.MaxBig) + ret, _, _ := core.ApplyMessage(vmenv, msg, gp) + res = append(res, ret...) + } + } else { + header := lc.GetHeaderByHash(bhash) + state := light.NewLightState(light.StateTrieID(header), lc.Odr()) + from, err := state.GetOrNewStateObject(ctx, testBankAddress) + if err == nil { + from.SetBalance(common.MaxBig) + + msg := lightcallmsg{ + from: from, + gas: big.NewInt(100000), + gasPrice: big.NewInt(0), + value: big.NewInt(0), + data: data, + to: &testContractAddr, + } + + vmenv := light.NewEnv(ctx, state, config, lc, msg, header, config.VmConfig) + gp := new(core.GasPool).AddGas(common.MaxBig) + ret, _, _ := core.ApplyMessage(vmenv, msg, gp) + if vmenv.Error() == nil { + res = append(res, ret...) + } + } + } + } + return res +} + +func testOdr(t *testing.T, protocol int, expFail uint64, fn odrTestFn) { + // Assemble the test environment + pm, db, odr := newTestProtocolManagerMust(t, false, 4, testChainGen) + lpm, ldb, odr := newTestProtocolManagerMust(t, true, 0, nil) + _, err1, lpeer, err2 := newTestPeerPair("peer", protocol, pm, lpm) + select { + case <-time.After(time.Millisecond * 100): + case err := <-err1: + t.Fatalf("peer 1 handshake error: %v", err) + case err := <-err2: + t.Fatalf("peer 1 handshake error: %v", err) + } + + lpm.synchronise(lpeer) + + test := func(expFail uint64) { + for i := uint64(0); i <= pm.blockchain.CurrentHeader().GetNumberU64(); i++ { + bhash := core.GetCanonicalHash(db, i) + b1 := fn(light.NoOdr, db, pm.chainConfig, pm.blockchain.(*core.BlockChain), nil, bhash) + ctx, _ := context.WithTimeout(context.Background(), 200*time.Millisecond) + b2 := fn(ctx, ldb, lpm.chainConfig, nil, lpm.blockchain.(*light.LightChain), bhash) + eq := bytes.Equal(b1, b2) + exp := i < expFail + if exp && !eq { + t.Errorf("odr mismatch") + } + if !exp && eq { + t.Errorf("unexpected odr match") + } + } + } + + // temporarily remove peer to test odr fails + odr.UnregisterPeer(lpeer) + // expect retrievals to fail (except genesis block) without a les peer + test(expFail) + odr.RegisterPeer(lpeer) + // expect all retrievals to pass + test(5) + odr.UnregisterPeer(lpeer) + // still expect all retrievals to pass, now data should be cached locally + test(5) +} diff --git a/les/peer.go b/les/peer.go new file mode 100644 index 000000000..dbddbb020 --- /dev/null +++ b/les/peer.go @@ -0,0 +1,584 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// Package les implements the Light Ethereum Subprotocol. +package les + +import ( + "errors" + "fmt" + "math/big" + "sync" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/eth" + "github.com/ethereum/go-ethereum/les/flowcontrol" + "github.com/ethereum/go-ethereum/logger" + "github.com/ethereum/go-ethereum/logger/glog" + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/rlp" +) + +var ( + errClosed = errors.New("peer set is closed") + errAlreadyRegistered = errors.New("peer is already registered") + errNotRegistered = errors.New("peer is not registered") +) + +const maxHeadInfoLen = 20 + +type peer struct { + *p2p.Peer + + rw p2p.MsgReadWriter + + version int // Protocol version negotiated + network int // Network ID being on + + id string + + firstHeadInfo, headInfo *announceData + headInfoLen int + lock sync.RWMutex + + announceChn chan announceData + + fcClient *flowcontrol.ClientNode // nil if the peer is server only + fcServer *flowcontrol.ServerNode // nil if the peer is client only + fcServerParams *flowcontrol.ServerParams + fcCosts requestCostTable +} + +func newPeer(version, network int, p *p2p.Peer, rw p2p.MsgReadWriter) *peer { + id := p.ID() + + return &peer{ + Peer: p, + rw: rw, + version: version, + network: network, + id: fmt.Sprintf("%x", id[:8]), + announceChn: make(chan announceData, 20), + } +} + +// Info gathers and returns a collection of metadata known about a peer. +func (p *peer) Info() *eth.PeerInfo { + return ð.PeerInfo{ + Version: p.version, + Difficulty: p.Td(), + Head: fmt.Sprintf("%x", p.Head()), + } +} + +// Head retrieves a copy of the current head (most recent) hash of the peer. +func (p *peer) Head() (hash common.Hash) { + p.lock.RLock() + defer p.lock.RUnlock() + + copy(hash[:], p.headInfo.Hash[:]) + return hash +} + +func (p *peer) HeadAndTd() (hash common.Hash, td *big.Int) { + p.lock.RLock() + defer p.lock.RUnlock() + + copy(hash[:], p.headInfo.Hash[:]) + return hash, p.headInfo.Td +} + +func (p *peer) headBlockInfo() blockInfo { + p.lock.RLock() + defer p.lock.RUnlock() + + return blockInfo{Hash: p.headInfo.Hash, Number: p.headInfo.Number, Td: p.headInfo.Td} +} + +func (p *peer) addNotify(announce *announceData) bool { + p.lock.Lock() + defer p.lock.Unlock() + + if announce.Td.Cmp(p.headInfo.Td) < 1 { + return false + } + if p.headInfoLen >= maxHeadInfoLen { + //return false + p.firstHeadInfo = p.firstHeadInfo.next + p.headInfoLen-- + } + if announce.haveHeaders == 0 { + hh := p.headInfo.Number - announce.ReorgDepth + if p.headInfo.haveHeaders < hh { + hh = p.headInfo.haveHeaders + } + announce.haveHeaders = hh + } + p.headInfo.next = announce + p.headInfo = announce + p.headInfoLen++ + return true +} + +func (p *peer) gotHeader(hash common.Hash, number uint64, td *big.Int) bool { + h := p.firstHeadInfo + ptr := 0 + for h != nil { + if h.Hash == hash { + if h.Number != number || h.Td.Cmp(td) != 0 { + return false + } + h.headKnown = true + h.haveHeaders = h.Number + p.firstHeadInfo = h + p.headInfoLen -= ptr + last := h + h = h.next + // propagate haveHeaders through the chain + for h != nil { + hh := last.Number - h.ReorgDepth + if last.haveHeaders < hh { + hh = last.haveHeaders + } + if hh > h.haveHeaders { + h.haveHeaders = hh + } else { + return true + } + last = h + h = h.next + } + return true + } + h = h.next + ptr++ + } + return true +} + +// Td retrieves the current total difficulty of a peer. +func (p *peer) Td() *big.Int { + p.lock.RLock() + defer p.lock.RUnlock() + + return new(big.Int).Set(p.headInfo.Td) +} + +func sendRequest(w p2p.MsgWriter, msgcode, reqID, cost uint64, data interface{}) error { + type req struct { + ReqID uint64 + Data interface{} + } + return p2p.Send(w, msgcode, req{reqID, data}) +} + +func sendResponse(w p2p.MsgWriter, msgcode, reqID, bv uint64, data interface{}) error { + type resp struct { + ReqID, BV uint64 + Data interface{} + } + return p2p.Send(w, msgcode, resp{reqID, bv, data}) +} + +func (p *peer) GetRequestCost(msgcode uint64, amount int) uint64 { + cost := p.fcCosts[msgcode].baseCost + p.fcCosts[msgcode].reqCost*uint64(amount) + if cost > p.fcServerParams.BufLimit { + cost = p.fcServerParams.BufLimit + } + return cost +} + +// SendAnnounce announces the availability of a number of blocks through +// a hash notification. +func (p *peer) SendAnnounce(request announceData) error { + return p2p.Send(p.rw, AnnounceMsg, request) +} + +// SendBlockHeaders sends a batch of block headers to the remote peer. +func (p *peer) SendBlockHeaders(reqID, bv uint64, headers []*types.Header) error { + return sendResponse(p.rw, BlockHeadersMsg, reqID, bv, headers) +} + +// SendBlockBodiesRLP sends a batch of block contents to the remote peer from +// an already RLP encoded format. +func (p *peer) SendBlockBodiesRLP(reqID, bv uint64, bodies []rlp.RawValue) error { + return sendResponse(p.rw, BlockBodiesMsg, reqID, bv, bodies) +} + +// SendCodeRLP sends a batch of arbitrary internal data, corresponding to the +// hashes requested. +func (p *peer) SendCode(reqID, bv uint64, data [][]byte) error { + return sendResponse(p.rw, CodeMsg, reqID, bv, data) +} + +// SendReceiptsRLP sends a batch of transaction receipts, corresponding to the +// ones requested from an already RLP encoded format. +func (p *peer) SendReceiptsRLP(reqID, bv uint64, receipts []rlp.RawValue) error { + return sendResponse(p.rw, ReceiptsMsg, reqID, bv, receipts) +} + +// SendProofs sends a batch of merkle proofs, corresponding to the ones requested. +func (p *peer) SendProofs(reqID, bv uint64, proofs proofsData) error { + return sendResponse(p.rw, ProofsMsg, reqID, bv, proofs) +} + +// SendHeaderProofs sends a batch of header proofs, corresponding to the ones requested. +func (p *peer) SendHeaderProofs(reqID, bv uint64, proofs []ChtResp) error { + return sendResponse(p.rw, HeaderProofsMsg, reqID, bv, proofs) +} + +// RequestHeadersByHash fetches a batch of blocks' headers corresponding to the +// specified header query, based on the hash of an origin block. +func (p *peer) RequestHeadersByHash(reqID, cost uint64, origin common.Hash, amount int, skip int, reverse bool) error { + glog.V(logger.Debug).Infof("%v fetching %d headers from %x, skipping %d (reverse = %v)", p, amount, origin[:4], skip, reverse) + return sendRequest(p.rw, GetBlockHeadersMsg, reqID, cost, &getBlockHeadersData{Origin: hashOrNumber{Hash: origin}, Amount: uint64(amount), Skip: uint64(skip), Reverse: reverse}) +} + +// RequestHeadersByNumber fetches a batch of blocks' headers corresponding to the +// specified header query, based on the number of an origin block. +func (p *peer) RequestHeadersByNumber(reqID, cost, origin uint64, amount int, skip int, reverse bool) error { + glog.V(logger.Debug).Infof("%v fetching %d headers from #%d, skipping %d (reverse = %v)", p, amount, origin, skip, reverse) + return sendRequest(p.rw, GetBlockHeadersMsg, reqID, cost, &getBlockHeadersData{Origin: hashOrNumber{Number: origin}, Amount: uint64(amount), Skip: uint64(skip), Reverse: reverse}) +} + +// RequestBodies fetches a batch of blocks' bodies corresponding to the hashes +// specified. +func (p *peer) RequestBodies(reqID, cost uint64, hashes []common.Hash) error { + glog.V(logger.Debug).Infof("%v fetching %d block bodies", p, len(hashes)) + return sendRequest(p.rw, GetBlockBodiesMsg, reqID, cost, hashes) +} + +// RequestCode fetches a batch of arbitrary data from a node's known state +// data, corresponding to the specified hashes. +func (p *peer) RequestCode(reqID, cost uint64, reqs []*CodeReq) error { + glog.V(logger.Debug).Infof("%v fetching %v state data", p, len(reqs)) + return sendRequest(p.rw, GetCodeMsg, reqID, cost, reqs) +} + +// RequestReceipts fetches a batch of transaction receipts from a remote node. +func (p *peer) RequestReceipts(reqID, cost uint64, hashes []common.Hash) error { + glog.V(logger.Debug).Infof("%v fetching %v receipts", p, len(hashes)) + return sendRequest(p.rw, GetReceiptsMsg, reqID, cost, hashes) +} + +// RequestProofs fetches a batch of merkle proofs from a remote node. +func (p *peer) RequestProofs(reqID, cost uint64, reqs []*ProofReq) error { + glog.V(logger.Debug).Infof("%v fetching %v proofs", p, len(reqs)) + return sendRequest(p.rw, GetProofsMsg, reqID, cost, reqs) +} + +// RequestHeaderProofs fetches a batch of header merkle proofs from a remote node. +func (p *peer) RequestHeaderProofs(reqID, cost uint64, reqs []*ChtReq) error { + glog.V(logger.Debug).Infof("%v fetching %v header proofs", p, len(reqs)) + return sendRequest(p.rw, GetHeaderProofsMsg, reqID, cost, reqs) +} + +func (p *peer) SendTxs(cost uint64, txs types.Transactions) error { + glog.V(logger.Debug).Infof("%v relaying %v txs", p, len(txs)) + p.fcServer.SendRequest(0, cost) + return p2p.Send(p.rw, SendTxMsg, txs) +} + +type keyValueEntry struct { + Key string + Value rlp.RawValue +} +type keyValueList []keyValueEntry +type keyValueMap map[string]rlp.RawValue + +func (l keyValueList) add(key string, val interface{}) keyValueList { + var entry keyValueEntry + entry.Key = key + if val == nil { + val = uint64(0) + } + enc, err := rlp.EncodeToBytes(val) + if err == nil { + entry.Value = enc + } + return append(l, entry) +} + +func (l keyValueList) decode() keyValueMap { + m := make(keyValueMap) + for _, entry := range l { + m[entry.Key] = entry.Value + } + return m +} + +func (m keyValueMap) get(key string, val interface{}) error { + enc, ok := m[key] + if !ok { + return errResp(ErrHandshakeMissingKey, "%s", key) + } + if val == nil { + return nil + } + return rlp.DecodeBytes(enc, val) +} + +func (p *peer) sendReceiveHandshake(sendList keyValueList) (keyValueList, error) { + // Send out own handshake in a new thread + errc := make(chan error, 1) + go func() { + errc <- p2p.Send(p.rw, StatusMsg, sendList) + }() + // In the mean time retrieve the remote status message + msg, err := p.rw.ReadMsg() + if err != nil { + return nil, err + } + if msg.Code != StatusMsg { + return nil, errResp(ErrNoStatusMsg, "first msg has code %x (!= %x)", msg.Code, StatusMsg) + } + if msg.Size > ProtocolMaxMsgSize { + return nil, errResp(ErrMsgTooLarge, "%v > %v", msg.Size, ProtocolMaxMsgSize) + } + // Decode the handshake + var recvList keyValueList + if err := msg.Decode(&recvList); err != nil { + return nil, errResp(ErrDecode, "msg %v: %v", msg, err) + } + if err := <-errc; err != nil { + return nil, err + } + return recvList, nil +} + +// Handshake executes the les protocol handshake, negotiating version number, +// network IDs, difficulties, head and genesis blocks. +func (p *peer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis common.Hash, server *LesServer) error { + p.lock.Lock() + defer p.lock.Unlock() + + var send keyValueList + send = send.add("protocolVersion", uint64(p.version)) + send = send.add("networkId", uint64(p.network)) + send = send.add("headTd", td) + send = send.add("headHash", head) + send = send.add("headNum", headNum) + send = send.add("genesisHash", genesis) + if server != nil { + send = send.add("serveHeaders", nil) + send = send.add("serveChainSince", uint64(0)) + send = send.add("serveStateSince", uint64(0)) + send = send.add("txRelay", nil) + send = send.add("flowControl/BL", server.defParams.BufLimit) + send = send.add("flowControl/MRR", server.defParams.MinRecharge) + list := server.fcCostStats.getCurrentList() + send = send.add("flowControl/MRC", list) + p.fcCosts = list.decode() + } + recvList, err := p.sendReceiveHandshake(send) + if err != nil { + return err + } + recv := recvList.decode() + + var rGenesis, rHash common.Hash + var rVersion, rNetwork, rNum uint64 + var rTd *big.Int + + if err := recv.get("protocolVersion", &rVersion); err != nil { + return err + } + if err := recv.get("networkId", &rNetwork); err != nil { + return err + } + if err := recv.get("headTd", &rTd); err != nil { + return err + } + if err := recv.get("headHash", &rHash); err != nil { + return err + } + if err := recv.get("headNum", &rNum); err != nil { + return err + } + if err := recv.get("genesisHash", &rGenesis); err != nil { + return err + } + + if rGenesis != genesis { + return errResp(ErrGenesisBlockMismatch, "%x (!= %x)", rGenesis, genesis) + } + if int(rNetwork) != p.network { + return errResp(ErrNetworkIdMismatch, "%d (!= %d)", rNetwork, p.network) + } + if int(rVersion) != p.version { + return errResp(ErrProtocolVersionMismatch, "%d (!= %d)", rVersion, p.version) + } + if server != nil { + if recv.get("serveStateSince", nil) == nil { + return errResp(ErrUselessPeer, "wanted client, got server") + } + p.fcClient = flowcontrol.NewClientNode(server.fcManager, server.defParams) + } else { + if recv.get("serveChainSince", nil) != nil { + return errResp(ErrUselessPeer, "peer cannot serve chain") + } + if recv.get("serveStateSince", nil) != nil { + return errResp(ErrUselessPeer, "peer cannot serve state") + } + if recv.get("txRelay", nil) != nil { + return errResp(ErrUselessPeer, "peer cannot relay transactions") + } + params := &flowcontrol.ServerParams{} + if err := recv.get("flowControl/BL", ¶ms.BufLimit); err != nil { + return err + } + if err := recv.get("flowControl/MRR", ¶ms.MinRecharge); err != nil { + return err + } + var MRC RequestCostList + if err := recv.get("flowControl/MRC", &MRC); err != nil { + return err + } + p.fcServerParams = params + p.fcServer = flowcontrol.NewServerNode(params) + p.fcCosts = MRC.decode() + } + + p.firstHeadInfo = &announceData{Td: rTd, Hash: rHash, Number: rNum} + p.headInfo = p.firstHeadInfo + p.headInfoLen = 1 + return nil +} + +// String implements fmt.Stringer. +func (p *peer) String() string { + return fmt.Sprintf("Peer %s [%s]", p.id, + fmt.Sprintf("les/%d", p.version), + ) +} + +// peerSet represents the collection of active peers currently participating in +// the Light Ethereum sub-protocol. +type peerSet struct { + peers map[string]*peer + lock sync.RWMutex + closed bool +} + +// newPeerSet creates a new peer set to track the active participants. +func newPeerSet() *peerSet { + return &peerSet{ + peers: make(map[string]*peer), + } +} + +// Register injects a new peer into the working set, or returns an error if the +// peer is already known. +func (ps *peerSet) Register(p *peer) error { + ps.lock.Lock() + defer ps.lock.Unlock() + + if ps.closed { + return errClosed + } + if _, ok := ps.peers[p.id]; ok { + return errAlreadyRegistered + } + ps.peers[p.id] = p + return nil +} + +// Unregister removes a remote peer from the active set, disabling any further +// actions to/from that particular entity. +func (ps *peerSet) Unregister(id string) error { + ps.lock.Lock() + defer ps.lock.Unlock() + + if _, ok := ps.peers[id]; !ok { + return errNotRegistered + } + delete(ps.peers, id) + return nil +} + +// AllPeerIDs returns a list of all registered peer IDs +func (ps *peerSet) AllPeerIDs() []string { + ps.lock.RLock() + defer ps.lock.RUnlock() + + res := make([]string, len(ps.peers)) + idx := 0 + for id, _ := range ps.peers { + res[idx] = id + idx++ + } + return res +} + +// Peer retrieves the registered peer with the given id. +func (ps *peerSet) Peer(id string) *peer { + ps.lock.RLock() + defer ps.lock.RUnlock() + + return ps.peers[id] +} + +// Len returns if the current number of peers in the set. +func (ps *peerSet) Len() int { + ps.lock.RLock() + defer ps.lock.RUnlock() + + return len(ps.peers) +} + +// BestPeer retrieves the known peer with the currently highest total difficulty. +func (ps *peerSet) BestPeer() *peer { + ps.lock.RLock() + defer ps.lock.RUnlock() + + var ( + bestPeer *peer + bestTd *big.Int + ) + for _, p := range ps.peers { + if td := p.Td(); bestPeer == nil || td.Cmp(bestTd) > 0 { + bestPeer, bestTd = p, td + } + } + return bestPeer +} + +// AllPeers returns all peers in a list +func (ps *peerSet) AllPeers() []*peer { + ps.lock.RLock() + defer ps.lock.RUnlock() + + list := make([]*peer, len(ps.peers)) + i := 0 + for _, peer := range ps.peers { + list[i] = peer + i++ + } + return list +} + +// Close disconnects all peers. +// No new peers can be registered after Close has returned. +func (ps *peerSet) Close() { + ps.lock.Lock() + defer ps.lock.Unlock() + + for _, p := range ps.peers { + p.Disconnect(p2p.DiscQuitting) + } + ps.closed = true +} diff --git a/les/protocol.go b/les/protocol.go new file mode 100644 index 000000000..3d2de64e1 --- /dev/null +++ b/les/protocol.go @@ -0,0 +1,198 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// Package les implements the Light Ethereum Subprotocol. +package les + +import ( + "fmt" + "io" + "math/big" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/rlp" +) + +// Constants to match up protocol versions and messages +const ( + lpv1 = 1 +) + +// Supported versions of the les protocol (first is primary). +var ProtocolVersions = []uint{lpv1} + +// Number of implemented message corresponding to different protocol versions. +var ProtocolLengths = []uint64{15} + +const ( + NetworkId = 1 + ProtocolMaxMsgSize = 10 * 1024 * 1024 // Maximum cap on the size of a protocol message +) + +// les protocol message codes +const ( + // Protocol messages belonging to LPV1 + StatusMsg = 0x00 + AnnounceMsg = 0x01 + GetBlockHeadersMsg = 0x02 + BlockHeadersMsg = 0x03 + GetBlockBodiesMsg = 0x04 + BlockBodiesMsg = 0x05 + GetReceiptsMsg = 0x06 + ReceiptsMsg = 0x07 + GetProofsMsg = 0x08 + ProofsMsg = 0x09 + GetCodeMsg = 0x0a + CodeMsg = 0x0b + SendTxMsg = 0x0c + GetHeaderProofsMsg = 0x0d + HeaderProofsMsg = 0x0e +) + +type errCode int + +const ( + ErrMsgTooLarge = iota + ErrDecode + ErrInvalidMsgCode + ErrProtocolVersionMismatch + ErrNetworkIdMismatch + ErrGenesisBlockMismatch + ErrNoStatusMsg + ErrExtraStatusMsg + ErrSuspendedPeer + ErrUselessPeer + ErrRequestRejected + ErrUnexpectedResponse + ErrInvalidResponse + ErrTooManyTimeouts + ErrHandshakeMissingKey +) + +func (e errCode) String() string { + return errorToString[int(e)] +} + +// XXX change once legacy code is out +var errorToString = map[int]string{ + ErrMsgTooLarge: "Message too long", + ErrDecode: "Invalid message", + ErrInvalidMsgCode: "Invalid message code", + ErrProtocolVersionMismatch: "Protocol version mismatch", + ErrNetworkIdMismatch: "NetworkId mismatch", + ErrGenesisBlockMismatch: "Genesis block mismatch", + ErrNoStatusMsg: "No status message", + ErrExtraStatusMsg: "Extra status message", + ErrSuspendedPeer: "Suspended peer", + ErrRequestRejected: "Request rejected", + ErrUnexpectedResponse: "Unexpected response", + ErrInvalidResponse: "Invalid response", + ErrTooManyTimeouts: "Too many request timeouts", + ErrHandshakeMissingKey: "Key missing from handshake message", +} + +type chainManager interface { + GetBlockHashesFromHash(hash common.Hash, amount uint64) (hashes []common.Hash) + GetBlock(hash common.Hash) (block *types.Block) + Status() (td *big.Int, currentBlock common.Hash, genesisBlock common.Hash) +} + +// announceData is the network packet for the block announcements. +type announceData struct { + Hash common.Hash // Hash of one particular block being announced + Number uint64 // Number of one particular block being announced + Td *big.Int // Total difficulty of one particular block being announced + ReorgDepth uint64 + Update keyValueList + + haveHeaders uint64 // we have the headers of the remote peer's chain up to this number + headKnown bool + requested bool + next *announceData +} + +type blockInfo struct { + Hash common.Hash // Hash of one particular block being announced + Number uint64 // Number of one particular block being announced + Td *big.Int // Total difficulty of one particular block being announced +} + +// getBlockHashesData is the network packet for the hash based hash retrieval. +type getBlockHashesData struct { + Hash common.Hash + Amount uint64 +} + +// getBlockHeadersData represents a block header query. +type getBlockHeadersData struct { + Origin hashOrNumber // Block from which to retrieve headers + Amount uint64 // Maximum number of headers to retrieve + Skip uint64 // Blocks to skip between consecutive headers + Reverse bool // Query direction (false = rising towards latest, true = falling towards genesis) +} + +// hashOrNumber is a combined field for specifying an origin block. +type hashOrNumber struct { + Hash common.Hash // Block hash from which to retrieve headers (excludes Number) + Number uint64 // Block hash from which to retrieve headers (excludes Hash) +} + +// EncodeRLP is a specialized encoder for hashOrNumber to encode only one of the +// two contained union fields. +func (hn *hashOrNumber) EncodeRLP(w io.Writer) error { + if hn.Hash == (common.Hash{}) { + return rlp.Encode(w, hn.Number) + } + if hn.Number != 0 { + return fmt.Errorf("both origin hash (%x) and number (%d) provided", hn.Hash, hn.Number) + } + return rlp.Encode(w, hn.Hash) +} + +// DecodeRLP is a specialized decoder for hashOrNumber to decode the contents +// into either a block hash or a block number. +func (hn *hashOrNumber) DecodeRLP(s *rlp.Stream) error { + _, size, _ := s.Kind() + origin, err := s.Raw() + if err == nil { + switch { + case size == 32: + err = rlp.DecodeBytes(origin, &hn.Hash) + case size <= 8: + err = rlp.DecodeBytes(origin, &hn.Number) + default: + err = fmt.Errorf("invalid input size %d for origin", size) + } + } + return err +} + +// newBlockData is the network packet for the block propagation message. +type newBlockData struct { + Block *types.Block + TD *big.Int +} + +// blockBodiesData is the network packet for block content distribution. +type blockBodiesData []*types.Body + +// CodeData is the network response packet for a node data retrieval. +type CodeData []struct { + Value []byte +} + +type proofsData [][]rlp.RawValue diff --git a/les/request_test.go b/les/request_test.go new file mode 100644 index 000000000..df02afb32 --- /dev/null +++ b/les/request_test.go @@ -0,0 +1,94 @@ +package les + +import ( + "testing" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/light" + "golang.org/x/net/context" +) + +var testBankSecureTrieKey = secAddr(testBankAddress) + +func secAddr(addr common.Address) []byte { + return crypto.Keccak256(addr[:]) +} + +type accessTestFn func(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest + +func TestBlockAccessLes1(t *testing.T) { testAccess(t, 1, tfBlockAccess) } + +func tfBlockAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest { + return &light.BlockRequest{Hash: bhash, Number: number} +} + +func TestReceiptsAccessLes1(t *testing.T) { testAccess(t, 1, tfReceiptsAccess) } + +func tfReceiptsAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest { + return &light.ReceiptsRequest{Hash: bhash, Number: number} +} + +func TestTrieEntryAccessLes1(t *testing.T) { testAccess(t, 1, tfTrieEntryAccess) } + +func tfTrieEntryAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest { + return &light.TrieRequest{Id: light.StateTrieID(core.GetHeader(db, bhash, core.GetBlockNumber(db, bhash))), Key: testBankSecureTrieKey} +} + +func TestCodeAccessLes1(t *testing.T) { testAccess(t, 1, tfCodeAccess) } + +func tfCodeAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest { + header := core.GetHeader(db, bhash, core.GetBlockNumber(db, bhash)) + if header.GetNumberU64() < testContractDeployed { + return nil + } + sti := light.StateTrieID(header) + ci := light.StorageTrieID(sti, testContractAddr, common.Hash{}) + return &light.CodeRequest{Id: ci, Hash: crypto.Keccak256Hash(testContractCodeDeployed)} +} + +func testAccess(t *testing.T, protocol int, fn accessTestFn) { + // Assemble the test environment + pm, db, _ := newTestProtocolManagerMust(t, false, 4, testChainGen) + lpm, ldb, odr := newTestProtocolManagerMust(t, true, 0, nil) + _, err1, lpeer, err2 := newTestPeerPair("peer", protocol, pm, lpm) + select { + case <-time.After(time.Millisecond * 100): + case err := <-err1: + t.Fatalf("peer 1 handshake error: %v", err) + case err := <-err2: + t.Fatalf("peer 1 handshake error: %v", err) + } + + lpm.synchronise(lpeer) + + test := func(expFail uint64) { + for i := uint64(0); i <= pm.blockchain.CurrentHeader().GetNumberU64(); i++ { + bhash := core.GetCanonicalHash(db, i) + if req := fn(ldb, bhash, i); req != nil { + ctx, _ := context.WithTimeout(context.Background(), 200*time.Millisecond) + err := odr.Retrieve(ctx, req) + got := err == nil + exp := i < expFail + if exp && !got { + t.Errorf("object retrieval failed") + } + if !exp && got { + t.Errorf("unexpected object retrieval success") + } + } + } + } + + // temporarily remove peer to test odr fails + odr.UnregisterPeer(lpeer) + // expect retrievals to fail (except genesis block) without a les peer + test(0) + odr.RegisterPeer(lpeer) + // expect all retrievals to pass + test(5) + odr.UnregisterPeer(lpeer) +} diff --git a/les/server.go b/les/server.go new file mode 100644 index 000000000..bc5dd0837 --- /dev/null +++ b/les/server.go @@ -0,0 +1,401 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// Package les implements the Light Ethereum Subprotocol. +package les + +import ( + "encoding/binary" + "fmt" + "math" + "sync" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/eth" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/les/flowcontrol" + "github.com/ethereum/go-ethereum/light" + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/trie" +) + +type LesServer struct { + protocolManager *ProtocolManager + fcManager *flowcontrol.ClientManager // nil if our node is client only + fcCostStats *requestCostStats + defParams *flowcontrol.ServerParams +} + +func NewLesServer(eth *eth.Ethereum, config *eth.Config) (*LesServer, error) { + pm, err := NewProtocolManager(config.ChainConfig, false, config.NetworkId, eth.EventMux(), eth.Pow(), eth.BlockChain(), eth.TxPool(), eth.ChainDb(), nil, nil) + if err != nil { + return nil, err + } + pm.blockLoop() + + srv := &LesServer{protocolManager: pm} + pm.server = srv + + srv.defParams = &flowcontrol.ServerParams{ + BufLimit: 300000000, + MinRecharge: 50000, + } + srv.fcManager = flowcontrol.NewClientManager(uint64(config.LightServ), 10, 1000000000) + srv.fcCostStats = newCostStats(eth.ChainDb()) + return srv, nil +} + +func (s *LesServer) Protocols() []p2p.Protocol { + return s.protocolManager.SubProtocols +} + +func (s *LesServer) Start() { + s.protocolManager.Start() +} + +func (s *LesServer) Stop() { + s.fcCostStats.store() + s.fcManager.Stop() + go func() { + <-s.protocolManager.noMorePeers + }() + s.protocolManager.Stop() +} + +type requestCosts struct { + baseCost, reqCost uint64 +} + +type requestCostTable map[uint64]*requestCosts + +type RequestCostList []struct { + MsgCode, BaseCost, ReqCost uint64 +} + +func (list RequestCostList) decode() requestCostTable { + table := make(requestCostTable) + for _, e := range list { + table[e.MsgCode] = &requestCosts{ + baseCost: e.BaseCost, + reqCost: e.ReqCost, + } + } + return table +} + +func (table requestCostTable) encode() RequestCostList { + list := make(RequestCostList, len(table)) + for idx, code := range reqList { + list[idx].MsgCode = code + list[idx].BaseCost = table[code].baseCost + list[idx].ReqCost = table[code].reqCost + } + return list +} + +type linReg struct { + sumX, sumY, sumXX, sumXY float64 + cnt uint64 +} + +const linRegMaxCnt = 100000 + +func (l *linReg) add(x, y float64) { + if l.cnt >= linRegMaxCnt { + sub := float64(l.cnt+1-linRegMaxCnt) / linRegMaxCnt + l.sumX -= l.sumX * sub + l.sumY -= l.sumY * sub + l.sumXX -= l.sumXX * sub + l.sumXY -= l.sumXY * sub + l.cnt = linRegMaxCnt - 1 + } + l.cnt++ + l.sumX += x + l.sumY += y + l.sumXX += x * x + l.sumXY += x * y +} + +func (l *linReg) calc() (b, m float64) { + if l.cnt == 0 { + return 0, 0 + } + cnt := float64(l.cnt) + d := cnt*l.sumXX - l.sumX*l.sumX + if d < 0.001 { + return l.sumY / cnt, 0 + } + m = (cnt*l.sumXY - l.sumX*l.sumY) / d + b = (l.sumY / cnt) - (m * l.sumX / cnt) + return b, m +} + +func (l *linReg) toBytes() []byte { + var arr [40]byte + binary.BigEndian.PutUint64(arr[0:8], math.Float64bits(l.sumX)) + binary.BigEndian.PutUint64(arr[8:16], math.Float64bits(l.sumY)) + binary.BigEndian.PutUint64(arr[16:24], math.Float64bits(l.sumXX)) + binary.BigEndian.PutUint64(arr[24:32], math.Float64bits(l.sumXY)) + binary.BigEndian.PutUint64(arr[32:40], l.cnt) + return arr[:] +} + +func linRegFromBytes(data []byte) *linReg { + if len(data) != 40 { + return nil + } + l := &linReg{} + l.sumX = math.Float64frombits(binary.BigEndian.Uint64(data[0:8])) + l.sumY = math.Float64frombits(binary.BigEndian.Uint64(data[8:16])) + l.sumXX = math.Float64frombits(binary.BigEndian.Uint64(data[16:24])) + l.sumXY = math.Float64frombits(binary.BigEndian.Uint64(data[24:32])) + l.cnt = binary.BigEndian.Uint64(data[32:40]) + return l +} + +type requestCostStats struct { + lock sync.RWMutex + db ethdb.Database + stats map[uint64]*linReg +} + +type requestCostStatsRlp []struct { + MsgCode uint64 + Data []byte +} + +var rcStatsKey = []byte("_requestCostStats") + +func newCostStats(db ethdb.Database) *requestCostStats { + stats := make(map[uint64]*linReg) + for _, code := range reqList { + stats[code] = &linReg{cnt: 100} + } + + if db != nil { + data, err := db.Get(rcStatsKey) + var statsRlp requestCostStatsRlp + if err == nil { + err = rlp.DecodeBytes(data, &statsRlp) + } + if err == nil { + for _, r := range statsRlp { + if stats[r.MsgCode] != nil { + if l := linRegFromBytes(r.Data); l != nil { + stats[r.MsgCode] = l + } + } + } + } + } + + return &requestCostStats{ + db: db, + stats: stats, + } +} + +func (s *requestCostStats) store() { + s.lock.Lock() + defer s.lock.Unlock() + + statsRlp := make(requestCostStatsRlp, len(reqList)) + for i, code := range reqList { + statsRlp[i].MsgCode = code + statsRlp[i].Data = s.stats[code].toBytes() + } + + if data, err := rlp.EncodeToBytes(statsRlp); err == nil { + s.db.Put(rcStatsKey, data) + } +} + +func (s *requestCostStats) getCurrentList() RequestCostList { + s.lock.Lock() + defer s.lock.Unlock() + + list := make(RequestCostList, len(reqList)) + //fmt.Println("RequestCostList") + for idx, code := range reqList { + b, m := s.stats[code].calc() + //fmt.Println(code, s.stats[code].cnt, b/1000000, m/1000000) + if m < 0 { + b += m + m = 0 + } + if b < 0 { + b = 0 + } + + list[idx].MsgCode = code + list[idx].BaseCost = uint64(b * 2) + list[idx].ReqCost = uint64(m * 2) + } + return list +} + +func (s *requestCostStats) update(msgCode, reqCnt, cost uint64) { + s.lock.Lock() + defer s.lock.Unlock() + + c, ok := s.stats[msgCode] + if !ok || reqCnt == 0 { + return + } + c.add(float64(reqCnt), float64(cost)) +} + +func (pm *ProtocolManager) blockLoop() { + pm.wg.Add(1) + sub := pm.eventMux.Subscribe(core.ChainHeadEvent{}) + newCht := make(chan struct{}, 10) + newCht <- struct{}{} + go func() { + var mu sync.Mutex + var lastHead *types.Header + lastBroadcastTd := common.Big0 + for { + select { + case ev := <-sub.Chan(): + peers := pm.peers.AllPeers() + if len(peers) > 0 { + header := ev.Data.(core.ChainHeadEvent).Block.Header() + hash := header.Hash() + number := header.GetNumberU64() + td := core.GetTd(pm.chainDb, hash, number) + if td != nil && td.Cmp(lastBroadcastTd) > 0 { + var reorg uint64 + if lastHead != nil { + reorg = lastHead.GetNumberU64() - core.FindCommonAncestor(pm.chainDb, header, lastHead).GetNumberU64() + } + lastHead = header + lastBroadcastTd = td + //fmt.Println("BROADCAST", number, hash, td, reorg) + announce := announceData{Hash: hash, Number: number, Td: td, ReorgDepth: reorg} + for _, p := range peers { + select { + case p.announceChn <- announce: + default: + pm.removePeer(p.id) + } + } + } + } + newCht <- struct{}{} + case <-newCht: + go func() { + mu.Lock() + more := makeCht(pm.chainDb) + mu.Unlock() + if more { + time.Sleep(time.Millisecond * 10) + newCht <- struct{}{} + } + }() + case <-pm.quitSync: + sub.Unsubscribe() + pm.wg.Done() + return + } + } + }() +} + +var ( + lastChtKey = []byte("LastChtNumber") // chtNum (uint64 big endian) + chtPrefix = []byte("cht") // chtPrefix + chtNum (uint64 big endian) -> trie root hash + chtConfirmations = light.ChtFrequency / 2 +) + +func getChtRoot(db ethdb.Database, num uint64) common.Hash { + var encNumber [8]byte + binary.BigEndian.PutUint64(encNumber[:], num) + data, _ := db.Get(append(chtPrefix, encNumber[:]...)) + return common.BytesToHash(data) +} + +func storeChtRoot(db ethdb.Database, num uint64, root common.Hash) { + var encNumber [8]byte + binary.BigEndian.PutUint64(encNumber[:], num) + db.Put(append(chtPrefix, encNumber[:]...), root[:]) +} + +func makeCht(db ethdb.Database) bool { + headHash := core.GetHeadBlockHash(db) + headNum := core.GetBlockNumber(db, headHash) + + var newChtNum uint64 + if headNum > chtConfirmations { + newChtNum = (headNum - chtConfirmations) / light.ChtFrequency + } + + var lastChtNum uint64 + data, _ := db.Get(lastChtKey) + if len(data) == 8 { + lastChtNum = binary.BigEndian.Uint64(data[:]) + } + if newChtNum <= lastChtNum { + return false + } + + var t *trie.Trie + if lastChtNum > 0 { + var err error + t, err = trie.New(getChtRoot(db, lastChtNum), db) + if err != nil { + lastChtNum = 0 + } + } + if lastChtNum == 0 { + t, _ = trie.New(common.Hash{}, db) + } + + for num := lastChtNum * light.ChtFrequency; num < (lastChtNum+1)*light.ChtFrequency; num++ { + hash := core.GetCanonicalHash(db, num) + if hash == (common.Hash{}) { + panic("Canonical hash not found") + } + td := core.GetTd(db, hash, num) + if td == nil { + panic("TD not found") + } + var encNumber [8]byte + binary.BigEndian.PutUint64(encNumber[:], num) + var node light.ChtNode + node.Hash = hash + node.Td = td + data, _ := rlp.EncodeToBytes(node) + t.Update(encNumber[:], data) + } + + root, err := t.Commit() + if err != nil { + lastChtNum = 0 + } else { + lastChtNum++ + fmt.Printf("CHT %d %064x\n", lastChtNum, root) + storeChtRoot(db, lastChtNum, root) + var data [8]byte + binary.BigEndian.PutUint64(data[:], lastChtNum) + db.Put(lastChtKey, data[:]) + } + + return newChtNum > lastChtNum +} diff --git a/les/sync.go b/les/sync.go new file mode 100644 index 000000000..f92f8ce04 --- /dev/null +++ b/les/sync.go @@ -0,0 +1,84 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package les + +import ( + "time" + + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/eth/downloader" + "github.com/ethereum/go-ethereum/light" + "golang.org/x/net/context" +) + +const ( + //forceSyncCycle = 10 * time.Second // Time interval to force syncs, even if few peers are available + minDesiredPeerCount = 5 // Amount of peers desired to start syncing +) + +// syncer is responsible for periodically synchronising with the network, both +// downloading hashes and blocks as well as handling the announcement handler. +func (pm *ProtocolManager) syncer() { + // Start and ensure cleanup of sync mechanisms + //pm.fetcher.Start() + //defer pm.fetcher.Stop() + defer pm.downloader.Terminate() + + // Wait for different events to fire synchronisation operations + //forceSync := time.Tick(forceSyncCycle) + for { + select { + case <-pm.newPeerCh: +/* // Make sure we have peers to select from, then sync + if pm.peers.Len() < minDesiredPeerCount { + break + } + go pm.synchronise(pm.peers.BestPeer()) +*/ + /*case <-forceSync: + // Force a sync even if not enough peers are present + go pm.synchronise(pm.peers.BestPeer()) + */ + case <-pm.noMorePeers: + return + } + } +} + +func (pm *ProtocolManager) needToSync(peerHead blockInfo) bool { + head := pm.blockchain.CurrentHeader() + currentTd := core.GetTd(pm.chainDb, head.Hash(), head.Number.Uint64()) + return currentTd != nil && peerHead.Td.Cmp(currentTd) > 0 +} + +// synchronise tries to sync up our local block chain with a remote peer. +func (pm *ProtocolManager) synchronise(peer *peer) { + // Short circuit if no peers are available + if peer == nil { + return + } + + // Make sure the peer's TD is higher than our own. + if !pm.needToSync(peer.headBlockInfo()) { + return + } + + ctx, _ := context.WithTimeout(context.Background(), time.Second*5) + pm.blockchain.(*light.LightChain).SyncCht(ctx) + + pm.downloader.Synchronise(peer.id, peer.Head(), peer.Td(), downloader.LightSync) +} diff --git a/les/txrelay.go b/les/txrelay.go new file mode 100644 index 000000000..2df2fa0a9 --- /dev/null +++ b/les/txrelay.go @@ -0,0 +1,156 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . +package les + +import ( + "sync" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" +) + +type ltrInfo struct { + tx *types.Transaction + sentTo map[*peer]struct{} +} + +type LesTxRelay struct { + txSent map[common.Hash]*ltrInfo + txPending map[common.Hash]struct{} + ps *peerSet + peerList []*peer + peerStartPos int + lock sync.RWMutex +} + +func NewLesTxRelay() *LesTxRelay { + return &LesTxRelay{ + txSent: make(map[common.Hash]*ltrInfo), + txPending: make(map[common.Hash]struct{}), + ps: newPeerSet(), + } +} + +func (self *LesTxRelay) addPeer(p *peer) { + self.lock.Lock() + defer self.lock.Unlock() + + self.ps.Register(p) + self.peerList = self.ps.AllPeers() +} + +func (self *LesTxRelay) removePeer(id string) { + self.lock.Lock() + defer self.lock.Unlock() + + self.ps.Unregister(id) + self.peerList = self.ps.AllPeers() +} + +// send sends a list of transactions to at most a given number of peers at +// once, never resending any particular transaction to the same peer twice +func (self *LesTxRelay) send(txs types.Transactions, count int) { + sendTo := make(map[*peer]types.Transactions) + + self.peerStartPos++ // rotate the starting position of the peer list + if self.peerStartPos >= len(self.peerList) { + self.peerStartPos = 0 + } + + for _, tx := range txs { + hash := tx.Hash() + ltr, ok := self.txSent[hash] + if !ok { + ltr = <rInfo{ + tx: tx, + sentTo: make(map[*peer]struct{}), + } + self.txSent[hash] = ltr + self.txPending[hash] = struct{}{} + } + + if len(self.peerList) > 0 { + cnt := count + pos := self.peerStartPos + for { + peer := self.peerList[pos] + if _, ok := ltr.sentTo[peer]; !ok { + sendTo[peer] = append(sendTo[peer], tx) + ltr.sentTo[peer] = struct{}{} + cnt-- + } + if cnt == 0 { + break // sent it to the desired number of peers + } + pos++ + if pos == len(self.peerList) { + pos = 0 + } + if pos == self.peerStartPos { + break // tried all available peers + } + } + } + } + + for p, list := range sendTo { + cost := p.GetRequestCost(SendTxMsg, len(list)) + go func(p *peer, list types.Transactions, cost uint64) { + p.fcServer.SendRequest(0, cost) + p.SendTxs(cost, list) + }(p, list, cost) + } +} + +func (self *LesTxRelay) Send(txs types.Transactions) { + self.lock.Lock() + defer self.lock.Unlock() + + self.send(txs, 3) +} + +func (self *LesTxRelay) NewHead(head common.Hash, mined []common.Hash, rollback []common.Hash) { + self.lock.Lock() + defer self.lock.Unlock() + + for _, hash := range mined { + delete(self.txPending, hash) + } + + for _, hash := range rollback { + self.txPending[hash] = struct{}{} + } + + if len(self.txPending) > 0 { + txs := make(types.Transactions, len(self.txPending)) + i := 0 + for hash, _ := range self.txPending { + txs[i] = self.txSent[hash].tx + i++ + } + self.send(txs, 1) + } +} + +func (self *LesTxRelay) Discard(hashes []common.Hash) { + self.lock.Lock() + defer self.lock.Unlock() + + for _, hash := range hashes { + delete(self.txSent, hash) + delete(self.txPending, hash) + } +} -- cgit v1.2.3