From f38052c499c1fee61423efeddb1f52677f1442e9 Mon Sep 17 00:00:00 2001 From: Felix Lange Date: Tue, 4 Nov 2014 13:21:44 +0100 Subject: p2p: rework protocol API --- p2p/protocol.go | 353 +++++++++++++++++++++++++++++++------------------------- 1 file changed, 196 insertions(+), 157 deletions(-) (limited to 'p2p/protocol.go') diff --git a/p2p/protocol.go b/p2p/protocol.go index 5d05ced7d..ccc275287 100644 --- a/p2p/protocol.go +++ b/p2p/protocol.go @@ -2,43 +2,101 @@ package p2p import ( "bytes" - "fmt" "net" "sort" - "sync" "time" + + "github.com/ethereum/go-ethereum/ethutil" ) +// Protocol is implemented by P2P subprotocols. type Protocol interface { - Start() - Stop() - HandleIn(*Msg, chan *Msg) - HandleOut(*Msg) bool + // Start is called when the protocol becomes active. + // It should read and write messages from rw. + // Messages must be fully consumed. + // + // The connection is closed when Start returns. It should return + // any protocol-level error (such as an I/O error) that is + // encountered. + Start(peer *Peer, rw MsgReadWriter) error + + // Offset should return the number of message codes + // used by the protocol. Offset() MsgCode - Name() string +} + +type MsgReader interface { + ReadMsg() (Msg, error) +} + +type MsgWriter interface { + WriteMsg(Msg) error +} + +// MsgReadWriter is passed to protocols. Protocol implementations can +// use it to write messages back to a connected peer. +type MsgReadWriter interface { + MsgReader + MsgWriter +} + +type MsgHandler func(code MsgCode, data *ethutil.Value) error + +// MsgLoop reads messages off the given reader and +// calls the handler function for each decoded message until +// it returns an error or the peer connection is closed. +// +// If a message is larger than the given maximum size, RunProtocol +// returns an appropriate error.n +func MsgLoop(r MsgReader, maxsize uint32, handler MsgHandler) error { + for { + msg, err := r.ReadMsg() + if err != nil { + return err + } + if msg.Size > maxsize { + return NewPeerError(InvalidMsg, "size %d exceeds maximum size of %d", msg.Size, maxsize) + } + value, err := msg.Data() + if err != nil { + return err + } + if err := handler(msg.Code, value); err != nil { + return err + } + } +} + +// the ÐΞVp2p base protocol +type baseProtocol struct { + rw MsgReadWriter + peer *Peer +} + +type bpMsg struct { + code MsgCode + data *ethutil.Value } const ( - P2PVersion = 0 - pingTimeout = 2 - pingGracePeriod = 2 + p2pVersion = 0 + pingTimeout = 2 * time.Second + pingGracePeriod = 2 * time.Second ) const ( - HandshakeMsg = iota - DiscMsg - PingMsg - PongMsg - GetPeersMsg - PeersMsg - offset = 16 + // message codes + handshakeMsg = iota + discMsg + pingMsg + pongMsg + getPeersMsg + peersMsg ) -type ProtocolState uint8 - const ( - nullState = iota - handshakeReceived + baseProtocolOffset MsgCode = 16 + baseProtocolMaxMsgSize = 500 * 1024 ) type DiscReason byte @@ -62,7 +120,7 @@ const ( DiscSubprotocolError = 0x10 ) -var discReasonToString = map[DiscReason]string{ +var discReasonToString = [DiscSubprotocolError + 1]string{ DiscRequested: "Disconnect requested", DiscNetworkError: "Network error", DiscProtocolError: "Breach of protocol", @@ -82,197 +140,178 @@ func (d DiscReason) String() string { if len(discReasonToString) < int(d) { return "Unknown" } - return discReasonToString[d] } -type BaseProtocol struct { - peer *Peer - state ProtocolState - stateLock sync.RWMutex +func (bp *baseProtocol) Ping() { } -func NewBaseProtocol(peer *Peer) *BaseProtocol { - self := &BaseProtocol{ - peer: peer, - } - - return self +func (bp *baseProtocol) Offset() MsgCode { + return baseProtocolOffset } -func (self *BaseProtocol) Start() { - if self.peer != nil { - self.peer.Write("", self.peer.Server().Handshake()) - go self.peer.Messenger().PingPong( - pingTimeout*time.Second, - pingGracePeriod*time.Second, - self.Ping, - self.Timeout, - ) +func (bp *baseProtocol) Start(peer *Peer, rw MsgReadWriter) error { + bp.peer, bp.rw = peer, rw + + // Do the handshake. + // TODO: disconnect is valid before handshake, too. + rw.WriteMsg(bp.peer.server.handshakeMsg()) + msg, err := rw.ReadMsg() + if err != nil { + return err + } + if msg.Code != handshakeMsg { + return NewPeerError(ProtocolBreach, " first message must be handshake") + } + data, err := msg.Data() + if err != nil { + return NewPeerError(InvalidMsg, "%v", err) + } + if err := bp.handleHandshake(data); err != nil { + return err } -} -func (self *BaseProtocol) Stop() { + msgin := make(chan bpMsg) + done := make(chan error, 1) + go func() { + done <- MsgLoop(rw, baseProtocolMaxMsgSize, + func(code MsgCode, data *ethutil.Value) error { + msgin <- bpMsg{code, data} + return nil + }) + }() + return bp.loop(msgin, done) } -func (self *BaseProtocol) Ping() { - msg, _ := NewMsg(PingMsg) - self.peer.Write("", msg) +func (bp *baseProtocol) loop(msgin <-chan bpMsg, quit <-chan error) error { + logger.Debugf("pingpong keepalive started at %v\n", time.Now()) + messenger := bp.rw.(*proto).messenger + pingTimer := time.NewTimer(pingTimeout) + pinged := true + + for { + select { + case msg := <-msgin: + if err := bp.handle(msg.code, msg.data); err != nil { + return err + } + case err := <-quit: + return err + case <-messenger.pulse: + pingTimer.Reset(pingTimeout) + pinged = false + case <-pingTimer.C: + if pinged { + return NewPeerError(PingTimeout, "") + } + logger.Debugf("pinging at %v\n", time.Now()) + if err := bp.rw.WriteMsg(NewMsg(pingMsg)); err != nil { + return NewPeerError(WriteError, "%v", err) + } + pinged = true + pingTimer.Reset(pingTimeout) + } + } } -func (self *BaseProtocol) Timeout() { - self.peerError(PingTimeout, "") -} +func (bp *baseProtocol) handle(code MsgCode, data *ethutil.Value) error { + switch code { + case handshakeMsg: + return NewPeerError(ProtocolBreach, " extra handshake received") -func (self *BaseProtocol) Name() string { - return "" -} + case discMsg: + logger.Infof("Disconnect requested from peer %v, reason", DiscReason(data.Get(0).Uint())) + bp.peer.server.PeerDisconnect() <- DisconnectRequest{ + addr: bp.peer.Address, + reason: DiscRequested, + } -func (self *BaseProtocol) Offset() MsgCode { - return offset -} + case pingMsg: + return bp.rw.WriteMsg(NewMsg(pongMsg)) -func (self *BaseProtocol) CheckState(state ProtocolState) bool { - self.stateLock.RLock() - self.stateLock.RUnlock() - if self.state != state { - return false - } else { - return true - } -} + case pongMsg: + // reply for ping -func (self *BaseProtocol) HandleIn(msg *Msg, response chan *Msg) { - if msg.Code() == HandshakeMsg { - self.handleHandshake(msg) - } else { - if !self.CheckState(handshakeReceived) { - self.peerError(ProtocolBreach, "message code %v not allowed", msg.Code()) - close(response) - return - } - switch msg.Code() { - case DiscMsg: - logger.Infof("Disconnect requested from peer %v, reason", DiscReason(msg.Data().Get(0).Uint())) - self.peer.Server().PeerDisconnect() <- DisconnectRequest{ - addr: self.peer.Address, - reason: DiscRequested, - } - case PingMsg: - out, _ := NewMsg(PongMsg) - response <- out - case PongMsg: - case GetPeersMsg: - // Peer asked for list of connected peers - if out, err := self.peer.Server().PeersMessage(); err != nil { - response <- out + case getPeersMsg: + // Peer asked for list of connected peers. + peersRLP := bp.peer.server.encodedPeerList() + if peersRLP != nil { + msg := Msg{ + Code: peersMsg, + Size: uint32(len(peersRLP)), + Payload: bytes.NewReader(peersRLP), } - case PeersMsg: - self.handlePeers(msg) - default: - self.peerError(InvalidMsgCode, "unknown message code %v", msg.Code()) + return bp.rw.WriteMsg(msg) } - } - close(response) -} -func (self *BaseProtocol) HandleOut(msg *Msg) (allowed bool) { - // somewhat overly paranoid - allowed = msg.Code() == HandshakeMsg || msg.Code() == DiscMsg || msg.Code() < self.Offset() && self.CheckState(handshakeReceived) - return -} + case peersMsg: + bp.handlePeers(data) -func (self *BaseProtocol) peerError(errorCode ErrorCode, format string, v ...interface{}) { - err := NewPeerError(errorCode, format, v...) - logger.Warnln(err) - fmt.Println(self.peer, err) - if self.peer != nil { - self.peer.PeerErrorChan() <- err + default: + return NewPeerError(InvalidMsgCode, "unknown message code %v", code) } + return nil } -func (self *BaseProtocol) handlePeers(msg *Msg) { - it := msg.Data().NewIterator() +func (bp *baseProtocol) handlePeers(data *ethutil.Value) { + it := data.NewIterator() for it.Next() { ip := net.IP(it.Value().Get(0).Bytes()) port := it.Value().Get(1).Uint() address := &net.TCPAddr{IP: ip, Port: int(port)} - go self.peer.Server().PeerConnect(address) + go bp.peer.server.PeerConnect(address) } } -func (self *BaseProtocol) handleHandshake(msg *Msg) { - self.stateLock.Lock() - defer self.stateLock.Unlock() - if self.state != nullState { - self.peerError(ProtocolBreach, "extra handshake") - return - } - - c := msg.Data() - +func (bp *baseProtocol) handleHandshake(c *ethutil.Value) error { var ( - p2pVersion = c.Get(0).Uint() - id = c.Get(1).Str() - caps = c.Get(2) - port = c.Get(3).Uint() - pubkey = c.Get(4).Bytes() + remoteVersion = c.Get(0).Uint() + id = c.Get(1).Str() + caps = c.Get(2) + port = c.Get(3).Uint() + pubkey = c.Get(4).Bytes() ) - fmt.Printf("handshake received %v, %v, %v, %v, %v ", p2pVersion, id, caps, port, pubkey) - // Check correctness of p2p protocol version - if p2pVersion != P2PVersion { - self.peerError(P2PVersionMismatch, "Require protocol %d, received %d\n", P2PVersion, p2pVersion) - return + if remoteVersion != p2pVersion { + return NewPeerError(P2PVersionMismatch, "Require protocol %d, received %d\n", p2pVersion, remoteVersion) } // Handle the pub key (validation, uniqueness) if len(pubkey) == 0 { - self.peerError(PubkeyMissing, "not supplied in handshake.") - return + return NewPeerError(PubkeyMissing, "not supplied in handshake.") } if len(pubkey) != 64 { - self.peerError(PubkeyInvalid, "require 512 bit, got %v", len(pubkey)*8) - return + return NewPeerError(PubkeyInvalid, "require 512 bit, got %v", len(pubkey)*8) } - // Self connect detection - if bytes.Compare(self.peer.Server().ClientIdentity().Pubkey()[1:], pubkey) == 0 { - self.peerError(PubkeyForbidden, "not allowed to connect to self") - return + // self connect detection + if bytes.Compare(bp.peer.server.ClientIdentity().Pubkey()[1:], pubkey) == 0 { + return NewPeerError(PubkeyForbidden, "not allowed to connect to bp") } // register pubkey on server. this also sets the pubkey on the peer (need lock) - if err := self.peer.Server().RegisterPubkey(self.peer, pubkey); err != nil { - self.peerError(PubkeyForbidden, err.Error()) - return + if err := bp.peer.server.RegisterPubkey(bp.peer, pubkey); err != nil { + return NewPeerError(PubkeyForbidden, err.Error()) } // check port - if self.peer.Inbound { + if bp.peer.Inbound { uint16port := uint16(port) - if self.peer.Port > 0 && self.peer.Port != uint16port { - self.peerError(PortMismatch, "port mismatch: %v != %v", self.peer.Port, port) - return + if bp.peer.Port > 0 && bp.peer.Port != uint16port { + return NewPeerError(PortMismatch, "port mismatch: %v != %v", bp.peer.Port, port) } else { - self.peer.Port = uint16port + bp.peer.Port = uint16port } } capsIt := caps.NewIterator() for capsIt.Next() { cap := capsIt.Value().Str() - self.peer.Caps = append(self.peer.Caps, cap) + bp.peer.Caps = append(bp.peer.Caps, cap) } - sort.Strings(self.peer.Caps) - self.peer.Messenger().AddProtocols(self.peer.Caps) - - self.peer.Id = id - - self.state = handshakeReceived - - //p.ethereum.PushPeer(p) - // p.ethereum.reactor.Post("peerList", p.ethereum.Peers()) - return + sort.Strings(bp.peer.Caps) + bp.rw.(*proto).messenger.setRemoteProtocols(bp.peer.Caps) + bp.peer.Id = id + return nil } -- cgit v1.2.3