diff options
Diffstat (limited to 'p2p')
-rw-r--r-- | p2p/discover/node_test.go | 25 | ||||
-rw-r--r-- | p2p/discover/table.go | 8 | ||||
-rw-r--r-- | p2p/discover/table_test.go | 9 | ||||
-rw-r--r-- | p2p/discover/udp_test.go | 55 | ||||
-rw-r--r-- | p2p/peer.go | 87 | ||||
-rw-r--r-- | p2p/peer_test.go | 2 | ||||
-rw-r--r-- | p2p/server_test.go | 3 |
7 files changed, 112 insertions, 77 deletions
diff --git a/p2p/discover/node_test.go b/p2p/discover/node_test.go index b1babd989..795460c49 100644 --- a/p2p/discover/node_test.go +++ b/p2p/discover/node_test.go @@ -13,11 +13,6 @@ import ( "github.com/ethereum/go-ethereum/crypto" ) -var ( - quickrand = rand.New(rand.NewSource(time.Now().Unix())) - quickcfg = &quick.Config{MaxCount: 5000, Rand: quickrand} -) - var parseNodeTests = []struct { rawurl string wantError string @@ -176,7 +171,7 @@ func TestNodeID_distcmp(t *testing.T) { bbig := new(big.Int).SetBytes(b[:]) return new(big.Int).Xor(tbig, abig).Cmp(new(big.Int).Xor(tbig, bbig)) } - if err := quick.CheckEqual(distcmp, distcmpBig, quickcfg); err != nil { + if err := quick.CheckEqual(distcmp, distcmpBig, quickcfg()); err != nil { t.Error(err) } } @@ -195,7 +190,7 @@ func TestNodeID_logdist(t *testing.T) { abig, bbig := new(big.Int).SetBytes(a[:]), new(big.Int).SetBytes(b[:]) return new(big.Int).Xor(abig, bbig).BitLen() } - if err := quick.CheckEqual(logdist, logdistBig, quickcfg); err != nil { + if err := quick.CheckEqual(logdist, logdistBig, quickcfg()); err != nil { t.Error(err) } } @@ -211,9 +206,10 @@ func TestNodeID_logdistEqual(t *testing.T) { func TestNodeID_hashAtDistance(t *testing.T) { // we don't use quick.Check here because its output isn't // very helpful when the test fails. - for i := 0; i < quickcfg.MaxCount; i++ { - a := gen(common.Hash{}, quickrand).(common.Hash) - dist := quickrand.Intn(len(common.Hash{}) * 8) + cfg := quickcfg() + for i := 0; i < cfg.MaxCount; i++ { + a := gen(common.Hash{}, cfg.Rand).(common.Hash) + dist := cfg.Rand.Intn(len(common.Hash{}) * 8) result := hashAtDistance(a, dist) actualdist := logdist(result, a) @@ -225,7 +221,14 @@ func TestNodeID_hashAtDistance(t *testing.T) { } } -// TODO: this can be dropped when we require Go >= 1.5 +func quickcfg() *quick.Config { + return &quick.Config{ + MaxCount: 5000, + Rand: rand.New(rand.NewSource(time.Now().Unix())), + } +} + +// TODO: The Generate method can be dropped when we require Go >= 1.5 // because testing/quick learned to generate arrays in 1.5. func (NodeID) Generate(rand *rand.Rand, size int) reflect.Value { diff --git a/p2p/discover/table.go b/p2p/discover/table.go index 4b7ddb775..f71320425 100644 --- a/p2p/discover/table.go +++ b/p2p/discover/table.go @@ -40,6 +40,8 @@ type Table struct { bonding map[NodeID]*bondproc bondslots chan struct{} // limits total number of active bonding processes + nodeAddedHook func(*Node) // for testing + net transport self *Node // metadata of the local node } @@ -431,6 +433,9 @@ func (tab *Table) pingreplace(new *Node, b *bucket) { } copy(b.entries[1:], b.entries) b.entries[0] = new + if tab.nodeAddedHook != nil { + tab.nodeAddedHook(new) + } } // ping a remote endpoint and wait for a reply, also updating the node database @@ -466,6 +471,9 @@ outer: } if len(bucket.entries) < bucketSize { bucket.entries = append(bucket.entries, n) + if tab.nodeAddedHook != nil { + tab.nodeAddedHook(n) + } } } } diff --git a/p2p/discover/table_test.go b/p2p/discover/table_test.go index da398d137..829899916 100644 --- a/p2p/discover/table_test.go +++ b/p2p/discover/table_test.go @@ -9,6 +9,7 @@ import ( "reflect" "testing" "testing/quick" + "time" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" @@ -74,7 +75,7 @@ func TestBucket_bumpNoDuplicates(t *testing.T) { t.Parallel() cfg := &quick.Config{ MaxCount: 1000, - Rand: quickrand, + Rand: rand.New(rand.NewSource(time.Now().Unix())), Values: func(args []reflect.Value, rand *rand.Rand) { // generate a random list of nodes. this will be the content of the bucket. n := rand.Intn(bucketSize-1) + 1 @@ -205,7 +206,7 @@ func TestTable_closest(t *testing.T) { } return true } - if err := quick.Check(test, quickcfg); err != nil { + if err := quick.Check(test, quickcfg()); err != nil { t.Error(err) } } @@ -213,7 +214,7 @@ func TestTable_closest(t *testing.T) { func TestTable_ReadRandomNodesGetAll(t *testing.T) { cfg := &quick.Config{ MaxCount: 200, - Rand: quickrand, + Rand: rand.New(rand.NewSource(time.Now().Unix())), Values: func(args []reflect.Value, rand *rand.Rand) { args[0] = reflect.ValueOf(make([]*Node, rand.Intn(1000))) }, @@ -221,7 +222,7 @@ func TestTable_ReadRandomNodesGetAll(t *testing.T) { test := func(buf []*Node) bool { tab := newTable(nil, NodeID{}, &net.UDPAddr{}, "") for i := 0; i < len(buf); i++ { - ld := quickrand.Intn(len(tab.buckets)) + ld := cfg.Rand.Intn(len(tab.buckets)) tab.add([]*Node{nodeAtDistance(tab.self.sha, ld)}) } gotN := tab.ReadRandomNodes(buf) diff --git a/p2p/discover/udp_test.go b/p2p/discover/udp_test.go index 11fa31d7c..b5d035a98 100644 --- a/p2p/discover/udp_test.go +++ b/p2p/discover/udp_test.go @@ -234,14 +234,12 @@ func TestUDP_findnodeMultiReply(t *testing.T) { func TestUDP_successfulPing(t *testing.T) { test := newUDPTest(t) + added := make(chan *Node, 1) + test.table.nodeAddedHook = func(n *Node) { added <- n } defer test.table.Close() - done := make(chan struct{}) - go func() { - // The remote side sends a ping packet to initiate the exchange. - test.packetIn(nil, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: Version, Expiration: futureExp}) - close(done) - }() + // The remote side sends a ping packet to initiate the exchange. + go test.packetIn(nil, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: Version, Expiration: futureExp}) // the ping is replied to. test.waitPacketOut(func(p *pong) { @@ -277,35 +275,26 @@ func TestUDP_successfulPing(t *testing.T) { }) test.packetIn(nil, pongPacket, &pong{Expiration: futureExp}) - // ping should return shortly after getting the pong packet. - <-done - - // check that the node was added. - rid := PubkeyID(&test.remotekey.PublicKey) - rnode := find(test.table, rid) - if rnode == nil { - t.Fatalf("node %v not found in table", rid) - } - if !bytes.Equal(rnode.IP, test.remoteaddr.IP) { - t.Errorf("node has wrong IP: got %v, want: %v", rnode.IP, test.remoteaddr.IP) - } - if int(rnode.UDP) != test.remoteaddr.Port { - t.Errorf("node has wrong UDP port: got %v, want: %v", rnode.UDP, test.remoteaddr.Port) - } - if rnode.TCP != testRemote.TCP { - t.Errorf("node has wrong TCP port: got %v, want: %v", rnode.TCP, testRemote.TCP) - } -} - -func find(tab *Table, id NodeID) *Node { - for _, b := range tab.buckets { - for _, e := range b.entries { - if e.ID == id { - return e - } + // the node should be added to the table shortly after getting the + // pong packet. + select { + case n := <-added: + rid := PubkeyID(&test.remotekey.PublicKey) + if n.ID != rid { + t.Errorf("node has wrong ID: got %v, want %v", n.ID, rid) } + if !bytes.Equal(n.IP, test.remoteaddr.IP) { + t.Errorf("node has wrong IP: got %v, want: %v", n.IP, test.remoteaddr.IP) + } + if int(n.UDP) != test.remoteaddr.Port { + t.Errorf("node has wrong UDP port: got %v, want: %v", n.UDP, test.remoteaddr.Port) + } + if n.TCP != testRemote.TCP { + t.Errorf("node has wrong TCP port: got %v, want: %v", n.TCP, testRemote.TCP) + } + case <-time.After(2 * time.Second): + t.Errorf("node was not added within 2 seconds") } - return nil } // dgramPipe is a fake UDP socket. It queues all sent datagrams. diff --git a/p2p/peer.go b/p2p/peer.go index cbe5ccc84..40466cf84 100644 --- a/p2p/peer.go +++ b/p2p/peer.go @@ -115,41 +115,60 @@ func newPeer(conn *conn, protocols []Protocol) *Peer { } func (p *Peer) run() DiscReason { - readErr := make(chan error, 1) + var ( + writeStart = make(chan struct{}, 1) + writeErr = make(chan error, 1) + readErr = make(chan error, 1) + reason DiscReason + requested bool + ) p.wg.Add(2) go p.readLoop(readErr) go p.pingLoop() - p.startProtocols() + // Start all protocol handlers. + writeStart <- struct{}{} + p.startProtocols(writeStart, writeErr) // Wait for an error or disconnect. - var ( - reason DiscReason - requested bool - ) - select { - case err := <-readErr: - if r, ok := err.(DiscReason); ok { - reason = r - } else { - // Note: We rely on protocols to abort if there is a write - // error. It might be more robust to handle them here as well. - glog.V(logger.Detail).Infof("%v: Read error: %v\n", p, err) - reason = DiscNetworkError +loop: + for { + select { + case err := <-writeErr: + // A write finished. Allow the next write to start if + // there was no error. + if err != nil { + glog.V(logger.Detail).Infof("%v: write error: %v\n", p, err) + reason = DiscNetworkError + break loop + } + writeStart <- struct{}{} + case err := <-readErr: + if r, ok := err.(DiscReason); ok { + glog.V(logger.Debug).Infof("%v: remote requested disconnect: %v\n", p, r) + requested = true + reason = r + } else { + glog.V(logger.Detail).Infof("%v: read error: %v\n", p, err) + reason = DiscNetworkError + } + break loop + case err := <-p.protoErr: + reason = discReasonForError(err) + glog.V(logger.Debug).Infof("%v: protocol error: %v (%v)\n", p, err, reason) + break loop + case reason = <-p.disc: + glog.V(logger.Debug).Infof("%v: locally requested disconnect: %v\n", p, reason) + break loop } - case err := <-p.protoErr: - reason = discReasonForError(err) - case reason = <-p.disc: - requested = true } + close(p.closed) p.rw.close(reason) p.wg.Wait() - if requested { reason = DiscRequested } - glog.V(logger.Debug).Infof("%v: Disconnected: %v\n", p, reason) return reason } @@ -196,7 +215,6 @@ func (p *Peer) handle(msg Msg) error { // This is the last message. We don't need to discard or // check errors because, the connection will be closed after it. rlp.Decode(msg.Payload, &reason) - glog.V(logger.Debug).Infof("%v: Disconnect Requested: %v\n", p, reason[0]) return reason[0] case msg.Code < baseProtocolLength: // ignore other base protocol messages @@ -247,11 +265,13 @@ outer: return result } -func (p *Peer) startProtocols() { +func (p *Peer) startProtocols(writeStart <-chan struct{}, writeErr chan<- error) { p.wg.Add(len(p.running)) for _, proto := range p.running { proto := proto proto.closed = p.closed + proto.wstart = writeStart + proto.werr = writeErr glog.V(logger.Detail).Infof("%v: Starting protocol %s/%d\n", p, proto.Name, proto.Version) go func() { err := proto.Run(p, proto) @@ -280,18 +300,31 @@ func (p *Peer) getProto(code uint64) (*protoRW, error) { type protoRW struct { Protocol - in chan Msg - closed <-chan struct{} + in chan Msg // receices read messages + closed <-chan struct{} // receives when peer is shutting down + wstart <-chan struct{} // receives when write may start + werr chan<- error // for write results offset uint64 w MsgWriter } -func (rw *protoRW) WriteMsg(msg Msg) error { +func (rw *protoRW) WriteMsg(msg Msg) (err error) { if msg.Code >= rw.Length { return newPeerError(errInvalidMsgCode, "not handled") } msg.Code += rw.offset - return rw.w.WriteMsg(msg) + select { + case <-rw.wstart: + err = rw.w.WriteMsg(msg) + // Report write status back to Peer.run. It will initiate + // shutdown if the error is non-nil and unblock the next write + // otherwise. The calling protocol code should exit for errors + // as well but we don't want to rely on that. + rw.werr <- err + case <-rw.closed: + err = fmt.Errorf("shutting down") + } + return err } func (rw *protoRW) ReadMsg() (Msg, error) { diff --git a/p2p/peer_test.go b/p2p/peer_test.go index 7b772e198..575d0ff79 100644 --- a/p2p/peer_test.go +++ b/p2p/peer_test.go @@ -121,7 +121,7 @@ func TestPeerDisconnect(t *testing.T) { } select { case reason := <-disc: - if reason != DiscQuitting { + if reason != DiscRequested { t.Errorf("run returned wrong reason: got %v, want %v", reason, DiscRequested) } case <-time.After(500 * time.Millisecond): diff --git a/p2p/server_test.go b/p2p/server_test.go index 01448cc7b..e8d21a188 100644 --- a/p2p/server_test.go +++ b/p2p/server_test.go @@ -117,7 +117,6 @@ func TestServerDial(t *testing.T) { t.Error("accept error:", err) return } - conn.Close() accepted <- conn }() @@ -134,6 +133,8 @@ func TestServerDial(t *testing.T) { select { case conn := <-accepted: + defer conn.Close() + select { case peer := <-connected: if peer.ID() != remid { |