aboutsummaryrefslogtreecommitdiffstats
path: root/p2p/discover/udp.go
diff options
context:
space:
mode:
Diffstat (limited to 'p2p/discover/udp.go')
-rw-r--r--p2p/discover/udp.go214
1 files changed, 130 insertions, 84 deletions
diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go
index 738a01fb7..e9ede1397 100644
--- a/p2p/discover/udp.go
+++ b/p2p/discover/udp.go
@@ -20,12 +20,14 @@ const Version = 3
// Errors
var (
- errPacketTooSmall = errors.New("too small")
- errBadHash = errors.New("bad hash")
- errExpired = errors.New("expired")
- errBadVersion = errors.New("version mismatch")
- errTimeout = errors.New("RPC timeout")
- errClosed = errors.New("socket closed")
+ errPacketTooSmall = errors.New("too small")
+ errBadHash = errors.New("bad hash")
+ errExpired = errors.New("expired")
+ errBadVersion = errors.New("version mismatch")
+ errUnsolicitedReply = errors.New("unsolicited reply")
+ errUnknownNode = errors.New("unknown node")
+ errTimeout = errors.New("RPC timeout")
+ errClosed = errors.New("socket closed")
)
// Timeouts
@@ -80,14 +82,27 @@ type rpcNode struct {
ID NodeID
}
+type packet interface {
+ handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error
+}
+
+type conn interface {
+ ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error)
+ WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error)
+ Close() error
+ LocalAddr() net.Addr
+}
+
// udp implements the RPC protocol.
type udp struct {
- conn *net.UDPConn
- priv *ecdsa.PrivateKey
+ conn conn
+ priv *ecdsa.PrivateKey
+
addpending chan *pending
- replies chan reply
- closing chan struct{}
- nat nat.Interface
+ gotreply chan reply
+
+ closing chan struct{}
+ nat nat.Interface
*Table
}
@@ -124,6 +139,9 @@ type reply struct {
from NodeID
ptype byte
data interface{}
+ // loop indicates whether there was
+ // a matching request by sending on this channel.
+ matched chan<- bool
}
// ListenUDP returns a new table that listens for UDP packets on laddr.
@@ -136,15 +154,20 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface) (*Table
if err != nil {
return nil, err
}
+ tab, _ := newUDP(priv, conn, natm)
+ log.Infoln("Listening,", tab.self)
+ return tab, nil
+}
+
+func newUDP(priv *ecdsa.PrivateKey, c conn, natm nat.Interface) (*Table, *udp) {
udp := &udp{
- conn: conn,
+ conn: c,
priv: priv,
closing: make(chan struct{}),
+ gotreply: make(chan reply),
addpending: make(chan *pending),
- replies: make(chan reply),
}
-
- realaddr := conn.LocalAddr().(*net.UDPAddr)
+ realaddr := c.LocalAddr().(*net.UDPAddr)
if natm != nil {
if !realaddr.IP.IsLoopback() {
go nat.Map(natm, udp.closing, "udp", realaddr.Port, realaddr.Port, "ethereum discovery")
@@ -155,11 +178,9 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface) (*Table
}
}
udp.Table = newTable(udp, PubkeyID(&priv.PublicKey), realaddr)
-
go udp.loop()
go udp.readLoop()
- log.Infoln("Listening, ", udp.self)
- return udp.Table, nil
+ return udp.Table, udp
}
func (t *udp) close() {
@@ -169,10 +190,10 @@ func (t *udp) close() {
}
// ping sends a ping message to the given node and waits for a reply.
-func (t *udp) ping(e *Node) error {
+func (t *udp) ping(toid NodeID, toaddr *net.UDPAddr) error {
// TODO: maybe check for ReplyTo field in callback to measure RTT
- errc := t.pending(e.ID, pongPacket, func(interface{}) bool { return true })
- t.send(e, pingPacket, ping{
+ errc := t.pending(toid, pongPacket, func(interface{}) bool { return true })
+ t.send(toaddr, pingPacket, ping{
Version: Version,
IP: t.self.IP.String(),
Port: uint16(t.self.TCPPort),
@@ -181,12 +202,16 @@ func (t *udp) ping(e *Node) error {
return <-errc
}
+func (t *udp) waitping(from NodeID) error {
+ return <-t.pending(from, pingPacket, func(interface{}) bool { return true })
+}
+
// findnode sends a findnode request to the given node and waits until
// the node has sent up to k neighbors.
-func (t *udp) findnode(to *Node, target NodeID) ([]*Node, error) {
+func (t *udp) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) {
nodes := make([]*Node, 0, bucketSize)
nreceived := 0
- errc := t.pending(to.ID, neighborsPacket, func(r interface{}) bool {
+ errc := t.pending(toid, neighborsPacket, func(r interface{}) bool {
reply := r.(*neighbors)
for _, n := range reply.Nodes {
nreceived++
@@ -196,8 +221,7 @@ func (t *udp) findnode(to *Node, target NodeID) ([]*Node, error) {
}
return nreceived >= bucketSize
})
-
- t.send(to, findnodePacket, findnode{
+ t.send(toaddr, findnodePacket, findnode{
Target: target,
Expiration: uint64(time.Now().Add(expiration).Unix()),
})
@@ -219,6 +243,17 @@ func (t *udp) pending(id NodeID, ptype byte, callback func(interface{}) bool) <-
return ch
}
+func (t *udp) handleReply(from NodeID, ptype byte, req packet) bool {
+ matched := make(chan bool)
+ select {
+ case t.gotreply <- reply{from, ptype, req, matched}:
+ // loop will handle it
+ return <-matched
+ case <-t.closing:
+ return false
+ }
+}
+
// loop runs in its own goroutin. it keeps track of
// the refresh timer and the pending reply queue.
func (t *udp) loop() {
@@ -249,6 +284,7 @@ func (t *udp) loop() {
for _, p := range pending {
p.errc <- errClosed
}
+ pending = nil
return
case p := <-t.addpending:
@@ -256,18 +292,21 @@ func (t *udp) loop() {
pending = append(pending, p)
rearmTimeout()
- case reply := <-t.replies:
- // run matching callbacks, remove if they return false.
+ case r := <-t.gotreply:
+ var matched bool
for i := 0; i < len(pending); i++ {
- p := pending[i]
- if reply.from == p.from && reply.ptype == p.ptype && p.callback(reply.data) {
- p.errc <- nil
- copy(pending[i:], pending[i+1:])
- pending = pending[:len(pending)-1]
- i--
+ if p := pending[i]; p.from == r.from && p.ptype == r.ptype {
+ matched = true
+ if p.callback(r.data) {
+ // callback indicates the request is done, remove it.
+ p.errc <- nil
+ copy(pending[i:], pending[i+1:])
+ pending = pending[:len(pending)-1]
+ i--
+ }
}
}
- rearmTimeout()
+ r.matched <- matched
case now := <-timeout.C:
// notify and remove callbacks whose deadline is in the past.
@@ -292,33 +331,38 @@ const (
var headSpace = make([]byte, headSize)
-func (t *udp) send(to *Node, ptype byte, req interface{}) error {
+func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req interface{}) error {
+ packet, err := encodePacket(t.priv, ptype, req)
+ if err != nil {
+ return err
+ }
+ log.DebugDetailf(">>> %v %T %v\n", toaddr, req, req)
+ if _, err = t.conn.WriteToUDP(packet, toaddr); err != nil {
+ log.DebugDetailln("UDP send failed:", err)
+ }
+ return err
+}
+
+func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) ([]byte, error) {
b := new(bytes.Buffer)
b.Write(headSpace)
b.WriteByte(ptype)
if err := rlp.Encode(b, req); err != nil {
log.Errorln("error encoding packet:", err)
- return err
+ return nil, err
}
-
packet := b.Bytes()
- sig, err := crypto.Sign(crypto.Sha3(packet[headSize:]), t.priv)
+ sig, err := crypto.Sign(crypto.Sha3(packet[headSize:]), priv)
if err != nil {
log.Errorln("could not sign packet:", err)
- return err
+ return nil, err
}
copy(packet[macSize:], sig)
// add the hash to the front. Note: this doesn't protect the
// packet in any way. Our public key will be part of this hash in
- // the future.
+ // The future.
copy(packet, crypto.Sha3(packet[macSize:]))
-
- toaddr := &net.UDPAddr{IP: to.IP, Port: to.DiscPort}
- log.DebugDetailf(">>> %v %T %v\n", toaddr, req, req)
- if _, err = t.conn.WriteToUDP(packet, toaddr); err != nil {
- log.DebugDetailln("UDP send failed:", err)
- }
- return err
+ return packet, nil
}
// readLoop runs in its own goroutine. it handles incoming UDP packets.
@@ -330,29 +374,34 @@ func (t *udp) readLoop() {
if err != nil {
return
}
- if err := t.packetIn(from, buf[:nbytes]); err != nil {
+ packet, fromID, hash, err := decodePacket(buf[:nbytes])
+ if err != nil {
log.Debugf("Bad packet from %v: %v\n", from, err)
+ continue
}
+ log.DebugDetailf("<<< %v %T %v\n", from, packet, packet)
+ go func() {
+ if err := packet.handle(t, from, fromID, hash); err != nil {
+ log.Debugf("error handling %T from %v: %v", packet, from, err)
+ }
+ }()
}
}
-func (t *udp) packetIn(from *net.UDPAddr, buf []byte) error {
+func decodePacket(buf []byte) (packet, NodeID, []byte, error) {
if len(buf) < headSize+1 {
- return errPacketTooSmall
+ return nil, NodeID{}, nil, errPacketTooSmall
}
hash, sig, sigdata := buf[:macSize], buf[macSize:headSize], buf[headSize:]
shouldhash := crypto.Sha3(buf[macSize:])
if !bytes.Equal(hash, shouldhash) {
- return errBadHash
+ return nil, NodeID{}, nil, errBadHash
}
fromID, err := recoverNodeID(crypto.Sha3(buf[headSize:]), sig)
if err != nil {
- return err
- }
-
- var req interface {
- handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error
+ return nil, NodeID{}, hash, err
}
+ var req packet
switch ptype := sigdata[0]; ptype {
case pingPacket:
req = new(ping)
@@ -363,13 +412,10 @@ func (t *udp) packetIn(from *net.UDPAddr, buf []byte) error {
case neighborsPacket:
req = new(neighbors)
default:
- return fmt.Errorf("unknown type: %d", ptype)
+ return nil, fromID, hash, fmt.Errorf("unknown type: %d", ptype)
}
- if err := rlp.Decode(bytes.NewReader(sigdata[1:]), req); err != nil {
- return err
- }
- log.DebugDetailf("<<< %v %T %v\n", from, req, req)
- return req.handle(t, from, fromID, hash)
+ err = rlp.Decode(bytes.NewReader(sigdata[1:]), req)
+ return req, fromID, hash, err
}
func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
@@ -379,18 +425,14 @@ func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) er
if req.Version != Version {
return errBadVersion
}
- t.mutex.Lock()
- // Note: we're ignoring the provided IP address right now
- n := t.bumpOrAdd(fromID, from)
- if req.Port != 0 {
- n.TCPPort = int(req.Port)
- }
- t.mutex.Unlock()
-
- t.send(n, pongPacket, pong{
+ t.send(from, pongPacket, pong{
ReplyTok: mac,
Expiration: uint64(time.Now().Add(expiration).Unix()),
})
+ if !t.handleReply(fromID, pingPacket, req) {
+ // Note: we're ignoring the provided IP address right now
+ t.bond(true, fromID, from, req.Port)
+ }
return nil
}
@@ -398,11 +440,9 @@ func (req *pong) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) er
if expired(req.Expiration) {
return errExpired
}
- t.mutex.Lock()
- t.bump(fromID)
- t.mutex.Unlock()
-
- t.replies <- reply{fromID, pongPacket, req}
+ if !t.handleReply(fromID, pongPacket, req) {
+ return errUnsolicitedReply
+ }
return nil
}
@@ -410,12 +450,21 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte
if expired(req.Expiration) {
return errExpired
}
+ if t.db.get(fromID) == nil {
+ // No bond exists, we don't process the packet. This prevents
+ // an attack vector where the discovery protocol could be used
+ // to amplify traffic in a DDOS attack. A malicious actor
+ // would send a findnode request with the IP address and UDP
+ // port of the target as the source address. The recipient of
+ // the findnode packet would then send a neighbors packet
+ // (which is a much bigger packet than findnode) to the victim.
+ return errUnknownNode
+ }
t.mutex.Lock()
- e := t.bumpOrAdd(fromID, from)
closest := t.closest(req.Target, bucketSize).entries
t.mutex.Unlock()
- t.send(e, neighborsPacket, neighbors{
+ t.send(from, neighborsPacket, neighbors{
Nodes: closest,
Expiration: uint64(time.Now().Add(expiration).Unix()),
})
@@ -426,12 +475,9 @@ func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byt
if expired(req.Expiration) {
return errExpired
}
- t.mutex.Lock()
- t.bump(fromID)
- t.add(req.Nodes)
- t.mutex.Unlock()
-
- t.replies <- reply{fromID, neighborsPacket, req}
+ if !t.handleReply(fromID, neighborsPacket, req) {
+ return errUnsolicitedReply
+ }
return nil
}