aboutsummaryrefslogtreecommitdiffstats
path: root/p2p/peer.go
diff options
context:
space:
mode:
Diffstat (limited to 'p2p/peer.go')
-rw-r--r--p2p/peer.go229
1 files changed, 59 insertions, 170 deletions
diff --git a/p2p/peer.go b/p2p/peer.go
index fd5bec7d5..b9bf0fd73 100644
--- a/p2p/peer.go
+++ b/p2p/peer.go
@@ -33,37 +33,14 @@ const (
peersMsg = 0x05
)
-// handshake is the RLP structure of the protocol handshake.
-type handshake struct {
- Version uint64
- Name string
- Caps []Cap
- ListenPort uint64
- NodeID discover.NodeID
-}
-
// Peer represents a connected remote node.
type Peer struct {
// Peers have all the log methods.
// Use them to display messages related to the peer.
*logger.Logger
- infoMu sync.Mutex
- name string
- caps []Cap
-
- ourID, remoteID *discover.NodeID
- ourName string
-
- rw *frameRW
-
- // These fields maintain the running protocols.
- protocols []Protocol
- runlock sync.RWMutex // protects running
- running map[string]*proto
-
- // disables protocol handshake, for testing
- noHandshake bool
+ rw *conn
+ running map[string]*protoRW
protoWG sync.WaitGroup
protoErr chan error
@@ -73,36 +50,27 @@ type Peer struct {
// NewPeer returns a peer for testing purposes.
func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer {
- conn, _ := net.Pipe()
- peer := newPeer(conn, nil, "", nil, &id)
- peer.setHandshakeInfo(name, caps)
+ pipe, _ := net.Pipe()
+ conn := newConn(pipe, &protoHandshake{ID: id, Name: name, Caps: caps})
+ peer := newPeer(conn, nil)
close(peer.closed) // ensures Disconnect doesn't block
return peer
}
// ID returns the node's public key.
func (p *Peer) ID() discover.NodeID {
- return *p.remoteID
+ return p.rw.ID
}
// Name returns the node name that the remote node advertised.
func (p *Peer) Name() string {
- // this needs a lock because the information is part of the
- // protocol handshake.
- p.infoMu.Lock()
- name := p.name
- p.infoMu.Unlock()
- return name
+ return p.rw.Name
}
// Caps returns the capabilities (supported subprotocols) of the remote peer.
func (p *Peer) Caps() []Cap {
- // this needs a lock because the information is part of the
- // protocol handshake.
- p.infoMu.Lock()
- caps := p.caps
- p.infoMu.Unlock()
- return caps
+ // TODO: maybe return copy
+ return p.rw.Caps
}
// RemoteAddr returns the remote address of the network connection.
@@ -126,30 +94,20 @@ func (p *Peer) Disconnect(reason DiscReason) {
// String implements fmt.Stringer.
func (p *Peer) String() string {
- return fmt.Sprintf("Peer %.8x %v", p.remoteID[:], p.RemoteAddr())
+ return fmt.Sprintf("Peer %.8x %v", p.rw.ID[:], p.RemoteAddr())
}
-func newPeer(conn net.Conn, protocols []Protocol, ourName string, ourID, remoteID *discover.NodeID) *Peer {
- logtag := fmt.Sprintf("Peer %.8x %v", remoteID[:], conn.RemoteAddr())
- return &Peer{
- Logger: logger.NewLogger(logtag),
- rw: newFrameRW(conn, msgWriteTimeout),
- ourID: ourID,
- ourName: ourName,
- remoteID: remoteID,
- protocols: protocols,
- running: make(map[string]*proto),
- disc: make(chan DiscReason),
- protoErr: make(chan error),
- closed: make(chan struct{}),
+func newPeer(conn *conn, protocols []Protocol) *Peer {
+ logtag := fmt.Sprintf("Peer %.8x %v", conn.ID[:], conn.RemoteAddr())
+ p := &Peer{
+ Logger: logger.NewLogger(logtag),
+ rw: conn,
+ running: matchProtocols(protocols, conn.Caps, conn),
+ disc: make(chan DiscReason),
+ protoErr: make(chan error),
+ closed: make(chan struct{}),
}
-}
-
-func (p *Peer) setHandshakeInfo(name string, caps []Cap) {
- p.infoMu.Lock()
- p.name = name
- p.caps = caps
- p.infoMu.Unlock()
+ return p
}
func (p *Peer) run() DiscReason {
@@ -157,16 +115,9 @@ func (p *Peer) run() DiscReason {
defer p.closeProtocols()
defer close(p.closed)
+ p.startProtocols()
go func() { readErr <- p.readLoop() }()
- if !p.noHandshake {
- if err := writeProtocolHandshake(p.rw, p.ourName, *p.ourID, p.protocols); err != nil {
- p.DebugDetailf("Protocol handshake error: %v\n", err)
- p.rw.Close()
- return DiscProtocolError
- }
- }
-
// Wait for an error or disconnect.
var reason DiscReason
select {
@@ -206,11 +157,6 @@ func (p *Peer) politeDisconnect(reason DiscReason) {
}
func (p *Peer) readLoop() error {
- if !p.noHandshake {
- if err := readProtocolHandshake(p, p.rw); err != nil {
- return err
- }
- }
for {
msg, err := p.rw.ReadMsg()
if err != nil {
@@ -249,105 +195,51 @@ func (p *Peer) handle(msg Msg) error {
return nil
}
-func readProtocolHandshake(p *Peer, rw MsgReadWriter) error {
- // read and handle remote handshake
- msg, err := rw.ReadMsg()
- if err != nil {
- return err
- }
- if msg.Code == discMsg {
- // disconnect before protocol handshake is valid according to the
- // spec and we send it ourself if Server.addPeer fails.
- var reason DiscReason
- rlp.Decode(msg.Payload, &reason)
- return discRequestedError(reason)
- }
- if msg.Code != handshakeMsg {
- return newPeerError(errProtocolBreach, "expected handshake, got %x", msg.Code)
- }
- if msg.Size > baseProtocolMaxMsgSize {
- return newPeerError(errInvalidMsg, "message too big")
- }
- var hs handshake
- if err := msg.Decode(&hs); err != nil {
- return err
- }
- // validate handshake info
- if hs.Version != baseProtocolVersion {
- return newPeerError(errP2PVersionMismatch, "required version %d, received %d\n",
- baseProtocolVersion, hs.Version)
- }
- if hs.NodeID == *p.remoteID {
- return newPeerError(errPubkeyForbidden, "node ID mismatch")
- }
- // TODO: remove Caps with empty name
- p.setHandshakeInfo(hs.Name, hs.Caps)
- p.startSubprotocols(hs.Caps)
- return nil
-}
-
-func writeProtocolHandshake(w MsgWriter, name string, id discover.NodeID, ps []Protocol) error {
- var caps []interface{}
- for _, proto := range ps {
- caps = append(caps, proto.cap())
- }
- return EncodeMsg(w, handshakeMsg, baseProtocolVersion, name, caps, 0, id)
-}
-
-// startProtocols starts matching named subprotocols.
-func (p *Peer) startSubprotocols(caps []Cap) {
+// matchProtocols creates structures for matching named subprotocols.
+func matchProtocols(protocols []Protocol, caps []Cap, rw MsgReadWriter) map[string]*protoRW {
sort.Sort(capsByName(caps))
- p.runlock.Lock()
- defer p.runlock.Unlock()
offset := baseProtocolLength
+ result := make(map[string]*protoRW)
outer:
for _, cap := range caps {
- for _, proto := range p.protocols {
- if proto.Name == cap.Name &&
- proto.Version == cap.Version &&
- p.running[cap.Name] == nil {
- p.running[cap.Name] = p.startProto(offset, proto)
+ for _, proto := range protocols {
+ if proto.Name == cap.Name && proto.Version == cap.Version && result[cap.Name] == nil {
+ result[cap.Name] = &protoRW{Protocol: proto, offset: offset, in: make(chan Msg), w: rw}
offset += proto.Length
continue outer
}
}
}
+ return result
}
-func (p *Peer) startProto(offset uint64, impl Protocol) *proto {
- p.DebugDetailf("Starting protocol %s/%d\n", impl.Name, impl.Version)
- rw := &proto{
- name: impl.Name,
- in: make(chan Msg),
- offset: offset,
- maxcode: impl.Length,
- w: p.rw,
+func (p *Peer) startProtocols() {
+ for _, proto := range p.running {
+ proto := proto
+ p.DebugDetailf("Starting protocol %s/%d\n", proto.Name, proto.Version)
+ p.protoWG.Add(1)
+ go func() {
+ err := proto.Run(p, proto)
+ if err == nil {
+ p.DebugDetailf("Protocol %s/%d returned\n", proto.Name, proto.Version)
+ err = errors.New("protocol returned")
+ } else {
+ p.DebugDetailf("Protocol %s/%d error: %v\n", proto.Name, proto.Version, err)
+ }
+ select {
+ case p.protoErr <- err:
+ case <-p.closed:
+ }
+ p.protoWG.Done()
+ }()
}
- p.protoWG.Add(1)
- go func() {
- err := impl.Run(p, rw)
- if err == nil {
- p.DebugDetailf("Protocol %s/%d returned\n", impl.Name, impl.Version)
- err = errors.New("protocol returned")
- } else {
- p.DebugDetailf("Protocol %s/%d error: %v\n", impl.Name, impl.Version, err)
- }
- select {
- case p.protoErr <- err:
- case <-p.closed:
- }
- p.protoWG.Done()
- }()
- return rw
}
// getProto finds the protocol responsible for handling
// the given message code.
-func (p *Peer) getProto(code uint64) (*proto, error) {
- p.runlock.RLock()
- defer p.runlock.RUnlock()
+func (p *Peer) getProto(code uint64) (*protoRW, error) {
for _, proto := range p.running {
- if code >= proto.offset && code < proto.offset+proto.maxcode {
+ if code >= proto.offset && code < proto.offset+proto.Length {
return proto, nil
}
}
@@ -355,46 +247,43 @@ func (p *Peer) getProto(code uint64) (*proto, error) {
}
func (p *Peer) closeProtocols() {
- p.runlock.RLock()
for _, p := range p.running {
close(p.in)
}
- p.runlock.RUnlock()
p.protoWG.Wait()
}
// writeProtoMsg sends the given message on behalf of the given named protocol.
// this exists because of Server.Broadcast.
func (p *Peer) writeProtoMsg(protoName string, msg Msg) error {
- p.runlock.RLock()
proto, ok := p.running[protoName]
- p.runlock.RUnlock()
if !ok {
return fmt.Errorf("protocol %s not handled by peer", protoName)
}
- if msg.Code >= proto.maxcode {
+ if msg.Code >= proto.Length {
return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName)
}
msg.Code += proto.offset
return p.rw.WriteMsg(msg)
}
-type proto struct {
- name string
- in chan Msg
- maxcode, offset uint64
- w MsgWriter
+type protoRW struct {
+ Protocol
+
+ in chan Msg
+ offset uint64
+ w MsgWriter
}
-func (rw *proto) WriteMsg(msg Msg) error {
- if msg.Code >= rw.maxcode {
+func (rw *protoRW) WriteMsg(msg Msg) error {
+ if msg.Code >= rw.Length {
return newPeerError(errInvalidMsgCode, "not handled")
}
msg.Code += rw.offset
return rw.w.WriteMsg(msg)
}
-func (rw *proto) ReadMsg() (Msg, error) {
+func (rw *protoRW) ReadMsg() (Msg, error) {
msg, ok := <-rw.in
if !ok {
return msg, io.EOF