diff options
Diffstat (limited to 'p2p/handshake.go')
-rw-r--r-- | p2p/handshake.go | 68 |
1 files changed, 42 insertions, 26 deletions
diff --git a/p2p/handshake.go b/p2p/handshake.go index 031064407..43361364f 100644 --- a/p2p/handshake.go +++ b/p2p/handshake.go @@ -68,50 +68,61 @@ type protoHandshake struct { // setupConn starts a protocol session on the given connection. // It runs the encryption handshake and the protocol handshake. // If dial is non-nil, the connection the local node is the initiator. -func setupConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) { +// If atcap is true, the connection will be disconnected with DiscTooManyPeers +// after the key exchange. +func setupConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node, atcap bool) (*conn, error) { if dial == nil { - return setupInboundConn(fd, prv, our) + return setupInboundConn(fd, prv, our, atcap) } else { - return setupOutboundConn(fd, prv, our, dial) + return setupOutboundConn(fd, prv, our, dial, atcap) } } -func setupInboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake) (*conn, error) { +func setupInboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, atcap bool) (*conn, error) { secrets, err := receiverEncHandshake(fd, prv, nil) if err != nil { return nil, fmt.Errorf("encryption handshake failed: %v", err) } - - // Run the protocol handshake using authenticated messages. rw := newRlpxFrameRW(fd, secrets) - rhs, err := readProtocolHandshake(rw, our) + if atcap { + SendItems(rw, discMsg, DiscTooManyPeers) + return nil, errors.New("we have too many peers") + } + // Run the protocol handshake using authenticated messages. + rhs, err := readProtocolHandshake(rw, secrets.RemoteID, our) if err != nil { return nil, err } - if rhs.ID != secrets.RemoteID { - return nil, errors.New("node ID in protocol handshake does not match encryption handshake") - } - // TODO: validate that handshake node ID matches if err := Send(rw, handshakeMsg, our); err != nil { - return nil, fmt.Errorf("protocol write error: %v", err) + return nil, fmt.Errorf("protocol handshake write error: %v", err) } return &conn{rw, rhs}, nil } -func setupOutboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) { +func setupOutboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node, atcap bool) (*conn, error) { secrets, err := initiatorEncHandshake(fd, prv, dial.ID, nil) if err != nil { return nil, fmt.Errorf("encryption handshake failed: %v", err) } - - // Run the protocol handshake using authenticated messages. rw := newRlpxFrameRW(fd, secrets) - if err := Send(rw, handshakeMsg, our); err != nil { - return nil, fmt.Errorf("protocol write error: %v", err) + if atcap { + SendItems(rw, discMsg, DiscTooManyPeers) + return nil, errors.New("we have too many peers") } - rhs, err := readProtocolHandshake(rw, our) + // Run the protocol handshake using authenticated messages. + // + // Note that even though writing the handshake is first, we prefer + // returning the handshake read error. If the remote side + // disconnects us early with a valid reason, we should return it + // as the error so it can be tracked elsewhere. + werr := make(chan error) + go func() { werr <- Send(rw, handshakeMsg, our) }() + rhs, err := readProtocolHandshake(rw, secrets.RemoteID, our) if err != nil { - return nil, fmt.Errorf("protocol handshake read error: %v", err) + return nil, err + } + if err := <-werr; err != nil { + return nil, fmt.Errorf("protocol handshake write error: %v", err) } if rhs.ID != dial.ID { return nil, errors.New("dialed node id mismatch") @@ -398,18 +409,17 @@ func xor(one, other []byte) (xor []byte) { return xor } -func readProtocolHandshake(r MsgReader, our *protoHandshake) (*protoHandshake, error) { - // read and handle remote handshake - msg, err := r.ReadMsg() +func readProtocolHandshake(rw MsgReadWriter, wantID discover.NodeID, our *protoHandshake) (*protoHandshake, error) { + msg, err := rw.ReadMsg() if err != nil { return nil, err } if msg.Code == discMsg { // disconnect before protocol handshake is valid according to the // spec and we send it ourself if Server.addPeer fails. - var reason DiscReason + var reason [1]DiscReason rlp.Decode(msg.Payload, &reason) - return nil, discRequestedError(reason) + return nil, reason[0] } if msg.Code != handshakeMsg { return nil, fmt.Errorf("expected handshake, got %x", msg.Code) @@ -423,10 +433,16 @@ func readProtocolHandshake(r MsgReader, our *protoHandshake) (*protoHandshake, e } // validate handshake info if hs.Version != our.Version { - return nil, newPeerError(errP2PVersionMismatch, "required version %d, received %d\n", baseProtocolVersion, hs.Version) + SendItems(rw, discMsg, DiscIncompatibleVersion) + return nil, fmt.Errorf("required version %d, received %d\n", baseProtocolVersion, hs.Version) } if (hs.ID == discover.NodeID{}) { - return nil, newPeerError(errPubkeyInvalid, "missing") + SendItems(rw, discMsg, DiscInvalidIdentity) + return nil, errors.New("invalid public key in handshake") + } + if hs.ID != wantID { + SendItems(rw, discMsg, DiscUnexpectedIdentity) + return nil, errors.New("handshake node ID does not match encryption handshake") } return &hs, nil } |