aboutsummaryrefslogtreecommitdiffstats
path: root/les
diff options
context:
space:
mode:
authorZsolt Felfoldi <zsfelfoldi@gmail.com>2016-10-14 11:51:29 +0800
committerFelix Lange <fjl@twurst.com>2016-11-09 09:12:53 +0800
commit9f8d192991c4f68fa14c91366722bbca601da117 (patch)
tree5c1e089673d3f0208cd4a8208623bb95f29622c9 /les
parent760fd65487614b7a61443cd9371015925795f40f (diff)
downloaddexon-9f8d192991c4f68fa14c91366722bbca601da117.tar
dexon-9f8d192991c4f68fa14c91366722bbca601da117.tar.gz
dexon-9f8d192991c4f68fa14c91366722bbca601da117.tar.bz2
dexon-9f8d192991c4f68fa14c91366722bbca601da117.tar.lz
dexon-9f8d192991c4f68fa14c91366722bbca601da117.tar.xz
dexon-9f8d192991c4f68fa14c91366722bbca601da117.tar.zst
dexon-9f8d192991c4f68fa14c91366722bbca601da117.zip
les: light client protocol and API
Diffstat (limited to 'les')
-rw-r--r--les/api_backend.go144
-rw-r--r--les/backend.go218
-rw-r--r--les/fetcher.go295
-rw-r--r--les/flowcontrol/control.go172
-rw-r--r--les/flowcontrol/manager.go223
-rw-r--r--les/handler.go854
-rw-r--r--les/handler_test.go322
-rw-r--r--les/helper_test.go318
-rw-r--r--les/metrics.go111
-rw-r--r--les/odr.go247
-rw-r--r--les/odr_peerset.go119
-rw-r--r--les/odr_requests.go325
-rw-r--r--les/odr_test.go222
-rw-r--r--les/peer.go584
-rw-r--r--les/protocol.go198
-rw-r--r--les/request_test.go94
-rw-r--r--les/server.go401
-rw-r--r--les/sync.go84
-rw-r--r--les/txrelay.go156
19 files changed, 5087 insertions, 0 deletions
diff --git a/les/api_backend.go b/les/api_backend.go
new file mode 100644
index 000000000..d50b3ea33
--- /dev/null
+++ b/les/api_backend.go
@@ -0,0 +1,144 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of go-ethereum.
+//
+// go-ethereum is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// go-ethereum is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with go-ethereum. If not, see <http://www.gnu.org/licenses/>.
+
+package les
+
+import (
+ "math/big"
+
+ "github.com/ethereum/go-ethereum/accounts"
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core"
+ "github.com/ethereum/go-ethereum/core/types"
+ "github.com/ethereum/go-ethereum/core/vm"
+ "github.com/ethereum/go-ethereum/eth/downloader"
+ "github.com/ethereum/go-ethereum/eth/gasprice"
+ "github.com/ethereum/go-ethereum/ethdb"
+ "github.com/ethereum/go-ethereum/event"
+ "github.com/ethereum/go-ethereum/internal/ethapi"
+ "github.com/ethereum/go-ethereum/light"
+ rpc "github.com/ethereum/go-ethereum/rpc"
+ "golang.org/x/net/context"
+)
+
+type LesApiBackend struct {
+ eth *LightEthereum
+ gpo *gasprice.LightPriceOracle
+}
+
+func (b *LesApiBackend) SetHead(number uint64) {
+ b.eth.blockchain.SetHead(number)
+}
+
+func (b *LesApiBackend) HeaderByNumber(ctx context.Context, blockNr rpc.BlockNumber) (*types.Header, error) {
+ if blockNr == rpc.LatestBlockNumber || blockNr == rpc.PendingBlockNumber {
+ return b.eth.blockchain.CurrentHeader(), nil
+ }
+
+ return b.eth.blockchain.GetHeaderByNumberOdr(ctx, uint64(blockNr))
+}
+
+func (b *LesApiBackend) BlockByNumber(ctx context.Context, blockNr rpc.BlockNumber) (*types.Block, error) {
+ header, err := b.HeaderByNumber(ctx, blockNr)
+ if header == nil || err != nil {
+ return nil, err
+ }
+ return b.GetBlock(ctx, header.Hash())
+}
+
+func (b *LesApiBackend) StateAndHeaderByNumber(ctx context.Context, blockNr rpc.BlockNumber) (ethapi.State, *types.Header, error) {
+ header, err := b.HeaderByNumber(ctx, blockNr)
+ if header == nil || err != nil {
+ return nil, nil, err
+ }
+ return light.NewLightState(light.StateTrieID(header), b.eth.odr), header, nil
+}
+
+func (b *LesApiBackend) GetBlock(ctx context.Context, blockHash common.Hash) (*types.Block, error) {
+ return b.eth.blockchain.GetBlockByHash(ctx, blockHash)
+}
+
+func (b *LesApiBackend) GetReceipts(ctx context.Context, blockHash common.Hash) (types.Receipts, error) {
+ return light.GetBlockReceipts(ctx, b.eth.odr, blockHash, core.GetBlockNumber(b.eth.chainDb, blockHash))
+}
+
+func (b *LesApiBackend) GetTd(blockHash common.Hash) *big.Int {
+ return b.eth.blockchain.GetTdByHash(blockHash)
+}
+
+func (b *LesApiBackend) GetVMEnv(ctx context.Context, msg core.Message, state ethapi.State, header *types.Header) (vm.Environment, func() error, error) {
+ stateDb := state.(*light.LightState).Copy()
+ addr, _ := msg.From()
+ from, err := stateDb.GetOrNewStateObject(ctx, addr)
+ if err != nil {
+ return nil, nil, err
+ }
+ from.SetBalance(common.MaxBig)
+ env := light.NewEnv(ctx, stateDb, b.eth.chainConfig, b.eth.blockchain, msg, header, b.eth.chainConfig.VmConfig)
+ return env, env.Error, nil
+}
+
+func (b *LesApiBackend) SendTx(ctx context.Context, signedTx *types.Transaction) error {
+ return b.eth.txPool.Add(ctx, signedTx)
+}
+
+func (b *LesApiBackend) RemoveTx(txHash common.Hash) {
+ b.eth.txPool.RemoveTx(txHash)
+}
+
+func (b *LesApiBackend) GetPoolTransactions() types.Transactions {
+ return b.eth.txPool.GetTransactions()
+}
+
+func (b *LesApiBackend) GetPoolTransaction(txHash common.Hash) *types.Transaction {
+ return b.eth.txPool.GetTransaction(txHash)
+}
+
+func (b *LesApiBackend) GetPoolNonce(ctx context.Context, addr common.Address) (uint64, error) {
+ return b.eth.txPool.GetNonce(ctx, addr)
+}
+
+func (b *LesApiBackend) Stats() (pending int, queued int) {
+ return b.eth.txPool.Stats(), 0
+}
+
+func (b *LesApiBackend) TxPoolContent() (map[common.Address]types.Transactions, map[common.Address]types.Transactions) {
+ return b.eth.txPool.Content()
+}
+
+func (b *LesApiBackend) Downloader() *downloader.Downloader {
+ return b.eth.Downloader()
+}
+
+func (b *LesApiBackend) ProtocolVersion() int {
+ return b.eth.LesVersion() + 10000
+}
+
+func (b *LesApiBackend) SuggestPrice(ctx context.Context) (*big.Int, error) {
+ return b.gpo.SuggestPrice(ctx)
+}
+
+func (b *LesApiBackend) ChainDb() ethdb.Database {
+ return b.eth.chainDb
+}
+
+func (b *LesApiBackend) EventMux() *event.TypeMux {
+ return b.eth.eventMux
+}
+
+func (b *LesApiBackend) AccountManager() *accounts.Manager {
+ return b.eth.accountManager
+}
diff --git a/les/backend.go b/les/backend.go
new file mode 100644
index 000000000..8011a4b31
--- /dev/null
+++ b/les/backend.go
@@ -0,0 +1,218 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+// Package les implements the Light Ethereum Subprotocol.
+package les
+
+import (
+ "errors"
+ "fmt"
+ "time"
+
+ "github.com/ethereum/ethash"
+ "github.com/ethereum/go-ethereum/accounts"
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/common/compiler"
+ "github.com/ethereum/go-ethereum/common/httpclient"
+ "github.com/ethereum/go-ethereum/core"
+ "github.com/ethereum/go-ethereum/core/types"
+ "github.com/ethereum/go-ethereum/core/vm"
+ "github.com/ethereum/go-ethereum/eth"
+ "github.com/ethereum/go-ethereum/eth/downloader"
+ "github.com/ethereum/go-ethereum/eth/filters"
+ "github.com/ethereum/go-ethereum/eth/gasprice"
+ "github.com/ethereum/go-ethereum/ethdb"
+ "github.com/ethereum/go-ethereum/event"
+ "github.com/ethereum/go-ethereum/internal/ethapi"
+ "github.com/ethereum/go-ethereum/light"
+ "github.com/ethereum/go-ethereum/node"
+ "github.com/ethereum/go-ethereum/p2p"
+ rpc "github.com/ethereum/go-ethereum/rpc"
+)
+
+type LightEthereum struct {
+ odr *LesOdr
+ relay *LesTxRelay
+ chainConfig *core.ChainConfig
+ // Channel for shutting down the service
+ shutdownChan chan bool
+ // Handlers
+ txPool *light.TxPool
+ blockchain *light.LightChain
+ protocolManager *ProtocolManager
+ // DB interfaces
+ chainDb ethdb.Database // Block chain database
+
+ ApiBackend *LesApiBackend
+
+ eventMux *event.TypeMux
+ pow *ethash.Ethash
+ httpclient *httpclient.HTTPClient
+ accountManager *accounts.Manager
+ solcPath string
+ solc *compiler.Solidity
+
+ NatSpec bool
+ PowTest bool
+ netVersionId int
+ netRPCService *ethapi.PublicNetAPI
+}
+
+func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) {
+ chainDb, err := eth.CreateDB(ctx, config, "lightchaindata")
+ if err != nil {
+ return nil, err
+ }
+ if err := eth.SetupGenesisBlock(&chainDb, config); err != nil {
+ return nil, err
+ }
+ pow, err := eth.CreatePoW(config)
+ if err != nil {
+ return nil, err
+ }
+
+ odr := NewLesOdr(chainDb)
+ relay := NewLesTxRelay()
+ eth := &LightEthereum{
+ odr: odr,
+ relay: relay,
+ chainDb: chainDb,
+ eventMux: ctx.EventMux,
+ accountManager: ctx.AccountManager,
+ pow: pow,
+ shutdownChan: make(chan bool),
+ httpclient: httpclient.New(config.DocRoot),
+ netVersionId: config.NetworkId,
+ NatSpec: config.NatSpec,
+ PowTest: config.PowTest,
+ solcPath: config.SolcPath,
+ }
+
+ if config.ChainConfig == nil {
+ return nil, errors.New("missing chain config")
+ }
+ eth.chainConfig = config.ChainConfig
+ eth.chainConfig.VmConfig = vm.Config{
+ EnableJit: config.EnableJit,
+ ForceJit: config.ForceJit,
+ }
+ eth.blockchain, err = light.NewLightChain(odr, eth.chainConfig, eth.pow, eth.eventMux)
+ if err != nil {
+ if err == core.ErrNoGenesis {
+ return nil, fmt.Errorf(`Genesis block not found. Please supply a genesis block with the "--genesis /path/to/file" argument`)
+ }
+ return nil, err
+ }
+
+ eth.txPool = light.NewTxPool(eth.chainConfig, eth.eventMux, eth.blockchain, eth.relay)
+ if eth.protocolManager, err = NewProtocolManager(eth.chainConfig, config.LightMode, config.NetworkId, eth.eventMux, eth.pow, eth.blockchain, nil, chainDb, odr, relay); err != nil {
+ return nil, err
+ }
+
+ eth.ApiBackend = &LesApiBackend{eth, nil}
+ eth.ApiBackend.gpo = gasprice.NewLightPriceOracle(eth.ApiBackend)
+ return eth, nil
+}
+
+type LightDummyAPI struct{}
+
+// Etherbase is the address that mining rewards will be send to
+func (s *LightDummyAPI) Etherbase() (common.Address, error) {
+ return common.Address{}, fmt.Errorf("not supported")
+}
+
+// Coinbase is the address that mining rewards will be send to (alias for Etherbase)
+func (s *LightDummyAPI) Coinbase() (common.Address, error) {
+ return common.Address{}, fmt.Errorf("not supported")
+}
+
+// Hashrate returns the POW hashrate
+func (s *LightDummyAPI) Hashrate() *rpc.HexNumber {
+ return rpc.NewHexNumber(0)
+}
+
+// Mining returns an indication if this node is currently mining.
+func (s *LightDummyAPI) Mining() bool {
+ return false
+}
+
+// APIs returns the collection of RPC services the ethereum package offers.
+// NOTE, some of these services probably need to be moved to somewhere else.
+func (s *LightEthereum) APIs() []rpc.API {
+ return append(ethapi.GetAPIs(s.ApiBackend, s.solcPath), []rpc.API{
+ {
+ Namespace: "eth",
+ Version: "1.0",
+ Service: &LightDummyAPI{},
+ Public: true,
+ }, {
+ Namespace: "eth",
+ Version: "1.0",
+ Service: downloader.NewPublicDownloaderAPI(s.protocolManager.downloader, s.eventMux),
+ Public: true,
+ }, {
+ Namespace: "eth",
+ Version: "1.0",
+ Service: filters.NewPublicFilterAPI(s.ApiBackend, true),
+ Public: true,
+ }, {
+ Namespace: "net",
+ Version: "1.0",
+ Service: s.netRPCService,
+ Public: true,
+ },
+ }...)
+}
+
+func (s *LightEthereum) ResetWithGenesisBlock(gb *types.Block) {
+ s.blockchain.ResetWithGenesisBlock(gb)
+}
+
+func (s *LightEthereum) BlockChain() *light.LightChain { return s.blockchain }
+func (s *LightEthereum) TxPool() *light.TxPool { return s.txPool }
+func (s *LightEthereum) LesVersion() int { return int(s.protocolManager.SubProtocols[0].Version) }
+func (s *LightEthereum) Downloader() *downloader.Downloader { return s.protocolManager.downloader }
+
+// Protocols implements node.Service, returning all the currently configured
+// network protocols to start.
+func (s *LightEthereum) Protocols() []p2p.Protocol {
+ return s.protocolManager.SubProtocols
+}
+
+// Start implements node.Service, starting all internal goroutines needed by the
+// Ethereum protocol implementation.
+func (s *LightEthereum) Start(srvr *p2p.Server) error {
+ s.netRPCService = ethapi.NewPublicNetAPI(srvr, s.netVersionId)
+ s.protocolManager.Start()
+ return nil
+}
+
+// Stop implements node.Service, terminating all internal goroutines used by the
+// Ethereum protocol.
+func (s *LightEthereum) Stop() error {
+ s.odr.Stop()
+ s.blockchain.Stop()
+ s.protocolManager.Stop()
+ s.txPool.Stop()
+
+ s.eventMux.Stop()
+
+ time.Sleep(time.Millisecond * 200)
+ s.chainDb.Close()
+ close(s.shutdownChan)
+
+ return nil
+}
diff --git a/les/fetcher.go b/les/fetcher.go
new file mode 100644
index 000000000..3fa5cf0e2
--- /dev/null
+++ b/les/fetcher.go
@@ -0,0 +1,295 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+// Package les implements the Light Ethereum Subprotocol.
+package les
+
+import (
+ "math/big"
+ "sync"
+ "time"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core"
+ "github.com/ethereum/go-ethereum/core/types"
+)
+
+type lightFetcher struct {
+ pm *ProtocolManager
+ odr *LesOdr
+ chain BlockChain
+
+ headAnnouncedMu sync.Mutex
+ headAnnouncedBy map[common.Hash][]*peer
+ currentTd *big.Int
+ deliverChn chan fetchResponse
+ reqMu sync.RWMutex
+ requested map[uint64]fetchRequest
+ timeoutChn chan uint64
+ notifyChn chan bool // true if initiated from outside
+ syncing bool
+ syncDone chan struct{}
+}
+
+type fetchRequest struct {
+ hash common.Hash
+ amount uint64
+ peer *peer
+}
+
+type fetchResponse struct {
+ reqID uint64
+ headers []*types.Header
+}
+
+func newLightFetcher(pm *ProtocolManager) *lightFetcher {
+ f := &lightFetcher{
+ pm: pm,
+ chain: pm.blockchain,
+ odr: pm.odr,
+ headAnnouncedBy: make(map[common.Hash][]*peer),
+ deliverChn: make(chan fetchResponse, 100),
+ requested: make(map[uint64]fetchRequest),
+ timeoutChn: make(chan uint64),
+ notifyChn: make(chan bool, 100),
+ syncDone: make(chan struct{}),
+ currentTd: big.NewInt(0),
+ }
+ go f.syncLoop()
+ return f
+}
+
+func (f *lightFetcher) notify(p *peer, head *announceData) {
+ var headHash common.Hash
+ if head == nil {
+ // initial notify
+ headHash = p.Head()
+ } else {
+ if core.GetTd(f.pm.chainDb, head.Hash, head.Number) != nil {
+ head.haveHeaders = head.Number
+ }
+ //fmt.Println("notify", p.id, head.Number, head.ReorgDepth, head.haveHeaders)
+ if !p.addNotify(head) {
+ //fmt.Println("addNotify fail")
+ f.pm.removePeer(p.id)
+ }
+ headHash = head.Hash
+ }
+ f.headAnnouncedMu.Lock()
+ f.headAnnouncedBy[headHash] = append(f.headAnnouncedBy[headHash], p)
+ f.headAnnouncedMu.Unlock()
+ f.notifyChn <- true
+}
+
+func (f *lightFetcher) gotHeader(header *types.Header) {
+ f.headAnnouncedMu.Lock()
+ defer f.headAnnouncedMu.Unlock()
+
+ hash := header.Hash()
+ peerList := f.headAnnouncedBy[hash]
+ if peerList == nil {
+ return
+ }
+ number := header.GetNumberU64()
+ td := core.GetTd(f.pm.chainDb, hash, number)
+ for _, peer := range peerList {
+ peer.lock.Lock()
+ ok := peer.gotHeader(hash, number, td)
+ peer.lock.Unlock()
+ if !ok {
+ //fmt.Println("gotHeader fail")
+ f.pm.removePeer(peer.id)
+ }
+ }
+ delete(f.headAnnouncedBy, hash)
+}
+
+func (f *lightFetcher) nextRequest() (*peer, *announceData) {
+ var bestPeer *peer
+ bestTd := f.currentTd
+ for _, peer := range f.pm.peers.AllPeers() {
+ peer.lock.RLock()
+ if !peer.headInfo.requested && (peer.headInfo.Td.Cmp(bestTd) > 0 ||
+ (bestPeer != nil && peer.headInfo.Td.Cmp(bestTd) == 0 && peer.headInfo.haveHeaders > bestPeer.headInfo.haveHeaders)) {
+ bestPeer = peer
+ bestTd = peer.headInfo.Td
+ }
+ peer.lock.RUnlock()
+ }
+ if bestPeer == nil {
+ return nil, nil
+ }
+ bestPeer.lock.Lock()
+ res := bestPeer.headInfo
+ res.requested = true
+ bestPeer.lock.Unlock()
+ for _, peer := range f.pm.peers.AllPeers() {
+ if peer != bestPeer {
+ peer.lock.Lock()
+ if peer.headInfo.Hash == bestPeer.headInfo.Hash && peer.headInfo.haveHeaders == bestPeer.headInfo.haveHeaders {
+ peer.headInfo.requested = true
+ }
+ peer.lock.Unlock()
+ }
+ }
+ return bestPeer, res
+}
+
+func (f *lightFetcher) deliverHeaders(reqID uint64, headers []*types.Header) {
+ f.deliverChn <- fetchResponse{reqID: reqID, headers: headers}
+}
+
+func (f *lightFetcher) requestedID(reqID uint64) bool {
+ f.reqMu.RLock()
+ _, ok := f.requested[reqID]
+ f.reqMu.RUnlock()
+ return ok
+}
+
+func (f *lightFetcher) request(p *peer, block *announceData) {
+ //fmt.Println("request", p.id, block.Number, block.haveHeaders)
+ amount := block.Number - block.haveHeaders
+ if amount == 0 {
+ return
+ }
+ if amount > 100 {
+ f.syncing = true
+ go func() {
+ //fmt.Println("f.pm.synchronise(p)")
+ f.pm.synchronise(p)
+ //fmt.Println("sync done")
+ f.syncDone <- struct{}{}
+ }()
+ return
+ }
+
+ reqID := f.odr.getNextReqID()
+ f.reqMu.Lock()
+ f.requested[reqID] = fetchRequest{hash: block.Hash, amount: amount, peer: p}
+ f.reqMu.Unlock()
+ cost := p.GetRequestCost(GetBlockHeadersMsg, int(amount))
+ p.fcServer.SendRequest(reqID, cost)
+ go p.RequestHeadersByHash(reqID, cost, block.Hash, int(amount), 0, true)
+ go func() {
+ time.Sleep(hardRequestTimeout)
+ f.timeoutChn <- reqID
+ }()
+}
+
+func (f *lightFetcher) processResponse(req fetchRequest, resp fetchResponse) bool {
+ if uint64(len(resp.headers)) != req.amount || resp.headers[0].Hash() != req.hash {
+ return false
+ }
+ headers := make([]*types.Header, req.amount)
+ for i, header := range resp.headers {
+ headers[int(req.amount)-1-i] = header
+ }
+ if _, err := f.chain.InsertHeaderChain(headers, 1); err != nil {
+ return false
+ }
+ for _, header := range headers {
+ td := core.GetTd(f.pm.chainDb, header.Hash(), header.GetNumberU64())
+ if td == nil {
+ return false
+ }
+ if td.Cmp(f.currentTd) > 0 {
+ f.currentTd = td
+ }
+ f.gotHeader(header)
+ }
+ return true
+}
+
+func (f *lightFetcher) checkSyncedHeaders() {
+ //fmt.Println("checkSyncedHeaders()")
+ for _, peer := range f.pm.peers.AllPeers() {
+ peer.lock.Lock()
+ h := peer.firstHeadInfo
+ remove := false
+ loop:
+ for h != nil {
+ if td := core.GetTd(f.pm.chainDb, h.Hash, h.Number); td != nil {
+ //fmt.Println(" found", h.Number)
+ ok := peer.gotHeader(h.Hash, h.Number, td)
+ if !ok {
+ remove = true
+ break loop
+ }
+ if td.Cmp(f.currentTd) > 0 {
+ f.currentTd = td
+ }
+ }
+ h = h.next
+ }
+ peer.lock.Unlock()
+ if remove {
+ //fmt.Println("checkSync fail")
+ f.pm.removePeer(peer.id)
+ }
+ }
+}
+
+func (f *lightFetcher) syncLoop() {
+ f.pm.wg.Add(1)
+ defer f.pm.wg.Done()
+
+ srtoNotify := false
+ for {
+ select {
+ case <-f.pm.quitSync:
+ return
+ case ext := <-f.notifyChn:
+ //fmt.Println("<-f.notifyChn", f.syncing, ext, srtoNotify)
+ s := srtoNotify
+ srtoNotify = false
+ if !f.syncing && !(ext && s) {
+ if p, r := f.nextRequest(); r != nil {
+ srtoNotify = true
+ go func() {
+ time.Sleep(softRequestTimeout)
+ f.notifyChn <- false
+ }()
+ f.request(p, r)
+ }
+ }
+ case reqID := <-f.timeoutChn:
+ f.reqMu.Lock()
+ req, ok := f.requested[reqID]
+ if ok {
+ delete(f.requested, reqID)
+ }
+ f.reqMu.Unlock()
+ if ok {
+ //fmt.Println("hard timeout")
+ f.pm.removePeer(req.peer.id)
+ }
+ case resp := <-f.deliverChn:
+ //fmt.Println("<-f.deliverChn", f.syncing)
+ f.reqMu.Lock()
+ req, ok := f.requested[resp.reqID]
+ delete(f.requested, resp.reqID)
+ f.reqMu.Unlock()
+ if !ok || !(f.syncing || f.processResponse(req, resp)) {
+ //fmt.Println("processResponse fail")
+ f.pm.removePeer(req.peer.id)
+ }
+ case <-f.syncDone:
+ //fmt.Println("<-f.syncDone", f.syncing)
+ f.checkSyncedHeaders()
+ f.syncing = false
+ }
+ }
+}
diff --git a/les/flowcontrol/control.go b/les/flowcontrol/control.go
new file mode 100644
index 000000000..1b569db0b
--- /dev/null
+++ b/les/flowcontrol/control.go
@@ -0,0 +1,172 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+// Package flowcontrol implements a client side flow control mechanism
+package flowcontrol
+
+import (
+ "sync"
+ "time"
+
+ "github.com/ethereum/go-ethereum/common/mclock"
+)
+
+const fcTimeConst = 1000000
+
+type ServerParams struct {
+ BufLimit, MinRecharge uint64
+}
+
+type ClientNode struct {
+ params *ServerParams
+ bufValue uint64
+ lastTime int64
+ lock sync.Mutex
+ cm *ClientManager
+ cmNode *cmNode
+}
+
+func NewClientNode(cm *ClientManager, params *ServerParams) *ClientNode {
+ node := &ClientNode{
+ cm: cm,
+ params: params,
+ bufValue: params.BufLimit,
+ lastTime: getTime(),
+ }
+ node.cmNode = cm.addNode(node)
+ return node
+}
+
+func (peer *ClientNode) Remove(cm *ClientManager) {
+ cm.removeNode(peer.cmNode)
+}
+
+func (peer *ClientNode) recalcBV(time int64) {
+ dt := uint64(time - peer.lastTime)
+ if time < peer.lastTime {
+ dt = 0
+ }
+ peer.bufValue += peer.params.MinRecharge * dt / fcTimeConst
+ if peer.bufValue > peer.params.BufLimit {
+ peer.bufValue = peer.params.BufLimit
+ }
+ peer.lastTime = time
+}
+
+func (peer *ClientNode) AcceptRequest() (uint64, bool) {
+ peer.lock.Lock()
+ defer peer.lock.Unlock()
+
+ time := getTime()
+ peer.recalcBV(time)
+ return peer.bufValue, peer.cm.accept(peer.cmNode, time)
+}
+
+func (peer *ClientNode) RequestProcessed(cost uint64) (bv, realCost uint64) {
+ peer.lock.Lock()
+ defer peer.lock.Unlock()
+
+ time := getTime()
+ peer.recalcBV(time)
+ peer.bufValue -= cost
+ peer.recalcBV(time)
+ rcValue, rcost := peer.cm.processed(peer.cmNode, time)
+ if rcValue < peer.params.BufLimit {
+ bv := peer.params.BufLimit - rcValue
+ if bv > peer.bufValue {
+ peer.bufValue = bv
+ }
+ }
+ return peer.bufValue, rcost
+}
+
+type ServerNode struct {
+ bufEstimate uint64
+ lastTime int64
+ params *ServerParams
+ sumCost uint64 // sum of req costs sent to this server
+ pending map[uint64]uint64 // value = sumCost after sending the given req
+ lock sync.Mutex
+}
+
+func NewServerNode(params *ServerParams) *ServerNode {
+ return &ServerNode{
+ bufEstimate: params.BufLimit,
+ lastTime: getTime(),
+ params: params,
+ pending: make(map[uint64]uint64),
+ }
+}
+
+func getTime() int64 {
+ return int64(mclock.Now())
+}
+
+func (peer *ServerNode) recalcBLE(time int64) {
+ dt := uint64(time - peer.lastTime)
+ if time < peer.lastTime {
+ dt = 0
+ }
+ peer.bufEstimate += peer.params.MinRecharge * dt / fcTimeConst
+ if peer.bufEstimate > peer.params.BufLimit {
+ peer.bufEstimate = peer.params.BufLimit
+ }
+ peer.lastTime = time
+}
+
+func (peer *ServerNode) canSend(maxCost uint64) uint64 {
+ if peer.bufEstimate >= maxCost {
+ return 0
+ }
+ return (maxCost - peer.bufEstimate) * fcTimeConst / peer.params.MinRecharge
+}
+
+func (peer *ServerNode) CanSend(maxCost uint64) uint64 {
+ peer.lock.Lock()
+ defer peer.lock.Unlock()
+
+ return peer.canSend(maxCost)
+}
+
+// blocks until request can be sent
+func (peer *ServerNode) SendRequest(reqID, maxCost uint64) {
+ peer.lock.Lock()
+ defer peer.lock.Unlock()
+
+ peer.recalcBLE(getTime())
+ for peer.bufEstimate < maxCost {
+ time.Sleep(time.Duration(peer.canSend(maxCost)))
+ peer.recalcBLE(getTime())
+ }
+ peer.bufEstimate -= maxCost
+ peer.sumCost += maxCost
+ if reqID >= 0 {
+ peer.pending[reqID] = peer.sumCost
+ }
+}
+
+func (peer *ServerNode) GotReply(reqID, bv uint64) {
+ peer.lock.Lock()
+ defer peer.lock.Unlock()
+
+ sc, ok := peer.pending[reqID]
+ if !ok {
+ return
+ }
+ delete(peer.pending, reqID)
+ peer.bufEstimate = bv - (peer.sumCost - sc)
+ peer.lastTime = getTime()
+}
diff --git a/les/flowcontrol/manager.go b/les/flowcontrol/manager.go
new file mode 100644
index 000000000..c0469e7b6
--- /dev/null
+++ b/les/flowcontrol/manager.go
@@ -0,0 +1,223 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+// Package flowcontrol implements a client side flow control mechanism
+package flowcontrol
+
+import (
+ "sync"
+ "time"
+)
+
+const rcConst = 1000000
+
+type cmNode struct {
+ node *ClientNode
+ lastUpdate int64
+ reqAccepted int64
+ serving, recharging bool
+ rcWeight uint64
+ rcValue, rcDelta int64
+ finishRecharge, startValue int64
+}
+
+func (node *cmNode) update(time int64) {
+ dt := time - node.lastUpdate
+ node.rcValue += node.rcDelta * dt / rcConst
+ node.lastUpdate = time
+ if node.recharging && time >= node.finishRecharge {
+ node.recharging = false
+ node.rcDelta = 0
+ node.rcValue = 0
+ }
+}
+
+func (node *cmNode) set(serving bool, simReqCnt, sumWeight uint64) {
+ if node.serving && !serving {
+ node.recharging = true
+ sumWeight += node.rcWeight
+ }
+ node.serving = serving
+ if node.recharging && serving {
+ node.recharging = false
+ sumWeight -= node.rcWeight
+ }
+
+ node.rcDelta = 0
+ if serving {
+ node.rcDelta = int64(rcConst / simReqCnt)
+ }
+ if node.recharging {
+ node.rcDelta = -int64(node.node.cm.rcRecharge * node.rcWeight / sumWeight)
+ node.finishRecharge = node.lastUpdate + node.rcValue*rcConst/(-node.rcDelta)
+ }
+}
+
+type ClientManager struct {
+ lock sync.Mutex
+ nodes map[*cmNode]struct{}
+ simReqCnt, sumWeight, rcSumValue uint64
+ maxSimReq, maxRcSum uint64
+ rcRecharge uint64
+ resumeQueue chan chan bool
+ time int64
+}
+
+func NewClientManager(rcTarget, maxSimReq, maxRcSum uint64) *ClientManager {
+ cm := &ClientManager{
+ nodes: make(map[*cmNode]struct{}),
+ resumeQueue: make(chan chan bool),
+ rcRecharge: rcConst * rcConst / (100*rcConst/rcTarget - rcConst),
+ maxSimReq: maxSimReq,
+ maxRcSum: maxRcSum,
+ }
+ go cm.queueProc()
+ return cm
+}
+
+func (self *ClientManager) Stop() {
+ self.lock.Lock()
+ defer self.lock.Unlock()
+
+ // signal any waiting accept routines to return false
+ self.nodes = make(map[*cmNode]struct{})
+ close(self.resumeQueue)
+}
+
+func (self *ClientManager) addNode(cnode *ClientNode) *cmNode {
+ time := getTime()
+ node := &cmNode{
+ node: cnode,
+ lastUpdate: time,
+ finishRecharge: time,
+ rcWeight: 1,
+ }
+ self.lock.Lock()
+ defer self.lock.Unlock()
+
+ self.nodes[node] = struct{}{}
+ self.update(getTime())
+ return node
+}
+
+func (self *ClientManager) removeNode(node *cmNode) {
+ self.lock.Lock()
+ defer self.lock.Unlock()
+
+ time := getTime()
+ self.stop(node, time)
+ delete(self.nodes, node)
+ self.update(time)
+}
+
+// recalc sumWeight
+func (self *ClientManager) updateNodes(time int64) (rce bool) {
+ var sumWeight, rcSum uint64
+ for node, _ := range self.nodes {
+ rc := node.recharging
+ node.update(time)
+ if rc && !node.recharging {
+ rce = true
+ }
+ if node.recharging {
+ sumWeight += node.rcWeight
+ }
+ rcSum += uint64(node.rcValue)
+ }
+ self.sumWeight = sumWeight
+ self.rcSumValue = rcSum
+ return
+}
+
+func (self *ClientManager) update(time int64) {
+ for {
+ firstTime := time
+ for node, _ := range self.nodes {
+ if node.recharging && node.finishRecharge < firstTime {
+ firstTime = node.finishRecharge
+ }
+ }
+ if self.updateNodes(firstTime) {
+ for node, _ := range self.nodes {
+ if node.recharging {
+ node.set(node.serving, self.simReqCnt, self.sumWeight)
+ }
+ }
+ } else {
+ self.time = time
+ return
+ }
+ }
+}
+
+func (self *ClientManager) canStartReq() bool {
+ return self.simReqCnt < self.maxSimReq && self.rcSumValue < self.maxRcSum
+}
+
+func (self *ClientManager) queueProc() {
+ for rc := range self.resumeQueue {
+ for {
+ time.Sleep(time.Millisecond * 10)
+ self.lock.Lock()
+ self.update(getTime())
+ cs := self.canStartReq()
+ self.lock.Unlock()
+ if cs {
+ break
+ }
+ }
+ close(rc)
+ }
+}
+
+func (self *ClientManager) accept(node *cmNode, time int64) bool {
+ self.lock.Lock()
+ defer self.lock.Unlock()
+
+ self.update(time)
+ if !self.canStartReq() {
+ resume := make(chan bool)
+ self.lock.Unlock()
+ self.resumeQueue <- resume
+ <-resume
+ self.lock.Lock()
+ if _, ok := self.nodes[node]; !ok {
+ return false // reject if node has been removed or manager has been stopped
+ }
+ }
+ self.simReqCnt++
+ node.set(true, self.simReqCnt, self.sumWeight)
+ node.startValue = node.rcValue
+ self.update(self.time)
+ return true
+}
+
+func (self *ClientManager) stop(node *cmNode, time int64) {
+ if node.serving {
+ self.update(time)
+ self.simReqCnt--
+ node.set(false, self.simReqCnt, self.sumWeight)
+ self.update(time)
+ }
+}
+
+func (self *ClientManager) processed(node *cmNode, time int64) (rcValue, rcCost uint64) {
+ self.lock.Lock()
+ defer self.lock.Unlock()
+
+ self.stop(node, time)
+ return uint64(node.rcValue), uint64(node.rcValue - node.startValue)
+}
diff --git a/les/handler.go b/les/handler.go
new file mode 100644
index 000000000..ef18af4d8
--- /dev/null
+++ b/les/handler.go
@@ -0,0 +1,854 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+// Package les implements the Light Ethereum Subprotocol.
+package les
+
+import (
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "math/big"
+ "sync"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core"
+ "github.com/ethereum/go-ethereum/core/state"
+ "github.com/ethereum/go-ethereum/core/types"
+ "github.com/ethereum/go-ethereum/eth"
+ "github.com/ethereum/go-ethereum/eth/downloader"
+ "github.com/ethereum/go-ethereum/ethdb"
+ "github.com/ethereum/go-ethereum/event"
+ "github.com/ethereum/go-ethereum/logger"
+ "github.com/ethereum/go-ethereum/logger/glog"
+ "github.com/ethereum/go-ethereum/p2p"
+ "github.com/ethereum/go-ethereum/p2p/discover"
+ "github.com/ethereum/go-ethereum/pow"
+ "github.com/ethereum/go-ethereum/rlp"
+ "github.com/ethereum/go-ethereum/trie"
+)
+
+const (
+ softResponseLimit = 2 * 1024 * 1024 // Target maximum size of returned blocks, headers or node data.
+ estHeaderRlpSize = 500 // Approximate size of an RLP encoded block header
+
+ ethVersion = 63 // equivalent eth version for the downloader
+
+ MaxHeaderFetch = 192 // Amount of block headers to be fetched per retrieval request
+ MaxBodyFetch = 32 // Amount of block bodies to be fetched per retrieval request
+ MaxReceiptFetch = 128 // Amount of transaction receipts to allow fetching per request
+ MaxCodeFetch = 64 // Amount of contract codes to allow fetching per request
+ MaxProofsFetch = 64 // Amount of merkle proofs to be fetched per retrieval request
+ MaxHeaderProofsFetch = 64 // Amount of merkle proofs to be fetched per retrieval request
+ MaxTxSend = 64 // Amount of transactions to be send per request
+
+ disableClientRemovePeer = true
+)
+
+// errIncompatibleConfig is returned if the requested protocols and configs are
+// not compatible (low protocol version restrictions and high requirements).
+var errIncompatibleConfig = errors.New("incompatible configuration")
+
+func errResp(code errCode, format string, v ...interface{}) error {
+ return fmt.Errorf("%v - %v", code, fmt.Sprintf(format, v...))
+}
+
+type hashFetcherFn func(common.Hash) error
+
+type BlockChain interface {
+ HasHeader(hash common.Hash) bool
+ GetHeader(hash common.Hash, number uint64) *types.Header
+ GetHeaderByHash(hash common.Hash) *types.Header
+ CurrentHeader() *types.Header
+ GetTdByHash(hash common.Hash) *big.Int
+ InsertHeaderChain(chain []*types.Header, checkFreq int) (int, error)
+ Rollback(chain []common.Hash)
+ Status() (td *big.Int, currentBlock common.Hash, genesisBlock common.Hash)
+ GetHeaderByNumber(number uint64) *types.Header
+ GetBlockHashesFromHash(hash common.Hash, max uint64) []common.Hash
+ LastBlockHash() common.Hash
+ Genesis() *types.Block
+}
+
+type txPool interface {
+ // AddTransactions should add the given transactions to the pool.
+ AddBatch([]*types.Transaction)
+}
+
+type ProtocolManager struct {
+ lightSync bool
+ txpool txPool
+ txrelay *LesTxRelay
+ networkId int
+ chainConfig *core.ChainConfig
+ blockchain BlockChain
+ chainDb ethdb.Database
+ odr *LesOdr
+ server *LesServer
+
+ downloader *downloader.Downloader
+ fetcher *lightFetcher
+ peers *peerSet
+
+ SubProtocols []p2p.Protocol
+
+ eventMux *event.TypeMux
+
+ // channels for fetcher, syncer, txsyncLoop
+ newPeerCh chan *peer
+ quitSync chan struct{}
+ noMorePeers chan struct{}
+
+ syncMu sync.Mutex
+ syncing bool
+ syncDone chan struct{}
+
+ // wait group is used for graceful shutdowns during downloading
+ // and processing
+ wg sync.WaitGroup
+}
+
+// NewProtocolManager returns a new ethereum sub protocol manager. The Ethereum sub protocol manages peers capable
+// with the ethereum network.
+func NewProtocolManager(chainConfig *core.ChainConfig, lightSync bool, networkId int, mux *event.TypeMux, pow pow.PoW, blockchain BlockChain, txpool txPool, chainDb ethdb.Database, odr *LesOdr, txrelay *LesTxRelay) (*ProtocolManager, error) {
+ // Create the protocol manager with the base fields
+ manager := &ProtocolManager{
+ lightSync: lightSync,
+ eventMux: mux,
+ blockchain: blockchain,
+ chainConfig: chainConfig,
+ chainDb: chainDb,
+ networkId: networkId,
+ txpool: txpool,
+ txrelay: txrelay,
+ odr: odr,
+ peers: newPeerSet(),
+ newPeerCh: make(chan *peer),
+ quitSync: make(chan struct{}),
+ noMorePeers: make(chan struct{}),
+ }
+ // Initiate a sub-protocol for every implemented version we can handle
+ manager.SubProtocols = make([]p2p.Protocol, 0, len(ProtocolVersions))
+ for i, version := range ProtocolVersions {
+ // Compatible, initialize the sub-protocol
+ version := version // Closure for the run
+ manager.SubProtocols = append(manager.SubProtocols, p2p.Protocol{
+ Name: "les",
+ Version: version,
+ Length: ProtocolLengths[i],
+ Run: func(p *p2p.Peer, rw p2p.MsgReadWriter) error {
+ peer := manager.newPeer(int(version), networkId, p, rw)
+ select {
+ case manager.newPeerCh <- peer:
+ manager.wg.Add(1)
+ defer manager.wg.Done()
+ return manager.handle(peer)
+ case <-manager.quitSync:
+ return p2p.DiscQuitting
+ }
+ },
+ NodeInfo: func() interface{} {
+ return manager.NodeInfo()
+ },
+ PeerInfo: func(id discover.NodeID) interface{} {
+ if p := manager.peers.Peer(fmt.Sprintf("%x", id[:8])); p != nil {
+ return p.Info()
+ }
+ return nil
+ },
+ })
+ }
+ if len(manager.SubProtocols) == 0 {
+ return nil, errIncompatibleConfig
+ }
+
+ removePeer := manager.removePeer
+ if disableClientRemovePeer {
+ removePeer = func(id string) {}
+ }
+
+ if lightSync {
+ glog.V(logger.Debug).Infof("LES: create downloader")
+ manager.downloader = downloader.New(downloader.LightSync, chainDb, manager.eventMux, blockchain.HasHeader, nil, blockchain.GetHeaderByHash,
+ nil, blockchain.CurrentHeader, nil, nil, nil, blockchain.GetTdByHash,
+ blockchain.InsertHeaderChain, nil, nil, blockchain.Rollback, removePeer)
+ manager.fetcher = newLightFetcher(manager)
+ }
+
+ if odr != nil {
+ odr.removePeer = removePeer
+ }
+
+ /*validator := func(block *types.Block, parent *types.Block) error {
+ return core.ValidateHeader(pow, block.Header(), parent.Header(), true, false)
+ }
+ heighter := func() uint64 {
+ return chainman.LastBlockNumberU64()
+ }
+ manager.fetcher = fetcher.New(chainman.GetBlockNoOdr, validator, nil, heighter, chainman.InsertChain, manager.removePeer)
+ */
+ return manager, nil
+}
+
+func (pm *ProtocolManager) removePeer(id string) {
+ // Short circuit if the peer was already removed
+ peer := pm.peers.Peer(id)
+ if peer == nil {
+ return
+ }
+ glog.V(logger.Debug).Infoln("Removing peer", id)
+
+ // Unregister the peer from the downloader and Ethereum peer set
+ glog.V(logger.Debug).Infof("LES: unregister peer %v", id)
+ if pm.lightSync {
+ pm.downloader.UnregisterPeer(id)
+ pm.odr.UnregisterPeer(peer)
+ if pm.txrelay != nil {
+ pm.txrelay.removePeer(id)
+ }
+ }
+ if err := pm.peers.Unregister(id); err != nil {
+ glog.V(logger.Error).Infoln("Removal failed:", err)
+ }
+ // Hard disconnect at the networking layer
+ if peer != nil {
+ peer.Peer.Disconnect(p2p.DiscUselessPeer)
+ }
+}
+
+func (pm *ProtocolManager) Start() {
+ if pm.lightSync {
+ // start sync handler
+ go pm.syncer()
+ } else {
+ go func() {
+ for range pm.newPeerCh {
+ }
+ }()
+ }
+}
+
+func (pm *ProtocolManager) Stop() {
+ // Showing a log message. During download / process this could actually
+ // take between 5 to 10 seconds and therefor feedback is required.
+ glog.V(logger.Info).Infoln("Stopping light ethereum protocol handler...")
+
+ // Quit the sync loop.
+ // After this send has completed, no new peers will be accepted.
+ pm.noMorePeers <- struct{}{}
+
+ close(pm.quitSync) // quits syncer, fetcher
+
+ // Disconnect existing sessions.
+ // This also closes the gate for any new registrations on the peer set.
+ // sessions which are already established but not added to pm.peers yet
+ // will exit when they try to register.
+ pm.peers.Close()
+
+ // Wait for any process action
+ pm.wg.Wait()
+
+ glog.V(logger.Info).Infoln("Light ethereum protocol handler stopped")
+}
+
+func (pm *ProtocolManager) newPeer(pv, nv int, p *p2p.Peer, rw p2p.MsgReadWriter) *peer {
+ return newPeer(pv, nv, p, newMeteredMsgWriter(rw))
+}
+
+// handle is the callback invoked to manage the life cycle of a les peer. When
+// this function terminates, the peer is disconnected.
+func (pm *ProtocolManager) handle(p *peer) error {
+ glog.V(logger.Debug).Infof("%v: peer connected [%s]", p, p.Name())
+
+ // Execute the LES handshake
+ td, head, genesis := pm.blockchain.Status()
+ headNum := core.GetBlockNumber(pm.chainDb, head)
+ if err := p.Handshake(td, head, headNum, genesis, pm.server); err != nil {
+ glog.V(logger.Debug).Infof("%v: handshake failed: %v", p, err)
+ return err
+ }
+ if rw, ok := p.rw.(*meteredMsgReadWriter); ok {
+ rw.Init(p.version)
+ }
+ // Register the peer locally
+ glog.V(logger.Detail).Infof("%v: adding peer", p)
+ if err := pm.peers.Register(p); err != nil {
+ glog.V(logger.Error).Infof("%v: addition failed: %v", p, err)
+ return err
+ }
+ defer func() {
+ if pm.server != nil && pm.server.fcManager != nil && p.fcClient != nil {
+ p.fcClient.Remove(pm.server.fcManager)
+ }
+ pm.removePeer(p.id)
+ }()
+
+ // Register the peer in the downloader. If the downloader considers it banned, we disconnect
+ glog.V(logger.Debug).Infof("LES: register peer %v", p.id)
+ if pm.lightSync {
+ requestHeadersByHash := func(origin common.Hash, amount int, skip int, reverse bool) error {
+ reqID := pm.odr.getNextReqID()
+ cost := p.GetRequestCost(GetBlockHeadersMsg, amount)
+ p.fcServer.SendRequest(reqID, cost)
+ return p.RequestHeadersByHash(reqID, cost, origin, amount, skip, reverse)
+ }
+ requestHeadersByNumber := func(origin uint64, amount int, skip int, reverse bool) error {
+ reqID := pm.odr.getNextReqID()
+ cost := p.GetRequestCost(GetBlockHeadersMsg, amount)
+ p.fcServer.SendRequest(reqID, cost)
+ return p.RequestHeadersByNumber(reqID, cost, origin, amount, skip, reverse)
+ }
+ if err := pm.downloader.RegisterPeer(p.id, ethVersion, p.HeadAndTd,
+ requestHeadersByHash, requestHeadersByNumber, nil, nil, nil); err != nil {
+ return err
+ }
+ pm.odr.RegisterPeer(p)
+ if pm.txrelay != nil {
+ pm.txrelay.addPeer(p)
+ }
+
+ pm.fetcher.notify(p, nil)
+ }
+
+ stop := make(chan struct{})
+ defer close(stop)
+ go func() {
+ // new block announce loop
+ for {
+ select {
+ case announce := <-p.announceChn:
+ p.SendAnnounce(announce)
+ //fmt.Println(" BROADCAST sent")
+ case <-stop:
+ return
+ }
+ }
+ }()
+
+ // main loop. handle incoming messages.
+ for {
+ if err := pm.handleMsg(p); err != nil {
+ glog.V(logger.Debug).Infof("%v: message handling failed: %v", p, err)
+ //fmt.Println("handleMsg err:", err)
+ return err
+ }
+ }
+}
+
+var reqList = []uint64{GetBlockHeadersMsg, GetBlockBodiesMsg, GetCodeMsg, GetReceiptsMsg, GetProofsMsg, SendTxMsg, GetHeaderProofsMsg}
+
+// handleMsg is invoked whenever an inbound message is received from a remote
+// peer. The remote connection is torn down upon returning any error.
+func (pm *ProtocolManager) handleMsg(p *peer) error {
+ // Read the next message from the remote peer, and ensure it's fully consumed
+ msg, err := p.rw.ReadMsg()
+ if err != nil {
+ return err
+ }
+
+ var costs *requestCosts
+ var reqCnt, maxReqs int
+
+ //fmt.Println("MSG", msg.Code, msg.Size)
+ if rc, ok := p.fcCosts[msg.Code]; ok { // check if msg is a supported request type
+ costs = rc
+ if p.fcClient == nil {
+ return errResp(ErrRequestRejected, "")
+ }
+ bv, ok := p.fcClient.AcceptRequest()
+ if !ok || bv < costs.baseCost {
+ return errResp(ErrRequestRejected, "")
+ }
+ maxReqs = 10000
+ if bv < pm.server.defParams.BufLimit {
+ d := bv - costs.baseCost
+ if d/10000 < costs.reqCost {
+ maxReqs = int(d / costs.reqCost)
+ }
+ }
+ }
+
+ if msg.Size > ProtocolMaxMsgSize {
+ return errResp(ErrMsgTooLarge, "%v > %v", msg.Size, ProtocolMaxMsgSize)
+ }
+ defer msg.Discard()
+
+ var deliverMsg *Msg
+
+ // Handle the message depending on its contents
+ switch msg.Code {
+ case StatusMsg:
+ glog.V(logger.Debug).Infof("LES: received StatusMsg from peer %v", p.id)
+ // Status messages should never arrive after the handshake
+ return errResp(ErrExtraStatusMsg, "uncontrolled status message")
+
+ // Block header query, collect the requested headers and reply
+ case AnnounceMsg:
+ var req announceData
+ if err := msg.Decode(&req); err != nil {
+ return errResp(ErrDecode, "%v: %v", msg, err)
+ }
+ //fmt.Println("RECEIVED", req.Number, req.Hash, req.Td, req.ReorgDepth)
+ pm.fetcher.notify(p, &req)
+
+ case GetBlockHeadersMsg:
+ glog.V(logger.Debug).Infof("LES: received GetBlockHeadersMsg from peer %v", p.id)
+ // Decode the complex header query
+ var req struct {
+ ReqID uint64
+ Query getBlockHeadersData
+ }
+ if err := msg.Decode(&req); err != nil {
+ return errResp(ErrDecode, "%v: %v", msg, err)
+ }
+
+ query := req.Query
+ if query.Amount > uint64(maxReqs) || query.Amount > MaxHeaderFetch {
+ return errResp(ErrRequestRejected, "")
+ }
+
+ hashMode := query.Origin.Hash != (common.Hash{})
+
+ // Gather headers until the fetch or network limits is reached
+ var (
+ bytes common.StorageSize
+ headers []*types.Header
+ unknown bool
+ )
+ for !unknown && len(headers) < int(query.Amount) && bytes < softResponseLimit {
+ // Retrieve the next header satisfying the query
+ var origin *types.Header
+ if hashMode {
+ origin = pm.blockchain.GetHeaderByHash(query.Origin.Hash)
+ } else {
+ origin = pm.blockchain.GetHeaderByNumber(query.Origin.Number)
+ }
+ if origin == nil {
+ break
+ }
+ number := origin.Number.Uint64()
+ headers = append(headers, origin)
+ bytes += estHeaderRlpSize
+
+ // Advance to the next header of the query
+ switch {
+ case query.Origin.Hash != (common.Hash{}) && query.Reverse:
+ // Hash based traversal towards the genesis block
+ for i := 0; i < int(query.Skip)+1; i++ {
+ if header := pm.blockchain.GetHeader(query.Origin.Hash, number); header != nil {
+ query.Origin.Hash = header.ParentHash
+ number--
+ } else {
+ unknown = true
+ break
+ }
+ }
+ case query.Origin.Hash != (common.Hash{}) && !query.Reverse:
+ // Hash based traversal towards the leaf block
+ if header := pm.blockchain.GetHeaderByNumber(origin.Number.Uint64() + query.Skip + 1); header != nil {
+ if pm.blockchain.GetBlockHashesFromHash(header.Hash(), query.Skip+1)[query.Skip] == query.Origin.Hash {
+ query.Origin.Hash = header.Hash()
+ } else {
+ unknown = true
+ }
+ } else {
+ unknown = true
+ }
+ case query.Reverse:
+ // Number based traversal towards the genesis block
+ if query.Origin.Number >= query.Skip+1 {
+ query.Origin.Number -= (query.Skip + 1)
+ } else {
+ unknown = true
+ }
+
+ case !query.Reverse:
+ // Number based traversal towards the leaf block
+ query.Origin.Number += (query.Skip + 1)
+ }
+ }
+
+ bv, rcost := p.fcClient.RequestProcessed(costs.baseCost + query.Amount*costs.reqCost)
+ pm.server.fcCostStats.update(msg.Code, query.Amount, rcost)
+ return p.SendBlockHeaders(req.ReqID, bv, headers)
+
+ case BlockHeadersMsg:
+ if pm.downloader == nil {
+ return errResp(ErrUnexpectedResponse, "")
+ }
+
+ glog.V(logger.Debug).Infof("LES: received BlockHeadersMsg from peer %v", p.id)
+ // A batch of headers arrived to one of our previous requests
+ var resp struct {
+ ReqID, BV uint64
+ Headers []*types.Header
+ }
+ if err := msg.Decode(&resp); err != nil {
+ return errResp(ErrDecode, "msg %v: %v", msg, err)
+ }
+ p.fcServer.GotReply(resp.ReqID, resp.BV)
+ if pm.fetcher.requestedID(resp.ReqID) {
+ pm.fetcher.deliverHeaders(resp.ReqID, resp.Headers)
+ } else {
+ err := pm.downloader.DeliverHeaders(p.id, resp.Headers)
+ if err != nil {
+ glog.V(logger.Debug).Infoln(err)
+ }
+ }
+
+ case GetBlockBodiesMsg:
+ glog.V(logger.Debug).Infof("LES: received GetBlockBodiesMsg from peer %v", p.id)
+ // Decode the retrieval message
+ var req struct {
+ ReqID uint64
+ Hashes []common.Hash
+ }
+ if err := msg.Decode(&req); err != nil {
+ return errResp(ErrDecode, "msg %v: %v", msg, err)
+ }
+ // Gather blocks until the fetch or network limits is reached
+ var (
+ bytes int
+ bodies []rlp.RawValue
+ )
+ reqCnt = len(req.Hashes)
+ if reqCnt > maxReqs || reqCnt > MaxBodyFetch {
+ return errResp(ErrRequestRejected, "")
+ }
+ for _, hash := range req.Hashes {
+ if bytes >= softResponseLimit {
+ break
+ }
+ // Retrieve the requested block body, stopping if enough was found
+ if data := core.GetBodyRLP(pm.chainDb, hash, core.GetBlockNumber(pm.chainDb, hash)); len(data) != 0 {
+ bodies = append(bodies, data)
+ bytes += len(data)
+ }
+ }
+ bv, rcost := p.fcClient.RequestProcessed(costs.baseCost + uint64(reqCnt)*costs.reqCost)
+ pm.server.fcCostStats.update(msg.Code, uint64(reqCnt), rcost)
+ return p.SendBlockBodiesRLP(req.ReqID, bv, bodies)
+
+ case BlockBodiesMsg:
+ if pm.odr == nil {
+ return errResp(ErrUnexpectedResponse, "")
+ }
+
+ glog.V(logger.Debug).Infof("LES: received BlockBodiesMsg from peer %v", p.id)
+ // A batch of block bodies arrived to one of our previous requests
+ var resp struct {
+ ReqID, BV uint64
+ Data []*types.Body
+ }
+ if err := msg.Decode(&resp); err != nil {
+ return errResp(ErrDecode, "msg %v: %v", msg, err)
+ }
+ p.fcServer.GotReply(resp.ReqID, resp.BV)
+ deliverMsg = &Msg{
+ MsgType: MsgBlockBodies,
+ ReqID: resp.ReqID,
+ Obj: resp.Data,
+ }
+
+ case GetCodeMsg:
+ glog.V(logger.Debug).Infof("LES: received GetCodeMsg from peer %v", p.id)
+ // Decode the retrieval message
+ var req struct {
+ ReqID uint64
+ Reqs []CodeReq
+ }
+ if err := msg.Decode(&req); err != nil {
+ return errResp(ErrDecode, "msg %v: %v", msg, err)
+ }
+ // Gather state data until the fetch or network limits is reached
+ var (
+ bytes int
+ data [][]byte
+ )
+ reqCnt = len(req.Reqs)
+ if reqCnt > maxReqs || reqCnt > MaxCodeFetch {
+ return errResp(ErrRequestRejected, "")
+ }
+ for _, req := range req.Reqs {
+ // Retrieve the requested state entry, stopping if enough was found
+ if header := core.GetHeader(pm.chainDb, req.BHash, core.GetBlockNumber(pm.chainDb, req.BHash)); header != nil {
+ if trie, _ := trie.New(header.Root, pm.chainDb); trie != nil {
+ sdata := trie.Get(req.AccKey)
+ var acc state.Account
+ if err := rlp.DecodeBytes(sdata, &acc); err == nil {
+ entry, _ := pm.chainDb.Get(acc.CodeHash)
+ if bytes+len(entry) >= softResponseLimit {
+ break
+ }
+ data = append(data, entry)
+ bytes += len(entry)
+ }
+ }
+ }
+ }
+ bv, rcost := p.fcClient.RequestProcessed(costs.baseCost + uint64(reqCnt)*costs.reqCost)
+ pm.server.fcCostStats.update(msg.Code, uint64(reqCnt), rcost)
+ return p.SendCode(req.ReqID, bv, data)
+
+ case CodeMsg:
+ if pm.odr == nil {
+ return errResp(ErrUnexpectedResponse, "")
+ }
+
+ glog.V(logger.Debug).Infof("LES: received CodeMsg from peer %v", p.id)
+ // A batch of node state data arrived to one of our previous requests
+ var resp struct {
+ ReqID, BV uint64
+ Data [][]byte
+ }
+ if err := msg.Decode(&resp); err != nil {
+ return errResp(ErrDecode, "msg %v: %v", msg, err)
+ }
+ p.fcServer.GotReply(resp.ReqID, resp.BV)
+ deliverMsg = &Msg{
+ MsgType: MsgCode,
+ ReqID: resp.ReqID,
+ Obj: resp.Data,
+ }
+
+ case GetReceiptsMsg:
+ glog.V(logger.Debug).Infof("LES: received GetReceiptsMsg from peer %v", p.id)
+ // Decode the retrieval message
+ var req struct {
+ ReqID uint64
+ Hashes []common.Hash
+ }
+ if err := msg.Decode(&req); err != nil {
+ return errResp(ErrDecode, "msg %v: %v", msg, err)
+ }
+ // Gather state data until the fetch or network limits is reached
+ var (
+ bytes int
+ receipts []rlp.RawValue
+ )
+ reqCnt = len(req.Hashes)
+ if reqCnt > maxReqs || reqCnt > MaxReceiptFetch {
+ return errResp(ErrRequestRejected, "")
+ }
+ for _, hash := range req.Hashes {
+ if bytes >= softResponseLimit {
+ break
+ }
+ // Retrieve the requested block's receipts, skipping if unknown to us
+ results := core.GetBlockReceipts(pm.chainDb, hash, core.GetBlockNumber(pm.chainDb, hash))
+ if results == nil {
+ if header := pm.blockchain.GetHeaderByHash(hash); header == nil || header.ReceiptHash != types.EmptyRootHash {
+ continue
+ }
+ }
+ // If known, encode and queue for response packet
+ if encoded, err := rlp.EncodeToBytes(results); err != nil {
+ glog.V(logger.Error).Infof("failed to encode receipt: %v", err)
+ } else {
+ receipts = append(receipts, encoded)
+ bytes += len(encoded)
+ }
+ }
+ bv, rcost := p.fcClient.RequestProcessed(costs.baseCost + uint64(reqCnt)*costs.reqCost)
+ pm.server.fcCostStats.update(msg.Code, uint64(reqCnt), rcost)
+ return p.SendReceiptsRLP(req.ReqID, bv, receipts)
+
+ case ReceiptsMsg:
+ if pm.odr == nil {
+ return errResp(ErrUnexpectedResponse, "")
+ }
+
+ glog.V(logger.Debug).Infof("LES: received ReceiptsMsg from peer %v", p.id)
+ // A batch of receipts arrived to one of our previous requests
+ var resp struct {
+ ReqID, BV uint64
+ Receipts []types.Receipts
+ }
+ if err := msg.Decode(&resp); err != nil {
+ return errResp(ErrDecode, "msg %v: %v", msg, err)
+ }
+ p.fcServer.GotReply(resp.ReqID, resp.BV)
+ deliverMsg = &Msg{
+ MsgType: MsgReceipts,
+ ReqID: resp.ReqID,
+ Obj: resp.Receipts,
+ }
+
+ case GetProofsMsg:
+ glog.V(logger.Debug).Infof("LES: received GetProofsMsg from peer %v", p.id)
+ // Decode the retrieval message
+ var req struct {
+ ReqID uint64
+ Reqs []ProofReq
+ }
+ if err := msg.Decode(&req); err != nil {
+ return errResp(ErrDecode, "msg %v: %v", msg, err)
+ }
+ // Gather state data until the fetch or network limits is reached
+ var (
+ bytes int
+ proofs proofsData
+ )
+ reqCnt = len(req.Reqs)
+ if reqCnt > maxReqs || reqCnt > MaxProofsFetch {
+ return errResp(ErrRequestRejected, "")
+ }
+ for _, req := range req.Reqs {
+ if bytes >= softResponseLimit {
+ break
+ }
+ // Retrieve the requested state entry, stopping if enough was found
+ if header := core.GetHeader(pm.chainDb, req.BHash, core.GetBlockNumber(pm.chainDb, req.BHash)); header != nil {
+ if tr, _ := trie.New(header.Root, pm.chainDb); tr != nil {
+ if len(req.AccKey) > 0 {
+ sdata := tr.Get(req.AccKey)
+ tr = nil
+ var acc state.Account
+ if err := rlp.DecodeBytes(sdata, &acc); err == nil {
+ tr, _ = trie.New(acc.Root, pm.chainDb)
+ }
+ }
+ if tr != nil {
+ proof := tr.Prove(req.Key)
+ proofs = append(proofs, proof)
+ bytes += len(proof)
+ }
+ }
+ }
+ }
+ bv, rcost := p.fcClient.RequestProcessed(costs.baseCost + uint64(reqCnt)*costs.reqCost)
+ pm.server.fcCostStats.update(msg.Code, uint64(reqCnt), rcost)
+ return p.SendProofs(req.ReqID, bv, proofs)
+
+ case ProofsMsg:
+ if pm.odr == nil {
+ return errResp(ErrUnexpectedResponse, "")
+ }
+
+ glog.V(logger.Debug).Infof("LES: received ProofsMsg from peer %v", p.id)
+ // A batch of merkle proofs arrived to one of our previous requests
+ var resp struct {
+ ReqID, BV uint64
+ Data [][]rlp.RawValue
+ }
+ if err := msg.Decode(&resp); err != nil {
+ return errResp(ErrDecode, "msg %v: %v", msg, err)
+ }
+ p.fcServer.GotReply(resp.ReqID, resp.BV)
+ deliverMsg = &Msg{
+ MsgType: MsgProofs,
+ ReqID: resp.ReqID,
+ Obj: resp.Data,
+ }
+
+ case GetHeaderProofsMsg:
+ glog.V(logger.Debug).Infof("LES: received GetHeaderProofsMsg from peer %v", p.id)
+ // Decode the retrieval message
+ var req struct {
+ ReqID uint64
+ Reqs []ChtReq
+ }
+ if err := msg.Decode(&req); err != nil {
+ return errResp(ErrDecode, "msg %v: %v", msg, err)
+ }
+ // Gather state data until the fetch or network limits is reached
+ var (
+ bytes int
+ proofs []ChtResp
+ )
+ reqCnt = len(req.Reqs)
+ if reqCnt > maxReqs || reqCnt > MaxHeaderProofsFetch {
+ return errResp(ErrRequestRejected, "")
+ }
+ for _, req := range req.Reqs {
+ if bytes >= softResponseLimit {
+ break
+ }
+
+ if header := pm.blockchain.GetHeaderByNumber(req.BlockNum); header != nil {
+ if root := getChtRoot(pm.chainDb, req.ChtNum); root != (common.Hash{}) {
+ if tr, _ := trie.New(root, pm.chainDb); tr != nil {
+ var encNumber [8]byte
+ binary.BigEndian.PutUint64(encNumber[:], req.BlockNum)
+ proof := tr.Prove(encNumber[:])
+ proofs = append(proofs, ChtResp{Header: header, Proof: proof})
+ bytes += len(proof) + estHeaderRlpSize
+ }
+ }
+ }
+ }
+ bv, rcost := p.fcClient.RequestProcessed(costs.baseCost + uint64(reqCnt)*costs.reqCost)
+ pm.server.fcCostStats.update(msg.Code, uint64(reqCnt), rcost)
+ return p.SendHeaderProofs(req.ReqID, bv, proofs)
+
+ case HeaderProofsMsg:
+ if pm.odr == nil {
+ return errResp(ErrUnexpectedResponse, "")
+ }
+
+ glog.V(logger.Debug).Infof("LES: received HeaderProofsMsg from peer %v", p.id)
+ var resp struct {
+ ReqID, BV uint64
+ Data []ChtResp
+ }
+ if err := msg.Decode(&resp); err != nil {
+ return errResp(ErrDecode, "msg %v: %v", msg, err)
+ }
+ p.fcServer.GotReply(resp.ReqID, resp.BV)
+ deliverMsg = &Msg{
+ MsgType: MsgHeaderProofs,
+ ReqID: resp.ReqID,
+ Obj: resp.Data,
+ }
+
+ case SendTxMsg:
+ if pm.txpool == nil {
+ return errResp(ErrUnexpectedResponse, "")
+ }
+ // Transactions arrived, parse all of them and deliver to the pool
+ var txs []*types.Transaction
+ if err := msg.Decode(&txs); err != nil {
+ return errResp(ErrDecode, "msg %v: %v", msg, err)
+ }
+ reqCnt = len(txs)
+ if reqCnt > maxReqs || reqCnt > MaxTxSend {
+ return errResp(ErrRequestRejected, "")
+ }
+ pm.txpool.AddBatch(txs)
+ _, rcost := p.fcClient.RequestProcessed(costs.baseCost + uint64(reqCnt)*costs.reqCost)
+ pm.server.fcCostStats.update(msg.Code, uint64(reqCnt), rcost)
+
+ default:
+ glog.V(logger.Debug).Infof("LES: received unknown message with code %d from peer %v", msg.Code, p.id)
+ return errResp(ErrInvalidMsgCode, "%v", msg.Code)
+ }
+
+ if deliverMsg != nil {
+ return pm.odr.Deliver(p, deliverMsg)
+ }
+
+ return nil
+}
+
+// NodeInfo retrieves some protocol metadata about the running host node.
+func (self *ProtocolManager) NodeInfo() *eth.EthNodeInfo {
+ return &eth.EthNodeInfo{
+ Network: self.networkId,
+ Difficulty: self.blockchain.GetTdByHash(self.blockchain.LastBlockHash()),
+ Genesis: self.blockchain.Genesis().Hash(),
+ Head: self.blockchain.LastBlockHash(),
+ }
+}
diff --git a/les/handler_test.go b/les/handler_test.go
new file mode 100644
index 000000000..2aa7a5590
--- /dev/null
+++ b/les/handler_test.go
@@ -0,0 +1,322 @@
+package les
+
+import (
+ "math/rand"
+ "testing"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core"
+ "github.com/ethereum/go-ethereum/core/types"
+ "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/eth/downloader"
+ "github.com/ethereum/go-ethereum/p2p"
+ "github.com/ethereum/go-ethereum/rlp"
+ "github.com/ethereum/go-ethereum/trie"
+)
+
+func expectResponse(r p2p.MsgReader, msgcode, reqID, bv uint64, data interface{}) error {
+ type resp struct {
+ ReqID, BV uint64
+ Data interface{}
+ }
+ return p2p.ExpectMsg(r, msgcode, resp{reqID, bv, data})
+}
+
+// Tests that block headers can be retrieved from a remote chain based on user queries.
+func TestGetBlockHeadersLes1(t *testing.T) { testGetBlockHeaders(t, 1) }
+
+func testGetBlockHeaders(t *testing.T, protocol int) {
+ pm, _, _ := newTestProtocolManagerMust(t, false, downloader.MaxHashFetch+15, nil)
+ bc := pm.blockchain.(*core.BlockChain)
+ peer, _ := newTestPeer(t, "peer", protocol, pm, true)
+ defer peer.close()
+
+ // Create a "random" unknown hash for testing
+ var unknown common.Hash
+ for i, _ := range unknown {
+ unknown[i] = byte(i)
+ }
+ // Create a batch of tests for various scenarios
+ limit := uint64(MaxHeaderFetch)
+ tests := []struct {
+ query *getBlockHeadersData // The query to execute for header retrieval
+ expect []common.Hash // The hashes of the block whose headers are expected
+ }{
+ // A single random block should be retrievable by hash and number too
+ {
+ &getBlockHeadersData{Origin: hashOrNumber{Hash: bc.GetBlockByNumber(limit / 2).Hash()}, Amount: 1},
+ []common.Hash{bc.GetBlockByNumber(limit / 2).Hash()},
+ }, {
+ &getBlockHeadersData{Origin: hashOrNumber{Number: limit / 2}, Amount: 1},
+ []common.Hash{bc.GetBlockByNumber(limit / 2).Hash()},
+ },
+ // Multiple headers should be retrievable in both directions
+ {
+ &getBlockHeadersData{Origin: hashOrNumber{Number: limit / 2}, Amount: 3},
+ []common.Hash{
+ bc.GetBlockByNumber(limit / 2).Hash(),
+ bc.GetBlockByNumber(limit/2 + 1).Hash(),
+ bc.GetBlockByNumber(limit/2 + 2).Hash(),
+ },
+ }, {
+ &getBlockHeadersData{Origin: hashOrNumber{Number: limit / 2}, Amount: 3, Reverse: true},
+ []common.Hash{
+ bc.GetBlockByNumber(limit / 2).Hash(),
+ bc.GetBlockByNumber(limit/2 - 1).Hash(),
+ bc.GetBlockByNumber(limit/2 - 2).Hash(),
+ },
+ },
+ // Multiple headers with skip lists should be retrievable
+ {
+ &getBlockHeadersData{Origin: hashOrNumber{Number: limit / 2}, Skip: 3, Amount: 3},
+ []common.Hash{
+ bc.GetBlockByNumber(limit / 2).Hash(),
+ bc.GetBlockByNumber(limit/2 + 4).Hash(),
+ bc.GetBlockByNumber(limit/2 + 8).Hash(),
+ },
+ }, {
+ &getBlockHeadersData{Origin: hashOrNumber{Number: limit / 2}, Skip: 3, Amount: 3, Reverse: true},
+ []common.Hash{
+ bc.GetBlockByNumber(limit / 2).Hash(),
+ bc.GetBlockByNumber(limit/2 - 4).Hash(),
+ bc.GetBlockByNumber(limit/2 - 8).Hash(),
+ },
+ },
+ // The chain endpoints should be retrievable
+ {
+ &getBlockHeadersData{Origin: hashOrNumber{Number: 0}, Amount: 1},
+ []common.Hash{bc.GetBlockByNumber(0).Hash()},
+ }, {
+ &getBlockHeadersData{Origin: hashOrNumber{Number: bc.CurrentBlock().NumberU64()}, Amount: 1},
+ []common.Hash{bc.CurrentBlock().Hash()},
+ },
+ // Ensure protocol limits are honored
+ /*{
+ &getBlockHeadersData{Origin: hashOrNumber{Number: bc.CurrentBlock().NumberU64() - 1}, Amount: limit + 10, Reverse: true},
+ bc.GetBlockHashesFromHash(bc.CurrentBlock().Hash(), limit),
+ },*/
+ // Check that requesting more than available is handled gracefully
+ {
+ &getBlockHeadersData{Origin: hashOrNumber{Number: bc.CurrentBlock().NumberU64() - 4}, Skip: 3, Amount: 3},
+ []common.Hash{
+ bc.GetBlockByNumber(bc.CurrentBlock().NumberU64() - 4).Hash(),
+ bc.GetBlockByNumber(bc.CurrentBlock().NumberU64()).Hash(),
+ },
+ }, {
+ &getBlockHeadersData{Origin: hashOrNumber{Number: 4}, Skip: 3, Amount: 3, Reverse: true},
+ []common.Hash{
+ bc.GetBlockByNumber(4).Hash(),
+ bc.GetBlockByNumber(0).Hash(),
+ },
+ },
+ // Check that requesting more than available is handled gracefully, even if mid skip
+ {
+ &getBlockHeadersData{Origin: hashOrNumber{Number: bc.CurrentBlock().NumberU64() - 4}, Skip: 2, Amount: 3},
+ []common.Hash{
+ bc.GetBlockByNumber(bc.CurrentBlock().NumberU64() - 4).Hash(),
+ bc.GetBlockByNumber(bc.CurrentBlock().NumberU64() - 1).Hash(),
+ },
+ }, {
+ &getBlockHeadersData{Origin: hashOrNumber{Number: 4}, Skip: 2, Amount: 3, Reverse: true},
+ []common.Hash{
+ bc.GetBlockByNumber(4).Hash(),
+ bc.GetBlockByNumber(1).Hash(),
+ },
+ },
+ // Check that non existing headers aren't returned
+ {
+ &getBlockHeadersData{Origin: hashOrNumber{Hash: unknown}, Amount: 1},
+ []common.Hash{},
+ }, {
+ &getBlockHeadersData{Origin: hashOrNumber{Number: bc.CurrentBlock().NumberU64() + 1}, Amount: 1},
+ []common.Hash{},
+ },
+ }
+ // Run each of the tests and verify the results against the chain
+ var reqID uint64
+ for i, tt := range tests {
+ // Collect the headers to expect in the response
+ headers := []*types.Header{}
+ for _, hash := range tt.expect {
+ headers = append(headers, bc.GetHeaderByHash(hash))
+ }
+ // Send the hash request and verify the response
+ reqID++
+ cost := peer.GetRequestCost(GetBlockHeadersMsg, int(tt.query.Amount))
+ sendRequest(peer.app, GetBlockHeadersMsg, reqID, cost, tt.query)
+ if err := expectResponse(peer.app, BlockHeadersMsg, reqID, testBufLimit, headers); err != nil {
+ t.Errorf("test %d: headers mismatch: %v", i, err)
+ }
+ }
+}
+
+// Tests that block contents can be retrieved from a remote chain based on their hashes.
+func TestGetBlockBodiesLes1(t *testing.T) { testGetBlockBodies(t, 1) }
+
+func testGetBlockBodies(t *testing.T, protocol int) {
+ pm, _, _ := newTestProtocolManagerMust(t, false, downloader.MaxBlockFetch+15, nil)
+ bc := pm.blockchain.(*core.BlockChain)
+ peer, _ := newTestPeer(t, "peer", protocol, pm, true)
+ defer peer.close()
+
+ // Create a batch of tests for various scenarios
+ limit := MaxBodyFetch
+ tests := []struct {
+ random int // Number of blocks to fetch randomly from the chain
+ explicit []common.Hash // Explicitly requested blocks
+ available []bool // Availability of explicitly requested blocks
+ expected int // Total number of existing blocks to expect
+ }{
+ {1, nil, nil, 1}, // A single random block should be retrievable
+ {10, nil, nil, 10}, // Multiple random blocks should be retrievable
+ {limit, nil, nil, limit}, // The maximum possible blocks should be retrievable
+ //{limit + 1, nil, nil, limit}, // No more than the possible block count should be returned
+ {0, []common.Hash{bc.Genesis().Hash()}, []bool{true}, 1}, // The genesis block should be retrievable
+ {0, []common.Hash{bc.CurrentBlock().Hash()}, []bool{true}, 1}, // The chains head block should be retrievable
+ {0, []common.Hash{common.Hash{}}, []bool{false}, 0}, // A non existent block should not be returned
+
+ // Existing and non-existing blocks interleaved should not cause problems
+ {0, []common.Hash{
+ common.Hash{},
+ bc.GetBlockByNumber(1).Hash(),
+ common.Hash{},
+ bc.GetBlockByNumber(10).Hash(),
+ common.Hash{},
+ bc.GetBlockByNumber(100).Hash(),
+ common.Hash{},
+ }, []bool{false, true, false, true, false, true, false}, 3},
+ }
+ // Run each of the tests and verify the results against the chain
+ var reqID uint64
+ for i, tt := range tests {
+ // Collect the hashes to request, and the response to expect
+ hashes, seen := []common.Hash{}, make(map[int64]bool)
+ bodies := []*types.Body{}
+
+ for j := 0; j < tt.random; j++ {
+ for {
+ num := rand.Int63n(int64(bc.CurrentBlock().NumberU64()))
+ if !seen[num] {
+ seen[num] = true
+
+ block := bc.GetBlockByNumber(uint64(num))
+ hashes = append(hashes, block.Hash())
+ if len(bodies) < tt.expected {
+ bodies = append(bodies, &types.Body{Transactions: block.Transactions(), Uncles: block.Uncles()})
+ }
+ break
+ }
+ }
+ }
+ for j, hash := range tt.explicit {
+ hashes = append(hashes, hash)
+ if tt.available[j] && len(bodies) < tt.expected {
+ block := bc.GetBlockByHash(hash)
+ bodies = append(bodies, &types.Body{Transactions: block.Transactions(), Uncles: block.Uncles()})
+ }
+ }
+ reqID++
+ // Send the hash request and verify the response
+ cost := peer.GetRequestCost(GetBlockBodiesMsg, len(hashes))
+ sendRequest(peer.app, GetBlockBodiesMsg, reqID, cost, hashes)
+ if err := expectResponse(peer.app, BlockBodiesMsg, reqID, testBufLimit, bodies); err != nil {
+ t.Errorf("test %d: bodies mismatch: %v", i, err)
+ }
+ }
+}
+
+// Tests that the contract codes can be retrieved based on account addresses.
+func TestGetCodeLes1(t *testing.T) { testGetCode(t, 1) }
+
+func testGetCode(t *testing.T, protocol int) {
+ // Assemble the test environment
+ pm, _, _ := newTestProtocolManagerMust(t, false, 4, testChainGen)
+ bc := pm.blockchain.(*core.BlockChain)
+ peer, _ := newTestPeer(t, "peer", protocol, pm, true)
+ defer peer.close()
+
+ var codereqs []*CodeReq
+ var codes [][]byte
+
+ for i := uint64(0); i <= bc.CurrentBlock().NumberU64(); i++ {
+ header := bc.GetHeaderByNumber(i)
+ req := &CodeReq{
+ BHash: header.Hash(),
+ AccKey: crypto.Keccak256(testContractAddr[:]),
+ }
+ codereqs = append(codereqs, req)
+ if i >= testContractDeployed {
+ codes = append(codes, testContractCodeDeployed)
+ }
+ }
+
+ cost := peer.GetRequestCost(GetCodeMsg, len(codereqs))
+ sendRequest(peer.app, GetCodeMsg, 42, cost, codereqs)
+ if err := expectResponse(peer.app, CodeMsg, 42, testBufLimit, codes); err != nil {
+ t.Errorf("codes mismatch: %v", err)
+ }
+}
+
+// Tests that the transaction receipts can be retrieved based on hashes.
+func TestGetReceiptLes1(t *testing.T) { testGetReceipt(t, 1) }
+
+func testGetReceipt(t *testing.T, protocol int) {
+ // Assemble the test environment
+ pm, db, _ := newTestProtocolManagerMust(t, false, 4, testChainGen)
+ bc := pm.blockchain.(*core.BlockChain)
+ peer, _ := newTestPeer(t, "peer", protocol, pm, true)
+ defer peer.close()
+
+ // Collect the hashes to request, and the response to expect
+ hashes, receipts := []common.Hash{}, []types.Receipts{}
+ for i := uint64(0); i <= bc.CurrentBlock().NumberU64(); i++ {
+ block := bc.GetBlockByNumber(i)
+
+ hashes = append(hashes, block.Hash())
+ receipts = append(receipts, core.GetBlockReceipts(db, block.Hash(), block.NumberU64()))
+ }
+ // Send the hash request and verify the response
+ cost := peer.GetRequestCost(GetReceiptsMsg, len(hashes))
+ sendRequest(peer.app, GetReceiptsMsg, 42, cost, hashes)
+ if err := expectResponse(peer.app, ReceiptsMsg, 42, testBufLimit, receipts); err != nil {
+ t.Errorf("receipts mismatch: %v", err)
+ }
+}
+
+// Tests that trie merkle proofs can be retrieved
+func TestGetProofsLes1(t *testing.T) { testGetReceipt(t, 1) }
+
+func testGetProofs(t *testing.T, protocol int) {
+ // Assemble the test environment
+ pm, db, _ := newTestProtocolManagerMust(t, false, 4, testChainGen)
+ bc := pm.blockchain.(*core.BlockChain)
+ peer, _ := newTestPeer(t, "peer", protocol, pm, true)
+ defer peer.close()
+
+ var proofreqs []ProofReq
+ var proofs [][]rlp.RawValue
+
+ accounts := []common.Address{testBankAddress, acc1Addr, acc2Addr, common.Address{}}
+ for i := uint64(0); i <= bc.CurrentBlock().NumberU64(); i++ {
+ header := bc.GetHeaderByNumber(i)
+ root := header.Root
+ trie, _ := trie.New(root, db)
+
+ for _, acc := range accounts {
+ req := ProofReq{
+ BHash: header.Hash(),
+ Key: acc[:],
+ }
+ proofreqs = append(proofreqs, req)
+
+ proof := trie.Prove(crypto.Keccak256(acc[:]))
+ proofs = append(proofs, proof)
+ }
+ }
+ // Send the proof request and verify the response
+ cost := peer.GetRequestCost(GetProofsMsg, len(proofreqs))
+ sendRequest(peer.app, GetProofsMsg, 42, cost, proofreqs)
+ if err := expectResponse(peer.app, ProofsMsg, 42, testBufLimit, proofs); err != nil {
+ t.Errorf("proofs mismatch: %v", err)
+ }
+}
diff --git a/les/helper_test.go b/les/helper_test.go
new file mode 100644
index 000000000..9817e5004
--- /dev/null
+++ b/les/helper_test.go
@@ -0,0 +1,318 @@
+// This file contains some shares testing functionality, common to multiple
+// different files and modules being tested.
+
+package les
+
+import (
+ "crypto/ecdsa"
+ "crypto/rand"
+ "math/big"
+ "sync"
+ "testing"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core"
+ "github.com/ethereum/go-ethereum/core/types"
+ "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/ethdb"
+ "github.com/ethereum/go-ethereum/event"
+ "github.com/ethereum/go-ethereum/les/flowcontrol"
+ "github.com/ethereum/go-ethereum/light"
+ "github.com/ethereum/go-ethereum/p2p"
+ "github.com/ethereum/go-ethereum/p2p/discover"
+ "github.com/ethereum/go-ethereum/params"
+)
+
+var (
+ testBankKey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291")
+ testBankAddress = crypto.PubkeyToAddress(testBankKey.PublicKey)
+ testBankFunds = big.NewInt(1000000)
+
+ acc1Key, _ = crypto.HexToECDSA("8a1f9a8f95be41cd7ccb6168179afb4504aefe388d1e14474d32c45c72ce7b7a")
+ acc2Key, _ = crypto.HexToECDSA("49a7b37aa6f6645917e7b807e9d1c00d4fa71f18343b0d4122a4d2df64dd6fee")
+ acc1Addr = crypto.PubkeyToAddress(acc1Key.PublicKey)
+ acc2Addr = crypto.PubkeyToAddress(acc2Key.PublicKey)
+
+ testContractCode = common.Hex2Bytes("606060405260cc8060106000396000f360606040526000357c01000000000000000000000000000000000000000000000000000000009004806360cd2685146041578063c16431b914606b57603f565b005b6055600480803590602001909190505060a9565b6040518082815260200191505060405180910390f35b60886004808035906020019091908035906020019091905050608a565b005b80600060005083606481101560025790900160005b50819055505b5050565b6000600060005082606481101560025790900160005b5054905060c7565b91905056")
+ testContractAddr common.Address
+ testContractCodeDeployed = testContractCode[16:]
+ testContractDeployed = uint64(2)
+
+ testBufLimit = uint64(100)
+)
+
+/*
+contract test {
+
+ uint256[100] data;
+
+ function Put(uint256 addr, uint256 value) {
+ data[addr] = value;
+ }
+
+ function Get(uint256 addr) constant returns (uint256 value) {
+ return data[addr];
+ }
+}
+*/
+
+func testChainGen(i int, block *core.BlockGen) {
+ switch i {
+ case 0:
+ // In block 1, the test bank sends account #1 some ether.
+ tx, _ := types.NewTransaction(block.TxNonce(testBankAddress), acc1Addr, big.NewInt(10000), params.TxGas, nil, nil).SignECDSA(testBankKey)
+ block.AddTx(tx)
+ case 1:
+ // In block 2, the test bank sends some more ether to account #1.
+ // acc1Addr passes it on to account #2.
+ // acc1Addr creates a test contract.
+ tx1, _ := types.NewTransaction(block.TxNonce(testBankAddress), acc1Addr, big.NewInt(1000), params.TxGas, nil, nil).SignECDSA(testBankKey)
+ nonce := block.TxNonce(acc1Addr)
+ tx2, _ := types.NewTransaction(nonce, acc2Addr, big.NewInt(1000), params.TxGas, nil, nil).SignECDSA(acc1Key)
+ nonce++
+ tx3, _ := types.NewContractCreation(nonce, big.NewInt(0), big.NewInt(200000), big.NewInt(0), testContractCode).SignECDSA(acc1Key)
+ testContractAddr = crypto.CreateAddress(acc1Addr, nonce)
+ block.AddTx(tx1)
+ block.AddTx(tx2)
+ block.AddTx(tx3)
+ case 2:
+ // Block 3 is empty but was mined by account #2.
+ block.SetCoinbase(acc2Addr)
+ block.SetExtra([]byte("yeehaw"))
+ data := common.Hex2Bytes("C16431B900000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000001")
+ tx, _ := types.NewTransaction(block.TxNonce(testBankAddress), testContractAddr, big.NewInt(0), big.NewInt(100000), nil, data).SignECDSA(testBankKey)
+ block.AddTx(tx)
+ case 3:
+ // Block 4 includes blocks 2 and 3 as uncle headers (with modified extra data).
+ b2 := block.PrevBlock(1).Header()
+ b2.Extra = []byte("foo")
+ block.AddUncle(b2)
+ b3 := block.PrevBlock(2).Header()
+ b3.Extra = []byte("foo")
+ block.AddUncle(b3)
+ data := common.Hex2Bytes("C16431B900000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000002")
+ tx, _ := types.NewTransaction(block.TxNonce(testBankAddress), testContractAddr, big.NewInt(0), big.NewInt(100000), nil, data).SignECDSA(testBankKey)
+ block.AddTx(tx)
+ }
+}
+
+func testRCL() RequestCostList {
+ cl := make(RequestCostList, len(reqList))
+ for i, code := range reqList {
+ cl[i].MsgCode = code
+ cl[i].BaseCost = 0
+ cl[i].ReqCost = 0
+ }
+ return cl
+}
+
+// newTestProtocolManager creates a new protocol manager for testing purposes,
+// with the given number of blocks already known, and potential notification
+// channels for different events.
+func newTestProtocolManager(lightSync bool, blocks int, generator func(int, *core.BlockGen)) (*ProtocolManager, ethdb.Database, *LesOdr, error) {
+ var (
+ evmux = new(event.TypeMux)
+ pow = new(core.FakePow)
+ db, _ = ethdb.NewMemDatabase()
+ genesis = core.WriteGenesisBlockForTesting(db, core.GenesisAccount{Address: testBankAddress, Balance: testBankFunds})
+ chainConfig = &core.ChainConfig{HomesteadBlock: big.NewInt(0)} // homestead set to 0 because of chain maker
+ odr *LesOdr
+ chain BlockChain
+ )
+
+ if lightSync {
+ odr = NewLesOdr(db)
+ chain, _ = light.NewLightChain(odr, chainConfig, pow, evmux)
+ } else {
+ blockchain, _ := core.NewBlockChain(db, chainConfig, pow, evmux)
+ gchain, _ := core.GenerateChain(nil, genesis, db, blocks, generator)
+ if _, err := blockchain.InsertChain(gchain); err != nil {
+ panic(err)
+ }
+ chain = blockchain
+ }
+
+ pm, err := NewProtocolManager(chainConfig, lightSync, NetworkId, evmux, pow, chain, nil, db, odr, nil)
+ if err != nil {
+ return nil, nil, nil, err
+ }
+ if !lightSync {
+ srv := &LesServer{protocolManager: pm}
+ pm.server = srv
+
+ srv.defParams = &flowcontrol.ServerParams{
+ BufLimit: testBufLimit,
+ MinRecharge: 1,
+ }
+
+ srv.fcManager = flowcontrol.NewClientManager(50, 10, 1000000000)
+ srv.fcCostStats = newCostStats(nil)
+ }
+ pm.Start()
+ return pm, db, odr, nil
+}
+
+// newTestProtocolManagerMust creates a new protocol manager for testing purposes,
+// with the given number of blocks already known, and potential notification
+// channels for different events. In case of an error, the constructor force-
+// fails the test.
+func newTestProtocolManagerMust(t *testing.T, lightSync bool, blocks int, generator func(int, *core.BlockGen)) (*ProtocolManager, ethdb.Database, *LesOdr) {
+ pm, db, odr, err := newTestProtocolManager(lightSync, blocks, generator)
+ if err != nil {
+ t.Fatalf("Failed to create protocol manager: %v", err)
+ }
+ return pm, db, odr
+}
+
+// testTxPool is a fake, helper transaction pool for testing purposes
+type testTxPool struct {
+ pool []*types.Transaction // Collection of all transactions
+ added chan<- []*types.Transaction // Notification channel for new transactions
+
+ lock sync.RWMutex // Protects the transaction pool
+}
+
+// AddTransactions appends a batch of transactions to the pool, and notifies any
+// listeners if the addition channel is non nil
+func (p *testTxPool) AddBatch(txs []*types.Transaction) {
+ p.lock.Lock()
+ defer p.lock.Unlock()
+
+ p.pool = append(p.pool, txs...)
+ if p.added != nil {
+ p.added <- txs
+ }
+}
+
+// GetTransactions returns all the transactions known to the pool
+func (p *testTxPool) GetTransactions() types.Transactions {
+ p.lock.RLock()
+ defer p.lock.RUnlock()
+
+ txs := make([]*types.Transaction, len(p.pool))
+ copy(txs, p.pool)
+
+ return txs
+}
+
+// newTestTransaction create a new dummy transaction.
+func newTestTransaction(from *ecdsa.PrivateKey, nonce uint64, datasize int) *types.Transaction {
+ tx := types.NewTransaction(nonce, common.Address{}, big.NewInt(0), big.NewInt(100000), big.NewInt(0), make([]byte, datasize))
+ tx, _ = tx.SignECDSA(from)
+
+ return tx
+}
+
+// testPeer is a simulated peer to allow testing direct network calls.
+type testPeer struct {
+ net p2p.MsgReadWriter // Network layer reader/writer to simulate remote messaging
+ app *p2p.MsgPipeRW // Application layer reader/writer to simulate the local side
+ *peer
+}
+
+// newTestPeer creates a new peer registered at the given protocol manager.
+func newTestPeer(t *testing.T, name string, version int, pm *ProtocolManager, shake bool) (*testPeer, <-chan error) {
+ // Create a message pipe to communicate through
+ app, net := p2p.MsgPipe()
+
+ // Generate a random id and create the peer
+ var id discover.NodeID
+ rand.Read(id[:])
+
+ peer := pm.newPeer(version, NetworkId, p2p.NewPeer(id, name, nil), net)
+
+ // Start the peer on a new thread
+ errc := make(chan error, 1)
+ go func() {
+ select {
+ case pm.newPeerCh <- peer:
+ errc <- pm.handle(peer)
+ case <-pm.quitSync:
+ errc <- p2p.DiscQuitting
+ }
+ }()
+ tp := &testPeer{
+ app: app,
+ net: net,
+ peer: peer,
+ }
+ // Execute any implicitly requested handshakes and return
+ if shake {
+ td, head, genesis := pm.blockchain.Status()
+ headNum := pm.blockchain.CurrentHeader().Number.Uint64()
+ tp.handshake(t, td, head, headNum, genesis)
+ }
+ return tp, errc
+}
+
+func newTestPeerPair(name string, version int, pm, pm2 *ProtocolManager) (*peer, <-chan error, *peer, <-chan error) {
+ // Create a message pipe to communicate through
+ app, net := p2p.MsgPipe()
+
+ // Generate a random id and create the peer
+ var id discover.NodeID
+ rand.Read(id[:])
+
+ peer := pm.newPeer(version, NetworkId, p2p.NewPeer(id, name, nil), net)
+ peer2 := pm2.newPeer(version, NetworkId, p2p.NewPeer(id, name, nil), app)
+
+ // Start the peer on a new thread
+ errc := make(chan error, 1)
+ errc2 := make(chan error, 1)
+ go func() {
+ select {
+ case pm.newPeerCh <- peer:
+ errc <- pm.handle(peer)
+ case <-pm.quitSync:
+ errc <- p2p.DiscQuitting
+ }
+ }()
+ go func() {
+ select {
+ case pm2.newPeerCh <- peer2:
+ errc2 <- pm2.handle(peer2)
+ case <-pm2.quitSync:
+ errc2 <- p2p.DiscQuitting
+ }
+ }()
+ return peer, errc, peer2, errc2
+}
+
+// handshake simulates a trivial handshake that expects the same state from the
+// remote side as we are simulating locally.
+func (p *testPeer) handshake(t *testing.T, td *big.Int, head common.Hash, headNum uint64, genesis common.Hash) {
+ var expList keyValueList
+ expList = expList.add("protocolVersion", uint64(p.version))
+ expList = expList.add("networkId", uint64(NetworkId))
+ expList = expList.add("headTd", td)
+ expList = expList.add("headHash", head)
+ expList = expList.add("headNum", headNum)
+ expList = expList.add("genesisHash", genesis)
+ sendList := make(keyValueList, len(expList))
+ copy(sendList, expList)
+ expList = expList.add("serveHeaders", nil)
+ expList = expList.add("serveChainSince", uint64(0))
+ expList = expList.add("serveStateSince", uint64(0))
+ expList = expList.add("txRelay", nil)
+ expList = expList.add("flowControl/BL", testBufLimit)
+ expList = expList.add("flowControl/MRR", uint64(1))
+ expList = expList.add("flowControl/MRC", testRCL())
+
+ if err := p2p.ExpectMsg(p.app, StatusMsg, expList); err != nil {
+ t.Fatalf("status recv: %v", err)
+ }
+ if err := p2p.Send(p.app, StatusMsg, sendList); err != nil {
+ t.Fatalf("status send: %v", err)
+ }
+
+ p.fcServerParams = &flowcontrol.ServerParams{
+ BufLimit: testBufLimit,
+ MinRecharge: 1,
+ }
+}
+
+// close terminates the local side of the peer, notifying the remote protocol
+// manager of termination.
+func (p *testPeer) close() {
+ p.app.Close()
+}
diff --git a/les/metrics.go b/les/metrics.go
new file mode 100644
index 000000000..88e6726e2
--- /dev/null
+++ b/les/metrics.go
@@ -0,0 +1,111 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+package les
+
+import (
+ "github.com/ethereum/go-ethereum/metrics"
+ "github.com/ethereum/go-ethereum/p2p"
+)
+
+var (
+ /* propTxnInPacketsMeter = metrics.NewMeter("eth/prop/txns/in/packets")
+ propTxnInTrafficMeter = metrics.NewMeter("eth/prop/txns/in/traffic")
+ propTxnOutPacketsMeter = metrics.NewMeter("eth/prop/txns/out/packets")
+ propTxnOutTrafficMeter = metrics.NewMeter("eth/prop/txns/out/traffic")
+ propHashInPacketsMeter = metrics.NewMeter("eth/prop/hashes/in/packets")
+ propHashInTrafficMeter = metrics.NewMeter("eth/prop/hashes/in/traffic")
+ propHashOutPacketsMeter = metrics.NewMeter("eth/prop/hashes/out/packets")
+ propHashOutTrafficMeter = metrics.NewMeter("eth/prop/hashes/out/traffic")
+ propBlockInPacketsMeter = metrics.NewMeter("eth/prop/blocks/in/packets")
+ propBlockInTrafficMeter = metrics.NewMeter("eth/prop/blocks/in/traffic")
+ propBlockOutPacketsMeter = metrics.NewMeter("eth/prop/blocks/out/packets")
+ propBlockOutTrafficMeter = metrics.NewMeter("eth/prop/blocks/out/traffic")
+ reqHashInPacketsMeter = metrics.NewMeter("eth/req/hashes/in/packets")
+ reqHashInTrafficMeter = metrics.NewMeter("eth/req/hashes/in/traffic")
+ reqHashOutPacketsMeter = metrics.NewMeter("eth/req/hashes/out/packets")
+ reqHashOutTrafficMeter = metrics.NewMeter("eth/req/hashes/out/traffic")
+ reqBlockInPacketsMeter = metrics.NewMeter("eth/req/blocks/in/packets")
+ reqBlockInTrafficMeter = metrics.NewMeter("eth/req/blocks/in/traffic")
+ reqBlockOutPacketsMeter = metrics.NewMeter("eth/req/blocks/out/packets")
+ reqBlockOutTrafficMeter = metrics.NewMeter("eth/req/blocks/out/traffic")
+ reqHeaderInPacketsMeter = metrics.NewMeter("eth/req/headers/in/packets")
+ reqHeaderInTrafficMeter = metrics.NewMeter("eth/req/headers/in/traffic")
+ reqHeaderOutPacketsMeter = metrics.NewMeter("eth/req/headers/out/packets")
+ reqHeaderOutTrafficMeter = metrics.NewMeter("eth/req/headers/out/traffic")
+ reqBodyInPacketsMeter = metrics.NewMeter("eth/req/bodies/in/packets")
+ reqBodyInTrafficMeter = metrics.NewMeter("eth/req/bodies/in/traffic")
+ reqBodyOutPacketsMeter = metrics.NewMeter("eth/req/bodies/out/packets")
+ reqBodyOutTrafficMeter = metrics.NewMeter("eth/req/bodies/out/traffic")
+ reqStateInPacketsMeter = metrics.NewMeter("eth/req/states/in/packets")
+ reqStateInTrafficMeter = metrics.NewMeter("eth/req/states/in/traffic")
+ reqStateOutPacketsMeter = metrics.NewMeter("eth/req/states/out/packets")
+ reqStateOutTrafficMeter = metrics.NewMeter("eth/req/states/out/traffic")
+ reqReceiptInPacketsMeter = metrics.NewMeter("eth/req/receipts/in/packets")
+ reqReceiptInTrafficMeter = metrics.NewMeter("eth/req/receipts/in/traffic")
+ reqReceiptOutPacketsMeter = metrics.NewMeter("eth/req/receipts/out/packets")
+ reqReceiptOutTrafficMeter = metrics.NewMeter("eth/req/receipts/out/traffic")*/
+ miscInPacketsMeter = metrics.NewMeter("les/misc/in/packets")
+ miscInTrafficMeter = metrics.NewMeter("les/misc/in/traffic")
+ miscOutPacketsMeter = metrics.NewMeter("les/misc/out/packets")
+ miscOutTrafficMeter = metrics.NewMeter("les/misc/out/traffic")
+)
+
+// meteredMsgReadWriter is a wrapper around a p2p.MsgReadWriter, capable of
+// accumulating the above defined metrics based on the data stream contents.
+type meteredMsgReadWriter struct {
+ p2p.MsgReadWriter // Wrapped message stream to meter
+ version int // Protocol version to select correct meters
+}
+
+// newMeteredMsgWriter wraps a p2p MsgReadWriter with metering support. If the
+// metrics system is disabled, this fucntion returns the original object.
+func newMeteredMsgWriter(rw p2p.MsgReadWriter) p2p.MsgReadWriter {
+ if !metrics.Enabled {
+ return rw
+ }
+ return &meteredMsgReadWriter{MsgReadWriter: rw}
+}
+
+// Init sets the protocol version used by the stream to know which meters to
+// increment in case of overlapping message ids between protocol versions.
+func (rw *meteredMsgReadWriter) Init(version int) {
+ rw.version = version
+}
+
+func (rw *meteredMsgReadWriter) ReadMsg() (p2p.Msg, error) {
+ // Read the message and short circuit in case of an error
+ msg, err := rw.MsgReadWriter.ReadMsg()
+ if err != nil {
+ return msg, err
+ }
+ // Account for the data traffic
+ packets, traffic := miscInPacketsMeter, miscInTrafficMeter
+ packets.Mark(1)
+ traffic.Mark(int64(msg.Size))
+
+ return msg, err
+}
+
+func (rw *meteredMsgReadWriter) WriteMsg(msg p2p.Msg) error {
+ // Account for the data traffic
+ packets, traffic := miscOutPacketsMeter, miscOutTrafficMeter
+ packets.Mark(1)
+ traffic.Mark(int64(msg.Size))
+
+ // Send the packet to the p2p layer
+ return rw.MsgReadWriter.WriteMsg(msg)
+}
diff --git a/les/odr.go b/les/odr.go
new file mode 100644
index 000000000..2674ba6a1
--- /dev/null
+++ b/les/odr.go
@@ -0,0 +1,247 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+package les
+
+import (
+ "sync"
+ "time"
+
+ "github.com/ethereum/go-ethereum/common/mclock"
+ "github.com/ethereum/go-ethereum/ethdb"
+ "github.com/ethereum/go-ethereum/light"
+ "github.com/ethereum/go-ethereum/logger"
+ "github.com/ethereum/go-ethereum/logger/glog"
+ "golang.org/x/net/context"
+)
+
+var (
+ softRequestTimeout = time.Millisecond * 500
+ hardRequestTimeout = time.Second * 10
+ retryPeers = time.Second * 1
+)
+
+// peerDropFn is a callback type for dropping a peer detected as malicious.
+type peerDropFn func(id string)
+
+type LesOdr struct {
+ light.OdrBackend
+ db ethdb.Database
+ stop chan struct{}
+ removePeer peerDropFn
+ mlock, clock sync.Mutex
+ sentReqs map[uint64]*sentReq
+ peers *odrPeerSet
+ lastReqID uint64
+}
+
+func NewLesOdr(db ethdb.Database) *LesOdr {
+ return &LesOdr{
+ db: db,
+ stop: make(chan struct{}),
+ peers: newOdrPeerSet(),
+ sentReqs: make(map[uint64]*sentReq),
+ }
+}
+
+func (odr *LesOdr) Stop() {
+ close(odr.stop)
+}
+
+func (odr *LesOdr) Database() ethdb.Database {
+ return odr.db
+}
+
+// validatorFunc is a function that processes a message and returns true if
+// it was a meaningful answer to a given request
+type validatorFunc func(ethdb.Database, *Msg) bool
+
+// sentReq is a request waiting for an answer that satisfies its valFunc
+type sentReq struct {
+ valFunc validatorFunc
+ sentTo map[*peer]chan struct{}
+ lock sync.RWMutex // protects acces to sentTo
+ answered chan struct{} // closed and set to nil when any peer answers it
+}
+
+// RegisterPeer registers a new LES peer to the ODR capable peer set
+func (self *LesOdr) RegisterPeer(p *peer) error {
+ return self.peers.register(p)
+}
+
+// UnregisterPeer removes a peer from the ODR capable peer set
+func (self *LesOdr) UnregisterPeer(p *peer) {
+ self.peers.unregister(p)
+}
+
+const (
+ MsgBlockBodies = iota
+ MsgCode
+ MsgReceipts
+ MsgProofs
+ MsgHeaderProofs
+)
+
+// Msg encodes a LES message that delivers reply data for a request
+type Msg struct {
+ MsgType int
+ ReqID uint64
+ Obj interface{}
+}
+
+// Deliver is called by the LES protocol manager to deliver ODR reply messages to waiting requests
+func (self *LesOdr) Deliver(peer *peer, msg *Msg) error {
+ var delivered chan struct{}
+ self.mlock.Lock()
+ req, ok := self.sentReqs[msg.ReqID]
+ self.mlock.Unlock()
+ if ok {
+ req.lock.Lock()
+ delivered, ok = req.sentTo[peer]
+ req.lock.Unlock()
+ }
+
+ if !ok {
+ return errResp(ErrUnexpectedResponse, "reqID = %v", msg.ReqID)
+ }
+
+ if req.valFunc(self.db, msg) {
+ close(delivered)
+ req.lock.Lock()
+ if req.answered != nil {
+ close(req.answered)
+ req.answered = nil
+ }
+ req.lock.Unlock()
+ return nil
+ }
+ return errResp(ErrInvalidResponse, "reqID = %v", msg.ReqID)
+}
+
+func (self *LesOdr) requestPeer(req *sentReq, peer *peer, delivered, timeout chan struct{}, reqWg *sync.WaitGroup) {
+ stime := mclock.Now()
+ defer func() {
+ req.lock.Lock()
+ delete(req.sentTo, peer)
+ req.lock.Unlock()
+ reqWg.Done()
+ }()
+
+ select {
+ case <-delivered:
+ servTime := uint64(mclock.Now() - stime)
+ self.peers.updateTimeout(peer, false)
+ self.peers.updateServTime(peer, servTime)
+ return
+ case <-time.After(softRequestTimeout):
+ close(timeout)
+ if self.peers.updateTimeout(peer, true) {
+ self.removePeer(peer.id)
+ }
+ case <-self.stop:
+ return
+ }
+
+ select {
+ case <-delivered:
+ servTime := uint64(mclock.Now() - stime)
+ self.peers.updateServTime(peer, servTime)
+ return
+ case <-time.After(hardRequestTimeout):
+ self.removePeer(peer.id)
+ case <-self.stop:
+ return
+ }
+}
+
+// networkRequest sends a request to known peers until an answer is received
+// or the context is cancelled
+func (self *LesOdr) networkRequest(ctx context.Context, lreq LesOdrRequest) error {
+ answered := make(chan struct{})
+ req := &sentReq{
+ valFunc: lreq.Valid,
+ sentTo: make(map[*peer]chan struct{}),
+ answered: answered, // reply delivered by any peer
+ }
+ reqID := self.getNextReqID()
+ self.mlock.Lock()
+ self.sentReqs[reqID] = req
+ self.mlock.Unlock()
+
+ reqWg := new(sync.WaitGroup)
+ reqWg.Add(1)
+ defer reqWg.Done()
+ go func() {
+ reqWg.Wait()
+ self.mlock.Lock()
+ delete(self.sentReqs, reqID)
+ self.mlock.Unlock()
+ }()
+
+ exclude := make(map[*peer]struct{})
+ for {
+ if peer := self.peers.bestPeer(lreq, exclude); peer == nil {
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-req.answered:
+ return nil
+ case <-time.After(retryPeers):
+ }
+ } else {
+ exclude[peer] = struct{}{}
+ delivered := make(chan struct{})
+ timeout := make(chan struct{})
+ req.lock.Lock()
+ req.sentTo[peer] = delivered
+ req.lock.Unlock()
+ reqWg.Add(1)
+ cost := lreq.GetCost(peer)
+ peer.fcServer.SendRequest(reqID, cost)
+ go self.requestPeer(req, peer, delivered, timeout, reqWg)
+ lreq.Request(reqID, peer)
+
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-answered:
+ return nil
+ case <-timeout:
+ }
+ }
+ }
+}
+
+// Retrieve tries to fetch an object from the local db, then from the LES network.
+// If the network retrieval was successful, it stores the object in local db.
+func (self *LesOdr) Retrieve(ctx context.Context, req light.OdrRequest) (err error) {
+ lreq := LesRequest(req)
+ err = self.networkRequest(ctx, lreq)
+ if err == nil {
+ // retrieved from network, store in db
+ req.StoreResult(self.db)
+ } else {
+ glog.V(logger.Debug).Infof("networkRequest err = %v", err)
+ }
+ return
+}
+
+func (self *LesOdr) getNextReqID() uint64 {
+ self.clock.Lock()
+ defer self.clock.Unlock()
+
+ self.lastReqID++
+ return self.lastReqID
+}
diff --git a/les/odr_peerset.go b/les/odr_peerset.go
new file mode 100644
index 000000000..0323ce27f
--- /dev/null
+++ b/les/odr_peerset.go
@@ -0,0 +1,119 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+package les
+
+import (
+ "sync"
+)
+
+const dropTimeoutRatio = 20
+
+type odrPeerInfo struct {
+ reqTimeSum, reqTimeCnt, reqCnt, timeoutCnt uint64
+}
+
+// odrPeerSet represents the collection of active peer participating in the block
+// download procedure.
+type odrPeerSet struct {
+ peers map[*peer]*odrPeerInfo
+ lock sync.RWMutex
+}
+
+// newPeerSet creates a new peer set top track the active download sources.
+func newOdrPeerSet() *odrPeerSet {
+ return &odrPeerSet{
+ peers: make(map[*peer]*odrPeerInfo),
+ }
+}
+
+// Register injects a new peer into the working set, or returns an error if the
+// peer is already known.
+func (ps *odrPeerSet) register(p *peer) error {
+ ps.lock.Lock()
+ defer ps.lock.Unlock()
+
+ if _, ok := ps.peers[p]; ok {
+ return errAlreadyRegistered
+ }
+ ps.peers[p] = &odrPeerInfo{}
+ return nil
+}
+
+// Unregister removes a remote peer from the active set, disabling any further
+// actions to/from that particular entity.
+func (ps *odrPeerSet) unregister(p *peer) error {
+ ps.lock.Lock()
+ defer ps.lock.Unlock()
+
+ if _, ok := ps.peers[p]; !ok {
+ return errNotRegistered
+ }
+ delete(ps.peers, p)
+ return nil
+}
+
+func (ps *odrPeerSet) peerPriority(p *peer, info *odrPeerInfo, req LesOdrRequest) uint64 {
+ tm := p.fcServer.CanSend(req.GetCost(p))
+ if info.reqTimeCnt > 0 {
+ tm += info.reqTimeSum / info.reqTimeCnt
+ }
+ return tm
+}
+
+func (ps *odrPeerSet) bestPeer(req LesOdrRequest, exclude map[*peer]struct{}) *peer {
+ var best *peer
+ var bpv uint64
+ ps.lock.Lock()
+ defer ps.lock.Unlock()
+
+ for p, info := range ps.peers {
+ if _, ok := exclude[p]; !ok {
+ pv := ps.peerPriority(p, info, req)
+ if best == nil || pv < bpv {
+ best = p
+ bpv = pv
+ }
+ }
+ }
+ return best
+}
+
+func (ps *odrPeerSet) updateTimeout(p *peer, timeout bool) (drop bool) {
+ ps.lock.Lock()
+ defer ps.lock.Unlock()
+
+ if info, ok := ps.peers[p]; ok {
+ info.reqCnt++
+ if timeout {
+ // check ratio before increase to allow an extra timeout
+ if info.timeoutCnt*dropTimeoutRatio >= info.reqCnt {
+ return true
+ }
+ info.timeoutCnt++
+ }
+ }
+ return false
+}
+
+func (ps *odrPeerSet) updateServTime(p *peer, servTime uint64) {
+ ps.lock.Lock()
+ defer ps.lock.Unlock()
+
+ if info, ok := ps.peers[p]; ok {
+ info.reqTimeSum += servTime
+ info.reqTimeCnt++
+ }
+}
diff --git a/les/odr_requests.go b/les/odr_requests.go
new file mode 100644
index 000000000..bf0346977
--- /dev/null
+++ b/les/odr_requests.go
@@ -0,0 +1,325 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+// Package light implements on-demand retrieval capable state and chain objects
+// for the Ethereum Light Client.
+package les
+
+import (
+ "bytes"
+ "encoding/binary"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core"
+ "github.com/ethereum/go-ethereum/core/types"
+ "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/ethdb"
+ "github.com/ethereum/go-ethereum/light"
+ "github.com/ethereum/go-ethereum/logger"
+ "github.com/ethereum/go-ethereum/logger/glog"
+ "github.com/ethereum/go-ethereum/rlp"
+ "github.com/ethereum/go-ethereum/trie"
+)
+
+type LesOdrRequest interface {
+ GetCost(*peer) uint64
+ Request(uint64, *peer) error
+ Valid(ethdb.Database, *Msg) bool // if true, keeps the retrieved object
+}
+
+func LesRequest(req light.OdrRequest) LesOdrRequest {
+ switch r := req.(type) {
+ case *light.BlockRequest:
+ return (*BlockRequest)(r)
+ case *light.ReceiptsRequest:
+ return (*ReceiptsRequest)(r)
+ case *light.TrieRequest:
+ return (*TrieRequest)(r)
+ case *light.CodeRequest:
+ return (*CodeRequest)(r)
+ case *light.ChtRequest:
+ return (*ChtRequest)(r)
+ default:
+ return nil
+ }
+}
+
+// BlockRequest is the ODR request type for block bodies
+type BlockRequest light.BlockRequest
+
+// GetCost returns the cost of the given ODR request according to the serving
+// peer's cost table (implementation of LesOdrRequest)
+func (self *BlockRequest) GetCost(peer *peer) uint64 {
+ return peer.GetRequestCost(GetBlockBodiesMsg, 1)
+}
+
+// Request sends an ODR request to the LES network (implementation of LesOdrRequest)
+func (self *BlockRequest) Request(reqID uint64, peer *peer) error {
+ glog.V(logger.Debug).Infof("ODR: requesting body of block %08x from peer %v", self.Hash[:4], peer.id)
+ return peer.RequestBodies(reqID, self.GetCost(peer), []common.Hash{self.Hash})
+}
+
+// Valid processes an ODR request reply message from the LES network
+// returns true and stores results in memory if the message was a valid reply
+// to the request (implementation of LesOdrRequest)
+func (self *BlockRequest) Valid(db ethdb.Database, msg *Msg) bool {
+ glog.V(logger.Debug).Infof("ODR: validating body of block %08x", self.Hash[:4])
+ if msg.MsgType != MsgBlockBodies {
+ glog.V(logger.Debug).Infof("ODR: invalid message type")
+ return false
+ }
+ bodies := msg.Obj.([]*types.Body)
+ if len(bodies) != 1 {
+ glog.V(logger.Debug).Infof("ODR: invalid number of entries: %d", len(bodies))
+ return false
+ }
+ body := bodies[0]
+ header := core.GetHeader(db, self.Hash, self.Number)
+ if header == nil {
+ glog.V(logger.Debug).Infof("ODR: header not found for block %08x", self.Hash[:4])
+ return false
+ }
+ txHash := types.DeriveSha(types.Transactions(body.Transactions))
+ if header.TxHash != txHash {
+ glog.V(logger.Debug).Infof("ODR: header.TxHash %08x does not match received txHash %08x", header.TxHash[:4], txHash[:4])
+ return false
+ }
+ uncleHash := types.CalcUncleHash(body.Uncles)
+ if header.UncleHash != uncleHash {
+ glog.V(logger.Debug).Infof("ODR: header.UncleHash %08x does not match received uncleHash %08x", header.UncleHash[:4], uncleHash[:4])
+ return false
+ }
+ data, err := rlp.EncodeToBytes(body)
+ if err != nil {
+ glog.V(logger.Debug).Infof("ODR: body RLP encode error: %v", err)
+ return false
+ }
+ self.Rlp = data
+ glog.V(logger.Debug).Infof("ODR: validation successful")
+ return true
+}
+
+// ReceiptsRequest is the ODR request type for block receipts by block hash
+type ReceiptsRequest light.ReceiptsRequest
+
+// GetCost returns the cost of the given ODR request according to the serving
+// peer's cost table (implementation of LesOdrRequest)
+func (self *ReceiptsRequest) GetCost(peer *peer) uint64 {
+ return peer.GetRequestCost(GetReceiptsMsg, 1)
+}
+
+// Request sends an ODR request to the LES network (implementation of LesOdrRequest)
+func (self *ReceiptsRequest) Request(reqID uint64, peer *peer) error {
+ glog.V(logger.Debug).Infof("ODR: requesting receipts for block %08x from peer %v", self.Hash[:4], peer.id)
+ return peer.RequestReceipts(reqID, self.GetCost(peer), []common.Hash{self.Hash})
+}
+
+// Valid processes an ODR request reply message from the LES network
+// returns true and stores results in memory if the message was a valid reply
+// to the request (implementation of LesOdrRequest)
+func (self *ReceiptsRequest) Valid(db ethdb.Database, msg *Msg) bool {
+ glog.V(logger.Debug).Infof("ODR: validating receipts for block %08x", self.Hash[:4])
+ if msg.MsgType != MsgReceipts {
+ glog.V(logger.Debug).Infof("ODR: invalid message type")
+ return false
+ }
+ receipts := msg.Obj.([]types.Receipts)
+ if len(receipts) != 1 {
+ glog.V(logger.Debug).Infof("ODR: invalid number of entries: %d", len(receipts))
+ return false
+ }
+ hash := types.DeriveSha(receipts[0])
+ header := core.GetHeader(db, self.Hash, self.Number)
+ if header == nil {
+ glog.V(logger.Debug).Infof("ODR: header not found for block %08x", self.Hash[:4])
+ return false
+ }
+ if !bytes.Equal(header.ReceiptHash[:], hash[:]) {
+ glog.V(logger.Debug).Infof("ODR: header receipts hash %08x does not match calculated RLP hash %08x", header.ReceiptHash[:4], hash[:4])
+ return false
+ }
+ self.Receipts = receipts[0]
+ glog.V(logger.Debug).Infof("ODR: validation successful")
+ return true
+}
+
+type ProofReq struct {
+ BHash common.Hash
+ AccKey, Key []byte
+ FromLevel uint
+}
+
+// ODR request type for state/storage trie entries, see LesOdrRequest interface
+type TrieRequest light.TrieRequest
+
+// GetCost returns the cost of the given ODR request according to the serving
+// peer's cost table (implementation of LesOdrRequest)
+func (self *TrieRequest) GetCost(peer *peer) uint64 {
+ return peer.GetRequestCost(GetProofsMsg, 1)
+}
+
+// Request sends an ODR request to the LES network (implementation of LesOdrRequest)
+func (self *TrieRequest) Request(reqID uint64, peer *peer) error {
+ glog.V(logger.Debug).Infof("ODR: requesting trie root %08x key %08x from peer %v", self.Id.Root[:4], self.Key[:4], peer.id)
+ req := &ProofReq{
+ BHash: self.Id.BlockHash,
+ AccKey: self.Id.AccKey,
+ Key: self.Key,
+ }
+ return peer.RequestProofs(reqID, self.GetCost(peer), []*ProofReq{req})
+}
+
+// Valid processes an ODR request reply message from the LES network
+// returns true and stores results in memory if the message was a valid reply
+// to the request (implementation of LesOdrRequest)
+func (self *TrieRequest) Valid(db ethdb.Database, msg *Msg) bool {
+ glog.V(logger.Debug).Infof("ODR: validating trie root %08x key %08x", self.Id.Root[:4], self.Key[:4])
+
+ if msg.MsgType != MsgProofs {
+ glog.V(logger.Debug).Infof("ODR: invalid message type")
+ return false
+ }
+ proofs := msg.Obj.([][]rlp.RawValue)
+ if len(proofs) != 1 {
+ glog.V(logger.Debug).Infof("ODR: invalid number of entries: %d", len(proofs))
+ return false
+ }
+ _, err := trie.VerifyProof(self.Id.Root, self.Key, proofs[0])
+ if err != nil {
+ glog.V(logger.Debug).Infof("ODR: merkle proof verification error: %v", err)
+ return false
+ }
+ self.Proof = proofs[0]
+ glog.V(logger.Debug).Infof("ODR: validation successful")
+ return true
+}
+
+type CodeReq struct {
+ BHash common.Hash
+ AccKey []byte
+}
+
+// ODR request type for node data (used for retrieving contract code), see LesOdrRequest interface
+type CodeRequest light.CodeRequest
+
+// GetCost returns the cost of the given ODR request according to the serving
+// peer's cost table (implementation of LesOdrRequest)
+func (self *CodeRequest) GetCost(peer *peer) uint64 {
+ return peer.GetRequestCost(GetCodeMsg, 1)
+}
+
+// Request sends an ODR request to the LES network (implementation of LesOdrRequest)
+func (self *CodeRequest) Request(reqID uint64, peer *peer) error {
+ glog.V(logger.Debug).Infof("ODR: requesting node data for hash %08x from peer %v", self.Hash[:4], peer.id)
+ req := &CodeReq{
+ BHash: self.Id.BlockHash,
+ AccKey: self.Id.AccKey,
+ }
+ return peer.RequestCode(reqID, self.GetCost(peer), []*CodeReq{req})
+}
+
+// Valid processes an ODR request reply message from the LES network
+// returns true and stores results in memory if the message was a valid reply
+// to the request (implementation of LesOdrRequest)
+func (self *CodeRequest) Valid(db ethdb.Database, msg *Msg) bool {
+ glog.V(logger.Debug).Infof("ODR: validating node data for hash %08x", self.Hash[:4])
+ if msg.MsgType != MsgCode {
+ glog.V(logger.Debug).Infof("ODR: invalid message type")
+ return false
+ }
+ reply := msg.Obj.([][]byte)
+ if len(reply) != 1 {
+ glog.V(logger.Debug).Infof("ODR: invalid number of entries: %d", len(reply))
+ return false
+ }
+ data := reply[0]
+ hash := crypto.Sha3Hash(data)
+ if !bytes.Equal(self.Hash[:], hash[:]) {
+ glog.V(logger.Debug).Infof("ODR: requested hash %08x does not match received data hash %08x", self.Hash[:4], hash[:4])
+ return false
+ }
+ self.Data = data
+ glog.V(logger.Debug).Infof("ODR: validation successful")
+ return true
+}
+
+type ChtReq struct {
+ ChtNum, BlockNum, FromLevel uint64
+}
+
+type ChtResp struct {
+ Header *types.Header
+ Proof []rlp.RawValue
+}
+
+// ODR request type for requesting headers by Canonical Hash Trie, see LesOdrRequest interface
+type ChtRequest light.ChtRequest
+
+// GetCost returns the cost of the given ODR request according to the serving
+// peer's cost table (implementation of LesOdrRequest)
+func (self *ChtRequest) GetCost(peer *peer) uint64 {
+ return peer.GetRequestCost(GetHeaderProofsMsg, 1)
+}
+
+// Request sends an ODR request to the LES network (implementation of LesOdrRequest)
+func (self *ChtRequest) Request(reqID uint64, peer *peer) error {
+ glog.V(logger.Debug).Infof("ODR: requesting CHT #%d block #%d from peer %v", self.ChtNum, self.BlockNum, peer.id)
+ req := &ChtReq{
+ ChtNum: self.ChtNum,
+ BlockNum: self.BlockNum,
+ }
+ return peer.RequestHeaderProofs(reqID, self.GetCost(peer), []*ChtReq{req})
+}
+
+// Valid processes an ODR request reply message from the LES network
+// returns true and stores results in memory if the message was a valid reply
+// to the request (implementation of LesOdrRequest)
+func (self *ChtRequest) Valid(db ethdb.Database, msg *Msg) bool {
+ glog.V(logger.Debug).Infof("ODR: validating CHT #%d block #%d", self.ChtNum, self.BlockNum)
+
+ if msg.MsgType != MsgHeaderProofs {
+ glog.V(logger.Debug).Infof("ODR: invalid message type")
+ return false
+ }
+ proofs := msg.Obj.([]ChtResp)
+ if len(proofs) != 1 {
+ glog.V(logger.Debug).Infof("ODR: invalid number of entries: %d", len(proofs))
+ return false
+ }
+ proof := proofs[0]
+ var encNumber [8]byte
+ binary.BigEndian.PutUint64(encNumber[:], self.BlockNum)
+ value, err := trie.VerifyProof(self.ChtRoot, encNumber[:], proof.Proof)
+ if err != nil {
+ glog.V(logger.Debug).Infof("ODR: CHT merkle proof verification error: %v", err)
+ return false
+ }
+ var node light.ChtNode
+ if err := rlp.DecodeBytes(value, &node); err != nil {
+ glog.V(logger.Debug).Infof("ODR: error decoding CHT node: %v", err)
+ return false
+ }
+ if node.Hash != proof.Header.Hash() {
+ glog.V(logger.Debug).Infof("ODR: CHT header hash does not match")
+ return false
+ }
+
+ self.Proof = proof.Proof
+ self.Header = proof.Header
+ self.Td = node.Td
+ glog.V(logger.Debug).Infof("ODR: validation successful")
+ return true
+}
diff --git a/les/odr_test.go b/les/odr_test.go
new file mode 100644
index 000000000..bd52a82dd
--- /dev/null
+++ b/les/odr_test.go
@@ -0,0 +1,222 @@
+package les
+
+import (
+ "bytes"
+ "math/big"
+ "testing"
+ "time"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core"
+ "github.com/ethereum/go-ethereum/core/state"
+ "github.com/ethereum/go-ethereum/core/types"
+ "github.com/ethereum/go-ethereum/ethdb"
+ "github.com/ethereum/go-ethereum/light"
+ "github.com/ethereum/go-ethereum/rlp"
+ "golang.org/x/net/context"
+)
+
+type odrTestFn func(ctx context.Context, db ethdb.Database, config *core.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte
+
+func TestOdrGetBlockLes1(t *testing.T) { testOdr(t, 1, 1, odrGetBlock) }
+
+func odrGetBlock(ctx context.Context, db ethdb.Database, config *core.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte {
+ var block *types.Block
+ if bc != nil {
+ block = bc.GetBlockByHash(bhash)
+ } else {
+ block, _ = lc.GetBlockByHash(ctx, bhash)
+ }
+ if block == nil {
+ return nil
+ }
+ rlp, _ := rlp.EncodeToBytes(block)
+ return rlp
+}
+
+func TestOdrGetReceiptsLes1(t *testing.T) { testOdr(t, 1, 1, odrGetReceipts) }
+
+func odrGetReceipts(ctx context.Context, db ethdb.Database, config *core.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte {
+ var receipts types.Receipts
+ if bc != nil {
+ receipts = core.GetBlockReceipts(db, bhash, core.GetBlockNumber(db, bhash))
+ } else {
+ receipts, _ = light.GetBlockReceipts(ctx, lc.Odr(), bhash, core.GetBlockNumber(db, bhash))
+ }
+ if receipts == nil {
+ return nil
+ }
+ rlp, _ := rlp.EncodeToBytes(receipts)
+ return rlp
+}
+
+func TestOdrAccountsLes1(t *testing.T) { testOdr(t, 1, 1, odrAccounts) }
+
+func odrAccounts(ctx context.Context, db ethdb.Database, config *core.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte {
+ dummyAddr := common.HexToAddress("1234567812345678123456781234567812345678")
+ acc := []common.Address{testBankAddress, acc1Addr, acc2Addr, dummyAddr}
+
+ var res []byte
+ for _, addr := range acc {
+ if bc != nil {
+ header := bc.GetHeaderByHash(bhash)
+ st, err := state.New(header.Root, db)
+ if err == nil {
+ bal := st.GetBalance(addr)
+ rlp, _ := rlp.EncodeToBytes(bal)
+ res = append(res, rlp...)
+ }
+ } else {
+ header := lc.GetHeaderByHash(bhash)
+ st := light.NewLightState(light.StateTrieID(header), lc.Odr())
+ bal, err := st.GetBalance(ctx, addr)
+ if err == nil {
+ rlp, _ := rlp.EncodeToBytes(bal)
+ res = append(res, rlp...)
+ }
+ }
+ }
+
+ return res
+}
+
+func TestOdrContractCallLes1(t *testing.T) { testOdr(t, 1, 2, odrContractCall) }
+
+// fullcallmsg is the message type used for call transations.
+type fullcallmsg struct {
+ from *state.StateObject
+ to *common.Address
+ gas, gasPrice *big.Int
+ value *big.Int
+ data []byte
+}
+
+// accessor boilerplate to implement core.Message
+func (m fullcallmsg) From() (common.Address, error) { return m.from.Address(), nil }
+func (m fullcallmsg) FromFrontier() (common.Address, error) { return m.from.Address(), nil }
+func (m fullcallmsg) Nonce() uint64 { return 0 }
+func (m fullcallmsg) CheckNonce() bool { return false }
+func (m fullcallmsg) To() *common.Address { return m.to }
+func (m fullcallmsg) GasPrice() *big.Int { return m.gasPrice }
+func (m fullcallmsg) Gas() *big.Int { return m.gas }
+func (m fullcallmsg) Value() *big.Int { return m.value }
+func (m fullcallmsg) Data() []byte { return m.data }
+
+// callmsg is the message type used for call transations.
+type lightcallmsg struct {
+ from *light.StateObject
+ to *common.Address
+ gas, gasPrice *big.Int
+ value *big.Int
+ data []byte
+}
+
+// accessor boilerplate to implement core.Message
+func (m lightcallmsg) From() (common.Address, error) { return m.from.Address(), nil }
+func (m lightcallmsg) FromFrontier() (common.Address, error) { return m.from.Address(), nil }
+func (m lightcallmsg) Nonce() uint64 { return 0 }
+func (m lightcallmsg) CheckNonce() bool { return false }
+func (m lightcallmsg) To() *common.Address { return m.to }
+func (m lightcallmsg) GasPrice() *big.Int { return m.gasPrice }
+func (m lightcallmsg) Gas() *big.Int { return m.gas }
+func (m lightcallmsg) Value() *big.Int { return m.value }
+func (m lightcallmsg) Data() []byte { return m.data }
+
+func odrContractCall(ctx context.Context, db ethdb.Database, config *core.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte {
+ data := common.Hex2Bytes("60CD26850000000000000000000000000000000000000000000000000000000000000000")
+
+ var res []byte
+ for i := 0; i < 3; i++ {
+ data[35] = byte(i)
+ if bc != nil {
+ header := bc.GetHeaderByHash(bhash)
+ statedb, err := state.New(header.Root, db)
+ if err == nil {
+ from := statedb.GetOrNewStateObject(testBankAddress)
+ from.SetBalance(common.MaxBig)
+
+ msg := fullcallmsg{
+ from: from,
+ gas: big.NewInt(100000),
+ gasPrice: big.NewInt(0),
+ value: big.NewInt(0),
+ data: data,
+ to: &testContractAddr,
+ }
+
+ vmenv := core.NewEnv(statedb, config, bc, msg, header, config.VmConfig)
+ gp := new(core.GasPool).AddGas(common.MaxBig)
+ ret, _, _ := core.ApplyMessage(vmenv, msg, gp)
+ res = append(res, ret...)
+ }
+ } else {
+ header := lc.GetHeaderByHash(bhash)
+ state := light.NewLightState(light.StateTrieID(header), lc.Odr())
+ from, err := state.GetOrNewStateObject(ctx, testBankAddress)
+ if err == nil {
+ from.SetBalance(common.MaxBig)
+
+ msg := lightcallmsg{
+ from: from,
+ gas: big.NewInt(100000),
+ gasPrice: big.NewInt(0),
+ value: big.NewInt(0),
+ data: data,
+ to: &testContractAddr,
+ }
+
+ vmenv := light.NewEnv(ctx, state, config, lc, msg, header, config.VmConfig)
+ gp := new(core.GasPool).AddGas(common.MaxBig)
+ ret, _, _ := core.ApplyMessage(vmenv, msg, gp)
+ if vmenv.Error() == nil {
+ res = append(res, ret...)
+ }
+ }
+ }
+ }
+ return res
+}
+
+func testOdr(t *testing.T, protocol int, expFail uint64, fn odrTestFn) {
+ // Assemble the test environment
+ pm, db, odr := newTestProtocolManagerMust(t, false, 4, testChainGen)
+ lpm, ldb, odr := newTestProtocolManagerMust(t, true, 0, nil)
+ _, err1, lpeer, err2 := newTestPeerPair("peer", protocol, pm, lpm)
+ select {
+ case <-time.After(time.Millisecond * 100):
+ case err := <-err1:
+ t.Fatalf("peer 1 handshake error: %v", err)
+ case err := <-err2:
+ t.Fatalf("peer 1 handshake error: %v", err)
+ }
+
+ lpm.synchronise(lpeer)
+
+ test := func(expFail uint64) {
+ for i := uint64(0); i <= pm.blockchain.CurrentHeader().GetNumberU64(); i++ {
+ bhash := core.GetCanonicalHash(db, i)
+ b1 := fn(light.NoOdr, db, pm.chainConfig, pm.blockchain.(*core.BlockChain), nil, bhash)
+ ctx, _ := context.WithTimeout(context.Background(), 200*time.Millisecond)
+ b2 := fn(ctx, ldb, lpm.chainConfig, nil, lpm.blockchain.(*light.LightChain), bhash)
+ eq := bytes.Equal(b1, b2)
+ exp := i < expFail
+ if exp && !eq {
+ t.Errorf("odr mismatch")
+ }
+ if !exp && eq {
+ t.Errorf("unexpected odr match")
+ }
+ }
+ }
+
+ // temporarily remove peer to test odr fails
+ odr.UnregisterPeer(lpeer)
+ // expect retrievals to fail (except genesis block) without a les peer
+ test(expFail)
+ odr.RegisterPeer(lpeer)
+ // expect all retrievals to pass
+ test(5)
+ odr.UnregisterPeer(lpeer)
+ // still expect all retrievals to pass, now data should be cached locally
+ test(5)
+}
diff --git a/les/peer.go b/les/peer.go
new file mode 100644
index 000000000..dbddbb020
--- /dev/null
+++ b/les/peer.go
@@ -0,0 +1,584 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+// Package les implements the Light Ethereum Subprotocol.
+package les
+
+import (
+ "errors"
+ "fmt"
+ "math/big"
+ "sync"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/types"
+ "github.com/ethereum/go-ethereum/eth"
+ "github.com/ethereum/go-ethereum/les/flowcontrol"
+ "github.com/ethereum/go-ethereum/logger"
+ "github.com/ethereum/go-ethereum/logger/glog"
+ "github.com/ethereum/go-ethereum/p2p"
+ "github.com/ethereum/go-ethereum/rlp"
+)
+
+var (
+ errClosed = errors.New("peer set is closed")
+ errAlreadyRegistered = errors.New("peer is already registered")
+ errNotRegistered = errors.New("peer is not registered")
+)
+
+const maxHeadInfoLen = 20
+
+type peer struct {
+ *p2p.Peer
+
+ rw p2p.MsgReadWriter
+
+ version int // Protocol version negotiated
+ network int // Network ID being on
+
+ id string
+
+ firstHeadInfo, headInfo *announceData
+ headInfoLen int
+ lock sync.RWMutex
+
+ announceChn chan announceData
+
+ fcClient *flowcontrol.ClientNode // nil if the peer is server only
+ fcServer *flowcontrol.ServerNode // nil if the peer is client only
+ fcServerParams *flowcontrol.ServerParams
+ fcCosts requestCostTable
+}
+
+func newPeer(version, network int, p *p2p.Peer, rw p2p.MsgReadWriter) *peer {
+ id := p.ID()
+
+ return &peer{
+ Peer: p,
+ rw: rw,
+ version: version,
+ network: network,
+ id: fmt.Sprintf("%x", id[:8]),
+ announceChn: make(chan announceData, 20),
+ }
+}
+
+// Info gathers and returns a collection of metadata known about a peer.
+func (p *peer) Info() *eth.PeerInfo {
+ return &eth.PeerInfo{
+ Version: p.version,
+ Difficulty: p.Td(),
+ Head: fmt.Sprintf("%x", p.Head()),
+ }
+}
+
+// Head retrieves a copy of the current head (most recent) hash of the peer.
+func (p *peer) Head() (hash common.Hash) {
+ p.lock.RLock()
+ defer p.lock.RUnlock()
+
+ copy(hash[:], p.headInfo.Hash[:])
+ return hash
+}
+
+func (p *peer) HeadAndTd() (hash common.Hash, td *big.Int) {
+ p.lock.RLock()
+ defer p.lock.RUnlock()
+
+ copy(hash[:], p.headInfo.Hash[:])
+ return hash, p.headInfo.Td
+}
+
+func (p *peer) headBlockInfo() blockInfo {
+ p.lock.RLock()
+ defer p.lock.RUnlock()
+
+ return blockInfo{Hash: p.headInfo.Hash, Number: p.headInfo.Number, Td: p.headInfo.Td}
+}
+
+func (p *peer) addNotify(announce *announceData) bool {
+ p.lock.Lock()
+ defer p.lock.Unlock()
+
+ if announce.Td.Cmp(p.headInfo.Td) < 1 {
+ return false
+ }
+ if p.headInfoLen >= maxHeadInfoLen {
+ //return false
+ p.firstHeadInfo = p.firstHeadInfo.next
+ p.headInfoLen--
+ }
+ if announce.haveHeaders == 0 {
+ hh := p.headInfo.Number - announce.ReorgDepth
+ if p.headInfo.haveHeaders < hh {
+ hh = p.headInfo.haveHeaders
+ }
+ announce.haveHeaders = hh
+ }
+ p.headInfo.next = announce
+ p.headInfo = announce
+ p.headInfoLen++
+ return true
+}
+
+func (p *peer) gotHeader(hash common.Hash, number uint64, td *big.Int) bool {
+ h := p.firstHeadInfo
+ ptr := 0
+ for h != nil {
+ if h.Hash == hash {
+ if h.Number != number || h.Td.Cmp(td) != 0 {
+ return false
+ }
+ h.headKnown = true
+ h.haveHeaders = h.Number
+ p.firstHeadInfo = h
+ p.headInfoLen -= ptr
+ last := h
+ h = h.next
+ // propagate haveHeaders through the chain
+ for h != nil {
+ hh := last.Number - h.ReorgDepth
+ if last.haveHeaders < hh {
+ hh = last.haveHeaders
+ }
+ if hh > h.haveHeaders {
+ h.haveHeaders = hh
+ } else {
+ return true
+ }
+ last = h
+ h = h.next
+ }
+ return true
+ }
+ h = h.next
+ ptr++
+ }
+ return true
+}
+
+// Td retrieves the current total difficulty of a peer.
+func (p *peer) Td() *big.Int {
+ p.lock.RLock()
+ defer p.lock.RUnlock()
+
+ return new(big.Int).Set(p.headInfo.Td)
+}
+
+func sendRequest(w p2p.MsgWriter, msgcode, reqID, cost uint64, data interface{}) error {
+ type req struct {
+ ReqID uint64
+ Data interface{}
+ }
+ return p2p.Send(w, msgcode, req{reqID, data})
+}
+
+func sendResponse(w p2p.MsgWriter, msgcode, reqID, bv uint64, data interface{}) error {
+ type resp struct {
+ ReqID, BV uint64
+ Data interface{}
+ }
+ return p2p.Send(w, msgcode, resp{reqID, bv, data})
+}
+
+func (p *peer) GetRequestCost(msgcode uint64, amount int) uint64 {
+ cost := p.fcCosts[msgcode].baseCost + p.fcCosts[msgcode].reqCost*uint64(amount)
+ if cost > p.fcServerParams.BufLimit {
+ cost = p.fcServerParams.BufLimit
+ }
+ return cost
+}
+
+// SendAnnounce announces the availability of a number of blocks through
+// a hash notification.
+func (p *peer) SendAnnounce(request announceData) error {
+ return p2p.Send(p.rw, AnnounceMsg, request)
+}
+
+// SendBlockHeaders sends a batch of block headers to the remote peer.
+func (p *peer) SendBlockHeaders(reqID, bv uint64, headers []*types.Header) error {
+ return sendResponse(p.rw, BlockHeadersMsg, reqID, bv, headers)
+}
+
+// SendBlockBodiesRLP sends a batch of block contents to the remote peer from
+// an already RLP encoded format.
+func (p *peer) SendBlockBodiesRLP(reqID, bv uint64, bodies []rlp.RawValue) error {
+ return sendResponse(p.rw, BlockBodiesMsg, reqID, bv, bodies)
+}
+
+// SendCodeRLP sends a batch of arbitrary internal data, corresponding to the
+// hashes requested.
+func (p *peer) SendCode(reqID, bv uint64, data [][]byte) error {
+ return sendResponse(p.rw, CodeMsg, reqID, bv, data)
+}
+
+// SendReceiptsRLP sends a batch of transaction receipts, corresponding to the
+// ones requested from an already RLP encoded format.
+func (p *peer) SendReceiptsRLP(reqID, bv uint64, receipts []rlp.RawValue) error {
+ return sendResponse(p.rw, ReceiptsMsg, reqID, bv, receipts)
+}
+
+// SendProofs sends a batch of merkle proofs, corresponding to the ones requested.
+func (p *peer) SendProofs(reqID, bv uint64, proofs proofsData) error {
+ return sendResponse(p.rw, ProofsMsg, reqID, bv, proofs)
+}
+
+// SendHeaderProofs sends a batch of header proofs, corresponding to the ones requested.
+func (p *peer) SendHeaderProofs(reqID, bv uint64, proofs []ChtResp) error {
+ return sendResponse(p.rw, HeaderProofsMsg, reqID, bv, proofs)
+}
+
+// RequestHeadersByHash fetches a batch of blocks' headers corresponding to the
+// specified header query, based on the hash of an origin block.
+func (p *peer) RequestHeadersByHash(reqID, cost uint64, origin common.Hash, amount int, skip int, reverse bool) error {
+ glog.V(logger.Debug).Infof("%v fetching %d headers from %x, skipping %d (reverse = %v)", p, amount, origin[:4], skip, reverse)
+ return sendRequest(p.rw, GetBlockHeadersMsg, reqID, cost, &getBlockHeadersData{Origin: hashOrNumber{Hash: origin}, Amount: uint64(amount), Skip: uint64(skip), Reverse: reverse})
+}
+
+// RequestHeadersByNumber fetches a batch of blocks' headers corresponding to the
+// specified header query, based on the number of an origin block.
+func (p *peer) RequestHeadersByNumber(reqID, cost, origin uint64, amount int, skip int, reverse bool) error {
+ glog.V(logger.Debug).Infof("%v fetching %d headers from #%d, skipping %d (reverse = %v)", p, amount, origin, skip, reverse)
+ return sendRequest(p.rw, GetBlockHeadersMsg, reqID, cost, &getBlockHeadersData{Origin: hashOrNumber{Number: origin}, Amount: uint64(amount), Skip: uint64(skip), Reverse: reverse})
+}
+
+// RequestBodies fetches a batch of blocks' bodies corresponding to the hashes
+// specified.
+func (p *peer) RequestBodies(reqID, cost uint64, hashes []common.Hash) error {
+ glog.V(logger.Debug).Infof("%v fetching %d block bodies", p, len(hashes))
+ return sendRequest(p.rw, GetBlockBodiesMsg, reqID, cost, hashes)
+}
+
+// RequestCode fetches a batch of arbitrary data from a node's known state
+// data, corresponding to the specified hashes.
+func (p *peer) RequestCode(reqID, cost uint64, reqs []*CodeReq) error {
+ glog.V(logger.Debug).Infof("%v fetching %v state data", p, len(reqs))
+ return sendRequest(p.rw, GetCodeMsg, reqID, cost, reqs)
+}
+
+// RequestReceipts fetches a batch of transaction receipts from a remote node.
+func (p *peer) RequestReceipts(reqID, cost uint64, hashes []common.Hash) error {
+ glog.V(logger.Debug).Infof("%v fetching %v receipts", p, len(hashes))
+ return sendRequest(p.rw, GetReceiptsMsg, reqID, cost, hashes)
+}
+
+// RequestProofs fetches a batch of merkle proofs from a remote node.
+func (p *peer) RequestProofs(reqID, cost uint64, reqs []*ProofReq) error {
+ glog.V(logger.Debug).Infof("%v fetching %v proofs", p, len(reqs))
+ return sendRequest(p.rw, GetProofsMsg, reqID, cost, reqs)
+}
+
+// RequestHeaderProofs fetches a batch of header merkle proofs from a remote node.
+func (p *peer) RequestHeaderProofs(reqID, cost uint64, reqs []*ChtReq) error {
+ glog.V(logger.Debug).Infof("%v fetching %v header proofs", p, len(reqs))
+ return sendRequest(p.rw, GetHeaderProofsMsg, reqID, cost, reqs)
+}
+
+func (p *peer) SendTxs(cost uint64, txs types.Transactions) error {
+ glog.V(logger.Debug).Infof("%v relaying %v txs", p, len(txs))
+ p.fcServer.SendRequest(0, cost)
+ return p2p.Send(p.rw, SendTxMsg, txs)
+}
+
+type keyValueEntry struct {
+ Key string
+ Value rlp.RawValue
+}
+type keyValueList []keyValueEntry
+type keyValueMap map[string]rlp.RawValue
+
+func (l keyValueList) add(key string, val interface{}) keyValueList {
+ var entry keyValueEntry
+ entry.Key = key
+ if val == nil {
+ val = uint64(0)
+ }
+ enc, err := rlp.EncodeToBytes(val)
+ if err == nil {
+ entry.Value = enc
+ }
+ return append(l, entry)
+}
+
+func (l keyValueList) decode() keyValueMap {
+ m := make(keyValueMap)
+ for _, entry := range l {
+ m[entry.Key] = entry.Value
+ }
+ return m
+}
+
+func (m keyValueMap) get(key string, val interface{}) error {
+ enc, ok := m[key]
+ if !ok {
+ return errResp(ErrHandshakeMissingKey, "%s", key)
+ }
+ if val == nil {
+ return nil
+ }
+ return rlp.DecodeBytes(enc, val)
+}
+
+func (p *peer) sendReceiveHandshake(sendList keyValueList) (keyValueList, error) {
+ // Send out own handshake in a new thread
+ errc := make(chan error, 1)
+ go func() {
+ errc <- p2p.Send(p.rw, StatusMsg, sendList)
+ }()
+ // In the mean time retrieve the remote status message
+ msg, err := p.rw.ReadMsg()
+ if err != nil {
+ return nil, err
+ }
+ if msg.Code != StatusMsg {
+ return nil, errResp(ErrNoStatusMsg, "first msg has code %x (!= %x)", msg.Code, StatusMsg)
+ }
+ if msg.Size > ProtocolMaxMsgSize {
+ return nil, errResp(ErrMsgTooLarge, "%v > %v", msg.Size, ProtocolMaxMsgSize)
+ }
+ // Decode the handshake
+ var recvList keyValueList
+ if err := msg.Decode(&recvList); err != nil {
+ return nil, errResp(ErrDecode, "msg %v: %v", msg, err)
+ }
+ if err := <-errc; err != nil {
+ return nil, err
+ }
+ return recvList, nil
+}
+
+// Handshake executes the les protocol handshake, negotiating version number,
+// network IDs, difficulties, head and genesis blocks.
+func (p *peer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis common.Hash, server *LesServer) error {
+ p.lock.Lock()
+ defer p.lock.Unlock()
+
+ var send keyValueList
+ send = send.add("protocolVersion", uint64(p.version))
+ send = send.add("networkId", uint64(p.network))
+ send = send.add("headTd", td)
+ send = send.add("headHash", head)
+ send = send.add("headNum", headNum)
+ send = send.add("genesisHash", genesis)
+ if server != nil {
+ send = send.add("serveHeaders", nil)
+ send = send.add("serveChainSince", uint64(0))
+ send = send.add("serveStateSince", uint64(0))
+ send = send.add("txRelay", nil)
+ send = send.add("flowControl/BL", server.defParams.BufLimit)
+ send = send.add("flowControl/MRR", server.defParams.MinRecharge)
+ list := server.fcCostStats.getCurrentList()
+ send = send.add("flowControl/MRC", list)
+ p.fcCosts = list.decode()
+ }
+ recvList, err := p.sendReceiveHandshake(send)
+ if err != nil {
+ return err
+ }
+ recv := recvList.decode()
+
+ var rGenesis, rHash common.Hash
+ var rVersion, rNetwork, rNum uint64
+ var rTd *big.Int
+
+ if err := recv.get("protocolVersion", &rVersion); err != nil {
+ return err
+ }
+ if err := recv.get("networkId", &rNetwork); err != nil {
+ return err
+ }
+ if err := recv.get("headTd", &rTd); err != nil {
+ return err
+ }
+ if err := recv.get("headHash", &rHash); err != nil {
+ return err
+ }
+ if err := recv.get("headNum", &rNum); err != nil {
+ return err
+ }
+ if err := recv.get("genesisHash", &rGenesis); err != nil {
+ return err
+ }
+
+ if rGenesis != genesis {
+ return errResp(ErrGenesisBlockMismatch, "%x (!= %x)", rGenesis, genesis)
+ }
+ if int(rNetwork) != p.network {
+ return errResp(ErrNetworkIdMismatch, "%d (!= %d)", rNetwork, p.network)
+ }
+ if int(rVersion) != p.version {
+ return errResp(ErrProtocolVersionMismatch, "%d (!= %d)", rVersion, p.version)
+ }
+ if server != nil {
+ if recv.get("serveStateSince", nil) == nil {
+ return errResp(ErrUselessPeer, "wanted client, got server")
+ }
+ p.fcClient = flowcontrol.NewClientNode(server.fcManager, server.defParams)
+ } else {
+ if recv.get("serveChainSince", nil) != nil {
+ return errResp(ErrUselessPeer, "peer cannot serve chain")
+ }
+ if recv.get("serveStateSince", nil) != nil {
+ return errResp(ErrUselessPeer, "peer cannot serve state")
+ }
+ if recv.get("txRelay", nil) != nil {
+ return errResp(ErrUselessPeer, "peer cannot relay transactions")
+ }
+ params := &flowcontrol.ServerParams{}
+ if err := recv.get("flowControl/BL", &params.BufLimit); err != nil {
+ return err
+ }
+ if err := recv.get("flowControl/MRR", &params.MinRecharge); err != nil {
+ return err
+ }
+ var MRC RequestCostList
+ if err := recv.get("flowControl/MRC", &MRC); err != nil {
+ return err
+ }
+ p.fcServerParams = params
+ p.fcServer = flowcontrol.NewServerNode(params)
+ p.fcCosts = MRC.decode()
+ }
+
+ p.firstHeadInfo = &announceData{Td: rTd, Hash: rHash, Number: rNum}
+ p.headInfo = p.firstHeadInfo
+ p.headInfoLen = 1
+ return nil
+}
+
+// String implements fmt.Stringer.
+func (p *peer) String() string {
+ return fmt.Sprintf("Peer %s [%s]", p.id,
+ fmt.Sprintf("les/%d", p.version),
+ )
+}
+
+// peerSet represents the collection of active peers currently participating in
+// the Light Ethereum sub-protocol.
+type peerSet struct {
+ peers map[string]*peer
+ lock sync.RWMutex
+ closed bool
+}
+
+// newPeerSet creates a new peer set to track the active participants.
+func newPeerSet() *peerSet {
+ return &peerSet{
+ peers: make(map[string]*peer),
+ }
+}
+
+// Register injects a new peer into the working set, or returns an error if the
+// peer is already known.
+func (ps *peerSet) Register(p *peer) error {
+ ps.lock.Lock()
+ defer ps.lock.Unlock()
+
+ if ps.closed {
+ return errClosed
+ }
+ if _, ok := ps.peers[p.id]; ok {
+ return errAlreadyRegistered
+ }
+ ps.peers[p.id] = p
+ return nil
+}
+
+// Unregister removes a remote peer from the active set, disabling any further
+// actions to/from that particular entity.
+func (ps *peerSet) Unregister(id string) error {
+ ps.lock.Lock()
+ defer ps.lock.Unlock()
+
+ if _, ok := ps.peers[id]; !ok {
+ return errNotRegistered
+ }
+ delete(ps.peers, id)
+ return nil
+}
+
+// AllPeerIDs returns a list of all registered peer IDs
+func (ps *peerSet) AllPeerIDs() []string {
+ ps.lock.RLock()
+ defer ps.lock.RUnlock()
+
+ res := make([]string, len(ps.peers))
+ idx := 0
+ for id, _ := range ps.peers {
+ res[idx] = id
+ idx++
+ }
+ return res
+}
+
+// Peer retrieves the registered peer with the given id.
+func (ps *peerSet) Peer(id string) *peer {
+ ps.lock.RLock()
+ defer ps.lock.RUnlock()
+
+ return ps.peers[id]
+}
+
+// Len returns if the current number of peers in the set.
+func (ps *peerSet) Len() int {
+ ps.lock.RLock()
+ defer ps.lock.RUnlock()
+
+ return len(ps.peers)
+}
+
+// BestPeer retrieves the known peer with the currently highest total difficulty.
+func (ps *peerSet) BestPeer() *peer {
+ ps.lock.RLock()
+ defer ps.lock.RUnlock()
+
+ var (
+ bestPeer *peer
+ bestTd *big.Int
+ )
+ for _, p := range ps.peers {
+ if td := p.Td(); bestPeer == nil || td.Cmp(bestTd) > 0 {
+ bestPeer, bestTd = p, td
+ }
+ }
+ return bestPeer
+}
+
+// AllPeers returns all peers in a list
+func (ps *peerSet) AllPeers() []*peer {
+ ps.lock.RLock()
+ defer ps.lock.RUnlock()
+
+ list := make([]*peer, len(ps.peers))
+ i := 0
+ for _, peer := range ps.peers {
+ list[i] = peer
+ i++
+ }
+ return list
+}
+
+// Close disconnects all peers.
+// No new peers can be registered after Close has returned.
+func (ps *peerSet) Close() {
+ ps.lock.Lock()
+ defer ps.lock.Unlock()
+
+ for _, p := range ps.peers {
+ p.Disconnect(p2p.DiscQuitting)
+ }
+ ps.closed = true
+}
diff --git a/les/protocol.go b/les/protocol.go
new file mode 100644
index 000000000..3d2de64e1
--- /dev/null
+++ b/les/protocol.go
@@ -0,0 +1,198 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+// Package les implements the Light Ethereum Subprotocol.
+package les
+
+import (
+ "fmt"
+ "io"
+ "math/big"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/types"
+ "github.com/ethereum/go-ethereum/rlp"
+)
+
+// Constants to match up protocol versions and messages
+const (
+ lpv1 = 1
+)
+
+// Supported versions of the les protocol (first is primary).
+var ProtocolVersions = []uint{lpv1}
+
+// Number of implemented message corresponding to different protocol versions.
+var ProtocolLengths = []uint64{15}
+
+const (
+ NetworkId = 1
+ ProtocolMaxMsgSize = 10 * 1024 * 1024 // Maximum cap on the size of a protocol message
+)
+
+// les protocol message codes
+const (
+ // Protocol messages belonging to LPV1
+ StatusMsg = 0x00
+ AnnounceMsg = 0x01
+ GetBlockHeadersMsg = 0x02
+ BlockHeadersMsg = 0x03
+ GetBlockBodiesMsg = 0x04
+ BlockBodiesMsg = 0x05
+ GetReceiptsMsg = 0x06
+ ReceiptsMsg = 0x07
+ GetProofsMsg = 0x08
+ ProofsMsg = 0x09
+ GetCodeMsg = 0x0a
+ CodeMsg = 0x0b
+ SendTxMsg = 0x0c
+ GetHeaderProofsMsg = 0x0d
+ HeaderProofsMsg = 0x0e
+)
+
+type errCode int
+
+const (
+ ErrMsgTooLarge = iota
+ ErrDecode
+ ErrInvalidMsgCode
+ ErrProtocolVersionMismatch
+ ErrNetworkIdMismatch
+ ErrGenesisBlockMismatch
+ ErrNoStatusMsg
+ ErrExtraStatusMsg
+ ErrSuspendedPeer
+ ErrUselessPeer
+ ErrRequestRejected
+ ErrUnexpectedResponse
+ ErrInvalidResponse
+ ErrTooManyTimeouts
+ ErrHandshakeMissingKey
+)
+
+func (e errCode) String() string {
+ return errorToString[int(e)]
+}
+
+// XXX change once legacy code is out
+var errorToString = map[int]string{
+ ErrMsgTooLarge: "Message too long",
+ ErrDecode: "Invalid message",
+ ErrInvalidMsgCode: "Invalid message code",
+ ErrProtocolVersionMismatch: "Protocol version mismatch",
+ ErrNetworkIdMismatch: "NetworkId mismatch",
+ ErrGenesisBlockMismatch: "Genesis block mismatch",
+ ErrNoStatusMsg: "No status message",
+ ErrExtraStatusMsg: "Extra status message",
+ ErrSuspendedPeer: "Suspended peer",
+ ErrRequestRejected: "Request rejected",
+ ErrUnexpectedResponse: "Unexpected response",
+ ErrInvalidResponse: "Invalid response",
+ ErrTooManyTimeouts: "Too many request timeouts",
+ ErrHandshakeMissingKey: "Key missing from handshake message",
+}
+
+type chainManager interface {
+ GetBlockHashesFromHash(hash common.Hash, amount uint64) (hashes []common.Hash)
+ GetBlock(hash common.Hash) (block *types.Block)
+ Status() (td *big.Int, currentBlock common.Hash, genesisBlock common.Hash)
+}
+
+// announceData is the network packet for the block announcements.
+type announceData struct {
+ Hash common.Hash // Hash of one particular block being announced
+ Number uint64 // Number of one particular block being announced
+ Td *big.Int // Total difficulty of one particular block being announced
+ ReorgDepth uint64
+ Update keyValueList
+
+ haveHeaders uint64 // we have the headers of the remote peer's chain up to this number
+ headKnown bool
+ requested bool
+ next *announceData
+}
+
+type blockInfo struct {
+ Hash common.Hash // Hash of one particular block being announced
+ Number uint64 // Number of one particular block being announced
+ Td *big.Int // Total difficulty of one particular block being announced
+}
+
+// getBlockHashesData is the network packet for the hash based hash retrieval.
+type getBlockHashesData struct {
+ Hash common.Hash
+ Amount uint64
+}
+
+// getBlockHeadersData represents a block header query.
+type getBlockHeadersData struct {
+ Origin hashOrNumber // Block from which to retrieve headers
+ Amount uint64 // Maximum number of headers to retrieve
+ Skip uint64 // Blocks to skip between consecutive headers
+ Reverse bool // Query direction (false = rising towards latest, true = falling towards genesis)
+}
+
+// hashOrNumber is a combined field for specifying an origin block.
+type hashOrNumber struct {
+ Hash common.Hash // Block hash from which to retrieve headers (excludes Number)
+ Number uint64 // Block hash from which to retrieve headers (excludes Hash)
+}
+
+// EncodeRLP is a specialized encoder for hashOrNumber to encode only one of the
+// two contained union fields.
+func (hn *hashOrNumber) EncodeRLP(w io.Writer) error {
+ if hn.Hash == (common.Hash{}) {
+ return rlp.Encode(w, hn.Number)
+ }
+ if hn.Number != 0 {
+ return fmt.Errorf("both origin hash (%x) and number (%d) provided", hn.Hash, hn.Number)
+ }
+ return rlp.Encode(w, hn.Hash)
+}
+
+// DecodeRLP is a specialized decoder for hashOrNumber to decode the contents
+// into either a block hash or a block number.
+func (hn *hashOrNumber) DecodeRLP(s *rlp.Stream) error {
+ _, size, _ := s.Kind()
+ origin, err := s.Raw()
+ if err == nil {
+ switch {
+ case size == 32:
+ err = rlp.DecodeBytes(origin, &hn.Hash)
+ case size <= 8:
+ err = rlp.DecodeBytes(origin, &hn.Number)
+ default:
+ err = fmt.Errorf("invalid input size %d for origin", size)
+ }
+ }
+ return err
+}
+
+// newBlockData is the network packet for the block propagation message.
+type newBlockData struct {
+ Block *types.Block
+ TD *big.Int
+}
+
+// blockBodiesData is the network packet for block content distribution.
+type blockBodiesData []*types.Body
+
+// CodeData is the network response packet for a node data retrieval.
+type CodeData []struct {
+ Value []byte
+}
+
+type proofsData [][]rlp.RawValue
diff --git a/les/request_test.go b/les/request_test.go
new file mode 100644
index 000000000..df02afb32
--- /dev/null
+++ b/les/request_test.go
@@ -0,0 +1,94 @@
+package les
+
+import (
+ "testing"
+ "time"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core"
+ "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/ethdb"
+ "github.com/ethereum/go-ethereum/light"
+ "golang.org/x/net/context"
+)
+
+var testBankSecureTrieKey = secAddr(testBankAddress)
+
+func secAddr(addr common.Address) []byte {
+ return crypto.Keccak256(addr[:])
+}
+
+type accessTestFn func(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest
+
+func TestBlockAccessLes1(t *testing.T) { testAccess(t, 1, tfBlockAccess) }
+
+func tfBlockAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest {
+ return &light.BlockRequest{Hash: bhash, Number: number}
+}
+
+func TestReceiptsAccessLes1(t *testing.T) { testAccess(t, 1, tfReceiptsAccess) }
+
+func tfReceiptsAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest {
+ return &light.ReceiptsRequest{Hash: bhash, Number: number}
+}
+
+func TestTrieEntryAccessLes1(t *testing.T) { testAccess(t, 1, tfTrieEntryAccess) }
+
+func tfTrieEntryAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest {
+ return &light.TrieRequest{Id: light.StateTrieID(core.GetHeader(db, bhash, core.GetBlockNumber(db, bhash))), Key: testBankSecureTrieKey}
+}
+
+func TestCodeAccessLes1(t *testing.T) { testAccess(t, 1, tfCodeAccess) }
+
+func tfCodeAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest {
+ header := core.GetHeader(db, bhash, core.GetBlockNumber(db, bhash))
+ if header.GetNumberU64() < testContractDeployed {
+ return nil
+ }
+ sti := light.StateTrieID(header)
+ ci := light.StorageTrieID(sti, testContractAddr, common.Hash{})
+ return &light.CodeRequest{Id: ci, Hash: crypto.Keccak256Hash(testContractCodeDeployed)}
+}
+
+func testAccess(t *testing.T, protocol int, fn accessTestFn) {
+ // Assemble the test environment
+ pm, db, _ := newTestProtocolManagerMust(t, false, 4, testChainGen)
+ lpm, ldb, odr := newTestProtocolManagerMust(t, true, 0, nil)
+ _, err1, lpeer, err2 := newTestPeerPair("peer", protocol, pm, lpm)
+ select {
+ case <-time.After(time.Millisecond * 100):
+ case err := <-err1:
+ t.Fatalf("peer 1 handshake error: %v", err)
+ case err := <-err2:
+ t.Fatalf("peer 1 handshake error: %v", err)
+ }
+
+ lpm.synchronise(lpeer)
+
+ test := func(expFail uint64) {
+ for i := uint64(0); i <= pm.blockchain.CurrentHeader().GetNumberU64(); i++ {
+ bhash := core.GetCanonicalHash(db, i)
+ if req := fn(ldb, bhash, i); req != nil {
+ ctx, _ := context.WithTimeout(context.Background(), 200*time.Millisecond)
+ err := odr.Retrieve(ctx, req)
+ got := err == nil
+ exp := i < expFail
+ if exp && !got {
+ t.Errorf("object retrieval failed")
+ }
+ if !exp && got {
+ t.Errorf("unexpected object retrieval success")
+ }
+ }
+ }
+ }
+
+ // temporarily remove peer to test odr fails
+ odr.UnregisterPeer(lpeer)
+ // expect retrievals to fail (except genesis block) without a les peer
+ test(0)
+ odr.RegisterPeer(lpeer)
+ // expect all retrievals to pass
+ test(5)
+ odr.UnregisterPeer(lpeer)
+}
diff --git a/les/server.go b/les/server.go
new file mode 100644
index 000000000..bc5dd0837
--- /dev/null
+++ b/les/server.go
@@ -0,0 +1,401 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+// Package les implements the Light Ethereum Subprotocol.
+package les
+
+import (
+ "encoding/binary"
+ "fmt"
+ "math"
+ "sync"
+ "time"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core"
+ "github.com/ethereum/go-ethereum/core/types"
+ "github.com/ethereum/go-ethereum/eth"
+ "github.com/ethereum/go-ethereum/ethdb"
+ "github.com/ethereum/go-ethereum/les/flowcontrol"
+ "github.com/ethereum/go-ethereum/light"
+ "github.com/ethereum/go-ethereum/p2p"
+ "github.com/ethereum/go-ethereum/rlp"
+ "github.com/ethereum/go-ethereum/trie"
+)
+
+type LesServer struct {
+ protocolManager *ProtocolManager
+ fcManager *flowcontrol.ClientManager // nil if our node is client only
+ fcCostStats *requestCostStats
+ defParams *flowcontrol.ServerParams
+}
+
+func NewLesServer(eth *eth.Ethereum, config *eth.Config) (*LesServer, error) {
+ pm, err := NewProtocolManager(config.ChainConfig, false, config.NetworkId, eth.EventMux(), eth.Pow(), eth.BlockChain(), eth.TxPool(), eth.ChainDb(), nil, nil)
+ if err != nil {
+ return nil, err
+ }
+ pm.blockLoop()
+
+ srv := &LesServer{protocolManager: pm}
+ pm.server = srv
+
+ srv.defParams = &flowcontrol.ServerParams{
+ BufLimit: 300000000,
+ MinRecharge: 50000,
+ }
+ srv.fcManager = flowcontrol.NewClientManager(uint64(config.LightServ), 10, 1000000000)
+ srv.fcCostStats = newCostStats(eth.ChainDb())
+ return srv, nil
+}
+
+func (s *LesServer) Protocols() []p2p.Protocol {
+ return s.protocolManager.SubProtocols
+}
+
+func (s *LesServer) Start() {
+ s.protocolManager.Start()
+}
+
+func (s *LesServer) Stop() {
+ s.fcCostStats.store()
+ s.fcManager.Stop()
+ go func() {
+ <-s.protocolManager.noMorePeers
+ }()
+ s.protocolManager.Stop()
+}
+
+type requestCosts struct {
+ baseCost, reqCost uint64
+}
+
+type requestCostTable map[uint64]*requestCosts
+
+type RequestCostList []struct {
+ MsgCode, BaseCost, ReqCost uint64
+}
+
+func (list RequestCostList) decode() requestCostTable {
+ table := make(requestCostTable)
+ for _, e := range list {
+ table[e.MsgCode] = &requestCosts{
+ baseCost: e.BaseCost,
+ reqCost: e.ReqCost,
+ }
+ }
+ return table
+}
+
+func (table requestCostTable) encode() RequestCostList {
+ list := make(RequestCostList, len(table))
+ for idx, code := range reqList {
+ list[idx].MsgCode = code
+ list[idx].BaseCost = table[code].baseCost
+ list[idx].ReqCost = table[code].reqCost
+ }
+ return list
+}
+
+type linReg struct {
+ sumX, sumY, sumXX, sumXY float64
+ cnt uint64
+}
+
+const linRegMaxCnt = 100000
+
+func (l *linReg) add(x, y float64) {
+ if l.cnt >= linRegMaxCnt {
+ sub := float64(l.cnt+1-linRegMaxCnt) / linRegMaxCnt
+ l.sumX -= l.sumX * sub
+ l.sumY -= l.sumY * sub
+ l.sumXX -= l.sumXX * sub
+ l.sumXY -= l.sumXY * sub
+ l.cnt = linRegMaxCnt - 1
+ }
+ l.cnt++
+ l.sumX += x
+ l.sumY += y
+ l.sumXX += x * x
+ l.sumXY += x * y
+}
+
+func (l *linReg) calc() (b, m float64) {
+ if l.cnt == 0 {
+ return 0, 0
+ }
+ cnt := float64(l.cnt)
+ d := cnt*l.sumXX - l.sumX*l.sumX
+ if d < 0.001 {
+ return l.sumY / cnt, 0
+ }
+ m = (cnt*l.sumXY - l.sumX*l.sumY) / d
+ b = (l.sumY / cnt) - (m * l.sumX / cnt)
+ return b, m
+}
+
+func (l *linReg) toBytes() []byte {
+ var arr [40]byte
+ binary.BigEndian.PutUint64(arr[0:8], math.Float64bits(l.sumX))
+ binary.BigEndian.PutUint64(arr[8:16], math.Float64bits(l.sumY))
+ binary.BigEndian.PutUint64(arr[16:24], math.Float64bits(l.sumXX))
+ binary.BigEndian.PutUint64(arr[24:32], math.Float64bits(l.sumXY))
+ binary.BigEndian.PutUint64(arr[32:40], l.cnt)
+ return arr[:]
+}
+
+func linRegFromBytes(data []byte) *linReg {
+ if len(data) != 40 {
+ return nil
+ }
+ l := &linReg{}
+ l.sumX = math.Float64frombits(binary.BigEndian.Uint64(data[0:8]))
+ l.sumY = math.Float64frombits(binary.BigEndian.Uint64(data[8:16]))
+ l.sumXX = math.Float64frombits(binary.BigEndian.Uint64(data[16:24]))
+ l.sumXY = math.Float64frombits(binary.BigEndian.Uint64(data[24:32]))
+ l.cnt = binary.BigEndian.Uint64(data[32:40])
+ return l
+}
+
+type requestCostStats struct {
+ lock sync.RWMutex
+ db ethdb.Database
+ stats map[uint64]*linReg
+}
+
+type requestCostStatsRlp []struct {
+ MsgCode uint64
+ Data []byte
+}
+
+var rcStatsKey = []byte("_requestCostStats")
+
+func newCostStats(db ethdb.Database) *requestCostStats {
+ stats := make(map[uint64]*linReg)
+ for _, code := range reqList {
+ stats[code] = &linReg{cnt: 100}
+ }
+
+ if db != nil {
+ data, err := db.Get(rcStatsKey)
+ var statsRlp requestCostStatsRlp
+ if err == nil {
+ err = rlp.DecodeBytes(data, &statsRlp)
+ }
+ if err == nil {
+ for _, r := range statsRlp {
+ if stats[r.MsgCode] != nil {
+ if l := linRegFromBytes(r.Data); l != nil {
+ stats[r.MsgCode] = l
+ }
+ }
+ }
+ }
+ }
+
+ return &requestCostStats{
+ db: db,
+ stats: stats,
+ }
+}
+
+func (s *requestCostStats) store() {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+
+ statsRlp := make(requestCostStatsRlp, len(reqList))
+ for i, code := range reqList {
+ statsRlp[i].MsgCode = code
+ statsRlp[i].Data = s.stats[code].toBytes()
+ }
+
+ if data, err := rlp.EncodeToBytes(statsRlp); err == nil {
+ s.db.Put(rcStatsKey, data)
+ }
+}
+
+func (s *requestCostStats) getCurrentList() RequestCostList {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+
+ list := make(RequestCostList, len(reqList))
+ //fmt.Println("RequestCostList")
+ for idx, code := range reqList {
+ b, m := s.stats[code].calc()
+ //fmt.Println(code, s.stats[code].cnt, b/1000000, m/1000000)
+ if m < 0 {
+ b += m
+ m = 0
+ }
+ if b < 0 {
+ b = 0
+ }
+
+ list[idx].MsgCode = code
+ list[idx].BaseCost = uint64(b * 2)
+ list[idx].ReqCost = uint64(m * 2)
+ }
+ return list
+}
+
+func (s *requestCostStats) update(msgCode, reqCnt, cost uint64) {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+
+ c, ok := s.stats[msgCode]
+ if !ok || reqCnt == 0 {
+ return
+ }
+ c.add(float64(reqCnt), float64(cost))
+}
+
+func (pm *ProtocolManager) blockLoop() {
+ pm.wg.Add(1)
+ sub := pm.eventMux.Subscribe(core.ChainHeadEvent{})
+ newCht := make(chan struct{}, 10)
+ newCht <- struct{}{}
+ go func() {
+ var mu sync.Mutex
+ var lastHead *types.Header
+ lastBroadcastTd := common.Big0
+ for {
+ select {
+ case ev := <-sub.Chan():
+ peers := pm.peers.AllPeers()
+ if len(peers) > 0 {
+ header := ev.Data.(core.ChainHeadEvent).Block.Header()
+ hash := header.Hash()
+ number := header.GetNumberU64()
+ td := core.GetTd(pm.chainDb, hash, number)
+ if td != nil && td.Cmp(lastBroadcastTd) > 0 {
+ var reorg uint64
+ if lastHead != nil {
+ reorg = lastHead.GetNumberU64() - core.FindCommonAncestor(pm.chainDb, header, lastHead).GetNumberU64()
+ }
+ lastHead = header
+ lastBroadcastTd = td
+ //fmt.Println("BROADCAST", number, hash, td, reorg)
+ announce := announceData{Hash: hash, Number: number, Td: td, ReorgDepth: reorg}
+ for _, p := range peers {
+ select {
+ case p.announceChn <- announce:
+ default:
+ pm.removePeer(p.id)
+ }
+ }
+ }
+ }
+ newCht <- struct{}{}
+ case <-newCht:
+ go func() {
+ mu.Lock()
+ more := makeCht(pm.chainDb)
+ mu.Unlock()
+ if more {
+ time.Sleep(time.Millisecond * 10)
+ newCht <- struct{}{}
+ }
+ }()
+ case <-pm.quitSync:
+ sub.Unsubscribe()
+ pm.wg.Done()
+ return
+ }
+ }
+ }()
+}
+
+var (
+ lastChtKey = []byte("LastChtNumber") // chtNum (uint64 big endian)
+ chtPrefix = []byte("cht") // chtPrefix + chtNum (uint64 big endian) -> trie root hash
+ chtConfirmations = light.ChtFrequency / 2
+)
+
+func getChtRoot(db ethdb.Database, num uint64) common.Hash {
+ var encNumber [8]byte
+ binary.BigEndian.PutUint64(encNumber[:], num)
+ data, _ := db.Get(append(chtPrefix, encNumber[:]...))
+ return common.BytesToHash(data)
+}
+
+func storeChtRoot(db ethdb.Database, num uint64, root common.Hash) {
+ var encNumber [8]byte
+ binary.BigEndian.PutUint64(encNumber[:], num)
+ db.Put(append(chtPrefix, encNumber[:]...), root[:])
+}
+
+func makeCht(db ethdb.Database) bool {
+ headHash := core.GetHeadBlockHash(db)
+ headNum := core.GetBlockNumber(db, headHash)
+
+ var newChtNum uint64
+ if headNum > chtConfirmations {
+ newChtNum = (headNum - chtConfirmations) / light.ChtFrequency
+ }
+
+ var lastChtNum uint64
+ data, _ := db.Get(lastChtKey)
+ if len(data) == 8 {
+ lastChtNum = binary.BigEndian.Uint64(data[:])
+ }
+ if newChtNum <= lastChtNum {
+ return false
+ }
+
+ var t *trie.Trie
+ if lastChtNum > 0 {
+ var err error
+ t, err = trie.New(getChtRoot(db, lastChtNum), db)
+ if err != nil {
+ lastChtNum = 0
+ }
+ }
+ if lastChtNum == 0 {
+ t, _ = trie.New(common.Hash{}, db)
+ }
+
+ for num := lastChtNum * light.ChtFrequency; num < (lastChtNum+1)*light.ChtFrequency; num++ {
+ hash := core.GetCanonicalHash(db, num)
+ if hash == (common.Hash{}) {
+ panic("Canonical hash not found")
+ }
+ td := core.GetTd(db, hash, num)
+ if td == nil {
+ panic("TD not found")
+ }
+ var encNumber [8]byte
+ binary.BigEndian.PutUint64(encNumber[:], num)
+ var node light.ChtNode
+ node.Hash = hash
+ node.Td = td
+ data, _ := rlp.EncodeToBytes(node)
+ t.Update(encNumber[:], data)
+ }
+
+ root, err := t.Commit()
+ if err != nil {
+ lastChtNum = 0
+ } else {
+ lastChtNum++
+ fmt.Printf("CHT %d %064x\n", lastChtNum, root)
+ storeChtRoot(db, lastChtNum, root)
+ var data [8]byte
+ binary.BigEndian.PutUint64(data[:], lastChtNum)
+ db.Put(lastChtKey, data[:])
+ }
+
+ return newChtNum > lastChtNum
+}
diff --git a/les/sync.go b/les/sync.go
new file mode 100644
index 000000000..f92f8ce04
--- /dev/null
+++ b/les/sync.go
@@ -0,0 +1,84 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+package les
+
+import (
+ "time"
+
+ "github.com/ethereum/go-ethereum/core"
+ "github.com/ethereum/go-ethereum/eth/downloader"
+ "github.com/ethereum/go-ethereum/light"
+ "golang.org/x/net/context"
+)
+
+const (
+ //forceSyncCycle = 10 * time.Second // Time interval to force syncs, even if few peers are available
+ minDesiredPeerCount = 5 // Amount of peers desired to start syncing
+)
+
+// syncer is responsible for periodically synchronising with the network, both
+// downloading hashes and blocks as well as handling the announcement handler.
+func (pm *ProtocolManager) syncer() {
+ // Start and ensure cleanup of sync mechanisms
+ //pm.fetcher.Start()
+ //defer pm.fetcher.Stop()
+ defer pm.downloader.Terminate()
+
+ // Wait for different events to fire synchronisation operations
+ //forceSync := time.Tick(forceSyncCycle)
+ for {
+ select {
+ case <-pm.newPeerCh:
+/* // Make sure we have peers to select from, then sync
+ if pm.peers.Len() < minDesiredPeerCount {
+ break
+ }
+ go pm.synchronise(pm.peers.BestPeer())
+*/
+ /*case <-forceSync:
+ // Force a sync even if not enough peers are present
+ go pm.synchronise(pm.peers.BestPeer())
+ */
+ case <-pm.noMorePeers:
+ return
+ }
+ }
+}
+
+func (pm *ProtocolManager) needToSync(peerHead blockInfo) bool {
+ head := pm.blockchain.CurrentHeader()
+ currentTd := core.GetTd(pm.chainDb, head.Hash(), head.Number.Uint64())
+ return currentTd != nil && peerHead.Td.Cmp(currentTd) > 0
+}
+
+// synchronise tries to sync up our local block chain with a remote peer.
+func (pm *ProtocolManager) synchronise(peer *peer) {
+ // Short circuit if no peers are available
+ if peer == nil {
+ return
+ }
+
+ // Make sure the peer's TD is higher than our own.
+ if !pm.needToSync(peer.headBlockInfo()) {
+ return
+ }
+
+ ctx, _ := context.WithTimeout(context.Background(), time.Second*5)
+ pm.blockchain.(*light.LightChain).SyncCht(ctx)
+
+ pm.downloader.Synchronise(peer.id, peer.Head(), peer.Td(), downloader.LightSync)
+}
diff --git a/les/txrelay.go b/les/txrelay.go
new file mode 100644
index 000000000..2df2fa0a9
--- /dev/null
+++ b/les/txrelay.go
@@ -0,0 +1,156 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+package les
+
+import (
+ "sync"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/types"
+)
+
+type ltrInfo struct {
+ tx *types.Transaction
+ sentTo map[*peer]struct{}
+}
+
+type LesTxRelay struct {
+ txSent map[common.Hash]*ltrInfo
+ txPending map[common.Hash]struct{}
+ ps *peerSet
+ peerList []*peer
+ peerStartPos int
+ lock sync.RWMutex
+}
+
+func NewLesTxRelay() *LesTxRelay {
+ return &LesTxRelay{
+ txSent: make(map[common.Hash]*ltrInfo),
+ txPending: make(map[common.Hash]struct{}),
+ ps: newPeerSet(),
+ }
+}
+
+func (self *LesTxRelay) addPeer(p *peer) {
+ self.lock.Lock()
+ defer self.lock.Unlock()
+
+ self.ps.Register(p)
+ self.peerList = self.ps.AllPeers()
+}
+
+func (self *LesTxRelay) removePeer(id string) {
+ self.lock.Lock()
+ defer self.lock.Unlock()
+
+ self.ps.Unregister(id)
+ self.peerList = self.ps.AllPeers()
+}
+
+// send sends a list of transactions to at most a given number of peers at
+// once, never resending any particular transaction to the same peer twice
+func (self *LesTxRelay) send(txs types.Transactions, count int) {
+ sendTo := make(map[*peer]types.Transactions)
+
+ self.peerStartPos++ // rotate the starting position of the peer list
+ if self.peerStartPos >= len(self.peerList) {
+ self.peerStartPos = 0
+ }
+
+ for _, tx := range txs {
+ hash := tx.Hash()
+ ltr, ok := self.txSent[hash]
+ if !ok {
+ ltr = &ltrInfo{
+ tx: tx,
+ sentTo: make(map[*peer]struct{}),
+ }
+ self.txSent[hash] = ltr
+ self.txPending[hash] = struct{}{}
+ }
+
+ if len(self.peerList) > 0 {
+ cnt := count
+ pos := self.peerStartPos
+ for {
+ peer := self.peerList[pos]
+ if _, ok := ltr.sentTo[peer]; !ok {
+ sendTo[peer] = append(sendTo[peer], tx)
+ ltr.sentTo[peer] = struct{}{}
+ cnt--
+ }
+ if cnt == 0 {
+ break // sent it to the desired number of peers
+ }
+ pos++
+ if pos == len(self.peerList) {
+ pos = 0
+ }
+ if pos == self.peerStartPos {
+ break // tried all available peers
+ }
+ }
+ }
+ }
+
+ for p, list := range sendTo {
+ cost := p.GetRequestCost(SendTxMsg, len(list))
+ go func(p *peer, list types.Transactions, cost uint64) {
+ p.fcServer.SendRequest(0, cost)
+ p.SendTxs(cost, list)
+ }(p, list, cost)
+ }
+}
+
+func (self *LesTxRelay) Send(txs types.Transactions) {
+ self.lock.Lock()
+ defer self.lock.Unlock()
+
+ self.send(txs, 3)
+}
+
+func (self *LesTxRelay) NewHead(head common.Hash, mined []common.Hash, rollback []common.Hash) {
+ self.lock.Lock()
+ defer self.lock.Unlock()
+
+ for _, hash := range mined {
+ delete(self.txPending, hash)
+ }
+
+ for _, hash := range rollback {
+ self.txPending[hash] = struct{}{}
+ }
+
+ if len(self.txPending) > 0 {
+ txs := make(types.Transactions, len(self.txPending))
+ i := 0
+ for hash, _ := range self.txPending {
+ txs[i] = self.txSent[hash].tx
+ i++
+ }
+ self.send(txs, 1)
+ }
+}
+
+func (self *LesTxRelay) Discard(hashes []common.Hash) {
+ self.lock.Lock()
+ defer self.lock.Unlock()
+
+ for _, hash := range hashes {
+ delete(self.txSent, hash)
+ delete(self.txPending, hash)
+ }
+}