aboutsummaryrefslogtreecommitdiffstats
path: root/p2p/peer.go
diff options
context:
space:
mode:
Diffstat (limited to 'p2p/peer.go')
-rw-r--r--p2p/peer.go37
1 files changed, 21 insertions, 16 deletions
diff --git a/p2p/peer.go b/p2p/peer.go
index 1fa8264a3..b61cf96da 100644
--- a/p2p/peer.go
+++ b/p2p/peer.go
@@ -1,6 +1,7 @@
package p2p
import (
+ "errors"
"fmt"
"io"
"io/ioutil"
@@ -71,7 +72,8 @@ type Peer struct {
runlock sync.RWMutex // protects running
running map[string]*proto
- protocolHandshakeEnabled bool
+ // disables protocol handshake, for testing
+ noHandshake bool
protoWG sync.WaitGroup
protoErr chan error
@@ -134,11 +136,11 @@ func (p *Peer) Disconnect(reason DiscReason) {
// String implements fmt.Stringer.
func (p *Peer) String() string {
- return fmt.Sprintf("Peer %.8x %v", p.remoteID, p.RemoteAddr())
+ return fmt.Sprintf("Peer %.8x %v", p.remoteID[:], p.RemoteAddr())
}
func newPeer(conn net.Conn, protocols []Protocol, ourName string, ourID, remoteID *discover.NodeID) *Peer {
- logtag := fmt.Sprintf("Peer %.8x %v", remoteID, conn.RemoteAddr())
+ logtag := fmt.Sprintf("Peer %.8x %v", remoteID[:], conn.RemoteAddr())
return &Peer{
Logger: logger.NewLogger(logtag),
rw: newFrameRW(conn, msgWriteTimeout),
@@ -164,33 +166,35 @@ func (p *Peer) run() DiscReason {
var readErr = make(chan error, 1)
defer p.closeProtocols()
defer close(p.closed)
- defer p.rw.Close()
- // start the read loop
go func() { readErr <- p.readLoop() }()
- if p.protocolHandshakeEnabled {
+ if !p.noHandshake {
if err := writeProtocolHandshake(p.rw, p.ourName, *p.ourID, p.protocols); err != nil {
p.DebugDetailf("Protocol handshake error: %v\n", err)
+ p.rw.Close()
return DiscProtocolError
}
}
- // wait for an error or disconnect
+ // Wait for an error or disconnect.
var reason DiscReason
select {
case err := <-readErr:
// We rely on protocols to abort if there is a write error. It
// might be more robust to handle them here as well.
p.DebugDetailf("Read error: %v\n", err)
- reason = DiscNetworkError
+ p.rw.Close()
+ return DiscNetworkError
+
case err := <-p.protoErr:
reason = discReasonForError(err)
case reason = <-p.disc:
}
- if reason != DiscNetworkError {
- p.politeDisconnect(reason)
- }
+ p.politeDisconnect(reason)
+
+ // Wait for readLoop. It will end because conn is now closed.
+ <-readErr
p.Debugf("Disconnected: %v\n", reason)
return reason
}
@@ -198,9 +202,9 @@ func (p *Peer) run() DiscReason {
func (p *Peer) politeDisconnect(reason DiscReason) {
done := make(chan struct{})
go func() {
- // send reason
EncodeMsg(p.rw, discMsg, uint(reason))
- // discard any data that might arrive
+ // Wait for the other side to close the connection.
+ // Discard any data that they send until then.
io.Copy(ioutil.Discard, p.rw)
close(done)
}()
@@ -208,10 +212,11 @@ func (p *Peer) politeDisconnect(reason DiscReason) {
case <-done:
case <-time.After(disconnectGracePeriod):
}
+ p.rw.Close()
}
func (p *Peer) readLoop() error {
- if p.protocolHandshakeEnabled {
+ if !p.noHandshake {
if err := readProtocolHandshake(p, p.rw); err != nil {
return err
}
@@ -264,7 +269,7 @@ func readProtocolHandshake(p *Peer, rw MsgReadWriter) error {
return newPeerError(errProtocolBreach, "expected handshake, got %x", msg.Code)
}
if msg.Size > baseProtocolMaxMsgSize {
- return newPeerError(errMisc, "message too big")
+ return newPeerError(errInvalidMsg, "message too big")
}
var hs handshake
if err := msg.Decode(&hs); err != nil {
@@ -326,7 +331,7 @@ func (p *Peer) startProto(offset uint64, impl Protocol) *proto {
err := impl.Run(p, rw)
if err == nil {
p.DebugDetailf("Protocol %s/%d returned\n", impl.Name, impl.Version)
- err = newPeerError(errMisc, "protocol returned")
+ err = errors.New("protocol returned")
} else {
p.DebugDetailf("Protocol %s/%d error: %v\n", impl.Name, impl.Version, err)
}