aboutsummaryrefslogtreecommitdiffstats
path: root/p2p/discover/udp.go
diff options
context:
space:
mode:
Diffstat (limited to 'p2p/discover/udp.go')
-rw-r--r--p2p/discover/udp.go222
1 files changed, 131 insertions, 91 deletions
diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go
index 37a044902..5ce4c43dc 100644
--- a/p2p/discover/udp.go
+++ b/p2p/discover/udp.go
@@ -67,6 +67,8 @@ const (
// RPC request structures
type (
ping struct {
+ senderKey *ecdsa.PublicKey // filled in by preverify
+
Version uint
From, To rpcEndpoint
Expiration uint64
@@ -155,8 +157,13 @@ func nodeToRPC(n *node) rpcNode {
return rpcNode{ID: ekey, IP: n.IP(), UDP: uint16(n.UDP()), TCP: uint16(n.TCP())}
}
+// packet is implemented by all protocol messages.
type packet interface {
- handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte) error
+ // preverify checks whether the packet is valid and should be handled at all.
+ preverify(t *udp, from *net.UDPAddr, fromID enode.ID, fromKey encPubkey) error
+ // handle handles the packet.
+ handle(t *udp, from *net.UDPAddr, fromID enode.ID, mac []byte)
+ // name returns the name of the packet for logging purposes.
name() string
}
@@ -177,43 +184,48 @@ type udp struct {
tab *Table
wg sync.WaitGroup
- addpending chan *pending
- gotreply chan reply
- closing chan struct{}
+ addReplyMatcher chan *replyMatcher
+ gotreply chan reply
+ closing chan struct{}
}
// pending represents a pending reply.
//
-// some implementations of the protocol wish to send more than one
-// reply packet to findnode. in general, any neighbors packet cannot
+// Some implementations of the protocol wish to send more than one
+// reply packet to findnode. In general, any neighbors packet cannot
// be matched up with a specific findnode packet.
//
-// our implementation handles this by storing a callback function for
-// each pending reply. incoming packets from a node are dispatched
-// to all the callback functions for that node.
-type pending struct {
+// Our implementation handles this by storing a callback function for
+// each pending reply. Incoming packets from a node are dispatched
+// to all callback functions for that node.
+type replyMatcher struct {
// these fields must match in the reply.
from enode.ID
+ ip net.IP
ptype byte
// time when the request must complete
deadline time.Time
- // callback is called when a matching reply arrives. if it returns
- // true, the callback is removed from the pending reply queue.
- // if it returns false, the reply is considered incomplete and
- // the callback will be invoked again for the next matching reply.
- callback func(resp interface{}) (done bool)
+ // callback is called when a matching reply arrives. If it returns matched == true, the
+ // reply was acceptable. The second return value indicates whether the callback should
+ // be removed from the pending reply queue. If it returns false, the reply is considered
+ // incomplete and the callback will be invoked again for the next matching reply.
+ callback replyMatchFunc
// errc receives nil when the callback indicates completion or an
// error if no further reply is received within the timeout.
errc chan<- error
}
+type replyMatchFunc func(interface{}) (matched bool, requestDone bool)
+
type reply struct {
from enode.ID
+ ip net.IP
ptype byte
- data interface{}
+ data packet
+
// loop indicates whether there was
// a matching request by sending on this channel.
matched chan<- bool
@@ -247,14 +259,14 @@ func ListenUDP(c conn, ln *enode.LocalNode, cfg Config) (*Table, error) {
func newUDP(c conn, ln *enode.LocalNode, cfg Config) (*Table, *udp, error) {
udp := &udp{
- conn: c,
- priv: cfg.PrivateKey,
- netrestrict: cfg.NetRestrict,
- localNode: ln,
- db: ln.Database(),
- closing: make(chan struct{}),
- gotreply: make(chan reply),
- addpending: make(chan *pending),
+ conn: c,
+ priv: cfg.PrivateKey,
+ netrestrict: cfg.NetRestrict,
+ localNode: ln,
+ db: ln.Database(),
+ closing: make(chan struct{}),
+ gotreply: make(chan reply),
+ addReplyMatcher: make(chan *replyMatcher),
}
tab, err := newTable(udp, ln.Database(), cfg.Bootnodes)
if err != nil {
@@ -304,35 +316,37 @@ func (t *udp) sendPing(toid enode.ID, toaddr *net.UDPAddr, callback func()) <-ch
errc <- err
return errc
}
- errc := t.pending(toid, pongPacket, func(p interface{}) bool {
- ok := bytes.Equal(p.(*pong).ReplyTok, hash)
- if ok && callback != nil {
+ // Add a matcher for the reply to the pending reply queue. Pongs are matched if they
+ // reference the ping we're about to send.
+ errc := t.pending(toid, toaddr.IP, pongPacket, func(p interface{}) (matched bool, requestDone bool) {
+ matched = bytes.Equal(p.(*pong).ReplyTok, hash)
+ if matched && callback != nil {
callback()
}
- return ok
+ return matched, matched
})
+ // Send the packet.
t.localNode.UDPContact(toaddr)
- t.write(toaddr, req.name(), packet)
+ t.write(toaddr, toid, req.name(), packet)
return errc
}
-func (t *udp) waitping(from enode.ID) error {
- return <-t.pending(from, pingPacket, func(interface{}) bool { return true })
-}
-
// findnode sends a findnode request to the given node and waits until
// the node has sent up to k neighbors.
func (t *udp) findnode(toid enode.ID, toaddr *net.UDPAddr, target encPubkey) ([]*node, error) {
// If we haven't seen a ping from the destination node for a while, it won't remember
// our endpoint proof and reject findnode. Solicit a ping first.
- if time.Since(t.db.LastPingReceived(toid)) > bondExpiration {
+ if time.Since(t.db.LastPingReceived(toid, toaddr.IP)) > bondExpiration {
t.ping(toid, toaddr)
- t.waitping(toid)
+ // Wait for them to ping back and process our pong.
+ time.Sleep(respTimeout)
}
+ // Add a matcher for 'neighbours' replies to the pending reply queue. The matcher is
+ // active until enough nodes have been received.
nodes := make([]*node, 0, bucketSize)
nreceived := 0
- errc := t.pending(toid, neighborsPacket, func(r interface{}) bool {
+ errc := t.pending(toid, toaddr.IP, neighborsPacket, func(r interface{}) (matched bool, requestDone bool) {
reply := r.(*neighbors)
for _, rn := range reply.Nodes {
nreceived++
@@ -343,22 +357,22 @@ func (t *udp) findnode(toid enode.ID, toaddr *net.UDPAddr, target encPubkey) ([]
}
nodes = append(nodes, n)
}
- return nreceived >= bucketSize
+ return true, nreceived >= bucketSize
})
- t.send(toaddr, findnodePacket, &findnode{
+ t.send(toaddr, toid, findnodePacket, &findnode{
Target: target,
Expiration: uint64(time.Now().Add(expiration).Unix()),
})
return nodes, <-errc
}
-// pending adds a reply callback to the pending reply queue.
-// see the documentation of type pending for a detailed explanation.
-func (t *udp) pending(id enode.ID, ptype byte, callback func(interface{}) bool) <-chan error {
+// pending adds a reply matcher to the pending reply queue.
+// see the documentation of type replyMatcher for a detailed explanation.
+func (t *udp) pending(id enode.ID, ip net.IP, ptype byte, callback replyMatchFunc) <-chan error {
ch := make(chan error, 1)
- p := &pending{from: id, ptype: ptype, callback: callback, errc: ch}
+ p := &replyMatcher{from: id, ip: ip, ptype: ptype, callback: callback, errc: ch}
select {
- case t.addpending <- p:
+ case t.addReplyMatcher <- p:
// loop will handle it
case <-t.closing:
ch <- errClosed
@@ -366,10 +380,12 @@ func (t *udp) pending(id enode.ID, ptype byte, callback func(interface{}) bool)
return ch
}
-func (t *udp) handleReply(from enode.ID, ptype byte, req packet) bool {
+// handleReply dispatches a reply packet, invoking reply matchers. It returns
+// whether any matcher considered the packet acceptable.
+func (t *udp) handleReply(from enode.ID, fromIP net.IP, ptype byte, req packet) bool {
matched := make(chan bool, 1)
select {
- case t.gotreply <- reply{from, ptype, req, matched}:
+ case t.gotreply <- reply{from, fromIP, ptype, req, matched}:
// loop will handle it
return <-matched
case <-t.closing:
@@ -385,8 +401,8 @@ func (t *udp) loop() {
var (
plist = list.New()
timeout = time.NewTimer(0)
- nextTimeout *pending // head of plist when timeout was last reset
- contTimeouts = 0 // number of continuous timeouts to do NTP checks
+ nextTimeout *replyMatcher // head of plist when timeout was last reset
+ contTimeouts = 0 // number of continuous timeouts to do NTP checks
ntpWarnTime = time.Unix(0, 0)
)
<-timeout.C // ignore first timeout
@@ -399,7 +415,7 @@ func (t *udp) loop() {
// Start the timer so it fires when the next pending reply has expired.
now := time.Now()
for el := plist.Front(); el != nil; el = el.Next() {
- nextTimeout = el.Value.(*pending)
+ nextTimeout = el.Value.(*replyMatcher)
if dist := nextTimeout.deadline.Sub(now); dist < 2*respTimeout {
timeout.Reset(dist)
return
@@ -420,25 +436,23 @@ func (t *udp) loop() {
select {
case <-t.closing:
for el := plist.Front(); el != nil; el = el.Next() {
- el.Value.(*pending).errc <- errClosed
+ el.Value.(*replyMatcher).errc <- errClosed
}
return
- case p := <-t.addpending:
+ case p := <-t.addReplyMatcher:
p.deadline = time.Now().Add(respTimeout)
plist.PushBack(p)
case r := <-t.gotreply:
- var matched bool
+ var matched bool // whether any replyMatcher considered the reply acceptable.
for el := plist.Front(); el != nil; el = el.Next() {
- p := el.Value.(*pending)
- if p.from == r.from && p.ptype == r.ptype {
- matched = true
- // Remove the matcher if its callback indicates
- // that all replies have been received. This is
- // required for packet types that expect multiple
- // reply packets.
- if p.callback(r.data) {
+ p := el.Value.(*replyMatcher)
+ if p.from == r.from && p.ptype == r.ptype && p.ip.Equal(r.ip) {
+ ok, requestDone := p.callback(r.data)
+ matched = matched || ok
+ // Remove the matcher if callback indicates that all replies have been received.
+ if requestDone {
p.errc <- nil
plist.Remove(el)
}
@@ -453,7 +467,7 @@ func (t *udp) loop() {
// Notify and remove callbacks whose deadline is in the past.
for el := plist.Front(); el != nil; el = el.Next() {
- p := el.Value.(*pending)
+ p := el.Value.(*replyMatcher)
if now.After(p.deadline) || now.Equal(p.deadline) {
p.errc <- errTimeout
plist.Remove(el)
@@ -504,17 +518,17 @@ func init() {
}
}
-func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req packet) ([]byte, error) {
+func (t *udp) send(toaddr *net.UDPAddr, toid enode.ID, ptype byte, req packet) ([]byte, error) {
packet, hash, err := encodePacket(t.priv, ptype, req)
if err != nil {
return hash, err
}
- return hash, t.write(toaddr, req.name(), packet)
+ return hash, t.write(toaddr, toid, req.name(), packet)
}
-func (t *udp) write(toaddr *net.UDPAddr, what string, packet []byte) error {
+func (t *udp) write(toaddr *net.UDPAddr, toid enode.ID, what string, packet []byte) error {
_, err := t.conn.WriteToUDP(packet, toaddr)
- log.Trace(">> "+what, "addr", toaddr, "err", err)
+ log.Trace(">> "+what, "id", toid, "addr", toaddr, "err", err)
return err
}
@@ -573,13 +587,19 @@ func (t *udp) readLoop(unhandled chan<- ReadPacket) {
}
func (t *udp) handlePacket(from *net.UDPAddr, buf []byte) error {
- packet, fromID, hash, err := decodePacket(buf)
+ packet, fromKey, hash, err := decodePacket(buf)
if err != nil {
log.Debug("Bad discv4 packet", "addr", from, "err", err)
return err
}
- err = packet.handle(t, from, fromID, hash)
- log.Trace("<< "+packet.name(), "addr", from, "err", err)
+ fromID := fromKey.id()
+ if err == nil {
+ err = packet.preverify(t, from, fromID, fromKey)
+ }
+ log.Trace("<< "+packet.name(), "id", fromID, "addr", from, "err", err)
+ if err == nil {
+ packet.handle(t, from, fromID, hash)
+ }
return err
}
@@ -615,54 +635,67 @@ func decodePacket(buf []byte) (packet, encPubkey, []byte, error) {
return req, fromKey, hash, err
}
-func (req *ping) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte) error {
+// Packet Handlers
+
+func (req *ping) preverify(t *udp, from *net.UDPAddr, fromID enode.ID, fromKey encPubkey) error {
if expired(req.Expiration) {
return errExpired
}
key, err := decodePubkey(fromKey)
if err != nil {
- return fmt.Errorf("invalid public key: %v", err)
+ return errors.New("invalid public key")
}
- t.send(from, pongPacket, &pong{
+ req.senderKey = key
+ return nil
+}
+
+func (req *ping) handle(t *udp, from *net.UDPAddr, fromID enode.ID, mac []byte) {
+ // Reply.
+ t.send(from, fromID, pongPacket, &pong{
To: makeEndpoint(from, req.From.TCP),
ReplyTok: mac,
Expiration: uint64(time.Now().Add(expiration).Unix()),
})
- n := wrapNode(enode.NewV4(key, from.IP, int(req.From.TCP), from.Port))
- t.handleReply(n.ID(), pingPacket, req)
- if time.Since(t.db.LastPongReceived(n.ID())) > bondExpiration {
- t.sendPing(n.ID(), from, func() { t.tab.addThroughPing(n) })
+
+ // Ping back if our last pong on file is too far in the past.
+ n := wrapNode(enode.NewV4(req.senderKey, from.IP, int(req.From.TCP), from.Port))
+ if time.Since(t.db.LastPongReceived(n.ID(), from.IP)) > bondExpiration {
+ t.sendPing(fromID, from, func() {
+ t.tab.addThroughPing(n)
+ })
} else {
t.tab.addThroughPing(n)
}
+
+ // Update node database and endpoint predictor.
+ t.db.UpdateLastPingReceived(n.ID(), from.IP, time.Now())
t.localNode.UDPEndpointStatement(from, &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)})
- t.db.UpdateLastPingReceived(n.ID(), time.Now())
- return nil
}
func (req *ping) name() string { return "PING/v4" }
-func (req *pong) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte) error {
+func (req *pong) preverify(t *udp, from *net.UDPAddr, fromID enode.ID, fromKey encPubkey) error {
if expired(req.Expiration) {
return errExpired
}
- fromID := fromKey.id()
- if !t.handleReply(fromID, pongPacket, req) {
+ if !t.handleReply(fromID, from.IP, pongPacket, req) {
return errUnsolicitedReply
}
- t.localNode.UDPEndpointStatement(from, &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)})
- t.db.UpdateLastPongReceived(fromID, time.Now())
return nil
}
+func (req *pong) handle(t *udp, from *net.UDPAddr, fromID enode.ID, mac []byte) {
+ t.localNode.UDPEndpointStatement(from, &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)})
+ t.db.UpdateLastPongReceived(fromID, from.IP, time.Now())
+}
+
func (req *pong) name() string { return "PONG/v4" }
-func (req *findnode) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte) error {
+func (req *findnode) preverify(t *udp, from *net.UDPAddr, fromID enode.ID, fromKey encPubkey) error {
if expired(req.Expiration) {
return errExpired
}
- fromID := fromKey.id()
- if time.Since(t.db.LastPongReceived(fromID)) > bondExpiration {
+ if time.Since(t.db.LastPongReceived(fromID, from.IP)) > bondExpiration {
// No endpoint proof pong exists, we don't process the packet. This prevents an
// attack vector where the discovery protocol could be used to amplify traffic in a
// DDOS attack. A malicious actor would send a findnode request with the IP address
@@ -671,43 +704,50 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []
// findnode) to the victim.
return errUnknownNode
}
+ return nil
+}
+
+func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID enode.ID, mac []byte) {
+ // Determine closest nodes.
target := enode.ID(crypto.Keccak256Hash(req.Target[:]))
t.tab.mutex.Lock()
closest := t.tab.closest(target, bucketSize).entries
t.tab.mutex.Unlock()
- p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())}
- var sent bool
// Send neighbors in chunks with at most maxNeighbors per packet
// to stay below the 1280 byte limit.
+ p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())}
+ var sent bool
for _, n := range closest {
if netutil.CheckRelayIP(from.IP, n.IP()) == nil {
p.Nodes = append(p.Nodes, nodeToRPC(n))
}
if len(p.Nodes) == maxNeighbors {
- t.send(from, neighborsPacket, &p)
+ t.send(from, fromID, neighborsPacket, &p)
p.Nodes = p.Nodes[:0]
sent = true
}
}
if len(p.Nodes) > 0 || !sent {
- t.send(from, neighborsPacket, &p)
+ t.send(from, fromID, neighborsPacket, &p)
}
- return nil
}
func (req *findnode) name() string { return "FINDNODE/v4" }
-func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte) error {
+func (req *neighbors) preverify(t *udp, from *net.UDPAddr, fromID enode.ID, fromKey encPubkey) error {
if expired(req.Expiration) {
return errExpired
}
- if !t.handleReply(fromKey.id(), neighborsPacket, req) {
+ if !t.handleReply(fromID, from.IP, neighborsPacket, req) {
return errUnsolicitedReply
}
return nil
}
+func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromID enode.ID, mac []byte) {
+}
+
func (req *neighbors) name() string { return "NEIGHBORS/v4" }
func expired(ts uint64) bool {