aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--cmd/bootnode/main.go7
-rw-r--r--p2p/discover/node.go6
-rw-r--r--p2p/discover/table.go484
-rw-r--r--p2p/discover/table_test.go171
-rw-r--r--p2p/discover/udp.go86
-rw-r--r--p2p/discover/udp_test.go21
-rw-r--r--p2p/netutil/net.go131
-rw-r--r--p2p/netutil/net_test.go89
-rw-r--r--p2p/peer.go6
-rw-r--r--p2p/server.go77
10 files changed, 801 insertions, 277 deletions
diff --git a/cmd/bootnode/main.go b/cmd/bootnode/main.go
index ecfc6fc24..2e93cc04d 100644
--- a/cmd/bootnode/main.go
+++ b/cmd/bootnode/main.go
@@ -122,7 +122,12 @@ func main() {
utils.Fatalf("%v", err)
}
} else {
- if _, err := discover.ListenUDP(nodeKey, conn, realaddr, nil, "", restrictList); err != nil {
+ cfg := discover.Config{
+ PrivateKey: nodeKey,
+ AnnounceAddr: realaddr,
+ NetRestrict: restrictList,
+ }
+ if _, err := discover.ListenUDP(conn, cfg); err != nil {
utils.Fatalf("%v", err)
}
}
diff --git a/p2p/discover/node.go b/p2p/discover/node.go
index fc928a91a..3b0c84115 100644
--- a/p2p/discover/node.go
+++ b/p2p/discover/node.go
@@ -29,6 +29,7 @@ import (
"regexp"
"strconv"
"strings"
+ "time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
@@ -51,9 +52,8 @@ type Node struct {
// with ID.
sha common.Hash
- // whether this node is currently being pinged in order to replace
- // it in a bucket
- contested bool
+ // Time when the node was added to the table.
+ addedAt time.Time
}
// NewNode creates a new node. It is mostly meant to be used for
diff --git a/p2p/discover/table.go b/p2p/discover/table.go
index ec4eb94ad..84c54dac1 100644
--- a/p2p/discover/table.go
+++ b/p2p/discover/table.go
@@ -23,10 +23,11 @@
package discover
import (
- "crypto/rand"
+ crand "crypto/rand"
"encoding/binary"
"errors"
"fmt"
+ mrand "math/rand"
"net"
"sort"
"sync"
@@ -35,29 +36,45 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/log"
+ "github.com/ethereum/go-ethereum/p2p/netutil"
)
const (
- alpha = 3 // Kademlia concurrency factor
- bucketSize = 16 // Kademlia bucket size
- hashBits = len(common.Hash{}) * 8
- nBuckets = hashBits + 1 // Number of buckets
-
- maxBondingPingPongs = 16
- maxFindnodeFailures = 5
-
- autoRefreshInterval = 1 * time.Hour
- seedCount = 30
- seedMaxAge = 5 * 24 * time.Hour
+ alpha = 3 // Kademlia concurrency factor
+ bucketSize = 16 // Kademlia bucket size
+ maxReplacements = 10 // Size of per-bucket replacement list
+
+ // We keep buckets for the upper 1/15 of distances because
+ // it's very unlikely we'll ever encounter a node that's closer.
+ hashBits = len(common.Hash{}) * 8
+ nBuckets = hashBits / 15 // Number of buckets
+ bucketMinDistance = hashBits - nBuckets // Log distance of closest bucket
+
+ // IP address limits.
+ bucketIPLimit, bucketSubnet = 2, 24 // at most 2 addresses from the same /24
+ tableIPLimit, tableSubnet = 10, 24
+
+ maxBondingPingPongs = 16 // Limit on the number of concurrent ping/pong interactions
+ maxFindnodeFailures = 5 // Nodes exceeding this limit are dropped
+
+ refreshInterval = 30 * time.Minute
+ revalidateInterval = 10 * time.Second
+ copyNodesInterval = 30 * time.Second
+ seedMinTableTime = 5 * time.Minute
+ seedCount = 30
+ seedMaxAge = 5 * 24 * time.Hour
)
type Table struct {
- mutex sync.Mutex // protects buckets, their content, and nursery
+ mutex sync.Mutex // protects buckets, bucket content, nursery, rand
buckets [nBuckets]*bucket // index of known nodes by distance
nursery []*Node // bootstrap nodes
- db *nodeDB // database of known nodes
+ rand *mrand.Rand // source of randomness, periodically reseeded
+ ips netutil.DistinctNetSet
+ db *nodeDB // database of known nodes
refreshReq chan chan struct{}
+ initDone chan struct{}
closeReq chan struct{}
closed chan struct{}
@@ -89,9 +106,13 @@ type transport interface {
// bucket contains nodes, ordered by their last activity. the entry
// that was most recently active is the first element in entries.
-type bucket struct{ entries []*Node }
+type bucket struct {
+ entries []*Node // live entries, sorted by time of last contact
+ replacements []*Node // recently seen nodes to be used if revalidation fails
+ ips netutil.DistinctNetSet
+}
-func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string) (*Table, error) {
+func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string, bootnodes []*Node) (*Table, error) {
// If no node database was given, use an in-memory one
db, err := newNodeDB(nodeDBPath, Version, ourID)
if err != nil {
@@ -104,19 +125,42 @@ func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string
bonding: make(map[NodeID]*bondproc),
bondslots: make(chan struct{}, maxBondingPingPongs),
refreshReq: make(chan chan struct{}),
+ initDone: make(chan struct{}),
closeReq: make(chan struct{}),
closed: make(chan struct{}),
+ rand: mrand.New(mrand.NewSource(0)),
+ ips: netutil.DistinctNetSet{Subnet: tableSubnet, Limit: tableIPLimit},
+ }
+ if err := tab.setFallbackNodes(bootnodes); err != nil {
+ return nil, err
}
for i := 0; i < cap(tab.bondslots); i++ {
tab.bondslots <- struct{}{}
}
for i := range tab.buckets {
- tab.buckets[i] = new(bucket)
+ tab.buckets[i] = &bucket{
+ ips: netutil.DistinctNetSet{Subnet: bucketSubnet, Limit: bucketIPLimit},
+ }
}
- go tab.refreshLoop()
+ tab.seedRand()
+ tab.loadSeedNodes(false)
+ // Start the background expiration goroutine after loading seeds so that the search for
+ // seed nodes also considers older nodes that would otherwise be removed by the
+ // expiration.
+ tab.db.ensureExpirer()
+ go tab.loop()
return tab, nil
}
+func (tab *Table) seedRand() {
+ var b [8]byte
+ crand.Read(b[:])
+
+ tab.mutex.Lock()
+ tab.rand.Seed(int64(binary.BigEndian.Uint64(b[:])))
+ tab.mutex.Unlock()
+}
+
// Self returns the local node.
// The returned node should not be modified by the caller.
func (tab *Table) Self() *Node {
@@ -127,9 +171,12 @@ func (tab *Table) Self() *Node {
// table. It will not write the same node more than once. The nodes in
// the slice are copies and can be modified by the caller.
func (tab *Table) ReadRandomNodes(buf []*Node) (n int) {
+ if !tab.isInitDone() {
+ return 0
+ }
tab.mutex.Lock()
defer tab.mutex.Unlock()
- // TODO: tree-based buckets would help here
+
// Find all non-empty buckets and get a fresh slice of their entries.
var buckets [][]*Node
for _, b := range tab.buckets {
@@ -141,8 +188,8 @@ func (tab *Table) ReadRandomNodes(buf []*Node) (n int) {
return 0
}
// Shuffle the buckets.
- for i := uint32(len(buckets)) - 1; i > 0; i-- {
- j := randUint(i)
+ for i := len(buckets) - 1; i > 0; i-- {
+ j := tab.rand.Intn(len(buckets))
buckets[i], buckets[j] = buckets[j], buckets[i]
}
// Move head of each bucket into buf, removing buckets that become empty.
@@ -161,15 +208,6 @@ func (tab *Table) ReadRandomNodes(buf []*Node) (n int) {
return i + 1
}
-func randUint(max uint32) uint32 {
- if max == 0 {
- return 0
- }
- var b [4]byte
- rand.Read(b[:])
- return binary.BigEndian.Uint32(b[:]) % max
-}
-
// Close terminates the network listener and flushes the node database.
func (tab *Table) Close() {
select {
@@ -180,16 +218,15 @@ func (tab *Table) Close() {
}
}
-// SetFallbackNodes sets the initial points of contact. These nodes
+// setFallbackNodes sets the initial points of contact. These nodes
// are used to connect to the network if the table is empty and there
// are no known nodes in the database.
-func (tab *Table) SetFallbackNodes(nodes []*Node) error {
+func (tab *Table) setFallbackNodes(nodes []*Node) error {
for _, n := range nodes {
if err := n.validateComplete(); err != nil {
return fmt.Errorf("bad bootstrap/fallback node %q (%v)", n, err)
}
}
- tab.mutex.Lock()
tab.nursery = make([]*Node, 0, len(nodes))
for _, n := range nodes {
cpy := *n
@@ -198,11 +235,19 @@ func (tab *Table) SetFallbackNodes(nodes []*Node) error {
cpy.sha = crypto.Keccak256Hash(n.ID[:])
tab.nursery = append(tab.nursery, &cpy)
}
- tab.mutex.Unlock()
- tab.refresh()
return nil
}
+// isInitDone returns whether the table's initial seeding procedure has completed.
+func (tab *Table) isInitDone() bool {
+ select {
+ case <-tab.initDone:
+ return true
+ default:
+ return false
+ }
+}
+
// Resolve searches for a specific node with the given ID.
// It returns nil if the node could not be found.
func (tab *Table) Resolve(targetID NodeID) *Node {
@@ -314,33 +359,49 @@ func (tab *Table) refresh() <-chan struct{} {
return done
}
-// refreshLoop schedules doRefresh runs and coordinates shutdown.
-func (tab *Table) refreshLoop() {
+// loop schedules refresh, revalidate runs and coordinates shutdown.
+func (tab *Table) loop() {
var (
- timer = time.NewTicker(autoRefreshInterval)
- waiting []chan struct{} // accumulates waiting callers while doRefresh runs
- done chan struct{} // where doRefresh reports completion
+ revalidate = time.NewTimer(tab.nextRevalidateTime())
+ refresh = time.NewTicker(refreshInterval)
+ copyNodes = time.NewTicker(copyNodesInterval)
+ revalidateDone = make(chan struct{})
+ refreshDone = make(chan struct{}) // where doRefresh reports completion
+ waiting = []chan struct{}{tab.initDone} // holds waiting callers while doRefresh runs
)
+ defer refresh.Stop()
+ defer revalidate.Stop()
+ defer copyNodes.Stop()
+
+ // Start initial refresh.
+ go tab.doRefresh(refreshDone)
+
loop:
for {
select {
- case <-timer.C:
- if done == nil {
- done = make(chan struct{})
- go tab.doRefresh(done)
+ case <-refresh.C:
+ tab.seedRand()
+ if refreshDone == nil {
+ refreshDone = make(chan struct{})
+ go tab.doRefresh(refreshDone)
}
case req := <-tab.refreshReq:
waiting = append(waiting, req)
- if done == nil {
- done = make(chan struct{})
- go tab.doRefresh(done)
+ if refreshDone == nil {
+ refreshDone = make(chan struct{})
+ go tab.doRefresh(refreshDone)
}
- case <-done:
+ case <-refreshDone:
for _, ch := range waiting {
close(ch)
}
- waiting = nil
- done = nil
+ waiting, refreshDone = nil, nil
+ case <-revalidate.C:
+ go tab.doRevalidate(revalidateDone)
+ case <-revalidateDone:
+ revalidate.Reset(tab.nextRevalidateTime())
+ case <-copyNodes.C:
+ go tab.copyBondedNodes()
case <-tab.closeReq:
break loop
}
@@ -349,8 +410,8 @@ loop:
if tab.net != nil {
tab.net.close()
}
- if done != nil {
- <-done
+ if refreshDone != nil {
+ <-refreshDone
}
for _, ch := range waiting {
close(ch)
@@ -365,38 +426,109 @@ loop:
func (tab *Table) doRefresh(done chan struct{}) {
defer close(done)
+ // Load nodes from the database and insert
+ // them. This should yield a few previously seen nodes that are
+ // (hopefully) still alive.
+ tab.loadSeedNodes(true)
+
+ // Run self lookup to discover new neighbor nodes.
+ tab.lookup(tab.self.ID, false)
+
// The Kademlia paper specifies that the bucket refresh should
// perform a lookup in the least recently used bucket. We cannot
// adhere to this because the findnode target is a 512bit value
// (not hash-sized) and it is not easily possible to generate a
// sha3 preimage that falls into a chosen bucket.
- // We perform a lookup with a random target instead.
- var target NodeID
- rand.Read(target[:])
- result := tab.lookup(target, false)
- if len(result) > 0 {
- return
+ // We perform a few lookups with a random target instead.
+ for i := 0; i < 3; i++ {
+ var target NodeID
+ crand.Read(target[:])
+ tab.lookup(target, false)
}
+}
- // The table is empty. Load nodes from the database and insert
- // them. This should yield a few previously seen nodes that are
- // (hopefully) still alive.
+func (tab *Table) loadSeedNodes(bond bool) {
seeds := tab.db.querySeeds(seedCount, seedMaxAge)
- seeds = tab.bondall(append(seeds, tab.nursery...))
+ seeds = append(seeds, tab.nursery...)
+ if bond {
+ seeds = tab.bondall(seeds)
+ }
+ for i := range seeds {
+ seed := seeds[i]
+ age := log.Lazy{Fn: func() interface{} { return time.Since(tab.db.lastPong(seed.ID)) }}
+ log.Debug("Found seed node in database", "id", seed.ID, "addr", seed.addr(), "age", age)
+ tab.add(seed)
+ }
+}
+
+// doRevalidate checks that the last node in a random bucket is still live
+// and replaces or deletes the node if it isn't.
+func (tab *Table) doRevalidate(done chan<- struct{}) {
+ defer func() { done <- struct{}{} }()
- if len(seeds) == 0 {
- log.Debug("No discv4 seed nodes found")
+ last, bi := tab.nodeToRevalidate()
+ if last == nil {
+ // No non-empty bucket found.
+ return
+ }
+
+ // Ping the selected node and wait for a pong.
+ err := tab.ping(last.ID, last.addr())
+
+ tab.mutex.Lock()
+ defer tab.mutex.Unlock()
+ b := tab.buckets[bi]
+ if err == nil {
+ // The node responded, move it to the front.
+ log.Debug("Revalidated node", "b", bi, "id", last.ID)
+ b.bump(last)
+ return
+ }
+ // No reply received, pick a replacement or delete the node if there aren't
+ // any replacements.
+ if r := tab.replace(b, last); r != nil {
+ log.Debug("Replaced dead node", "b", bi, "id", last.ID, "ip", last.IP, "r", r.ID, "rip", r.IP)
+ } else {
+ log.Debug("Removed dead node", "b", bi, "id", last.ID, "ip", last.IP)
}
- for _, n := range seeds {
- age := log.Lazy{Fn: func() time.Duration { return time.Since(tab.db.lastPong(n.ID)) }}
- log.Trace("Found seed node in database", "id", n.ID, "addr", n.addr(), "age", age)
+}
+
+// nodeToRevalidate returns the last node in a random, non-empty bucket.
+func (tab *Table) nodeToRevalidate() (n *Node, bi int) {
+ tab.mutex.Lock()
+ defer tab.mutex.Unlock()
+
+ for _, bi = range tab.rand.Perm(len(tab.buckets)) {
+ b := tab.buckets[bi]
+ if len(b.entries) > 0 {
+ last := b.entries[len(b.entries)-1]
+ return last, bi
+ }
}
+ return nil, 0
+}
+
+func (tab *Table) nextRevalidateTime() time.Duration {
tab.mutex.Lock()
- tab.stuff(seeds)
- tab.mutex.Unlock()
+ defer tab.mutex.Unlock()
- // Finally, do a self lookup to fill up the buckets.
- tab.lookup(tab.self.ID, false)
+ return time.Duration(tab.rand.Int63n(int64(revalidateInterval)))
+}
+
+// copyBondedNodes adds nodes from the table to the database if they have been in the table
+// longer then minTableTime.
+func (tab *Table) copyBondedNodes() {
+ tab.mutex.Lock()
+ defer tab.mutex.Unlock()
+
+ now := time.Now()
+ for _, b := range tab.buckets {
+ for _, n := range b.entries {
+ if now.Sub(n.addedAt) >= seedMinTableTime {
+ tab.db.updateNode(n)
+ }
+ }
+ }
}
// closest returns the n nodes in the table that are closest to the
@@ -459,15 +591,14 @@ func (tab *Table) bond(pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16
if id == tab.self.ID {
return nil, errors.New("is self")
}
- // Retrieve a previously known node and any recent findnode failures
- node, fails := tab.db.node(id), 0
- if node != nil {
- fails = tab.db.findFails(id)
+ if pinged && !tab.isInitDone() {
+ return nil, errors.New("still initializing")
}
- // If the node is unknown (non-bonded) or failed (remotely unknown), bond from scratch
- var result error
+ // Start bonding if we haven't seen this node for a while or if it failed findnode too often.
+ node, fails := tab.db.node(id), tab.db.findFails(id)
age := time.Since(tab.db.lastPong(id))
- if node == nil || fails > 0 || age > nodeDBNodeExpiration {
+ var result error
+ if fails > 0 || age > nodeDBNodeExpiration {
log.Trace("Starting bonding ping/pong", "id", id, "known", node != nil, "failcount", fails, "age", age)
tab.bondmu.Lock()
@@ -494,10 +625,10 @@ func (tab *Table) bond(pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16
node = w.n
}
}
+ // Add the node to the table even if the bonding ping/pong
+ // fails. It will be relaced quickly if it continues to be
+ // unresponsive.
if node != nil {
- // Add the node to the table even if the bonding ping/pong
- // fails. It will be relaced quickly if it continues to be
- // unresponsive.
tab.add(node)
tab.db.updateFindFails(id, 0)
}
@@ -522,7 +653,6 @@ func (tab *Table) pingpong(w *bondproc, pinged bool, id NodeID, addr *net.UDPAdd
}
// Bonding succeeded, update the node database.
w.n = NewNode(id, addr.IP, uint16(addr.Port), tcpPort)
- tab.db.updateNode(w.n)
close(w.done)
}
@@ -534,16 +664,18 @@ func (tab *Table) ping(id NodeID, addr *net.UDPAddr) error {
return err
}
tab.db.updateLastPong(id, time.Now())
-
- // Start the background expiration goroutine after the first
- // successful communication. Subsequent calls have no effect if it
- // is already running. We do this here instead of somewhere else
- // so that the search for seed nodes also considers older nodes
- // that would otherwise be removed by the expiration.
- tab.db.ensureExpirer()
return nil
}
+// bucket returns the bucket for the given node ID hash.
+func (tab *Table) bucket(sha common.Hash) *bucket {
+ d := logdist(tab.self.sha, sha)
+ if d <= bucketMinDistance {
+ return tab.buckets[0]
+ }
+ return tab.buckets[d-bucketMinDistance-1]
+}
+
// add attempts to add the given node its corresponding bucket. If the
// bucket has space available, adding the node succeeds immediately.
// Otherwise, the node is added if the least recently active node in
@@ -551,57 +683,29 @@ func (tab *Table) ping(id NodeID, addr *net.UDPAddr) error {
//
// The caller must not hold tab.mutex.
func (tab *Table) add(new *Node) {
- b := tab.buckets[logdist(tab.self.sha, new.sha)]
tab.mutex.Lock()
defer tab.mutex.Unlock()
- if b.bump(new) {
- return
- }
- var oldest *Node
- if len(b.entries) == bucketSize {
- oldest = b.entries[bucketSize-1]
- if oldest.contested {
- // The node is already being replaced, don't attempt
- // to replace it.
- return
- }
- oldest.contested = true
- // Let go of the mutex so other goroutines can access
- // the table while we ping the least recently active node.
- tab.mutex.Unlock()
- err := tab.ping(oldest.ID, oldest.addr())
- tab.mutex.Lock()
- oldest.contested = false
- if err == nil {
- // The node responded, don't replace it.
- return
- }
- }
- added := b.replace(new, oldest)
- if added && tab.nodeAddedHook != nil {
- tab.nodeAddedHook(new)
+
+ b := tab.bucket(new.sha)
+ if !tab.bumpOrAdd(b, new) {
+ // Node is not in table. Add it to the replacement list.
+ tab.addReplacement(b, new)
}
}
// stuff adds nodes the table to the end of their corresponding bucket
-// if the bucket is not full. The caller must hold tab.mutex.
+// if the bucket is not full. The caller must not hold tab.mutex.
func (tab *Table) stuff(nodes []*Node) {
-outer:
+ tab.mutex.Lock()
+ defer tab.mutex.Unlock()
+
for _, n := range nodes {
if n.ID == tab.self.ID {
continue // don't add self
}
- bucket := tab.buckets[logdist(tab.self.sha, n.sha)]
- for i := range bucket.entries {
- if bucket.entries[i].ID == n.ID {
- continue outer // already in bucket
- }
- }
- if len(bucket.entries) < bucketSize {
- bucket.entries = append(bucket.entries, n)
- if tab.nodeAddedHook != nil {
- tab.nodeAddedHook(n)
- }
+ b := tab.bucket(n.sha)
+ if len(b.entries) < bucketSize {
+ tab.bumpOrAdd(b, n)
}
}
}
@@ -611,36 +715,72 @@ outer:
func (tab *Table) delete(node *Node) {
tab.mutex.Lock()
defer tab.mutex.Unlock()
- bucket := tab.buckets[logdist(tab.self.sha, node.sha)]
- for i := range bucket.entries {
- if bucket.entries[i].ID == node.ID {
- bucket.entries = append(bucket.entries[:i], bucket.entries[i+1:]...)
- return
- }
- }
+
+ tab.deleteInBucket(tab.bucket(node.sha), node)
}
-func (b *bucket) replace(n *Node, last *Node) bool {
- // Don't add if b already contains n.
- for i := range b.entries {
- if b.entries[i].ID == n.ID {
- return false
- }
+func (tab *Table) addIP(b *bucket, ip net.IP) bool {
+ if netutil.IsLAN(ip) {
+ return true
}
- // Replace last if it is still the last entry or just add n if b
- // isn't full. If is no longer the last entry, it has either been
- // replaced with someone else or became active.
- if len(b.entries) == bucketSize && (last == nil || b.entries[bucketSize-1].ID != last.ID) {
+ if !tab.ips.Add(ip) {
+ log.Debug("IP exceeds table limit", "ip", ip)
return false
}
- if len(b.entries) < bucketSize {
- b.entries = append(b.entries, nil)
+ if !b.ips.Add(ip) {
+ log.Debug("IP exceeds bucket limit", "ip", ip)
+ tab.ips.Remove(ip)
+ return false
}
- copy(b.entries[1:], b.entries)
- b.entries[0] = n
return true
}
+func (tab *Table) removeIP(b *bucket, ip net.IP) {
+ if netutil.IsLAN(ip) {
+ return
+ }
+ tab.ips.Remove(ip)
+ b.ips.Remove(ip)
+}
+
+func (tab *Table) addReplacement(b *bucket, n *Node) {
+ for _, e := range b.replacements {
+ if e.ID == n.ID {
+ return // already in list
+ }
+ }
+ if !tab.addIP(b, n.IP) {
+ return
+ }
+ var removed *Node
+ b.replacements, removed = pushNode(b.replacements, n, maxReplacements)
+ if removed != nil {
+ tab.removeIP(b, removed.IP)
+ }
+}
+
+// replace removes n from the replacement list and replaces 'last' with it if it is the
+// last entry in the bucket. If 'last' isn't the last entry, it has either been replaced
+// with someone else or became active.
+func (tab *Table) replace(b *bucket, last *Node) *Node {
+ if len(b.entries) >= 0 && b.entries[len(b.entries)-1].ID != last.ID {
+ // Entry has moved, don't replace it.
+ return nil
+ }
+ // Still the last entry.
+ if len(b.replacements) == 0 {
+ tab.deleteInBucket(b, last)
+ return nil
+ }
+ r := b.replacements[tab.rand.Intn(len(b.replacements))]
+ b.replacements = deleteNode(b.replacements, r)
+ b.entries[len(b.entries)-1] = r
+ tab.removeIP(b, last.IP)
+ return r
+}
+
+// bump moves the given node to the front of the bucket entry list
+// if it is contained in that list.
func (b *bucket) bump(n *Node) bool {
for i := range b.entries {
if b.entries[i].ID == n.ID {
@@ -653,6 +793,50 @@ func (b *bucket) bump(n *Node) bool {
return false
}
+// bumpOrAdd moves n to the front of the bucket entry list or adds it if the list isn't
+// full. The return value is true if n is in the bucket.
+func (tab *Table) bumpOrAdd(b *bucket, n *Node) bool {
+ if b.bump(n) {
+ return true
+ }
+ if len(b.entries) >= bucketSize || !tab.addIP(b, n.IP) {
+ return false
+ }
+ b.entries, _ = pushNode(b.entries, n, bucketSize)
+ b.replacements = deleteNode(b.replacements, n)
+ n.addedAt = time.Now()
+ if tab.nodeAddedHook != nil {
+ tab.nodeAddedHook(n)
+ }
+ return true
+}
+
+func (tab *Table) deleteInBucket(b *bucket, n *Node) {
+ b.entries = deleteNode(b.entries, n)
+ tab.removeIP(b, n.IP)
+}
+
+// pushNode adds n to the front of list, keeping at most max items.
+func pushNode(list []*Node, n *Node, max int) ([]*Node, *Node) {
+ if len(list) < max {
+ list = append(list, nil)
+ }
+ removed := list[len(list)-1]
+ copy(list[1:], list)
+ list[0] = n
+ return list, removed
+}
+
+// deleteNode removes n from list.
+func deleteNode(list []*Node, n *Node) []*Node {
+ for i := range list {
+ if list[i].ID == n.ID {
+ return append(list[:i], list[i+1:]...)
+ }
+ }
+ return list
+}
+
// nodesByDistance is a list of nodes, ordered by
// distance to target.
type nodesByDistance struct {
diff --git a/p2p/discover/table_test.go b/p2p/discover/table_test.go
index 1037cc609..3ce48d299 100644
--- a/p2p/discover/table_test.go
+++ b/p2p/discover/table_test.go
@@ -20,6 +20,7 @@ import (
"crypto/ecdsa"
"fmt"
"math/rand"
+ "sync"
"net"
"reflect"
@@ -32,60 +33,65 @@ import (
)
func TestTable_pingReplace(t *testing.T) {
- doit := func(newNodeIsResponding, lastInBucketIsResponding bool) {
- transport := newPingRecorder()
- tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "")
- defer tab.Close()
- pingSender := NewNode(MustHexID("a502af0f59b2aab7746995408c79e9ca312d2793cc997e44fc55eda62f0150bbb8c59a6f9269ba3a081518b62699ee807c7c19c20125ddfccca872608af9e370"), net.IP{}, 99, 99)
+ run := func(newNodeResponding, lastInBucketResponding bool) {
+ name := fmt.Sprintf("newNodeResponding=%t/lastInBucketResponding=%t", newNodeResponding, lastInBucketResponding)
+ t.Run(name, func(t *testing.T) {
+ t.Parallel()
+ testPingReplace(t, newNodeResponding, lastInBucketResponding)
+ })
+ }
- // fill up the sender's bucket.
- last := fillBucket(tab, 253)
+ run(true, true)
+ run(false, true)
+ run(true, false)
+ run(false, false)
+}
- // this call to bond should replace the last node
- // in its bucket if the node is not responding.
- transport.responding[last.ID] = lastInBucketIsResponding
- transport.responding[pingSender.ID] = newNodeIsResponding
- tab.bond(true, pingSender.ID, &net.UDPAddr{}, 0)
+func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding bool) {
+ transport := newPingRecorder()
+ tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil)
+ defer tab.Close()
- // first ping goes to sender (bonding pingback)
- if !transport.pinged[pingSender.ID] {
- t.Error("table did not ping back sender")
- }
- if newNodeIsResponding {
- // second ping goes to oldest node in bucket
- // to see whether it is still alive.
- if !transport.pinged[last.ID] {
- t.Error("table did not ping last node in bucket")
- }
- }
+ // Wait for init so bond is accepted.
+ <-tab.initDone
- tab.mutex.Lock()
- defer tab.mutex.Unlock()
- if l := len(tab.buckets[253].entries); l != bucketSize {
- t.Errorf("wrong bucket size after bond: got %d, want %d", l, bucketSize)
- }
+ // fill up the sender's bucket.
+ pingSender := NewNode(MustHexID("a502af0f59b2aab7746995408c79e9ca312d2793cc997e44fc55eda62f0150bbb8c59a6f9269ba3a081518b62699ee807c7c19c20125ddfccca872608af9e370"), net.IP{}, 99, 99)
+ last := fillBucket(tab, pingSender)
- if lastInBucketIsResponding || !newNodeIsResponding {
- if !contains(tab.buckets[253].entries, last.ID) {
- t.Error("last entry was removed")
- }
- if contains(tab.buckets[253].entries, pingSender.ID) {
- t.Error("new entry was added")
- }
- } else {
- if contains(tab.buckets[253].entries, last.ID) {
- t.Error("last entry was not removed")
- }
- if !contains(tab.buckets[253].entries, pingSender.ID) {
- t.Error("new entry was not added")
- }
- }
+ // this call to bond should replace the last node
+ // in its bucket if the node is not responding.
+ transport.dead[last.ID] = !lastInBucketIsResponding
+ transport.dead[pingSender.ID] = !newNodeIsResponding
+ tab.bond(true, pingSender.ID, &net.UDPAddr{}, 0)
+ tab.doRevalidate(make(chan struct{}, 1))
+
+ // first ping goes to sender (bonding pingback)
+ if !transport.pinged[pingSender.ID] {
+ t.Error("table did not ping back sender")
+ }
+ if !transport.pinged[last.ID] {
+ // second ping goes to oldest node in bucket
+ // to see whether it is still alive.
+ t.Error("table did not ping last node in bucket")
}
- doit(true, true)
- doit(false, true)
- doit(true, false)
- doit(false, false)
+ tab.mutex.Lock()
+ defer tab.mutex.Unlock()
+ wantSize := bucketSize
+ if !lastInBucketIsResponding && !newNodeIsResponding {
+ wantSize--
+ }
+ if l := len(tab.bucket(pingSender.sha).entries); l != wantSize {
+ t.Errorf("wrong bucket size after bond: got %d, want %d", l, wantSize)
+ }
+ if found := contains(tab.bucket(pingSender.sha).entries, last.ID); found != lastInBucketIsResponding {
+ t.Errorf("last entry found: %t, want: %t", found, lastInBucketIsResponding)
+ }
+ wantNewEntry := newNodeIsResponding && !lastInBucketIsResponding
+ if found := contains(tab.bucket(pingSender.sha).entries, pingSender.ID); found != wantNewEntry {
+ t.Errorf("new entry found: %t, want: %t", found, wantNewEntry)
+ }
}
func TestBucket_bumpNoDuplicates(t *testing.T) {
@@ -130,11 +136,45 @@ func TestBucket_bumpNoDuplicates(t *testing.T) {
}
}
+// This checks that the table-wide IP limit is applied correctly.
+func TestTable_IPLimit(t *testing.T) {
+ transport := newPingRecorder()
+ tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil)
+ defer tab.Close()
+
+ for i := 0; i < tableIPLimit+1; i++ {
+ n := nodeAtDistance(tab.self.sha, i)
+ n.IP = net.IP{172, 0, 1, byte(i)}
+ tab.add(n)
+ }
+ if tab.len() > tableIPLimit {
+ t.Errorf("too many nodes in table")
+ }
+}
+
+// This checks that the table-wide IP limit is applied correctly.
+func TestTable_BucketIPLimit(t *testing.T) {
+ transport := newPingRecorder()
+ tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil)
+ defer tab.Close()
+
+ d := 3
+ for i := 0; i < bucketIPLimit+1; i++ {
+ n := nodeAtDistance(tab.self.sha, d)
+ n.IP = net.IP{172, 0, 1, byte(i)}
+ tab.add(n)
+ }
+ if tab.len() > bucketIPLimit {
+ t.Errorf("too many nodes in table")
+ }
+}
+
// fillBucket inserts nodes into the given bucket until
// it is full. The node's IDs dont correspond to their
// hashes.
-func fillBucket(tab *Table, ld int) (last *Node) {
- b := tab.buckets[ld]
+func fillBucket(tab *Table, n *Node) (last *Node) {
+ ld := logdist(tab.self.sha, n.sha)
+ b := tab.bucket(n.sha)
for len(b.entries) < bucketSize {
b.entries = append(b.entries, nodeAtDistance(tab.self.sha, ld))
}
@@ -146,30 +186,39 @@ func fillBucket(tab *Table, ld int) (last *Node) {
func nodeAtDistance(base common.Hash, ld int) (n *Node) {
n = new(Node)
n.sha = hashAtDistance(base, ld)
- n.IP = net.IP{10, 0, 2, byte(ld)}
+ n.IP = net.IP{byte(ld), 0, 2, byte(ld)}
copy(n.ID[:], n.sha[:]) // ensure the node still has a unique ID
return n
}
-type pingRecorder struct{ responding, pinged map[NodeID]bool }
+type pingRecorder struct {
+ mu sync.Mutex
+ dead, pinged map[NodeID]bool
+}
func newPingRecorder() *pingRecorder {
- return &pingRecorder{make(map[NodeID]bool), make(map[NodeID]bool)}
+ return &pingRecorder{
+ dead: make(map[NodeID]bool),
+ pinged: make(map[NodeID]bool),
+ }
}
func (t *pingRecorder) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) {
- panic("findnode called on pingRecorder")
+ return nil, nil
}
func (t *pingRecorder) close() {}
func (t *pingRecorder) waitping(from NodeID) error {
return nil // remote always pings
}
func (t *pingRecorder) ping(toid NodeID, toaddr *net.UDPAddr) error {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
t.pinged[toid] = true
- if t.responding[toid] {
- return nil
- } else {
+ if t.dead[toid] {
return errTimeout
+ } else {
+ return nil
}
}
@@ -178,7 +227,8 @@ func TestTable_closest(t *testing.T) {
test := func(test *closeTest) bool {
// for any node table, Target and N
- tab, _ := newTable(nil, test.Self, &net.UDPAddr{}, "")
+ transport := newPingRecorder()
+ tab, _ := newTable(transport, test.Self, &net.UDPAddr{}, "", nil)
defer tab.Close()
tab.stuff(test.All)
@@ -237,8 +287,11 @@ func TestTable_ReadRandomNodesGetAll(t *testing.T) {
},
}
test := func(buf []*Node) bool {
- tab, _ := newTable(nil, NodeID{}, &net.UDPAddr{}, "")
+ transport := newPingRecorder()
+ tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil)
defer tab.Close()
+ <-tab.initDone
+
for i := 0; i < len(buf); i++ {
ld := cfg.Rand.Intn(len(tab.buckets))
tab.stuff([]*Node{nodeAtDistance(tab.self.sha, ld)})
@@ -280,7 +333,7 @@ func (*closeTest) Generate(rand *rand.Rand, size int) reflect.Value {
func TestTable_Lookup(t *testing.T) {
self := nodeAtDistance(common.Hash{}, 0)
- tab, _ := newTable(lookupTestnet, self.ID, &net.UDPAddr{}, "")
+ tab, _ := newTable(lookupTestnet, self.ID, &net.UDPAddr{}, "", nil)
defer tab.Close()
// lookup on empty table returns no nodes
diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go
index 60436952d..e40de2c36 100644
--- a/p2p/discover/udp.go
+++ b/p2p/discover/udp.go
@@ -216,9 +216,22 @@ type ReadPacket struct {
Addr *net.UDPAddr
}
+// Config holds Table-related settings.
+type Config struct {
+ // These settings are required and configure the UDP listener:
+ PrivateKey *ecdsa.PrivateKey
+
+ // These settings are optional:
+ AnnounceAddr *net.UDPAddr // local address announced in the DHT
+ NodeDBPath string // if set, the node database is stored at this filesystem location
+ NetRestrict *netutil.Netlist // network whitelist
+ Bootnodes []*Node // list of bootstrap nodes
+ Unhandled chan<- ReadPacket // unhandled packets are sent on this channel
+}
+
// ListenUDP returns a new table that listens for UDP packets on laddr.
-func ListenUDP(priv *ecdsa.PrivateKey, conn conn, realaddr *net.UDPAddr, unhandled chan ReadPacket, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, error) {
- tab, _, err := newUDP(priv, conn, realaddr, unhandled, nodeDBPath, netrestrict)
+func ListenUDP(c conn, cfg Config) (*Table, error) {
+ tab, _, err := newUDP(c, cfg)
if err != nil {
return nil, err
}
@@ -226,25 +239,29 @@ func ListenUDP(priv *ecdsa.PrivateKey, conn conn, realaddr *net.UDPAddr, unhandl
return tab, nil
}
-func newUDP(priv *ecdsa.PrivateKey, c conn, realaddr *net.UDPAddr, unhandled chan ReadPacket, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, *udp, error) {
+func newUDP(c conn, cfg Config) (*Table, *udp, error) {
udp := &udp{
conn: c,
- priv: priv,
- netrestrict: netrestrict,
+ priv: cfg.PrivateKey,
+ netrestrict: cfg.NetRestrict,
closing: make(chan struct{}),
gotreply: make(chan reply),
addpending: make(chan *pending),
}
+ realaddr := c.LocalAddr().(*net.UDPAddr)
+ if cfg.AnnounceAddr != nil {
+ realaddr = cfg.AnnounceAddr
+ }
// TODO: separate TCP port
udp.ourEndpoint = makeEndpoint(realaddr, uint16(realaddr.Port))
- tab, err := newTable(udp, PubkeyID(&priv.PublicKey), realaddr, nodeDBPath)
+ tab, err := newTable(udp, PubkeyID(&cfg.PrivateKey.PublicKey), realaddr, cfg.NodeDBPath, cfg.Bootnodes)
if err != nil {
return nil, nil, err
}
udp.Table = tab
go udp.loop()
- go udp.readLoop(unhandled)
+ go udp.readLoop(cfg.Unhandled)
return udp.Table, udp, nil
}
@@ -256,14 +273,20 @@ func (t *udp) close() {
// ping sends a ping message to the given node and waits for a reply.
func (t *udp) ping(toid NodeID, toaddr *net.UDPAddr) error {
- // TODO: maybe check for ReplyTo field in callback to measure RTT
- errc := t.pending(toid, pongPacket, func(interface{}) bool { return true })
- t.send(toaddr, pingPacket, &ping{
+ req := &ping{
Version: Version,
From: t.ourEndpoint,
To: makeEndpoint(toaddr, 0), // TODO: maybe use known TCP port from DB
Expiration: uint64(time.Now().Add(expiration).Unix()),
+ }
+ packet, hash, err := encodePacket(t.priv, pingPacket, req)
+ if err != nil {
+ return err
+ }
+ errc := t.pending(toid, pongPacket, func(p interface{}) bool {
+ return bytes.Equal(p.(*pong).ReplyTok, hash)
})
+ t.write(toaddr, req.name(), packet)
return <-errc
}
@@ -447,40 +470,45 @@ func init() {
}
}
-func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req packet) error {
- packet, err := encodePacket(t.priv, ptype, req)
+func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req packet) ([]byte, error) {
+ packet, hash, err := encodePacket(t.priv, ptype, req)
if err != nil {
- return err
+ return hash, err
}
- _, err = t.conn.WriteToUDP(packet, toaddr)
- log.Trace(">> "+req.name(), "addr", toaddr, "err", err)
+ return hash, t.write(toaddr, req.name(), packet)
+}
+
+func (t *udp) write(toaddr *net.UDPAddr, what string, packet []byte) error {
+ _, err := t.conn.WriteToUDP(packet, toaddr)
+ log.Trace(">> "+what, "addr", toaddr, "err", err)
return err
}
-func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) ([]byte, error) {
+func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) (packet, hash []byte, err error) {
b := new(bytes.Buffer)
b.Write(headSpace)
b.WriteByte(ptype)
if err := rlp.Encode(b, req); err != nil {
log.Error("Can't encode discv4 packet", "err", err)
- return nil, err
+ return nil, nil, err
}
- packet := b.Bytes()
+ packet = b.Bytes()
sig, err := crypto.Sign(crypto.Keccak256(packet[headSize:]), priv)
if err != nil {
log.Error("Can't sign discv4 packet", "err", err)
- return nil, err
+ return nil, nil, err
}
copy(packet[macSize:], sig)
// add the hash to the front. Note: this doesn't protect the
// packet in any way. Our public key will be part of this hash in
// The future.
- copy(packet, crypto.Keccak256(packet[macSize:]))
- return packet, nil
+ hash = crypto.Keccak256(packet[macSize:])
+ copy(packet, hash)
+ return packet, hash, nil
}
// readLoop runs in its own goroutine. it handles incoming UDP packets.
-func (t *udp) readLoop(unhandled chan ReadPacket) {
+func (t *udp) readLoop(unhandled chan<- ReadPacket) {
defer t.conn.Close()
if unhandled != nil {
defer close(unhandled)
@@ -601,18 +629,22 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte
t.mutex.Unlock()
p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())}
+ var sent bool
// Send neighbors in chunks with at most maxNeighbors per packet
// to stay below the 1280 byte limit.
- for i, n := range closest {
- if netutil.CheckRelayIP(from.IP, n.IP) != nil {
- continue
+ for _, n := range closest {
+ if netutil.CheckRelayIP(from.IP, n.IP) == nil {
+ p.Nodes = append(p.Nodes, nodeToRPC(n))
}
- p.Nodes = append(p.Nodes, nodeToRPC(n))
- if len(p.Nodes) == maxNeighbors || i == len(closest)-1 {
+ if len(p.Nodes) == maxNeighbors {
t.send(from, neighborsPacket, &p)
p.Nodes = p.Nodes[:0]
+ sent = true
}
}
+ if len(p.Nodes) > 0 || !sent {
+ t.send(from, neighborsPacket, &p)
+ }
return nil
}
diff --git a/p2p/discover/udp_test.go b/p2p/discover/udp_test.go
index b81caf839..3ffa5c4dd 100644
--- a/p2p/discover/udp_test.go
+++ b/p2p/discover/udp_test.go
@@ -70,14 +70,15 @@ func newUDPTest(t *testing.T) *udpTest {
remotekey: newkey(),
remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303},
}
- realaddr := test.pipe.LocalAddr().(*net.UDPAddr)
- test.table, test.udp, _ = newUDP(test.localkey, test.pipe, realaddr, nil, "", nil)
+ test.table, test.udp, _ = newUDP(test.pipe, Config{PrivateKey: test.localkey})
+ // Wait for initial refresh so the table doesn't send unexpected findnode.
+ <-test.table.initDone
return test
}
// handles a packet as if it had been sent to the transport.
func (test *udpTest) packetIn(wantError error, ptype byte, data packet) error {
- enc, err := encodePacket(test.remotekey, ptype, data)
+ enc, _, err := encodePacket(test.remotekey, ptype, data)
if err != nil {
return test.errorf("packet (%d) encode error: %v", ptype, err)
}
@@ -90,19 +91,19 @@ func (test *udpTest) packetIn(wantError error, ptype byte, data packet) error {
// waits for a packet to be sent by the transport.
// validate should have type func(*udpTest, X) error, where X is a packet type.
-func (test *udpTest) waitPacketOut(validate interface{}) error {
+func (test *udpTest) waitPacketOut(validate interface{}) ([]byte, error) {
dgram := test.pipe.waitPacketOut()
- p, _, _, err := decodePacket(dgram)
+ p, _, hash, err := decodePacket(dgram)
if err != nil {
- return test.errorf("sent packet decode error: %v", err)
+ return hash, test.errorf("sent packet decode error: %v", err)
}
fn := reflect.ValueOf(validate)
exptype := fn.Type().In(0)
if reflect.TypeOf(p) != exptype {
- return test.errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype)
+ return hash, test.errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype)
}
fn.Call([]reflect.Value{reflect.ValueOf(p)})
- return nil
+ return hash, nil
}
func (test *udpTest) errorf(format string, args ...interface{}) error {
@@ -351,7 +352,7 @@ func TestUDP_successfulPing(t *testing.T) {
})
// remote is unknown, the table pings back.
- test.waitPacketOut(func(p *ping) error {
+ hash, _ := test.waitPacketOut(func(p *ping) error {
if !reflect.DeepEqual(p.From, test.udp.ourEndpoint) {
t.Errorf("got ping.From %v, want %v", p.From, test.udp.ourEndpoint)
}
@@ -365,7 +366,7 @@ func TestUDP_successfulPing(t *testing.T) {
}
return nil
})
- test.packetIn(nil, pongPacket, &pong{Expiration: futureExp})
+ test.packetIn(nil, pongPacket, &pong{ReplyTok: hash, Expiration: futureExp})
// the node should be added to the table shortly after getting the
// pong packet.
diff --git a/p2p/netutil/net.go b/p2p/netutil/net.go
index f6005afd2..656abb682 100644
--- a/p2p/netutil/net.go
+++ b/p2p/netutil/net.go
@@ -18,8 +18,11 @@
package netutil
import (
+ "bytes"
"errors"
+ "fmt"
"net"
+ "sort"
"strings"
)
@@ -189,3 +192,131 @@ func CheckRelayIP(sender, addr net.IP) error {
}
return nil
}
+
+// SameNet reports whether two IP addresses have an equal prefix of the given bit length.
+func SameNet(bits uint, ip, other net.IP) bool {
+ ip4, other4 := ip.To4(), other.To4()
+ switch {
+ case (ip4 == nil) != (other4 == nil):
+ return false
+ case ip4 != nil:
+ return sameNet(bits, ip4, other4)
+ default:
+ return sameNet(bits, ip.To16(), other.To16())
+ }
+}
+
+func sameNet(bits uint, ip, other net.IP) bool {
+ nb := int(bits / 8)
+ mask := ^byte(0xFF >> (bits % 8))
+ if mask != 0 && nb < len(ip) && ip[nb]&mask != other[nb]&mask {
+ return false
+ }
+ return nb <= len(ip) && bytes.Equal(ip[:nb], other[:nb])
+}
+
+// DistinctNetSet tracks IPs, ensuring that at most N of them
+// fall into the same network range.
+type DistinctNetSet struct {
+ Subnet uint // number of common prefix bits
+ Limit uint // maximum number of IPs in each subnet
+
+ members map[string]uint
+ buf net.IP
+}
+
+// Add adds an IP address to the set. It returns false (and doesn't add the IP) if the
+// number of existing IPs in the defined range exceeds the limit.
+func (s *DistinctNetSet) Add(ip net.IP) bool {
+ key := s.key(ip)
+ n := s.members[string(key)]
+ if n < s.Limit {
+ s.members[string(key)] = n + 1
+ return true
+ }
+ return false
+}
+
+// Remove removes an IP from the set.
+func (s *DistinctNetSet) Remove(ip net.IP) {
+ key := s.key(ip)
+ if n, ok := s.members[string(key)]; ok {
+ if n == 1 {
+ delete(s.members, string(key))
+ } else {
+ s.members[string(key)] = n - 1
+ }
+ }
+}
+
+// Contains whether the given IP is contained in the set.
+func (s DistinctNetSet) Contains(ip net.IP) bool {
+ key := s.key(ip)
+ _, ok := s.members[string(key)]
+ return ok
+}
+
+// Len returns the number of tracked IPs.
+func (s DistinctNetSet) Len() int {
+ n := uint(0)
+ for _, i := range s.members {
+ n += i
+ }
+ return int(n)
+}
+
+// key encodes the map key for an address into a temporary buffer.
+//
+// The first byte of key is '4' or '6' to distinguish IPv4/IPv6 address types.
+// The remainder of the key is the IP, truncated to the number of bits.
+func (s *DistinctNetSet) key(ip net.IP) net.IP {
+ // Lazily initialize storage.
+ if s.members == nil {
+ s.members = make(map[string]uint)
+ s.buf = make(net.IP, 17)
+ }
+ // Canonicalize ip and bits.
+ typ := byte('6')
+ if ip4 := ip.To4(); ip4 != nil {
+ typ, ip = '4', ip4
+ }
+ bits := s.Subnet
+ if bits > uint(len(ip)*8) {
+ bits = uint(len(ip) * 8)
+ }
+ // Encode the prefix into s.buf.
+ nb := int(bits / 8)
+ mask := ^byte(0xFF >> (bits % 8))
+ s.buf[0] = typ
+ buf := append(s.buf[:1], ip[:nb]...)
+ if nb < len(ip) && mask != 0 {
+ buf = append(buf, ip[nb]&mask)
+ }
+ return buf
+}
+
+// String implements fmt.Stringer
+func (s DistinctNetSet) String() string {
+ var buf bytes.Buffer
+ buf.WriteString("{")
+ keys := make([]string, 0, len(s.members))
+ for k := range s.members {
+ keys = append(keys, k)
+ }
+ sort.Strings(keys)
+ for i, k := range keys {
+ var ip net.IP
+ if k[0] == '4' {
+ ip = make(net.IP, 4)
+ } else {
+ ip = make(net.IP, 16)
+ }
+ copy(ip, k[1:])
+ fmt.Fprintf(&buf, "%vĂ—%d", ip, s.members[k])
+ if i != len(keys)-1 {
+ buf.WriteString(" ")
+ }
+ }
+ buf.WriteString("}")
+ return buf.String()
+}
diff --git a/p2p/netutil/net_test.go b/p2p/netutil/net_test.go
index 1ee1fcb4d..3a6aa081f 100644
--- a/p2p/netutil/net_test.go
+++ b/p2p/netutil/net_test.go
@@ -17,9 +17,11 @@
package netutil
import (
+ "fmt"
"net"
"reflect"
"testing"
+ "testing/quick"
"github.com/davecgh/go-spew/spew"
)
@@ -171,3 +173,90 @@ func BenchmarkCheckRelayIP(b *testing.B) {
CheckRelayIP(sender, addr)
}
}
+
+func TestSameNet(t *testing.T) {
+ tests := []struct {
+ ip, other string
+ bits uint
+ want bool
+ }{
+ {"0.0.0.0", "0.0.0.0", 32, true},
+ {"0.0.0.0", "0.0.0.1", 0, true},
+ {"0.0.0.0", "0.0.0.1", 31, true},
+ {"0.0.0.0", "0.0.0.1", 32, false},
+ {"0.33.0.1", "0.34.0.2", 8, true},
+ {"0.33.0.1", "0.34.0.2", 13, true},
+ {"0.33.0.1", "0.34.0.2", 15, false},
+ }
+
+ for _, test := range tests {
+ if ok := SameNet(test.bits, parseIP(test.ip), parseIP(test.other)); ok != test.want {
+ t.Errorf("SameNet(%d, %s, %s) == %t, want %t", test.bits, test.ip, test.other, ok, test.want)
+ }
+ }
+}
+
+func ExampleSameNet() {
+ // This returns true because the IPs are in the same /24 network:
+ fmt.Println(SameNet(24, net.IP{127, 0, 0, 1}, net.IP{127, 0, 0, 3}))
+ // This call returns false:
+ fmt.Println(SameNet(24, net.IP{127, 3, 0, 1}, net.IP{127, 5, 0, 3}))
+ // Output:
+ // true
+ // false
+}
+
+func TestDistinctNetSet(t *testing.T) {
+ ops := []struct {
+ add, remove string
+ fails bool
+ }{
+ {add: "127.0.0.1"},
+ {add: "127.0.0.2"},
+ {add: "127.0.0.3", fails: true},
+ {add: "127.32.0.1"},
+ {add: "127.32.0.2"},
+ {add: "127.32.0.3", fails: true},
+ {add: "127.33.0.1", fails: true},
+ {add: "127.34.0.1"},
+ {add: "127.34.0.2"},
+ {add: "127.34.0.3", fails: true},
+ // Make room for an address, then add again.
+ {remove: "127.0.0.1"},
+ {add: "127.0.0.3"},
+ {add: "127.0.0.3", fails: true},
+ }
+
+ set := DistinctNetSet{Subnet: 15, Limit: 2}
+ for _, op := range ops {
+ var desc string
+ if op.add != "" {
+ desc = fmt.Sprintf("Add(%s)", op.add)
+ if ok := set.Add(parseIP(op.add)); ok != !op.fails {
+ t.Errorf("%s == %t, want %t", desc, ok, !op.fails)
+ }
+ } else {
+ desc = fmt.Sprintf("Remove(%s)", op.remove)
+ set.Remove(parseIP(op.remove))
+ }
+ t.Logf("%s: %v", desc, set)
+ }
+}
+
+func TestDistinctNetSetAddRemove(t *testing.T) {
+ cfg := &quick.Config{}
+ fn := func(ips []net.IP) bool {
+ s := DistinctNetSet{Limit: 3, Subnet: 2}
+ for _, ip := range ips {
+ s.Add(ip)
+ }
+ for _, ip := range ips {
+ s.Remove(ip)
+ }
+ return s.Len() == 0
+ }
+
+ if err := quick.Check(fn, cfg); err != nil {
+ t.Fatal(err)
+ }
+}
diff --git a/p2p/peer.go b/p2p/peer.go
index bad1c8c8b..477d8c219 100644
--- a/p2p/peer.go
+++ b/p2p/peer.go
@@ -419,6 +419,9 @@ type PeerInfo struct {
Network struct {
LocalAddress string `json:"localAddress"` // Local endpoint of the TCP data connection
RemoteAddress string `json:"remoteAddress"` // Remote endpoint of the TCP data connection
+ Inbound bool `json:"inbound"`
+ Trusted bool `json:"trusted"`
+ Static bool `json:"static"`
} `json:"network"`
Protocols map[string]interface{} `json:"protocols"` // Sub-protocol specific metadata fields
}
@@ -439,6 +442,9 @@ func (p *Peer) Info() *PeerInfo {
}
info.Network.LocalAddress = p.LocalAddr().String()
info.Network.RemoteAddress = p.RemoteAddr().String()
+ info.Network.Inbound = p.rw.is(inboundConn)
+ info.Network.Trusted = p.rw.is(trustedConn)
+ info.Network.Static = p.rw.is(staticDialedConn)
// Gather all the running protocol infos
for _, proto := range p.running {
diff --git a/p2p/server.go b/p2p/server.go
index 2cff94ea5..edc1d9d21 100644
--- a/p2p/server.go
+++ b/p2p/server.go
@@ -40,11 +40,10 @@ const (
refreshPeersInterval = 30 * time.Second
staticPeerCheckInterval = 15 * time.Second
- // Maximum number of concurrently handshaking inbound connections.
- maxAcceptConns = 50
-
- // Maximum number of concurrently dialing outbound connections.
- maxActiveDialTasks = 16
+ // Connectivity defaults.
+ maxActiveDialTasks = 16
+ defaultMaxPendingPeers = 50
+ defaultDialRatio = 3
// Maximum time allowed for reading a complete message.
// This is effectively the amount of time a connection can be idle.
@@ -70,6 +69,11 @@ type Config struct {
// Zero defaults to preset values.
MaxPendingPeers int `toml:",omitempty"`
+ // DialRatio controls the ratio of inbound to dialed connections.
+ // Example: a DialRatio of 2 allows 1/2 of connections to be dialed.
+ // Setting DialRatio to zero defaults it to 3.
+ DialRatio int `toml:",omitempty"`
+
// NoDiscovery can be used to disable the peer discovery mechanism.
// Disabling is useful for protocol debugging (manual topology).
NoDiscovery bool
@@ -427,7 +431,6 @@ func (srv *Server) Start() (err error) {
if err != nil {
return err
}
-
realaddr = conn.LocalAddr().(*net.UDPAddr)
if srv.NAT != nil {
if !realaddr.IP.IsLoopback() {
@@ -447,11 +450,16 @@ func (srv *Server) Start() (err error) {
// node table
if !srv.NoDiscovery {
- ntab, err := discover.ListenUDP(srv.PrivateKey, conn, realaddr, unhandled, srv.NodeDatabase, srv.NetRestrict)
- if err != nil {
- return err
+ cfg := discover.Config{
+ PrivateKey: srv.PrivateKey,
+ AnnounceAddr: realaddr,
+ NodeDBPath: srv.NodeDatabase,
+ NetRestrict: srv.NetRestrict,
+ Bootnodes: srv.BootstrapNodes,
+ Unhandled: unhandled,
}
- if err := ntab.SetFallbackNodes(srv.BootstrapNodes); err != nil {
+ ntab, err := discover.ListenUDP(conn, cfg)
+ if err != nil {
return err
}
srv.ntab = ntab
@@ -476,10 +484,7 @@ func (srv *Server) Start() (err error) {
srv.DiscV5 = ntab
}
- dynPeers := (srv.MaxPeers + 1) / 2
- if srv.NoDiscovery {
- dynPeers = 0
- }
+ dynPeers := srv.maxDialedConns()
dialer := newDialState(srv.StaticNodes, srv.BootstrapNodes, srv.ntab, dynPeers, srv.NetRestrict)
// handshake
@@ -536,6 +541,7 @@ func (srv *Server) run(dialstate dialer) {
defer srv.loopWG.Done()
var (
peers = make(map[discover.NodeID]*Peer)
+ inboundCount = 0
trusted = make(map[discover.NodeID]bool, len(srv.TrustedNodes))
taskdone = make(chan task, maxActiveDialTasks)
runningTasks []task
@@ -621,14 +627,14 @@ running:
}
// TODO: track in-progress inbound node IDs (pre-Peer) to avoid dialing them.
select {
- case c.cont <- srv.encHandshakeChecks(peers, c):
+ case c.cont <- srv.encHandshakeChecks(peers, inboundCount, c):
case <-srv.quit:
break running
}
case c := <-srv.addpeer:
// At this point the connection is past the protocol handshake.
// Its capabilities are known and the remote identity is verified.
- err := srv.protoHandshakeChecks(peers, c)
+ err := srv.protoHandshakeChecks(peers, inboundCount, c)
if err == nil {
// The handshakes are done and it passed all checks.
p := newPeer(c, srv.Protocols)
@@ -639,8 +645,11 @@ running:
}
name := truncateName(c.name)
srv.log.Debug("Adding p2p peer", "name", name, "addr", c.fd.RemoteAddr(), "peers", len(peers)+1)
- peers[c.id] = p
go srv.runPeer(p)
+ peers[c.id] = p
+ if p.Inbound() {
+ inboundCount++
+ }
}
// The dialer logic relies on the assumption that
// dial tasks complete after the peer has been added or
@@ -655,6 +664,9 @@ running:
d := common.PrettyDuration(mclock.Now() - pd.created)
pd.log.Debug("Removing p2p peer", "duration", d, "peers", len(peers)-1, "req", pd.requested, "err", pd.err)
delete(peers, pd.ID())
+ if pd.Inbound() {
+ inboundCount--
+ }
}
}
@@ -681,20 +693,22 @@ running:
}
}
-func (srv *Server) protoHandshakeChecks(peers map[discover.NodeID]*Peer, c *conn) error {
+func (srv *Server) protoHandshakeChecks(peers map[discover.NodeID]*Peer, inboundCount int, c *conn) error {
// Drop connections with no matching protocols.
if len(srv.Protocols) > 0 && countMatchingProtocols(srv.Protocols, c.caps) == 0 {
return DiscUselessPeer
}
// Repeat the encryption handshake checks because the
// peer set might have changed between the handshakes.
- return srv.encHandshakeChecks(peers, c)
+ return srv.encHandshakeChecks(peers, inboundCount, c)
}
-func (srv *Server) encHandshakeChecks(peers map[discover.NodeID]*Peer, c *conn) error {
+func (srv *Server) encHandshakeChecks(peers map[discover.NodeID]*Peer, inboundCount int, c *conn) error {
switch {
case !c.is(trustedConn|staticDialedConn) && len(peers) >= srv.MaxPeers:
return DiscTooManyPeers
+ case !c.is(trustedConn) && c.is(inboundConn) && inboundCount >= srv.maxInboundConns():
+ return DiscTooManyPeers
case peers[c.id] != nil:
return DiscAlreadyConnected
case c.id == srv.Self().ID:
@@ -704,6 +718,21 @@ func (srv *Server) encHandshakeChecks(peers map[discover.NodeID]*Peer, c *conn)
}
}
+func (srv *Server) maxInboundConns() int {
+ return srv.MaxPeers - srv.maxDialedConns()
+}
+
+func (srv *Server) maxDialedConns() int {
+ if srv.NoDiscovery || srv.NoDial {
+ return 0
+ }
+ r := srv.DialRatio
+ if r == 0 {
+ r = defaultDialRatio
+ }
+ return srv.MaxPeers / r
+}
+
type tempError interface {
Temporary() bool
}
@@ -714,10 +743,7 @@ func (srv *Server) listenLoop() {
defer srv.loopWG.Done()
srv.log.Info("RLPx listener up", "self", srv.makeSelf(srv.listener, srv.ntab))
- // This channel acts as a semaphore limiting
- // active inbound connections that are lingering pre-handshake.
- // If all slots are taken, no further connections are accepted.
- tokens := maxAcceptConns
+ tokens := defaultMaxPendingPeers
if srv.MaxPendingPeers > 0 {
tokens = srv.MaxPendingPeers
}
@@ -758,9 +784,6 @@ func (srv *Server) listenLoop() {
fd = newMeteredConn(fd, true)
srv.log.Trace("Accepted connection", "addr", fd.RemoteAddr())
-
- // Spawn the handler. It will give the slot back when the connection
- // has been established.
go func() {
srv.SetupConn(fd, inboundConn, nil)
slots <- struct{}{}