diff options
Diffstat (limited to 'p2p/discover/udp.go')
-rw-r--r-- | p2p/discover/udp.go | 88 |
1 files changed, 60 insertions, 28 deletions
diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go index 60436952d..524c6e498 100644 --- a/p2p/discover/udp.go +++ b/p2p/discover/udp.go @@ -216,9 +216,22 @@ type ReadPacket struct { Addr *net.UDPAddr } +// Config holds Table-related settings. +type Config struct { + // These settings are required and configure the UDP listener: + PrivateKey *ecdsa.PrivateKey + + // These settings are optional: + AnnounceAddr *net.UDPAddr // local address announced in the DHT + NodeDBPath string // if set, the node database is stored at this filesystem location + NetRestrict *netutil.Netlist // network whitelist + Bootnodes []*Node // list of bootstrap nodes + Unhandled chan<- ReadPacket // unhandled packets are sent on this channel +} + // ListenUDP returns a new table that listens for UDP packets on laddr. -func ListenUDP(priv *ecdsa.PrivateKey, conn conn, realaddr *net.UDPAddr, unhandled chan ReadPacket, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, error) { - tab, _, err := newUDP(priv, conn, realaddr, unhandled, nodeDBPath, netrestrict) +func ListenUDP(c conn, cfg Config) (*Table, error) { + tab, _, err := newUDP(c, cfg) if err != nil { return nil, err } @@ -226,25 +239,29 @@ func ListenUDP(priv *ecdsa.PrivateKey, conn conn, realaddr *net.UDPAddr, unhandl return tab, nil } -func newUDP(priv *ecdsa.PrivateKey, c conn, realaddr *net.UDPAddr, unhandled chan ReadPacket, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, *udp, error) { +func newUDP(c conn, cfg Config) (*Table, *udp, error) { udp := &udp{ conn: c, - priv: priv, - netrestrict: netrestrict, + priv: cfg.PrivateKey, + netrestrict: cfg.NetRestrict, closing: make(chan struct{}), gotreply: make(chan reply), addpending: make(chan *pending), } + realaddr := c.LocalAddr().(*net.UDPAddr) + if cfg.AnnounceAddr != nil { + realaddr = cfg.AnnounceAddr + } // TODO: separate TCP port udp.ourEndpoint = makeEndpoint(realaddr, uint16(realaddr.Port)) - tab, err := newTable(udp, PubkeyID(&priv.PublicKey), realaddr, nodeDBPath) + tab, err := newTable(udp, PubkeyID(&cfg.PrivateKey.PublicKey), realaddr, cfg.NodeDBPath, cfg.Bootnodes) if err != nil { return nil, nil, err } udp.Table = tab go udp.loop() - go udp.readLoop(unhandled) + go udp.readLoop(cfg.Unhandled) return udp.Table, udp, nil } @@ -256,14 +273,20 @@ func (t *udp) close() { // ping sends a ping message to the given node and waits for a reply. func (t *udp) ping(toid NodeID, toaddr *net.UDPAddr) error { - // TODO: maybe check for ReplyTo field in callback to measure RTT - errc := t.pending(toid, pongPacket, func(interface{}) bool { return true }) - t.send(toaddr, pingPacket, &ping{ + req := &ping{ Version: Version, From: t.ourEndpoint, To: makeEndpoint(toaddr, 0), // TODO: maybe use known TCP port from DB Expiration: uint64(time.Now().Add(expiration).Unix()), + } + packet, hash, err := encodePacket(t.priv, pingPacket, req) + if err != nil { + return err + } + errc := t.pending(toid, pongPacket, func(p interface{}) bool { + return bytes.Equal(p.(*pong).ReplyTok, hash) }) + t.write(toaddr, req.name(), packet) return <-errc } @@ -447,40 +470,45 @@ func init() { } } -func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req packet) error { - packet, err := encodePacket(t.priv, ptype, req) +func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req packet) ([]byte, error) { + packet, hash, err := encodePacket(t.priv, ptype, req) if err != nil { - return err + return hash, err } - _, err = t.conn.WriteToUDP(packet, toaddr) - log.Trace(">> "+req.name(), "addr", toaddr, "err", err) + return hash, t.write(toaddr, req.name(), packet) +} + +func (t *udp) write(toaddr *net.UDPAddr, what string, packet []byte) error { + _, err := t.conn.WriteToUDP(packet, toaddr) + log.Trace(">> "+what, "addr", toaddr, "err", err) return err } -func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) ([]byte, error) { +func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) (packet, hash []byte, err error) { b := new(bytes.Buffer) b.Write(headSpace) b.WriteByte(ptype) if err := rlp.Encode(b, req); err != nil { log.Error("Can't encode discv4 packet", "err", err) - return nil, err + return nil, nil, err } - packet := b.Bytes() + packet = b.Bytes() sig, err := crypto.Sign(crypto.Keccak256(packet[headSize:]), priv) if err != nil { log.Error("Can't sign discv4 packet", "err", err) - return nil, err + return nil, nil, err } copy(packet[macSize:], sig) // add the hash to the front. Note: this doesn't protect the // packet in any way. Our public key will be part of this hash in // The future. - copy(packet, crypto.Keccak256(packet[macSize:])) - return packet, nil + hash = crypto.Keccak256(packet[macSize:]) + copy(packet, hash) + return packet, hash, nil } // readLoop runs in its own goroutine. it handles incoming UDP packets. -func (t *udp) readLoop(unhandled chan ReadPacket) { +func (t *udp) readLoop(unhandled chan<- ReadPacket) { defer t.conn.Close() if unhandled != nil { defer close(unhandled) @@ -585,7 +613,7 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte if expired(req.Expiration) { return errExpired } - if t.db.node(fromID) == nil { + if !t.db.hasBond(fromID) { // No bond exists, we don't process the packet. This prevents // an attack vector where the discovery protocol could be used // to amplify traffic in a DDOS attack. A malicious actor @@ -601,18 +629,22 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte t.mutex.Unlock() p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())} + var sent bool // Send neighbors in chunks with at most maxNeighbors per packet // to stay below the 1280 byte limit. - for i, n := range closest { - if netutil.CheckRelayIP(from.IP, n.IP) != nil { - continue + for _, n := range closest { + if netutil.CheckRelayIP(from.IP, n.IP) == nil { + p.Nodes = append(p.Nodes, nodeToRPC(n)) } - p.Nodes = append(p.Nodes, nodeToRPC(n)) - if len(p.Nodes) == maxNeighbors || i == len(closest)-1 { + if len(p.Nodes) == maxNeighbors { t.send(from, neighborsPacket, &p) p.Nodes = p.Nodes[:0] + sent = true } } + if len(p.Nodes) > 0 || !sent { + t.send(from, neighborsPacket, &p) + } return nil } |