aboutsummaryrefslogtreecommitdiffstats
path: root/p2p/discover/udp.go
diff options
context:
space:
mode:
Diffstat (limited to 'p2p/discover/udp.go')
-rw-r--r--p2p/discover/udp.go486
1 files changed, 486 insertions, 0 deletions
diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go
new file mode 100644
index 000000000..e9ede1397
--- /dev/null
+++ b/p2p/discover/udp.go
@@ -0,0 +1,486 @@
+package discover
+
+import (
+ "bytes"
+ "crypto/ecdsa"
+ "errors"
+ "fmt"
+ "net"
+ "time"
+
+ "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/logger"
+ "github.com/ethereum/go-ethereum/p2p/nat"
+ "github.com/ethereum/go-ethereum/rlp"
+)
+
+var log = logger.NewLogger("P2P Discovery")
+
+const Version = 3
+
+// Errors
+var (
+ errPacketTooSmall = errors.New("too small")
+ errBadHash = errors.New("bad hash")
+ errExpired = errors.New("expired")
+ errBadVersion = errors.New("version mismatch")
+ errUnsolicitedReply = errors.New("unsolicited reply")
+ errUnknownNode = errors.New("unknown node")
+ errTimeout = errors.New("RPC timeout")
+ errClosed = errors.New("socket closed")
+)
+
+// Timeouts
+const (
+ respTimeout = 300 * time.Millisecond
+ sendTimeout = 300 * time.Millisecond
+ expiration = 20 * time.Second
+
+ refreshInterval = 1 * time.Hour
+)
+
+// RPC packet types
+const (
+ pingPacket = iota + 1 // zero is 'reserved'
+ pongPacket
+ findnodePacket
+ neighborsPacket
+)
+
+// RPC request structures
+type (
+ ping struct {
+ Version uint // must match Version
+ IP string // our IP
+ Port uint16 // our port
+ Expiration uint64
+ }
+
+ // reply to Ping
+ pong struct {
+ ReplyTok []byte
+ Expiration uint64
+ }
+
+ findnode struct {
+ // Id to look up. The responding node will send back nodes
+ // closest to the target.
+ Target NodeID
+ Expiration uint64
+ }
+
+ // reply to findnode
+ neighbors struct {
+ Nodes []*Node
+ Expiration uint64
+ }
+)
+
+type rpcNode struct {
+ IP string
+ Port uint16
+ ID NodeID
+}
+
+type packet interface {
+ handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error
+}
+
+type conn interface {
+ ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error)
+ WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error)
+ Close() error
+ LocalAddr() net.Addr
+}
+
+// udp implements the RPC protocol.
+type udp struct {
+ conn conn
+ priv *ecdsa.PrivateKey
+
+ addpending chan *pending
+ gotreply chan reply
+
+ closing chan struct{}
+ nat nat.Interface
+
+ *Table
+}
+
+// pending represents a pending reply.
+//
+// some implementations of the protocol wish to send more than one
+// reply packet to findnode. in general, any neighbors packet cannot
+// be matched up with a specific findnode packet.
+//
+// our implementation handles this by storing a callback function for
+// each pending reply. incoming packets from a node are dispatched
+// to all the callback functions for that node.
+type pending struct {
+ // these fields must match in the reply.
+ from NodeID
+ ptype byte
+
+ // time when the request must complete
+ deadline time.Time
+
+ // callback is called when a matching reply arrives. if it returns
+ // true, the callback is removed from the pending reply queue.
+ // if it returns false, the reply is considered incomplete and
+ // the callback will be invoked again for the next matching reply.
+ callback func(resp interface{}) (done bool)
+
+ // errc receives nil when the callback indicates completion or an
+ // error if no further reply is received within the timeout.
+ errc chan<- error
+}
+
+type reply struct {
+ from NodeID
+ ptype byte
+ data interface{}
+ // loop indicates whether there was
+ // a matching request by sending on this channel.
+ matched chan<- bool
+}
+
+// ListenUDP returns a new table that listens for UDP packets on laddr.
+func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface) (*Table, error) {
+ addr, err := net.ResolveUDPAddr("udp", laddr)
+ if err != nil {
+ return nil, err
+ }
+ conn, err := net.ListenUDP("udp", addr)
+ if err != nil {
+ return nil, err
+ }
+ tab, _ := newUDP(priv, conn, natm)
+ log.Infoln("Listening,", tab.self)
+ return tab, nil
+}
+
+func newUDP(priv *ecdsa.PrivateKey, c conn, natm nat.Interface) (*Table, *udp) {
+ udp := &udp{
+ conn: c,
+ priv: priv,
+ closing: make(chan struct{}),
+ gotreply: make(chan reply),
+ addpending: make(chan *pending),
+ }
+ realaddr := c.LocalAddr().(*net.UDPAddr)
+ if natm != nil {
+ if !realaddr.IP.IsLoopback() {
+ go nat.Map(natm, udp.closing, "udp", realaddr.Port, realaddr.Port, "ethereum discovery")
+ }
+ // TODO: react to external IP changes over time.
+ if ext, err := natm.ExternalIP(); err == nil {
+ realaddr = &net.UDPAddr{IP: ext, Port: realaddr.Port}
+ }
+ }
+ udp.Table = newTable(udp, PubkeyID(&priv.PublicKey), realaddr)
+ go udp.loop()
+ go udp.readLoop()
+ return udp.Table, udp
+}
+
+func (t *udp) close() {
+ close(t.closing)
+ t.conn.Close()
+ // TODO: wait for the loops to end.
+}
+
+// ping sends a ping message to the given node and waits for a reply.
+func (t *udp) ping(toid NodeID, toaddr *net.UDPAddr) error {
+ // TODO: maybe check for ReplyTo field in callback to measure RTT
+ errc := t.pending(toid, pongPacket, func(interface{}) bool { return true })
+ t.send(toaddr, pingPacket, ping{
+ Version: Version,
+ IP: t.self.IP.String(),
+ Port: uint16(t.self.TCPPort),
+ Expiration: uint64(time.Now().Add(expiration).Unix()),
+ })
+ return <-errc
+}
+
+func (t *udp) waitping(from NodeID) error {
+ return <-t.pending(from, pingPacket, func(interface{}) bool { return true })
+}
+
+// findnode sends a findnode request to the given node and waits until
+// the node has sent up to k neighbors.
+func (t *udp) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) {
+ nodes := make([]*Node, 0, bucketSize)
+ nreceived := 0
+ errc := t.pending(toid, neighborsPacket, func(r interface{}) bool {
+ reply := r.(*neighbors)
+ for _, n := range reply.Nodes {
+ nreceived++
+ if n.isValid() {
+ nodes = append(nodes, n)
+ }
+ }
+ return nreceived >= bucketSize
+ })
+ t.send(toaddr, findnodePacket, findnode{
+ Target: target,
+ Expiration: uint64(time.Now().Add(expiration).Unix()),
+ })
+ err := <-errc
+ return nodes, err
+}
+
+// pending adds a reply callback to the pending reply queue.
+// see the documentation of type pending for a detailed explanation.
+func (t *udp) pending(id NodeID, ptype byte, callback func(interface{}) bool) <-chan error {
+ ch := make(chan error, 1)
+ p := &pending{from: id, ptype: ptype, callback: callback, errc: ch}
+ select {
+ case t.addpending <- p:
+ // loop will handle it
+ case <-t.closing:
+ ch <- errClosed
+ }
+ return ch
+}
+
+func (t *udp) handleReply(from NodeID, ptype byte, req packet) bool {
+ matched := make(chan bool)
+ select {
+ case t.gotreply <- reply{from, ptype, req, matched}:
+ // loop will handle it
+ return <-matched
+ case <-t.closing:
+ return false
+ }
+}
+
+// loop runs in its own goroutin. it keeps track of
+// the refresh timer and the pending reply queue.
+func (t *udp) loop() {
+ var (
+ pending []*pending
+ nextDeadline time.Time
+ timeout = time.NewTimer(0)
+ refresh = time.NewTicker(refreshInterval)
+ )
+ <-timeout.C // ignore first timeout
+ defer refresh.Stop()
+ defer timeout.Stop()
+
+ rearmTimeout := func() {
+ if len(pending) == 0 || nextDeadline == pending[0].deadline {
+ return
+ }
+ nextDeadline = pending[0].deadline
+ timeout.Reset(nextDeadline.Sub(time.Now()))
+ }
+
+ for {
+ select {
+ case <-refresh.C:
+ go t.refresh()
+
+ case <-t.closing:
+ for _, p := range pending {
+ p.errc <- errClosed
+ }
+ pending = nil
+ return
+
+ case p := <-t.addpending:
+ p.deadline = time.Now().Add(respTimeout)
+ pending = append(pending, p)
+ rearmTimeout()
+
+ case r := <-t.gotreply:
+ var matched bool
+ for i := 0; i < len(pending); i++ {
+ if p := pending[i]; p.from == r.from && p.ptype == r.ptype {
+ matched = true
+ if p.callback(r.data) {
+ // callback indicates the request is done, remove it.
+ p.errc <- nil
+ copy(pending[i:], pending[i+1:])
+ pending = pending[:len(pending)-1]
+ i--
+ }
+ }
+ }
+ r.matched <- matched
+
+ case now := <-timeout.C:
+ // notify and remove callbacks whose deadline is in the past.
+ i := 0
+ for ; i < len(pending) && now.After(pending[i].deadline); i++ {
+ pending[i].errc <- errTimeout
+ }
+ if i > 0 {
+ copy(pending, pending[i:])
+ pending = pending[:len(pending)-i]
+ }
+ rearmTimeout()
+ }
+ }
+}
+
+const (
+ macSize = 256 / 8
+ sigSize = 520 / 8
+ headSize = macSize + sigSize // space of packet frame data
+)
+
+var headSpace = make([]byte, headSize)
+
+func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req interface{}) error {
+ packet, err := encodePacket(t.priv, ptype, req)
+ if err != nil {
+ return err
+ }
+ log.DebugDetailf(">>> %v %T %v\n", toaddr, req, req)
+ if _, err = t.conn.WriteToUDP(packet, toaddr); err != nil {
+ log.DebugDetailln("UDP send failed:", err)
+ }
+ return err
+}
+
+func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) ([]byte, error) {
+ b := new(bytes.Buffer)
+ b.Write(headSpace)
+ b.WriteByte(ptype)
+ if err := rlp.Encode(b, req); err != nil {
+ log.Errorln("error encoding packet:", err)
+ return nil, err
+ }
+ packet := b.Bytes()
+ sig, err := crypto.Sign(crypto.Sha3(packet[headSize:]), priv)
+ if err != nil {
+ log.Errorln("could not sign packet:", err)
+ return nil, err
+ }
+ copy(packet[macSize:], sig)
+ // add the hash to the front. Note: this doesn't protect the
+ // packet in any way. Our public key will be part of this hash in
+ // The future.
+ copy(packet, crypto.Sha3(packet[macSize:]))
+ return packet, nil
+}
+
+// readLoop runs in its own goroutine. it handles incoming UDP packets.
+func (t *udp) readLoop() {
+ defer t.conn.Close()
+ buf := make([]byte, 4096) // TODO: good buffer size
+ for {
+ nbytes, from, err := t.conn.ReadFromUDP(buf)
+ if err != nil {
+ return
+ }
+ packet, fromID, hash, err := decodePacket(buf[:nbytes])
+ if err != nil {
+ log.Debugf("Bad packet from %v: %v\n", from, err)
+ continue
+ }
+ log.DebugDetailf("<<< %v %T %v\n", from, packet, packet)
+ go func() {
+ if err := packet.handle(t, from, fromID, hash); err != nil {
+ log.Debugf("error handling %T from %v: %v", packet, from, err)
+ }
+ }()
+ }
+}
+
+func decodePacket(buf []byte) (packet, NodeID, []byte, error) {
+ if len(buf) < headSize+1 {
+ return nil, NodeID{}, nil, errPacketTooSmall
+ }
+ hash, sig, sigdata := buf[:macSize], buf[macSize:headSize], buf[headSize:]
+ shouldhash := crypto.Sha3(buf[macSize:])
+ if !bytes.Equal(hash, shouldhash) {
+ return nil, NodeID{}, nil, errBadHash
+ }
+ fromID, err := recoverNodeID(crypto.Sha3(buf[headSize:]), sig)
+ if err != nil {
+ return nil, NodeID{}, hash, err
+ }
+ var req packet
+ switch ptype := sigdata[0]; ptype {
+ case pingPacket:
+ req = new(ping)
+ case pongPacket:
+ req = new(pong)
+ case findnodePacket:
+ req = new(findnode)
+ case neighborsPacket:
+ req = new(neighbors)
+ default:
+ return nil, fromID, hash, fmt.Errorf("unknown type: %d", ptype)
+ }
+ err = rlp.Decode(bytes.NewReader(sigdata[1:]), req)
+ return req, fromID, hash, err
+}
+
+func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
+ if expired(req.Expiration) {
+ return errExpired
+ }
+ if req.Version != Version {
+ return errBadVersion
+ }
+ t.send(from, pongPacket, pong{
+ ReplyTok: mac,
+ Expiration: uint64(time.Now().Add(expiration).Unix()),
+ })
+ if !t.handleReply(fromID, pingPacket, req) {
+ // Note: we're ignoring the provided IP address right now
+ t.bond(true, fromID, from, req.Port)
+ }
+ return nil
+}
+
+func (req *pong) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
+ if expired(req.Expiration) {
+ return errExpired
+ }
+ if !t.handleReply(fromID, pongPacket, req) {
+ return errUnsolicitedReply
+ }
+ return nil
+}
+
+func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
+ if expired(req.Expiration) {
+ return errExpired
+ }
+ if t.db.get(fromID) == nil {
+ // No bond exists, we don't process the packet. This prevents
+ // an attack vector where the discovery protocol could be used
+ // to amplify traffic in a DDOS attack. A malicious actor
+ // would send a findnode request with the IP address and UDP
+ // port of the target as the source address. The recipient of
+ // the findnode packet would then send a neighbors packet
+ // (which is a much bigger packet than findnode) to the victim.
+ return errUnknownNode
+ }
+ t.mutex.Lock()
+ closest := t.closest(req.Target, bucketSize).entries
+ t.mutex.Unlock()
+
+ t.send(from, neighborsPacket, neighbors{
+ Nodes: closest,
+ Expiration: uint64(time.Now().Add(expiration).Unix()),
+ })
+ return nil
+}
+
+func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
+ if expired(req.Expiration) {
+ return errExpired
+ }
+ if !t.handleReply(fromID, neighborsPacket, req) {
+ return errUnsolicitedReply
+ }
+ return nil
+}
+
+func expired(ts uint64) bool {
+ return time.Unix(int64(ts), 0).Before(time.Now())
+}