aboutsummaryrefslogtreecommitdiffstats
path: root/p2p/discover
diff options
context:
space:
mode:
Diffstat (limited to 'p2p/discover')
-rw-r--r--p2p/discover/database.go17
-rw-r--r--p2p/discover/database_test.go18
-rw-r--r--p2p/discover/node.go6
-rw-r--r--p2p/discover/table.go486
-rw-r--r--p2p/discover/table_test.go171
-rw-r--r--p2p/discover/udp.go88
-rw-r--r--p2p/discover/udp_test.go29
7 files changed, 543 insertions, 272 deletions
diff --git a/p2p/discover/database.go b/p2p/discover/database.go
index b136609f2..6f98de9b4 100644
--- a/p2p/discover/database.go
+++ b/p2p/discover/database.go
@@ -257,7 +257,7 @@ func (db *nodeDB) expireNodes() error {
}
// Skip the node if not expired yet (and not self)
if !bytes.Equal(id[:], db.self[:]) {
- if seen := db.lastPong(id); seen.After(threshold) {
+ if seen := db.bondTime(id); seen.After(threshold) {
continue
}
}
@@ -278,13 +278,18 @@ func (db *nodeDB) updateLastPing(id NodeID, instance time.Time) error {
return db.storeInt64(makeKey(id, nodeDBDiscoverPing), instance.Unix())
}
-// lastPong retrieves the time of the last successful contact from remote node.
-func (db *nodeDB) lastPong(id NodeID) time.Time {
+// bondTime retrieves the time of the last successful pong from remote node.
+func (db *nodeDB) bondTime(id NodeID) time.Time {
return time.Unix(db.fetchInt64(makeKey(id, nodeDBDiscoverPong)), 0)
}
-// updateLastPong updates the last time a remote node successfully contacted.
-func (db *nodeDB) updateLastPong(id NodeID, instance time.Time) error {
+// hasBond reports whether the given node is considered bonded.
+func (db *nodeDB) hasBond(id NodeID) bool {
+ return time.Since(db.bondTime(id)) < nodeDBNodeExpiration
+}
+
+// updateBondTime updates the last pong time of a node.
+func (db *nodeDB) updateBondTime(id NodeID, instance time.Time) error {
return db.storeInt64(makeKey(id, nodeDBDiscoverPong), instance.Unix())
}
@@ -327,7 +332,7 @@ seek:
if n.ID == db.self {
continue seek
}
- if now.Sub(db.lastPong(n.ID)) > maxAge {
+ if now.Sub(db.bondTime(n.ID)) > maxAge {
continue seek
}
for i := range nodes {
diff --git a/p2p/discover/database_test.go b/p2p/discover/database_test.go
index be972fd2c..c4fa44d09 100644
--- a/p2p/discover/database_test.go
+++ b/p2p/discover/database_test.go
@@ -125,13 +125,13 @@ func TestNodeDBFetchStore(t *testing.T) {
t.Errorf("ping: value mismatch: have %v, want %v", stored, inst)
}
// Check fetch/store operations on a node pong object
- if stored := db.lastPong(node.ID); stored.Unix() != 0 {
+ if stored := db.bondTime(node.ID); stored.Unix() != 0 {
t.Errorf("pong: non-existing object: %v", stored)
}
- if err := db.updateLastPong(node.ID, inst); err != nil {
+ if err := db.updateBondTime(node.ID, inst); err != nil {
t.Errorf("pong: failed to update: %v", err)
}
- if stored := db.lastPong(node.ID); stored.Unix() != inst.Unix() {
+ if stored := db.bondTime(node.ID); stored.Unix() != inst.Unix() {
t.Errorf("pong: value mismatch: have %v, want %v", stored, inst)
}
// Check fetch/store operations on a node findnode-failure object
@@ -224,8 +224,8 @@ func TestNodeDBSeedQuery(t *testing.T) {
if err := db.updateNode(seed.node); err != nil {
t.Fatalf("node %d: failed to insert: %v", i, err)
}
- if err := db.updateLastPong(seed.node.ID, seed.pong); err != nil {
- t.Fatalf("node %d: failed to insert lastPong: %v", i, err)
+ if err := db.updateBondTime(seed.node.ID, seed.pong); err != nil {
+ t.Fatalf("node %d: failed to insert bondTime: %v", i, err)
}
}
@@ -332,8 +332,8 @@ func TestNodeDBExpiration(t *testing.T) {
if err := db.updateNode(seed.node); err != nil {
t.Fatalf("node %d: failed to insert: %v", i, err)
}
- if err := db.updateLastPong(seed.node.ID, seed.pong); err != nil {
- t.Fatalf("node %d: failed to update pong: %v", i, err)
+ if err := db.updateBondTime(seed.node.ID, seed.pong); err != nil {
+ t.Fatalf("node %d: failed to update bondTime: %v", i, err)
}
}
// Expire some of them, and check the rest
@@ -365,8 +365,8 @@ func TestNodeDBSelfExpiration(t *testing.T) {
if err := db.updateNode(seed.node); err != nil {
t.Fatalf("node %d: failed to insert: %v", i, err)
}
- if err := db.updateLastPong(seed.node.ID, seed.pong); err != nil {
- t.Fatalf("node %d: failed to update pong: %v", i, err)
+ if err := db.updateBondTime(seed.node.ID, seed.pong); err != nil {
+ t.Fatalf("node %d: failed to update bondTime: %v", i, err)
}
}
// Expire the nodes and make sure self has been evacuated too
diff --git a/p2p/discover/node.go b/p2p/discover/node.go
index fc928a91a..3b0c84115 100644
--- a/p2p/discover/node.go
+++ b/p2p/discover/node.go
@@ -29,6 +29,7 @@ import (
"regexp"
"strconv"
"strings"
+ "time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
@@ -51,9 +52,8 @@ type Node struct {
// with ID.
sha common.Hash
- // whether this node is currently being pinged in order to replace
- // it in a bucket
- contested bool
+ // Time when the node was added to the table.
+ addedAt time.Time
}
// NewNode creates a new node. It is mostly meant to be used for
diff --git a/p2p/discover/table.go b/p2p/discover/table.go
index ec4eb94ad..6509326e6 100644
--- a/p2p/discover/table.go
+++ b/p2p/discover/table.go
@@ -23,10 +23,11 @@
package discover
import (
- "crypto/rand"
+ crand "crypto/rand"
"encoding/binary"
"errors"
"fmt"
+ mrand "math/rand"
"net"
"sort"
"sync"
@@ -35,29 +36,45 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/log"
+ "github.com/ethereum/go-ethereum/p2p/netutil"
)
const (
- alpha = 3 // Kademlia concurrency factor
- bucketSize = 16 // Kademlia bucket size
- hashBits = len(common.Hash{}) * 8
- nBuckets = hashBits + 1 // Number of buckets
-
- maxBondingPingPongs = 16
- maxFindnodeFailures = 5
-
- autoRefreshInterval = 1 * time.Hour
- seedCount = 30
- seedMaxAge = 5 * 24 * time.Hour
+ alpha = 3 // Kademlia concurrency factor
+ bucketSize = 16 // Kademlia bucket size
+ maxReplacements = 10 // Size of per-bucket replacement list
+
+ // We keep buckets for the upper 1/15 of distances because
+ // it's very unlikely we'll ever encounter a node that's closer.
+ hashBits = len(common.Hash{}) * 8
+ nBuckets = hashBits / 15 // Number of buckets
+ bucketMinDistance = hashBits - nBuckets // Log distance of closest bucket
+
+ // IP address limits.
+ bucketIPLimit, bucketSubnet = 2, 24 // at most 2 addresses from the same /24
+ tableIPLimit, tableSubnet = 10, 24
+
+ maxBondingPingPongs = 16 // Limit on the number of concurrent ping/pong interactions
+ maxFindnodeFailures = 5 // Nodes exceeding this limit are dropped
+
+ refreshInterval = 30 * time.Minute
+ revalidateInterval = 10 * time.Second
+ copyNodesInterval = 30 * time.Second
+ seedMinTableTime = 5 * time.Minute
+ seedCount = 30
+ seedMaxAge = 5 * 24 * time.Hour
)
type Table struct {
- mutex sync.Mutex // protects buckets, their content, and nursery
+ mutex sync.Mutex // protects buckets, bucket content, nursery, rand
buckets [nBuckets]*bucket // index of known nodes by distance
nursery []*Node // bootstrap nodes
- db *nodeDB // database of known nodes
+ rand *mrand.Rand // source of randomness, periodically reseeded
+ ips netutil.DistinctNetSet
+ db *nodeDB // database of known nodes
refreshReq chan chan struct{}
+ initDone chan struct{}
closeReq chan struct{}
closed chan struct{}
@@ -89,9 +106,13 @@ type transport interface {
// bucket contains nodes, ordered by their last activity. the entry
// that was most recently active is the first element in entries.
-type bucket struct{ entries []*Node }
+type bucket struct {
+ entries []*Node // live entries, sorted by time of last contact
+ replacements []*Node // recently seen nodes to be used if revalidation fails
+ ips netutil.DistinctNetSet
+}
-func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string) (*Table, error) {
+func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string, bootnodes []*Node) (*Table, error) {
// If no node database was given, use an in-memory one
db, err := newNodeDB(nodeDBPath, Version, ourID)
if err != nil {
@@ -104,19 +125,42 @@ func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string
bonding: make(map[NodeID]*bondproc),
bondslots: make(chan struct{}, maxBondingPingPongs),
refreshReq: make(chan chan struct{}),
+ initDone: make(chan struct{}),
closeReq: make(chan struct{}),
closed: make(chan struct{}),
+ rand: mrand.New(mrand.NewSource(0)),
+ ips: netutil.DistinctNetSet{Subnet: tableSubnet, Limit: tableIPLimit},
+ }
+ if err := tab.setFallbackNodes(bootnodes); err != nil {
+ return nil, err
}
for i := 0; i < cap(tab.bondslots); i++ {
tab.bondslots <- struct{}{}
}
for i := range tab.buckets {
- tab.buckets[i] = new(bucket)
+ tab.buckets[i] = &bucket{
+ ips: netutil.DistinctNetSet{Subnet: bucketSubnet, Limit: bucketIPLimit},
+ }
}
- go tab.refreshLoop()
+ tab.seedRand()
+ tab.loadSeedNodes(false)
+ // Start the background expiration goroutine after loading seeds so that the search for
+ // seed nodes also considers older nodes that would otherwise be removed by the
+ // expiration.
+ tab.db.ensureExpirer()
+ go tab.loop()
return tab, nil
}
+func (tab *Table) seedRand() {
+ var b [8]byte
+ crand.Read(b[:])
+
+ tab.mutex.Lock()
+ tab.rand.Seed(int64(binary.BigEndian.Uint64(b[:])))
+ tab.mutex.Unlock()
+}
+
// Self returns the local node.
// The returned node should not be modified by the caller.
func (tab *Table) Self() *Node {
@@ -127,9 +171,12 @@ func (tab *Table) Self() *Node {
// table. It will not write the same node more than once. The nodes in
// the slice are copies and can be modified by the caller.
func (tab *Table) ReadRandomNodes(buf []*Node) (n int) {
+ if !tab.isInitDone() {
+ return 0
+ }
tab.mutex.Lock()
defer tab.mutex.Unlock()
- // TODO: tree-based buckets would help here
+
// Find all non-empty buckets and get a fresh slice of their entries.
var buckets [][]*Node
for _, b := range tab.buckets {
@@ -141,8 +188,8 @@ func (tab *Table) ReadRandomNodes(buf []*Node) (n int) {
return 0
}
// Shuffle the buckets.
- for i := uint32(len(buckets)) - 1; i > 0; i-- {
- j := randUint(i)
+ for i := len(buckets) - 1; i > 0; i-- {
+ j := tab.rand.Intn(len(buckets))
buckets[i], buckets[j] = buckets[j], buckets[i]
}
// Move head of each bucket into buf, removing buckets that become empty.
@@ -161,15 +208,6 @@ func (tab *Table) ReadRandomNodes(buf []*Node) (n int) {
return i + 1
}
-func randUint(max uint32) uint32 {
- if max == 0 {
- return 0
- }
- var b [4]byte
- rand.Read(b[:])
- return binary.BigEndian.Uint32(b[:]) % max
-}
-
// Close terminates the network listener and flushes the node database.
func (tab *Table) Close() {
select {
@@ -180,16 +218,15 @@ func (tab *Table) Close() {
}
}
-// SetFallbackNodes sets the initial points of contact. These nodes
+// setFallbackNodes sets the initial points of contact. These nodes
// are used to connect to the network if the table is empty and there
// are no known nodes in the database.
-func (tab *Table) SetFallbackNodes(nodes []*Node) error {
+func (tab *Table) setFallbackNodes(nodes []*Node) error {
for _, n := range nodes {
if err := n.validateComplete(); err != nil {
return fmt.Errorf("bad bootstrap/fallback node %q (%v)", n, err)
}
}
- tab.mutex.Lock()
tab.nursery = make([]*Node, 0, len(nodes))
for _, n := range nodes {
cpy := *n
@@ -198,11 +235,19 @@ func (tab *Table) SetFallbackNodes(nodes []*Node) error {
cpy.sha = crypto.Keccak256Hash(n.ID[:])
tab.nursery = append(tab.nursery, &cpy)
}
- tab.mutex.Unlock()
- tab.refresh()
return nil
}
+// isInitDone returns whether the table's initial seeding procedure has completed.
+func (tab *Table) isInitDone() bool {
+ select {
+ case <-tab.initDone:
+ return true
+ default:
+ return false
+ }
+}
+
// Resolve searches for a specific node with the given ID.
// It returns nil if the node could not be found.
func (tab *Table) Resolve(targetID NodeID) *Node {
@@ -314,33 +359,49 @@ func (tab *Table) refresh() <-chan struct{} {
return done
}
-// refreshLoop schedules doRefresh runs and coordinates shutdown.
-func (tab *Table) refreshLoop() {
+// loop schedules refresh, revalidate runs and coordinates shutdown.
+func (tab *Table) loop() {
var (
- timer = time.NewTicker(autoRefreshInterval)
- waiting []chan struct{} // accumulates waiting callers while doRefresh runs
- done chan struct{} // where doRefresh reports completion
+ revalidate = time.NewTimer(tab.nextRevalidateTime())
+ refresh = time.NewTicker(refreshInterval)
+ copyNodes = time.NewTicker(copyNodesInterval)
+ revalidateDone = make(chan struct{})
+ refreshDone = make(chan struct{}) // where doRefresh reports completion
+ waiting = []chan struct{}{tab.initDone} // holds waiting callers while doRefresh runs
)
+ defer refresh.Stop()
+ defer revalidate.Stop()
+ defer copyNodes.Stop()
+
+ // Start initial refresh.
+ go tab.doRefresh(refreshDone)
+
loop:
for {
select {
- case <-timer.C:
- if done == nil {
- done = make(chan struct{})
- go tab.doRefresh(done)
+ case <-refresh.C:
+ tab.seedRand()
+ if refreshDone == nil {
+ refreshDone = make(chan struct{})
+ go tab.doRefresh(refreshDone)
}
case req := <-tab.refreshReq:
waiting = append(waiting, req)
- if done == nil {
- done = make(chan struct{})
- go tab.doRefresh(done)
+ if refreshDone == nil {
+ refreshDone = make(chan struct{})
+ go tab.doRefresh(refreshDone)
}
- case <-done:
+ case <-refreshDone:
for _, ch := range waiting {
close(ch)
}
- waiting = nil
- done = nil
+ waiting, refreshDone = nil, nil
+ case <-revalidate.C:
+ go tab.doRevalidate(revalidateDone)
+ case <-revalidateDone:
+ revalidate.Reset(tab.nextRevalidateTime())
+ case <-copyNodes.C:
+ go tab.copyBondedNodes()
case <-tab.closeReq:
break loop
}
@@ -349,8 +410,8 @@ loop:
if tab.net != nil {
tab.net.close()
}
- if done != nil {
- <-done
+ if refreshDone != nil {
+ <-refreshDone
}
for _, ch := range waiting {
close(ch)
@@ -365,38 +426,109 @@ loop:
func (tab *Table) doRefresh(done chan struct{}) {
defer close(done)
+ // Load nodes from the database and insert
+ // them. This should yield a few previously seen nodes that are
+ // (hopefully) still alive.
+ tab.loadSeedNodes(true)
+
+ // Run self lookup to discover new neighbor nodes.
+ tab.lookup(tab.self.ID, false)
+
// The Kademlia paper specifies that the bucket refresh should
// perform a lookup in the least recently used bucket. We cannot
// adhere to this because the findnode target is a 512bit value
// (not hash-sized) and it is not easily possible to generate a
// sha3 preimage that falls into a chosen bucket.
- // We perform a lookup with a random target instead.
- var target NodeID
- rand.Read(target[:])
- result := tab.lookup(target, false)
- if len(result) > 0 {
- return
+ // We perform a few lookups with a random target instead.
+ for i := 0; i < 3; i++ {
+ var target NodeID
+ crand.Read(target[:])
+ tab.lookup(target, false)
}
+}
- // The table is empty. Load nodes from the database and insert
- // them. This should yield a few previously seen nodes that are
- // (hopefully) still alive.
+func (tab *Table) loadSeedNodes(bond bool) {
seeds := tab.db.querySeeds(seedCount, seedMaxAge)
- seeds = tab.bondall(append(seeds, tab.nursery...))
+ seeds = append(seeds, tab.nursery...)
+ if bond {
+ seeds = tab.bondall(seeds)
+ }
+ for i := range seeds {
+ seed := seeds[i]
+ age := log.Lazy{Fn: func() interface{} { return time.Since(tab.db.bondTime(seed.ID)) }}
+ log.Debug("Found seed node in database", "id", seed.ID, "addr", seed.addr(), "age", age)
+ tab.add(seed)
+ }
+}
- if len(seeds) == 0 {
- log.Debug("No discv4 seed nodes found")
+// doRevalidate checks that the last node in a random bucket is still live
+// and replaces or deletes the node if it isn't.
+func (tab *Table) doRevalidate(done chan<- struct{}) {
+ defer func() { done <- struct{}{} }()
+
+ last, bi := tab.nodeToRevalidate()
+ if last == nil {
+ // No non-empty bucket found.
+ return
+ }
+
+ // Ping the selected node and wait for a pong.
+ err := tab.ping(last.ID, last.addr())
+
+ tab.mutex.Lock()
+ defer tab.mutex.Unlock()
+ b := tab.buckets[bi]
+ if err == nil {
+ // The node responded, move it to the front.
+ log.Debug("Revalidated node", "b", bi, "id", last.ID)
+ b.bump(last)
+ return
}
- for _, n := range seeds {
- age := log.Lazy{Fn: func() time.Duration { return time.Since(tab.db.lastPong(n.ID)) }}
- log.Trace("Found seed node in database", "id", n.ID, "addr", n.addr(), "age", age)
+ // No reply received, pick a replacement or delete the node if there aren't
+ // any replacements.
+ if r := tab.replace(b, last); r != nil {
+ log.Debug("Replaced dead node", "b", bi, "id", last.ID, "ip", last.IP, "r", r.ID, "rip", r.IP)
+ } else {
+ log.Debug("Removed dead node", "b", bi, "id", last.ID, "ip", last.IP)
}
+}
+
+// nodeToRevalidate returns the last node in a random, non-empty bucket.
+func (tab *Table) nodeToRevalidate() (n *Node, bi int) {
tab.mutex.Lock()
- tab.stuff(seeds)
- tab.mutex.Unlock()
+ defer tab.mutex.Unlock()
- // Finally, do a self lookup to fill up the buckets.
- tab.lookup(tab.self.ID, false)
+ for _, bi = range tab.rand.Perm(len(tab.buckets)) {
+ b := tab.buckets[bi]
+ if len(b.entries) > 0 {
+ last := b.entries[len(b.entries)-1]
+ return last, bi
+ }
+ }
+ return nil, 0
+}
+
+func (tab *Table) nextRevalidateTime() time.Duration {
+ tab.mutex.Lock()
+ defer tab.mutex.Unlock()
+
+ return time.Duration(tab.rand.Int63n(int64(revalidateInterval)))
+}
+
+// copyBondedNodes adds nodes from the table to the database if they have been in the table
+// longer then minTableTime.
+func (tab *Table) copyBondedNodes() {
+ tab.mutex.Lock()
+ defer tab.mutex.Unlock()
+
+ now := time.Now()
+ for _, b := range tab.buckets {
+ for _, n := range b.entries {
+ if now.Sub(n.addedAt) >= seedMinTableTime {
+ tab.db.updateNode(n)
+ }
+ }
+ }
}
// closest returns the n nodes in the table that are closest to the
@@ -459,15 +591,14 @@ func (tab *Table) bond(pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16
if id == tab.self.ID {
return nil, errors.New("is self")
}
- // Retrieve a previously known node and any recent findnode failures
- node, fails := tab.db.node(id), 0
- if node != nil {
- fails = tab.db.findFails(id)
+ if pinged && !tab.isInitDone() {
+ return nil, errors.New("still initializing")
}
- // If the node is unknown (non-bonded) or failed (remotely unknown), bond from scratch
+ // Start bonding if we haven't seen this node for a while or if it failed findnode too often.
+ node, fails := tab.db.node(id), tab.db.findFails(id)
+ age := time.Since(tab.db.bondTime(id))
var result error
- age := time.Since(tab.db.lastPong(id))
- if node == nil || fails > 0 || age > nodeDBNodeExpiration {
+ if fails > 0 || age > nodeDBNodeExpiration {
log.Trace("Starting bonding ping/pong", "id", id, "known", node != nil, "failcount", fails, "age", age)
tab.bondmu.Lock()
@@ -494,10 +625,10 @@ func (tab *Table) bond(pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16
node = w.n
}
}
+ // Add the node to the table even if the bonding ping/pong
+ // fails. It will be relaced quickly if it continues to be
+ // unresponsive.
if node != nil {
- // Add the node to the table even if the bonding ping/pong
- // fails. It will be relaced quickly if it continues to be
- // unresponsive.
tab.add(node)
tab.db.updateFindFails(id, 0)
}
@@ -522,7 +653,6 @@ func (tab *Table) pingpong(w *bondproc, pinged bool, id NodeID, addr *net.UDPAdd
}
// Bonding succeeded, update the node database.
w.n = NewNode(id, addr.IP, uint16(addr.Port), tcpPort)
- tab.db.updateNode(w.n)
close(w.done)
}
@@ -533,17 +663,19 @@ func (tab *Table) ping(id NodeID, addr *net.UDPAddr) error {
if err := tab.net.ping(id, addr); err != nil {
return err
}
- tab.db.updateLastPong(id, time.Now())
-
- // Start the background expiration goroutine after the first
- // successful communication. Subsequent calls have no effect if it
- // is already running. We do this here instead of somewhere else
- // so that the search for seed nodes also considers older nodes
- // that would otherwise be removed by the expiration.
- tab.db.ensureExpirer()
+ tab.db.updateBondTime(id, time.Now())
return nil
}
+// bucket returns the bucket for the given node ID hash.
+func (tab *Table) bucket(sha common.Hash) *bucket {
+ d := logdist(tab.self.sha, sha)
+ if d <= bucketMinDistance {
+ return tab.buckets[0]
+ }
+ return tab.buckets[d-bucketMinDistance-1]
+}
+
// add attempts to add the given node its corresponding bucket. If the
// bucket has space available, adding the node succeeds immediately.
// Otherwise, the node is added if the least recently active node in
@@ -551,57 +683,29 @@ func (tab *Table) ping(id NodeID, addr *net.UDPAddr) error {
//
// The caller must not hold tab.mutex.
func (tab *Table) add(new *Node) {
- b := tab.buckets[logdist(tab.self.sha, new.sha)]
tab.mutex.Lock()
defer tab.mutex.Unlock()
- if b.bump(new) {
- return
- }
- var oldest *Node
- if len(b.entries) == bucketSize {
- oldest = b.entries[bucketSize-1]
- if oldest.contested {
- // The node is already being replaced, don't attempt
- // to replace it.
- return
- }
- oldest.contested = true
- // Let go of the mutex so other goroutines can access
- // the table while we ping the least recently active node.
- tab.mutex.Unlock()
- err := tab.ping(oldest.ID, oldest.addr())
- tab.mutex.Lock()
- oldest.contested = false
- if err == nil {
- // The node responded, don't replace it.
- return
- }
- }
- added := b.replace(new, oldest)
- if added && tab.nodeAddedHook != nil {
- tab.nodeAddedHook(new)
+
+ b := tab.bucket(new.sha)
+ if !tab.bumpOrAdd(b, new) {
+ // Node is not in table. Add it to the replacement list.
+ tab.addReplacement(b, new)
}
}
// stuff adds nodes the table to the end of their corresponding bucket
-// if the bucket is not full. The caller must hold tab.mutex.
+// if the bucket is not full. The caller must not hold tab.mutex.
func (tab *Table) stuff(nodes []*Node) {
-outer:
+ tab.mutex.Lock()
+ defer tab.mutex.Unlock()
+
for _, n := range nodes {
if n.ID == tab.self.ID {
continue // don't add self
}
- bucket := tab.buckets[logdist(tab.self.sha, n.sha)]
- for i := range bucket.entries {
- if bucket.entries[i].ID == n.ID {
- continue outer // already in bucket
- }
- }
- if len(bucket.entries) < bucketSize {
- bucket.entries = append(bucket.entries, n)
- if tab.nodeAddedHook != nil {
- tab.nodeAddedHook(n)
- }
+ b := tab.bucket(n.sha)
+ if len(b.entries) < bucketSize {
+ tab.bumpOrAdd(b, n)
}
}
}
@@ -611,36 +715,72 @@ outer:
func (tab *Table) delete(node *Node) {
tab.mutex.Lock()
defer tab.mutex.Unlock()
- bucket := tab.buckets[logdist(tab.self.sha, node.sha)]
- for i := range bucket.entries {
- if bucket.entries[i].ID == node.ID {
- bucket.entries = append(bucket.entries[:i], bucket.entries[i+1:]...)
- return
- }
- }
+
+ tab.deleteInBucket(tab.bucket(node.sha), node)
}
-func (b *bucket) replace(n *Node, last *Node) bool {
- // Don't add if b already contains n.
- for i := range b.entries {
- if b.entries[i].ID == n.ID {
- return false
- }
+func (tab *Table) addIP(b *bucket, ip net.IP) bool {
+ if netutil.IsLAN(ip) {
+ return true
}
- // Replace last if it is still the last entry or just add n if b
- // isn't full. If is no longer the last entry, it has either been
- // replaced with someone else or became active.
- if len(b.entries) == bucketSize && (last == nil || b.entries[bucketSize-1].ID != last.ID) {
+ if !tab.ips.Add(ip) {
+ log.Debug("IP exceeds table limit", "ip", ip)
return false
}
- if len(b.entries) < bucketSize {
- b.entries = append(b.entries, nil)
+ if !b.ips.Add(ip) {
+ log.Debug("IP exceeds bucket limit", "ip", ip)
+ tab.ips.Remove(ip)
+ return false
}
- copy(b.entries[1:], b.entries)
- b.entries[0] = n
return true
}
+func (tab *Table) removeIP(b *bucket, ip net.IP) {
+ if netutil.IsLAN(ip) {
+ return
+ }
+ tab.ips.Remove(ip)
+ b.ips.Remove(ip)
+}
+
+func (tab *Table) addReplacement(b *bucket, n *Node) {
+ for _, e := range b.replacements {
+ if e.ID == n.ID {
+ return // already in list
+ }
+ }
+ if !tab.addIP(b, n.IP) {
+ return
+ }
+ var removed *Node
+ b.replacements, removed = pushNode(b.replacements, n, maxReplacements)
+ if removed != nil {
+ tab.removeIP(b, removed.IP)
+ }
+}
+
+// replace removes n from the replacement list and replaces 'last' with it if it is the
+// last entry in the bucket. If 'last' isn't the last entry, it has either been replaced
+// with someone else or became active.
+func (tab *Table) replace(b *bucket, last *Node) *Node {
+ if len(b.entries) == 0 || b.entries[len(b.entries)-1].ID != last.ID {
+ // Entry has moved, don't replace it.
+ return nil
+ }
+ // Still the last entry.
+ if len(b.replacements) == 0 {
+ tab.deleteInBucket(b, last)
+ return nil
+ }
+ r := b.replacements[tab.rand.Intn(len(b.replacements))]
+ b.replacements = deleteNode(b.replacements, r)
+ b.entries[len(b.entries)-1] = r
+ tab.removeIP(b, last.IP)
+ return r
+}
+
+// bump moves the given node to the front of the bucket entry list
+// if it is contained in that list.
func (b *bucket) bump(n *Node) bool {
for i := range b.entries {
if b.entries[i].ID == n.ID {
@@ -653,6 +793,50 @@ func (b *bucket) bump(n *Node) bool {
return false
}
+// bumpOrAdd moves n to the front of the bucket entry list or adds it if the list isn't
+// full. The return value is true if n is in the bucket.
+func (tab *Table) bumpOrAdd(b *bucket, n *Node) bool {
+ if b.bump(n) {
+ return true
+ }
+ if len(b.entries) >= bucketSize || !tab.addIP(b, n.IP) {
+ return false
+ }
+ b.entries, _ = pushNode(b.entries, n, bucketSize)
+ b.replacements = deleteNode(b.replacements, n)
+ n.addedAt = time.Now()
+ if tab.nodeAddedHook != nil {
+ tab.nodeAddedHook(n)
+ }
+ return true
+}
+
+func (tab *Table) deleteInBucket(b *bucket, n *Node) {
+ b.entries = deleteNode(b.entries, n)
+ tab.removeIP(b, n.IP)
+}
+
+// pushNode adds n to the front of list, keeping at most max items.
+func pushNode(list []*Node, n *Node, max int) ([]*Node, *Node) {
+ if len(list) < max {
+ list = append(list, nil)
+ }
+ removed := list[len(list)-1]
+ copy(list[1:], list)
+ list[0] = n
+ return list, removed
+}
+
+// deleteNode removes n from list.
+func deleteNode(list []*Node, n *Node) []*Node {
+ for i := range list {
+ if list[i].ID == n.ID {
+ return append(list[:i], list[i+1:]...)
+ }
+ }
+ return list
+}
+
// nodesByDistance is a list of nodes, ordered by
// distance to target.
type nodesByDistance struct {
diff --git a/p2p/discover/table_test.go b/p2p/discover/table_test.go
index 1037cc609..3ce48d299 100644
--- a/p2p/discover/table_test.go
+++ b/p2p/discover/table_test.go
@@ -20,6 +20,7 @@ import (
"crypto/ecdsa"
"fmt"
"math/rand"
+ "sync"
"net"
"reflect"
@@ -32,60 +33,65 @@ import (
)
func TestTable_pingReplace(t *testing.T) {
- doit := func(newNodeIsResponding, lastInBucketIsResponding bool) {
- transport := newPingRecorder()
- tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "")
- defer tab.Close()
- pingSender := NewNode(MustHexID("a502af0f59b2aab7746995408c79e9ca312d2793cc997e44fc55eda62f0150bbb8c59a6f9269ba3a081518b62699ee807c7c19c20125ddfccca872608af9e370"), net.IP{}, 99, 99)
+ run := func(newNodeResponding, lastInBucketResponding bool) {
+ name := fmt.Sprintf("newNodeResponding=%t/lastInBucketResponding=%t", newNodeResponding, lastInBucketResponding)
+ t.Run(name, func(t *testing.T) {
+ t.Parallel()
+ testPingReplace(t, newNodeResponding, lastInBucketResponding)
+ })
+ }
- // fill up the sender's bucket.
- last := fillBucket(tab, 253)
+ run(true, true)
+ run(false, true)
+ run(true, false)
+ run(false, false)
+}
- // this call to bond should replace the last node
- // in its bucket if the node is not responding.
- transport.responding[last.ID] = lastInBucketIsResponding
- transport.responding[pingSender.ID] = newNodeIsResponding
- tab.bond(true, pingSender.ID, &net.UDPAddr{}, 0)
+func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding bool) {
+ transport := newPingRecorder()
+ tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil)
+ defer tab.Close()
- // first ping goes to sender (bonding pingback)
- if !transport.pinged[pingSender.ID] {
- t.Error("table did not ping back sender")
- }
- if newNodeIsResponding {
- // second ping goes to oldest node in bucket
- // to see whether it is still alive.
- if !transport.pinged[last.ID] {
- t.Error("table did not ping last node in bucket")
- }
- }
+ // Wait for init so bond is accepted.
+ <-tab.initDone
- tab.mutex.Lock()
- defer tab.mutex.Unlock()
- if l := len(tab.buckets[253].entries); l != bucketSize {
- t.Errorf("wrong bucket size after bond: got %d, want %d", l, bucketSize)
- }
+ // fill up the sender's bucket.
+ pingSender := NewNode(MustHexID("a502af0f59b2aab7746995408c79e9ca312d2793cc997e44fc55eda62f0150bbb8c59a6f9269ba3a081518b62699ee807c7c19c20125ddfccca872608af9e370"), net.IP{}, 99, 99)
+ last := fillBucket(tab, pingSender)
- if lastInBucketIsResponding || !newNodeIsResponding {
- if !contains(tab.buckets[253].entries, last.ID) {
- t.Error("last entry was removed")
- }
- if contains(tab.buckets[253].entries, pingSender.ID) {
- t.Error("new entry was added")
- }
- } else {
- if contains(tab.buckets[253].entries, last.ID) {
- t.Error("last entry was not removed")
- }
- if !contains(tab.buckets[253].entries, pingSender.ID) {
- t.Error("new entry was not added")
- }
- }
+ // this call to bond should replace the last node
+ // in its bucket if the node is not responding.
+ transport.dead[last.ID] = !lastInBucketIsResponding
+ transport.dead[pingSender.ID] = !newNodeIsResponding
+ tab.bond(true, pingSender.ID, &net.UDPAddr{}, 0)
+ tab.doRevalidate(make(chan struct{}, 1))
+
+ // first ping goes to sender (bonding pingback)
+ if !transport.pinged[pingSender.ID] {
+ t.Error("table did not ping back sender")
+ }
+ if !transport.pinged[last.ID] {
+ // second ping goes to oldest node in bucket
+ // to see whether it is still alive.
+ t.Error("table did not ping last node in bucket")
}
- doit(true, true)
- doit(false, true)
- doit(true, false)
- doit(false, false)
+ tab.mutex.Lock()
+ defer tab.mutex.Unlock()
+ wantSize := bucketSize
+ if !lastInBucketIsResponding && !newNodeIsResponding {
+ wantSize--
+ }
+ if l := len(tab.bucket(pingSender.sha).entries); l != wantSize {
+ t.Errorf("wrong bucket size after bond: got %d, want %d", l, wantSize)
+ }
+ if found := contains(tab.bucket(pingSender.sha).entries, last.ID); found != lastInBucketIsResponding {
+ t.Errorf("last entry found: %t, want: %t", found, lastInBucketIsResponding)
+ }
+ wantNewEntry := newNodeIsResponding && !lastInBucketIsResponding
+ if found := contains(tab.bucket(pingSender.sha).entries, pingSender.ID); found != wantNewEntry {
+ t.Errorf("new entry found: %t, want: %t", found, wantNewEntry)
+ }
}
func TestBucket_bumpNoDuplicates(t *testing.T) {
@@ -130,11 +136,45 @@ func TestBucket_bumpNoDuplicates(t *testing.T) {
}
}
+// This checks that the table-wide IP limit is applied correctly.
+func TestTable_IPLimit(t *testing.T) {
+ transport := newPingRecorder()
+ tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil)
+ defer tab.Close()
+
+ for i := 0; i < tableIPLimit+1; i++ {
+ n := nodeAtDistance(tab.self.sha, i)
+ n.IP = net.IP{172, 0, 1, byte(i)}
+ tab.add(n)
+ }
+ if tab.len() > tableIPLimit {
+ t.Errorf("too many nodes in table")
+ }
+}
+
+// This checks that the table-wide IP limit is applied correctly.
+func TestTable_BucketIPLimit(t *testing.T) {
+ transport := newPingRecorder()
+ tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil)
+ defer tab.Close()
+
+ d := 3
+ for i := 0; i < bucketIPLimit+1; i++ {
+ n := nodeAtDistance(tab.self.sha, d)
+ n.IP = net.IP{172, 0, 1, byte(i)}
+ tab.add(n)
+ }
+ if tab.len() > bucketIPLimit {
+ t.Errorf("too many nodes in table")
+ }
+}
+
// fillBucket inserts nodes into the given bucket until
// it is full. The node's IDs dont correspond to their
// hashes.
-func fillBucket(tab *Table, ld int) (last *Node) {
- b := tab.buckets[ld]
+func fillBucket(tab *Table, n *Node) (last *Node) {
+ ld := logdist(tab.self.sha, n.sha)
+ b := tab.bucket(n.sha)
for len(b.entries) < bucketSize {
b.entries = append(b.entries, nodeAtDistance(tab.self.sha, ld))
}
@@ -146,30 +186,39 @@ func fillBucket(tab *Table, ld int) (last *Node) {
func nodeAtDistance(base common.Hash, ld int) (n *Node) {
n = new(Node)
n.sha = hashAtDistance(base, ld)
- n.IP = net.IP{10, 0, 2, byte(ld)}
+ n.IP = net.IP{byte(ld), 0, 2, byte(ld)}
copy(n.ID[:], n.sha[:]) // ensure the node still has a unique ID
return n
}
-type pingRecorder struct{ responding, pinged map[NodeID]bool }
+type pingRecorder struct {
+ mu sync.Mutex
+ dead, pinged map[NodeID]bool
+}
func newPingRecorder() *pingRecorder {
- return &pingRecorder{make(map[NodeID]bool), make(map[NodeID]bool)}
+ return &pingRecorder{
+ dead: make(map[NodeID]bool),
+ pinged: make(map[NodeID]bool),
+ }
}
func (t *pingRecorder) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) {
- panic("findnode called on pingRecorder")
+ return nil, nil
}
func (t *pingRecorder) close() {}
func (t *pingRecorder) waitping(from NodeID) error {
return nil // remote always pings
}
func (t *pingRecorder) ping(toid NodeID, toaddr *net.UDPAddr) error {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
t.pinged[toid] = true
- if t.responding[toid] {
- return nil
- } else {
+ if t.dead[toid] {
return errTimeout
+ } else {
+ return nil
}
}
@@ -178,7 +227,8 @@ func TestTable_closest(t *testing.T) {
test := func(test *closeTest) bool {
// for any node table, Target and N
- tab, _ := newTable(nil, test.Self, &net.UDPAddr{}, "")
+ transport := newPingRecorder()
+ tab, _ := newTable(transport, test.Self, &net.UDPAddr{}, "", nil)
defer tab.Close()
tab.stuff(test.All)
@@ -237,8 +287,11 @@ func TestTable_ReadRandomNodesGetAll(t *testing.T) {
},
}
test := func(buf []*Node) bool {
- tab, _ := newTable(nil, NodeID{}, &net.UDPAddr{}, "")
+ transport := newPingRecorder()
+ tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil)
defer tab.Close()
+ <-tab.initDone
+
for i := 0; i < len(buf); i++ {
ld := cfg.Rand.Intn(len(tab.buckets))
tab.stuff([]*Node{nodeAtDistance(tab.self.sha, ld)})
@@ -280,7 +333,7 @@ func (*closeTest) Generate(rand *rand.Rand, size int) reflect.Value {
func TestTable_Lookup(t *testing.T) {
self := nodeAtDistance(common.Hash{}, 0)
- tab, _ := newTable(lookupTestnet, self.ID, &net.UDPAddr{}, "")
+ tab, _ := newTable(lookupTestnet, self.ID, &net.UDPAddr{}, "", nil)
defer tab.Close()
// lookup on empty table returns no nodes
diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go
index 60436952d..524c6e498 100644
--- a/p2p/discover/udp.go
+++ b/p2p/discover/udp.go
@@ -216,9 +216,22 @@ type ReadPacket struct {
Addr *net.UDPAddr
}
+// Config holds Table-related settings.
+type Config struct {
+ // These settings are required and configure the UDP listener:
+ PrivateKey *ecdsa.PrivateKey
+
+ // These settings are optional:
+ AnnounceAddr *net.UDPAddr // local address announced in the DHT
+ NodeDBPath string // if set, the node database is stored at this filesystem location
+ NetRestrict *netutil.Netlist // network whitelist
+ Bootnodes []*Node // list of bootstrap nodes
+ Unhandled chan<- ReadPacket // unhandled packets are sent on this channel
+}
+
// ListenUDP returns a new table that listens for UDP packets on laddr.
-func ListenUDP(priv *ecdsa.PrivateKey, conn conn, realaddr *net.UDPAddr, unhandled chan ReadPacket, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, error) {
- tab, _, err := newUDP(priv, conn, realaddr, unhandled, nodeDBPath, netrestrict)
+func ListenUDP(c conn, cfg Config) (*Table, error) {
+ tab, _, err := newUDP(c, cfg)
if err != nil {
return nil, err
}
@@ -226,25 +239,29 @@ func ListenUDP(priv *ecdsa.PrivateKey, conn conn, realaddr *net.UDPAddr, unhandl
return tab, nil
}
-func newUDP(priv *ecdsa.PrivateKey, c conn, realaddr *net.UDPAddr, unhandled chan ReadPacket, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, *udp, error) {
+func newUDP(c conn, cfg Config) (*Table, *udp, error) {
udp := &udp{
conn: c,
- priv: priv,
- netrestrict: netrestrict,
+ priv: cfg.PrivateKey,
+ netrestrict: cfg.NetRestrict,
closing: make(chan struct{}),
gotreply: make(chan reply),
addpending: make(chan *pending),
}
+ realaddr := c.LocalAddr().(*net.UDPAddr)
+ if cfg.AnnounceAddr != nil {
+ realaddr = cfg.AnnounceAddr
+ }
// TODO: separate TCP port
udp.ourEndpoint = makeEndpoint(realaddr, uint16(realaddr.Port))
- tab, err := newTable(udp, PubkeyID(&priv.PublicKey), realaddr, nodeDBPath)
+ tab, err := newTable(udp, PubkeyID(&cfg.PrivateKey.PublicKey), realaddr, cfg.NodeDBPath, cfg.Bootnodes)
if err != nil {
return nil, nil, err
}
udp.Table = tab
go udp.loop()
- go udp.readLoop(unhandled)
+ go udp.readLoop(cfg.Unhandled)
return udp.Table, udp, nil
}
@@ -256,14 +273,20 @@ func (t *udp) close() {
// ping sends a ping message to the given node and waits for a reply.
func (t *udp) ping(toid NodeID, toaddr *net.UDPAddr) error {
- // TODO: maybe check for ReplyTo field in callback to measure RTT
- errc := t.pending(toid, pongPacket, func(interface{}) bool { return true })
- t.send(toaddr, pingPacket, &ping{
+ req := &ping{
Version: Version,
From: t.ourEndpoint,
To: makeEndpoint(toaddr, 0), // TODO: maybe use known TCP port from DB
Expiration: uint64(time.Now().Add(expiration).Unix()),
+ }
+ packet, hash, err := encodePacket(t.priv, pingPacket, req)
+ if err != nil {
+ return err
+ }
+ errc := t.pending(toid, pongPacket, func(p interface{}) bool {
+ return bytes.Equal(p.(*pong).ReplyTok, hash)
})
+ t.write(toaddr, req.name(), packet)
return <-errc
}
@@ -447,40 +470,45 @@ func init() {
}
}
-func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req packet) error {
- packet, err := encodePacket(t.priv, ptype, req)
+func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req packet) ([]byte, error) {
+ packet, hash, err := encodePacket(t.priv, ptype, req)
if err != nil {
- return err
+ return hash, err
}
- _, err = t.conn.WriteToUDP(packet, toaddr)
- log.Trace(">> "+req.name(), "addr", toaddr, "err", err)
+ return hash, t.write(toaddr, req.name(), packet)
+}
+
+func (t *udp) write(toaddr *net.UDPAddr, what string, packet []byte) error {
+ _, err := t.conn.WriteToUDP(packet, toaddr)
+ log.Trace(">> "+what, "addr", toaddr, "err", err)
return err
}
-func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) ([]byte, error) {
+func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) (packet, hash []byte, err error) {
b := new(bytes.Buffer)
b.Write(headSpace)
b.WriteByte(ptype)
if err := rlp.Encode(b, req); err != nil {
log.Error("Can't encode discv4 packet", "err", err)
- return nil, err
+ return nil, nil, err
}
- packet := b.Bytes()
+ packet = b.Bytes()
sig, err := crypto.Sign(crypto.Keccak256(packet[headSize:]), priv)
if err != nil {
log.Error("Can't sign discv4 packet", "err", err)
- return nil, err
+ return nil, nil, err
}
copy(packet[macSize:], sig)
// add the hash to the front. Note: this doesn't protect the
// packet in any way. Our public key will be part of this hash in
// The future.
- copy(packet, crypto.Keccak256(packet[macSize:]))
- return packet, nil
+ hash = crypto.Keccak256(packet[macSize:])
+ copy(packet, hash)
+ return packet, hash, nil
}
// readLoop runs in its own goroutine. it handles incoming UDP packets.
-func (t *udp) readLoop(unhandled chan ReadPacket) {
+func (t *udp) readLoop(unhandled chan<- ReadPacket) {
defer t.conn.Close()
if unhandled != nil {
defer close(unhandled)
@@ -585,7 +613,7 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte
if expired(req.Expiration) {
return errExpired
}
- if t.db.node(fromID) == nil {
+ if !t.db.hasBond(fromID) {
// No bond exists, we don't process the packet. This prevents
// an attack vector where the discovery protocol could be used
// to amplify traffic in a DDOS attack. A malicious actor
@@ -601,18 +629,22 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte
t.mutex.Unlock()
p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())}
+ var sent bool
// Send neighbors in chunks with at most maxNeighbors per packet
// to stay below the 1280 byte limit.
- for i, n := range closest {
- if netutil.CheckRelayIP(from.IP, n.IP) != nil {
- continue
+ for _, n := range closest {
+ if netutil.CheckRelayIP(from.IP, n.IP) == nil {
+ p.Nodes = append(p.Nodes, nodeToRPC(n))
}
- p.Nodes = append(p.Nodes, nodeToRPC(n))
- if len(p.Nodes) == maxNeighbors || i == len(closest)-1 {
+ if len(p.Nodes) == maxNeighbors {
t.send(from, neighborsPacket, &p)
p.Nodes = p.Nodes[:0]
+ sent = true
}
}
+ if len(p.Nodes) > 0 || !sent {
+ t.send(from, neighborsPacket, &p)
+ }
return nil
}
diff --git a/p2p/discover/udp_test.go b/p2p/discover/udp_test.go
index b81caf839..db9804f7b 100644
--- a/p2p/discover/udp_test.go
+++ b/p2p/discover/udp_test.go
@@ -70,14 +70,15 @@ func newUDPTest(t *testing.T) *udpTest {
remotekey: newkey(),
remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303},
}
- realaddr := test.pipe.LocalAddr().(*net.UDPAddr)
- test.table, test.udp, _ = newUDP(test.localkey, test.pipe, realaddr, nil, "", nil)
+ test.table, test.udp, _ = newUDP(test.pipe, Config{PrivateKey: test.localkey})
+ // Wait for initial refresh so the table doesn't send unexpected findnode.
+ <-test.table.initDone
return test
}
// handles a packet as if it had been sent to the transport.
func (test *udpTest) packetIn(wantError error, ptype byte, data packet) error {
- enc, err := encodePacket(test.remotekey, ptype, data)
+ enc, _, err := encodePacket(test.remotekey, ptype, data)
if err != nil {
return test.errorf("packet (%d) encode error: %v", ptype, err)
}
@@ -90,19 +91,19 @@ func (test *udpTest) packetIn(wantError error, ptype byte, data packet) error {
// waits for a packet to be sent by the transport.
// validate should have type func(*udpTest, X) error, where X is a packet type.
-func (test *udpTest) waitPacketOut(validate interface{}) error {
+func (test *udpTest) waitPacketOut(validate interface{}) ([]byte, error) {
dgram := test.pipe.waitPacketOut()
- p, _, _, err := decodePacket(dgram)
+ p, _, hash, err := decodePacket(dgram)
if err != nil {
- return test.errorf("sent packet decode error: %v", err)
+ return hash, test.errorf("sent packet decode error: %v", err)
}
fn := reflect.ValueOf(validate)
exptype := fn.Type().In(0)
if reflect.TypeOf(p) != exptype {
- return test.errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype)
+ return hash, test.errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype)
}
fn.Call([]reflect.Value{reflect.ValueOf(p)})
- return nil
+ return hash, nil
}
func (test *udpTest) errorf(format string, args ...interface{}) error {
@@ -246,12 +247,8 @@ func TestUDP_findnode(t *testing.T) {
// ensure there's a bond with the test node,
// findnode won't be accepted otherwise.
- test.table.db.updateNode(NewNode(
- PubkeyID(&test.remotekey.PublicKey),
- test.remoteaddr.IP,
- uint16(test.remoteaddr.Port),
- 99,
- ))
+ test.table.db.updateBondTime(PubkeyID(&test.remotekey.PublicKey), time.Now())
+
// check that closest neighbors are returned.
test.packetIn(nil, findnodePacket, &findnode{Target: testTarget, Expiration: futureExp})
expected := test.table.closest(targetHash, bucketSize)
@@ -351,7 +348,7 @@ func TestUDP_successfulPing(t *testing.T) {
})
// remote is unknown, the table pings back.
- test.waitPacketOut(func(p *ping) error {
+ hash, _ := test.waitPacketOut(func(p *ping) error {
if !reflect.DeepEqual(p.From, test.udp.ourEndpoint) {
t.Errorf("got ping.From %v, want %v", p.From, test.udp.ourEndpoint)
}
@@ -365,7 +362,7 @@ func TestUDP_successfulPing(t *testing.T) {
}
return nil
})
- test.packetIn(nil, pongPacket, &pong{Expiration: futureExp})
+ test.packetIn(nil, pongPacket, &pong{ReplyTok: hash, Expiration: futureExp})
// the node should be added to the table shortly after getting the
// pong packet.