aboutsummaryrefslogtreecommitdiffstats
path: root/swarm/network
diff options
context:
space:
mode:
Diffstat (limited to 'swarm/network')
-rw-r--r--swarm/network/fetcher.go24
-rw-r--r--swarm/network/fetcher_test.go54
-rw-r--r--swarm/network/stream/delivery.go6
-rw-r--r--swarm/network/stream/messages.go14
-rw-r--r--swarm/network/stream/stream.go2
-rw-r--r--swarm/network/stream/streamer_test.go77
6 files changed, 146 insertions, 31 deletions
diff --git a/swarm/network/fetcher.go b/swarm/network/fetcher.go
index 5b4b61c7e..6aed57e22 100644
--- a/swarm/network/fetcher.go
+++ b/swarm/network/fetcher.go
@@ -32,6 +32,8 @@ var searchTimeout = 1 * time.Second
// Also used in stream delivery.
var RequestTimeout = 10 * time.Second
+var maxHopCount uint8 = 20 // maximum number of forwarded requests (hops), to make sure requests are not forwarded forever in peer loops
+
type RequestFunc func(context.Context, *Request) (*enode.ID, chan struct{}, error)
// Fetcher is created when a chunk is not found locally. It starts a request handler loop once and
@@ -44,7 +46,7 @@ type Fetcher struct {
protoRequestFunc RequestFunc // request function fetcher calls to issue retrieve request for a chunk
addr storage.Address // the address of the chunk to be fetched
offerC chan *enode.ID // channel of sources (peer node id strings)
- requestC chan struct{}
+ requestC chan uint8 // channel for incoming requests (with the hopCount value in it)
skipCheck bool
}
@@ -53,6 +55,7 @@ type Request struct {
Source *enode.ID // nodeID of peer to request from (can be nil)
SkipCheck bool // whether to offer the chunk first or deliver directly
peersToSkip *sync.Map // peers not to request chunk from (only makes sense if source is nil)
+ HopCount uint8 // number of forwarded requests (hops)
}
// NewRequest returns a new instance of Request based on chunk address skip check and
@@ -113,7 +116,7 @@ func NewFetcher(addr storage.Address, rf RequestFunc, skipCheck bool) *Fetcher {
addr: addr,
protoRequestFunc: rf,
offerC: make(chan *enode.ID),
- requestC: make(chan struct{}),
+ requestC: make(chan uint8),
skipCheck: skipCheck,
}
}
@@ -136,7 +139,7 @@ func (f *Fetcher) Offer(ctx context.Context, source *enode.ID) {
}
// Request is called when an upstream peer request the chunk as part of `RetrieveRequestMsg`, or from a local request through FileStore, and the node does not have the chunk locally.
-func (f *Fetcher) Request(ctx context.Context) {
+func (f *Fetcher) Request(ctx context.Context, hopCount uint8) {
// First we need to have this select to make sure that we return if context is done
select {
case <-ctx.Done():
@@ -144,10 +147,15 @@ func (f *Fetcher) Request(ctx context.Context) {
default:
}
+ if hopCount >= maxHopCount {
+ log.Debug("fetcher request hop count limit reached", "hops", hopCount)
+ return
+ }
+
// This select alone would not guarantee that we return of context is done, it could potentially
// push to offerC instead if offerC is available (see number 2 in https://golang.org/ref/spec#Select_statements)
select {
- case f.requestC <- struct{}{}:
+ case f.requestC <- hopCount + 1:
case <-ctx.Done():
}
}
@@ -161,6 +169,7 @@ func (f *Fetcher) run(ctx context.Context, peers *sync.Map) {
waitC <-chan time.Time // timer channel
sources []*enode.ID // known sources, ie. peers that offered the chunk
requested bool // true if the chunk was actually requested
+ hopCount uint8
)
gone := make(chan *enode.ID) // channel to signal that a peer we requested from disconnected
@@ -183,7 +192,7 @@ func (f *Fetcher) run(ctx context.Context, peers *sync.Map) {
doRequest = requested
// incoming request
- case <-f.requestC:
+ case hopCount = <-f.requestC:
log.Trace("new request", "request addr", f.addr)
// 2) chunk is requested, set requested flag
// launch a request iff none been launched yet
@@ -213,7 +222,7 @@ func (f *Fetcher) run(ctx context.Context, peers *sync.Map) {
// need to issue a new request
if doRequest {
var err error
- sources, err = f.doRequest(ctx, gone, peers, sources)
+ sources, err = f.doRequest(ctx, gone, peers, sources, hopCount)
if err != nil {
log.Info("unable to request", "request addr", f.addr, "err", err)
}
@@ -251,7 +260,7 @@ func (f *Fetcher) run(ctx context.Context, peers *sync.Map) {
// * the peer's address is added to the set of peers to skip
// * the peer's address is removed from prospective sources, and
// * a go routine is started that reports on the gone channel if the peer is disconnected (or terminated their streamer)
-func (f *Fetcher) doRequest(ctx context.Context, gone chan *enode.ID, peersToSkip *sync.Map, sources []*enode.ID) ([]*enode.ID, error) {
+func (f *Fetcher) doRequest(ctx context.Context, gone chan *enode.ID, peersToSkip *sync.Map, sources []*enode.ID, hopCount uint8) ([]*enode.ID, error) {
var i int
var sourceID *enode.ID
var quit chan struct{}
@@ -260,6 +269,7 @@ func (f *Fetcher) doRequest(ctx context.Context, gone chan *enode.ID, peersToSki
Addr: f.addr,
SkipCheck: f.skipCheck,
peersToSkip: peersToSkip,
+ HopCount: hopCount,
}
foundSource := false
diff --git a/swarm/network/fetcher_test.go b/swarm/network/fetcher_test.go
index b2316b097..3a926f475 100644
--- a/swarm/network/fetcher_test.go
+++ b/swarm/network/fetcher_test.go
@@ -33,7 +33,7 @@ type mockRequester struct {
// requests []Request
requestC chan *Request // when a request is coming it is pushed to requestC
waitTimes []time.Duration // with waitTimes[i] you can define how much to wait on the ith request (optional)
- ctr int //counts the number of requests
+ count int //counts the number of requests
quitC chan struct{}
}
@@ -47,9 +47,9 @@ func newMockRequester(waitTimes ...time.Duration) *mockRequester {
func (m *mockRequester) doRequest(ctx context.Context, request *Request) (*enode.ID, chan struct{}, error) {
waitTime := time.Duration(0)
- if m.ctr < len(m.waitTimes) {
- waitTime = m.waitTimes[m.ctr]
- m.ctr++
+ if m.count < len(m.waitTimes) {
+ waitTime = m.waitTimes[m.count]
+ m.count++
}
time.Sleep(waitTime)
m.requestC <- request
@@ -83,7 +83,7 @@ func TestFetcherSingleRequest(t *testing.T) {
go fetcher.run(ctx, peersToSkip)
rctx := context.Background()
- fetcher.Request(rctx)
+ fetcher.Request(rctx, 0)
select {
case request := <-requester.requestC:
@@ -100,6 +100,11 @@ func TestFetcherSingleRequest(t *testing.T) {
t.Fatalf("request.peersToSkip does not contain peer returned by the request function")
}
+ // hopCount in the forwarded request should be incremented
+ if request.HopCount != 1 {
+ t.Fatalf("Expected request.HopCount 1 got %v", request.HopCount)
+ }
+
// fetch should trigger a request, if it doesn't happen in time, test should fail
case <-time.After(200 * time.Millisecond):
t.Fatalf("fetch timeout")
@@ -123,7 +128,7 @@ func TestFetcherCancelStopsFetcher(t *testing.T) {
rctx, rcancel := context.WithTimeout(ctx, 100*time.Millisecond)
defer rcancel()
// we call Request with an active context
- fetcher.Request(rctx)
+ fetcher.Request(rctx, 0)
// fetcher should not initiate request, we can only check by waiting a bit and making sure no request is happening
select {
@@ -151,7 +156,7 @@ func TestFetcherCancelStopsRequest(t *testing.T) {
rcancel()
// we call Request with a cancelled context
- fetcher.Request(rctx)
+ fetcher.Request(rctx, 0)
// fetcher should not initiate request, we can only check by waiting a bit and making sure no request is happening
select {
@@ -162,7 +167,7 @@ func TestFetcherCancelStopsRequest(t *testing.T) {
// if there is another Request with active context, there should be a request, because the fetcher itself is not cancelled
rctx = context.Background()
- fetcher.Request(rctx)
+ fetcher.Request(rctx, 0)
select {
case <-requester.requestC:
@@ -200,7 +205,7 @@ func TestFetcherOfferUsesSource(t *testing.T) {
// call Request after the Offer
rctx = context.Background()
- fetcher.Request(rctx)
+ fetcher.Request(rctx, 0)
// there should be exactly 1 request coming from fetcher
var request *Request
@@ -241,7 +246,7 @@ func TestFetcherOfferAfterRequestUsesSourceFromContext(t *testing.T) {
// call Request first
rctx := context.Background()
- fetcher.Request(rctx)
+ fetcher.Request(rctx, 0)
// there should be a request coming from fetcher
var request *Request
@@ -296,7 +301,7 @@ func TestFetcherRetryOnTimeout(t *testing.T) {
// call the fetch function with an active context
rctx := context.Background()
- fetcher.Request(rctx)
+ fetcher.Request(rctx, 0)
// after 100ms the first request should be initiated
time.Sleep(100 * time.Millisecond)
@@ -338,7 +343,7 @@ func TestFetcherFactory(t *testing.T) {
fetcher := fetcherFactory.New(context.Background(), addr, peersToSkip)
- fetcher.Request(context.Background())
+ fetcher.Request(context.Background(), 0)
// check if the created fetchFunction really starts a fetcher and initiates a request
select {
@@ -368,7 +373,7 @@ func TestFetcherRequestQuitRetriesRequest(t *testing.T) {
go fetcher.run(ctx, peersToSkip)
rctx := context.Background()
- fetcher.Request(rctx)
+ fetcher.Request(rctx, 0)
select {
case <-requester.requestC:
@@ -457,3 +462,26 @@ func TestRequestSkipPeerPermanent(t *testing.T) {
t.Errorf("peer not skipped")
}
}
+
+func TestFetcherMaxHopCount(t *testing.T) {
+ requester := newMockRequester()
+ addr := make([]byte, 32)
+ fetcher := NewFetcher(addr, requester.doRequest, true)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ peersToSkip := &sync.Map{}
+
+ go fetcher.run(ctx, peersToSkip)
+
+ rctx := context.Background()
+ fetcher.Request(rctx, maxHopCount)
+
+ // if hopCount is already at max no request should be initiated
+ select {
+ case <-requester.requestC:
+ t.Fatalf("cancelled fetcher initiated request")
+ case <-time.After(200 * time.Millisecond):
+ }
+}
diff --git a/swarm/network/stream/delivery.go b/swarm/network/stream/delivery.go
index 431136ab1..c2adb1009 100644
--- a/swarm/network/stream/delivery.go
+++ b/swarm/network/stream/delivery.go
@@ -128,6 +128,7 @@ func (s *SwarmChunkServer) GetData(ctx context.Context, key []byte) ([]byte, err
type RetrieveRequestMsg struct {
Addr storage.Address
SkipCheck bool
+ HopCount uint8
}
func (d *Delivery) handleRetrieveRequestMsg(ctx context.Context, sp *Peer, req *RetrieveRequestMsg) error {
@@ -148,7 +149,9 @@ func (d *Delivery) handleRetrieveRequestMsg(ctx context.Context, sp *Peer, req *
var cancel func()
// TODO: do something with this hardcoded timeout, maybe use TTL in the future
- ctx, cancel = context.WithTimeout(context.WithValue(ctx, "peer", sp.ID().String()), network.RequestTimeout)
+ ctx = context.WithValue(ctx, "peer", sp.ID().String())
+ ctx = context.WithValue(ctx, "hopcount", req.HopCount)
+ ctx, cancel = context.WithTimeout(ctx, network.RequestTimeout)
go func() {
select {
@@ -247,6 +250,7 @@ func (d *Delivery) RequestFromPeers(ctx context.Context, req *network.Request) (
err := sp.SendPriority(ctx, &RetrieveRequestMsg{
Addr: req.Addr,
SkipCheck: req.SkipCheck,
+ HopCount: req.HopCount,
}, Top)
if err != nil {
return nil, nil, err
diff --git a/swarm/network/stream/messages.go b/swarm/network/stream/messages.go
index 1e47b7cf9..74c785d58 100644
--- a/swarm/network/stream/messages.go
+++ b/swarm/network/stream/messages.go
@@ -26,7 +26,7 @@ import (
bv "github.com/ethereum/go-ethereum/swarm/network/bitvector"
"github.com/ethereum/go-ethereum/swarm/spancontext"
"github.com/ethereum/go-ethereum/swarm/storage"
- opentracing "github.com/opentracing/opentracing-go"
+ "github.com/opentracing/opentracing-go"
)
var syncBatchTimeout = 30 * time.Second
@@ -197,10 +197,16 @@ func (p *Peer) handleOfferedHashesMsg(ctx context.Context, req *OfferedHashesMsg
if err != nil {
return err
}
+
hashes := req.Hashes
- want, err := bv.New(len(hashes) / HashSize)
+ lenHashes := len(hashes)
+ if lenHashes%HashSize != 0 {
+ return fmt.Errorf("error invalid hashes length (len: %v)", lenHashes)
+ }
+
+ want, err := bv.New(lenHashes / HashSize)
if err != nil {
- return fmt.Errorf("error initiaising bitvector of length %v: %v", len(hashes)/HashSize, err)
+ return fmt.Errorf("error initiaising bitvector of length %v: %v", lenHashes/HashSize, err)
}
ctr := 0
@@ -208,7 +214,7 @@ func (p *Peer) handleOfferedHashesMsg(ctx context.Context, req *OfferedHashesMsg
ctx, cancel := context.WithTimeout(ctx, syncBatchTimeout)
ctx = context.WithValue(ctx, "source", p.ID().String())
- for i := 0; i < len(hashes); i += HashSize {
+ for i := 0; i < lenHashes; i += HashSize {
hash := hashes[i : i+HashSize]
if wait := c.NeedData(ctx, hash); wait != nil {
diff --git a/swarm/network/stream/stream.go b/swarm/network/stream/stream.go
index 3b1b11d36..1eda06c6a 100644
--- a/swarm/network/stream/stream.go
+++ b/swarm/network/stream/stream.go
@@ -642,7 +642,7 @@ func (c *clientParams) clientCreated() {
// Spec is the spec of the streamer protocol
var Spec = &protocols.Spec{
Name: "stream",
- Version: 6,
+ Version: 7,
MaxMsgSize: 10 * 1024 * 1024,
Messages: []interface{}{
UnsubscribeMsg{},
diff --git a/swarm/network/stream/streamer_test.go b/swarm/network/stream/streamer_test.go
index 04366cd39..0bdebefa7 100644
--- a/swarm/network/stream/streamer_test.go
+++ b/swarm/network/stream/streamer_test.go
@@ -19,6 +19,7 @@ package stream
import (
"bytes"
"context"
+ "errors"
"strconv"
"testing"
"time"
@@ -56,11 +57,12 @@ func TestStreamerRequestSubscription(t *testing.T) {
}
var (
- hash0 = sha3.Sum256([]byte{0})
- hash1 = sha3.Sum256([]byte{1})
- hash2 = sha3.Sum256([]byte{2})
- hashesTmp = append(hash0[:], hash1[:]...)
- hashes = append(hashesTmp, hash2[:]...)
+ hash0 = sha3.Sum256([]byte{0})
+ hash1 = sha3.Sum256([]byte{1})
+ hash2 = sha3.Sum256([]byte{2})
+ hashesTmp = append(hash0[:], hash1[:]...)
+ hashes = append(hashesTmp, hash2[:]...)
+ corruptHashes = append(hashes[:40])
)
type testClient struct {
@@ -460,6 +462,71 @@ func TestStreamerUpstreamSubscribeLiveAndHistory(t *testing.T) {
}
}
+func TestStreamerDownstreamCorruptHashesMsgExchange(t *testing.T) {
+ tester, streamer, _, teardown, err := newStreamerTester(t, nil)
+ defer teardown()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ stream := NewStream("foo", "", true)
+
+ var tc *testClient
+
+ streamer.RegisterClientFunc("foo", func(p *Peer, t string, live bool) (Client, error) {
+ tc = newTestClient(t)
+ return tc, nil
+ })
+
+ node := tester.Nodes[0]
+
+ err = streamer.Subscribe(node.ID(), stream, NewRange(5, 8), Top)
+ if err != nil {
+ t.Fatalf("Expected no error, got %v", err)
+ }
+
+ err = tester.TestExchanges(p2ptest.Exchange{
+ Label: "Subscribe message",
+ Expects: []p2ptest.Expect{
+ {
+ Code: 4,
+ Msg: &SubscribeMsg{
+ Stream: stream,
+ History: NewRange(5, 8),
+ Priority: Top,
+ },
+ Peer: node.ID(),
+ },
+ },
+ },
+ p2ptest.Exchange{
+ Label: "Corrupt offered hash message",
+ Triggers: []p2ptest.Trigger{
+ {
+ Code: 1,
+ Msg: &OfferedHashesMsg{
+ HandoverProof: &HandoverProof{
+ Handover: &Handover{},
+ },
+ Hashes: corruptHashes,
+ From: 5,
+ To: 8,
+ Stream: stream,
+ },
+ Peer: node.ID(),
+ },
+ },
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ expectedError := errors.New("Message handler error: (msg code 1): error invalid hashes length (len: 40)")
+ if err := tester.TestDisconnected(&p2ptest.Disconnect{Peer: tester.Nodes[0].ID(), Error: expectedError}); err != nil {
+ t.Fatal(err)
+ }
+}
+
func TestStreamerDownstreamOfferedHashesMsgExchange(t *testing.T) {
tester, streamer, _, teardown, err := newStreamerTester(t, nil)
defer teardown()