diff options
Diffstat (limited to 'p2p')
-rw-r--r-- | p2p/dial.go | 13 | ||||
-rw-r--r-- | p2p/dial_test.go | 44 | ||||
-rw-r--r-- | p2p/discover/database.go | 17 | ||||
-rw-r--r-- | p2p/discover/database_test.go | 18 | ||||
-rw-r--r-- | p2p/discover/node.go | 6 | ||||
-rw-r--r-- | p2p/discover/table.go | 486 | ||||
-rw-r--r-- | p2p/discover/table_test.go | 171 | ||||
-rw-r--r-- | p2p/discover/udp.go | 118 | ||||
-rw-r--r-- | p2p/discover/udp_test.go | 28 | ||||
-rw-r--r-- | p2p/discv5/net.go | 43 | ||||
-rw-r--r-- | p2p/discv5/net_test.go | 2 | ||||
-rw-r--r-- | p2p/discv5/sim_test.go | 2 | ||||
-rw-r--r-- | p2p/discv5/ticket.go | 16 | ||||
-rw-r--r-- | p2p/discv5/udp.go | 67 | ||||
-rw-r--r-- | p2p/netutil/net.go | 131 | ||||
-rw-r--r-- | p2p/netutil/net_test.go | 89 | ||||
-rw-r--r-- | p2p/peer.go | 6 | ||||
-rw-r--r-- | p2p/rlpx.go | 10 | ||||
-rw-r--r-- | p2p/rlpx_test.go | 38 | ||||
-rw-r--r-- | p2p/server.go | 150 | ||||
-rw-r--r-- | p2p/simulations/adapters/state.go | 1 |
21 files changed, 1067 insertions, 389 deletions
diff --git a/p2p/dial.go b/p2p/dial.go index f5ff2c211..d8feceb9f 100644 --- a/p2p/dial.go +++ b/p2p/dial.go @@ -154,6 +154,9 @@ func (s *dialstate) addStatic(n *discover.Node) { func (s *dialstate) removeStatic(n *discover.Node) { // This removes a task so future attempts to connect will not be made. delete(s.static, n.ID) + // This removes a previous dial timestamp so that application + // can force a server to reconnect with chosen peer immediately. + s.hist.remove(n.ID) } func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now time.Time) []task { @@ -390,6 +393,16 @@ func (h dialHistory) min() pastDial { } func (h *dialHistory) add(id discover.NodeID, exp time.Time) { heap.Push(h, pastDial{id, exp}) + +} +func (h *dialHistory) remove(id discover.NodeID) bool { + for i, v := range *h { + if v.id == id { + heap.Remove(h, i) + return true + } + } + return false } func (h dialHistory) contains(id discover.NodeID) bool { for _, v := range h { diff --git a/p2p/dial_test.go b/p2p/dial_test.go index ad18ef9ab..2a7941fc6 100644 --- a/p2p/dial_test.go +++ b/p2p/dial_test.go @@ -515,6 +515,50 @@ func TestDialStateStaticDial(t *testing.T) { }) } +// This test checks that static peers will be redialed immediately if they were re-added to a static list. +func TestDialStaticAfterReset(t *testing.T) { + wantStatic := []*discover.Node{ + {ID: uintID(1)}, + {ID: uintID(2)}, + } + + rounds := []round{ + // Static dials are launched for the nodes that aren't yet connected. + { + peers: nil, + new: []task{ + &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(1)}}, + &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(2)}}, + }, + }, + // No new dial tasks, all peers are connected. + { + peers: []*Peer{ + {rw: &conn{flags: staticDialedConn, id: uintID(1)}}, + {rw: &conn{flags: staticDialedConn, id: uintID(2)}}, + }, + done: []task{ + &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(1)}}, + &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(2)}}, + }, + new: []task{ + &waitExpireTask{Duration: 30 * time.Second}, + }, + }, + } + dTest := dialtest{ + init: newDialState(wantStatic, nil, fakeTable{}, 0, nil), + rounds: rounds, + } + runDialTest(t, dTest) + for _, n := range wantStatic { + dTest.init.removeStatic(n) + dTest.init.addStatic(n) + } + // without removing peers they will be considered recently dialed + runDialTest(t, dTest) +} + // This test checks that past dials are not retried for some time. func TestDialStateCache(t *testing.T) { wantStatic := []*discover.Node{ diff --git a/p2p/discover/database.go b/p2p/discover/database.go index b136609f2..6f98de9b4 100644 --- a/p2p/discover/database.go +++ b/p2p/discover/database.go @@ -257,7 +257,7 @@ func (db *nodeDB) expireNodes() error { } // Skip the node if not expired yet (and not self) if !bytes.Equal(id[:], db.self[:]) { - if seen := db.lastPong(id); seen.After(threshold) { + if seen := db.bondTime(id); seen.After(threshold) { continue } } @@ -278,13 +278,18 @@ func (db *nodeDB) updateLastPing(id NodeID, instance time.Time) error { return db.storeInt64(makeKey(id, nodeDBDiscoverPing), instance.Unix()) } -// lastPong retrieves the time of the last successful contact from remote node. -func (db *nodeDB) lastPong(id NodeID) time.Time { +// bondTime retrieves the time of the last successful pong from remote node. +func (db *nodeDB) bondTime(id NodeID) time.Time { return time.Unix(db.fetchInt64(makeKey(id, nodeDBDiscoverPong)), 0) } -// updateLastPong updates the last time a remote node successfully contacted. -func (db *nodeDB) updateLastPong(id NodeID, instance time.Time) error { +// hasBond reports whether the given node is considered bonded. +func (db *nodeDB) hasBond(id NodeID) bool { + return time.Since(db.bondTime(id)) < nodeDBNodeExpiration +} + +// updateBondTime updates the last pong time of a node. +func (db *nodeDB) updateBondTime(id NodeID, instance time.Time) error { return db.storeInt64(makeKey(id, nodeDBDiscoverPong), instance.Unix()) } @@ -327,7 +332,7 @@ seek: if n.ID == db.self { continue seek } - if now.Sub(db.lastPong(n.ID)) > maxAge { + if now.Sub(db.bondTime(n.ID)) > maxAge { continue seek } for i := range nodes { diff --git a/p2p/discover/database_test.go b/p2p/discover/database_test.go index be972fd2c..c4fa44d09 100644 --- a/p2p/discover/database_test.go +++ b/p2p/discover/database_test.go @@ -125,13 +125,13 @@ func TestNodeDBFetchStore(t *testing.T) { t.Errorf("ping: value mismatch: have %v, want %v", stored, inst) } // Check fetch/store operations on a node pong object - if stored := db.lastPong(node.ID); stored.Unix() != 0 { + if stored := db.bondTime(node.ID); stored.Unix() != 0 { t.Errorf("pong: non-existing object: %v", stored) } - if err := db.updateLastPong(node.ID, inst); err != nil { + if err := db.updateBondTime(node.ID, inst); err != nil { t.Errorf("pong: failed to update: %v", err) } - if stored := db.lastPong(node.ID); stored.Unix() != inst.Unix() { + if stored := db.bondTime(node.ID); stored.Unix() != inst.Unix() { t.Errorf("pong: value mismatch: have %v, want %v", stored, inst) } // Check fetch/store operations on a node findnode-failure object @@ -224,8 +224,8 @@ func TestNodeDBSeedQuery(t *testing.T) { if err := db.updateNode(seed.node); err != nil { t.Fatalf("node %d: failed to insert: %v", i, err) } - if err := db.updateLastPong(seed.node.ID, seed.pong); err != nil { - t.Fatalf("node %d: failed to insert lastPong: %v", i, err) + if err := db.updateBondTime(seed.node.ID, seed.pong); err != nil { + t.Fatalf("node %d: failed to insert bondTime: %v", i, err) } } @@ -332,8 +332,8 @@ func TestNodeDBExpiration(t *testing.T) { if err := db.updateNode(seed.node); err != nil { t.Fatalf("node %d: failed to insert: %v", i, err) } - if err := db.updateLastPong(seed.node.ID, seed.pong); err != nil { - t.Fatalf("node %d: failed to update pong: %v", i, err) + if err := db.updateBondTime(seed.node.ID, seed.pong); err != nil { + t.Fatalf("node %d: failed to update bondTime: %v", i, err) } } // Expire some of them, and check the rest @@ -365,8 +365,8 @@ func TestNodeDBSelfExpiration(t *testing.T) { if err := db.updateNode(seed.node); err != nil { t.Fatalf("node %d: failed to insert: %v", i, err) } - if err := db.updateLastPong(seed.node.ID, seed.pong); err != nil { - t.Fatalf("node %d: failed to update pong: %v", i, err) + if err := db.updateBondTime(seed.node.ID, seed.pong); err != nil { + t.Fatalf("node %d: failed to update bondTime: %v", i, err) } } // Expire the nodes and make sure self has been evacuated too diff --git a/p2p/discover/node.go b/p2p/discover/node.go index fc928a91a..3b0c84115 100644 --- a/p2p/discover/node.go +++ b/p2p/discover/node.go @@ -29,6 +29,7 @@ import ( "regexp" "strconv" "strings" + "time" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" @@ -51,9 +52,8 @@ type Node struct { // with ID. sha common.Hash - // whether this node is currently being pinged in order to replace - // it in a bucket - contested bool + // Time when the node was added to the table. + addedAt time.Time } // NewNode creates a new node. It is mostly meant to be used for diff --git a/p2p/discover/table.go b/p2p/discover/table.go index ec4eb94ad..6509326e6 100644 --- a/p2p/discover/table.go +++ b/p2p/discover/table.go @@ -23,10 +23,11 @@ package discover import ( - "crypto/rand" + crand "crypto/rand" "encoding/binary" "errors" "fmt" + mrand "math/rand" "net" "sort" "sync" @@ -35,29 +36,45 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p/netutil" ) const ( - alpha = 3 // Kademlia concurrency factor - bucketSize = 16 // Kademlia bucket size - hashBits = len(common.Hash{}) * 8 - nBuckets = hashBits + 1 // Number of buckets - - maxBondingPingPongs = 16 - maxFindnodeFailures = 5 - - autoRefreshInterval = 1 * time.Hour - seedCount = 30 - seedMaxAge = 5 * 24 * time.Hour + alpha = 3 // Kademlia concurrency factor + bucketSize = 16 // Kademlia bucket size + maxReplacements = 10 // Size of per-bucket replacement list + + // We keep buckets for the upper 1/15 of distances because + // it's very unlikely we'll ever encounter a node that's closer. + hashBits = len(common.Hash{}) * 8 + nBuckets = hashBits / 15 // Number of buckets + bucketMinDistance = hashBits - nBuckets // Log distance of closest bucket + + // IP address limits. + bucketIPLimit, bucketSubnet = 2, 24 // at most 2 addresses from the same /24 + tableIPLimit, tableSubnet = 10, 24 + + maxBondingPingPongs = 16 // Limit on the number of concurrent ping/pong interactions + maxFindnodeFailures = 5 // Nodes exceeding this limit are dropped + + refreshInterval = 30 * time.Minute + revalidateInterval = 10 * time.Second + copyNodesInterval = 30 * time.Second + seedMinTableTime = 5 * time.Minute + seedCount = 30 + seedMaxAge = 5 * 24 * time.Hour ) type Table struct { - mutex sync.Mutex // protects buckets, their content, and nursery + mutex sync.Mutex // protects buckets, bucket content, nursery, rand buckets [nBuckets]*bucket // index of known nodes by distance nursery []*Node // bootstrap nodes - db *nodeDB // database of known nodes + rand *mrand.Rand // source of randomness, periodically reseeded + ips netutil.DistinctNetSet + db *nodeDB // database of known nodes refreshReq chan chan struct{} + initDone chan struct{} closeReq chan struct{} closed chan struct{} @@ -89,9 +106,13 @@ type transport interface { // bucket contains nodes, ordered by their last activity. the entry // that was most recently active is the first element in entries. -type bucket struct{ entries []*Node } +type bucket struct { + entries []*Node // live entries, sorted by time of last contact + replacements []*Node // recently seen nodes to be used if revalidation fails + ips netutil.DistinctNetSet +} -func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string) (*Table, error) { +func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string, bootnodes []*Node) (*Table, error) { // If no node database was given, use an in-memory one db, err := newNodeDB(nodeDBPath, Version, ourID) if err != nil { @@ -104,19 +125,42 @@ func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string bonding: make(map[NodeID]*bondproc), bondslots: make(chan struct{}, maxBondingPingPongs), refreshReq: make(chan chan struct{}), + initDone: make(chan struct{}), closeReq: make(chan struct{}), closed: make(chan struct{}), + rand: mrand.New(mrand.NewSource(0)), + ips: netutil.DistinctNetSet{Subnet: tableSubnet, Limit: tableIPLimit}, + } + if err := tab.setFallbackNodes(bootnodes); err != nil { + return nil, err } for i := 0; i < cap(tab.bondslots); i++ { tab.bondslots <- struct{}{} } for i := range tab.buckets { - tab.buckets[i] = new(bucket) + tab.buckets[i] = &bucket{ + ips: netutil.DistinctNetSet{Subnet: bucketSubnet, Limit: bucketIPLimit}, + } } - go tab.refreshLoop() + tab.seedRand() + tab.loadSeedNodes(false) + // Start the background expiration goroutine after loading seeds so that the search for + // seed nodes also considers older nodes that would otherwise be removed by the + // expiration. + tab.db.ensureExpirer() + go tab.loop() return tab, nil } +func (tab *Table) seedRand() { + var b [8]byte + crand.Read(b[:]) + + tab.mutex.Lock() + tab.rand.Seed(int64(binary.BigEndian.Uint64(b[:]))) + tab.mutex.Unlock() +} + // Self returns the local node. // The returned node should not be modified by the caller. func (tab *Table) Self() *Node { @@ -127,9 +171,12 @@ func (tab *Table) Self() *Node { // table. It will not write the same node more than once. The nodes in // the slice are copies and can be modified by the caller. func (tab *Table) ReadRandomNodes(buf []*Node) (n int) { + if !tab.isInitDone() { + return 0 + } tab.mutex.Lock() defer tab.mutex.Unlock() - // TODO: tree-based buckets would help here + // Find all non-empty buckets and get a fresh slice of their entries. var buckets [][]*Node for _, b := range tab.buckets { @@ -141,8 +188,8 @@ func (tab *Table) ReadRandomNodes(buf []*Node) (n int) { return 0 } // Shuffle the buckets. - for i := uint32(len(buckets)) - 1; i > 0; i-- { - j := randUint(i) + for i := len(buckets) - 1; i > 0; i-- { + j := tab.rand.Intn(len(buckets)) buckets[i], buckets[j] = buckets[j], buckets[i] } // Move head of each bucket into buf, removing buckets that become empty. @@ -161,15 +208,6 @@ func (tab *Table) ReadRandomNodes(buf []*Node) (n int) { return i + 1 } -func randUint(max uint32) uint32 { - if max == 0 { - return 0 - } - var b [4]byte - rand.Read(b[:]) - return binary.BigEndian.Uint32(b[:]) % max -} - // Close terminates the network listener and flushes the node database. func (tab *Table) Close() { select { @@ -180,16 +218,15 @@ func (tab *Table) Close() { } } -// SetFallbackNodes sets the initial points of contact. These nodes +// setFallbackNodes sets the initial points of contact. These nodes // are used to connect to the network if the table is empty and there // are no known nodes in the database. -func (tab *Table) SetFallbackNodes(nodes []*Node) error { +func (tab *Table) setFallbackNodes(nodes []*Node) error { for _, n := range nodes { if err := n.validateComplete(); err != nil { return fmt.Errorf("bad bootstrap/fallback node %q (%v)", n, err) } } - tab.mutex.Lock() tab.nursery = make([]*Node, 0, len(nodes)) for _, n := range nodes { cpy := *n @@ -198,11 +235,19 @@ func (tab *Table) SetFallbackNodes(nodes []*Node) error { cpy.sha = crypto.Keccak256Hash(n.ID[:]) tab.nursery = append(tab.nursery, &cpy) } - tab.mutex.Unlock() - tab.refresh() return nil } +// isInitDone returns whether the table's initial seeding procedure has completed. +func (tab *Table) isInitDone() bool { + select { + case <-tab.initDone: + return true + default: + return false + } +} + // Resolve searches for a specific node with the given ID. // It returns nil if the node could not be found. func (tab *Table) Resolve(targetID NodeID) *Node { @@ -314,33 +359,49 @@ func (tab *Table) refresh() <-chan struct{} { return done } -// refreshLoop schedules doRefresh runs and coordinates shutdown. -func (tab *Table) refreshLoop() { +// loop schedules refresh, revalidate runs and coordinates shutdown. +func (tab *Table) loop() { var ( - timer = time.NewTicker(autoRefreshInterval) - waiting []chan struct{} // accumulates waiting callers while doRefresh runs - done chan struct{} // where doRefresh reports completion + revalidate = time.NewTimer(tab.nextRevalidateTime()) + refresh = time.NewTicker(refreshInterval) + copyNodes = time.NewTicker(copyNodesInterval) + revalidateDone = make(chan struct{}) + refreshDone = make(chan struct{}) // where doRefresh reports completion + waiting = []chan struct{}{tab.initDone} // holds waiting callers while doRefresh runs ) + defer refresh.Stop() + defer revalidate.Stop() + defer copyNodes.Stop() + + // Start initial refresh. + go tab.doRefresh(refreshDone) + loop: for { select { - case <-timer.C: - if done == nil { - done = make(chan struct{}) - go tab.doRefresh(done) + case <-refresh.C: + tab.seedRand() + if refreshDone == nil { + refreshDone = make(chan struct{}) + go tab.doRefresh(refreshDone) } case req := <-tab.refreshReq: waiting = append(waiting, req) - if done == nil { - done = make(chan struct{}) - go tab.doRefresh(done) + if refreshDone == nil { + refreshDone = make(chan struct{}) + go tab.doRefresh(refreshDone) } - case <-done: + case <-refreshDone: for _, ch := range waiting { close(ch) } - waiting = nil - done = nil + waiting, refreshDone = nil, nil + case <-revalidate.C: + go tab.doRevalidate(revalidateDone) + case <-revalidateDone: + revalidate.Reset(tab.nextRevalidateTime()) + case <-copyNodes.C: + go tab.copyBondedNodes() case <-tab.closeReq: break loop } @@ -349,8 +410,8 @@ loop: if tab.net != nil { tab.net.close() } - if done != nil { - <-done + if refreshDone != nil { + <-refreshDone } for _, ch := range waiting { close(ch) @@ -365,38 +426,109 @@ loop: func (tab *Table) doRefresh(done chan struct{}) { defer close(done) + // Load nodes from the database and insert + // them. This should yield a few previously seen nodes that are + // (hopefully) still alive. + tab.loadSeedNodes(true) + + // Run self lookup to discover new neighbor nodes. + tab.lookup(tab.self.ID, false) + // The Kademlia paper specifies that the bucket refresh should // perform a lookup in the least recently used bucket. We cannot // adhere to this because the findnode target is a 512bit value // (not hash-sized) and it is not easily possible to generate a // sha3 preimage that falls into a chosen bucket. - // We perform a lookup with a random target instead. - var target NodeID - rand.Read(target[:]) - result := tab.lookup(target, false) - if len(result) > 0 { - return + // We perform a few lookups with a random target instead. + for i := 0; i < 3; i++ { + var target NodeID + crand.Read(target[:]) + tab.lookup(target, false) } +} - // The table is empty. Load nodes from the database and insert - // them. This should yield a few previously seen nodes that are - // (hopefully) still alive. +func (tab *Table) loadSeedNodes(bond bool) { seeds := tab.db.querySeeds(seedCount, seedMaxAge) - seeds = tab.bondall(append(seeds, tab.nursery...)) + seeds = append(seeds, tab.nursery...) + if bond { + seeds = tab.bondall(seeds) + } + for i := range seeds { + seed := seeds[i] + age := log.Lazy{Fn: func() interface{} { return time.Since(tab.db.bondTime(seed.ID)) }} + log.Debug("Found seed node in database", "id", seed.ID, "addr", seed.addr(), "age", age) + tab.add(seed) + } +} - if len(seeds) == 0 { - log.Debug("No discv4 seed nodes found") +// doRevalidate checks that the last node in a random bucket is still live +// and replaces or deletes the node if it isn't. +func (tab *Table) doRevalidate(done chan<- struct{}) { + defer func() { done <- struct{}{} }() + + last, bi := tab.nodeToRevalidate() + if last == nil { + // No non-empty bucket found. + return + } + + // Ping the selected node and wait for a pong. + err := tab.ping(last.ID, last.addr()) + + tab.mutex.Lock() + defer tab.mutex.Unlock() + b := tab.buckets[bi] + if err == nil { + // The node responded, move it to the front. + log.Debug("Revalidated node", "b", bi, "id", last.ID) + b.bump(last) + return } - for _, n := range seeds { - age := log.Lazy{Fn: func() time.Duration { return time.Since(tab.db.lastPong(n.ID)) }} - log.Trace("Found seed node in database", "id", n.ID, "addr", n.addr(), "age", age) + // No reply received, pick a replacement or delete the node if there aren't + // any replacements. + if r := tab.replace(b, last); r != nil { + log.Debug("Replaced dead node", "b", bi, "id", last.ID, "ip", last.IP, "r", r.ID, "rip", r.IP) + } else { + log.Debug("Removed dead node", "b", bi, "id", last.ID, "ip", last.IP) } +} + +// nodeToRevalidate returns the last node in a random, non-empty bucket. +func (tab *Table) nodeToRevalidate() (n *Node, bi int) { tab.mutex.Lock() - tab.stuff(seeds) - tab.mutex.Unlock() + defer tab.mutex.Unlock() - // Finally, do a self lookup to fill up the buckets. - tab.lookup(tab.self.ID, false) + for _, bi = range tab.rand.Perm(len(tab.buckets)) { + b := tab.buckets[bi] + if len(b.entries) > 0 { + last := b.entries[len(b.entries)-1] + return last, bi + } + } + return nil, 0 +} + +func (tab *Table) nextRevalidateTime() time.Duration { + tab.mutex.Lock() + defer tab.mutex.Unlock() + + return time.Duration(tab.rand.Int63n(int64(revalidateInterval))) +} + +// copyBondedNodes adds nodes from the table to the database if they have been in the table +// longer then minTableTime. +func (tab *Table) copyBondedNodes() { + tab.mutex.Lock() + defer tab.mutex.Unlock() + + now := time.Now() + for _, b := range tab.buckets { + for _, n := range b.entries { + if now.Sub(n.addedAt) >= seedMinTableTime { + tab.db.updateNode(n) + } + } + } } // closest returns the n nodes in the table that are closest to the @@ -459,15 +591,14 @@ func (tab *Table) bond(pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16 if id == tab.self.ID { return nil, errors.New("is self") } - // Retrieve a previously known node and any recent findnode failures - node, fails := tab.db.node(id), 0 - if node != nil { - fails = tab.db.findFails(id) + if pinged && !tab.isInitDone() { + return nil, errors.New("still initializing") } - // If the node is unknown (non-bonded) or failed (remotely unknown), bond from scratch + // Start bonding if we haven't seen this node for a while or if it failed findnode too often. + node, fails := tab.db.node(id), tab.db.findFails(id) + age := time.Since(tab.db.bondTime(id)) var result error - age := time.Since(tab.db.lastPong(id)) - if node == nil || fails > 0 || age > nodeDBNodeExpiration { + if fails > 0 || age > nodeDBNodeExpiration { log.Trace("Starting bonding ping/pong", "id", id, "known", node != nil, "failcount", fails, "age", age) tab.bondmu.Lock() @@ -494,10 +625,10 @@ func (tab *Table) bond(pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16 node = w.n } } + // Add the node to the table even if the bonding ping/pong + // fails. It will be relaced quickly if it continues to be + // unresponsive. if node != nil { - // Add the node to the table even if the bonding ping/pong - // fails. It will be relaced quickly if it continues to be - // unresponsive. tab.add(node) tab.db.updateFindFails(id, 0) } @@ -522,7 +653,6 @@ func (tab *Table) pingpong(w *bondproc, pinged bool, id NodeID, addr *net.UDPAdd } // Bonding succeeded, update the node database. w.n = NewNode(id, addr.IP, uint16(addr.Port), tcpPort) - tab.db.updateNode(w.n) close(w.done) } @@ -533,17 +663,19 @@ func (tab *Table) ping(id NodeID, addr *net.UDPAddr) error { if err := tab.net.ping(id, addr); err != nil { return err } - tab.db.updateLastPong(id, time.Now()) - - // Start the background expiration goroutine after the first - // successful communication. Subsequent calls have no effect if it - // is already running. We do this here instead of somewhere else - // so that the search for seed nodes also considers older nodes - // that would otherwise be removed by the expiration. - tab.db.ensureExpirer() + tab.db.updateBondTime(id, time.Now()) return nil } +// bucket returns the bucket for the given node ID hash. +func (tab *Table) bucket(sha common.Hash) *bucket { + d := logdist(tab.self.sha, sha) + if d <= bucketMinDistance { + return tab.buckets[0] + } + return tab.buckets[d-bucketMinDistance-1] +} + // add attempts to add the given node its corresponding bucket. If the // bucket has space available, adding the node succeeds immediately. // Otherwise, the node is added if the least recently active node in @@ -551,57 +683,29 @@ func (tab *Table) ping(id NodeID, addr *net.UDPAddr) error { // // The caller must not hold tab.mutex. func (tab *Table) add(new *Node) { - b := tab.buckets[logdist(tab.self.sha, new.sha)] tab.mutex.Lock() defer tab.mutex.Unlock() - if b.bump(new) { - return - } - var oldest *Node - if len(b.entries) == bucketSize { - oldest = b.entries[bucketSize-1] - if oldest.contested { - // The node is already being replaced, don't attempt - // to replace it. - return - } - oldest.contested = true - // Let go of the mutex so other goroutines can access - // the table while we ping the least recently active node. - tab.mutex.Unlock() - err := tab.ping(oldest.ID, oldest.addr()) - tab.mutex.Lock() - oldest.contested = false - if err == nil { - // The node responded, don't replace it. - return - } - } - added := b.replace(new, oldest) - if added && tab.nodeAddedHook != nil { - tab.nodeAddedHook(new) + + b := tab.bucket(new.sha) + if !tab.bumpOrAdd(b, new) { + // Node is not in table. Add it to the replacement list. + tab.addReplacement(b, new) } } // stuff adds nodes the table to the end of their corresponding bucket -// if the bucket is not full. The caller must hold tab.mutex. +// if the bucket is not full. The caller must not hold tab.mutex. func (tab *Table) stuff(nodes []*Node) { -outer: + tab.mutex.Lock() + defer tab.mutex.Unlock() + for _, n := range nodes { if n.ID == tab.self.ID { continue // don't add self } - bucket := tab.buckets[logdist(tab.self.sha, n.sha)] - for i := range bucket.entries { - if bucket.entries[i].ID == n.ID { - continue outer // already in bucket - } - } - if len(bucket.entries) < bucketSize { - bucket.entries = append(bucket.entries, n) - if tab.nodeAddedHook != nil { - tab.nodeAddedHook(n) - } + b := tab.bucket(n.sha) + if len(b.entries) < bucketSize { + tab.bumpOrAdd(b, n) } } } @@ -611,36 +715,72 @@ outer: func (tab *Table) delete(node *Node) { tab.mutex.Lock() defer tab.mutex.Unlock() - bucket := tab.buckets[logdist(tab.self.sha, node.sha)] - for i := range bucket.entries { - if bucket.entries[i].ID == node.ID { - bucket.entries = append(bucket.entries[:i], bucket.entries[i+1:]...) - return - } - } + + tab.deleteInBucket(tab.bucket(node.sha), node) } -func (b *bucket) replace(n *Node, last *Node) bool { - // Don't add if b already contains n. - for i := range b.entries { - if b.entries[i].ID == n.ID { - return false - } +func (tab *Table) addIP(b *bucket, ip net.IP) bool { + if netutil.IsLAN(ip) { + return true } - // Replace last if it is still the last entry or just add n if b - // isn't full. If is no longer the last entry, it has either been - // replaced with someone else or became active. - if len(b.entries) == bucketSize && (last == nil || b.entries[bucketSize-1].ID != last.ID) { + if !tab.ips.Add(ip) { + log.Debug("IP exceeds table limit", "ip", ip) return false } - if len(b.entries) < bucketSize { - b.entries = append(b.entries, nil) + if !b.ips.Add(ip) { + log.Debug("IP exceeds bucket limit", "ip", ip) + tab.ips.Remove(ip) + return false } - copy(b.entries[1:], b.entries) - b.entries[0] = n return true } +func (tab *Table) removeIP(b *bucket, ip net.IP) { + if netutil.IsLAN(ip) { + return + } + tab.ips.Remove(ip) + b.ips.Remove(ip) +} + +func (tab *Table) addReplacement(b *bucket, n *Node) { + for _, e := range b.replacements { + if e.ID == n.ID { + return // already in list + } + } + if !tab.addIP(b, n.IP) { + return + } + var removed *Node + b.replacements, removed = pushNode(b.replacements, n, maxReplacements) + if removed != nil { + tab.removeIP(b, removed.IP) + } +} + +// replace removes n from the replacement list and replaces 'last' with it if it is the +// last entry in the bucket. If 'last' isn't the last entry, it has either been replaced +// with someone else or became active. +func (tab *Table) replace(b *bucket, last *Node) *Node { + if len(b.entries) == 0 || b.entries[len(b.entries)-1].ID != last.ID { + // Entry has moved, don't replace it. + return nil + } + // Still the last entry. + if len(b.replacements) == 0 { + tab.deleteInBucket(b, last) + return nil + } + r := b.replacements[tab.rand.Intn(len(b.replacements))] + b.replacements = deleteNode(b.replacements, r) + b.entries[len(b.entries)-1] = r + tab.removeIP(b, last.IP) + return r +} + +// bump moves the given node to the front of the bucket entry list +// if it is contained in that list. func (b *bucket) bump(n *Node) bool { for i := range b.entries { if b.entries[i].ID == n.ID { @@ -653,6 +793,50 @@ func (b *bucket) bump(n *Node) bool { return false } +// bumpOrAdd moves n to the front of the bucket entry list or adds it if the list isn't +// full. The return value is true if n is in the bucket. +func (tab *Table) bumpOrAdd(b *bucket, n *Node) bool { + if b.bump(n) { + return true + } + if len(b.entries) >= bucketSize || !tab.addIP(b, n.IP) { + return false + } + b.entries, _ = pushNode(b.entries, n, bucketSize) + b.replacements = deleteNode(b.replacements, n) + n.addedAt = time.Now() + if tab.nodeAddedHook != nil { + tab.nodeAddedHook(n) + } + return true +} + +func (tab *Table) deleteInBucket(b *bucket, n *Node) { + b.entries = deleteNode(b.entries, n) + tab.removeIP(b, n.IP) +} + +// pushNode adds n to the front of list, keeping at most max items. +func pushNode(list []*Node, n *Node, max int) ([]*Node, *Node) { + if len(list) < max { + list = append(list, nil) + } + removed := list[len(list)-1] + copy(list[1:], list) + list[0] = n + return list, removed +} + +// deleteNode removes n from list. +func deleteNode(list []*Node, n *Node) []*Node { + for i := range list { + if list[i].ID == n.ID { + return append(list[:i], list[i+1:]...) + } + } + return list +} + // nodesByDistance is a list of nodes, ordered by // distance to target. type nodesByDistance struct { diff --git a/p2p/discover/table_test.go b/p2p/discover/table_test.go index 1037cc609..3ce48d299 100644 --- a/p2p/discover/table_test.go +++ b/p2p/discover/table_test.go @@ -20,6 +20,7 @@ import ( "crypto/ecdsa" "fmt" "math/rand" + "sync" "net" "reflect" @@ -32,60 +33,65 @@ import ( ) func TestTable_pingReplace(t *testing.T) { - doit := func(newNodeIsResponding, lastInBucketIsResponding bool) { - transport := newPingRecorder() - tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "") - defer tab.Close() - pingSender := NewNode(MustHexID("a502af0f59b2aab7746995408c79e9ca312d2793cc997e44fc55eda62f0150bbb8c59a6f9269ba3a081518b62699ee807c7c19c20125ddfccca872608af9e370"), net.IP{}, 99, 99) + run := func(newNodeResponding, lastInBucketResponding bool) { + name := fmt.Sprintf("newNodeResponding=%t/lastInBucketResponding=%t", newNodeResponding, lastInBucketResponding) + t.Run(name, func(t *testing.T) { + t.Parallel() + testPingReplace(t, newNodeResponding, lastInBucketResponding) + }) + } - // fill up the sender's bucket. - last := fillBucket(tab, 253) + run(true, true) + run(false, true) + run(true, false) + run(false, false) +} - // this call to bond should replace the last node - // in its bucket if the node is not responding. - transport.responding[last.ID] = lastInBucketIsResponding - transport.responding[pingSender.ID] = newNodeIsResponding - tab.bond(true, pingSender.ID, &net.UDPAddr{}, 0) +func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding bool) { + transport := newPingRecorder() + tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil) + defer tab.Close() - // first ping goes to sender (bonding pingback) - if !transport.pinged[pingSender.ID] { - t.Error("table did not ping back sender") - } - if newNodeIsResponding { - // second ping goes to oldest node in bucket - // to see whether it is still alive. - if !transport.pinged[last.ID] { - t.Error("table did not ping last node in bucket") - } - } + // Wait for init so bond is accepted. + <-tab.initDone - tab.mutex.Lock() - defer tab.mutex.Unlock() - if l := len(tab.buckets[253].entries); l != bucketSize { - t.Errorf("wrong bucket size after bond: got %d, want %d", l, bucketSize) - } + // fill up the sender's bucket. + pingSender := NewNode(MustHexID("a502af0f59b2aab7746995408c79e9ca312d2793cc997e44fc55eda62f0150bbb8c59a6f9269ba3a081518b62699ee807c7c19c20125ddfccca872608af9e370"), net.IP{}, 99, 99) + last := fillBucket(tab, pingSender) - if lastInBucketIsResponding || !newNodeIsResponding { - if !contains(tab.buckets[253].entries, last.ID) { - t.Error("last entry was removed") - } - if contains(tab.buckets[253].entries, pingSender.ID) { - t.Error("new entry was added") - } - } else { - if contains(tab.buckets[253].entries, last.ID) { - t.Error("last entry was not removed") - } - if !contains(tab.buckets[253].entries, pingSender.ID) { - t.Error("new entry was not added") - } - } + // this call to bond should replace the last node + // in its bucket if the node is not responding. + transport.dead[last.ID] = !lastInBucketIsResponding + transport.dead[pingSender.ID] = !newNodeIsResponding + tab.bond(true, pingSender.ID, &net.UDPAddr{}, 0) + tab.doRevalidate(make(chan struct{}, 1)) + + // first ping goes to sender (bonding pingback) + if !transport.pinged[pingSender.ID] { + t.Error("table did not ping back sender") + } + if !transport.pinged[last.ID] { + // second ping goes to oldest node in bucket + // to see whether it is still alive. + t.Error("table did not ping last node in bucket") } - doit(true, true) - doit(false, true) - doit(true, false) - doit(false, false) + tab.mutex.Lock() + defer tab.mutex.Unlock() + wantSize := bucketSize + if !lastInBucketIsResponding && !newNodeIsResponding { + wantSize-- + } + if l := len(tab.bucket(pingSender.sha).entries); l != wantSize { + t.Errorf("wrong bucket size after bond: got %d, want %d", l, wantSize) + } + if found := contains(tab.bucket(pingSender.sha).entries, last.ID); found != lastInBucketIsResponding { + t.Errorf("last entry found: %t, want: %t", found, lastInBucketIsResponding) + } + wantNewEntry := newNodeIsResponding && !lastInBucketIsResponding + if found := contains(tab.bucket(pingSender.sha).entries, pingSender.ID); found != wantNewEntry { + t.Errorf("new entry found: %t, want: %t", found, wantNewEntry) + } } func TestBucket_bumpNoDuplicates(t *testing.T) { @@ -130,11 +136,45 @@ func TestBucket_bumpNoDuplicates(t *testing.T) { } } +// This checks that the table-wide IP limit is applied correctly. +func TestTable_IPLimit(t *testing.T) { + transport := newPingRecorder() + tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil) + defer tab.Close() + + for i := 0; i < tableIPLimit+1; i++ { + n := nodeAtDistance(tab.self.sha, i) + n.IP = net.IP{172, 0, 1, byte(i)} + tab.add(n) + } + if tab.len() > tableIPLimit { + t.Errorf("too many nodes in table") + } +} + +// This checks that the table-wide IP limit is applied correctly. +func TestTable_BucketIPLimit(t *testing.T) { + transport := newPingRecorder() + tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil) + defer tab.Close() + + d := 3 + for i := 0; i < bucketIPLimit+1; i++ { + n := nodeAtDistance(tab.self.sha, d) + n.IP = net.IP{172, 0, 1, byte(i)} + tab.add(n) + } + if tab.len() > bucketIPLimit { + t.Errorf("too many nodes in table") + } +} + // fillBucket inserts nodes into the given bucket until // it is full. The node's IDs dont correspond to their // hashes. -func fillBucket(tab *Table, ld int) (last *Node) { - b := tab.buckets[ld] +func fillBucket(tab *Table, n *Node) (last *Node) { + ld := logdist(tab.self.sha, n.sha) + b := tab.bucket(n.sha) for len(b.entries) < bucketSize { b.entries = append(b.entries, nodeAtDistance(tab.self.sha, ld)) } @@ -146,30 +186,39 @@ func fillBucket(tab *Table, ld int) (last *Node) { func nodeAtDistance(base common.Hash, ld int) (n *Node) { n = new(Node) n.sha = hashAtDistance(base, ld) - n.IP = net.IP{10, 0, 2, byte(ld)} + n.IP = net.IP{byte(ld), 0, 2, byte(ld)} copy(n.ID[:], n.sha[:]) // ensure the node still has a unique ID return n } -type pingRecorder struct{ responding, pinged map[NodeID]bool } +type pingRecorder struct { + mu sync.Mutex + dead, pinged map[NodeID]bool +} func newPingRecorder() *pingRecorder { - return &pingRecorder{make(map[NodeID]bool), make(map[NodeID]bool)} + return &pingRecorder{ + dead: make(map[NodeID]bool), + pinged: make(map[NodeID]bool), + } } func (t *pingRecorder) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) { - panic("findnode called on pingRecorder") + return nil, nil } func (t *pingRecorder) close() {} func (t *pingRecorder) waitping(from NodeID) error { return nil // remote always pings } func (t *pingRecorder) ping(toid NodeID, toaddr *net.UDPAddr) error { + t.mu.Lock() + defer t.mu.Unlock() + t.pinged[toid] = true - if t.responding[toid] { - return nil - } else { + if t.dead[toid] { return errTimeout + } else { + return nil } } @@ -178,7 +227,8 @@ func TestTable_closest(t *testing.T) { test := func(test *closeTest) bool { // for any node table, Target and N - tab, _ := newTable(nil, test.Self, &net.UDPAddr{}, "") + transport := newPingRecorder() + tab, _ := newTable(transport, test.Self, &net.UDPAddr{}, "", nil) defer tab.Close() tab.stuff(test.All) @@ -237,8 +287,11 @@ func TestTable_ReadRandomNodesGetAll(t *testing.T) { }, } test := func(buf []*Node) bool { - tab, _ := newTable(nil, NodeID{}, &net.UDPAddr{}, "") + transport := newPingRecorder() + tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil) defer tab.Close() + <-tab.initDone + for i := 0; i < len(buf); i++ { ld := cfg.Rand.Intn(len(tab.buckets)) tab.stuff([]*Node{nodeAtDistance(tab.self.sha, ld)}) @@ -280,7 +333,7 @@ func (*closeTest) Generate(rand *rand.Rand, size int) reflect.Value { func TestTable_Lookup(t *testing.T) { self := nodeAtDistance(common.Hash{}, 0) - tab, _ := newTable(lookupTestnet, self.ID, &net.UDPAddr{}, "") + tab, _ := newTable(lookupTestnet, self.ID, &net.UDPAddr{}, "", nil) defer tab.Close() // lookup on empty table returns no nodes diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go index f9eb99ee3..524c6e498 100644 --- a/p2p/discover/udp.go +++ b/p2p/discover/udp.go @@ -210,17 +210,28 @@ type reply struct { matched chan<- bool } +// ReadPacket is sent to the unhandled channel when it could not be processed +type ReadPacket struct { + Data []byte + Addr *net.UDPAddr +} + +// Config holds Table-related settings. +type Config struct { + // These settings are required and configure the UDP listener: + PrivateKey *ecdsa.PrivateKey + + // These settings are optional: + AnnounceAddr *net.UDPAddr // local address announced in the DHT + NodeDBPath string // if set, the node database is stored at this filesystem location + NetRestrict *netutil.Netlist // network whitelist + Bootnodes []*Node // list of bootstrap nodes + Unhandled chan<- ReadPacket // unhandled packets are sent on this channel +} + // ListenUDP returns a new table that listens for UDP packets on laddr. -func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, error) { - addr, err := net.ResolveUDPAddr("udp", laddr) - if err != nil { - return nil, err - } - conn, err := net.ListenUDP("udp", addr) - if err != nil { - return nil, err - } - tab, _, err := newUDP(priv, conn, natm, nodeDBPath, netrestrict) +func ListenUDP(c conn, cfg Config) (*Table, error) { + tab, _, err := newUDP(c, cfg) if err != nil { return nil, err } @@ -228,35 +239,29 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBP return tab, nil } -func newUDP(priv *ecdsa.PrivateKey, c conn, natm nat.Interface, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, *udp, error) { +func newUDP(c conn, cfg Config) (*Table, *udp, error) { udp := &udp{ conn: c, - priv: priv, - netrestrict: netrestrict, + priv: cfg.PrivateKey, + netrestrict: cfg.NetRestrict, closing: make(chan struct{}), gotreply: make(chan reply), addpending: make(chan *pending), } realaddr := c.LocalAddr().(*net.UDPAddr) - if natm != nil { - if !realaddr.IP.IsLoopback() { - go nat.Map(natm, udp.closing, "udp", realaddr.Port, realaddr.Port, "ethereum discovery") - } - // TODO: react to external IP changes over time. - if ext, err := natm.ExternalIP(); err == nil { - realaddr = &net.UDPAddr{IP: ext, Port: realaddr.Port} - } + if cfg.AnnounceAddr != nil { + realaddr = cfg.AnnounceAddr } // TODO: separate TCP port udp.ourEndpoint = makeEndpoint(realaddr, uint16(realaddr.Port)) - tab, err := newTable(udp, PubkeyID(&priv.PublicKey), realaddr, nodeDBPath) + tab, err := newTable(udp, PubkeyID(&cfg.PrivateKey.PublicKey), realaddr, cfg.NodeDBPath, cfg.Bootnodes) if err != nil { return nil, nil, err } udp.Table = tab go udp.loop() - go udp.readLoop() + go udp.readLoop(cfg.Unhandled) return udp.Table, udp, nil } @@ -268,14 +273,20 @@ func (t *udp) close() { // ping sends a ping message to the given node and waits for a reply. func (t *udp) ping(toid NodeID, toaddr *net.UDPAddr) error { - // TODO: maybe check for ReplyTo field in callback to measure RTT - errc := t.pending(toid, pongPacket, func(interface{}) bool { return true }) - t.send(toaddr, pingPacket, &ping{ + req := &ping{ Version: Version, From: t.ourEndpoint, To: makeEndpoint(toaddr, 0), // TODO: maybe use known TCP port from DB Expiration: uint64(time.Now().Add(expiration).Unix()), + } + packet, hash, err := encodePacket(t.priv, pingPacket, req) + if err != nil { + return err + } + errc := t.pending(toid, pongPacket, func(p interface{}) bool { + return bytes.Equal(p.(*pong).ReplyTok, hash) }) + t.write(toaddr, req.name(), packet) return <-errc } @@ -459,41 +470,49 @@ func init() { } } -func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req packet) error { - packet, err := encodePacket(t.priv, ptype, req) +func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req packet) ([]byte, error) { + packet, hash, err := encodePacket(t.priv, ptype, req) if err != nil { - return err + return hash, err } - _, err = t.conn.WriteToUDP(packet, toaddr) - log.Trace(">> "+req.name(), "addr", toaddr, "err", err) + return hash, t.write(toaddr, req.name(), packet) +} + +func (t *udp) write(toaddr *net.UDPAddr, what string, packet []byte) error { + _, err := t.conn.WriteToUDP(packet, toaddr) + log.Trace(">> "+what, "addr", toaddr, "err", err) return err } -func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) ([]byte, error) { +func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) (packet, hash []byte, err error) { b := new(bytes.Buffer) b.Write(headSpace) b.WriteByte(ptype) if err := rlp.Encode(b, req); err != nil { log.Error("Can't encode discv4 packet", "err", err) - return nil, err + return nil, nil, err } - packet := b.Bytes() + packet = b.Bytes() sig, err := crypto.Sign(crypto.Keccak256(packet[headSize:]), priv) if err != nil { log.Error("Can't sign discv4 packet", "err", err) - return nil, err + return nil, nil, err } copy(packet[macSize:], sig) // add the hash to the front. Note: this doesn't protect the // packet in any way. Our public key will be part of this hash in // The future. - copy(packet, crypto.Keccak256(packet[macSize:])) - return packet, nil + hash = crypto.Keccak256(packet[macSize:]) + copy(packet, hash) + return packet, hash, nil } // readLoop runs in its own goroutine. it handles incoming UDP packets. -func (t *udp) readLoop() { +func (t *udp) readLoop(unhandled chan<- ReadPacket) { defer t.conn.Close() + if unhandled != nil { + defer close(unhandled) + } // Discovery packets are defined to be no larger than 1280 bytes. // Packets larger than this size will be cut at the end and treated // as invalid because their hash won't match. @@ -509,7 +528,12 @@ func (t *udp) readLoop() { log.Debug("UDP read error", "err", err) return } - t.handlePacket(from, buf[:nbytes]) + if t.handlePacket(from, buf[:nbytes]) != nil && unhandled != nil { + select { + case unhandled <- ReadPacket{buf[:nbytes], from}: + default: + } + } } } @@ -589,7 +613,7 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte if expired(req.Expiration) { return errExpired } - if t.db.node(fromID) == nil { + if !t.db.hasBond(fromID) { // No bond exists, we don't process the packet. This prevents // an attack vector where the discovery protocol could be used // to amplify traffic in a DDOS attack. A malicious actor @@ -605,18 +629,22 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte t.mutex.Unlock() p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())} + var sent bool // Send neighbors in chunks with at most maxNeighbors per packet // to stay below the 1280 byte limit. - for i, n := range closest { - if netutil.CheckRelayIP(from.IP, n.IP) != nil { - continue + for _, n := range closest { + if netutil.CheckRelayIP(from.IP, n.IP) == nil { + p.Nodes = append(p.Nodes, nodeToRPC(n)) } - p.Nodes = append(p.Nodes, nodeToRPC(n)) - if len(p.Nodes) == maxNeighbors || i == len(closest)-1 { + if len(p.Nodes) == maxNeighbors { t.send(from, neighborsPacket, &p) p.Nodes = p.Nodes[:0] + sent = true } } + if len(p.Nodes) > 0 || !sent { + t.send(from, neighborsPacket, &p) + } return nil } diff --git a/p2p/discover/udp_test.go b/p2p/discover/udp_test.go index 21e8b561d..db9804f7b 100644 --- a/p2p/discover/udp_test.go +++ b/p2p/discover/udp_test.go @@ -70,13 +70,15 @@ func newUDPTest(t *testing.T) *udpTest { remotekey: newkey(), remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303}, } - test.table, test.udp, _ = newUDP(test.localkey, test.pipe, nil, "", nil) + test.table, test.udp, _ = newUDP(test.pipe, Config{PrivateKey: test.localkey}) + // Wait for initial refresh so the table doesn't send unexpected findnode. + <-test.table.initDone return test } // handles a packet as if it had been sent to the transport. func (test *udpTest) packetIn(wantError error, ptype byte, data packet) error { - enc, err := encodePacket(test.remotekey, ptype, data) + enc, _, err := encodePacket(test.remotekey, ptype, data) if err != nil { return test.errorf("packet (%d) encode error: %v", ptype, err) } @@ -89,19 +91,19 @@ func (test *udpTest) packetIn(wantError error, ptype byte, data packet) error { // waits for a packet to be sent by the transport. // validate should have type func(*udpTest, X) error, where X is a packet type. -func (test *udpTest) waitPacketOut(validate interface{}) error { +func (test *udpTest) waitPacketOut(validate interface{}) ([]byte, error) { dgram := test.pipe.waitPacketOut() - p, _, _, err := decodePacket(dgram) + p, _, hash, err := decodePacket(dgram) if err != nil { - return test.errorf("sent packet decode error: %v", err) + return hash, test.errorf("sent packet decode error: %v", err) } fn := reflect.ValueOf(validate) exptype := fn.Type().In(0) if reflect.TypeOf(p) != exptype { - return test.errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype) + return hash, test.errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype) } fn.Call([]reflect.Value{reflect.ValueOf(p)}) - return nil + return hash, nil } func (test *udpTest) errorf(format string, args ...interface{}) error { @@ -245,12 +247,8 @@ func TestUDP_findnode(t *testing.T) { // ensure there's a bond with the test node, // findnode won't be accepted otherwise. - test.table.db.updateNode(NewNode( - PubkeyID(&test.remotekey.PublicKey), - test.remoteaddr.IP, - uint16(test.remoteaddr.Port), - 99, - )) + test.table.db.updateBondTime(PubkeyID(&test.remotekey.PublicKey), time.Now()) + // check that closest neighbors are returned. test.packetIn(nil, findnodePacket, &findnode{Target: testTarget, Expiration: futureExp}) expected := test.table.closest(targetHash, bucketSize) @@ -350,7 +348,7 @@ func TestUDP_successfulPing(t *testing.T) { }) // remote is unknown, the table pings back. - test.waitPacketOut(func(p *ping) error { + hash, _ := test.waitPacketOut(func(p *ping) error { if !reflect.DeepEqual(p.From, test.udp.ourEndpoint) { t.Errorf("got ping.From %v, want %v", p.From, test.udp.ourEndpoint) } @@ -364,7 +362,7 @@ func TestUDP_successfulPing(t *testing.T) { } return nil }) - test.packetIn(nil, pongPacket, &pong{Expiration: futureExp}) + test.packetIn(nil, pongPacket, &pong{ReplyTok: hash, Expiration: futureExp}) // the node should be added to the table shortly after getting the // pong packet. diff --git a/p2p/discv5/net.go b/p2p/discv5/net.go index cd9981584..52c677b62 100644 --- a/p2p/discv5/net.go +++ b/p2p/discv5/net.go @@ -29,7 +29,6 @@ import ( "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto/sha3" "github.com/ethereum/go-ethereum/log" - "github.com/ethereum/go-ethereum/p2p/nat" "github.com/ethereum/go-ethereum/p2p/netutil" "github.com/ethereum/go-ethereum/rlp" ) @@ -134,7 +133,7 @@ type timeoutEvent struct { node *Node } -func newNetwork(conn transport, ourPubkey ecdsa.PublicKey, natm nat.Interface, dbPath string, netrestrict *netutil.Netlist) (*Network, error) { +func newNetwork(conn transport, ourPubkey ecdsa.PublicKey, dbPath string, netrestrict *netutil.Netlist) (*Network, error) { ourID := PubkeyID(&ourPubkey) var db *nodeDB @@ -566,11 +565,8 @@ loop: if lookupChn := searchInfo[res.target.topic].lookupChn; lookupChn != nil { lookupChn <- net.ticketStore.radius[res.target.topic].converged } - net.ticketStore.searchLookupDone(res.target, res.nodes, func(n *Node) []byte { - net.ping(n, n.addr()) - return n.pingEcho - }, func(n *Node, topic Topic) []byte { - if n.state == known { + net.ticketStore.searchLookupDone(res.target, res.nodes, func(n *Node, topic Topic) []byte { + if n.state != nil && n.state.canQuery { return net.conn.send(n, topicQueryPacket, topicQuery{Topic: topic}) // TODO: set expiration } else { if n.state == unknown { @@ -634,15 +630,20 @@ loop: } net.refreshResp <- refreshDone case <-refreshDone: - log.Trace("<-net.refreshDone") - refreshDone = nil - list := searchReqWhenRefreshDone - searchReqWhenRefreshDone = nil - go func() { - for _, req := range list { - net.topicSearchReq <- req - } - }() + log.Trace("<-net.refreshDone", "table size", net.tab.count) + if net.tab.count != 0 { + refreshDone = nil + list := searchReqWhenRefreshDone + searchReqWhenRefreshDone = nil + go func() { + for _, req := range list { + net.topicSearchReq <- req + } + }() + } else { + refreshDone = make(chan struct{}) + net.refresh(refreshDone) + } } } log.Trace("loop stopped") @@ -752,7 +753,15 @@ func (net *Network) internNodeFromNeighbours(sender *net.UDPAddr, rn rpcNode) (n return n, err } if !n.IP.Equal(rn.IP) || n.UDP != rn.UDP || n.TCP != rn.TCP { - err = fmt.Errorf("metadata mismatch: got %v, want %v", rn, n) + if n.state == known { + // reject address change if node is known by us + err = fmt.Errorf("metadata mismatch: got %v, want %v", rn, n) + } else { + // accept otherwise; this will be handled nicer with signed ENRs + n.IP = rn.IP + n.UDP = rn.UDP + n.TCP = rn.TCP + } } return n, err } diff --git a/p2p/discv5/net_test.go b/p2p/discv5/net_test.go index bd234f5ba..369282ca9 100644 --- a/p2p/discv5/net_test.go +++ b/p2p/discv5/net_test.go @@ -28,7 +28,7 @@ import ( func TestNetwork_Lookup(t *testing.T) { key, _ := crypto.GenerateKey() - network, err := newNetwork(lookupTestnet, key.PublicKey, nil, "", nil) + network, err := newNetwork(lookupTestnet, key.PublicKey, "", nil) if err != nil { t.Fatal(err) } diff --git a/p2p/discv5/sim_test.go b/p2p/discv5/sim_test.go index bf57872e2..543faecd4 100644 --- a/p2p/discv5/sim_test.go +++ b/p2p/discv5/sim_test.go @@ -282,7 +282,7 @@ func (s *simulation) launchNode(log bool) *Network { addr := &net.UDPAddr{IP: ip, Port: 30303} transport := &simTransport{joinTime: time.Now(), sender: id, senderAddr: addr, sim: s, priv: key} - net, err := newNetwork(transport, key.PublicKey, nil, "<no database>", nil) + net, err := newNetwork(transport, key.PublicKey, "<no database>", nil) if err != nil { panic("cannot launch new node: " + err.Error()) } diff --git a/p2p/discv5/ticket.go b/p2p/discv5/ticket.go index b45ec4d2b..b3d1ac4ba 100644 --- a/p2p/discv5/ticket.go +++ b/p2p/discv5/ticket.go @@ -350,7 +350,7 @@ func (s *ticketStore) nextFilteredTicket() (*ticketRef, time.Duration) { regTime := now + mclock.AbsTime(wait) topic := ticket.t.topics[ticket.idx] - if regTime >= s.tickets[topic].nextReg { + if s.tickets[topic] != nil && regTime >= s.tickets[topic].nextReg { return ticket, wait } s.removeTicketRef(*ticket) @@ -420,11 +420,14 @@ func (s *ticketStore) nextRegisterableTicket() (*ticketRef, time.Duration) { func (s *ticketStore) removeTicketRef(ref ticketRef) { log.Trace("Removing discovery ticket reference", "node", ref.t.node.ID, "serial", ref.t.serial) + // Make nextRegisterableTicket return the next available ticket. + s.nextTicketCached = nil + topic := ref.topic() tickets := s.tickets[topic] if tickets == nil { - log.Warn("Removing tickets from unknown topic", "topic", topic) + log.Trace("Removing tickets from unknown topic", "topic", topic) return } bucket := timeBucket(ref.t.regTime[ref.idx] / mclock.AbsTime(ticketTimeBucketLen)) @@ -450,9 +453,6 @@ func (s *ticketStore) removeTicketRef(ref ticketRef) { delete(s.nodes, ref.t.node) delete(s.nodeLastReq, ref.t.node) } - - // Make nextRegisterableTicket return the next available ticket. - s.nextTicketCached = nil } type lookupInfo struct { @@ -494,13 +494,13 @@ func (s *ticketStore) registerLookupDone(lookup lookupInfo, nodes []*Node, ping } } -func (s *ticketStore) searchLookupDone(lookup lookupInfo, nodes []*Node, ping func(n *Node) []byte, query func(n *Node, topic Topic) []byte) { +func (s *ticketStore) searchLookupDone(lookup lookupInfo, nodes []*Node, query func(n *Node, topic Topic) []byte) { now := mclock.Now() for i, n := range nodes { if i == 0 || (binary.BigEndian.Uint64(n.sha[:8])^binary.BigEndian.Uint64(lookup.target[:8])) < s.radius[lookup.topic].minRadius { if lookup.radiusLookup { if lastReq, ok := s.nodeLastReq[n]; !ok || time.Duration(now-lastReq.time) > radiusTC { - s.nodeLastReq[n] = reqInfo{pingHash: ping(n), lookup: lookup, time: now} + s.nodeLastReq[n] = reqInfo{pingHash: nil, lookup: lookup, time: now} } } // else { if s.canQueryTopic(n, lookup.topic) { @@ -642,7 +642,7 @@ func (s *ticketStore) gotTopicNodes(from *Node, hash common.Hash, nodes []rpcNod if ip.IsUnspecified() || ip.IsLoopback() { ip = from.IP } - n := NewNode(node.ID, ip, node.UDP-1, node.TCP-1) // subtract one from port while discv5 is running in test mode on UDPport+1 + n := NewNode(node.ID, ip, node.UDP, node.TCP) select { case chn <- n: default: diff --git a/p2p/discv5/udp.go b/p2p/discv5/udp.go index 26087cd8e..6ce72d2c1 100644 --- a/p2p/discv5/udp.go +++ b/p2p/discv5/udp.go @@ -37,7 +37,7 @@ const Version = 4 // Errors var ( errPacketTooSmall = errors.New("too small") - errBadHash = errors.New("bad hash") + errBadPrefix = errors.New("bad prefix") errExpired = errors.New("expired") errUnsolicitedReply = errors.New("unsolicited reply") errUnknownNode = errors.New("unknown node") @@ -49,7 +49,7 @@ var ( // Timeouts const ( respTimeout = 500 * time.Millisecond - sendTimeout = 500 * time.Millisecond + queryDelay = 1000 * time.Millisecond expiration = 20 * time.Second ntpFailureThreshold = 32 // Continuous timeouts after which to check NTP @@ -145,10 +145,11 @@ type ( } ) -const ( - macSize = 256 / 8 - sigSize = 520 / 8 - headSize = macSize + sigSize // space of packet frame data +var ( + versionPrefix = []byte("temporary discovery v5") + versionPrefixSize = len(versionPrefix) + sigSize = 520 / 8 + headSize = versionPrefixSize + sigSize // space of packet frame data ) // Neighbors replies are sent across multiple packets to @@ -237,30 +238,23 @@ type udp struct { } // ListenUDP returns a new table that listens for UDP packets on laddr. -func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string, netrestrict *netutil.Netlist) (*Network, error) { - transport, err := listenUDP(priv, laddr) +func ListenUDP(priv *ecdsa.PrivateKey, conn conn, realaddr *net.UDPAddr, nodeDBPath string, netrestrict *netutil.Netlist) (*Network, error) { + transport, err := listenUDP(priv, conn, realaddr) if err != nil { return nil, err } - net, err := newNetwork(transport, priv.PublicKey, natm, nodeDBPath, netrestrict) + net, err := newNetwork(transport, priv.PublicKey, nodeDBPath, netrestrict) if err != nil { return nil, err } + log.Info("UDP listener up", "net", net.tab.self) transport.net = net go transport.readLoop() return net, nil } -func listenUDP(priv *ecdsa.PrivateKey, laddr string) (*udp, error) { - addr, err := net.ResolveUDPAddr("udp", laddr) - if err != nil { - return nil, err - } - conn, err := net.ListenUDP("udp", addr) - if err != nil { - return nil, err - } - return &udp{conn: conn, priv: priv, ourEndpoint: makeEndpoint(addr, uint16(addr.Port))}, nil +func listenUDP(priv *ecdsa.PrivateKey, conn conn, realaddr *net.UDPAddr) (*udp, error) { + return &udp{conn: conn, priv: priv, ourEndpoint: makeEndpoint(realaddr, uint16(realaddr.Port))}, nil } func (t *udp) localAddr() *net.UDPAddr { @@ -324,20 +318,20 @@ func (t *udp) sendTopicRegister(remote *Node, topics []Topic, idx int, pong []by func (t *udp) sendTopicNodes(remote *Node, queryHash common.Hash, nodes []*Node) { p := topicNodes{Echo: queryHash} - if len(nodes) == 0 { - t.sendPacket(remote.ID, remote.addr(), byte(topicNodesPacket), p) - return - } - for i, result := range nodes { - if netutil.CheckRelayIP(remote.IP, result.IP) != nil { - continue + var sent bool + for _, result := range nodes { + if result.IP.Equal(t.net.tab.self.IP) || netutil.CheckRelayIP(remote.IP, result.IP) == nil { + p.Nodes = append(p.Nodes, nodeToRPC(result)) } - p.Nodes = append(p.Nodes, nodeToRPC(result)) - if len(p.Nodes) == maxTopicNodes || i == len(nodes)-1 { + if len(p.Nodes) == maxTopicNodes { t.sendPacket(remote.ID, remote.addr(), byte(topicNodesPacket), p) p.Nodes = p.Nodes[:0] + sent = true } } + if !sent || len(p.Nodes) > 0 { + t.sendPacket(remote.ID, remote.addr(), byte(topicNodesPacket), p) + } } func (t *udp) sendPacket(toid NodeID, toaddr *net.UDPAddr, ptype byte, req interface{}) (hash []byte, err error) { @@ -372,11 +366,9 @@ func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) (p, hash log.Error(fmt.Sprint("could not sign packet:", err)) return nil, nil, err } - copy(packet[macSize:], sig) - // add the hash to the front. Note: this doesn't protect the - // packet in any way. - hash = crypto.Keccak256(packet[macSize:]) - copy(packet, hash) + copy(packet, versionPrefix) + copy(packet[versionPrefixSize:], sig) + hash = crypto.Keccak256(packet[versionPrefixSize:]) return packet, hash, nil } @@ -420,17 +412,16 @@ func decodePacket(buffer []byte, pkt *ingressPacket) error { } buf := make([]byte, len(buffer)) copy(buf, buffer) - hash, sig, sigdata := buf[:macSize], buf[macSize:headSize], buf[headSize:] - shouldhash := crypto.Keccak256(buf[macSize:]) - if !bytes.Equal(hash, shouldhash) { - return errBadHash + prefix, sig, sigdata := buf[:versionPrefixSize], buf[versionPrefixSize:headSize], buf[headSize:] + if !bytes.Equal(prefix, versionPrefix) { + return errBadPrefix } fromID, err := recoverNodeID(crypto.Keccak256(buf[headSize:]), sig) if err != nil { return err } pkt.rawData = buf - pkt.hash = hash + pkt.hash = crypto.Keccak256(buf[versionPrefixSize:]) pkt.remoteID = fromID switch pkt.ev = nodeEvent(sigdata[0]); pkt.ev { case pingPacket: diff --git a/p2p/netutil/net.go b/p2p/netutil/net.go index f6005afd2..656abb682 100644 --- a/p2p/netutil/net.go +++ b/p2p/netutil/net.go @@ -18,8 +18,11 @@ package netutil import ( + "bytes" "errors" + "fmt" "net" + "sort" "strings" ) @@ -189,3 +192,131 @@ func CheckRelayIP(sender, addr net.IP) error { } return nil } + +// SameNet reports whether two IP addresses have an equal prefix of the given bit length. +func SameNet(bits uint, ip, other net.IP) bool { + ip4, other4 := ip.To4(), other.To4() + switch { + case (ip4 == nil) != (other4 == nil): + return false + case ip4 != nil: + return sameNet(bits, ip4, other4) + default: + return sameNet(bits, ip.To16(), other.To16()) + } +} + +func sameNet(bits uint, ip, other net.IP) bool { + nb := int(bits / 8) + mask := ^byte(0xFF >> (bits % 8)) + if mask != 0 && nb < len(ip) && ip[nb]&mask != other[nb]&mask { + return false + } + return nb <= len(ip) && bytes.Equal(ip[:nb], other[:nb]) +} + +// DistinctNetSet tracks IPs, ensuring that at most N of them +// fall into the same network range. +type DistinctNetSet struct { + Subnet uint // number of common prefix bits + Limit uint // maximum number of IPs in each subnet + + members map[string]uint + buf net.IP +} + +// Add adds an IP address to the set. It returns false (and doesn't add the IP) if the +// number of existing IPs in the defined range exceeds the limit. +func (s *DistinctNetSet) Add(ip net.IP) bool { + key := s.key(ip) + n := s.members[string(key)] + if n < s.Limit { + s.members[string(key)] = n + 1 + return true + } + return false +} + +// Remove removes an IP from the set. +func (s *DistinctNetSet) Remove(ip net.IP) { + key := s.key(ip) + if n, ok := s.members[string(key)]; ok { + if n == 1 { + delete(s.members, string(key)) + } else { + s.members[string(key)] = n - 1 + } + } +} + +// Contains whether the given IP is contained in the set. +func (s DistinctNetSet) Contains(ip net.IP) bool { + key := s.key(ip) + _, ok := s.members[string(key)] + return ok +} + +// Len returns the number of tracked IPs. +func (s DistinctNetSet) Len() int { + n := uint(0) + for _, i := range s.members { + n += i + } + return int(n) +} + +// key encodes the map key for an address into a temporary buffer. +// +// The first byte of key is '4' or '6' to distinguish IPv4/IPv6 address types. +// The remainder of the key is the IP, truncated to the number of bits. +func (s *DistinctNetSet) key(ip net.IP) net.IP { + // Lazily initialize storage. + if s.members == nil { + s.members = make(map[string]uint) + s.buf = make(net.IP, 17) + } + // Canonicalize ip and bits. + typ := byte('6') + if ip4 := ip.To4(); ip4 != nil { + typ, ip = '4', ip4 + } + bits := s.Subnet + if bits > uint(len(ip)*8) { + bits = uint(len(ip) * 8) + } + // Encode the prefix into s.buf. + nb := int(bits / 8) + mask := ^byte(0xFF >> (bits % 8)) + s.buf[0] = typ + buf := append(s.buf[:1], ip[:nb]...) + if nb < len(ip) && mask != 0 { + buf = append(buf, ip[nb]&mask) + } + return buf +} + +// String implements fmt.Stringer +func (s DistinctNetSet) String() string { + var buf bytes.Buffer + buf.WriteString("{") + keys := make([]string, 0, len(s.members)) + for k := range s.members { + keys = append(keys, k) + } + sort.Strings(keys) + for i, k := range keys { + var ip net.IP + if k[0] == '4' { + ip = make(net.IP, 4) + } else { + ip = make(net.IP, 16) + } + copy(ip, k[1:]) + fmt.Fprintf(&buf, "%vĂ—%d", ip, s.members[k]) + if i != len(keys)-1 { + buf.WriteString(" ") + } + } + buf.WriteString("}") + return buf.String() +} diff --git a/p2p/netutil/net_test.go b/p2p/netutil/net_test.go index 1ee1fcb4d..3a6aa081f 100644 --- a/p2p/netutil/net_test.go +++ b/p2p/netutil/net_test.go @@ -17,9 +17,11 @@ package netutil import ( + "fmt" "net" "reflect" "testing" + "testing/quick" "github.com/davecgh/go-spew/spew" ) @@ -171,3 +173,90 @@ func BenchmarkCheckRelayIP(b *testing.B) { CheckRelayIP(sender, addr) } } + +func TestSameNet(t *testing.T) { + tests := []struct { + ip, other string + bits uint + want bool + }{ + {"0.0.0.0", "0.0.0.0", 32, true}, + {"0.0.0.0", "0.0.0.1", 0, true}, + {"0.0.0.0", "0.0.0.1", 31, true}, + {"0.0.0.0", "0.0.0.1", 32, false}, + {"0.33.0.1", "0.34.0.2", 8, true}, + {"0.33.0.1", "0.34.0.2", 13, true}, + {"0.33.0.1", "0.34.0.2", 15, false}, + } + + for _, test := range tests { + if ok := SameNet(test.bits, parseIP(test.ip), parseIP(test.other)); ok != test.want { + t.Errorf("SameNet(%d, %s, %s) == %t, want %t", test.bits, test.ip, test.other, ok, test.want) + } + } +} + +func ExampleSameNet() { + // This returns true because the IPs are in the same /24 network: + fmt.Println(SameNet(24, net.IP{127, 0, 0, 1}, net.IP{127, 0, 0, 3})) + // This call returns false: + fmt.Println(SameNet(24, net.IP{127, 3, 0, 1}, net.IP{127, 5, 0, 3})) + // Output: + // true + // false +} + +func TestDistinctNetSet(t *testing.T) { + ops := []struct { + add, remove string + fails bool + }{ + {add: "127.0.0.1"}, + {add: "127.0.0.2"}, + {add: "127.0.0.3", fails: true}, + {add: "127.32.0.1"}, + {add: "127.32.0.2"}, + {add: "127.32.0.3", fails: true}, + {add: "127.33.0.1", fails: true}, + {add: "127.34.0.1"}, + {add: "127.34.0.2"}, + {add: "127.34.0.3", fails: true}, + // Make room for an address, then add again. + {remove: "127.0.0.1"}, + {add: "127.0.0.3"}, + {add: "127.0.0.3", fails: true}, + } + + set := DistinctNetSet{Subnet: 15, Limit: 2} + for _, op := range ops { + var desc string + if op.add != "" { + desc = fmt.Sprintf("Add(%s)", op.add) + if ok := set.Add(parseIP(op.add)); ok != !op.fails { + t.Errorf("%s == %t, want %t", desc, ok, !op.fails) + } + } else { + desc = fmt.Sprintf("Remove(%s)", op.remove) + set.Remove(parseIP(op.remove)) + } + t.Logf("%s: %v", desc, set) + } +} + +func TestDistinctNetSetAddRemove(t *testing.T) { + cfg := &quick.Config{} + fn := func(ips []net.IP) bool { + s := DistinctNetSet{Limit: 3, Subnet: 2} + for _, ip := range ips { + s.Add(ip) + } + for _, ip := range ips { + s.Remove(ip) + } + return s.Len() == 0 + } + + if err := quick.Check(fn, cfg); err != nil { + t.Fatal(err) + } +} diff --git a/p2p/peer.go b/p2p/peer.go index bad1c8c8b..477d8c219 100644 --- a/p2p/peer.go +++ b/p2p/peer.go @@ -419,6 +419,9 @@ type PeerInfo struct { Network struct { LocalAddress string `json:"localAddress"` // Local endpoint of the TCP data connection RemoteAddress string `json:"remoteAddress"` // Remote endpoint of the TCP data connection + Inbound bool `json:"inbound"` + Trusted bool `json:"trusted"` + Static bool `json:"static"` } `json:"network"` Protocols map[string]interface{} `json:"protocols"` // Sub-protocol specific metadata fields } @@ -439,6 +442,9 @@ func (p *Peer) Info() *PeerInfo { } info.Network.LocalAddress = p.LocalAddr().String() info.Network.RemoteAddress = p.RemoteAddr().String() + info.Network.Inbound = p.rw.is(inboundConn) + info.Network.Trusted = p.rw.is(trustedConn) + info.Network.Static = p.rw.is(staticDialedConn) // Gather all the running protocol infos for _, proto := range p.running { diff --git a/p2p/rlpx.go b/p2p/rlpx.go index 24037ecc1..e65a0b604 100644 --- a/p2p/rlpx.go +++ b/p2p/rlpx.go @@ -108,8 +108,14 @@ func (t *rlpx) close(err error) { // Tell the remote end why we're disconnecting if possible. if t.rw != nil { if r, ok := err.(DiscReason); ok && r != DiscNetworkError { - t.fd.SetWriteDeadline(time.Now().Add(discWriteTimeout)) - SendItems(t.rw, discMsg, r) + // rlpx tries to send DiscReason to disconnected peer + // if the connection is net.Pipe (in-memory simulation) + // it hangs forever, since net.Pipe does not implement + // a write deadline. Because of this only try to send + // the disconnect reason message if there is no error. + if err := t.fd.SetWriteDeadline(time.Now().Add(discWriteTimeout)); err == nil { + SendItems(t.rw, discMsg, r) + } } } t.fd.Close() diff --git a/p2p/rlpx_test.go b/p2p/rlpx_test.go index f4cefa650..bca460402 100644 --- a/p2p/rlpx_test.go +++ b/p2p/rlpx_test.go @@ -156,14 +156,18 @@ func TestProtocolHandshake(t *testing.T) { node1 = &discover.Node{ID: discover.PubkeyID(&prv1.PublicKey), IP: net.IP{5, 6, 7, 8}, TCP: 44} hs1 = &protoHandshake{Version: 3, ID: node1.ID, Caps: []Cap{{"c", 1}, {"d", 3}}} - fd0, fd1 = net.Pipe() - wg sync.WaitGroup + wg sync.WaitGroup ) + fd0, fd1, err := tcpPipe() + if err != nil { + t.Fatal(err) + } + wg.Add(2) go func() { defer wg.Done() - defer fd1.Close() + defer fd0.Close() rlpx := newRLPX(fd0) remid, err := rlpx.doEncHandshake(prv0, node1) if err != nil { @@ -597,3 +601,31 @@ func TestHandshakeForwardCompatibility(t *testing.T) { t.Errorf("ingress-mac('foo') mismatch:\ngot %x\nwant %x", fooIngressHash, wantFooIngressHash) } } + +// tcpPipe creates an in process full duplex pipe based on a localhost TCP socket +func tcpPipe() (net.Conn, net.Conn, error) { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, nil, err + } + defer l.Close() + + var aconn net.Conn + aerr := make(chan error, 1) + go func() { + var err error + aconn, err = l.Accept() + aerr <- err + }() + + dconn, err := net.Dial("tcp", l.Addr().String()) + if err != nil { + <-aerr + return nil, nil, err + } + if err := <-aerr; err != nil { + dconn.Close() + return nil, nil, err + } + return aconn, dconn, nil +} diff --git a/p2p/server.go b/p2p/server.go index 922df55ba..90e92dc05 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -40,11 +40,10 @@ const ( refreshPeersInterval = 30 * time.Second staticPeerCheckInterval = 15 * time.Second - // Maximum number of concurrently handshaking inbound connections. - maxAcceptConns = 50 - - // Maximum number of concurrently dialing outbound connections. - maxActiveDialTasks = 16 + // Connectivity defaults. + maxActiveDialTasks = 16 + defaultMaxPendingPeers = 50 + defaultDialRatio = 3 // Maximum time allowed for reading a complete message. // This is effectively the amount of time a connection can be idle. @@ -70,6 +69,11 @@ type Config struct { // Zero defaults to preset values. MaxPendingPeers int `toml:",omitempty"` + // DialRatio controls the ratio of inbound to dialed connections. + // Example: a DialRatio of 2 allows 1/2 of connections to be dialed. + // Setting DialRatio to zero defaults it to 3. + DialRatio int `toml:",omitempty"` + // NoDiscovery can be used to disable the peer discovery mechanism. // Disabling is useful for protocol debugging (manual topology). NoDiscovery bool @@ -78,9 +82,6 @@ type Config struct { // protocol should be started or not. DiscoveryV5 bool `toml:",omitempty"` - // Listener address for the V5 discovery protocol UDP traffic. - DiscoveryV5Addr string `toml:",omitempty"` - // Name sets the node name of this server. // Use common.MakeName to create a name that follows existing conventions. Name string `toml:"-"` @@ -141,7 +142,7 @@ type Config struct { EnableMsgEvents bool // Logger is a custom logger to use with the p2p.Server. - Logger log.Logger + Logger log.Logger `toml:",omitempty"` } // Server manages all peer connections. @@ -354,6 +355,32 @@ func (srv *Server) Stop() { srv.loopWG.Wait() } +// sharedUDPConn implements a shared connection. Write sends messages to the underlying connection while read returns +// messages that were found unprocessable and sent to the unhandled channel by the primary listener. +type sharedUDPConn struct { + *net.UDPConn + unhandled chan discover.ReadPacket +} + +// ReadFromUDP implements discv5.conn +func (s *sharedUDPConn) ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) { + packet, ok := <-s.unhandled + if !ok { + return 0, nil, fmt.Errorf("Connection was closed") + } + l := len(packet.Data) + if l > len(b) { + l = len(b) + } + copy(b[:l], packet.Data[:l]) + return l, packet.Addr, nil +} + +// Close implements discv5.conn +func (s *sharedUDPConn) Close() error { + return nil +} + // Start starts running the server. // Servers can not be re-used after stopping. func (srv *Server) Start() (err error) { @@ -388,20 +415,66 @@ func (srv *Server) Start() (err error) { srv.peerOp = make(chan peerOpFunc) srv.peerOpDone = make(chan struct{}) - // node table - if !srv.NoDiscovery { - ntab, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT, srv.NodeDatabase, srv.NetRestrict) + var ( + conn *net.UDPConn + sconn *sharedUDPConn + realaddr *net.UDPAddr + unhandled chan discover.ReadPacket + ) + + if !srv.NoDiscovery || srv.DiscoveryV5 { + addr, err := net.ResolveUDPAddr("udp", srv.ListenAddr) if err != nil { return err } - if err := ntab.SetFallbackNodes(srv.BootstrapNodes); err != nil { + conn, err = net.ListenUDP("udp", addr) + if err != nil { + return err + } + realaddr = conn.LocalAddr().(*net.UDPAddr) + if srv.NAT != nil { + if !realaddr.IP.IsLoopback() { + go nat.Map(srv.NAT, srv.quit, "udp", realaddr.Port, realaddr.Port, "ethereum discovery") + } + // TODO: react to external IP changes over time. + if ext, err := srv.NAT.ExternalIP(); err == nil { + realaddr = &net.UDPAddr{IP: ext, Port: realaddr.Port} + } + } + } + + if !srv.NoDiscovery && srv.DiscoveryV5 { + unhandled = make(chan discover.ReadPacket, 100) + sconn = &sharedUDPConn{conn, unhandled} + } + + // node table + if !srv.NoDiscovery { + cfg := discover.Config{ + PrivateKey: srv.PrivateKey, + AnnounceAddr: realaddr, + NodeDBPath: srv.NodeDatabase, + NetRestrict: srv.NetRestrict, + Bootnodes: srv.BootstrapNodes, + Unhandled: unhandled, + } + ntab, err := discover.ListenUDP(conn, cfg) + if err != nil { return err } srv.ntab = ntab } if srv.DiscoveryV5 { - ntab, err := discv5.ListenUDP(srv.PrivateKey, srv.DiscoveryV5Addr, srv.NAT, "", srv.NetRestrict) //srv.NodeDatabase) + var ( + ntab *discv5.Network + err error + ) + if sconn != nil { + ntab, err = discv5.ListenUDP(srv.PrivateKey, sconn, realaddr, "", srv.NetRestrict) //srv.NodeDatabase) + } else { + ntab, err = discv5.ListenUDP(srv.PrivateKey, conn, realaddr, "", srv.NetRestrict) //srv.NodeDatabase) + } if err != nil { return err } @@ -411,10 +484,7 @@ func (srv *Server) Start() (err error) { srv.DiscV5 = ntab } - dynPeers := (srv.MaxPeers + 1) / 2 - if srv.NoDiscovery { - dynPeers = 0 - } + dynPeers := srv.maxDialedConns() dialer := newDialState(srv.StaticNodes, srv.BootstrapNodes, srv.ntab, dynPeers, srv.NetRestrict) // handshake @@ -471,6 +541,7 @@ func (srv *Server) run(dialstate dialer) { defer srv.loopWG.Done() var ( peers = make(map[discover.NodeID]*Peer) + inboundCount = 0 trusted = make(map[discover.NodeID]bool, len(srv.TrustedNodes)) taskdone = make(chan task, maxActiveDialTasks) runningTasks []task @@ -556,14 +627,14 @@ running: } // TODO: track in-progress inbound node IDs (pre-Peer) to avoid dialing them. select { - case c.cont <- srv.encHandshakeChecks(peers, c): + case c.cont <- srv.encHandshakeChecks(peers, inboundCount, c): case <-srv.quit: break running } case c := <-srv.addpeer: // At this point the connection is past the protocol handshake. // Its capabilities are known and the remote identity is verified. - err := srv.protoHandshakeChecks(peers, c) + err := srv.protoHandshakeChecks(peers, inboundCount, c) if err == nil { // The handshakes are done and it passed all checks. p := newPeer(c, srv.Protocols) @@ -574,8 +645,11 @@ running: } name := truncateName(c.name) srv.log.Debug("Adding p2p peer", "name", name, "addr", c.fd.RemoteAddr(), "peers", len(peers)+1) - peers[c.id] = p go srv.runPeer(p) + peers[c.id] = p + if p.Inbound() { + inboundCount++ + } } // The dialer logic relies on the assumption that // dial tasks complete after the peer has been added or @@ -590,6 +664,9 @@ running: d := common.PrettyDuration(mclock.Now() - pd.created) pd.log.Debug("Removing p2p peer", "duration", d, "peers", len(peers)-1, "req", pd.requested, "err", pd.err) delete(peers, pd.ID()) + if pd.Inbound() { + inboundCount-- + } } } @@ -616,20 +693,22 @@ running: } } -func (srv *Server) protoHandshakeChecks(peers map[discover.NodeID]*Peer, c *conn) error { +func (srv *Server) protoHandshakeChecks(peers map[discover.NodeID]*Peer, inboundCount int, c *conn) error { // Drop connections with no matching protocols. if len(srv.Protocols) > 0 && countMatchingProtocols(srv.Protocols, c.caps) == 0 { return DiscUselessPeer } // Repeat the encryption handshake checks because the // peer set might have changed between the handshakes. - return srv.encHandshakeChecks(peers, c) + return srv.encHandshakeChecks(peers, inboundCount, c) } -func (srv *Server) encHandshakeChecks(peers map[discover.NodeID]*Peer, c *conn) error { +func (srv *Server) encHandshakeChecks(peers map[discover.NodeID]*Peer, inboundCount int, c *conn) error { switch { case !c.is(trustedConn|staticDialedConn) && len(peers) >= srv.MaxPeers: return DiscTooManyPeers + case !c.is(trustedConn) && c.is(inboundConn) && inboundCount >= srv.maxInboundConns(): + return DiscTooManyPeers case peers[c.id] != nil: return DiscAlreadyConnected case c.id == srv.Self().ID: @@ -639,6 +718,21 @@ func (srv *Server) encHandshakeChecks(peers map[discover.NodeID]*Peer, c *conn) } } +func (srv *Server) maxInboundConns() int { + return srv.MaxPeers - srv.maxDialedConns() +} + +func (srv *Server) maxDialedConns() int { + if srv.NoDiscovery || srv.NoDial { + return 0 + } + r := srv.DialRatio + if r == 0 { + r = defaultDialRatio + } + return srv.MaxPeers / r +} + type tempError interface { Temporary() bool } @@ -649,10 +743,7 @@ func (srv *Server) listenLoop() { defer srv.loopWG.Done() srv.log.Info("RLPx listener up", "self", srv.makeSelf(srv.listener, srv.ntab)) - // This channel acts as a semaphore limiting - // active inbound connections that are lingering pre-handshake. - // If all slots are taken, no further connections are accepted. - tokens := maxAcceptConns + tokens := defaultMaxPendingPeers if srv.MaxPendingPeers > 0 { tokens = srv.MaxPendingPeers } @@ -693,9 +784,6 @@ func (srv *Server) listenLoop() { fd = newMeteredConn(fd, true) srv.log.Trace("Accepted connection", "addr", fd.RemoteAddr()) - - // Spawn the handler. It will give the slot back when the connection - // has been established. go func() { srv.SetupConn(fd, inboundConn, nil) slots <- struct{}{} diff --git a/p2p/simulations/adapters/state.go b/p2p/simulations/adapters/state.go index 8b1dfef90..0d4ecfb0f 100644 --- a/p2p/simulations/adapters/state.go +++ b/p2p/simulations/adapters/state.go @@ -13,6 +13,7 @@ // // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + package adapters type SimStateStore struct { |