aboutsummaryrefslogtreecommitdiffstats
path: root/p2p
diff options
context:
space:
mode:
Diffstat (limited to 'p2p')
-rw-r--r--p2p/dial.go13
-rw-r--r--p2p/dial_test.go44
-rw-r--r--p2p/discover/database.go17
-rw-r--r--p2p/discover/database_test.go18
-rw-r--r--p2p/discover/node.go6
-rw-r--r--p2p/discover/table.go486
-rw-r--r--p2p/discover/table_test.go171
-rw-r--r--p2p/discover/udp.go118
-rw-r--r--p2p/discover/udp_test.go28
-rw-r--r--p2p/discv5/net.go43
-rw-r--r--p2p/discv5/net_test.go2
-rw-r--r--p2p/discv5/sim_test.go2
-rw-r--r--p2p/discv5/ticket.go16
-rw-r--r--p2p/discv5/udp.go67
-rw-r--r--p2p/netutil/net.go131
-rw-r--r--p2p/netutil/net_test.go89
-rw-r--r--p2p/peer.go6
-rw-r--r--p2p/rlpx.go10
-rw-r--r--p2p/rlpx_test.go38
-rw-r--r--p2p/server.go150
-rw-r--r--p2p/simulations/adapters/state.go1
21 files changed, 1067 insertions, 389 deletions
diff --git a/p2p/dial.go b/p2p/dial.go
index f5ff2c211..d8feceb9f 100644
--- a/p2p/dial.go
+++ b/p2p/dial.go
@@ -154,6 +154,9 @@ func (s *dialstate) addStatic(n *discover.Node) {
func (s *dialstate) removeStatic(n *discover.Node) {
// This removes a task so future attempts to connect will not be made.
delete(s.static, n.ID)
+ // This removes a previous dial timestamp so that application
+ // can force a server to reconnect with chosen peer immediately.
+ s.hist.remove(n.ID)
}
func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now time.Time) []task {
@@ -390,6 +393,16 @@ func (h dialHistory) min() pastDial {
}
func (h *dialHistory) add(id discover.NodeID, exp time.Time) {
heap.Push(h, pastDial{id, exp})
+
+}
+func (h *dialHistory) remove(id discover.NodeID) bool {
+ for i, v := range *h {
+ if v.id == id {
+ heap.Remove(h, i)
+ return true
+ }
+ }
+ return false
}
func (h dialHistory) contains(id discover.NodeID) bool {
for _, v := range h {
diff --git a/p2p/dial_test.go b/p2p/dial_test.go
index ad18ef9ab..2a7941fc6 100644
--- a/p2p/dial_test.go
+++ b/p2p/dial_test.go
@@ -515,6 +515,50 @@ func TestDialStateStaticDial(t *testing.T) {
})
}
+// This test checks that static peers will be redialed immediately if they were re-added to a static list.
+func TestDialStaticAfterReset(t *testing.T) {
+ wantStatic := []*discover.Node{
+ {ID: uintID(1)},
+ {ID: uintID(2)},
+ }
+
+ rounds := []round{
+ // Static dials are launched for the nodes that aren't yet connected.
+ {
+ peers: nil,
+ new: []task{
+ &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(1)}},
+ &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(2)}},
+ },
+ },
+ // No new dial tasks, all peers are connected.
+ {
+ peers: []*Peer{
+ {rw: &conn{flags: staticDialedConn, id: uintID(1)}},
+ {rw: &conn{flags: staticDialedConn, id: uintID(2)}},
+ },
+ done: []task{
+ &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(1)}},
+ &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(2)}},
+ },
+ new: []task{
+ &waitExpireTask{Duration: 30 * time.Second},
+ },
+ },
+ }
+ dTest := dialtest{
+ init: newDialState(wantStatic, nil, fakeTable{}, 0, nil),
+ rounds: rounds,
+ }
+ runDialTest(t, dTest)
+ for _, n := range wantStatic {
+ dTest.init.removeStatic(n)
+ dTest.init.addStatic(n)
+ }
+ // without removing peers they will be considered recently dialed
+ runDialTest(t, dTest)
+}
+
// This test checks that past dials are not retried for some time.
func TestDialStateCache(t *testing.T) {
wantStatic := []*discover.Node{
diff --git a/p2p/discover/database.go b/p2p/discover/database.go
index b136609f2..6f98de9b4 100644
--- a/p2p/discover/database.go
+++ b/p2p/discover/database.go
@@ -257,7 +257,7 @@ func (db *nodeDB) expireNodes() error {
}
// Skip the node if not expired yet (and not self)
if !bytes.Equal(id[:], db.self[:]) {
- if seen := db.lastPong(id); seen.After(threshold) {
+ if seen := db.bondTime(id); seen.After(threshold) {
continue
}
}
@@ -278,13 +278,18 @@ func (db *nodeDB) updateLastPing(id NodeID, instance time.Time) error {
return db.storeInt64(makeKey(id, nodeDBDiscoverPing), instance.Unix())
}
-// lastPong retrieves the time of the last successful contact from remote node.
-func (db *nodeDB) lastPong(id NodeID) time.Time {
+// bondTime retrieves the time of the last successful pong from remote node.
+func (db *nodeDB) bondTime(id NodeID) time.Time {
return time.Unix(db.fetchInt64(makeKey(id, nodeDBDiscoverPong)), 0)
}
-// updateLastPong updates the last time a remote node successfully contacted.
-func (db *nodeDB) updateLastPong(id NodeID, instance time.Time) error {
+// hasBond reports whether the given node is considered bonded.
+func (db *nodeDB) hasBond(id NodeID) bool {
+ return time.Since(db.bondTime(id)) < nodeDBNodeExpiration
+}
+
+// updateBondTime updates the last pong time of a node.
+func (db *nodeDB) updateBondTime(id NodeID, instance time.Time) error {
return db.storeInt64(makeKey(id, nodeDBDiscoverPong), instance.Unix())
}
@@ -327,7 +332,7 @@ seek:
if n.ID == db.self {
continue seek
}
- if now.Sub(db.lastPong(n.ID)) > maxAge {
+ if now.Sub(db.bondTime(n.ID)) > maxAge {
continue seek
}
for i := range nodes {
diff --git a/p2p/discover/database_test.go b/p2p/discover/database_test.go
index be972fd2c..c4fa44d09 100644
--- a/p2p/discover/database_test.go
+++ b/p2p/discover/database_test.go
@@ -125,13 +125,13 @@ func TestNodeDBFetchStore(t *testing.T) {
t.Errorf("ping: value mismatch: have %v, want %v", stored, inst)
}
// Check fetch/store operations on a node pong object
- if stored := db.lastPong(node.ID); stored.Unix() != 0 {
+ if stored := db.bondTime(node.ID); stored.Unix() != 0 {
t.Errorf("pong: non-existing object: %v", stored)
}
- if err := db.updateLastPong(node.ID, inst); err != nil {
+ if err := db.updateBondTime(node.ID, inst); err != nil {
t.Errorf("pong: failed to update: %v", err)
}
- if stored := db.lastPong(node.ID); stored.Unix() != inst.Unix() {
+ if stored := db.bondTime(node.ID); stored.Unix() != inst.Unix() {
t.Errorf("pong: value mismatch: have %v, want %v", stored, inst)
}
// Check fetch/store operations on a node findnode-failure object
@@ -224,8 +224,8 @@ func TestNodeDBSeedQuery(t *testing.T) {
if err := db.updateNode(seed.node); err != nil {
t.Fatalf("node %d: failed to insert: %v", i, err)
}
- if err := db.updateLastPong(seed.node.ID, seed.pong); err != nil {
- t.Fatalf("node %d: failed to insert lastPong: %v", i, err)
+ if err := db.updateBondTime(seed.node.ID, seed.pong); err != nil {
+ t.Fatalf("node %d: failed to insert bondTime: %v", i, err)
}
}
@@ -332,8 +332,8 @@ func TestNodeDBExpiration(t *testing.T) {
if err := db.updateNode(seed.node); err != nil {
t.Fatalf("node %d: failed to insert: %v", i, err)
}
- if err := db.updateLastPong(seed.node.ID, seed.pong); err != nil {
- t.Fatalf("node %d: failed to update pong: %v", i, err)
+ if err := db.updateBondTime(seed.node.ID, seed.pong); err != nil {
+ t.Fatalf("node %d: failed to update bondTime: %v", i, err)
}
}
// Expire some of them, and check the rest
@@ -365,8 +365,8 @@ func TestNodeDBSelfExpiration(t *testing.T) {
if err := db.updateNode(seed.node); err != nil {
t.Fatalf("node %d: failed to insert: %v", i, err)
}
- if err := db.updateLastPong(seed.node.ID, seed.pong); err != nil {
- t.Fatalf("node %d: failed to update pong: %v", i, err)
+ if err := db.updateBondTime(seed.node.ID, seed.pong); err != nil {
+ t.Fatalf("node %d: failed to update bondTime: %v", i, err)
}
}
// Expire the nodes and make sure self has been evacuated too
diff --git a/p2p/discover/node.go b/p2p/discover/node.go
index fc928a91a..3b0c84115 100644
--- a/p2p/discover/node.go
+++ b/p2p/discover/node.go
@@ -29,6 +29,7 @@ import (
"regexp"
"strconv"
"strings"
+ "time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
@@ -51,9 +52,8 @@ type Node struct {
// with ID.
sha common.Hash
- // whether this node is currently being pinged in order to replace
- // it in a bucket
- contested bool
+ // Time when the node was added to the table.
+ addedAt time.Time
}
// NewNode creates a new node. It is mostly meant to be used for
diff --git a/p2p/discover/table.go b/p2p/discover/table.go
index ec4eb94ad..6509326e6 100644
--- a/p2p/discover/table.go
+++ b/p2p/discover/table.go
@@ -23,10 +23,11 @@
package discover
import (
- "crypto/rand"
+ crand "crypto/rand"
"encoding/binary"
"errors"
"fmt"
+ mrand "math/rand"
"net"
"sort"
"sync"
@@ -35,29 +36,45 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/log"
+ "github.com/ethereum/go-ethereum/p2p/netutil"
)
const (
- alpha = 3 // Kademlia concurrency factor
- bucketSize = 16 // Kademlia bucket size
- hashBits = len(common.Hash{}) * 8
- nBuckets = hashBits + 1 // Number of buckets
-
- maxBondingPingPongs = 16
- maxFindnodeFailures = 5
-
- autoRefreshInterval = 1 * time.Hour
- seedCount = 30
- seedMaxAge = 5 * 24 * time.Hour
+ alpha = 3 // Kademlia concurrency factor
+ bucketSize = 16 // Kademlia bucket size
+ maxReplacements = 10 // Size of per-bucket replacement list
+
+ // We keep buckets for the upper 1/15 of distances because
+ // it's very unlikely we'll ever encounter a node that's closer.
+ hashBits = len(common.Hash{}) * 8
+ nBuckets = hashBits / 15 // Number of buckets
+ bucketMinDistance = hashBits - nBuckets // Log distance of closest bucket
+
+ // IP address limits.
+ bucketIPLimit, bucketSubnet = 2, 24 // at most 2 addresses from the same /24
+ tableIPLimit, tableSubnet = 10, 24
+
+ maxBondingPingPongs = 16 // Limit on the number of concurrent ping/pong interactions
+ maxFindnodeFailures = 5 // Nodes exceeding this limit are dropped
+
+ refreshInterval = 30 * time.Minute
+ revalidateInterval = 10 * time.Second
+ copyNodesInterval = 30 * time.Second
+ seedMinTableTime = 5 * time.Minute
+ seedCount = 30
+ seedMaxAge = 5 * 24 * time.Hour
)
type Table struct {
- mutex sync.Mutex // protects buckets, their content, and nursery
+ mutex sync.Mutex // protects buckets, bucket content, nursery, rand
buckets [nBuckets]*bucket // index of known nodes by distance
nursery []*Node // bootstrap nodes
- db *nodeDB // database of known nodes
+ rand *mrand.Rand // source of randomness, periodically reseeded
+ ips netutil.DistinctNetSet
+ db *nodeDB // database of known nodes
refreshReq chan chan struct{}
+ initDone chan struct{}
closeReq chan struct{}
closed chan struct{}
@@ -89,9 +106,13 @@ type transport interface {
// bucket contains nodes, ordered by their last activity. the entry
// that was most recently active is the first element in entries.
-type bucket struct{ entries []*Node }
+type bucket struct {
+ entries []*Node // live entries, sorted by time of last contact
+ replacements []*Node // recently seen nodes to be used if revalidation fails
+ ips netutil.DistinctNetSet
+}
-func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string) (*Table, error) {
+func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string, bootnodes []*Node) (*Table, error) {
// If no node database was given, use an in-memory one
db, err := newNodeDB(nodeDBPath, Version, ourID)
if err != nil {
@@ -104,19 +125,42 @@ func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string
bonding: make(map[NodeID]*bondproc),
bondslots: make(chan struct{}, maxBondingPingPongs),
refreshReq: make(chan chan struct{}),
+ initDone: make(chan struct{}),
closeReq: make(chan struct{}),
closed: make(chan struct{}),
+ rand: mrand.New(mrand.NewSource(0)),
+ ips: netutil.DistinctNetSet{Subnet: tableSubnet, Limit: tableIPLimit},
+ }
+ if err := tab.setFallbackNodes(bootnodes); err != nil {
+ return nil, err
}
for i := 0; i < cap(tab.bondslots); i++ {
tab.bondslots <- struct{}{}
}
for i := range tab.buckets {
- tab.buckets[i] = new(bucket)
+ tab.buckets[i] = &bucket{
+ ips: netutil.DistinctNetSet{Subnet: bucketSubnet, Limit: bucketIPLimit},
+ }
}
- go tab.refreshLoop()
+ tab.seedRand()
+ tab.loadSeedNodes(false)
+ // Start the background expiration goroutine after loading seeds so that the search for
+ // seed nodes also considers older nodes that would otherwise be removed by the
+ // expiration.
+ tab.db.ensureExpirer()
+ go tab.loop()
return tab, nil
}
+func (tab *Table) seedRand() {
+ var b [8]byte
+ crand.Read(b[:])
+
+ tab.mutex.Lock()
+ tab.rand.Seed(int64(binary.BigEndian.Uint64(b[:])))
+ tab.mutex.Unlock()
+}
+
// Self returns the local node.
// The returned node should not be modified by the caller.
func (tab *Table) Self() *Node {
@@ -127,9 +171,12 @@ func (tab *Table) Self() *Node {
// table. It will not write the same node more than once. The nodes in
// the slice are copies and can be modified by the caller.
func (tab *Table) ReadRandomNodes(buf []*Node) (n int) {
+ if !tab.isInitDone() {
+ return 0
+ }
tab.mutex.Lock()
defer tab.mutex.Unlock()
- // TODO: tree-based buckets would help here
+
// Find all non-empty buckets and get a fresh slice of their entries.
var buckets [][]*Node
for _, b := range tab.buckets {
@@ -141,8 +188,8 @@ func (tab *Table) ReadRandomNodes(buf []*Node) (n int) {
return 0
}
// Shuffle the buckets.
- for i := uint32(len(buckets)) - 1; i > 0; i-- {
- j := randUint(i)
+ for i := len(buckets) - 1; i > 0; i-- {
+ j := tab.rand.Intn(len(buckets))
buckets[i], buckets[j] = buckets[j], buckets[i]
}
// Move head of each bucket into buf, removing buckets that become empty.
@@ -161,15 +208,6 @@ func (tab *Table) ReadRandomNodes(buf []*Node) (n int) {
return i + 1
}
-func randUint(max uint32) uint32 {
- if max == 0 {
- return 0
- }
- var b [4]byte
- rand.Read(b[:])
- return binary.BigEndian.Uint32(b[:]) % max
-}
-
// Close terminates the network listener and flushes the node database.
func (tab *Table) Close() {
select {
@@ -180,16 +218,15 @@ func (tab *Table) Close() {
}
}
-// SetFallbackNodes sets the initial points of contact. These nodes
+// setFallbackNodes sets the initial points of contact. These nodes
// are used to connect to the network if the table is empty and there
// are no known nodes in the database.
-func (tab *Table) SetFallbackNodes(nodes []*Node) error {
+func (tab *Table) setFallbackNodes(nodes []*Node) error {
for _, n := range nodes {
if err := n.validateComplete(); err != nil {
return fmt.Errorf("bad bootstrap/fallback node %q (%v)", n, err)
}
}
- tab.mutex.Lock()
tab.nursery = make([]*Node, 0, len(nodes))
for _, n := range nodes {
cpy := *n
@@ -198,11 +235,19 @@ func (tab *Table) SetFallbackNodes(nodes []*Node) error {
cpy.sha = crypto.Keccak256Hash(n.ID[:])
tab.nursery = append(tab.nursery, &cpy)
}
- tab.mutex.Unlock()
- tab.refresh()
return nil
}
+// isInitDone returns whether the table's initial seeding procedure has completed.
+func (tab *Table) isInitDone() bool {
+ select {
+ case <-tab.initDone:
+ return true
+ default:
+ return false
+ }
+}
+
// Resolve searches for a specific node with the given ID.
// It returns nil if the node could not be found.
func (tab *Table) Resolve(targetID NodeID) *Node {
@@ -314,33 +359,49 @@ func (tab *Table) refresh() <-chan struct{} {
return done
}
-// refreshLoop schedules doRefresh runs and coordinates shutdown.
-func (tab *Table) refreshLoop() {
+// loop schedules refresh, revalidate runs and coordinates shutdown.
+func (tab *Table) loop() {
var (
- timer = time.NewTicker(autoRefreshInterval)
- waiting []chan struct{} // accumulates waiting callers while doRefresh runs
- done chan struct{} // where doRefresh reports completion
+ revalidate = time.NewTimer(tab.nextRevalidateTime())
+ refresh = time.NewTicker(refreshInterval)
+ copyNodes = time.NewTicker(copyNodesInterval)
+ revalidateDone = make(chan struct{})
+ refreshDone = make(chan struct{}) // where doRefresh reports completion
+ waiting = []chan struct{}{tab.initDone} // holds waiting callers while doRefresh runs
)
+ defer refresh.Stop()
+ defer revalidate.Stop()
+ defer copyNodes.Stop()
+
+ // Start initial refresh.
+ go tab.doRefresh(refreshDone)
+
loop:
for {
select {
- case <-timer.C:
- if done == nil {
- done = make(chan struct{})
- go tab.doRefresh(done)
+ case <-refresh.C:
+ tab.seedRand()
+ if refreshDone == nil {
+ refreshDone = make(chan struct{})
+ go tab.doRefresh(refreshDone)
}
case req := <-tab.refreshReq:
waiting = append(waiting, req)
- if done == nil {
- done = make(chan struct{})
- go tab.doRefresh(done)
+ if refreshDone == nil {
+ refreshDone = make(chan struct{})
+ go tab.doRefresh(refreshDone)
}
- case <-done:
+ case <-refreshDone:
for _, ch := range waiting {
close(ch)
}
- waiting = nil
- done = nil
+ waiting, refreshDone = nil, nil
+ case <-revalidate.C:
+ go tab.doRevalidate(revalidateDone)
+ case <-revalidateDone:
+ revalidate.Reset(tab.nextRevalidateTime())
+ case <-copyNodes.C:
+ go tab.copyBondedNodes()
case <-tab.closeReq:
break loop
}
@@ -349,8 +410,8 @@ loop:
if tab.net != nil {
tab.net.close()
}
- if done != nil {
- <-done
+ if refreshDone != nil {
+ <-refreshDone
}
for _, ch := range waiting {
close(ch)
@@ -365,38 +426,109 @@ loop:
func (tab *Table) doRefresh(done chan struct{}) {
defer close(done)
+ // Load nodes from the database and insert
+ // them. This should yield a few previously seen nodes that are
+ // (hopefully) still alive.
+ tab.loadSeedNodes(true)
+
+ // Run self lookup to discover new neighbor nodes.
+ tab.lookup(tab.self.ID, false)
+
// The Kademlia paper specifies that the bucket refresh should
// perform a lookup in the least recently used bucket. We cannot
// adhere to this because the findnode target is a 512bit value
// (not hash-sized) and it is not easily possible to generate a
// sha3 preimage that falls into a chosen bucket.
- // We perform a lookup with a random target instead.
- var target NodeID
- rand.Read(target[:])
- result := tab.lookup(target, false)
- if len(result) > 0 {
- return
+ // We perform a few lookups with a random target instead.
+ for i := 0; i < 3; i++ {
+ var target NodeID
+ crand.Read(target[:])
+ tab.lookup(target, false)
}
+}
- // The table is empty. Load nodes from the database and insert
- // them. This should yield a few previously seen nodes that are
- // (hopefully) still alive.
+func (tab *Table) loadSeedNodes(bond bool) {
seeds := tab.db.querySeeds(seedCount, seedMaxAge)
- seeds = tab.bondall(append(seeds, tab.nursery...))
+ seeds = append(seeds, tab.nursery...)
+ if bond {
+ seeds = tab.bondall(seeds)
+ }
+ for i := range seeds {
+ seed := seeds[i]
+ age := log.Lazy{Fn: func() interface{} { return time.Since(tab.db.bondTime(seed.ID)) }}
+ log.Debug("Found seed node in database", "id", seed.ID, "addr", seed.addr(), "age", age)
+ tab.add(seed)
+ }
+}
- if len(seeds) == 0 {
- log.Debug("No discv4 seed nodes found")
+// doRevalidate checks that the last node in a random bucket is still live
+// and replaces or deletes the node if it isn't.
+func (tab *Table) doRevalidate(done chan<- struct{}) {
+ defer func() { done <- struct{}{} }()
+
+ last, bi := tab.nodeToRevalidate()
+ if last == nil {
+ // No non-empty bucket found.
+ return
+ }
+
+ // Ping the selected node and wait for a pong.
+ err := tab.ping(last.ID, last.addr())
+
+ tab.mutex.Lock()
+ defer tab.mutex.Unlock()
+ b := tab.buckets[bi]
+ if err == nil {
+ // The node responded, move it to the front.
+ log.Debug("Revalidated node", "b", bi, "id", last.ID)
+ b.bump(last)
+ return
}
- for _, n := range seeds {
- age := log.Lazy{Fn: func() time.Duration { return time.Since(tab.db.lastPong(n.ID)) }}
- log.Trace("Found seed node in database", "id", n.ID, "addr", n.addr(), "age", age)
+ // No reply received, pick a replacement or delete the node if there aren't
+ // any replacements.
+ if r := tab.replace(b, last); r != nil {
+ log.Debug("Replaced dead node", "b", bi, "id", last.ID, "ip", last.IP, "r", r.ID, "rip", r.IP)
+ } else {
+ log.Debug("Removed dead node", "b", bi, "id", last.ID, "ip", last.IP)
}
+}
+
+// nodeToRevalidate returns the last node in a random, non-empty bucket.
+func (tab *Table) nodeToRevalidate() (n *Node, bi int) {
tab.mutex.Lock()
- tab.stuff(seeds)
- tab.mutex.Unlock()
+ defer tab.mutex.Unlock()
- // Finally, do a self lookup to fill up the buckets.
- tab.lookup(tab.self.ID, false)
+ for _, bi = range tab.rand.Perm(len(tab.buckets)) {
+ b := tab.buckets[bi]
+ if len(b.entries) > 0 {
+ last := b.entries[len(b.entries)-1]
+ return last, bi
+ }
+ }
+ return nil, 0
+}
+
+func (tab *Table) nextRevalidateTime() time.Duration {
+ tab.mutex.Lock()
+ defer tab.mutex.Unlock()
+
+ return time.Duration(tab.rand.Int63n(int64(revalidateInterval)))
+}
+
+// copyBondedNodes adds nodes from the table to the database if they have been in the table
+// longer then minTableTime.
+func (tab *Table) copyBondedNodes() {
+ tab.mutex.Lock()
+ defer tab.mutex.Unlock()
+
+ now := time.Now()
+ for _, b := range tab.buckets {
+ for _, n := range b.entries {
+ if now.Sub(n.addedAt) >= seedMinTableTime {
+ tab.db.updateNode(n)
+ }
+ }
+ }
}
// closest returns the n nodes in the table that are closest to the
@@ -459,15 +591,14 @@ func (tab *Table) bond(pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16
if id == tab.self.ID {
return nil, errors.New("is self")
}
- // Retrieve a previously known node and any recent findnode failures
- node, fails := tab.db.node(id), 0
- if node != nil {
- fails = tab.db.findFails(id)
+ if pinged && !tab.isInitDone() {
+ return nil, errors.New("still initializing")
}
- // If the node is unknown (non-bonded) or failed (remotely unknown), bond from scratch
+ // Start bonding if we haven't seen this node for a while or if it failed findnode too often.
+ node, fails := tab.db.node(id), tab.db.findFails(id)
+ age := time.Since(tab.db.bondTime(id))
var result error
- age := time.Since(tab.db.lastPong(id))
- if node == nil || fails > 0 || age > nodeDBNodeExpiration {
+ if fails > 0 || age > nodeDBNodeExpiration {
log.Trace("Starting bonding ping/pong", "id", id, "known", node != nil, "failcount", fails, "age", age)
tab.bondmu.Lock()
@@ -494,10 +625,10 @@ func (tab *Table) bond(pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16
node = w.n
}
}
+ // Add the node to the table even if the bonding ping/pong
+ // fails. It will be relaced quickly if it continues to be
+ // unresponsive.
if node != nil {
- // Add the node to the table even if the bonding ping/pong
- // fails. It will be relaced quickly if it continues to be
- // unresponsive.
tab.add(node)
tab.db.updateFindFails(id, 0)
}
@@ -522,7 +653,6 @@ func (tab *Table) pingpong(w *bondproc, pinged bool, id NodeID, addr *net.UDPAdd
}
// Bonding succeeded, update the node database.
w.n = NewNode(id, addr.IP, uint16(addr.Port), tcpPort)
- tab.db.updateNode(w.n)
close(w.done)
}
@@ -533,17 +663,19 @@ func (tab *Table) ping(id NodeID, addr *net.UDPAddr) error {
if err := tab.net.ping(id, addr); err != nil {
return err
}
- tab.db.updateLastPong(id, time.Now())
-
- // Start the background expiration goroutine after the first
- // successful communication. Subsequent calls have no effect if it
- // is already running. We do this here instead of somewhere else
- // so that the search for seed nodes also considers older nodes
- // that would otherwise be removed by the expiration.
- tab.db.ensureExpirer()
+ tab.db.updateBondTime(id, time.Now())
return nil
}
+// bucket returns the bucket for the given node ID hash.
+func (tab *Table) bucket(sha common.Hash) *bucket {
+ d := logdist(tab.self.sha, sha)
+ if d <= bucketMinDistance {
+ return tab.buckets[0]
+ }
+ return tab.buckets[d-bucketMinDistance-1]
+}
+
// add attempts to add the given node its corresponding bucket. If the
// bucket has space available, adding the node succeeds immediately.
// Otherwise, the node is added if the least recently active node in
@@ -551,57 +683,29 @@ func (tab *Table) ping(id NodeID, addr *net.UDPAddr) error {
//
// The caller must not hold tab.mutex.
func (tab *Table) add(new *Node) {
- b := tab.buckets[logdist(tab.self.sha, new.sha)]
tab.mutex.Lock()
defer tab.mutex.Unlock()
- if b.bump(new) {
- return
- }
- var oldest *Node
- if len(b.entries) == bucketSize {
- oldest = b.entries[bucketSize-1]
- if oldest.contested {
- // The node is already being replaced, don't attempt
- // to replace it.
- return
- }
- oldest.contested = true
- // Let go of the mutex so other goroutines can access
- // the table while we ping the least recently active node.
- tab.mutex.Unlock()
- err := tab.ping(oldest.ID, oldest.addr())
- tab.mutex.Lock()
- oldest.contested = false
- if err == nil {
- // The node responded, don't replace it.
- return
- }
- }
- added := b.replace(new, oldest)
- if added && tab.nodeAddedHook != nil {
- tab.nodeAddedHook(new)
+
+ b := tab.bucket(new.sha)
+ if !tab.bumpOrAdd(b, new) {
+ // Node is not in table. Add it to the replacement list.
+ tab.addReplacement(b, new)
}
}
// stuff adds nodes the table to the end of their corresponding bucket
-// if the bucket is not full. The caller must hold tab.mutex.
+// if the bucket is not full. The caller must not hold tab.mutex.
func (tab *Table) stuff(nodes []*Node) {
-outer:
+ tab.mutex.Lock()
+ defer tab.mutex.Unlock()
+
for _, n := range nodes {
if n.ID == tab.self.ID {
continue // don't add self
}
- bucket := tab.buckets[logdist(tab.self.sha, n.sha)]
- for i := range bucket.entries {
- if bucket.entries[i].ID == n.ID {
- continue outer // already in bucket
- }
- }
- if len(bucket.entries) < bucketSize {
- bucket.entries = append(bucket.entries, n)
- if tab.nodeAddedHook != nil {
- tab.nodeAddedHook(n)
- }
+ b := tab.bucket(n.sha)
+ if len(b.entries) < bucketSize {
+ tab.bumpOrAdd(b, n)
}
}
}
@@ -611,36 +715,72 @@ outer:
func (tab *Table) delete(node *Node) {
tab.mutex.Lock()
defer tab.mutex.Unlock()
- bucket := tab.buckets[logdist(tab.self.sha, node.sha)]
- for i := range bucket.entries {
- if bucket.entries[i].ID == node.ID {
- bucket.entries = append(bucket.entries[:i], bucket.entries[i+1:]...)
- return
- }
- }
+
+ tab.deleteInBucket(tab.bucket(node.sha), node)
}
-func (b *bucket) replace(n *Node, last *Node) bool {
- // Don't add if b already contains n.
- for i := range b.entries {
- if b.entries[i].ID == n.ID {
- return false
- }
+func (tab *Table) addIP(b *bucket, ip net.IP) bool {
+ if netutil.IsLAN(ip) {
+ return true
}
- // Replace last if it is still the last entry or just add n if b
- // isn't full. If is no longer the last entry, it has either been
- // replaced with someone else or became active.
- if len(b.entries) == bucketSize && (last == nil || b.entries[bucketSize-1].ID != last.ID) {
+ if !tab.ips.Add(ip) {
+ log.Debug("IP exceeds table limit", "ip", ip)
return false
}
- if len(b.entries) < bucketSize {
- b.entries = append(b.entries, nil)
+ if !b.ips.Add(ip) {
+ log.Debug("IP exceeds bucket limit", "ip", ip)
+ tab.ips.Remove(ip)
+ return false
}
- copy(b.entries[1:], b.entries)
- b.entries[0] = n
return true
}
+func (tab *Table) removeIP(b *bucket, ip net.IP) {
+ if netutil.IsLAN(ip) {
+ return
+ }
+ tab.ips.Remove(ip)
+ b.ips.Remove(ip)
+}
+
+func (tab *Table) addReplacement(b *bucket, n *Node) {
+ for _, e := range b.replacements {
+ if e.ID == n.ID {
+ return // already in list
+ }
+ }
+ if !tab.addIP(b, n.IP) {
+ return
+ }
+ var removed *Node
+ b.replacements, removed = pushNode(b.replacements, n, maxReplacements)
+ if removed != nil {
+ tab.removeIP(b, removed.IP)
+ }
+}
+
+// replace removes n from the replacement list and replaces 'last' with it if it is the
+// last entry in the bucket. If 'last' isn't the last entry, it has either been replaced
+// with someone else or became active.
+func (tab *Table) replace(b *bucket, last *Node) *Node {
+ if len(b.entries) == 0 || b.entries[len(b.entries)-1].ID != last.ID {
+ // Entry has moved, don't replace it.
+ return nil
+ }
+ // Still the last entry.
+ if len(b.replacements) == 0 {
+ tab.deleteInBucket(b, last)
+ return nil
+ }
+ r := b.replacements[tab.rand.Intn(len(b.replacements))]
+ b.replacements = deleteNode(b.replacements, r)
+ b.entries[len(b.entries)-1] = r
+ tab.removeIP(b, last.IP)
+ return r
+}
+
+// bump moves the given node to the front of the bucket entry list
+// if it is contained in that list.
func (b *bucket) bump(n *Node) bool {
for i := range b.entries {
if b.entries[i].ID == n.ID {
@@ -653,6 +793,50 @@ func (b *bucket) bump(n *Node) bool {
return false
}
+// bumpOrAdd moves n to the front of the bucket entry list or adds it if the list isn't
+// full. The return value is true if n is in the bucket.
+func (tab *Table) bumpOrAdd(b *bucket, n *Node) bool {
+ if b.bump(n) {
+ return true
+ }
+ if len(b.entries) >= bucketSize || !tab.addIP(b, n.IP) {
+ return false
+ }
+ b.entries, _ = pushNode(b.entries, n, bucketSize)
+ b.replacements = deleteNode(b.replacements, n)
+ n.addedAt = time.Now()
+ if tab.nodeAddedHook != nil {
+ tab.nodeAddedHook(n)
+ }
+ return true
+}
+
+func (tab *Table) deleteInBucket(b *bucket, n *Node) {
+ b.entries = deleteNode(b.entries, n)
+ tab.removeIP(b, n.IP)
+}
+
+// pushNode adds n to the front of list, keeping at most max items.
+func pushNode(list []*Node, n *Node, max int) ([]*Node, *Node) {
+ if len(list) < max {
+ list = append(list, nil)
+ }
+ removed := list[len(list)-1]
+ copy(list[1:], list)
+ list[0] = n
+ return list, removed
+}
+
+// deleteNode removes n from list.
+func deleteNode(list []*Node, n *Node) []*Node {
+ for i := range list {
+ if list[i].ID == n.ID {
+ return append(list[:i], list[i+1:]...)
+ }
+ }
+ return list
+}
+
// nodesByDistance is a list of nodes, ordered by
// distance to target.
type nodesByDistance struct {
diff --git a/p2p/discover/table_test.go b/p2p/discover/table_test.go
index 1037cc609..3ce48d299 100644
--- a/p2p/discover/table_test.go
+++ b/p2p/discover/table_test.go
@@ -20,6 +20,7 @@ import (
"crypto/ecdsa"
"fmt"
"math/rand"
+ "sync"
"net"
"reflect"
@@ -32,60 +33,65 @@ import (
)
func TestTable_pingReplace(t *testing.T) {
- doit := func(newNodeIsResponding, lastInBucketIsResponding bool) {
- transport := newPingRecorder()
- tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "")
- defer tab.Close()
- pingSender := NewNode(MustHexID("a502af0f59b2aab7746995408c79e9ca312d2793cc997e44fc55eda62f0150bbb8c59a6f9269ba3a081518b62699ee807c7c19c20125ddfccca872608af9e370"), net.IP{}, 99, 99)
+ run := func(newNodeResponding, lastInBucketResponding bool) {
+ name := fmt.Sprintf("newNodeResponding=%t/lastInBucketResponding=%t", newNodeResponding, lastInBucketResponding)
+ t.Run(name, func(t *testing.T) {
+ t.Parallel()
+ testPingReplace(t, newNodeResponding, lastInBucketResponding)
+ })
+ }
- // fill up the sender's bucket.
- last := fillBucket(tab, 253)
+ run(true, true)
+ run(false, true)
+ run(true, false)
+ run(false, false)
+}
- // this call to bond should replace the last node
- // in its bucket if the node is not responding.
- transport.responding[last.ID] = lastInBucketIsResponding
- transport.responding[pingSender.ID] = newNodeIsResponding
- tab.bond(true, pingSender.ID, &net.UDPAddr{}, 0)
+func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding bool) {
+ transport := newPingRecorder()
+ tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil)
+ defer tab.Close()
- // first ping goes to sender (bonding pingback)
- if !transport.pinged[pingSender.ID] {
- t.Error("table did not ping back sender")
- }
- if newNodeIsResponding {
- // second ping goes to oldest node in bucket
- // to see whether it is still alive.
- if !transport.pinged[last.ID] {
- t.Error("table did not ping last node in bucket")
- }
- }
+ // Wait for init so bond is accepted.
+ <-tab.initDone
- tab.mutex.Lock()
- defer tab.mutex.Unlock()
- if l := len(tab.buckets[253].entries); l != bucketSize {
- t.Errorf("wrong bucket size after bond: got %d, want %d", l, bucketSize)
- }
+ // fill up the sender's bucket.
+ pingSender := NewNode(MustHexID("a502af0f59b2aab7746995408c79e9ca312d2793cc997e44fc55eda62f0150bbb8c59a6f9269ba3a081518b62699ee807c7c19c20125ddfccca872608af9e370"), net.IP{}, 99, 99)
+ last := fillBucket(tab, pingSender)
- if lastInBucketIsResponding || !newNodeIsResponding {
- if !contains(tab.buckets[253].entries, last.ID) {
- t.Error("last entry was removed")
- }
- if contains(tab.buckets[253].entries, pingSender.ID) {
- t.Error("new entry was added")
- }
- } else {
- if contains(tab.buckets[253].entries, last.ID) {
- t.Error("last entry was not removed")
- }
- if !contains(tab.buckets[253].entries, pingSender.ID) {
- t.Error("new entry was not added")
- }
- }
+ // this call to bond should replace the last node
+ // in its bucket if the node is not responding.
+ transport.dead[last.ID] = !lastInBucketIsResponding
+ transport.dead[pingSender.ID] = !newNodeIsResponding
+ tab.bond(true, pingSender.ID, &net.UDPAddr{}, 0)
+ tab.doRevalidate(make(chan struct{}, 1))
+
+ // first ping goes to sender (bonding pingback)
+ if !transport.pinged[pingSender.ID] {
+ t.Error("table did not ping back sender")
+ }
+ if !transport.pinged[last.ID] {
+ // second ping goes to oldest node in bucket
+ // to see whether it is still alive.
+ t.Error("table did not ping last node in bucket")
}
- doit(true, true)
- doit(false, true)
- doit(true, false)
- doit(false, false)
+ tab.mutex.Lock()
+ defer tab.mutex.Unlock()
+ wantSize := bucketSize
+ if !lastInBucketIsResponding && !newNodeIsResponding {
+ wantSize--
+ }
+ if l := len(tab.bucket(pingSender.sha).entries); l != wantSize {
+ t.Errorf("wrong bucket size after bond: got %d, want %d", l, wantSize)
+ }
+ if found := contains(tab.bucket(pingSender.sha).entries, last.ID); found != lastInBucketIsResponding {
+ t.Errorf("last entry found: %t, want: %t", found, lastInBucketIsResponding)
+ }
+ wantNewEntry := newNodeIsResponding && !lastInBucketIsResponding
+ if found := contains(tab.bucket(pingSender.sha).entries, pingSender.ID); found != wantNewEntry {
+ t.Errorf("new entry found: %t, want: %t", found, wantNewEntry)
+ }
}
func TestBucket_bumpNoDuplicates(t *testing.T) {
@@ -130,11 +136,45 @@ func TestBucket_bumpNoDuplicates(t *testing.T) {
}
}
+// This checks that the table-wide IP limit is applied correctly.
+func TestTable_IPLimit(t *testing.T) {
+ transport := newPingRecorder()
+ tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil)
+ defer tab.Close()
+
+ for i := 0; i < tableIPLimit+1; i++ {
+ n := nodeAtDistance(tab.self.sha, i)
+ n.IP = net.IP{172, 0, 1, byte(i)}
+ tab.add(n)
+ }
+ if tab.len() > tableIPLimit {
+ t.Errorf("too many nodes in table")
+ }
+}
+
+// This checks that the table-wide IP limit is applied correctly.
+func TestTable_BucketIPLimit(t *testing.T) {
+ transport := newPingRecorder()
+ tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil)
+ defer tab.Close()
+
+ d := 3
+ for i := 0; i < bucketIPLimit+1; i++ {
+ n := nodeAtDistance(tab.self.sha, d)
+ n.IP = net.IP{172, 0, 1, byte(i)}
+ tab.add(n)
+ }
+ if tab.len() > bucketIPLimit {
+ t.Errorf("too many nodes in table")
+ }
+}
+
// fillBucket inserts nodes into the given bucket until
// it is full. The node's IDs dont correspond to their
// hashes.
-func fillBucket(tab *Table, ld int) (last *Node) {
- b := tab.buckets[ld]
+func fillBucket(tab *Table, n *Node) (last *Node) {
+ ld := logdist(tab.self.sha, n.sha)
+ b := tab.bucket(n.sha)
for len(b.entries) < bucketSize {
b.entries = append(b.entries, nodeAtDistance(tab.self.sha, ld))
}
@@ -146,30 +186,39 @@ func fillBucket(tab *Table, ld int) (last *Node) {
func nodeAtDistance(base common.Hash, ld int) (n *Node) {
n = new(Node)
n.sha = hashAtDistance(base, ld)
- n.IP = net.IP{10, 0, 2, byte(ld)}
+ n.IP = net.IP{byte(ld), 0, 2, byte(ld)}
copy(n.ID[:], n.sha[:]) // ensure the node still has a unique ID
return n
}
-type pingRecorder struct{ responding, pinged map[NodeID]bool }
+type pingRecorder struct {
+ mu sync.Mutex
+ dead, pinged map[NodeID]bool
+}
func newPingRecorder() *pingRecorder {
- return &pingRecorder{make(map[NodeID]bool), make(map[NodeID]bool)}
+ return &pingRecorder{
+ dead: make(map[NodeID]bool),
+ pinged: make(map[NodeID]bool),
+ }
}
func (t *pingRecorder) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) {
- panic("findnode called on pingRecorder")
+ return nil, nil
}
func (t *pingRecorder) close() {}
func (t *pingRecorder) waitping(from NodeID) error {
return nil // remote always pings
}
func (t *pingRecorder) ping(toid NodeID, toaddr *net.UDPAddr) error {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
t.pinged[toid] = true
- if t.responding[toid] {
- return nil
- } else {
+ if t.dead[toid] {
return errTimeout
+ } else {
+ return nil
}
}
@@ -178,7 +227,8 @@ func TestTable_closest(t *testing.T) {
test := func(test *closeTest) bool {
// for any node table, Target and N
- tab, _ := newTable(nil, test.Self, &net.UDPAddr{}, "")
+ transport := newPingRecorder()
+ tab, _ := newTable(transport, test.Self, &net.UDPAddr{}, "", nil)
defer tab.Close()
tab.stuff(test.All)
@@ -237,8 +287,11 @@ func TestTable_ReadRandomNodesGetAll(t *testing.T) {
},
}
test := func(buf []*Node) bool {
- tab, _ := newTable(nil, NodeID{}, &net.UDPAddr{}, "")
+ transport := newPingRecorder()
+ tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil)
defer tab.Close()
+ <-tab.initDone
+
for i := 0; i < len(buf); i++ {
ld := cfg.Rand.Intn(len(tab.buckets))
tab.stuff([]*Node{nodeAtDistance(tab.self.sha, ld)})
@@ -280,7 +333,7 @@ func (*closeTest) Generate(rand *rand.Rand, size int) reflect.Value {
func TestTable_Lookup(t *testing.T) {
self := nodeAtDistance(common.Hash{}, 0)
- tab, _ := newTable(lookupTestnet, self.ID, &net.UDPAddr{}, "")
+ tab, _ := newTable(lookupTestnet, self.ID, &net.UDPAddr{}, "", nil)
defer tab.Close()
// lookup on empty table returns no nodes
diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go
index f9eb99ee3..524c6e498 100644
--- a/p2p/discover/udp.go
+++ b/p2p/discover/udp.go
@@ -210,17 +210,28 @@ type reply struct {
matched chan<- bool
}
+// ReadPacket is sent to the unhandled channel when it could not be processed
+type ReadPacket struct {
+ Data []byte
+ Addr *net.UDPAddr
+}
+
+// Config holds Table-related settings.
+type Config struct {
+ // These settings are required and configure the UDP listener:
+ PrivateKey *ecdsa.PrivateKey
+
+ // These settings are optional:
+ AnnounceAddr *net.UDPAddr // local address announced in the DHT
+ NodeDBPath string // if set, the node database is stored at this filesystem location
+ NetRestrict *netutil.Netlist // network whitelist
+ Bootnodes []*Node // list of bootstrap nodes
+ Unhandled chan<- ReadPacket // unhandled packets are sent on this channel
+}
+
// ListenUDP returns a new table that listens for UDP packets on laddr.
-func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, error) {
- addr, err := net.ResolveUDPAddr("udp", laddr)
- if err != nil {
- return nil, err
- }
- conn, err := net.ListenUDP("udp", addr)
- if err != nil {
- return nil, err
- }
- tab, _, err := newUDP(priv, conn, natm, nodeDBPath, netrestrict)
+func ListenUDP(c conn, cfg Config) (*Table, error) {
+ tab, _, err := newUDP(c, cfg)
if err != nil {
return nil, err
}
@@ -228,35 +239,29 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBP
return tab, nil
}
-func newUDP(priv *ecdsa.PrivateKey, c conn, natm nat.Interface, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, *udp, error) {
+func newUDP(c conn, cfg Config) (*Table, *udp, error) {
udp := &udp{
conn: c,
- priv: priv,
- netrestrict: netrestrict,
+ priv: cfg.PrivateKey,
+ netrestrict: cfg.NetRestrict,
closing: make(chan struct{}),
gotreply: make(chan reply),
addpending: make(chan *pending),
}
realaddr := c.LocalAddr().(*net.UDPAddr)
- if natm != nil {
- if !realaddr.IP.IsLoopback() {
- go nat.Map(natm, udp.closing, "udp", realaddr.Port, realaddr.Port, "ethereum discovery")
- }
- // TODO: react to external IP changes over time.
- if ext, err := natm.ExternalIP(); err == nil {
- realaddr = &net.UDPAddr{IP: ext, Port: realaddr.Port}
- }
+ if cfg.AnnounceAddr != nil {
+ realaddr = cfg.AnnounceAddr
}
// TODO: separate TCP port
udp.ourEndpoint = makeEndpoint(realaddr, uint16(realaddr.Port))
- tab, err := newTable(udp, PubkeyID(&priv.PublicKey), realaddr, nodeDBPath)
+ tab, err := newTable(udp, PubkeyID(&cfg.PrivateKey.PublicKey), realaddr, cfg.NodeDBPath, cfg.Bootnodes)
if err != nil {
return nil, nil, err
}
udp.Table = tab
go udp.loop()
- go udp.readLoop()
+ go udp.readLoop(cfg.Unhandled)
return udp.Table, udp, nil
}
@@ -268,14 +273,20 @@ func (t *udp) close() {
// ping sends a ping message to the given node and waits for a reply.
func (t *udp) ping(toid NodeID, toaddr *net.UDPAddr) error {
- // TODO: maybe check for ReplyTo field in callback to measure RTT
- errc := t.pending(toid, pongPacket, func(interface{}) bool { return true })
- t.send(toaddr, pingPacket, &ping{
+ req := &ping{
Version: Version,
From: t.ourEndpoint,
To: makeEndpoint(toaddr, 0), // TODO: maybe use known TCP port from DB
Expiration: uint64(time.Now().Add(expiration).Unix()),
+ }
+ packet, hash, err := encodePacket(t.priv, pingPacket, req)
+ if err != nil {
+ return err
+ }
+ errc := t.pending(toid, pongPacket, func(p interface{}) bool {
+ return bytes.Equal(p.(*pong).ReplyTok, hash)
})
+ t.write(toaddr, req.name(), packet)
return <-errc
}
@@ -459,41 +470,49 @@ func init() {
}
}
-func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req packet) error {
- packet, err := encodePacket(t.priv, ptype, req)
+func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req packet) ([]byte, error) {
+ packet, hash, err := encodePacket(t.priv, ptype, req)
if err != nil {
- return err
+ return hash, err
}
- _, err = t.conn.WriteToUDP(packet, toaddr)
- log.Trace(">> "+req.name(), "addr", toaddr, "err", err)
+ return hash, t.write(toaddr, req.name(), packet)
+}
+
+func (t *udp) write(toaddr *net.UDPAddr, what string, packet []byte) error {
+ _, err := t.conn.WriteToUDP(packet, toaddr)
+ log.Trace(">> "+what, "addr", toaddr, "err", err)
return err
}
-func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) ([]byte, error) {
+func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) (packet, hash []byte, err error) {
b := new(bytes.Buffer)
b.Write(headSpace)
b.WriteByte(ptype)
if err := rlp.Encode(b, req); err != nil {
log.Error("Can't encode discv4 packet", "err", err)
- return nil, err
+ return nil, nil, err
}
- packet := b.Bytes()
+ packet = b.Bytes()
sig, err := crypto.Sign(crypto.Keccak256(packet[headSize:]), priv)
if err != nil {
log.Error("Can't sign discv4 packet", "err", err)
- return nil, err
+ return nil, nil, err
}
copy(packet[macSize:], sig)
// add the hash to the front. Note: this doesn't protect the
// packet in any way. Our public key will be part of this hash in
// The future.
- copy(packet, crypto.Keccak256(packet[macSize:]))
- return packet, nil
+ hash = crypto.Keccak256(packet[macSize:])
+ copy(packet, hash)
+ return packet, hash, nil
}
// readLoop runs in its own goroutine. it handles incoming UDP packets.
-func (t *udp) readLoop() {
+func (t *udp) readLoop(unhandled chan<- ReadPacket) {
defer t.conn.Close()
+ if unhandled != nil {
+ defer close(unhandled)
+ }
// Discovery packets are defined to be no larger than 1280 bytes.
// Packets larger than this size will be cut at the end and treated
// as invalid because their hash won't match.
@@ -509,7 +528,12 @@ func (t *udp) readLoop() {
log.Debug("UDP read error", "err", err)
return
}
- t.handlePacket(from, buf[:nbytes])
+ if t.handlePacket(from, buf[:nbytes]) != nil && unhandled != nil {
+ select {
+ case unhandled <- ReadPacket{buf[:nbytes], from}:
+ default:
+ }
+ }
}
}
@@ -589,7 +613,7 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte
if expired(req.Expiration) {
return errExpired
}
- if t.db.node(fromID) == nil {
+ if !t.db.hasBond(fromID) {
// No bond exists, we don't process the packet. This prevents
// an attack vector where the discovery protocol could be used
// to amplify traffic in a DDOS attack. A malicious actor
@@ -605,18 +629,22 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte
t.mutex.Unlock()
p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())}
+ var sent bool
// Send neighbors in chunks with at most maxNeighbors per packet
// to stay below the 1280 byte limit.
- for i, n := range closest {
- if netutil.CheckRelayIP(from.IP, n.IP) != nil {
- continue
+ for _, n := range closest {
+ if netutil.CheckRelayIP(from.IP, n.IP) == nil {
+ p.Nodes = append(p.Nodes, nodeToRPC(n))
}
- p.Nodes = append(p.Nodes, nodeToRPC(n))
- if len(p.Nodes) == maxNeighbors || i == len(closest)-1 {
+ if len(p.Nodes) == maxNeighbors {
t.send(from, neighborsPacket, &p)
p.Nodes = p.Nodes[:0]
+ sent = true
}
}
+ if len(p.Nodes) > 0 || !sent {
+ t.send(from, neighborsPacket, &p)
+ }
return nil
}
diff --git a/p2p/discover/udp_test.go b/p2p/discover/udp_test.go
index 21e8b561d..db9804f7b 100644
--- a/p2p/discover/udp_test.go
+++ b/p2p/discover/udp_test.go
@@ -70,13 +70,15 @@ func newUDPTest(t *testing.T) *udpTest {
remotekey: newkey(),
remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303},
}
- test.table, test.udp, _ = newUDP(test.localkey, test.pipe, nil, "", nil)
+ test.table, test.udp, _ = newUDP(test.pipe, Config{PrivateKey: test.localkey})
+ // Wait for initial refresh so the table doesn't send unexpected findnode.
+ <-test.table.initDone
return test
}
// handles a packet as if it had been sent to the transport.
func (test *udpTest) packetIn(wantError error, ptype byte, data packet) error {
- enc, err := encodePacket(test.remotekey, ptype, data)
+ enc, _, err := encodePacket(test.remotekey, ptype, data)
if err != nil {
return test.errorf("packet (%d) encode error: %v", ptype, err)
}
@@ -89,19 +91,19 @@ func (test *udpTest) packetIn(wantError error, ptype byte, data packet) error {
// waits for a packet to be sent by the transport.
// validate should have type func(*udpTest, X) error, where X is a packet type.
-func (test *udpTest) waitPacketOut(validate interface{}) error {
+func (test *udpTest) waitPacketOut(validate interface{}) ([]byte, error) {
dgram := test.pipe.waitPacketOut()
- p, _, _, err := decodePacket(dgram)
+ p, _, hash, err := decodePacket(dgram)
if err != nil {
- return test.errorf("sent packet decode error: %v", err)
+ return hash, test.errorf("sent packet decode error: %v", err)
}
fn := reflect.ValueOf(validate)
exptype := fn.Type().In(0)
if reflect.TypeOf(p) != exptype {
- return test.errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype)
+ return hash, test.errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype)
}
fn.Call([]reflect.Value{reflect.ValueOf(p)})
- return nil
+ return hash, nil
}
func (test *udpTest) errorf(format string, args ...interface{}) error {
@@ -245,12 +247,8 @@ func TestUDP_findnode(t *testing.T) {
// ensure there's a bond with the test node,
// findnode won't be accepted otherwise.
- test.table.db.updateNode(NewNode(
- PubkeyID(&test.remotekey.PublicKey),
- test.remoteaddr.IP,
- uint16(test.remoteaddr.Port),
- 99,
- ))
+ test.table.db.updateBondTime(PubkeyID(&test.remotekey.PublicKey), time.Now())
+
// check that closest neighbors are returned.
test.packetIn(nil, findnodePacket, &findnode{Target: testTarget, Expiration: futureExp})
expected := test.table.closest(targetHash, bucketSize)
@@ -350,7 +348,7 @@ func TestUDP_successfulPing(t *testing.T) {
})
// remote is unknown, the table pings back.
- test.waitPacketOut(func(p *ping) error {
+ hash, _ := test.waitPacketOut(func(p *ping) error {
if !reflect.DeepEqual(p.From, test.udp.ourEndpoint) {
t.Errorf("got ping.From %v, want %v", p.From, test.udp.ourEndpoint)
}
@@ -364,7 +362,7 @@ func TestUDP_successfulPing(t *testing.T) {
}
return nil
})
- test.packetIn(nil, pongPacket, &pong{Expiration: futureExp})
+ test.packetIn(nil, pongPacket, &pong{ReplyTok: hash, Expiration: futureExp})
// the node should be added to the table shortly after getting the
// pong packet.
diff --git a/p2p/discv5/net.go b/p2p/discv5/net.go
index cd9981584..52c677b62 100644
--- a/p2p/discv5/net.go
+++ b/p2p/discv5/net.go
@@ -29,7 +29,6 @@ import (
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/crypto/sha3"
"github.com/ethereum/go-ethereum/log"
- "github.com/ethereum/go-ethereum/p2p/nat"
"github.com/ethereum/go-ethereum/p2p/netutil"
"github.com/ethereum/go-ethereum/rlp"
)
@@ -134,7 +133,7 @@ type timeoutEvent struct {
node *Node
}
-func newNetwork(conn transport, ourPubkey ecdsa.PublicKey, natm nat.Interface, dbPath string, netrestrict *netutil.Netlist) (*Network, error) {
+func newNetwork(conn transport, ourPubkey ecdsa.PublicKey, dbPath string, netrestrict *netutil.Netlist) (*Network, error) {
ourID := PubkeyID(&ourPubkey)
var db *nodeDB
@@ -566,11 +565,8 @@ loop:
if lookupChn := searchInfo[res.target.topic].lookupChn; lookupChn != nil {
lookupChn <- net.ticketStore.radius[res.target.topic].converged
}
- net.ticketStore.searchLookupDone(res.target, res.nodes, func(n *Node) []byte {
- net.ping(n, n.addr())
- return n.pingEcho
- }, func(n *Node, topic Topic) []byte {
- if n.state == known {
+ net.ticketStore.searchLookupDone(res.target, res.nodes, func(n *Node, topic Topic) []byte {
+ if n.state != nil && n.state.canQuery {
return net.conn.send(n, topicQueryPacket, topicQuery{Topic: topic}) // TODO: set expiration
} else {
if n.state == unknown {
@@ -634,15 +630,20 @@ loop:
}
net.refreshResp <- refreshDone
case <-refreshDone:
- log.Trace("<-net.refreshDone")
- refreshDone = nil
- list := searchReqWhenRefreshDone
- searchReqWhenRefreshDone = nil
- go func() {
- for _, req := range list {
- net.topicSearchReq <- req
- }
- }()
+ log.Trace("<-net.refreshDone", "table size", net.tab.count)
+ if net.tab.count != 0 {
+ refreshDone = nil
+ list := searchReqWhenRefreshDone
+ searchReqWhenRefreshDone = nil
+ go func() {
+ for _, req := range list {
+ net.topicSearchReq <- req
+ }
+ }()
+ } else {
+ refreshDone = make(chan struct{})
+ net.refresh(refreshDone)
+ }
}
}
log.Trace("loop stopped")
@@ -752,7 +753,15 @@ func (net *Network) internNodeFromNeighbours(sender *net.UDPAddr, rn rpcNode) (n
return n, err
}
if !n.IP.Equal(rn.IP) || n.UDP != rn.UDP || n.TCP != rn.TCP {
- err = fmt.Errorf("metadata mismatch: got %v, want %v", rn, n)
+ if n.state == known {
+ // reject address change if node is known by us
+ err = fmt.Errorf("metadata mismatch: got %v, want %v", rn, n)
+ } else {
+ // accept otherwise; this will be handled nicer with signed ENRs
+ n.IP = rn.IP
+ n.UDP = rn.UDP
+ n.TCP = rn.TCP
+ }
}
return n, err
}
diff --git a/p2p/discv5/net_test.go b/p2p/discv5/net_test.go
index bd234f5ba..369282ca9 100644
--- a/p2p/discv5/net_test.go
+++ b/p2p/discv5/net_test.go
@@ -28,7 +28,7 @@ import (
func TestNetwork_Lookup(t *testing.T) {
key, _ := crypto.GenerateKey()
- network, err := newNetwork(lookupTestnet, key.PublicKey, nil, "", nil)
+ network, err := newNetwork(lookupTestnet, key.PublicKey, "", nil)
if err != nil {
t.Fatal(err)
}
diff --git a/p2p/discv5/sim_test.go b/p2p/discv5/sim_test.go
index bf57872e2..543faecd4 100644
--- a/p2p/discv5/sim_test.go
+++ b/p2p/discv5/sim_test.go
@@ -282,7 +282,7 @@ func (s *simulation) launchNode(log bool) *Network {
addr := &net.UDPAddr{IP: ip, Port: 30303}
transport := &simTransport{joinTime: time.Now(), sender: id, senderAddr: addr, sim: s, priv: key}
- net, err := newNetwork(transport, key.PublicKey, nil, "<no database>", nil)
+ net, err := newNetwork(transport, key.PublicKey, "<no database>", nil)
if err != nil {
panic("cannot launch new node: " + err.Error())
}
diff --git a/p2p/discv5/ticket.go b/p2p/discv5/ticket.go
index b45ec4d2b..b3d1ac4ba 100644
--- a/p2p/discv5/ticket.go
+++ b/p2p/discv5/ticket.go
@@ -350,7 +350,7 @@ func (s *ticketStore) nextFilteredTicket() (*ticketRef, time.Duration) {
regTime := now + mclock.AbsTime(wait)
topic := ticket.t.topics[ticket.idx]
- if regTime >= s.tickets[topic].nextReg {
+ if s.tickets[topic] != nil && regTime >= s.tickets[topic].nextReg {
return ticket, wait
}
s.removeTicketRef(*ticket)
@@ -420,11 +420,14 @@ func (s *ticketStore) nextRegisterableTicket() (*ticketRef, time.Duration) {
func (s *ticketStore) removeTicketRef(ref ticketRef) {
log.Trace("Removing discovery ticket reference", "node", ref.t.node.ID, "serial", ref.t.serial)
+ // Make nextRegisterableTicket return the next available ticket.
+ s.nextTicketCached = nil
+
topic := ref.topic()
tickets := s.tickets[topic]
if tickets == nil {
- log.Warn("Removing tickets from unknown topic", "topic", topic)
+ log.Trace("Removing tickets from unknown topic", "topic", topic)
return
}
bucket := timeBucket(ref.t.regTime[ref.idx] / mclock.AbsTime(ticketTimeBucketLen))
@@ -450,9 +453,6 @@ func (s *ticketStore) removeTicketRef(ref ticketRef) {
delete(s.nodes, ref.t.node)
delete(s.nodeLastReq, ref.t.node)
}
-
- // Make nextRegisterableTicket return the next available ticket.
- s.nextTicketCached = nil
}
type lookupInfo struct {
@@ -494,13 +494,13 @@ func (s *ticketStore) registerLookupDone(lookup lookupInfo, nodes []*Node, ping
}
}
-func (s *ticketStore) searchLookupDone(lookup lookupInfo, nodes []*Node, ping func(n *Node) []byte, query func(n *Node, topic Topic) []byte) {
+func (s *ticketStore) searchLookupDone(lookup lookupInfo, nodes []*Node, query func(n *Node, topic Topic) []byte) {
now := mclock.Now()
for i, n := range nodes {
if i == 0 || (binary.BigEndian.Uint64(n.sha[:8])^binary.BigEndian.Uint64(lookup.target[:8])) < s.radius[lookup.topic].minRadius {
if lookup.radiusLookup {
if lastReq, ok := s.nodeLastReq[n]; !ok || time.Duration(now-lastReq.time) > radiusTC {
- s.nodeLastReq[n] = reqInfo{pingHash: ping(n), lookup: lookup, time: now}
+ s.nodeLastReq[n] = reqInfo{pingHash: nil, lookup: lookup, time: now}
}
} // else {
if s.canQueryTopic(n, lookup.topic) {
@@ -642,7 +642,7 @@ func (s *ticketStore) gotTopicNodes(from *Node, hash common.Hash, nodes []rpcNod
if ip.IsUnspecified() || ip.IsLoopback() {
ip = from.IP
}
- n := NewNode(node.ID, ip, node.UDP-1, node.TCP-1) // subtract one from port while discv5 is running in test mode on UDPport+1
+ n := NewNode(node.ID, ip, node.UDP, node.TCP)
select {
case chn <- n:
default:
diff --git a/p2p/discv5/udp.go b/p2p/discv5/udp.go
index 26087cd8e..6ce72d2c1 100644
--- a/p2p/discv5/udp.go
+++ b/p2p/discv5/udp.go
@@ -37,7 +37,7 @@ const Version = 4
// Errors
var (
errPacketTooSmall = errors.New("too small")
- errBadHash = errors.New("bad hash")
+ errBadPrefix = errors.New("bad prefix")
errExpired = errors.New("expired")
errUnsolicitedReply = errors.New("unsolicited reply")
errUnknownNode = errors.New("unknown node")
@@ -49,7 +49,7 @@ var (
// Timeouts
const (
respTimeout = 500 * time.Millisecond
- sendTimeout = 500 * time.Millisecond
+ queryDelay = 1000 * time.Millisecond
expiration = 20 * time.Second
ntpFailureThreshold = 32 // Continuous timeouts after which to check NTP
@@ -145,10 +145,11 @@ type (
}
)
-const (
- macSize = 256 / 8
- sigSize = 520 / 8
- headSize = macSize + sigSize // space of packet frame data
+var (
+ versionPrefix = []byte("temporary discovery v5")
+ versionPrefixSize = len(versionPrefix)
+ sigSize = 520 / 8
+ headSize = versionPrefixSize + sigSize // space of packet frame data
)
// Neighbors replies are sent across multiple packets to
@@ -237,30 +238,23 @@ type udp struct {
}
// ListenUDP returns a new table that listens for UDP packets on laddr.
-func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string, netrestrict *netutil.Netlist) (*Network, error) {
- transport, err := listenUDP(priv, laddr)
+func ListenUDP(priv *ecdsa.PrivateKey, conn conn, realaddr *net.UDPAddr, nodeDBPath string, netrestrict *netutil.Netlist) (*Network, error) {
+ transport, err := listenUDP(priv, conn, realaddr)
if err != nil {
return nil, err
}
- net, err := newNetwork(transport, priv.PublicKey, natm, nodeDBPath, netrestrict)
+ net, err := newNetwork(transport, priv.PublicKey, nodeDBPath, netrestrict)
if err != nil {
return nil, err
}
+ log.Info("UDP listener up", "net", net.tab.self)
transport.net = net
go transport.readLoop()
return net, nil
}
-func listenUDP(priv *ecdsa.PrivateKey, laddr string) (*udp, error) {
- addr, err := net.ResolveUDPAddr("udp", laddr)
- if err != nil {
- return nil, err
- }
- conn, err := net.ListenUDP("udp", addr)
- if err != nil {
- return nil, err
- }
- return &udp{conn: conn, priv: priv, ourEndpoint: makeEndpoint(addr, uint16(addr.Port))}, nil
+func listenUDP(priv *ecdsa.PrivateKey, conn conn, realaddr *net.UDPAddr) (*udp, error) {
+ return &udp{conn: conn, priv: priv, ourEndpoint: makeEndpoint(realaddr, uint16(realaddr.Port))}, nil
}
func (t *udp) localAddr() *net.UDPAddr {
@@ -324,20 +318,20 @@ func (t *udp) sendTopicRegister(remote *Node, topics []Topic, idx int, pong []by
func (t *udp) sendTopicNodes(remote *Node, queryHash common.Hash, nodes []*Node) {
p := topicNodes{Echo: queryHash}
- if len(nodes) == 0 {
- t.sendPacket(remote.ID, remote.addr(), byte(topicNodesPacket), p)
- return
- }
- for i, result := range nodes {
- if netutil.CheckRelayIP(remote.IP, result.IP) != nil {
- continue
+ var sent bool
+ for _, result := range nodes {
+ if result.IP.Equal(t.net.tab.self.IP) || netutil.CheckRelayIP(remote.IP, result.IP) == nil {
+ p.Nodes = append(p.Nodes, nodeToRPC(result))
}
- p.Nodes = append(p.Nodes, nodeToRPC(result))
- if len(p.Nodes) == maxTopicNodes || i == len(nodes)-1 {
+ if len(p.Nodes) == maxTopicNodes {
t.sendPacket(remote.ID, remote.addr(), byte(topicNodesPacket), p)
p.Nodes = p.Nodes[:0]
+ sent = true
}
}
+ if !sent || len(p.Nodes) > 0 {
+ t.sendPacket(remote.ID, remote.addr(), byte(topicNodesPacket), p)
+ }
}
func (t *udp) sendPacket(toid NodeID, toaddr *net.UDPAddr, ptype byte, req interface{}) (hash []byte, err error) {
@@ -372,11 +366,9 @@ func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) (p, hash
log.Error(fmt.Sprint("could not sign packet:", err))
return nil, nil, err
}
- copy(packet[macSize:], sig)
- // add the hash to the front. Note: this doesn't protect the
- // packet in any way.
- hash = crypto.Keccak256(packet[macSize:])
- copy(packet, hash)
+ copy(packet, versionPrefix)
+ copy(packet[versionPrefixSize:], sig)
+ hash = crypto.Keccak256(packet[versionPrefixSize:])
return packet, hash, nil
}
@@ -420,17 +412,16 @@ func decodePacket(buffer []byte, pkt *ingressPacket) error {
}
buf := make([]byte, len(buffer))
copy(buf, buffer)
- hash, sig, sigdata := buf[:macSize], buf[macSize:headSize], buf[headSize:]
- shouldhash := crypto.Keccak256(buf[macSize:])
- if !bytes.Equal(hash, shouldhash) {
- return errBadHash
+ prefix, sig, sigdata := buf[:versionPrefixSize], buf[versionPrefixSize:headSize], buf[headSize:]
+ if !bytes.Equal(prefix, versionPrefix) {
+ return errBadPrefix
}
fromID, err := recoverNodeID(crypto.Keccak256(buf[headSize:]), sig)
if err != nil {
return err
}
pkt.rawData = buf
- pkt.hash = hash
+ pkt.hash = crypto.Keccak256(buf[versionPrefixSize:])
pkt.remoteID = fromID
switch pkt.ev = nodeEvent(sigdata[0]); pkt.ev {
case pingPacket:
diff --git a/p2p/netutil/net.go b/p2p/netutil/net.go
index f6005afd2..656abb682 100644
--- a/p2p/netutil/net.go
+++ b/p2p/netutil/net.go
@@ -18,8 +18,11 @@
package netutil
import (
+ "bytes"
"errors"
+ "fmt"
"net"
+ "sort"
"strings"
)
@@ -189,3 +192,131 @@ func CheckRelayIP(sender, addr net.IP) error {
}
return nil
}
+
+// SameNet reports whether two IP addresses have an equal prefix of the given bit length.
+func SameNet(bits uint, ip, other net.IP) bool {
+ ip4, other4 := ip.To4(), other.To4()
+ switch {
+ case (ip4 == nil) != (other4 == nil):
+ return false
+ case ip4 != nil:
+ return sameNet(bits, ip4, other4)
+ default:
+ return sameNet(bits, ip.To16(), other.To16())
+ }
+}
+
+func sameNet(bits uint, ip, other net.IP) bool {
+ nb := int(bits / 8)
+ mask := ^byte(0xFF >> (bits % 8))
+ if mask != 0 && nb < len(ip) && ip[nb]&mask != other[nb]&mask {
+ return false
+ }
+ return nb <= len(ip) && bytes.Equal(ip[:nb], other[:nb])
+}
+
+// DistinctNetSet tracks IPs, ensuring that at most N of them
+// fall into the same network range.
+type DistinctNetSet struct {
+ Subnet uint // number of common prefix bits
+ Limit uint // maximum number of IPs in each subnet
+
+ members map[string]uint
+ buf net.IP
+}
+
+// Add adds an IP address to the set. It returns false (and doesn't add the IP) if the
+// number of existing IPs in the defined range exceeds the limit.
+func (s *DistinctNetSet) Add(ip net.IP) bool {
+ key := s.key(ip)
+ n := s.members[string(key)]
+ if n < s.Limit {
+ s.members[string(key)] = n + 1
+ return true
+ }
+ return false
+}
+
+// Remove removes an IP from the set.
+func (s *DistinctNetSet) Remove(ip net.IP) {
+ key := s.key(ip)
+ if n, ok := s.members[string(key)]; ok {
+ if n == 1 {
+ delete(s.members, string(key))
+ } else {
+ s.members[string(key)] = n - 1
+ }
+ }
+}
+
+// Contains whether the given IP is contained in the set.
+func (s DistinctNetSet) Contains(ip net.IP) bool {
+ key := s.key(ip)
+ _, ok := s.members[string(key)]
+ return ok
+}
+
+// Len returns the number of tracked IPs.
+func (s DistinctNetSet) Len() int {
+ n := uint(0)
+ for _, i := range s.members {
+ n += i
+ }
+ return int(n)
+}
+
+// key encodes the map key for an address into a temporary buffer.
+//
+// The first byte of key is '4' or '6' to distinguish IPv4/IPv6 address types.
+// The remainder of the key is the IP, truncated to the number of bits.
+func (s *DistinctNetSet) key(ip net.IP) net.IP {
+ // Lazily initialize storage.
+ if s.members == nil {
+ s.members = make(map[string]uint)
+ s.buf = make(net.IP, 17)
+ }
+ // Canonicalize ip and bits.
+ typ := byte('6')
+ if ip4 := ip.To4(); ip4 != nil {
+ typ, ip = '4', ip4
+ }
+ bits := s.Subnet
+ if bits > uint(len(ip)*8) {
+ bits = uint(len(ip) * 8)
+ }
+ // Encode the prefix into s.buf.
+ nb := int(bits / 8)
+ mask := ^byte(0xFF >> (bits % 8))
+ s.buf[0] = typ
+ buf := append(s.buf[:1], ip[:nb]...)
+ if nb < len(ip) && mask != 0 {
+ buf = append(buf, ip[nb]&mask)
+ }
+ return buf
+}
+
+// String implements fmt.Stringer
+func (s DistinctNetSet) String() string {
+ var buf bytes.Buffer
+ buf.WriteString("{")
+ keys := make([]string, 0, len(s.members))
+ for k := range s.members {
+ keys = append(keys, k)
+ }
+ sort.Strings(keys)
+ for i, k := range keys {
+ var ip net.IP
+ if k[0] == '4' {
+ ip = make(net.IP, 4)
+ } else {
+ ip = make(net.IP, 16)
+ }
+ copy(ip, k[1:])
+ fmt.Fprintf(&buf, "%vĂ—%d", ip, s.members[k])
+ if i != len(keys)-1 {
+ buf.WriteString(" ")
+ }
+ }
+ buf.WriteString("}")
+ return buf.String()
+}
diff --git a/p2p/netutil/net_test.go b/p2p/netutil/net_test.go
index 1ee1fcb4d..3a6aa081f 100644
--- a/p2p/netutil/net_test.go
+++ b/p2p/netutil/net_test.go
@@ -17,9 +17,11 @@
package netutil
import (
+ "fmt"
"net"
"reflect"
"testing"
+ "testing/quick"
"github.com/davecgh/go-spew/spew"
)
@@ -171,3 +173,90 @@ func BenchmarkCheckRelayIP(b *testing.B) {
CheckRelayIP(sender, addr)
}
}
+
+func TestSameNet(t *testing.T) {
+ tests := []struct {
+ ip, other string
+ bits uint
+ want bool
+ }{
+ {"0.0.0.0", "0.0.0.0", 32, true},
+ {"0.0.0.0", "0.0.0.1", 0, true},
+ {"0.0.0.0", "0.0.0.1", 31, true},
+ {"0.0.0.0", "0.0.0.1", 32, false},
+ {"0.33.0.1", "0.34.0.2", 8, true},
+ {"0.33.0.1", "0.34.0.2", 13, true},
+ {"0.33.0.1", "0.34.0.2", 15, false},
+ }
+
+ for _, test := range tests {
+ if ok := SameNet(test.bits, parseIP(test.ip), parseIP(test.other)); ok != test.want {
+ t.Errorf("SameNet(%d, %s, %s) == %t, want %t", test.bits, test.ip, test.other, ok, test.want)
+ }
+ }
+}
+
+func ExampleSameNet() {
+ // This returns true because the IPs are in the same /24 network:
+ fmt.Println(SameNet(24, net.IP{127, 0, 0, 1}, net.IP{127, 0, 0, 3}))
+ // This call returns false:
+ fmt.Println(SameNet(24, net.IP{127, 3, 0, 1}, net.IP{127, 5, 0, 3}))
+ // Output:
+ // true
+ // false
+}
+
+func TestDistinctNetSet(t *testing.T) {
+ ops := []struct {
+ add, remove string
+ fails bool
+ }{
+ {add: "127.0.0.1"},
+ {add: "127.0.0.2"},
+ {add: "127.0.0.3", fails: true},
+ {add: "127.32.0.1"},
+ {add: "127.32.0.2"},
+ {add: "127.32.0.3", fails: true},
+ {add: "127.33.0.1", fails: true},
+ {add: "127.34.0.1"},
+ {add: "127.34.0.2"},
+ {add: "127.34.0.3", fails: true},
+ // Make room for an address, then add again.
+ {remove: "127.0.0.1"},
+ {add: "127.0.0.3"},
+ {add: "127.0.0.3", fails: true},
+ }
+
+ set := DistinctNetSet{Subnet: 15, Limit: 2}
+ for _, op := range ops {
+ var desc string
+ if op.add != "" {
+ desc = fmt.Sprintf("Add(%s)", op.add)
+ if ok := set.Add(parseIP(op.add)); ok != !op.fails {
+ t.Errorf("%s == %t, want %t", desc, ok, !op.fails)
+ }
+ } else {
+ desc = fmt.Sprintf("Remove(%s)", op.remove)
+ set.Remove(parseIP(op.remove))
+ }
+ t.Logf("%s: %v", desc, set)
+ }
+}
+
+func TestDistinctNetSetAddRemove(t *testing.T) {
+ cfg := &quick.Config{}
+ fn := func(ips []net.IP) bool {
+ s := DistinctNetSet{Limit: 3, Subnet: 2}
+ for _, ip := range ips {
+ s.Add(ip)
+ }
+ for _, ip := range ips {
+ s.Remove(ip)
+ }
+ return s.Len() == 0
+ }
+
+ if err := quick.Check(fn, cfg); err != nil {
+ t.Fatal(err)
+ }
+}
diff --git a/p2p/peer.go b/p2p/peer.go
index bad1c8c8b..477d8c219 100644
--- a/p2p/peer.go
+++ b/p2p/peer.go
@@ -419,6 +419,9 @@ type PeerInfo struct {
Network struct {
LocalAddress string `json:"localAddress"` // Local endpoint of the TCP data connection
RemoteAddress string `json:"remoteAddress"` // Remote endpoint of the TCP data connection
+ Inbound bool `json:"inbound"`
+ Trusted bool `json:"trusted"`
+ Static bool `json:"static"`
} `json:"network"`
Protocols map[string]interface{} `json:"protocols"` // Sub-protocol specific metadata fields
}
@@ -439,6 +442,9 @@ func (p *Peer) Info() *PeerInfo {
}
info.Network.LocalAddress = p.LocalAddr().String()
info.Network.RemoteAddress = p.RemoteAddr().String()
+ info.Network.Inbound = p.rw.is(inboundConn)
+ info.Network.Trusted = p.rw.is(trustedConn)
+ info.Network.Static = p.rw.is(staticDialedConn)
// Gather all the running protocol infos
for _, proto := range p.running {
diff --git a/p2p/rlpx.go b/p2p/rlpx.go
index 24037ecc1..e65a0b604 100644
--- a/p2p/rlpx.go
+++ b/p2p/rlpx.go
@@ -108,8 +108,14 @@ func (t *rlpx) close(err error) {
// Tell the remote end why we're disconnecting if possible.
if t.rw != nil {
if r, ok := err.(DiscReason); ok && r != DiscNetworkError {
- t.fd.SetWriteDeadline(time.Now().Add(discWriteTimeout))
- SendItems(t.rw, discMsg, r)
+ // rlpx tries to send DiscReason to disconnected peer
+ // if the connection is net.Pipe (in-memory simulation)
+ // it hangs forever, since net.Pipe does not implement
+ // a write deadline. Because of this only try to send
+ // the disconnect reason message if there is no error.
+ if err := t.fd.SetWriteDeadline(time.Now().Add(discWriteTimeout)); err == nil {
+ SendItems(t.rw, discMsg, r)
+ }
}
}
t.fd.Close()
diff --git a/p2p/rlpx_test.go b/p2p/rlpx_test.go
index f4cefa650..bca460402 100644
--- a/p2p/rlpx_test.go
+++ b/p2p/rlpx_test.go
@@ -156,14 +156,18 @@ func TestProtocolHandshake(t *testing.T) {
node1 = &discover.Node{ID: discover.PubkeyID(&prv1.PublicKey), IP: net.IP{5, 6, 7, 8}, TCP: 44}
hs1 = &protoHandshake{Version: 3, ID: node1.ID, Caps: []Cap{{"c", 1}, {"d", 3}}}
- fd0, fd1 = net.Pipe()
- wg sync.WaitGroup
+ wg sync.WaitGroup
)
+ fd0, fd1, err := tcpPipe()
+ if err != nil {
+ t.Fatal(err)
+ }
+
wg.Add(2)
go func() {
defer wg.Done()
- defer fd1.Close()
+ defer fd0.Close()
rlpx := newRLPX(fd0)
remid, err := rlpx.doEncHandshake(prv0, node1)
if err != nil {
@@ -597,3 +601,31 @@ func TestHandshakeForwardCompatibility(t *testing.T) {
t.Errorf("ingress-mac('foo') mismatch:\ngot %x\nwant %x", fooIngressHash, wantFooIngressHash)
}
}
+
+// tcpPipe creates an in process full duplex pipe based on a localhost TCP socket
+func tcpPipe() (net.Conn, net.Conn, error) {
+ l, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ return nil, nil, err
+ }
+ defer l.Close()
+
+ var aconn net.Conn
+ aerr := make(chan error, 1)
+ go func() {
+ var err error
+ aconn, err = l.Accept()
+ aerr <- err
+ }()
+
+ dconn, err := net.Dial("tcp", l.Addr().String())
+ if err != nil {
+ <-aerr
+ return nil, nil, err
+ }
+ if err := <-aerr; err != nil {
+ dconn.Close()
+ return nil, nil, err
+ }
+ return aconn, dconn, nil
+}
diff --git a/p2p/server.go b/p2p/server.go
index 922df55ba..90e92dc05 100644
--- a/p2p/server.go
+++ b/p2p/server.go
@@ -40,11 +40,10 @@ const (
refreshPeersInterval = 30 * time.Second
staticPeerCheckInterval = 15 * time.Second
- // Maximum number of concurrently handshaking inbound connections.
- maxAcceptConns = 50
-
- // Maximum number of concurrently dialing outbound connections.
- maxActiveDialTasks = 16
+ // Connectivity defaults.
+ maxActiveDialTasks = 16
+ defaultMaxPendingPeers = 50
+ defaultDialRatio = 3
// Maximum time allowed for reading a complete message.
// This is effectively the amount of time a connection can be idle.
@@ -70,6 +69,11 @@ type Config struct {
// Zero defaults to preset values.
MaxPendingPeers int `toml:",omitempty"`
+ // DialRatio controls the ratio of inbound to dialed connections.
+ // Example: a DialRatio of 2 allows 1/2 of connections to be dialed.
+ // Setting DialRatio to zero defaults it to 3.
+ DialRatio int `toml:",omitempty"`
+
// NoDiscovery can be used to disable the peer discovery mechanism.
// Disabling is useful for protocol debugging (manual topology).
NoDiscovery bool
@@ -78,9 +82,6 @@ type Config struct {
// protocol should be started or not.
DiscoveryV5 bool `toml:",omitempty"`
- // Listener address for the V5 discovery protocol UDP traffic.
- DiscoveryV5Addr string `toml:",omitempty"`
-
// Name sets the node name of this server.
// Use common.MakeName to create a name that follows existing conventions.
Name string `toml:"-"`
@@ -141,7 +142,7 @@ type Config struct {
EnableMsgEvents bool
// Logger is a custom logger to use with the p2p.Server.
- Logger log.Logger
+ Logger log.Logger `toml:",omitempty"`
}
// Server manages all peer connections.
@@ -354,6 +355,32 @@ func (srv *Server) Stop() {
srv.loopWG.Wait()
}
+// sharedUDPConn implements a shared connection. Write sends messages to the underlying connection while read returns
+// messages that were found unprocessable and sent to the unhandled channel by the primary listener.
+type sharedUDPConn struct {
+ *net.UDPConn
+ unhandled chan discover.ReadPacket
+}
+
+// ReadFromUDP implements discv5.conn
+func (s *sharedUDPConn) ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) {
+ packet, ok := <-s.unhandled
+ if !ok {
+ return 0, nil, fmt.Errorf("Connection was closed")
+ }
+ l := len(packet.Data)
+ if l > len(b) {
+ l = len(b)
+ }
+ copy(b[:l], packet.Data[:l])
+ return l, packet.Addr, nil
+}
+
+// Close implements discv5.conn
+func (s *sharedUDPConn) Close() error {
+ return nil
+}
+
// Start starts running the server.
// Servers can not be re-used after stopping.
func (srv *Server) Start() (err error) {
@@ -388,20 +415,66 @@ func (srv *Server) Start() (err error) {
srv.peerOp = make(chan peerOpFunc)
srv.peerOpDone = make(chan struct{})
- // node table
- if !srv.NoDiscovery {
- ntab, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT, srv.NodeDatabase, srv.NetRestrict)
+ var (
+ conn *net.UDPConn
+ sconn *sharedUDPConn
+ realaddr *net.UDPAddr
+ unhandled chan discover.ReadPacket
+ )
+
+ if !srv.NoDiscovery || srv.DiscoveryV5 {
+ addr, err := net.ResolveUDPAddr("udp", srv.ListenAddr)
if err != nil {
return err
}
- if err := ntab.SetFallbackNodes(srv.BootstrapNodes); err != nil {
+ conn, err = net.ListenUDP("udp", addr)
+ if err != nil {
+ return err
+ }
+ realaddr = conn.LocalAddr().(*net.UDPAddr)
+ if srv.NAT != nil {
+ if !realaddr.IP.IsLoopback() {
+ go nat.Map(srv.NAT, srv.quit, "udp", realaddr.Port, realaddr.Port, "ethereum discovery")
+ }
+ // TODO: react to external IP changes over time.
+ if ext, err := srv.NAT.ExternalIP(); err == nil {
+ realaddr = &net.UDPAddr{IP: ext, Port: realaddr.Port}
+ }
+ }
+ }
+
+ if !srv.NoDiscovery && srv.DiscoveryV5 {
+ unhandled = make(chan discover.ReadPacket, 100)
+ sconn = &sharedUDPConn{conn, unhandled}
+ }
+
+ // node table
+ if !srv.NoDiscovery {
+ cfg := discover.Config{
+ PrivateKey: srv.PrivateKey,
+ AnnounceAddr: realaddr,
+ NodeDBPath: srv.NodeDatabase,
+ NetRestrict: srv.NetRestrict,
+ Bootnodes: srv.BootstrapNodes,
+ Unhandled: unhandled,
+ }
+ ntab, err := discover.ListenUDP(conn, cfg)
+ if err != nil {
return err
}
srv.ntab = ntab
}
if srv.DiscoveryV5 {
- ntab, err := discv5.ListenUDP(srv.PrivateKey, srv.DiscoveryV5Addr, srv.NAT, "", srv.NetRestrict) //srv.NodeDatabase)
+ var (
+ ntab *discv5.Network
+ err error
+ )
+ if sconn != nil {
+ ntab, err = discv5.ListenUDP(srv.PrivateKey, sconn, realaddr, "", srv.NetRestrict) //srv.NodeDatabase)
+ } else {
+ ntab, err = discv5.ListenUDP(srv.PrivateKey, conn, realaddr, "", srv.NetRestrict) //srv.NodeDatabase)
+ }
if err != nil {
return err
}
@@ -411,10 +484,7 @@ func (srv *Server) Start() (err error) {
srv.DiscV5 = ntab
}
- dynPeers := (srv.MaxPeers + 1) / 2
- if srv.NoDiscovery {
- dynPeers = 0
- }
+ dynPeers := srv.maxDialedConns()
dialer := newDialState(srv.StaticNodes, srv.BootstrapNodes, srv.ntab, dynPeers, srv.NetRestrict)
// handshake
@@ -471,6 +541,7 @@ func (srv *Server) run(dialstate dialer) {
defer srv.loopWG.Done()
var (
peers = make(map[discover.NodeID]*Peer)
+ inboundCount = 0
trusted = make(map[discover.NodeID]bool, len(srv.TrustedNodes))
taskdone = make(chan task, maxActiveDialTasks)
runningTasks []task
@@ -556,14 +627,14 @@ running:
}
// TODO: track in-progress inbound node IDs (pre-Peer) to avoid dialing them.
select {
- case c.cont <- srv.encHandshakeChecks(peers, c):
+ case c.cont <- srv.encHandshakeChecks(peers, inboundCount, c):
case <-srv.quit:
break running
}
case c := <-srv.addpeer:
// At this point the connection is past the protocol handshake.
// Its capabilities are known and the remote identity is verified.
- err := srv.protoHandshakeChecks(peers, c)
+ err := srv.protoHandshakeChecks(peers, inboundCount, c)
if err == nil {
// The handshakes are done and it passed all checks.
p := newPeer(c, srv.Protocols)
@@ -574,8 +645,11 @@ running:
}
name := truncateName(c.name)
srv.log.Debug("Adding p2p peer", "name", name, "addr", c.fd.RemoteAddr(), "peers", len(peers)+1)
- peers[c.id] = p
go srv.runPeer(p)
+ peers[c.id] = p
+ if p.Inbound() {
+ inboundCount++
+ }
}
// The dialer logic relies on the assumption that
// dial tasks complete after the peer has been added or
@@ -590,6 +664,9 @@ running:
d := common.PrettyDuration(mclock.Now() - pd.created)
pd.log.Debug("Removing p2p peer", "duration", d, "peers", len(peers)-1, "req", pd.requested, "err", pd.err)
delete(peers, pd.ID())
+ if pd.Inbound() {
+ inboundCount--
+ }
}
}
@@ -616,20 +693,22 @@ running:
}
}
-func (srv *Server) protoHandshakeChecks(peers map[discover.NodeID]*Peer, c *conn) error {
+func (srv *Server) protoHandshakeChecks(peers map[discover.NodeID]*Peer, inboundCount int, c *conn) error {
// Drop connections with no matching protocols.
if len(srv.Protocols) > 0 && countMatchingProtocols(srv.Protocols, c.caps) == 0 {
return DiscUselessPeer
}
// Repeat the encryption handshake checks because the
// peer set might have changed between the handshakes.
- return srv.encHandshakeChecks(peers, c)
+ return srv.encHandshakeChecks(peers, inboundCount, c)
}
-func (srv *Server) encHandshakeChecks(peers map[discover.NodeID]*Peer, c *conn) error {
+func (srv *Server) encHandshakeChecks(peers map[discover.NodeID]*Peer, inboundCount int, c *conn) error {
switch {
case !c.is(trustedConn|staticDialedConn) && len(peers) >= srv.MaxPeers:
return DiscTooManyPeers
+ case !c.is(trustedConn) && c.is(inboundConn) && inboundCount >= srv.maxInboundConns():
+ return DiscTooManyPeers
case peers[c.id] != nil:
return DiscAlreadyConnected
case c.id == srv.Self().ID:
@@ -639,6 +718,21 @@ func (srv *Server) encHandshakeChecks(peers map[discover.NodeID]*Peer, c *conn)
}
}
+func (srv *Server) maxInboundConns() int {
+ return srv.MaxPeers - srv.maxDialedConns()
+}
+
+func (srv *Server) maxDialedConns() int {
+ if srv.NoDiscovery || srv.NoDial {
+ return 0
+ }
+ r := srv.DialRatio
+ if r == 0 {
+ r = defaultDialRatio
+ }
+ return srv.MaxPeers / r
+}
+
type tempError interface {
Temporary() bool
}
@@ -649,10 +743,7 @@ func (srv *Server) listenLoop() {
defer srv.loopWG.Done()
srv.log.Info("RLPx listener up", "self", srv.makeSelf(srv.listener, srv.ntab))
- // This channel acts as a semaphore limiting
- // active inbound connections that are lingering pre-handshake.
- // If all slots are taken, no further connections are accepted.
- tokens := maxAcceptConns
+ tokens := defaultMaxPendingPeers
if srv.MaxPendingPeers > 0 {
tokens = srv.MaxPendingPeers
}
@@ -693,9 +784,6 @@ func (srv *Server) listenLoop() {
fd = newMeteredConn(fd, true)
srv.log.Trace("Accepted connection", "addr", fd.RemoteAddr())
-
- // Spawn the handler. It will give the slot back when the connection
- // has been established.
go func() {
srv.SetupConn(fd, inboundConn, nil)
slots <- struct{}{}
diff --git a/p2p/simulations/adapters/state.go b/p2p/simulations/adapters/state.go
index 8b1dfef90..0d4ecfb0f 100644
--- a/p2p/simulations/adapters/state.go
+++ b/p2p/simulations/adapters/state.go
@@ -13,6 +13,7 @@
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
package adapters
type SimStateStore struct {