From e34d1341022a51d8a86c4836c91e4e0ded888d27 Mon Sep 17 00:00:00 2001
From: Felix Lange <fjl@twurst.com>
Date: Sat, 7 Feb 2015 00:13:22 +0100
Subject: p2p: fixes for actual connections

The unit test hooks were turned on 'in production'.
---
 p2p/message.go     |  4 ++--
 p2p/peer.go        | 37 +++++++++++++++++++++----------------
 p2p/peer_error.go  |  2 +-
 p2p/peer_test.go   | 19 ++++++++++---------
 p2p/server.go      |  4 +++-
 p2p/server_test.go |  1 +
 6 files changed, 38 insertions(+), 29 deletions(-)

diff --git a/p2p/message.go b/p2p/message.go
index 6521d09c2..dfc33f349 100644
--- a/p2p/message.go
+++ b/p2p/message.go
@@ -174,10 +174,10 @@ func (rw *frameRW) ReadMsg() (msg Msg, err error) {
 	// read magic and payload size
 	start := make([]byte, 8)
 	if _, err = io.ReadFull(rw.bufconn, start); err != nil {
-		return msg, newPeerError(errRead, "%v", err)
+		return msg, err
 	}
 	if !bytes.HasPrefix(start, magicToken) {
-		return msg, newPeerError(errMagicTokenMismatch, "got %x, want %x", start[:4], magicToken)
+		return msg, fmt.Errorf("bad magic token %x", start[:4], magicToken)
 	}
 	size := binary.BigEndian.Uint32(start[4:])
 
diff --git a/p2p/peer.go b/p2p/peer.go
index 1fa8264a3..b61cf96da 100644
--- a/p2p/peer.go
+++ b/p2p/peer.go
@@ -1,6 +1,7 @@
 package p2p
 
 import (
+	"errors"
 	"fmt"
 	"io"
 	"io/ioutil"
@@ -71,7 +72,8 @@ type Peer struct {
 	runlock   sync.RWMutex // protects running
 	running   map[string]*proto
 
-	protocolHandshakeEnabled bool
+	// disables protocol handshake, for testing
+	noHandshake bool
 
 	protoWG  sync.WaitGroup
 	protoErr chan error
@@ -134,11 +136,11 @@ func (p *Peer) Disconnect(reason DiscReason) {
 
 // String implements fmt.Stringer.
 func (p *Peer) String() string {
-	return fmt.Sprintf("Peer %.8x %v", p.remoteID, p.RemoteAddr())
+	return fmt.Sprintf("Peer %.8x %v", p.remoteID[:], p.RemoteAddr())
 }
 
 func newPeer(conn net.Conn, protocols []Protocol, ourName string, ourID, remoteID *discover.NodeID) *Peer {
-	logtag := fmt.Sprintf("Peer %.8x %v", remoteID, conn.RemoteAddr())
+	logtag := fmt.Sprintf("Peer %.8x %v", remoteID[:], conn.RemoteAddr())
 	return &Peer{
 		Logger:    logger.NewLogger(logtag),
 		rw:        newFrameRW(conn, msgWriteTimeout),
@@ -164,33 +166,35 @@ func (p *Peer) run() DiscReason {
 	var readErr = make(chan error, 1)
 	defer p.closeProtocols()
 	defer close(p.closed)
-	defer p.rw.Close()
 
-	// start the read loop
 	go func() { readErr <- p.readLoop() }()
 
-	if p.protocolHandshakeEnabled {
+	if !p.noHandshake {
 		if err := writeProtocolHandshake(p.rw, p.ourName, *p.ourID, p.protocols); err != nil {
 			p.DebugDetailf("Protocol handshake error: %v\n", err)
+			p.rw.Close()
 			return DiscProtocolError
 		}
 	}
 
-	// wait for an error or disconnect
+	// Wait for an error or disconnect.
 	var reason DiscReason
 	select {
 	case err := <-readErr:
 		// We rely on protocols to abort if there is a write error. It
 		// might be more robust to handle them here as well.
 		p.DebugDetailf("Read error: %v\n", err)
-		reason = DiscNetworkError
+		p.rw.Close()
+		return DiscNetworkError
+
 	case err := <-p.protoErr:
 		reason = discReasonForError(err)
 	case reason = <-p.disc:
 	}
-	if reason != DiscNetworkError {
-		p.politeDisconnect(reason)
-	}
+	p.politeDisconnect(reason)
+
+	// Wait for readLoop. It will end because conn is now closed.
+	<-readErr
 	p.Debugf("Disconnected: %v\n", reason)
 	return reason
 }
@@ -198,9 +202,9 @@ func (p *Peer) run() DiscReason {
 func (p *Peer) politeDisconnect(reason DiscReason) {
 	done := make(chan struct{})
 	go func() {
-		// send reason
 		EncodeMsg(p.rw, discMsg, uint(reason))
-		// discard any data that might arrive
+		// Wait for the other side to close the connection.
+		// Discard any data that they send until then.
 		io.Copy(ioutil.Discard, p.rw)
 		close(done)
 	}()
@@ -208,10 +212,11 @@ func (p *Peer) politeDisconnect(reason DiscReason) {
 	case <-done:
 	case <-time.After(disconnectGracePeriod):
 	}
+	p.rw.Close()
 }
 
 func (p *Peer) readLoop() error {
-	if p.protocolHandshakeEnabled {
+	if !p.noHandshake {
 		if err := readProtocolHandshake(p, p.rw); err != nil {
 			return err
 		}
@@ -264,7 +269,7 @@ func readProtocolHandshake(p *Peer, rw MsgReadWriter) error {
 		return newPeerError(errProtocolBreach, "expected handshake, got %x", msg.Code)
 	}
 	if msg.Size > baseProtocolMaxMsgSize {
-		return newPeerError(errMisc, "message too big")
+		return newPeerError(errInvalidMsg, "message too big")
 	}
 	var hs handshake
 	if err := msg.Decode(&hs); err != nil {
@@ -326,7 +331,7 @@ func (p *Peer) startProto(offset uint64, impl Protocol) *proto {
 		err := impl.Run(p, rw)
 		if err == nil {
 			p.DebugDetailf("Protocol %s/%d returned\n", impl.Name, impl.Version)
-			err = newPeerError(errMisc, "protocol returned")
+			err = errors.New("protocol returned")
 		} else {
 			p.DebugDetailf("Protocol %s/%d error: %v\n", impl.Name, impl.Version, err)
 		}
diff --git a/p2p/peer_error.go b/p2p/peer_error.go
index 9133768f9..0ff4f4b43 100644
--- a/p2p/peer_error.go
+++ b/p2p/peer_error.go
@@ -123,7 +123,7 @@ func discReasonForError(err error) DiscReason {
 		return DiscProtocolError
 	case errPingTimeout:
 		return DiscReadTimeout
-	case errRead, errWrite, errMisc:
+	case errRead, errWrite:
 		return DiscNetworkError
 	default:
 		return DiscSubprotocolError
diff --git a/p2p/peer_test.go b/p2p/peer_test.go
index 76d856d3e..68c9910a2 100644
--- a/p2p/peer_test.go
+++ b/p2p/peer_test.go
@@ -30,10 +30,10 @@ var discard = Protocol{
 	},
 }
 
-func testPeer(handshake bool, protos []Protocol) (*frameRW, *Peer, <-chan DiscReason) {
+func testPeer(noHandshake bool, protos []Protocol) (*frameRW, *Peer, <-chan DiscReason) {
 	conn1, conn2 := net.Pipe()
 	peer := newPeer(conn1, protos, "name", &discover.NodeID{}, &discover.NodeID{})
-	peer.protocolHandshakeEnabled = handshake
+	peer.noHandshake = noHandshake
 	errc := make(chan DiscReason, 1)
 	go func() { errc <- peer.run() }()
 	return newFrameRW(conn2, msgWriteTimeout), peer, errc
@@ -61,7 +61,7 @@ func TestPeerProtoReadMsg(t *testing.T) {
 		},
 	}
 
-	rw, peer, errc := testPeer(false, []Protocol{proto})
+	rw, peer, errc := testPeer(true, []Protocol{proto})
 	defer rw.Close()
 	peer.startSubprotocols([]Cap{proto.cap()})
 
@@ -100,7 +100,7 @@ func TestPeerProtoReadLargeMsg(t *testing.T) {
 		},
 	}
 
-	rw, peer, errc := testPeer(false, []Protocol{proto})
+	rw, peer, errc := testPeer(true, []Protocol{proto})
 	defer rw.Close()
 	peer.startSubprotocols([]Cap{proto.cap()})
 
@@ -130,7 +130,7 @@ func TestPeerProtoEncodeMsg(t *testing.T) {
 			return nil
 		},
 	}
-	rw, peer, _ := testPeer(false, []Protocol{proto})
+	rw, peer, _ := testPeer(true, []Protocol{proto})
 	defer rw.Close()
 	peer.startSubprotocols([]Cap{proto.cap()})
 
@@ -142,7 +142,7 @@ func TestPeerProtoEncodeMsg(t *testing.T) {
 func TestPeerWriteForBroadcast(t *testing.T) {
 	defer testlog(t).detach()
 
-	rw, peer, peerErr := testPeer(false, []Protocol{discard})
+	rw, peer, peerErr := testPeer(true, []Protocol{discard})
 	defer rw.Close()
 	peer.startSubprotocols([]Cap{discard.cap()})
 
@@ -179,7 +179,7 @@ func TestPeerWriteForBroadcast(t *testing.T) {
 func TestPeerPing(t *testing.T) {
 	defer testlog(t).detach()
 
-	rw, _, _ := testPeer(false, nil)
+	rw, _, _ := testPeer(true, nil)
 	defer rw.Close()
 	if err := EncodeMsg(rw, pingMsg); err != nil {
 		t.Fatal(err)
@@ -192,7 +192,7 @@ func TestPeerPing(t *testing.T) {
 func TestPeerDisconnect(t *testing.T) {
 	defer testlog(t).detach()
 
-	rw, _, disc := testPeer(false, nil)
+	rw, _, disc := testPeer(true, nil)
 	defer rw.Close()
 	if err := EncodeMsg(rw, discMsg, DiscQuitting); err != nil {
 		t.Fatal(err)
@@ -233,7 +233,7 @@ func TestPeerHandshake(t *testing.T) {
 		{Name: "c", Version: 3, Length: 1, Run: run},
 		{Name: "d", Version: 4, Length: 1, Run: run},
 	}
-	rw, p, disc := testPeer(true, protocols)
+	rw, p, disc := testPeer(false, protocols)
 	p.remoteID = remote.ourID
 	defer rw.Close()
 
@@ -269,6 +269,7 @@ func TestPeerHandshake(t *testing.T) {
 	}
 
 	close(stop)
+	expectMsg(rw, discMsg, nil)
 	t.Logf("disc reason: %v", <-disc)
 }
 
diff --git a/p2p/server.go b/p2p/server.go
index 87be97a2f..c6d7fc2e8 100644
--- a/p2p/server.go
+++ b/p2p/server.go
@@ -408,7 +408,9 @@ func (srv *Server) startPeer(conn net.Conn, dest *discover.Node) {
 		return
 	}
 
-	srv.newPeerHook(p)
+	if srv.newPeerHook != nil {
+		srv.newPeerHook(p)
+	}
 	p.run()
 	srv.removePeer(p)
 }
diff --git a/p2p/server_test.go b/p2p/server_test.go
index 89300cf1c..d1e1640fb 100644
--- a/p2p/server_test.go
+++ b/p2p/server_test.go
@@ -118,6 +118,7 @@ func TestServerBroadcast(t *testing.T) {
 	srv := startTestServer(t, func(p *Peer) {
 		p.protocols = []Protocol{discard}
 		p.startSubprotocols([]Cap{discard.cap()})
+		p.noHandshake = true
 		connected.Done()
 	})
 	defer srv.Stop()
-- 
cgit v1.2.3