aboutsummaryrefslogtreecommitdiffstats
path: root/p2p/discover/udp.go
diff options
context:
space:
mode:
Diffstat (limited to 'p2p/discover/udp.go')
-rw-r--r--p2p/discover/udp.go88
1 files changed, 60 insertions, 28 deletions
diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go
index 60436952d..524c6e498 100644
--- a/p2p/discover/udp.go
+++ b/p2p/discover/udp.go
@@ -216,9 +216,22 @@ type ReadPacket struct {
Addr *net.UDPAddr
}
+// Config holds Table-related settings.
+type Config struct {
+ // These settings are required and configure the UDP listener:
+ PrivateKey *ecdsa.PrivateKey
+
+ // These settings are optional:
+ AnnounceAddr *net.UDPAddr // local address announced in the DHT
+ NodeDBPath string // if set, the node database is stored at this filesystem location
+ NetRestrict *netutil.Netlist // network whitelist
+ Bootnodes []*Node // list of bootstrap nodes
+ Unhandled chan<- ReadPacket // unhandled packets are sent on this channel
+}
+
// ListenUDP returns a new table that listens for UDP packets on laddr.
-func ListenUDP(priv *ecdsa.PrivateKey, conn conn, realaddr *net.UDPAddr, unhandled chan ReadPacket, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, error) {
- tab, _, err := newUDP(priv, conn, realaddr, unhandled, nodeDBPath, netrestrict)
+func ListenUDP(c conn, cfg Config) (*Table, error) {
+ tab, _, err := newUDP(c, cfg)
if err != nil {
return nil, err
}
@@ -226,25 +239,29 @@ func ListenUDP(priv *ecdsa.PrivateKey, conn conn, realaddr *net.UDPAddr, unhandl
return tab, nil
}
-func newUDP(priv *ecdsa.PrivateKey, c conn, realaddr *net.UDPAddr, unhandled chan ReadPacket, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, *udp, error) {
+func newUDP(c conn, cfg Config) (*Table, *udp, error) {
udp := &udp{
conn: c,
- priv: priv,
- netrestrict: netrestrict,
+ priv: cfg.PrivateKey,
+ netrestrict: cfg.NetRestrict,
closing: make(chan struct{}),
gotreply: make(chan reply),
addpending: make(chan *pending),
}
+ realaddr := c.LocalAddr().(*net.UDPAddr)
+ if cfg.AnnounceAddr != nil {
+ realaddr = cfg.AnnounceAddr
+ }
// TODO: separate TCP port
udp.ourEndpoint = makeEndpoint(realaddr, uint16(realaddr.Port))
- tab, err := newTable(udp, PubkeyID(&priv.PublicKey), realaddr, nodeDBPath)
+ tab, err := newTable(udp, PubkeyID(&cfg.PrivateKey.PublicKey), realaddr, cfg.NodeDBPath, cfg.Bootnodes)
if err != nil {
return nil, nil, err
}
udp.Table = tab
go udp.loop()
- go udp.readLoop(unhandled)
+ go udp.readLoop(cfg.Unhandled)
return udp.Table, udp, nil
}
@@ -256,14 +273,20 @@ func (t *udp) close() {
// ping sends a ping message to the given node and waits for a reply.
func (t *udp) ping(toid NodeID, toaddr *net.UDPAddr) error {
- // TODO: maybe check for ReplyTo field in callback to measure RTT
- errc := t.pending(toid, pongPacket, func(interface{}) bool { return true })
- t.send(toaddr, pingPacket, &ping{
+ req := &ping{
Version: Version,
From: t.ourEndpoint,
To: makeEndpoint(toaddr, 0), // TODO: maybe use known TCP port from DB
Expiration: uint64(time.Now().Add(expiration).Unix()),
+ }
+ packet, hash, err := encodePacket(t.priv, pingPacket, req)
+ if err != nil {
+ return err
+ }
+ errc := t.pending(toid, pongPacket, func(p interface{}) bool {
+ return bytes.Equal(p.(*pong).ReplyTok, hash)
})
+ t.write(toaddr, req.name(), packet)
return <-errc
}
@@ -447,40 +470,45 @@ func init() {
}
}
-func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req packet) error {
- packet, err := encodePacket(t.priv, ptype, req)
+func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req packet) ([]byte, error) {
+ packet, hash, err := encodePacket(t.priv, ptype, req)
if err != nil {
- return err
+ return hash, err
}
- _, err = t.conn.WriteToUDP(packet, toaddr)
- log.Trace(">> "+req.name(), "addr", toaddr, "err", err)
+ return hash, t.write(toaddr, req.name(), packet)
+}
+
+func (t *udp) write(toaddr *net.UDPAddr, what string, packet []byte) error {
+ _, err := t.conn.WriteToUDP(packet, toaddr)
+ log.Trace(">> "+what, "addr", toaddr, "err", err)
return err
}
-func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) ([]byte, error) {
+func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) (packet, hash []byte, err error) {
b := new(bytes.Buffer)
b.Write(headSpace)
b.WriteByte(ptype)
if err := rlp.Encode(b, req); err != nil {
log.Error("Can't encode discv4 packet", "err", err)
- return nil, err
+ return nil, nil, err
}
- packet := b.Bytes()
+ packet = b.Bytes()
sig, err := crypto.Sign(crypto.Keccak256(packet[headSize:]), priv)
if err != nil {
log.Error("Can't sign discv4 packet", "err", err)
- return nil, err
+ return nil, nil, err
}
copy(packet[macSize:], sig)
// add the hash to the front. Note: this doesn't protect the
// packet in any way. Our public key will be part of this hash in
// The future.
- copy(packet, crypto.Keccak256(packet[macSize:]))
- return packet, nil
+ hash = crypto.Keccak256(packet[macSize:])
+ copy(packet, hash)
+ return packet, hash, nil
}
// readLoop runs in its own goroutine. it handles incoming UDP packets.
-func (t *udp) readLoop(unhandled chan ReadPacket) {
+func (t *udp) readLoop(unhandled chan<- ReadPacket) {
defer t.conn.Close()
if unhandled != nil {
defer close(unhandled)
@@ -585,7 +613,7 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte
if expired(req.Expiration) {
return errExpired
}
- if t.db.node(fromID) == nil {
+ if !t.db.hasBond(fromID) {
// No bond exists, we don't process the packet. This prevents
// an attack vector where the discovery protocol could be used
// to amplify traffic in a DDOS attack. A malicious actor
@@ -601,18 +629,22 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte
t.mutex.Unlock()
p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())}
+ var sent bool
// Send neighbors in chunks with at most maxNeighbors per packet
// to stay below the 1280 byte limit.
- for i, n := range closest {
- if netutil.CheckRelayIP(from.IP, n.IP) != nil {
- continue
+ for _, n := range closest {
+ if netutil.CheckRelayIP(from.IP, n.IP) == nil {
+ p.Nodes = append(p.Nodes, nodeToRPC(n))
}
- p.Nodes = append(p.Nodes, nodeToRPC(n))
- if len(p.Nodes) == maxNeighbors || i == len(closest)-1 {
+ if len(p.Nodes) == maxNeighbors {
t.send(from, neighborsPacket, &p)
p.Nodes = p.Nodes[:0]
+ sent = true
}
}
+ if len(p.Nodes) > 0 || !sent {
+ t.send(from, neighborsPacket, &p)
+ }
return nil
}