aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--p2p/discover/node.go43
-rw-r--r--p2p/discover/table.go183
-rw-r--r--p2p/discover/table_test.go170
-rw-r--r--p2p/discover/udp.go214
-rw-r--r--p2p/discover/udp_test.go422
5 files changed, 649 insertions, 383 deletions
diff --git a/p2p/discover/node.go b/p2p/discover/node.go
index e1130e0b5..99cb549a5 100644
--- a/p2p/discover/node.go
+++ b/p2p/discover/node.go
@@ -13,6 +13,8 @@ import (
"net/url"
"strconv"
"strings"
+ "sync"
+ "sync/atomic"
"time"
"github.com/ethereum/go-ethereum/crypto"
@@ -30,7 +32,8 @@ type Node struct {
DiscPort int // UDP listening port for discovery protocol
TCPPort int // TCP listening port for RLPx
- active time.Time
+ // this must be set/read using atomic load and store.
+ activeStamp int64
}
func newNode(id NodeID, addr *net.UDPAddr) *Node {
@@ -39,7 +42,6 @@ func newNode(id NodeID, addr *net.UDPAddr) *Node {
IP: addr.IP,
DiscPort: addr.Port,
TCPPort: addr.Port,
- active: time.Now(),
}
}
@@ -48,6 +50,20 @@ func (n *Node) isValid() bool {
return !n.IP.IsMulticast() && !n.IP.IsUnspecified() && n.TCPPort != 0 && n.DiscPort != 0
}
+func (n *Node) bumpActive() {
+ stamp := time.Now().Unix()
+ atomic.StoreInt64(&n.activeStamp, stamp)
+}
+
+func (n *Node) active() time.Time {
+ stamp := atomic.LoadInt64(&n.activeStamp)
+ return time.Unix(stamp, 0)
+}
+
+func (n *Node) addr() *net.UDPAddr {
+ return &net.UDPAddr{IP: n.IP, Port: n.DiscPort}
+}
+
// The string representation of a Node is a URL.
// Please see ParseNode for a description of the format.
func (n *Node) String() string {
@@ -304,3 +320,26 @@ func randomID(a NodeID, n int) (b NodeID) {
}
return b
}
+
+// nodeDB stores all nodes we know about.
+type nodeDB struct {
+ mu sync.RWMutex
+ byID map[NodeID]*Node
+}
+
+func (db *nodeDB) get(id NodeID) *Node {
+ db.mu.RLock()
+ defer db.mu.RUnlock()
+ return db.byID[id]
+}
+
+func (db *nodeDB) add(id NodeID, addr *net.UDPAddr, tcpPort uint16) *Node {
+ db.mu.Lock()
+ defer db.mu.Unlock()
+ if db.byID == nil {
+ db.byID = make(map[NodeID]*Node)
+ }
+ n := &Node{ID: id, IP: addr.IP, DiscPort: addr.Port, TCPPort: int(tcpPort)}
+ db.byID[n.ID] = n
+ return n
+}
diff --git a/p2p/discover/table.go b/p2p/discover/table.go
index 33b705a12..842f55d9f 100644
--- a/p2p/discover/table.go
+++ b/p2p/discover/table.go
@@ -14,9 +14,10 @@ import (
)
const (
- alpha = 3 // Kademlia concurrency factor
- bucketSize = 16 // Kademlia bucket size
- nBuckets = nodeIDBits + 1 // Number of buckets
+ alpha = 3 // Kademlia concurrency factor
+ bucketSize = 16 // Kademlia bucket size
+ nBuckets = nodeIDBits + 1 // Number of buckets
+ maxBondingPingPongs = 10
)
type Table struct {
@@ -24,27 +25,50 @@ type Table struct {
buckets [nBuckets]*bucket // index of known nodes by distance
nursery []*Node // bootstrap nodes
+ bondmu sync.Mutex
+ bonding map[NodeID]*bondproc
+ bondslots chan struct{} // limits total number of active bonding processes
+
net transport
self *Node // metadata of the local node
+ db *nodeDB
+}
+
+type bondproc struct {
+ err error
+ n *Node
+ done chan struct{}
}
// transport is implemented by the UDP transport.
// it is an interface so we can test without opening lots of UDP
// sockets and without generating a private key.
type transport interface {
- ping(*Node) error
- findnode(e *Node, target NodeID) ([]*Node, error)
+ ping(NodeID, *net.UDPAddr) error
+ waitping(NodeID) error
+ findnode(toid NodeID, addr *net.UDPAddr, target NodeID) ([]*Node, error)
close()
}
// bucket contains nodes, ordered by their last activity.
+// the entry that was most recently active is the last element
+// in entries.
type bucket struct {
lastLookup time.Time
entries []*Node
}
func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr) *Table {
- tab := &Table{net: t, self: newNode(ourID, ourAddr)}
+ tab := &Table{
+ net: t,
+ db: new(nodeDB),
+ self: newNode(ourID, ourAddr),
+ bonding: make(map[NodeID]*bondproc),
+ bondslots: make(chan struct{}, maxBondingPingPongs),
+ }
+ for i := 0; i < cap(tab.bondslots); i++ {
+ tab.bondslots <- struct{}{}
+ }
for i := range tab.buckets {
tab.buckets[i] = new(bucket)
}
@@ -107,8 +131,8 @@ func (tab *Table) Lookup(target NodeID) []*Node {
asked[n.ID] = true
pendingQueries++
go func() {
- result, _ := tab.net.findnode(n, target)
- reply <- result
+ r, _ := tab.net.findnode(n.ID, n.addr(), target)
+ reply <- tab.bondall(r)
}()
}
}
@@ -116,13 +140,11 @@ func (tab *Table) Lookup(target NodeID) []*Node {
// we have asked all closest nodes, stop the search
break
}
-
// wait for the next reply
for _, n := range <-reply {
- cn := n
- if !seen[n.ID] {
+ if n != nil && !seen[n.ID] {
seen[n.ID] = true
- result.push(cn, bucketSize)
+ result.push(n, bucketSize)
}
}
pendingQueries--
@@ -145,8 +167,9 @@ func (tab *Table) refresh() {
result := tab.Lookup(randomID(tab.self.ID, ld))
if len(result) == 0 {
// bootstrap the table with a self lookup
+ all := tab.bondall(tab.nursery)
tab.mutex.Lock()
- tab.add(tab.nursery)
+ tab.add(all)
tab.mutex.Unlock()
tab.Lookup(tab.self.ID)
// TODO: the Kademlia paper says that we're supposed to perform
@@ -176,45 +199,105 @@ func (tab *Table) len() (n int) {
return n
}
-// bumpOrAdd updates the activity timestamp for the given node and
-// attempts to insert the node into a bucket. The returned Node might
-// not be part of the table. The caller must hold tab.mutex.
-func (tab *Table) bumpOrAdd(node NodeID, from *net.UDPAddr) (n *Node) {
- b := tab.buckets[logdist(tab.self.ID, node)]
- if n = b.bump(node); n == nil {
- n = newNode(node, from)
- if len(b.entries) == bucketSize {
- tab.pingReplace(n, b)
- } else {
- b.entries = append(b.entries, n)
+// bondall bonds with all given nodes concurrently and returns
+// those nodes for which bonding has probably succeeded.
+func (tab *Table) bondall(nodes []*Node) (result []*Node) {
+ rc := make(chan *Node, len(nodes))
+ for i := range nodes {
+ go func(n *Node) {
+ nn, _ := tab.bond(false, n.ID, n.addr(), uint16(n.TCPPort))
+ rc <- nn
+ }(nodes[i])
+ }
+ for _ = range nodes {
+ if n := <-rc; n != nil {
+ result = append(result, n)
}
}
- return n
+ return result
}
-func (tab *Table) pingReplace(n *Node, b *bucket) {
- old := b.entries[bucketSize-1]
- go func() {
- if err := tab.net.ping(old); err == nil {
- // it responded, we don't need to replace it.
- return
+// bond ensures the local node has a bond with the given remote node.
+// It also attempts to insert the node into the table if bonding succeeds.
+// The caller must not hold tab.mutex.
+//
+// A bond is must be established before sending findnode requests.
+// Both sides must have completed a ping/pong exchange for a bond to
+// exist. The total number of active bonding processes is limited in
+// order to restrain network use.
+//
+// bond is meant to operate idempotently in that bonding with a remote
+// node which still remembers a previously established bond will work.
+// The remote node will simply not send a ping back, causing waitping
+// to time out.
+//
+// If pinged is true, the remote node has just pinged us and one half
+// of the process can be skipped.
+func (tab *Table) bond(pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16) (*Node, error) {
+ var n *Node
+ if n = tab.db.get(id); n == nil {
+ tab.bondmu.Lock()
+ w := tab.bonding[id]
+ if w != nil {
+ // Wait for an existing bonding process to complete.
+ tab.bondmu.Unlock()
+ <-w.done
+ } else {
+ // Register a new bonding process.
+ w = &bondproc{done: make(chan struct{})}
+ tab.bonding[id] = w
+ tab.bondmu.Unlock()
+ // Do the ping/pong. The result goes into w.
+ tab.pingpong(w, pinged, id, addr, tcpPort)
+ // Unregister the process after it's done.
+ tab.bondmu.Lock()
+ delete(tab.bonding, id)
+ tab.bondmu.Unlock()
}
- // it didn't respond, replace the node if it is still the oldest node.
- tab.mutex.Lock()
- if len(b.entries) > 0 && b.entries[len(b.entries)-1] == old {
- // slide down other entries and put the new one in front.
- // TODO: insert in correct position to keep the order
- copy(b.entries[1:], b.entries)
- b.entries[0] = n
+ n = w.n
+ if w.err != nil {
+ return nil, w.err
}
- tab.mutex.Unlock()
- }()
+ }
+ tab.mutex.Lock()
+ defer tab.mutex.Unlock()
+ if b := tab.buckets[logdist(tab.self.ID, n.ID)]; !b.bump(n) {
+ tab.pingreplace(n, b)
+ }
+ return n, nil
+}
+
+func (tab *Table) pingpong(w *bondproc, pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16) {
+ <-tab.bondslots
+ defer func() { tab.bondslots <- struct{}{} }()
+ if w.err = tab.net.ping(id, addr); w.err != nil {
+ close(w.done)
+ return
+ }
+ if !pinged {
+ // Give the remote node a chance to ping us before we start
+ // sending findnode requests. If they still remember us,
+ // waitping will simply time out.
+ tab.net.waitping(id)
+ }
+ w.n = tab.db.add(id, addr, tcpPort)
+ close(w.done)
}
-// bump updates the activity timestamp for the given node.
-// The caller must hold tab.mutex.
-func (tab *Table) bump(node NodeID) {
- tab.buckets[logdist(tab.self.ID, node)].bump(node)
+func (tab *Table) pingreplace(new *Node, b *bucket) {
+ if len(b.entries) == bucketSize {
+ oldest := b.entries[bucketSize-1]
+ if err := tab.net.ping(oldest.ID, oldest.addr()); err == nil {
+ // The node responded, we don't need to replace it.
+ return
+ }
+ } else {
+ // Add a slot at the end so the last entry doesn't
+ // fall off when adding the new node.
+ b.entries = append(b.entries, nil)
+ }
+ copy(b.entries[1:], b.entries)
+ b.entries[0] = new
}
// add puts the entries into the table if their corresponding
@@ -240,17 +323,17 @@ outer:
}
}
-func (b *bucket) bump(id NodeID) *Node {
- for i, n := range b.entries {
- if n.ID == id {
- n.active = time.Now()
+func (b *bucket) bump(n *Node) bool {
+ for i := range b.entries {
+ if b.entries[i].ID == n.ID {
+ n.bumpActive()
// move it to the front
copy(b.entries[1:], b.entries[:i+1])
b.entries[0] = n
- return n
+ return true
}
}
- return nil
+ return false
}
// nodesByDistance is a list of nodes, ordered by
diff --git a/p2p/discover/table_test.go b/p2p/discover/table_test.go
index 08faea68e..95ec30bea 100644
--- a/p2p/discover/table_test.go
+++ b/p2p/discover/table_test.go
@@ -2,79 +2,68 @@ package discover
import (
"crypto/ecdsa"
- "errors"
"fmt"
"math/rand"
"net"
"reflect"
"testing"
"testing/quick"
- "time"
"github.com/ethereum/go-ethereum/crypto"
)
-func TestTable_bumpOrAddBucketAssign(t *testing.T) {
- tab := newTable(nil, NodeID{}, &net.UDPAddr{})
- for i := 1; i < len(tab.buckets); i++ {
- tab.bumpOrAdd(randomID(tab.self.ID, i), &net.UDPAddr{})
- }
- for i, b := range tab.buckets {
- if i > 0 && len(b.entries) != 1 {
- t.Errorf("bucket %d has %d entries, want 1", i, len(b.entries))
+func TestTable_pingReplace(t *testing.T) {
+ doit := func(newNodeIsResponding, lastInBucketIsResponding bool) {
+ transport := newPingRecorder()
+ tab := newTable(transport, NodeID{}, &net.UDPAddr{})
+ last := fillBucket(tab, 200)
+ pingSender := randomID(tab.self.ID, 200)
+
+ // this gotPing should replace the last node
+ // if the last node is not responding.
+ transport.responding[last.ID] = lastInBucketIsResponding
+ transport.responding[pingSender] = newNodeIsResponding
+ tab.bond(true, pingSender, &net.UDPAddr{}, 0)
+
+ // first ping goes to sender (bonding pingback)
+ if !transport.pinged[pingSender] {
+ t.Error("table did not ping back sender")
+ }
+ if newNodeIsResponding {
+ // second ping goes to oldest node in bucket
+ // to see whether it is still alive.
+ if !transport.pinged[last.ID] {
+ t.Error("table did not ping last node in bucket")
+ }
}
- }
-}
-
-func TestTable_bumpOrAddPingReplace(t *testing.T) {
- pingC := make(pingC)
- tab := newTable(pingC, NodeID{}, &net.UDPAddr{})
- last := fillBucket(tab, 200)
-
- // this bumpOrAdd should not replace the last node
- // because the node replies to ping.
- new := tab.bumpOrAdd(randomID(tab.self.ID, 200), &net.UDPAddr{})
- pinged := <-pingC
- if pinged != last.ID {
- t.Fatalf("pinged wrong node: %v\nwant %v", pinged, last.ID)
- }
+ tab.mutex.Lock()
+ defer tab.mutex.Unlock()
+ if l := len(tab.buckets[200].entries); l != bucketSize {
+ t.Errorf("wrong bucket size after gotPing: got %d, want %d", bucketSize, l)
+ }
- tab.mutex.Lock()
- defer tab.mutex.Unlock()
- if l := len(tab.buckets[200].entries); l != bucketSize {
- t.Errorf("wrong bucket size after bumpOrAdd: got %d, want %d", bucketSize, l)
- }
- if !contains(tab.buckets[200].entries, last.ID) {
- t.Error("last entry was removed")
- }
- if contains(tab.buckets[200].entries, new.ID) {
- t.Error("new entry was added")
+ if lastInBucketIsResponding || !newNodeIsResponding {
+ if !contains(tab.buckets[200].entries, last.ID) {
+ t.Error("last entry was removed")
+ }
+ if contains(tab.buckets[200].entries, pingSender) {
+ t.Error("new entry was added")
+ }
+ } else {
+ if contains(tab.buckets[200].entries, last.ID) {
+ t.Error("last entry was not removed")
+ }
+ if !contains(tab.buckets[200].entries, pingSender) {
+ t.Error("new entry was not added")
+ }
+ }
}
-}
-
-func TestTable_bumpOrAddPingTimeout(t *testing.T) {
- tab := newTable(pingC(nil), NodeID{}, &net.UDPAddr{})
- last := fillBucket(tab, 200)
- // this bumpOrAdd should replace the last node
- // because the node does not reply to ping.
- new := tab.bumpOrAdd(randomID(tab.self.ID, 200), &net.UDPAddr{})
-
- // wait for async bucket update. damn. this needs to go away.
- time.Sleep(2 * time.Millisecond)
-
- tab.mutex.Lock()
- defer tab.mutex.Unlock()
- if l := len(tab.buckets[200].entries); l != bucketSize {
- t.Errorf("wrong bucket size after bumpOrAdd: got %d, want %d", bucketSize, l)
- }
- if contains(tab.buckets[200].entries, last.ID) {
- t.Error("last entry was not removed")
- }
- if !contains(tab.buckets[200].entries, new.ID) {
- t.Error("new entry was not added")
- }
+ doit(true, true)
+ doit(false, true)
+ doit(false, true)
+ doit(false, false)
}
func fillBucket(tab *Table, ld int) (last *Node) {
@@ -85,44 +74,27 @@ func fillBucket(tab *Table, ld int) (last *Node) {
return b.entries[bucketSize-1]
}
-type pingC chan NodeID
+type pingRecorder struct{ responding, pinged map[NodeID]bool }
-func (t pingC) findnode(n *Node, target NodeID) ([]*Node, error) {
+func newPingRecorder() *pingRecorder {
+ return &pingRecorder{make(map[NodeID]bool), make(map[NodeID]bool)}
+}
+
+func (t *pingRecorder) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) {
panic("findnode called on pingRecorder")
}
-func (t pingC) close() {
+func (t *pingRecorder) close() {
panic("close called on pingRecorder")
}
-func (t pingC) ping(n *Node) error {
- if t == nil {
- return errTimeout
- }
- t <- n.ID
- return nil
+func (t *pingRecorder) waitping(from NodeID) error {
+ return nil // remote always pings
}
-
-func TestTable_bump(t *testing.T) {
- tab := newTable(nil, NodeID{}, &net.UDPAddr{})
-
- // add an old entry and two recent ones
- oldactive := time.Now().Add(-2 * time.Minute)
- old := &Node{ID: randomID(tab.self.ID, 200), active: oldactive}
- others := []*Node{
- &Node{ID: randomID(tab.self.ID, 200), active: time.Now()},
- &Node{ID: randomID(tab.self.ID, 200), active: time.Now()},
- }
- tab.add(append(others, old))
- if tab.buckets[200].entries[0] == old {
- t.Fatal("old entry is at front of bucket")
- }
-
- // bumping the old entry should move it to the front
- tab.bump(old.ID)
- if old.active == oldactive {
- t.Error("activity timestamp not updated")
- }
- if tab.buckets[200].entries[0] != old {
- t.Errorf("bumped entry did not move to the front of bucket")
+func (t *pingRecorder) ping(toid NodeID, toaddr *net.UDPAddr) error {
+ t.pinged[toid] = true
+ if t.responding[toid] {
+ return nil
+ } else {
+ return errTimeout
}
}
@@ -210,7 +182,7 @@ func TestTable_Lookup(t *testing.T) {
t.Fatalf("lookup on empty table returned %d results: %#v", len(results), results)
}
// seed table with initial node (otherwise lookup will terminate immediately)
- tab.bumpOrAdd(randomID(target, 200), &net.UDPAddr{Port: 200})
+ tab.add([]*Node{newNode(randomID(target, 200), &net.UDPAddr{Port: 200})})
results := tab.Lookup(target)
t.Logf("results:")
@@ -238,16 +210,16 @@ type findnodeOracle struct {
target NodeID
}
-func (t findnodeOracle) findnode(n *Node, target NodeID) ([]*Node, error) {
- t.t.Logf("findnode query at dist %d", n.DiscPort)
+func (t findnodeOracle) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) {
+ t.t.Logf("findnode query at dist %d", toaddr.Port)
// current log distance is encoded in port number
var result []*Node
- switch n.DiscPort {
+ switch toaddr.Port {
case 0:
panic("query to node at distance 0")
default:
// TODO: add more randomness to distances
- next := n.DiscPort - 1
+ next := toaddr.Port - 1
for i := 0; i < bucketSize; i++ {
result = append(result, &Node{ID: randomID(t.target, next), DiscPort: next})
}
@@ -255,11 +227,9 @@ func (t findnodeOracle) findnode(n *Node, target NodeID) ([]*Node, error) {
return result, nil
}
-func (t findnodeOracle) close() {}
-
-func (t findnodeOracle) ping(n *Node) error {
- return errors.New("ping is not supported by this transport")
-}
+func (t findnodeOracle) close() {}
+func (t findnodeOracle) waitping(from NodeID) error { return nil }
+func (t findnodeOracle) ping(toid NodeID, toaddr *net.UDPAddr) error { return nil }
func hasDuplicates(slice []*Node) bool {
seen := make(map[NodeID]bool)
diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go
index 738a01fb7..e9ede1397 100644
--- a/p2p/discover/udp.go
+++ b/p2p/discover/udp.go
@@ -20,12 +20,14 @@ const Version = 3
// Errors
var (
- errPacketTooSmall = errors.New("too small")
- errBadHash = errors.New("bad hash")
- errExpired = errors.New("expired")
- errBadVersion = errors.New("version mismatch")
- errTimeout = errors.New("RPC timeout")
- errClosed = errors.New("socket closed")
+ errPacketTooSmall = errors.New("too small")
+ errBadHash = errors.New("bad hash")
+ errExpired = errors.New("expired")
+ errBadVersion = errors.New("version mismatch")
+ errUnsolicitedReply = errors.New("unsolicited reply")
+ errUnknownNode = errors.New("unknown node")
+ errTimeout = errors.New("RPC timeout")
+ errClosed = errors.New("socket closed")
)
// Timeouts
@@ -80,14 +82,27 @@ type rpcNode struct {
ID NodeID
}
+type packet interface {
+ handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error
+}
+
+type conn interface {
+ ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error)
+ WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error)
+ Close() error
+ LocalAddr() net.Addr
+}
+
// udp implements the RPC protocol.
type udp struct {
- conn *net.UDPConn
- priv *ecdsa.PrivateKey
+ conn conn
+ priv *ecdsa.PrivateKey
+
addpending chan *pending
- replies chan reply
- closing chan struct{}
- nat nat.Interface
+ gotreply chan reply
+
+ closing chan struct{}
+ nat nat.Interface
*Table
}
@@ -124,6 +139,9 @@ type reply struct {
from NodeID
ptype byte
data interface{}
+ // loop indicates whether there was
+ // a matching request by sending on this channel.
+ matched chan<- bool
}
// ListenUDP returns a new table that listens for UDP packets on laddr.
@@ -136,15 +154,20 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface) (*Table
if err != nil {
return nil, err
}
+ tab, _ := newUDP(priv, conn, natm)
+ log.Infoln("Listening,", tab.self)
+ return tab, nil
+}
+
+func newUDP(priv *ecdsa.PrivateKey, c conn, natm nat.Interface) (*Table, *udp) {
udp := &udp{
- conn: conn,
+ conn: c,
priv: priv,
closing: make(chan struct{}),
+ gotreply: make(chan reply),
addpending: make(chan *pending),
- replies: make(chan reply),
}
-
- realaddr := conn.LocalAddr().(*net.UDPAddr)
+ realaddr := c.LocalAddr().(*net.UDPAddr)
if natm != nil {
if !realaddr.IP.IsLoopback() {
go nat.Map(natm, udp.closing, "udp", realaddr.Port, realaddr.Port, "ethereum discovery")
@@ -155,11 +178,9 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface) (*Table
}
}
udp.Table = newTable(udp, PubkeyID(&priv.PublicKey), realaddr)
-
go udp.loop()
go udp.readLoop()
- log.Infoln("Listening, ", udp.self)
- return udp.Table, nil
+ return udp.Table, udp
}
func (t *udp) close() {
@@ -169,10 +190,10 @@ func (t *udp) close() {
}
// ping sends a ping message to the given node and waits for a reply.
-func (t *udp) ping(e *Node) error {
+func (t *udp) ping(toid NodeID, toaddr *net.UDPAddr) error {
// TODO: maybe check for ReplyTo field in callback to measure RTT
- errc := t.pending(e.ID, pongPacket, func(interface{}) bool { return true })
- t.send(e, pingPacket, ping{
+ errc := t.pending(toid, pongPacket, func(interface{}) bool { return true })
+ t.send(toaddr, pingPacket, ping{
Version: Version,
IP: t.self.IP.String(),
Port: uint16(t.self.TCPPort),
@@ -181,12 +202,16 @@ func (t *udp) ping(e *Node) error {
return <-errc
}
+func (t *udp) waitping(from NodeID) error {
+ return <-t.pending(from, pingPacket, func(interface{}) bool { return true })
+}
+
// findnode sends a findnode request to the given node and waits until
// the node has sent up to k neighbors.
-func (t *udp) findnode(to *Node, target NodeID) ([]*Node, error) {
+func (t *udp) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) {
nodes := make([]*Node, 0, bucketSize)
nreceived := 0
- errc := t.pending(to.ID, neighborsPacket, func(r interface{}) bool {
+ errc := t.pending(toid, neighborsPacket, func(r interface{}) bool {
reply := r.(*neighbors)
for _, n := range reply.Nodes {
nreceived++
@@ -196,8 +221,7 @@ func (t *udp) findnode(to *Node, target NodeID) ([]*Node, error) {
}
return nreceived >= bucketSize
})
-
- t.send(to, findnodePacket, findnode{
+ t.send(toaddr, findnodePacket, findnode{
Target: target,
Expiration: uint64(time.Now().Add(expiration).Unix()),
})
@@ -219,6 +243,17 @@ func (t *udp) pending(id NodeID, ptype byte, callback func(interface{}) bool) <-
return ch
}
+func (t *udp) handleReply(from NodeID, ptype byte, req packet) bool {
+ matched := make(chan bool)
+ select {
+ case t.gotreply <- reply{from, ptype, req, matched}:
+ // loop will handle it
+ return <-matched
+ case <-t.closing:
+ return false
+ }
+}
+
// loop runs in its own goroutin. it keeps track of
// the refresh timer and the pending reply queue.
func (t *udp) loop() {
@@ -249,6 +284,7 @@ func (t *udp) loop() {
for _, p := range pending {
p.errc <- errClosed
}
+ pending = nil
return
case p := <-t.addpending:
@@ -256,18 +292,21 @@ func (t *udp) loop() {
pending = append(pending, p)
rearmTimeout()
- case reply := <-t.replies:
- // run matching callbacks, remove if they return false.
+ case r := <-t.gotreply:
+ var matched bool
for i := 0; i < len(pending); i++ {
- p := pending[i]
- if reply.from == p.from && reply.ptype == p.ptype && p.callback(reply.data) {
- p.errc <- nil
- copy(pending[i:], pending[i+1:])
- pending = pending[:len(pending)-1]
- i--
+ if p := pending[i]; p.from == r.from && p.ptype == r.ptype {
+ matched = true
+ if p.callback(r.data) {
+ // callback indicates the request is done, remove it.
+ p.errc <- nil
+ copy(pending[i:], pending[i+1:])
+ pending = pending[:len(pending)-1]
+ i--
+ }
}
}
- rearmTimeout()
+ r.matched <- matched
case now := <-timeout.C:
// notify and remove callbacks whose deadline is in the past.
@@ -292,33 +331,38 @@ const (
var headSpace = make([]byte, headSize)
-func (t *udp) send(to *Node, ptype byte, req interface{}) error {
+func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req interface{}) error {
+ packet, err := encodePacket(t.priv, ptype, req)
+ if err != nil {
+ return err
+ }
+ log.DebugDetailf(">>> %v %T %v\n", toaddr, req, req)
+ if _, err = t.conn.WriteToUDP(packet, toaddr); err != nil {
+ log.DebugDetailln("UDP send failed:", err)
+ }
+ return err
+}
+
+func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) ([]byte, error) {
b := new(bytes.Buffer)
b.Write(headSpace)
b.WriteByte(ptype)
if err := rlp.Encode(b, req); err != nil {
log.Errorln("error encoding packet:", err)
- return err
+ return nil, err
}
-
packet := b.Bytes()
- sig, err := crypto.Sign(crypto.Sha3(packet[headSize:]), t.priv)
+ sig, err := crypto.Sign(crypto.Sha3(packet[headSize:]), priv)
if err != nil {
log.Errorln("could not sign packet:", err)
- return err
+ return nil, err
}
copy(packet[macSize:], sig)
// add the hash to the front. Note: this doesn't protect the
// packet in any way. Our public key will be part of this hash in
- // the future.
+ // The future.
copy(packet, crypto.Sha3(packet[macSize:]))
-
- toaddr := &net.UDPAddr{IP: to.IP, Port: to.DiscPort}
- log.DebugDetailf(">>> %v %T %v\n", toaddr, req, req)
- if _, err = t.conn.WriteToUDP(packet, toaddr); err != nil {
- log.DebugDetailln("UDP send failed:", err)
- }
- return err
+ return packet, nil
}
// readLoop runs in its own goroutine. it handles incoming UDP packets.
@@ -330,29 +374,34 @@ func (t *udp) readLoop() {
if err != nil {
return
}
- if err := t.packetIn(from, buf[:nbytes]); err != nil {
+ packet, fromID, hash, err := decodePacket(buf[:nbytes])
+ if err != nil {
log.Debugf("Bad packet from %v: %v\n", from, err)
+ continue
}
+ log.DebugDetailf("<<< %v %T %v\n", from, packet, packet)
+ go func() {
+ if err := packet.handle(t, from, fromID, hash); err != nil {
+ log.Debugf("error handling %T from %v: %v", packet, from, err)
+ }
+ }()
}
}
-func (t *udp) packetIn(from *net.UDPAddr, buf []byte) error {
+func decodePacket(buf []byte) (packet, NodeID, []byte, error) {
if len(buf) < headSize+1 {
- return errPacketTooSmall
+ return nil, NodeID{}, nil, errPacketTooSmall
}
hash, sig, sigdata := buf[:macSize], buf[macSize:headSize], buf[headSize:]
shouldhash := crypto.Sha3(buf[macSize:])
if !bytes.Equal(hash, shouldhash) {
- return errBadHash
+ return nil, NodeID{}, nil, errBadHash
}
fromID, err := recoverNodeID(crypto.Sha3(buf[headSize:]), sig)
if err != nil {
- return err
- }
-
- var req interface {
- handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error
+ return nil, NodeID{}, hash, err
}
+ var req packet
switch ptype := sigdata[0]; ptype {
case pingPacket:
req = new(ping)
@@ -363,13 +412,10 @@ func (t *udp) packetIn(from *net.UDPAddr, buf []byte) error {
case neighborsPacket:
req = new(neighbors)
default:
- return fmt.Errorf("unknown type: %d", ptype)
+ return nil, fromID, hash, fmt.Errorf("unknown type: %d", ptype)
}
- if err := rlp.Decode(bytes.NewReader(sigdata[1:]), req); err != nil {
- return err
- }
- log.DebugDetailf("<<< %v %T %v\n", from, req, req)
- return req.handle(t, from, fromID, hash)
+ err = rlp.Decode(bytes.NewReader(sigdata[1:]), req)
+ return req, fromID, hash, err
}
func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
@@ -379,18 +425,14 @@ func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) er
if req.Version != Version {
return errBadVersion
}
- t.mutex.Lock()
- // Note: we're ignoring the provided IP address right now
- n := t.bumpOrAdd(fromID, from)
- if req.Port != 0 {
- n.TCPPort = int(req.Port)
- }
- t.mutex.Unlock()
-
- t.send(n, pongPacket, pong{
+ t.send(from, pongPacket, pong{
ReplyTok: mac,
Expiration: uint64(time.Now().Add(expiration).Unix()),
})
+ if !t.handleReply(fromID, pingPacket, req) {
+ // Note: we're ignoring the provided IP address right now
+ t.bond(true, fromID, from, req.Port)
+ }
return nil
}
@@ -398,11 +440,9 @@ func (req *pong) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) er
if expired(req.Expiration) {
return errExpired
}
- t.mutex.Lock()
- t.bump(fromID)
- t.mutex.Unlock()
-
- t.replies <- reply{fromID, pongPacket, req}
+ if !t.handleReply(fromID, pongPacket, req) {
+ return errUnsolicitedReply
+ }
return nil
}
@@ -410,12 +450,21 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte
if expired(req.Expiration) {
return errExpired
}
+ if t.db.get(fromID) == nil {
+ // No bond exists, we don't process the packet. This prevents
+ // an attack vector where the discovery protocol could be used
+ // to amplify traffic in a DDOS attack. A malicious actor
+ // would send a findnode request with the IP address and UDP
+ // port of the target as the source address. The recipient of
+ // the findnode packet would then send a neighbors packet
+ // (which is a much bigger packet than findnode) to the victim.
+ return errUnknownNode
+ }
t.mutex.Lock()
- e := t.bumpOrAdd(fromID, from)
closest := t.closest(req.Target, bucketSize).entries
t.mutex.Unlock()
- t.send(e, neighborsPacket, neighbors{
+ t.send(from, neighborsPacket, neighbors{
Nodes: closest,
Expiration: uint64(time.Now().Add(expiration).Unix()),
})
@@ -426,12 +475,9 @@ func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byt
if expired(req.Expiration) {
return errExpired
}
- t.mutex.Lock()
- t.bump(fromID)
- t.add(req.Nodes)
- t.mutex.Unlock()
-
- t.replies <- reply{fromID, neighborsPacket, req}
+ if !t.handleReply(fromID, neighborsPacket, req) {
+ return errUnsolicitedReply
+ }
return nil
}
diff --git a/p2p/discover/udp_test.go b/p2p/discover/udp_test.go
index 0a8ff6358..c6c4d78e3 100644
--- a/p2p/discover/udp_test.go
+++ b/p2p/discover/udp_test.go
@@ -1,10 +1,18 @@
package discover
import (
+ "bytes"
+ "crypto/ecdsa"
+ "errors"
"fmt"
+ "io"
logpkg "log"
"net"
"os"
+ "path"
+ "reflect"
+ "runtime"
+ "sync"
"testing"
"time"
@@ -15,197 +23,317 @@ func init() {
logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, logpkg.LstdFlags, logger.ErrorLevel))
}
-func TestUDP_ping(t *testing.T) {
- t.Parallel()
-
- n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
- n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
- defer n1.Close()
- defer n2.Close()
+type udpTest struct {
+ t *testing.T
+ pipe *dgramPipe
+ table *Table
+ udp *udp
+ sent [][]byte
+ localkey, remotekey *ecdsa.PrivateKey
+ remoteaddr *net.UDPAddr
+}
- if err := n1.net.ping(n2.self); err != nil {
- t.Fatalf("ping error: %v", err)
+func newUDPTest(t *testing.T) *udpTest {
+ test := &udpTest{
+ t: t,
+ pipe: newpipe(),
+ localkey: newkey(),
+ remotekey: newkey(),
+ remoteaddr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 30303},
}
- if find(n2, n1.self.ID) == nil {
- t.Errorf("node 2 does not contain id of node 1")
+ test.table, test.udp = newUDP(test.localkey, test.pipe, nil)
+ return test
+}
+
+// handles a packet as if it had been sent to the transport.
+func (test *udpTest) packetIn(wantError error, ptype byte, data packet) error {
+ enc, err := encodePacket(test.remotekey, ptype, data)
+ if err != nil {
+ return test.errorf("packet (%d) encode error: %v", err)
}
- if e := find(n1, n2.self.ID); e != nil {
- t.Errorf("node 1 does contains id of node 2: %v", e)
+ test.sent = append(test.sent, enc)
+ err = data.handle(test.udp, test.remoteaddr, PubkeyID(&test.remotekey.PublicKey), enc[:macSize])
+ if err != wantError {
+ return test.errorf("error mismatch: got %q, want %q", err, wantError)
}
+ return nil
}
-func find(tab *Table, id NodeID) *Node {
- for _, b := range tab.buckets {
- for _, e := range b.entries {
- if e.ID == id {
- return e
- }
- }
+// waits for a packet to be sent by the transport.
+// validate should have type func(*udpTest, X) error, where X is a packet type.
+func (test *udpTest) waitPacketOut(validate interface{}) error {
+ dgram := test.pipe.waitPacketOut()
+ p, _, _, err := decodePacket(dgram)
+ if err != nil {
+ return test.errorf("sent packet decode error: %v", err)
}
+ fn := reflect.ValueOf(validate)
+ exptype := fn.Type().In(0)
+ if reflect.TypeOf(p) != exptype {
+ return test.errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype)
+ }
+ fn.Call([]reflect.Value{reflect.ValueOf(p)})
return nil
}
-func TestUDP_findnode(t *testing.T) {
+func (test *udpTest) errorf(format string, args ...interface{}) error {
+ _, file, line, ok := runtime.Caller(2) // errorf + waitPacketOut
+ if ok {
+ file = path.Base(file)
+ } else {
+ file = "???"
+ line = 1
+ }
+ err := fmt.Errorf(format, args...)
+ fmt.Printf("\t%s:%d: %v\n", file, line, err)
+ test.t.Fail()
+ return err
+}
+
+// shared test variables
+var (
+ futureExp = uint64(time.Now().Add(10 * time.Hour).Unix())
+ testTarget = MustHexID("01010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101")
+)
+
+func TestUDP_packetErrors(t *testing.T) {
+ test := newUDPTest(t)
+ defer test.table.Close()
+
+ test.packetIn(errExpired, pingPacket, &ping{IP: "foo", Port: 99, Version: Version})
+ test.packetIn(errBadVersion, pingPacket, &ping{IP: "foo", Port: 99, Version: 99, Expiration: futureExp})
+ test.packetIn(errUnsolicitedReply, pongPacket, &pong{ReplyTok: []byte{}, Expiration: futureExp})
+ test.packetIn(errUnknownNode, findnodePacket, &findnode{Expiration: futureExp})
+ test.packetIn(errUnsolicitedReply, neighborsPacket, &neighbors{Expiration: futureExp})
+}
+
+func TestUDP_pingTimeout(t *testing.T) {
+ t.Parallel()
+ test := newUDPTest(t)
+ defer test.table.Close()
+
+ toaddr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 2222}
+ toid := NodeID{1, 2, 3, 4}
+ if err := test.udp.ping(toid, toaddr); err != errTimeout {
+ t.Error("expected timeout error, got", err)
+ }
+}
+
+func TestUDP_findnodeTimeout(t *testing.T) {
t.Parallel()
+ test := newUDPTest(t)
+ defer test.table.Close()
- n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
- n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
- defer n1.Close()
- defer n2.Close()
+ toaddr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 2222}
+ toid := NodeID{1, 2, 3, 4}
+ target := NodeID{4, 5, 6, 7}
+ result, err := test.udp.findnode(toid, toaddr, target)
+ if err != errTimeout {
+ t.Error("expected timeout error, got", err)
+ }
+ if len(result) > 0 {
+ t.Error("expected empty result, got", result)
+ }
+}
- // put a few nodes into n2. the exact distribution shouldn't
- // matter much, altough we need to take care not to overflow
- // any bucket.
- target := randomID(n1.self.ID, 100)
+func TestUDP_findnode(t *testing.T) {
+ test := newUDPTest(t)
+ defer test.table.Close()
+
+ // put a few nodes into the table. their exact
+ // distribution shouldn't matter much, altough we need to
+ // take care not to overflow any bucket.
+ target := testTarget
nodes := &nodesByDistance{target: target}
for i := 0; i < bucketSize; i++ {
- n2.add([]*Node{&Node{
+ nodes.push(&Node{
IP: net.IP{1, 2, 3, byte(i)},
DiscPort: i + 2,
TCPPort: i + 2,
- ID: randomID(n2.self.ID, i+2),
- }})
+ ID: randomID(test.table.self.ID, i+2),
+ }, bucketSize)
}
- n2.add(nodes.entries)
- n2.bumpOrAdd(n1.self.ID, &net.UDPAddr{IP: n1.self.IP, Port: n1.self.DiscPort})
- expected := n2.closest(target, bucketSize)
+ test.table.add(nodes.entries)
+
+ // ensure there's a bond with the test node,
+ // findnode won't be accepted otherwise.
+ test.table.db.add(PubkeyID(&test.remotekey.PublicKey), test.remoteaddr, 99)
- err := runUDP(10, func() error {
- result, _ := n1.net.findnode(n2.self, target)
- if len(result) != bucketSize {
- return fmt.Errorf("wrong number of results: got %d, want %d", len(result), bucketSize)
+ // check that closest neighbors are returned.
+ test.packetIn(nil, findnodePacket, &findnode{Target: testTarget, Expiration: futureExp})
+ test.waitPacketOut(func(p *neighbors) {
+ expected := test.table.closest(testTarget, bucketSize)
+ if len(p.Nodes) != bucketSize {
+ t.Errorf("wrong number of results: got %d, want %d", len(p.Nodes), bucketSize)
}
- for i := range result {
- if result[i].ID != expected.entries[i].ID {
- return fmt.Errorf("result mismatch at %d:\n got: %v\n want: %v", i, result[i], expected.entries[i])
+ for i := range p.Nodes {
+ if p.Nodes[i].ID != expected.entries[i].ID {
+ t.Errorf("result mismatch at %d:\n got: %v\n want: %v", i, p.Nodes[i], expected.entries[i])
}
}
- return nil
})
- if err != nil {
- t.Error(err)
- }
}
-func TestUDP_replytimeout(t *testing.T) {
- t.Parallel()
+func TestUDP_findnodeMultiReply(t *testing.T) {
+ test := newUDPTest(t)
+ defer test.table.Close()
- // reserve a port so we don't talk to an existing service by accident
- addr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
- fd, err := net.ListenUDP("udp", addr)
- if err != nil {
- t.Fatal(err)
- }
- defer fd.Close()
+ // queue a pending findnode request
+ resultc, errc := make(chan []*Node), make(chan error)
+ go func() {
+ rid := PubkeyID(&test.remotekey.PublicKey)
+ ns, err := test.udp.findnode(rid, test.remoteaddr, testTarget)
+ if err != nil && len(ns) == 0 {
+ errc <- err
+ } else {
+ resultc <- ns
+ }
+ }()
- n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
- defer n1.Close()
- n2 := n1.bumpOrAdd(randomID(n1.self.ID, 10), fd.LocalAddr().(*net.UDPAddr))
+ // wait for the findnode to be sent.
+ // after it is sent, the transport is waiting for a reply
+ test.waitPacketOut(func(p *findnode) {
+ if p.Target != testTarget {
+ t.Errorf("wrong target: got %v, want %v", p.Target, testTarget)
+ }
+ })
- if err := n1.net.ping(n2); err != errTimeout {
- t.Error("expected timeout error, got", err)
+ // send the reply as two packets.
+ list := []*Node{
+ MustParseNode("enode://ba85011c70bcc5c04d8607d3a0ed29aa6179c092cbdda10d5d32684fb33ed01bd94f588ca8f91ac48318087dcb02eaf36773a7a453f0eedd6742af668097b29c@10.0.1.16:30303"),
+ MustParseNode("enode://81fa361d25f157cd421c60dcc28d8dac5ef6a89476633339c5df30287474520caca09627da18543d9079b5b288698b542d56167aa5c09111e55acdbbdf2ef799@10.0.1.16:30303"),
+ MustParseNode("enode://9bffefd833d53fac8e652415f4973bee289e8b1a5c6c4cbe70abf817ce8a64cee11b823b66a987f51aaa9fba0d6a91b3e6bf0d5a5d1042de8e9eeea057b217f8@10.0.1.36:30301"),
+ MustParseNode("enode://1b5b4aa662d7cb44a7221bfba67302590b643028197a7d5214790f3bac7aaa4a3241be9e83c09cf1f6c69d007c634faae3dc1b1221793e8446c0b3a09de65960@10.0.1.16:30303"),
}
+ test.packetIn(nil, neighborsPacket, &neighbors{Expiration: futureExp, Nodes: list[:2]})
+ test.packetIn(nil, neighborsPacket, &neighbors{Expiration: futureExp, Nodes: list[2:]})
- if result, err := n1.net.findnode(n2, n1.self.ID); err != errTimeout {
- t.Error("expected timeout error, got", err)
- } else if len(result) > 0 {
- t.Error("expected empty result, got", result)
+ // check that the sent neighbors are all returned by findnode
+ select {
+ case result := <-resultc:
+ if !reflect.DeepEqual(result, list) {
+ t.Errorf("neighbors mismatch:\n got: %v\n want: %v", result, list)
+ }
+ case err := <-errc:
+ t.Errorf("findnode error: %v", err)
+ case <-time.After(5 * time.Second):
+ t.Error("findnode did not return within 5 seconds")
}
}
-func TestUDP_findnodeMultiReply(t *testing.T) {
- t.Parallel()
+func TestUDP_successfulPing(t *testing.T) {
+ test := newUDPTest(t)
+ defer test.table.Close()
- n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
- n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
- udp2 := n2.net.(*udp)
- defer n1.Close()
- defer n2.Close()
-
- err := runUDP(10, func() error {
- nodes := make([]*Node, bucketSize)
- for i := range nodes {
- nodes[i] = &Node{
- IP: net.IP{1, 2, 3, 4},
- DiscPort: i + 1,
- TCPPort: i + 1,
- ID: randomID(n2.self.ID, i+1),
- }
- }
+ done := make(chan struct{})
+ go func() {
+ test.packetIn(nil, pingPacket, &ping{IP: "foo", Port: 99, Version: Version, Expiration: futureExp})
+ close(done)
+ }()
- // ask N2 for neighbors. it will send an empty reply back.
- // the request will wait for up to bucketSize replies.
- resultc := make(chan []*Node)
- errc := make(chan error)
- go func() {
- ns, err := n1.net.findnode(n2.self, n1.self.ID)
- if err != nil {
- errc <- err
- } else {
- resultc <- ns
- }
- }()
-
- // send a few more neighbors packets to N1.
- // it should collect those.
- for end := 0; end < len(nodes); {
- off := end
- if end = end + 5; end > len(nodes) {
- end = len(nodes)
- }
- udp2.send(n1.self, neighborsPacket, neighbors{
- Nodes: nodes[off:end],
- Expiration: uint64(time.Now().Add(10 * time.Second).Unix()),
- })
+ // the ping is replied to.
+ test.waitPacketOut(func(p *pong) {
+ pinghash := test.sent[0][:macSize]
+ if !bytes.Equal(p.ReplyTok, pinghash) {
+ t.Errorf("got ReplyTok %x, want %x", p.ReplyTok, pinghash)
}
+ })
- // check that they are all returned. we cannot just check for
- // equality because they might not be returned in the order they
- // were sent.
- var result []*Node
- select {
- case result = <-resultc:
- case err := <-errc:
- return err
- }
- if hasDuplicates(result) {
- return fmt.Errorf("result slice contains duplicates")
- }
- if len(result) != len(nodes) {
- return fmt.Errorf("wrong number of nodes returned: got %d, want %d", len(result), len(nodes))
- }
- matched := make(map[NodeID]bool)
- for _, n := range result {
- for _, expn := range nodes {
- if n.ID == expn.ID { // && bytes.Equal(n.Addr.IP, expn.Addr.IP) && n.Addr.Port == expn.Addr.Port {
- matched[n.ID] = true
- }
+ // remote is unknown, the table pings back.
+ test.waitPacketOut(func(p *ping) error { return nil })
+ test.packetIn(nil, pongPacket, &pong{Expiration: futureExp})
+
+ // ping should return shortly after getting the pong packet.
+ <-done
+
+ // check that the node was added.
+ rid := PubkeyID(&test.remotekey.PublicKey)
+ rnode := find(test.table, rid)
+ if rnode == nil {
+ t.Fatalf("node %v not found in table", rid)
+ }
+ if !bytes.Equal(rnode.IP, test.remoteaddr.IP) {
+ t.Errorf("node has wrong IP: got %v, want: %v", rnode.IP, test.remoteaddr.IP)
+ }
+ if rnode.DiscPort != test.remoteaddr.Port {
+ t.Errorf("node has wrong Port: got %v, want: %v", rnode.DiscPort, test.remoteaddr.Port)
+ }
+ if rnode.TCPPort != 99 {
+ t.Errorf("node has wrong Port: got %v, want: %v", rnode.TCPPort, 99)
+ }
+}
+
+func find(tab *Table, id NodeID) *Node {
+ for _, b := range tab.buckets {
+ for _, e := range b.entries {
+ if e.ID == id {
+ return e
}
}
- if len(matched) != len(nodes) {
- return fmt.Errorf("wrong number of matching nodes: got %d, want %d", len(matched), len(nodes))
- }
- return nil
- })
- if err != nil {
- t.Error(err)
}
+ return nil
}
-// runUDP runs a test n times and returns an error if the test failed
-// in all n runs. This is necessary because UDP is unreliable even for
-// connections on the local machine, causing test failures.
-func runUDP(n int, test func() error) error {
- errcount := 0
- errors := ""
- for i := 0; i < n; i++ {
- if err := test(); err != nil {
- errors += fmt.Sprintf("\n#%d: %v", i, err)
- errcount++
- }
+// dgramPipe is a fake UDP socket. It queues all sent datagrams.
+type dgramPipe struct {
+ mu *sync.Mutex
+ cond *sync.Cond
+ closing chan struct{}
+ closed bool
+ queue [][]byte
+}
+
+func newpipe() *dgramPipe {
+ mu := new(sync.Mutex)
+ return &dgramPipe{
+ closing: make(chan struct{}),
+ cond: &sync.Cond{L: mu},
+ mu: mu,
}
- if errcount == n {
- return fmt.Errorf("failed on all %d iterations:%s", n, errors)
+}
+
+// WriteToUDP queues a datagram.
+func (c *dgramPipe) WriteToUDP(b []byte, to *net.UDPAddr) (n int, err error) {
+ msg := make([]byte, len(b))
+ copy(msg, b)
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.closed {
+ return 0, errors.New("closed")
+ }
+ c.queue = append(c.queue, msg)
+ c.cond.Signal()
+ return len(b), nil
+}
+
+// ReadFromUDP just hangs until the pipe is closed.
+func (c *dgramPipe) ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) {
+ <-c.closing
+ return 0, nil, io.EOF
+}
+
+func (c *dgramPipe) Close() error {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if !c.closed {
+ close(c.closing)
+ c.closed = true
}
return nil
}
+
+func (c *dgramPipe) LocalAddr() net.Addr {
+ return &net.UDPAddr{}
+}
+
+func (c *dgramPipe) waitPacketOut() []byte {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ for len(c.queue) == 0 {
+ c.cond.Wait()
+ }
+ p := c.queue[0]
+ copy(c.queue, c.queue[1:])
+ c.queue = c.queue[:len(c.queue)-1]
+ return p
+}