diff options
Diffstat (limited to 'p2p/discover/udp.go')
-rw-r--r-- | p2p/discover/udp.go | 218 |
1 files changed, 136 insertions, 82 deletions
diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go index 69e9f3c2e..e9ede1397 100644 --- a/p2p/discover/udp.go +++ b/p2p/discover/udp.go @@ -16,13 +16,18 @@ import ( var log = logger.NewLogger("P2P Discovery") +const Version = 3 + // Errors var ( - errPacketTooSmall = errors.New("too small") - errBadHash = errors.New("bad hash") - errExpired = errors.New("expired") - errTimeout = errors.New("RPC timeout") - errClosed = errors.New("socket closed") + errPacketTooSmall = errors.New("too small") + errBadHash = errors.New("bad hash") + errExpired = errors.New("expired") + errBadVersion = errors.New("version mismatch") + errUnsolicitedReply = errors.New("unsolicited reply") + errUnknownNode = errors.New("unknown node") + errTimeout = errors.New("RPC timeout") + errClosed = errors.New("socket closed") ) // Timeouts @@ -45,6 +50,7 @@ const ( // RPC request structures type ( ping struct { + Version uint // must match Version IP string // our IP Port uint16 // our port Expiration uint64 @@ -76,14 +82,27 @@ type rpcNode struct { ID NodeID } +type packet interface { + handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error +} + +type conn interface { + ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) + WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error) + Close() error + LocalAddr() net.Addr +} + // udp implements the RPC protocol. type udp struct { - conn *net.UDPConn - priv *ecdsa.PrivateKey + conn conn + priv *ecdsa.PrivateKey + addpending chan *pending - replies chan reply - closing chan struct{} - nat nat.Interface + gotreply chan reply + + closing chan struct{} + nat nat.Interface *Table } @@ -120,6 +139,9 @@ type reply struct { from NodeID ptype byte data interface{} + // loop indicates whether there was + // a matching request by sending on this channel. + matched chan<- bool } // ListenUDP returns a new table that listens for UDP packets on laddr. @@ -132,15 +154,20 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface) (*Table if err != nil { return nil, err } + tab, _ := newUDP(priv, conn, natm) + log.Infoln("Listening,", tab.self) + return tab, nil +} + +func newUDP(priv *ecdsa.PrivateKey, c conn, natm nat.Interface) (*Table, *udp) { udp := &udp{ - conn: conn, + conn: c, priv: priv, closing: make(chan struct{}), + gotreply: make(chan reply), addpending: make(chan *pending), - replies: make(chan reply), } - - realaddr := conn.LocalAddr().(*net.UDPAddr) + realaddr := c.LocalAddr().(*net.UDPAddr) if natm != nil { if !realaddr.IP.IsLoopback() { go nat.Map(natm, udp.closing, "udp", realaddr.Port, realaddr.Port, "ethereum discovery") @@ -151,11 +178,9 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface) (*Table } } udp.Table = newTable(udp, PubkeyID(&priv.PublicKey), realaddr) - go udp.loop() go udp.readLoop() - log.Infoln("Listening, ", udp.self) - return udp.Table, nil + return udp.Table, udp } func (t *udp) close() { @@ -165,10 +190,11 @@ func (t *udp) close() { } // ping sends a ping message to the given node and waits for a reply. -func (t *udp) ping(e *Node) error { +func (t *udp) ping(toid NodeID, toaddr *net.UDPAddr) error { // TODO: maybe check for ReplyTo field in callback to measure RTT - errc := t.pending(e.ID, pongPacket, func(interface{}) bool { return true }) - t.send(e, pingPacket, ping{ + errc := t.pending(toid, pongPacket, func(interface{}) bool { return true }) + t.send(toaddr, pingPacket, ping{ + Version: Version, IP: t.self.IP.String(), Port: uint16(t.self.TCPPort), Expiration: uint64(time.Now().Add(expiration).Unix()), @@ -176,12 +202,16 @@ func (t *udp) ping(e *Node) error { return <-errc } +func (t *udp) waitping(from NodeID) error { + return <-t.pending(from, pingPacket, func(interface{}) bool { return true }) +} + // findnode sends a findnode request to the given node and waits until // the node has sent up to k neighbors. -func (t *udp) findnode(to *Node, target NodeID) ([]*Node, error) { +func (t *udp) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) { nodes := make([]*Node, 0, bucketSize) nreceived := 0 - errc := t.pending(to.ID, neighborsPacket, func(r interface{}) bool { + errc := t.pending(toid, neighborsPacket, func(r interface{}) bool { reply := r.(*neighbors) for _, n := range reply.Nodes { nreceived++ @@ -191,8 +221,7 @@ func (t *udp) findnode(to *Node, target NodeID) ([]*Node, error) { } return nreceived >= bucketSize }) - - t.send(to, findnodePacket, findnode{ + t.send(toaddr, findnodePacket, findnode{ Target: target, Expiration: uint64(time.Now().Add(expiration).Unix()), }) @@ -214,6 +243,17 @@ func (t *udp) pending(id NodeID, ptype byte, callback func(interface{}) bool) <- return ch } +func (t *udp) handleReply(from NodeID, ptype byte, req packet) bool { + matched := make(chan bool) + select { + case t.gotreply <- reply{from, ptype, req, matched}: + // loop will handle it + return <-matched + case <-t.closing: + return false + } +} + // loop runs in its own goroutin. it keeps track of // the refresh timer and the pending reply queue. func (t *udp) loop() { @@ -244,6 +284,7 @@ func (t *udp) loop() { for _, p := range pending { p.errc <- errClosed } + pending = nil return case p := <-t.addpending: @@ -251,18 +292,21 @@ func (t *udp) loop() { pending = append(pending, p) rearmTimeout() - case reply := <-t.replies: - // run matching callbacks, remove if they return false. + case r := <-t.gotreply: + var matched bool for i := 0; i < len(pending); i++ { - p := pending[i] - if reply.from == p.from && reply.ptype == p.ptype && p.callback(reply.data) { - p.errc <- nil - copy(pending[i:], pending[i+1:]) - pending = pending[:len(pending)-1] - i-- + if p := pending[i]; p.from == r.from && p.ptype == r.ptype { + matched = true + if p.callback(r.data) { + // callback indicates the request is done, remove it. + p.errc <- nil + copy(pending[i:], pending[i+1:]) + pending = pending[:len(pending)-1] + i-- + } } } - rearmTimeout() + r.matched <- matched case now := <-timeout.C: // notify and remove callbacks whose deadline is in the past. @@ -287,33 +331,38 @@ const ( var headSpace = make([]byte, headSize) -func (t *udp) send(to *Node, ptype byte, req interface{}) error { +func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req interface{}) error { + packet, err := encodePacket(t.priv, ptype, req) + if err != nil { + return err + } + log.DebugDetailf(">>> %v %T %v\n", toaddr, req, req) + if _, err = t.conn.WriteToUDP(packet, toaddr); err != nil { + log.DebugDetailln("UDP send failed:", err) + } + return err +} + +func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) ([]byte, error) { b := new(bytes.Buffer) b.Write(headSpace) b.WriteByte(ptype) if err := rlp.Encode(b, req); err != nil { log.Errorln("error encoding packet:", err) - return err + return nil, err } - packet := b.Bytes() - sig, err := crypto.Sign(crypto.Sha3(packet[headSize:]), t.priv) + sig, err := crypto.Sign(crypto.Sha3(packet[headSize:]), priv) if err != nil { log.Errorln("could not sign packet:", err) - return err + return nil, err } copy(packet[macSize:], sig) // add the hash to the front. Note: this doesn't protect the // packet in any way. Our public key will be part of this hash in - // the future. + // The future. copy(packet, crypto.Sha3(packet[macSize:])) - - toaddr := &net.UDPAddr{IP: to.IP, Port: to.DiscPort} - log.DebugDetailf(">>> %v %T %v\n", toaddr, req, req) - if _, err = t.conn.WriteToUDP(packet, toaddr); err != nil { - log.DebugDetailln("UDP send failed:", err) - } - return err + return packet, nil } // readLoop runs in its own goroutine. it handles incoming UDP packets. @@ -325,29 +374,34 @@ func (t *udp) readLoop() { if err != nil { return } - if err := t.packetIn(from, buf[:nbytes]); err != nil { + packet, fromID, hash, err := decodePacket(buf[:nbytes]) + if err != nil { log.Debugf("Bad packet from %v: %v\n", from, err) + continue } + log.DebugDetailf("<<< %v %T %v\n", from, packet, packet) + go func() { + if err := packet.handle(t, from, fromID, hash); err != nil { + log.Debugf("error handling %T from %v: %v", packet, from, err) + } + }() } } -func (t *udp) packetIn(from *net.UDPAddr, buf []byte) error { +func decodePacket(buf []byte) (packet, NodeID, []byte, error) { if len(buf) < headSize+1 { - return errPacketTooSmall + return nil, NodeID{}, nil, errPacketTooSmall } hash, sig, sigdata := buf[:macSize], buf[macSize:headSize], buf[headSize:] shouldhash := crypto.Sha3(buf[macSize:]) if !bytes.Equal(hash, shouldhash) { - return errBadHash + return nil, NodeID{}, nil, errBadHash } fromID, err := recoverNodeID(crypto.Sha3(buf[headSize:]), sig) if err != nil { - return err - } - - var req interface { - handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error + return nil, NodeID{}, hash, err } + var req packet switch ptype := sigdata[0]; ptype { case pingPacket: req = new(ping) @@ -358,31 +412,27 @@ func (t *udp) packetIn(from *net.UDPAddr, buf []byte) error { case neighborsPacket: req = new(neighbors) default: - return fmt.Errorf("unknown type: %d", ptype) + return nil, fromID, hash, fmt.Errorf("unknown type: %d", ptype) } - if err := rlp.Decode(bytes.NewReader(sigdata[1:]), req); err != nil { - return err - } - log.DebugDetailf("<<< %v %T %v\n", from, req, req) - return req.handle(t, from, fromID, hash) + err = rlp.Decode(bytes.NewReader(sigdata[1:]), req) + return req, fromID, hash, err } func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error { if expired(req.Expiration) { return errExpired } - t.mutex.Lock() - // Note: we're ignoring the provided IP address right now - n := t.bumpOrAdd(fromID, from) - if req.Port != 0 { - n.TCPPort = int(req.Port) + if req.Version != Version { + return errBadVersion } - t.mutex.Unlock() - - t.send(n, pongPacket, pong{ + t.send(from, pongPacket, pong{ ReplyTok: mac, Expiration: uint64(time.Now().Add(expiration).Unix()), }) + if !t.handleReply(fromID, pingPacket, req) { + // Note: we're ignoring the provided IP address right now + t.bond(true, fromID, from, req.Port) + } return nil } @@ -390,11 +440,9 @@ func (req *pong) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) er if expired(req.Expiration) { return errExpired } - t.mutex.Lock() - t.bump(fromID) - t.mutex.Unlock() - - t.replies <- reply{fromID, pongPacket, req} + if !t.handleReply(fromID, pongPacket, req) { + return errUnsolicitedReply + } return nil } @@ -402,12 +450,21 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte if expired(req.Expiration) { return errExpired } + if t.db.get(fromID) == nil { + // No bond exists, we don't process the packet. This prevents + // an attack vector where the discovery protocol could be used + // to amplify traffic in a DDOS attack. A malicious actor + // would send a findnode request with the IP address and UDP + // port of the target as the source address. The recipient of + // the findnode packet would then send a neighbors packet + // (which is a much bigger packet than findnode) to the victim. + return errUnknownNode + } t.mutex.Lock() - e := t.bumpOrAdd(fromID, from) closest := t.closest(req.Target, bucketSize).entries t.mutex.Unlock() - t.send(e, neighborsPacket, neighbors{ + t.send(from, neighborsPacket, neighbors{ Nodes: closest, Expiration: uint64(time.Now().Add(expiration).Unix()), }) @@ -418,12 +475,9 @@ func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byt if expired(req.Expiration) { return errExpired } - t.mutex.Lock() - t.bump(fromID) - t.add(req.Nodes) - t.mutex.Unlock() - - t.replies <- reply{fromID, neighborsPacket, req} + if !t.handleReply(fromID, neighborsPacket, req) { + return errUnsolicitedReply + } return nil } |