aboutsummaryrefslogtreecommitdiffstats
path: root/p2p
diff options
context:
space:
mode:
Diffstat (limited to 'p2p')
-rw-r--r--p2p/discover/node_test.go25
-rw-r--r--p2p/discover/table.go8
-rw-r--r--p2p/discover/table_test.go9
-rw-r--r--p2p/discover/udp_test.go55
-rw-r--r--p2p/peer.go87
-rw-r--r--p2p/peer_test.go2
-rw-r--r--p2p/server_test.go3
7 files changed, 112 insertions, 77 deletions
diff --git a/p2p/discover/node_test.go b/p2p/discover/node_test.go
index b1babd989..795460c49 100644
--- a/p2p/discover/node_test.go
+++ b/p2p/discover/node_test.go
@@ -13,11 +13,6 @@ import (
"github.com/ethereum/go-ethereum/crypto"
)
-var (
- quickrand = rand.New(rand.NewSource(time.Now().Unix()))
- quickcfg = &quick.Config{MaxCount: 5000, Rand: quickrand}
-)
-
var parseNodeTests = []struct {
rawurl string
wantError string
@@ -176,7 +171,7 @@ func TestNodeID_distcmp(t *testing.T) {
bbig := new(big.Int).SetBytes(b[:])
return new(big.Int).Xor(tbig, abig).Cmp(new(big.Int).Xor(tbig, bbig))
}
- if err := quick.CheckEqual(distcmp, distcmpBig, quickcfg); err != nil {
+ if err := quick.CheckEqual(distcmp, distcmpBig, quickcfg()); err != nil {
t.Error(err)
}
}
@@ -195,7 +190,7 @@ func TestNodeID_logdist(t *testing.T) {
abig, bbig := new(big.Int).SetBytes(a[:]), new(big.Int).SetBytes(b[:])
return new(big.Int).Xor(abig, bbig).BitLen()
}
- if err := quick.CheckEqual(logdist, logdistBig, quickcfg); err != nil {
+ if err := quick.CheckEqual(logdist, logdistBig, quickcfg()); err != nil {
t.Error(err)
}
}
@@ -211,9 +206,10 @@ func TestNodeID_logdistEqual(t *testing.T) {
func TestNodeID_hashAtDistance(t *testing.T) {
// we don't use quick.Check here because its output isn't
// very helpful when the test fails.
- for i := 0; i < quickcfg.MaxCount; i++ {
- a := gen(common.Hash{}, quickrand).(common.Hash)
- dist := quickrand.Intn(len(common.Hash{}) * 8)
+ cfg := quickcfg()
+ for i := 0; i < cfg.MaxCount; i++ {
+ a := gen(common.Hash{}, cfg.Rand).(common.Hash)
+ dist := cfg.Rand.Intn(len(common.Hash{}) * 8)
result := hashAtDistance(a, dist)
actualdist := logdist(result, a)
@@ -225,7 +221,14 @@ func TestNodeID_hashAtDistance(t *testing.T) {
}
}
-// TODO: this can be dropped when we require Go >= 1.5
+func quickcfg() *quick.Config {
+ return &quick.Config{
+ MaxCount: 5000,
+ Rand: rand.New(rand.NewSource(time.Now().Unix())),
+ }
+}
+
+// TODO: The Generate method can be dropped when we require Go >= 1.5
// because testing/quick learned to generate arrays in 1.5.
func (NodeID) Generate(rand *rand.Rand, size int) reflect.Value {
diff --git a/p2p/discover/table.go b/p2p/discover/table.go
index 4b7ddb775..f71320425 100644
--- a/p2p/discover/table.go
+++ b/p2p/discover/table.go
@@ -40,6 +40,8 @@ type Table struct {
bonding map[NodeID]*bondproc
bondslots chan struct{} // limits total number of active bonding processes
+ nodeAddedHook func(*Node) // for testing
+
net transport
self *Node // metadata of the local node
}
@@ -431,6 +433,9 @@ func (tab *Table) pingreplace(new *Node, b *bucket) {
}
copy(b.entries[1:], b.entries)
b.entries[0] = new
+ if tab.nodeAddedHook != nil {
+ tab.nodeAddedHook(new)
+ }
}
// ping a remote endpoint and wait for a reply, also updating the node database
@@ -466,6 +471,9 @@ outer:
}
if len(bucket.entries) < bucketSize {
bucket.entries = append(bucket.entries, n)
+ if tab.nodeAddedHook != nil {
+ tab.nodeAddedHook(n)
+ }
}
}
}
diff --git a/p2p/discover/table_test.go b/p2p/discover/table_test.go
index da398d137..829899916 100644
--- a/p2p/discover/table_test.go
+++ b/p2p/discover/table_test.go
@@ -9,6 +9,7 @@ import (
"reflect"
"testing"
"testing/quick"
+ "time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
@@ -74,7 +75,7 @@ func TestBucket_bumpNoDuplicates(t *testing.T) {
t.Parallel()
cfg := &quick.Config{
MaxCount: 1000,
- Rand: quickrand,
+ Rand: rand.New(rand.NewSource(time.Now().Unix())),
Values: func(args []reflect.Value, rand *rand.Rand) {
// generate a random list of nodes. this will be the content of the bucket.
n := rand.Intn(bucketSize-1) + 1
@@ -205,7 +206,7 @@ func TestTable_closest(t *testing.T) {
}
return true
}
- if err := quick.Check(test, quickcfg); err != nil {
+ if err := quick.Check(test, quickcfg()); err != nil {
t.Error(err)
}
}
@@ -213,7 +214,7 @@ func TestTable_closest(t *testing.T) {
func TestTable_ReadRandomNodesGetAll(t *testing.T) {
cfg := &quick.Config{
MaxCount: 200,
- Rand: quickrand,
+ Rand: rand.New(rand.NewSource(time.Now().Unix())),
Values: func(args []reflect.Value, rand *rand.Rand) {
args[0] = reflect.ValueOf(make([]*Node, rand.Intn(1000)))
},
@@ -221,7 +222,7 @@ func TestTable_ReadRandomNodesGetAll(t *testing.T) {
test := func(buf []*Node) bool {
tab := newTable(nil, NodeID{}, &net.UDPAddr{}, "")
for i := 0; i < len(buf); i++ {
- ld := quickrand.Intn(len(tab.buckets))
+ ld := cfg.Rand.Intn(len(tab.buckets))
tab.add([]*Node{nodeAtDistance(tab.self.sha, ld)})
}
gotN := tab.ReadRandomNodes(buf)
diff --git a/p2p/discover/udp_test.go b/p2p/discover/udp_test.go
index 11fa31d7c..b5d035a98 100644
--- a/p2p/discover/udp_test.go
+++ b/p2p/discover/udp_test.go
@@ -234,14 +234,12 @@ func TestUDP_findnodeMultiReply(t *testing.T) {
func TestUDP_successfulPing(t *testing.T) {
test := newUDPTest(t)
+ added := make(chan *Node, 1)
+ test.table.nodeAddedHook = func(n *Node) { added <- n }
defer test.table.Close()
- done := make(chan struct{})
- go func() {
- // The remote side sends a ping packet to initiate the exchange.
- test.packetIn(nil, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: Version, Expiration: futureExp})
- close(done)
- }()
+ // The remote side sends a ping packet to initiate the exchange.
+ go test.packetIn(nil, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: Version, Expiration: futureExp})
// the ping is replied to.
test.waitPacketOut(func(p *pong) {
@@ -277,35 +275,26 @@ func TestUDP_successfulPing(t *testing.T) {
})
test.packetIn(nil, pongPacket, &pong{Expiration: futureExp})
- // ping should return shortly after getting the pong packet.
- <-done
-
- // check that the node was added.
- rid := PubkeyID(&test.remotekey.PublicKey)
- rnode := find(test.table, rid)
- if rnode == nil {
- t.Fatalf("node %v not found in table", rid)
- }
- if !bytes.Equal(rnode.IP, test.remoteaddr.IP) {
- t.Errorf("node has wrong IP: got %v, want: %v", rnode.IP, test.remoteaddr.IP)
- }
- if int(rnode.UDP) != test.remoteaddr.Port {
- t.Errorf("node has wrong UDP port: got %v, want: %v", rnode.UDP, test.remoteaddr.Port)
- }
- if rnode.TCP != testRemote.TCP {
- t.Errorf("node has wrong TCP port: got %v, want: %v", rnode.TCP, testRemote.TCP)
- }
-}
-
-func find(tab *Table, id NodeID) *Node {
- for _, b := range tab.buckets {
- for _, e := range b.entries {
- if e.ID == id {
- return e
- }
+ // the node should be added to the table shortly after getting the
+ // pong packet.
+ select {
+ case n := <-added:
+ rid := PubkeyID(&test.remotekey.PublicKey)
+ if n.ID != rid {
+ t.Errorf("node has wrong ID: got %v, want %v", n.ID, rid)
}
+ if !bytes.Equal(n.IP, test.remoteaddr.IP) {
+ t.Errorf("node has wrong IP: got %v, want: %v", n.IP, test.remoteaddr.IP)
+ }
+ if int(n.UDP) != test.remoteaddr.Port {
+ t.Errorf("node has wrong UDP port: got %v, want: %v", n.UDP, test.remoteaddr.Port)
+ }
+ if n.TCP != testRemote.TCP {
+ t.Errorf("node has wrong TCP port: got %v, want: %v", n.TCP, testRemote.TCP)
+ }
+ case <-time.After(2 * time.Second):
+ t.Errorf("node was not added within 2 seconds")
}
- return nil
}
// dgramPipe is a fake UDP socket. It queues all sent datagrams.
diff --git a/p2p/peer.go b/p2p/peer.go
index cbe5ccc84..40466cf84 100644
--- a/p2p/peer.go
+++ b/p2p/peer.go
@@ -115,41 +115,60 @@ func newPeer(conn *conn, protocols []Protocol) *Peer {
}
func (p *Peer) run() DiscReason {
- readErr := make(chan error, 1)
+ var (
+ writeStart = make(chan struct{}, 1)
+ writeErr = make(chan error, 1)
+ readErr = make(chan error, 1)
+ reason DiscReason
+ requested bool
+ )
p.wg.Add(2)
go p.readLoop(readErr)
go p.pingLoop()
- p.startProtocols()
+ // Start all protocol handlers.
+ writeStart <- struct{}{}
+ p.startProtocols(writeStart, writeErr)
// Wait for an error or disconnect.
- var (
- reason DiscReason
- requested bool
- )
- select {
- case err := <-readErr:
- if r, ok := err.(DiscReason); ok {
- reason = r
- } else {
- // Note: We rely on protocols to abort if there is a write
- // error. It might be more robust to handle them here as well.
- glog.V(logger.Detail).Infof("%v: Read error: %v\n", p, err)
- reason = DiscNetworkError
+loop:
+ for {
+ select {
+ case err := <-writeErr:
+ // A write finished. Allow the next write to start if
+ // there was no error.
+ if err != nil {
+ glog.V(logger.Detail).Infof("%v: write error: %v\n", p, err)
+ reason = DiscNetworkError
+ break loop
+ }
+ writeStart <- struct{}{}
+ case err := <-readErr:
+ if r, ok := err.(DiscReason); ok {
+ glog.V(logger.Debug).Infof("%v: remote requested disconnect: %v\n", p, r)
+ requested = true
+ reason = r
+ } else {
+ glog.V(logger.Detail).Infof("%v: read error: %v\n", p, err)
+ reason = DiscNetworkError
+ }
+ break loop
+ case err := <-p.protoErr:
+ reason = discReasonForError(err)
+ glog.V(logger.Debug).Infof("%v: protocol error: %v (%v)\n", p, err, reason)
+ break loop
+ case reason = <-p.disc:
+ glog.V(logger.Debug).Infof("%v: locally requested disconnect: %v\n", p, reason)
+ break loop
}
- case err := <-p.protoErr:
- reason = discReasonForError(err)
- case reason = <-p.disc:
- requested = true
}
+
close(p.closed)
p.rw.close(reason)
p.wg.Wait()
-
if requested {
reason = DiscRequested
}
- glog.V(logger.Debug).Infof("%v: Disconnected: %v\n", p, reason)
return reason
}
@@ -196,7 +215,6 @@ func (p *Peer) handle(msg Msg) error {
// This is the last message. We don't need to discard or
// check errors because, the connection will be closed after it.
rlp.Decode(msg.Payload, &reason)
- glog.V(logger.Debug).Infof("%v: Disconnect Requested: %v\n", p, reason[0])
return reason[0]
case msg.Code < baseProtocolLength:
// ignore other base protocol messages
@@ -247,11 +265,13 @@ outer:
return result
}
-func (p *Peer) startProtocols() {
+func (p *Peer) startProtocols(writeStart <-chan struct{}, writeErr chan<- error) {
p.wg.Add(len(p.running))
for _, proto := range p.running {
proto := proto
proto.closed = p.closed
+ proto.wstart = writeStart
+ proto.werr = writeErr
glog.V(logger.Detail).Infof("%v: Starting protocol %s/%d\n", p, proto.Name, proto.Version)
go func() {
err := proto.Run(p, proto)
@@ -280,18 +300,31 @@ func (p *Peer) getProto(code uint64) (*protoRW, error) {
type protoRW struct {
Protocol
- in chan Msg
- closed <-chan struct{}
+ in chan Msg // receices read messages
+ closed <-chan struct{} // receives when peer is shutting down
+ wstart <-chan struct{} // receives when write may start
+ werr chan<- error // for write results
offset uint64
w MsgWriter
}
-func (rw *protoRW) WriteMsg(msg Msg) error {
+func (rw *protoRW) WriteMsg(msg Msg) (err error) {
if msg.Code >= rw.Length {
return newPeerError(errInvalidMsgCode, "not handled")
}
msg.Code += rw.offset
- return rw.w.WriteMsg(msg)
+ select {
+ case <-rw.wstart:
+ err = rw.w.WriteMsg(msg)
+ // Report write status back to Peer.run. It will initiate
+ // shutdown if the error is non-nil and unblock the next write
+ // otherwise. The calling protocol code should exit for errors
+ // as well but we don't want to rely on that.
+ rw.werr <- err
+ case <-rw.closed:
+ err = fmt.Errorf("shutting down")
+ }
+ return err
}
func (rw *protoRW) ReadMsg() (Msg, error) {
diff --git a/p2p/peer_test.go b/p2p/peer_test.go
index 7b772e198..575d0ff79 100644
--- a/p2p/peer_test.go
+++ b/p2p/peer_test.go
@@ -121,7 +121,7 @@ func TestPeerDisconnect(t *testing.T) {
}
select {
case reason := <-disc:
- if reason != DiscQuitting {
+ if reason != DiscRequested {
t.Errorf("run returned wrong reason: got %v, want %v", reason, DiscRequested)
}
case <-time.After(500 * time.Millisecond):
diff --git a/p2p/server_test.go b/p2p/server_test.go
index 01448cc7b..e8d21a188 100644
--- a/p2p/server_test.go
+++ b/p2p/server_test.go
@@ -117,7 +117,6 @@ func TestServerDial(t *testing.T) {
t.Error("accept error:", err)
return
}
- conn.Close()
accepted <- conn
}()
@@ -134,6 +133,8 @@ func TestServerDial(t *testing.T) {
select {
case conn := <-accepted:
+ defer conn.Close()
+
select {
case peer := <-connected:
if peer.ID() != remid {