aboutsummaryrefslogtreecommitdiffstats
path: root/p2p/handshake.go
diff options
context:
space:
mode:
Diffstat (limited to 'p2p/handshake.go')
-rw-r--r--p2p/handshake.go68
1 files changed, 42 insertions, 26 deletions
diff --git a/p2p/handshake.go b/p2p/handshake.go
index 5a259cd76..43361364f 100644
--- a/p2p/handshake.go
+++ b/p2p/handshake.go
@@ -68,50 +68,61 @@ type protoHandshake struct {
// setupConn starts a protocol session on the given connection.
// It runs the encryption handshake and the protocol handshake.
// If dial is non-nil, the connection the local node is the initiator.
-func setupConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) {
+// If atcap is true, the connection will be disconnected with DiscTooManyPeers
+// after the key exchange.
+func setupConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node, atcap bool) (*conn, error) {
if dial == nil {
- return setupInboundConn(fd, prv, our)
+ return setupInboundConn(fd, prv, our, atcap)
} else {
- return setupOutboundConn(fd, prv, our, dial)
+ return setupOutboundConn(fd, prv, our, dial, atcap)
}
}
-func setupInboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake) (*conn, error) {
+func setupInboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, atcap bool) (*conn, error) {
secrets, err := receiverEncHandshake(fd, prv, nil)
if err != nil {
return nil, fmt.Errorf("encryption handshake failed: %v", err)
}
-
- // Run the protocol handshake using authenticated messages.
rw := newRlpxFrameRW(fd, secrets)
- rhs, err := readProtocolHandshake(rw, our)
+ if atcap {
+ SendItems(rw, discMsg, DiscTooManyPeers)
+ return nil, errors.New("we have too many peers")
+ }
+ // Run the protocol handshake using authenticated messages.
+ rhs, err := readProtocolHandshake(rw, secrets.RemoteID, our)
if err != nil {
return nil, err
}
- if rhs.ID != secrets.RemoteID {
- return nil, errors.New("node ID in protocol handshake does not match encryption handshake")
- }
- // TODO: validate that handshake node ID matches
if err := Send(rw, handshakeMsg, our); err != nil {
- return nil, fmt.Errorf("protocol write error: %v", err)
+ return nil, fmt.Errorf("protocol handshake write error: %v", err)
}
return &conn{rw, rhs}, nil
}
-func setupOutboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) {
+func setupOutboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node, atcap bool) (*conn, error) {
secrets, err := initiatorEncHandshake(fd, prv, dial.ID, nil)
if err != nil {
return nil, fmt.Errorf("encryption handshake failed: %v", err)
}
-
- // Run the protocol handshake using authenticated messages.
rw := newRlpxFrameRW(fd, secrets)
- if err := Send(rw, handshakeMsg, our); err != nil {
- return nil, fmt.Errorf("protocol write error: %v", err)
+ if atcap {
+ SendItems(rw, discMsg, DiscTooManyPeers)
+ return nil, errors.New("we have too many peers")
}
- rhs, err := readProtocolHandshake(rw, our)
+ // Run the protocol handshake using authenticated messages.
+ //
+ // Note that even though writing the handshake is first, we prefer
+ // returning the handshake read error. If the remote side
+ // disconnects us early with a valid reason, we should return it
+ // as the error so it can be tracked elsewhere.
+ werr := make(chan error)
+ go func() { werr <- Send(rw, handshakeMsg, our) }()
+ rhs, err := readProtocolHandshake(rw, secrets.RemoteID, our)
if err != nil {
- return nil, fmt.Errorf("protocol handshake read error: %v", err)
+ return nil, err
+ }
+ if err := <-werr; err != nil {
+ return nil, fmt.Errorf("protocol handshake write error: %v", err)
}
if rhs.ID != dial.ID {
return nil, errors.New("dialed node id mismatch")
@@ -398,18 +409,17 @@ func xor(one, other []byte) (xor []byte) {
return xor
}
-func readProtocolHandshake(r MsgReader, our *protoHandshake) (*protoHandshake, error) {
- // read and handle remote handshake
- msg, err := r.ReadMsg()
+func readProtocolHandshake(rw MsgReadWriter, wantID discover.NodeID, our *protoHandshake) (*protoHandshake, error) {
+ msg, err := rw.ReadMsg()
if err != nil {
return nil, err
}
if msg.Code == discMsg {
// disconnect before protocol handshake is valid according to the
// spec and we send it ourself if Server.addPeer fails.
- var reason DiscReason
+ var reason [1]DiscReason
rlp.Decode(msg.Payload, &reason)
- return nil, reason
+ return nil, reason[0]
}
if msg.Code != handshakeMsg {
return nil, fmt.Errorf("expected handshake, got %x", msg.Code)
@@ -423,10 +433,16 @@ func readProtocolHandshake(r MsgReader, our *protoHandshake) (*protoHandshake, e
}
// validate handshake info
if hs.Version != our.Version {
- return nil, newPeerError(errP2PVersionMismatch, "required version %d, received %d\n", baseProtocolVersion, hs.Version)
+ SendItems(rw, discMsg, DiscIncompatibleVersion)
+ return nil, fmt.Errorf("required version %d, received %d\n", baseProtocolVersion, hs.Version)
}
if (hs.ID == discover.NodeID{}) {
- return nil, newPeerError(errPubkeyInvalid, "missing")
+ SendItems(rw, discMsg, DiscInvalidIdentity)
+ return nil, errors.New("invalid public key in handshake")
+ }
+ if hs.ID != wantID {
+ SendItems(rw, discMsg, DiscUnexpectedIdentity)
+ return nil, errors.New("handshake node ID does not match encryption handshake")
}
return &hs, nil
}