diff options
-rw-r--r-- | swarm/network/protocol.go | 6 | ||||
-rw-r--r-- | swarm/network/protocol_test.go | 13 |
2 files changed, 13 insertions, 6 deletions
diff --git a/swarm/network/protocol.go b/swarm/network/protocol.go index a4b29239c..74e7de126 100644 --- a/swarm/network/protocol.go +++ b/swarm/network/protocol.go @@ -168,7 +168,7 @@ func (b *Bzz) APIs() []rpc.API { func (b *Bzz) RunProtocol(spec *protocols.Spec, run func(*BzzPeer) error) func(*p2p.Peer, p2p.MsgReadWriter) error { return func(p *p2p.Peer, rw p2p.MsgReadWriter) error { // wait for the bzz protocol to perform the handshake - handshake, _ := b.GetHandshake(p.ID()) + handshake, _ := b.GetOrCreateHandshake(p.ID()) defer b.removeHandshake(p.ID()) select { case <-handshake.done: @@ -213,7 +213,7 @@ func (b *Bzz) performHandshake(p *protocols.Peer, handshake *HandshakeMsg) error // runBzz is the p2p protocol run function for the bzz base protocol // that negotiates the bzz handshake func (b *Bzz) runBzz(p *p2p.Peer, rw p2p.MsgReadWriter) error { - handshake, _ := b.GetHandshake(p.ID()) + handshake, _ := b.GetOrCreateHandshake(p.ID()) if !<-handshake.init { return fmt.Errorf("%08x: bzz already started on peer %08x", b.localAddr.Over()[:4], p.ID().Bytes()[:4]) } @@ -303,7 +303,7 @@ func (b *Bzz) removeHandshake(peerID enode.ID) { } // GetHandshake returns the bzz handhake that the remote peer with peerID sent -func (b *Bzz) GetHandshake(peerID enode.ID) (*HandshakeMsg, bool) { +func (b *Bzz) GetOrCreateHandshake(peerID enode.ID) (*HandshakeMsg, bool) { b.mtx.Lock() defer b.mtx.Unlock() handshake, found := b.handshakes[peerID] diff --git a/swarm/network/protocol_test.go b/swarm/network/protocol_test.go index 58477a7b8..80d67e767 100644 --- a/swarm/network/protocol_test.go +++ b/swarm/network/protocol_test.go @@ -21,6 +21,7 @@ import ( "fmt" "os" "testing" + "time" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p" @@ -224,7 +225,7 @@ func TestBzzHandshakeLightNode(t *testing.T) { for _, test := range lightNodeTests { t.Run(test.name, func(t *testing.T) { randomAddr := RandomAddr() - pt := newBzzHandshakeTester(t, 1, randomAddr, false) + pt := newBzzHandshakeTester(nil, 1, randomAddr, false) // TODO change signature - t is not used anywhere node := pt.Nodes[0] addr := NewAddr(node) @@ -237,8 +238,14 @@ func TestBzzHandshakeLightNode(t *testing.T) { t.Fatal(err) } - if pt.bzz.handshakes[node.ID()].LightNode != test.lightNode { - t.Fatalf("peer LightNode flag is %v, should be %v", pt.bzz.handshakes[node.ID()].LightNode, test.lightNode) + select { + + case <-pt.bzz.handshakes[node.ID()].done: + if pt.bzz.handshakes[node.ID()].LightNode != test.lightNode { + t.Fatalf("peer LightNode flag is %v, should be %v", pt.bzz.handshakes[node.ID()].LightNode, test.lightNode) + } + case <-time.After(10 * time.Second): + t.Fatal("test timeout") } }) } |