diff options
Diffstat (limited to 'p2p/peer.go')
-rw-r--r-- | p2p/peer.go | 466 |
1 files changed, 156 insertions, 310 deletions
diff --git a/p2p/peer.go b/p2p/peer.go index 86c4d7ab5..fb027c834 100644 --- a/p2p/peer.go +++ b/p2p/peer.go @@ -1,8 +1,7 @@ package p2p import ( - "bufio" - "bytes" + "errors" "fmt" "io" "io/ioutil" @@ -11,159 +10,78 @@ import ( "sync" "time" - "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/logger" + "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/rlp" ) -// peerAddr is the structure of a peer list element. -// It is also a valid net.Addr. -type peerAddr struct { - IP net.IP - Port uint64 - Pubkey []byte // optional -} - -func newPeerAddr(addr net.Addr, pubkey []byte) *peerAddr { - n := addr.Network() - if n != "tcp" && n != "tcp4" && n != "tcp6" { - // for testing with non-TCP - return &peerAddr{net.ParseIP("127.0.0.1"), 30303, pubkey} - } - ta := addr.(*net.TCPAddr) - return &peerAddr{ta.IP, uint64(ta.Port), pubkey} -} - -func (d peerAddr) Network() string { - if d.IP.To4() != nil { - return "tcp4" - } else { - return "tcp6" - } -} +const ( + baseProtocolVersion = 3 + baseProtocolLength = uint64(16) + baseProtocolMaxMsgSize = 10 * 1024 * 1024 -func (d peerAddr) String() string { - return fmt.Sprintf("%v:%d", d.IP, d.Port) -} + disconnectGracePeriod = 2 * time.Second + pingInterval = 15 * time.Second +) -func (d peerAddr) RlpData() interface{} { - return []interface{}{d.IP, d.Port, d.Pubkey} -} +const ( + // devp2p message codes + handshakeMsg = 0x00 + discMsg = 0x01 + pingMsg = 0x02 + pongMsg = 0x03 + getPeersMsg = 0x04 + peersMsg = 0x05 +) -// Peer represents a remote peer. +// Peer represents a connected remote node. type Peer struct { // Peers have all the log methods. // Use them to display messages related to the peer. *logger.Logger - infolock sync.Mutex - identity ClientIdentity - caps []Cap - listenAddr *peerAddr // what remote peer is listening on - dialAddr *peerAddr // non-nil if dialing - - // The mutex protects the connection - // so only one protocol can write at a time. - writeMu sync.Mutex - conn net.Conn - bufconn *bufio.ReadWriter - - // These fields maintain the running protocols. - protocols []Protocol - runBaseProtocol bool // for testing - - runlock sync.RWMutex // protects running - running map[string]*proto + rw *conn + running map[string]*protoRW protoWG sync.WaitGroup protoErr chan error closed chan struct{} disc chan DiscReason - - activity event.TypeMux // for activity events - - slot int // index into Server peer list - - // These fields are kept so base protocol can access them. - // TODO: this should be one or more interfaces - ourID ClientIdentity // client id of the Server - ourListenAddr *peerAddr // listen addr of Server, nil if not listening - newPeerAddr chan<- *peerAddr // tell server about received peers - otherPeers func() []*Peer // should return the list of all peers - pubkeyHook func(*peerAddr) error // called at end of handshake to validate pubkey } // NewPeer returns a peer for testing purposes. -func NewPeer(id ClientIdentity, caps []Cap) *Peer { - conn, _ := net.Pipe() - peer := newPeer(conn, nil, nil) - peer.setHandshakeInfo(id, nil, caps) - close(peer.closed) +func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer { + pipe, _ := net.Pipe() + conn := newConn(pipe, &protoHandshake{ID: id, Name: name, Caps: caps}) + peer := newPeer(conn, nil) + close(peer.closed) // ensures Disconnect doesn't block return peer } -func newServerPeer(server *Server, conn net.Conn, dialAddr *peerAddr) *Peer { - p := newPeer(conn, server.Protocols, dialAddr) - p.ourID = server.Identity - p.newPeerAddr = server.peerConnect - p.otherPeers = server.Peers - p.pubkeyHook = server.verifyPeer - p.runBaseProtocol = true - - // laddr can be updated concurrently by NAT traversal. - // newServerPeer must be called with the server lock held. - if server.laddr != nil { - p.ourListenAddr = newPeerAddr(server.laddr, server.Identity.Pubkey()) - } - return p -} - -func newPeer(conn net.Conn, protocols []Protocol, dialAddr *peerAddr) *Peer { - p := &Peer{ - Logger: logger.NewLogger("P2P " + conn.RemoteAddr().String()), - conn: conn, - dialAddr: dialAddr, - bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)), - protocols: protocols, - running: make(map[string]*proto), - disc: make(chan DiscReason), - protoErr: make(chan error), - closed: make(chan struct{}), - } - return p +// ID returns the node's public key. +func (p *Peer) ID() discover.NodeID { + return p.rw.ID } -// Identity returns the client identity of the remote peer. The -// identity can be nil if the peer has not yet completed the -// handshake. -func (p *Peer) Identity() ClientIdentity { - p.infolock.Lock() - defer p.infolock.Unlock() - return p.identity +// Name returns the node name that the remote node advertised. +func (p *Peer) Name() string { + return p.rw.Name } // Caps returns the capabilities (supported subprotocols) of the remote peer. func (p *Peer) Caps() []Cap { - p.infolock.Lock() - defer p.infolock.Unlock() - return p.caps -} - -func (p *Peer) setHandshakeInfo(id ClientIdentity, laddr *peerAddr, caps []Cap) { - p.infolock.Lock() - p.identity = id - p.listenAddr = laddr - p.caps = caps - p.infolock.Unlock() + // TODO: maybe return copy + return p.rw.Caps } // RemoteAddr returns the remote address of the network connection. func (p *Peer) RemoteAddr() net.Addr { - return p.conn.RemoteAddr() + return p.rw.RemoteAddr() } // LocalAddr returns the local address of the network connection. func (p *Peer) LocalAddr() net.Addr { - return p.conn.LocalAddr() + return p.rw.LocalAddr() } // Disconnect terminates the peer connection with the given reason. @@ -177,198 +95,166 @@ func (p *Peer) Disconnect(reason DiscReason) { // String implements fmt.Stringer. func (p *Peer) String() string { - kind := "inbound" - p.infolock.Lock() - if p.dialAddr != nil { - kind = "outbound" - } - p.infolock.Unlock() - return fmt.Sprintf("Peer(%p %v %s)", p, p.conn.RemoteAddr(), kind) + return fmt.Sprintf("Peer %.8x %v", p.rw.ID[:], p.RemoteAddr()) } -const ( - // maximum amount of time allowed for reading a message - msgReadTimeout = 5 * time.Second - // maximum amount of time allowed for writing a message - msgWriteTimeout = 5 * time.Second - // messages smaller than this many bytes will be read at - // once before passing them to a protocol. - wholePayloadSize = 64 * 1024 -) - -var ( - inactivityTimeout = 2 * time.Second - disconnectGracePeriod = 2 * time.Second -) +func newPeer(conn *conn, protocols []Protocol) *Peer { + logtag := fmt.Sprintf("Peer %.8x %v", conn.ID[:], conn.RemoteAddr()) + p := &Peer{ + Logger: logger.NewLogger(logtag), + rw: conn, + running: matchProtocols(protocols, conn.Caps, conn), + disc: make(chan DiscReason), + protoErr: make(chan error), + closed: make(chan struct{}), + } + return p +} -func (p *Peer) loop() (reason DiscReason, err error) { - defer p.activity.Stop() +func (p *Peer) run() DiscReason { + var readErr = make(chan error, 1) defer p.closeProtocols() defer close(p.closed) - defer p.conn.Close() - - // read loop - readMsg := make(chan Msg) - readErr := make(chan error) - readNext := make(chan bool, 1) - protoDone := make(chan struct{}, 1) - go p.readLoop(readMsg, readErr, readNext) - readNext <- true - - if p.runBaseProtocol { - p.startBaseProtocol() - } + p.startProtocols() + go func() { readErr <- p.readLoop() }() + + ping := time.NewTicker(pingInterval) + defer ping.Stop() + + // Wait for an error or disconnect. + var reason DiscReason loop: for { select { - case msg := <-readMsg: - // a new message has arrived. - var wait bool - if wait, err = p.dispatch(msg, protoDone); err != nil { - p.Errorf("msg dispatch error: %v\n", err) - reason = discReasonForError(err) - break loop - } - if !wait { - // Msg has already been read completely, continue with next message. - readNext <- true - } - p.activity.Post(time.Now()) - case <-protoDone: - // protocol has consumed the message payload, - // we can continue reading from the socket. - readNext <- true - + case <-ping.C: + go func() { + if err := EncodeMsg(p.rw, pingMsg, nil); err != nil { + p.protoErr <- err + return + } + }() case err := <-readErr: - // read failed. there is no need to run the - // polite disconnect sequence because the connection - // is probably dead anyway. - // TODO: handle write errors as well - return DiscNetworkError, err - case err = <-p.protoErr: + // We rely on protocols to abort if there is a write error. It + // might be more robust to handle them here as well. + p.DebugDetailf("Read error: %v\n", err) + p.rw.Close() + return DiscNetworkError + case err := <-p.protoErr: reason = discReasonForError(err) break loop case reason = <-p.disc: break loop } } + p.politeDisconnect(reason) - // wait for read loop to return. - close(readNext) + // Wait for readLoop. It will end because conn is now closed. <-readErr - // tell the remote end to disconnect + p.Debugf("Disconnected: %v\n", reason) + return reason +} + +func (p *Peer) politeDisconnect(reason DiscReason) { done := make(chan struct{}) go func() { - p.conn.SetDeadline(time.Now().Add(disconnectGracePeriod)) - p.writeMsg(NewMsg(discMsg, reason), disconnectGracePeriod) - io.Copy(ioutil.Discard, p.conn) + EncodeMsg(p.rw, discMsg, uint(reason)) + // Wait for the other side to close the connection. + // Discard any data that they send until then. + io.Copy(ioutil.Discard, p.rw) close(done) }() select { case <-done: case <-time.After(disconnectGracePeriod): } - return reason, err + p.rw.Close() } -func (p *Peer) readLoop(msgc chan<- Msg, errc chan<- error, unblock <-chan bool) { - for _ = range unblock { - p.conn.SetReadDeadline(time.Now().Add(msgReadTimeout)) - if msg, err := readMsg(p.bufconn); err != nil { - errc <- err - } else { - msgc <- msg +func (p *Peer) readLoop() error { + for { + msg, err := p.rw.ReadMsg() + if err != nil { + return err + } + if err = p.handle(msg); err != nil { + return err } } - close(errc) -} - -func (p *Peer) dispatch(msg Msg, protoDone chan struct{}) (wait bool, err error) { - proto, err := p.getProto(msg.Code) - if err != nil { - return false, err - } - if msg.Size <= wholePayloadSize { - // optimization: msg is small enough, read all - // of it and move on to the next message - buf, err := ioutil.ReadAll(msg.Payload) + return nil +} + +func (p *Peer) handle(msg Msg) error { + switch { + case msg.Code == pingMsg: + msg.Discard() + go EncodeMsg(p.rw, pongMsg) + case msg.Code == discMsg: + var reason DiscReason + // no need to discard or for error checking, we'll close the + // connection after this. + rlp.Decode(msg.Payload, &reason) + p.Disconnect(DiscRequested) + return discRequestedError(reason) + case msg.Code < baseProtocolLength: + // ignore other base protocol messages + return msg.Discard() + default: + // it's a subprotocol message + proto, err := p.getProto(msg.Code) if err != nil { - return false, err + return fmt.Errorf("msg code out of range: %v", msg.Code) } - msg.Payload = bytes.NewReader(buf) - proto.in <- msg - } else { - wait = true - pr := &eofSignal{msg.Payload, int64(msg.Size), protoDone} - msg.Payload = pr proto.in <- msg } - return wait, nil -} - -func (p *Peer) startBaseProtocol() { - p.runlock.Lock() - defer p.runlock.Unlock() - p.running[""] = p.startProto(0, Protocol{ - Length: baseProtocolLength, - Run: runBaseProtocol, - }) + return nil } -// startProtocols starts matching named subprotocols. -func (p *Peer) startSubprotocols(caps []Cap) { +// matchProtocols creates structures for matching named subprotocols. +func matchProtocols(protocols []Protocol, caps []Cap, rw MsgReadWriter) map[string]*protoRW { sort.Sort(capsByName(caps)) - - p.runlock.Lock() - defer p.runlock.Unlock() offset := baseProtocolLength + result := make(map[string]*protoRW) outer: for _, cap := range caps { - for _, proto := range p.protocols { - if proto.Name == cap.Name && - proto.Version == cap.Version && - p.running[cap.Name] == nil { - p.running[cap.Name] = p.startProto(offset, proto) + for _, proto := range protocols { + if proto.Name == cap.Name && proto.Version == cap.Version && result[cap.Name] == nil { + result[cap.Name] = &protoRW{Protocol: proto, offset: offset, in: make(chan Msg), w: rw} offset += proto.Length continue outer } } } + return result } -func (p *Peer) startProto(offset uint64, impl Protocol) *proto { - rw := &proto{ - in: make(chan Msg), - offset: offset, - maxcode: impl.Length, - peer: p, +func (p *Peer) startProtocols() { + for _, proto := range p.running { + proto := proto + p.DebugDetailf("Starting protocol %s/%d\n", proto.Name, proto.Version) + p.protoWG.Add(1) + go func() { + err := proto.Run(p, proto) + if err == nil { + p.DebugDetailf("Protocol %s/%d returned\n", proto.Name, proto.Version) + err = errors.New("protocol returned") + } else { + p.DebugDetailf("Protocol %s/%d error: %v\n", proto.Name, proto.Version, err) + } + select { + case p.protoErr <- err: + case <-p.closed: + } + p.protoWG.Done() + }() } - p.protoWG.Add(1) - go func() { - err := impl.Run(p, rw) - if err == nil { - p.Infof("protocol %q returned", impl.Name) - err = newPeerError(errMisc, "protocol returned") - } else { - p.Errorf("protocol %q error: %v\n", impl.Name, err) - } - select { - case p.protoErr <- err: - case <-p.closed: - } - p.protoWG.Done() - }() - return rw } // getProto finds the protocol responsible for handling // the given message code. -func (p *Peer) getProto(code uint64) (*proto, error) { - p.runlock.RLock() - defer p.runlock.RUnlock() +func (p *Peer) getProto(code uint64) (*protoRW, error) { for _, proto := range p.running { - if code >= proto.offset && code < proto.offset+proto.maxcode { + if code >= proto.offset && code < proto.offset+proto.Length { return proto, nil } } @@ -376,60 +262,43 @@ func (p *Peer) getProto(code uint64) (*proto, error) { } func (p *Peer) closeProtocols() { - p.runlock.RLock() for _, p := range p.running { close(p.in) } - p.runlock.RUnlock() p.protoWG.Wait() } // writeProtoMsg sends the given message on behalf of the given named protocol. +// this exists because of Server.Broadcast. func (p *Peer) writeProtoMsg(protoName string, msg Msg) error { - p.runlock.RLock() proto, ok := p.running[protoName] - p.runlock.RUnlock() if !ok { return fmt.Errorf("protocol %s not handled by peer", protoName) } - if msg.Code >= proto.maxcode { + if msg.Code >= proto.Length { return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName) } msg.Code += proto.offset - return p.writeMsg(msg, msgWriteTimeout) + return p.rw.WriteMsg(msg) } -// writeMsg writes a message to the connection. -func (p *Peer) writeMsg(msg Msg, timeout time.Duration) error { - p.writeMu.Lock() - defer p.writeMu.Unlock() - p.conn.SetWriteDeadline(time.Now().Add(timeout)) - if err := writeMsg(p.bufconn, msg); err != nil { - return newPeerError(errWrite, "%v", err) - } - return p.bufconn.Flush() -} +type protoRW struct { + Protocol -type proto struct { - name string - in chan Msg - maxcode, offset uint64 - peer *Peer + in chan Msg + offset uint64 + w MsgWriter } -func (rw *proto) WriteMsg(msg Msg) error { - if msg.Code >= rw.maxcode { +func (rw *protoRW) WriteMsg(msg Msg) error { + if msg.Code >= rw.Length { return newPeerError(errInvalidMsgCode, "not handled") } msg.Code += rw.offset - return rw.peer.writeMsg(msg, msgWriteTimeout) -} - -func (rw *proto) EncodeMsg(code uint64, data ...interface{}) error { - return rw.WriteMsg(NewMsg(code, data)) + return rw.w.WriteMsg(msg) } -func (rw *proto) ReadMsg() (Msg, error) { +func (rw *protoRW) ReadMsg() (Msg, error) { msg, ok := <-rw.in if !ok { return msg, io.EOF @@ -437,26 +306,3 @@ func (rw *proto) ReadMsg() (Msg, error) { msg.Code -= rw.offset return msg, nil } - -// eofSignal wraps a reader with eof signaling. the eof channel is -// closed when the wrapped reader returns an error or when count bytes -// have been read. -// -type eofSignal struct { - wrapped io.Reader - count int64 - eof chan<- struct{} -} - -// note: when using eofSignal to detect whether a message payload -// has been read, Read might not be called for zero sized messages. - -func (r *eofSignal) Read(buf []byte) (int, error) { - n, err := r.wrapped.Read(buf) - r.count -= int64(n) - if (err != nil || r.count <= 0) && r.eof != nil { - r.eof <- struct{}{} // tell Peer that msg has been consumed - r.eof = nil - } - return n, err -} |