From 736e632215d49dd7bc61126f78dda4bad12768ea Mon Sep 17 00:00:00 2001
From: Felix Lange <fjl@twurst.com>
Date: Fri, 27 Feb 2015 03:06:55 +0000
Subject: p2p: use RLPx frames for messaging

---
 p2p/handshake.go   | 27 ++++++++++-----------------
 p2p/message.go     | 19 +++++++++++++++++++
 p2p/peer.go        | 21 ++++++++++++---------
 p2p/peer_test.go   | 36 +++++++++++++++++++-----------------
 p2p/server.go      |  7 ++++---
 p2p/server_test.go | 13 +++++++++----
 6 files changed, 73 insertions(+), 50 deletions(-)

diff --git a/p2p/handshake.go b/p2p/handshake.go
index 17f572dea..10ef970dc 100644
--- a/p2p/handshake.go
+++ b/p2p/handshake.go
@@ -32,14 +32,10 @@ const (
 )
 
 type conn struct {
-	*frameRW
+	MsgReadWriter
 	*protoHandshake
 }
 
-func newConn(fd net.Conn, hs *protoHandshake) *conn {
-	return &conn{newFrameRW(fd, msgWriteTimeout), hs}
-}
-
 // encHandshake contains the state of the encryption handshake.
 type encHandshake struct {
 	remoteID             discover.NodeID
@@ -115,17 +111,16 @@ func setupInboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake) (
 
 	// Run the protocol handshake using authenticated messages.
 	// TODO: move buffering setup here (out of newFrameRW)
-	phsrw := newRlpxFrameRW(fd, secrets)
-	rhs, err := readProtocolHandshake(phsrw, our)
+	rw := newRlpxFrameRW(fd, secrets)
+	rhs, err := readProtocolHandshake(rw, our)
 	if err != nil {
 		return nil, err
 	}
-	if err := writeProtocolHandshake(phsrw, our); err != nil {
+	// TODO: validate that handshake node ID matches
+	if err := writeProtocolHandshake(rw, our); err != nil {
 		return nil, fmt.Errorf("protocol write error: %v", err)
 	}
-
-	rw := newFrameRW(fd, msgWriteTimeout)
-	return &conn{rw, rhs}, nil
+	return &conn{&lockedRW{wrapped: rw}, rhs}, nil
 }
 
 func setupOutboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) {
@@ -136,20 +131,18 @@ func setupOutboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake,
 
 	// Run the protocol handshake using authenticated messages.
 	// TODO: move buffering setup here (out of newFrameRW)
-	phsrw := newRlpxFrameRW(fd, secrets)
-	if err := writeProtocolHandshake(phsrw, our); err != nil {
+	rw := newRlpxFrameRW(fd, secrets)
+	if err := writeProtocolHandshake(rw, our); err != nil {
 		return nil, fmt.Errorf("protocol write error: %v", err)
 	}
-	rhs, err := readProtocolHandshake(phsrw, our)
+	rhs, err := readProtocolHandshake(rw, our)
 	if err != nil {
 		return nil, fmt.Errorf("protocol handshake read error: %v", err)
 	}
 	if rhs.ID != dial.ID {
 		return nil, errors.New("dialed node id mismatch")
 	}
-
-	rw := newFrameRW(fd, msgWriteTimeout)
-	return &conn{rw, rhs}, nil
+	return &conn{&lockedRW{wrapped: rw}, rhs}, nil
 }
 
 // outboundEncHandshake negotiates a session token on conn.
diff --git a/p2p/message.go b/p2p/message.go
index 7adad4b09..d61faad13 100644
--- a/p2p/message.go
+++ b/p2p/message.go
@@ -119,6 +119,25 @@ func EncodeMsg(w MsgWriter, code uint64, data ...interface{}) error {
 	return w.WriteMsg(NewMsg(code, data...))
 }
 
+// lockedRW wraps a MsgReadWriter with locks around
+// ReadMsg and WriteMsg.
+type lockedRW struct {
+	rmu, wmu sync.Mutex
+	wrapped  MsgReadWriter
+}
+
+func (rw *lockedRW) ReadMsg() (Msg, error) {
+	rw.rmu.Lock()
+	defer rw.rmu.Unlock()
+	return rw.wrapped.ReadMsg()
+}
+
+func (rw *lockedRW) WriteMsg(msg Msg) error {
+	rw.wmu.Lock()
+	defer rw.wmu.Unlock()
+	return rw.wrapped.WriteMsg(msg)
+}
+
 // frameRW is a MsgReadWriter that reads and writes devp2p message frames.
 // As required by the interface, ReadMsg and WriteMsg can be called from
 // multiple goroutines.
diff --git a/p2p/peer.go b/p2p/peer.go
index fb027c834..4982c4612 100644
--- a/p2p/peer.go
+++ b/p2p/peer.go
@@ -40,6 +40,7 @@ type Peer struct {
 	// Use them to display messages related to the peer.
 	*logger.Logger
 
+	conn    net.Conn
 	rw      *conn
 	running map[string]*protoRW
 
@@ -52,8 +53,9 @@ type Peer struct {
 // NewPeer returns a peer for testing purposes.
 func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer {
 	pipe, _ := net.Pipe()
-	conn := newConn(pipe, &protoHandshake{ID: id, Name: name, Caps: caps})
-	peer := newPeer(conn, nil)
+	msgpipe, _ := MsgPipe()
+	conn := &conn{msgpipe, &protoHandshake{ID: id, Name: name, Caps: caps}}
+	peer := newPeer(pipe, conn, nil)
 	close(peer.closed) // ensures Disconnect doesn't block
 	return peer
 }
@@ -76,12 +78,12 @@ func (p *Peer) Caps() []Cap {
 
 // RemoteAddr returns the remote address of the network connection.
 func (p *Peer) RemoteAddr() net.Addr {
-	return p.rw.RemoteAddr()
+	return p.conn.RemoteAddr()
 }
 
 // LocalAddr returns the local address of the network connection.
 func (p *Peer) LocalAddr() net.Addr {
-	return p.rw.LocalAddr()
+	return p.conn.LocalAddr()
 }
 
 // Disconnect terminates the peer connection with the given reason.
@@ -98,10 +100,11 @@ func (p *Peer) String() string {
 	return fmt.Sprintf("Peer %.8x %v", p.rw.ID[:], p.RemoteAddr())
 }
 
-func newPeer(conn *conn, protocols []Protocol) *Peer {
-	logtag := fmt.Sprintf("Peer %.8x %v", conn.ID[:], conn.RemoteAddr())
+func newPeer(fd net.Conn, conn *conn, protocols []Protocol) *Peer {
+	logtag := fmt.Sprintf("Peer %.8x %v", conn.ID[:], fd.RemoteAddr())
 	p := &Peer{
 		Logger:   logger.NewLogger(logtag),
+		conn:     fd,
 		rw:       conn,
 		running:  matchProtocols(protocols, conn.Caps, conn),
 		disc:     make(chan DiscReason),
@@ -138,7 +141,7 @@ loop:
 			// We rely on protocols to abort if there is a write error. It
 			// might be more robust to handle them here as well.
 			p.DebugDetailf("Read error: %v\n", err)
-			p.rw.Close()
+			p.conn.Close()
 			return DiscNetworkError
 		case err := <-p.protoErr:
 			reason = discReasonForError(err)
@@ -161,14 +164,14 @@ func (p *Peer) politeDisconnect(reason DiscReason) {
 		EncodeMsg(p.rw, discMsg, uint(reason))
 		// Wait for the other side to close the connection.
 		// Discard any data that they send until then.
-		io.Copy(ioutil.Discard, p.rw)
+		io.Copy(ioutil.Discard, p.conn)
 		close(done)
 	}()
 	select {
 	case <-done:
 	case <-time.After(disconnectGracePeriod):
 	}
-	p.rw.Close()
+	p.conn.Close()
 }
 
 func (p *Peer) readLoop() error {
diff --git a/p2p/peer_test.go b/p2p/peer_test.go
index a1260adbd..1ba43bed5 100644
--- a/p2p/peer_test.go
+++ b/p2p/peer_test.go
@@ -3,6 +3,7 @@ package p2p
 import (
 	"bytes"
 	"fmt"
+	"io"
 	"io/ioutil"
 	"net"
 	"reflect"
@@ -29,8 +30,8 @@ var discard = Protocol{
 	},
 }
 
-func testPeer(protos []Protocol) (*conn, *Peer, <-chan DiscReason) {
-	fd1, fd2 := net.Pipe()
+func testPeer(protos []Protocol) (io.Closer, *conn, *Peer, <-chan DiscReason) {
+	fd1, _ := net.Pipe()
 	hs1 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion}
 	hs2 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion}
 	for _, p := range protos {
@@ -38,11 +39,12 @@ func testPeer(protos []Protocol) (*conn, *Peer, <-chan DiscReason) {
 		hs2.Caps = append(hs2.Caps, p.cap())
 	}
 
-	peer := newPeer(newConn(fd1, hs1), protos)
+	p1, p2 := MsgPipe()
+	peer := newPeer(fd1, &conn{p1, hs1}, protos)
 	errc := make(chan DiscReason, 1)
 	go func() { errc <- peer.run() }()
 
-	return newConn(fd2, hs2), peer, errc
+	return p1, &conn{p2, hs2}, peer, errc
 }
 
 func TestPeerProtoReadMsg(t *testing.T) {
@@ -67,8 +69,8 @@ func TestPeerProtoReadMsg(t *testing.T) {
 		},
 	}
 
-	rw, _, errc := testPeer([]Protocol{proto})
-	defer rw.Close()
+	closer, rw, _, errc := testPeer([]Protocol{proto})
+	defer closer.Close()
 
 	EncodeMsg(rw, baseProtocolLength+2, 1)
 	EncodeMsg(rw, baseProtocolLength+3, 2)
@@ -105,8 +107,8 @@ func TestPeerProtoReadLargeMsg(t *testing.T) {
 		},
 	}
 
-	rw, _, errc := testPeer([]Protocol{proto})
-	defer rw.Close()
+	closer, rw, _, errc := testPeer([]Protocol{proto})
+	defer closer.Close()
 
 	EncodeMsg(rw, 18, make([]byte, msgsize))
 	select {
@@ -134,8 +136,8 @@ func TestPeerProtoEncodeMsg(t *testing.T) {
 			return nil
 		},
 	}
-	rw, _, _ := testPeer([]Protocol{proto})
-	defer rw.Close()
+	closer, rw, _, _ := testPeer([]Protocol{proto})
+	defer closer.Close()
 
 	if err := expectMsg(rw, 17, []string{"foo", "bar"}); err != nil {
 		t.Error(err)
@@ -145,8 +147,8 @@ func TestPeerProtoEncodeMsg(t *testing.T) {
 func TestPeerWriteForBroadcast(t *testing.T) {
 	defer testlog(t).detach()
 
-	rw, peer, peerErr := testPeer([]Protocol{discard})
-	defer rw.Close()
+	closer, rw, peer, peerErr := testPeer([]Protocol{discard})
+	defer closer.Close()
 
 	// test write errors
 	if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil {
@@ -181,8 +183,8 @@ func TestPeerWriteForBroadcast(t *testing.T) {
 func TestPeerPing(t *testing.T) {
 	defer testlog(t).detach()
 
-	rw, _, _ := testPeer(nil)
-	defer rw.Close()
+	closer, rw, _, _ := testPeer(nil)
+	defer closer.Close()
 	if err := EncodeMsg(rw, pingMsg); err != nil {
 		t.Fatal(err)
 	}
@@ -194,15 +196,15 @@ func TestPeerPing(t *testing.T) {
 func TestPeerDisconnect(t *testing.T) {
 	defer testlog(t).detach()
 
-	rw, _, disc := testPeer(nil)
-	defer rw.Close()
+	closer, rw, _, disc := testPeer(nil)
+	defer closer.Close()
 	if err := EncodeMsg(rw, discMsg, DiscQuitting); err != nil {
 		t.Fatal(err)
 	}
 	if err := expectMsg(rw, discMsg, []interface{}{DiscRequested}); err != nil {
 		t.Error(err)
 	}
-	rw.Close() // make test end faster
+	closer.Close() // make test end faster
 	if reason := <-disc; reason != DiscRequested {
 		t.Errorf("run returned wrong reason: got %v, want %v", reason, DiscRequested)
 	}
diff --git a/p2p/server.go b/p2p/server.go
index 3ea2538d1..e53e832aa 100644
--- a/p2p/server.go
+++ b/p2p/server.go
@@ -358,14 +358,15 @@ func (srv *Server) findPeers() {
 
 func (srv *Server) startPeer(fd net.Conn, dest *discover.Node) {
 	// TODO: handle/store session token
-	fd.SetDeadline(time.Now().Add(handshakeTimeout))
+	// TODO: reenable deadlines
+	// fd.SetDeadline(time.Now().Add(handshakeTimeout))
 	conn, err := srv.setupFunc(fd, srv.PrivateKey, srv.ourHandshake, dest)
 	if err != nil {
 		fd.Close()
 		srvlog.Debugf("Handshake with %v failed: %v", fd.RemoteAddr(), err)
 		return
 	}
-	p := newPeer(conn, srv.Protocols)
+	p := newPeer(fd, conn, srv.Protocols)
 	if ok, reason := srv.addPeer(conn.ID, p); !ok {
 		srvlog.DebugDetailf("Not adding %v (%v)\n", p, reason)
 		p.politeDisconnect(reason)
@@ -375,7 +376,7 @@ func (srv *Server) startPeer(fd net.Conn, dest *discover.Node) {
 	srvlog.Debugf("Added %v\n", p)
 	srvjslog.LogJson(&logger.P2PConnected{
 		RemoteId:            fmt.Sprintf("%x", conn.ID[:]),
-		RemoteAddress:       conn.RemoteAddr().String(),
+		RemoteAddress:       fd.RemoteAddr().String(),
 		RemoteVersionString: conn.Name,
 		NumConnections:      srv.PeerCount(),
 	})
diff --git a/p2p/server_test.go b/p2p/server_test.go
index c109fffb9..c348f5a9a 100644
--- a/p2p/server_test.go
+++ b/p2p/server_test.go
@@ -11,6 +11,7 @@ import (
 	"time"
 
 	"github.com/ethereum/go-ethereum/crypto"
+	"github.com/ethereum/go-ethereum/crypto/sha3"
 	"github.com/ethereum/go-ethereum/p2p/discover"
 )
 
@@ -23,8 +24,14 @@ func startTestServer(t *testing.T, pf newPeerHook) *Server {
 		newPeerHook: pf,
 		setupFunc: func(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) {
 			id := randomID()
+			rw := newRlpxFrameRW(fd, secrets{
+				MAC:        zero16,
+				AES:        zero16,
+				IngressMAC: sha3.NewKeccak256(),
+				EgressMAC:  sha3.NewKeccak256(),
+			})
 			return &conn{
-				frameRW:        newFrameRW(fd, msgWriteTimeout),
+				MsgReadWriter:  rw,
 				protoHandshake: &protoHandshake{ID: id, Version: baseProtocolVersion},
 			}, nil
 		},
@@ -143,9 +150,7 @@ func TestServerBroadcast(t *testing.T) {
 
 	// broadcast one message
 	srv.Broadcast("discard", 0, "foo")
-	goldbuf := new(bytes.Buffer)
-	writeMsg(goldbuf, NewMsg(16, "foo"))
-	golden := goldbuf.Bytes()
+	golden := unhex("66e94e166f0a2c3b884cfa59ca34")
 
 	// check that the message has been written everywhere
 	for i, conn := range conns {
-- 
cgit v1.2.3