aboutsummaryrefslogtreecommitdiffstats
path: root/p2p
diff options
context:
space:
mode:
Diffstat (limited to 'p2p')
-rw-r--r--p2p/discover/database.go24
-rw-r--r--p2p/discover/database_test.go34
-rw-r--r--p2p/discover/table.go239
-rw-r--r--p2p/discover/table_test.go26
-rw-r--r--p2p/discover/udp.go59
-rw-r--r--p2p/discover/udp_test.go10
-rw-r--r--p2p/metrics.go10
-rw-r--r--p2p/nat/natpmp.go3
-rw-r--r--p2p/nat/natupnp.go7
-rw-r--r--p2p/peer.go9
-rw-r--r--p2p/protocols/protocol.go96
-rw-r--r--p2p/protocols/protocol_test.go4
-rw-r--r--p2p/rlpx.go15
-rw-r--r--p2p/server.go3
-rw-r--r--p2p/simulations/adapters/inproc.go21
-rw-r--r--p2p/simulations/adapters/state.go36
-rw-r--r--p2p/simulations/network.go67
-rw-r--r--p2p/testing/protocolsession.go2
-rw-r--r--p2p/testing/protocoltester.go21
19 files changed, 333 insertions, 353 deletions
diff --git a/p2p/discover/database.go b/p2p/discover/database.go
index 6f98de9b4..22554145f 100644
--- a/p2p/discover/database.go
+++ b/p2p/discover/database.go
@@ -42,6 +42,7 @@ var (
nodeDBNilNodeID = NodeID{} // Special node ID to use as a nil element.
nodeDBNodeExpiration = 24 * time.Hour // Time after which an unseen node should be dropped.
nodeDBCleanupCycle = time.Hour // Time period for running the expiration task.
+ nodeDBVersion = 5
)
// nodeDB stores all nodes we know about.
@@ -257,7 +258,7 @@ func (db *nodeDB) expireNodes() error {
}
// Skip the node if not expired yet (and not self)
if !bytes.Equal(id[:], db.self[:]) {
- if seen := db.bondTime(id); seen.After(threshold) {
+ if seen := db.lastPongReceived(id); seen.After(threshold) {
continue
}
}
@@ -267,29 +268,28 @@ func (db *nodeDB) expireNodes() error {
return nil
}
-// lastPing retrieves the time of the last ping packet send to a remote node,
-// requesting binding.
-func (db *nodeDB) lastPing(id NodeID) time.Time {
+// lastPingReceived retrieves the time of the last ping packet sent by the remote node.
+func (db *nodeDB) lastPingReceived(id NodeID) time.Time {
return time.Unix(db.fetchInt64(makeKey(id, nodeDBDiscoverPing)), 0)
}
-// updateLastPing updates the last time we tried contacting a remote node.
-func (db *nodeDB) updateLastPing(id NodeID, instance time.Time) error {
+// updateLastPing updates the last time remote node pinged us.
+func (db *nodeDB) updateLastPingReceived(id NodeID, instance time.Time) error {
return db.storeInt64(makeKey(id, nodeDBDiscoverPing), instance.Unix())
}
-// bondTime retrieves the time of the last successful pong from remote node.
-func (db *nodeDB) bondTime(id NodeID) time.Time {
+// lastPongReceived retrieves the time of the last successful pong from remote node.
+func (db *nodeDB) lastPongReceived(id NodeID) time.Time {
return time.Unix(db.fetchInt64(makeKey(id, nodeDBDiscoverPong)), 0)
}
// hasBond reports whether the given node is considered bonded.
func (db *nodeDB) hasBond(id NodeID) bool {
- return time.Since(db.bondTime(id)) < nodeDBNodeExpiration
+ return time.Since(db.lastPongReceived(id)) < nodeDBNodeExpiration
}
-// updateBondTime updates the last pong time of a node.
-func (db *nodeDB) updateBondTime(id NodeID, instance time.Time) error {
+// updateLastPongReceived updates the last pong time of a node.
+func (db *nodeDB) updateLastPongReceived(id NodeID, instance time.Time) error {
return db.storeInt64(makeKey(id, nodeDBDiscoverPong), instance.Unix())
}
@@ -332,7 +332,7 @@ seek:
if n.ID == db.self {
continue seek
}
- if now.Sub(db.bondTime(n.ID)) > maxAge {
+ if now.Sub(db.lastPongReceived(n.ID)) > maxAge {
continue seek
}
for i := range nodes {
diff --git a/p2p/discover/database_test.go b/p2p/discover/database_test.go
index c4fa44d09..27974344e 100644
--- a/p2p/discover/database_test.go
+++ b/p2p/discover/database_test.go
@@ -79,7 +79,7 @@ var nodeDBInt64Tests = []struct {
}
func TestNodeDBInt64(t *testing.T) {
- db, _ := newNodeDB("", Version, NodeID{})
+ db, _ := newNodeDB("", nodeDBVersion, NodeID{})
defer db.close()
tests := nodeDBInt64Tests
@@ -111,27 +111,27 @@ func TestNodeDBFetchStore(t *testing.T) {
inst := time.Now()
num := 314
- db, _ := newNodeDB("", Version, NodeID{})
+ db, _ := newNodeDB("", nodeDBVersion, NodeID{})
defer db.close()
// Check fetch/store operations on a node ping object
- if stored := db.lastPing(node.ID); stored.Unix() != 0 {
+ if stored := db.lastPingReceived(node.ID); stored.Unix() != 0 {
t.Errorf("ping: non-existing object: %v", stored)
}
- if err := db.updateLastPing(node.ID, inst); err != nil {
+ if err := db.updateLastPingReceived(node.ID, inst); err != nil {
t.Errorf("ping: failed to update: %v", err)
}
- if stored := db.lastPing(node.ID); stored.Unix() != inst.Unix() {
+ if stored := db.lastPingReceived(node.ID); stored.Unix() != inst.Unix() {
t.Errorf("ping: value mismatch: have %v, want %v", stored, inst)
}
// Check fetch/store operations on a node pong object
- if stored := db.bondTime(node.ID); stored.Unix() != 0 {
+ if stored := db.lastPongReceived(node.ID); stored.Unix() != 0 {
t.Errorf("pong: non-existing object: %v", stored)
}
- if err := db.updateBondTime(node.ID, inst); err != nil {
+ if err := db.updateLastPongReceived(node.ID, inst); err != nil {
t.Errorf("pong: failed to update: %v", err)
}
- if stored := db.bondTime(node.ID); stored.Unix() != inst.Unix() {
+ if stored := db.lastPongReceived(node.ID); stored.Unix() != inst.Unix() {
t.Errorf("pong: value mismatch: have %v, want %v", stored, inst)
}
// Check fetch/store operations on a node findnode-failure object
@@ -216,7 +216,7 @@ var nodeDBSeedQueryNodes = []struct {
}
func TestNodeDBSeedQuery(t *testing.T) {
- db, _ := newNodeDB("", Version, nodeDBSeedQueryNodes[1].node.ID)
+ db, _ := newNodeDB("", nodeDBVersion, nodeDBSeedQueryNodes[1].node.ID)
defer db.close()
// Insert a batch of nodes for querying
@@ -224,7 +224,7 @@ func TestNodeDBSeedQuery(t *testing.T) {
if err := db.updateNode(seed.node); err != nil {
t.Fatalf("node %d: failed to insert: %v", i, err)
}
- if err := db.updateBondTime(seed.node.ID, seed.pong); err != nil {
+ if err := db.updateLastPongReceived(seed.node.ID, seed.pong); err != nil {
t.Fatalf("node %d: failed to insert bondTime: %v", i, err)
}
}
@@ -267,7 +267,7 @@ func TestNodeDBPersistency(t *testing.T) {
)
// Create a persistent database and store some values
- db, err := newNodeDB(filepath.Join(root, "database"), Version, NodeID{})
+ db, err := newNodeDB(filepath.Join(root, "database"), nodeDBVersion, NodeID{})
if err != nil {
t.Fatalf("failed to create persistent database: %v", err)
}
@@ -277,7 +277,7 @@ func TestNodeDBPersistency(t *testing.T) {
db.close()
// Reopen the database and check the value
- db, err = newNodeDB(filepath.Join(root, "database"), Version, NodeID{})
+ db, err = newNodeDB(filepath.Join(root, "database"), nodeDBVersion, NodeID{})
if err != nil {
t.Fatalf("failed to open persistent database: %v", err)
}
@@ -287,7 +287,7 @@ func TestNodeDBPersistency(t *testing.T) {
db.close()
// Change the database version and check flush
- db, err = newNodeDB(filepath.Join(root, "database"), Version+1, NodeID{})
+ db, err = newNodeDB(filepath.Join(root, "database"), nodeDBVersion+1, NodeID{})
if err != nil {
t.Fatalf("failed to open persistent database: %v", err)
}
@@ -324,7 +324,7 @@ var nodeDBExpirationNodes = []struct {
}
func TestNodeDBExpiration(t *testing.T) {
- db, _ := newNodeDB("", Version, NodeID{})
+ db, _ := newNodeDB("", nodeDBVersion, NodeID{})
defer db.close()
// Add all the test nodes and set their last pong time
@@ -332,7 +332,7 @@ func TestNodeDBExpiration(t *testing.T) {
if err := db.updateNode(seed.node); err != nil {
t.Fatalf("node %d: failed to insert: %v", i, err)
}
- if err := db.updateBondTime(seed.node.ID, seed.pong); err != nil {
+ if err := db.updateLastPongReceived(seed.node.ID, seed.pong); err != nil {
t.Fatalf("node %d: failed to update bondTime: %v", i, err)
}
}
@@ -357,7 +357,7 @@ func TestNodeDBSelfExpiration(t *testing.T) {
break
}
}
- db, _ := newNodeDB("", Version, self)
+ db, _ := newNodeDB("", nodeDBVersion, self)
defer db.close()
// Add all the test nodes and set their last pong time
@@ -365,7 +365,7 @@ func TestNodeDBSelfExpiration(t *testing.T) {
if err := db.updateNode(seed.node); err != nil {
t.Fatalf("node %d: failed to insert: %v", i, err)
}
- if err := db.updateBondTime(seed.node.ID, seed.pong); err != nil {
+ if err := db.updateLastPongReceived(seed.node.ID, seed.pong); err != nil {
t.Fatalf("node %d: failed to update bondTime: %v", i, err)
}
}
diff --git a/p2p/discover/table.go b/p2p/discover/table.go
index 18920ccfd..8803daa56 100644
--- a/p2p/discover/table.go
+++ b/p2p/discover/table.go
@@ -25,7 +25,6 @@ package discover
import (
crand "crypto/rand"
"encoding/binary"
- "errors"
"fmt"
mrand "math/rand"
"net"
@@ -54,15 +53,13 @@ const (
bucketIPLimit, bucketSubnet = 2, 24 // at most 2 addresses from the same /24
tableIPLimit, tableSubnet = 10, 24
- maxBondingPingPongs = 16 // Limit on the number of concurrent ping/pong interactions
- maxFindnodeFailures = 5 // Nodes exceeding this limit are dropped
-
- refreshInterval = 30 * time.Minute
- revalidateInterval = 10 * time.Second
- copyNodesInterval = 30 * time.Second
- seedMinTableTime = 5 * time.Minute
- seedCount = 30
- seedMaxAge = 5 * 24 * time.Hour
+ maxFindnodeFailures = 5 // Nodes exceeding this limit are dropped
+ refreshInterval = 30 * time.Minute
+ revalidateInterval = 10 * time.Second
+ copyNodesInterval = 30 * time.Second
+ seedMinTableTime = 5 * time.Minute
+ seedCount = 30
+ seedMaxAge = 5 * 24 * time.Hour
)
type Table struct {
@@ -78,28 +75,17 @@ type Table struct {
closeReq chan struct{}
closed chan struct{}
- bondmu sync.Mutex
- bonding map[NodeID]*bondproc
- bondslots chan struct{} // limits total number of active bonding processes
-
nodeAddedHook func(*Node) // for testing
net transport
self *Node // metadata of the local node
}
-type bondproc struct {
- err error
- n *Node
- done chan struct{}
-}
-
// transport is implemented by the UDP transport.
// it is an interface so we can test without opening lots of UDP
// sockets and without generating a private key.
type transport interface {
ping(NodeID, *net.UDPAddr) error
- waitping(NodeID) error
findnode(toid NodeID, addr *net.UDPAddr, target NodeID) ([]*Node, error)
close()
}
@@ -114,7 +100,7 @@ type bucket struct {
func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string, bootnodes []*Node) (*Table, error) {
// If no node database was given, use an in-memory one
- db, err := newNodeDB(nodeDBPath, Version, ourID)
+ db, err := newNodeDB(nodeDBPath, nodeDBVersion, ourID)
if err != nil {
return nil, err
}
@@ -122,8 +108,6 @@ func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string
net: t,
db: db,
self: NewNode(ourID, ourAddr.IP, uint16(ourAddr.Port), uint16(ourAddr.Port)),
- bonding: make(map[NodeID]*bondproc),
- bondslots: make(chan struct{}, maxBondingPingPongs),
refreshReq: make(chan chan struct{}),
initDone: make(chan struct{}),
closeReq: make(chan struct{}),
@@ -134,16 +118,13 @@ func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string
if err := tab.setFallbackNodes(bootnodes); err != nil {
return nil, err
}
- for i := 0; i < cap(tab.bondslots); i++ {
- tab.bondslots <- struct{}{}
- }
for i := range tab.buckets {
tab.buckets[i] = &bucket{
ips: netutil.DistinctNetSet{Subnet: bucketSubnet, Limit: bucketIPLimit},
}
}
tab.seedRand()
- tab.loadSeedNodes(false)
+ tab.loadSeedNodes()
// Start the background expiration goroutine after loading seeds so that the search for
// seed nodes also considers older nodes that would otherwise be removed by the
// expiration.
@@ -315,22 +296,7 @@ func (tab *Table) lookup(targetID NodeID, refreshIfEmpty bool) []*Node {
if !asked[n.ID] {
asked[n.ID] = true
pendingQueries++
- go func() {
- // Find potential neighbors to bond with
- r, err := tab.net.findnode(n.ID, n.addr(), targetID)
- if err != nil {
- // Bump the failure counter to detect and evacuate non-bonded entries
- fails := tab.db.findFails(n.ID) + 1
- tab.db.updateFindFails(n.ID, fails)
- log.Trace("Bumping findnode failure counter", "id", n.ID, "failcount", fails)
-
- if fails >= maxFindnodeFailures {
- log.Trace("Too many findnode failures, dropping", "id", n.ID, "failcount", fails)
- tab.delete(n)
- }
- }
- reply <- tab.bondall(r)
- }()
+ go tab.findnode(n, targetID, reply)
}
}
if pendingQueries == 0 {
@@ -349,6 +315,29 @@ func (tab *Table) lookup(targetID NodeID, refreshIfEmpty bool) []*Node {
return result.entries
}
+func (tab *Table) findnode(n *Node, targetID NodeID, reply chan<- []*Node) {
+ fails := tab.db.findFails(n.ID)
+ r, err := tab.net.findnode(n.ID, n.addr(), targetID)
+ if err != nil || len(r) == 0 {
+ fails++
+ tab.db.updateFindFails(n.ID, fails)
+ log.Trace("Findnode failed", "id", n.ID, "failcount", fails, "err", err)
+ if fails >= maxFindnodeFailures {
+ log.Trace("Too many findnode failures, dropping", "id", n.ID, "failcount", fails)
+ tab.delete(n)
+ }
+ } else if fails > 0 {
+ tab.db.updateFindFails(n.ID, fails-1)
+ }
+
+ // Grab as many nodes as possible. Some of them might not be alive anymore, but we'll
+ // just remove those again during revalidation.
+ for _, n := range r {
+ tab.add(n)
+ }
+ reply <- r
+}
+
func (tab *Table) refresh() <-chan struct{} {
done := make(chan struct{})
select {
@@ -401,7 +390,7 @@ loop:
case <-revalidateDone:
revalidate.Reset(tab.nextRevalidateTime())
case <-copyNodes.C:
- go tab.copyBondedNodes()
+ go tab.copyLiveNodes()
case <-tab.closeReq:
break loop
}
@@ -429,7 +418,7 @@ func (tab *Table) doRefresh(done chan struct{}) {
// Load nodes from the database and insert
// them. This should yield a few previously seen nodes that are
// (hopefully) still alive.
- tab.loadSeedNodes(true)
+ tab.loadSeedNodes()
// Run self lookup to discover new neighbor nodes.
tab.lookup(tab.self.ID, false)
@@ -447,15 +436,12 @@ func (tab *Table) doRefresh(done chan struct{}) {
}
}
-func (tab *Table) loadSeedNodes(bond bool) {
+func (tab *Table) loadSeedNodes() {
seeds := tab.db.querySeeds(seedCount, seedMaxAge)
seeds = append(seeds, tab.nursery...)
- if bond {
- seeds = tab.bondall(seeds)
- }
for i := range seeds {
seed := seeds[i]
- age := log.Lazy{Fn: func() interface{} { return time.Since(tab.db.bondTime(seed.ID)) }}
+ age := log.Lazy{Fn: func() interface{} { return time.Since(tab.db.lastPongReceived(seed.ID)) }}
log.Debug("Found seed node in database", "id", seed.ID, "addr", seed.addr(), "age", age)
tab.add(seed)
}
@@ -473,7 +459,7 @@ func (tab *Table) doRevalidate(done chan<- struct{}) {
}
// Ping the selected node and wait for a pong.
- err := tab.ping(last.ID, last.addr())
+ err := tab.net.ping(last.ID, last.addr())
tab.mutex.Lock()
defer tab.mutex.Unlock()
@@ -515,9 +501,9 @@ func (tab *Table) nextRevalidateTime() time.Duration {
return time.Duration(tab.rand.Int63n(int64(revalidateInterval)))
}
-// copyBondedNodes adds nodes from the table to the database if they have been in the table
+// copyLiveNodes adds nodes from the table to the database if they have been in the table
// longer then minTableTime.
-func (tab *Table) copyBondedNodes() {
+func (tab *Table) copyLiveNodes() {
tab.mutex.Lock()
defer tab.mutex.Unlock()
@@ -553,120 +539,6 @@ func (tab *Table) len() (n int) {
return n
}
-// bondall bonds with all given nodes concurrently and returns
-// those nodes for which bonding has probably succeeded.
-func (tab *Table) bondall(nodes []*Node) (result []*Node) {
- rc := make(chan *Node, len(nodes))
- for i := range nodes {
- go func(n *Node) {
- nn, _ := tab.bond(false, n.ID, n.addr(), n.TCP)
- rc <- nn
- }(nodes[i])
- }
- for range nodes {
- if n := <-rc; n != nil {
- result = append(result, n)
- }
- }
- return result
-}
-
-// bond ensures the local node has a bond with the given remote node.
-// It also attempts to insert the node into the table if bonding succeeds.
-// The caller must not hold tab.mutex.
-//
-// A bond is must be established before sending findnode requests.
-// Both sides must have completed a ping/pong exchange for a bond to
-// exist. The total number of active bonding processes is limited in
-// order to restrain network use.
-//
-// bond is meant to operate idempotently in that bonding with a remote
-// node which still remembers a previously established bond will work.
-// The remote node will simply not send a ping back, causing waitping
-// to time out.
-//
-// If pinged is true, the remote node has just pinged us and one half
-// of the process can be skipped.
-func (tab *Table) bond(pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16) (*Node, error) {
- if id == tab.self.ID {
- return nil, errors.New("is self")
- }
- if pinged && !tab.isInitDone() {
- return nil, errors.New("still initializing")
- }
- // Start bonding if we haven't seen this node for a while or if it failed findnode too often.
- node, fails := tab.db.node(id), tab.db.findFails(id)
- age := time.Since(tab.db.bondTime(id))
- var result error
- if fails > 0 || age > nodeDBNodeExpiration {
- log.Trace("Starting bonding ping/pong", "id", id, "known", node != nil, "failcount", fails, "age", age)
-
- tab.bondmu.Lock()
- w := tab.bonding[id]
- if w != nil {
- // Wait for an existing bonding process to complete.
- tab.bondmu.Unlock()
- <-w.done
- } else {
- // Register a new bonding process.
- w = &bondproc{done: make(chan struct{})}
- tab.bonding[id] = w
- tab.bondmu.Unlock()
- // Do the ping/pong. The result goes into w.
- tab.pingpong(w, pinged, id, addr, tcpPort)
- // Unregister the process after it's done.
- tab.bondmu.Lock()
- delete(tab.bonding, id)
- tab.bondmu.Unlock()
- }
- // Retrieve the bonding results
- result = w.err
- if result == nil {
- node = w.n
- }
- }
- // Add the node to the table even if the bonding ping/pong
- // fails. It will be relaced quickly if it continues to be
- // unresponsive.
- if node != nil {
- tab.add(node)
- tab.db.updateFindFails(id, 0)
- }
- return node, result
-}
-
-func (tab *Table) pingpong(w *bondproc, pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16) {
- // Request a bonding slot to limit network usage
- <-tab.bondslots
- defer func() { tab.bondslots <- struct{}{} }()
-
- // Ping the remote side and wait for a pong.
- if w.err = tab.ping(id, addr); w.err != nil {
- close(w.done)
- return
- }
- if !pinged {
- // Give the remote node a chance to ping us before we start
- // sending findnode requests. If they still remember us,
- // waitping will simply time out.
- tab.net.waitping(id)
- }
- // Bonding succeeded, update the node database.
- w.n = NewNode(id, addr.IP, uint16(addr.Port), tcpPort)
- close(w.done)
-}
-
-// ping a remote endpoint and wait for a reply, also updating the node
-// database accordingly.
-func (tab *Table) ping(id NodeID, addr *net.UDPAddr) error {
- tab.db.updateLastPing(id, time.Now())
- if err := tab.net.ping(id, addr); err != nil {
- return err
- }
- tab.db.updateBondTime(id, time.Now())
- return nil
-}
-
// bucket returns the bucket for the given node ID hash.
func (tab *Table) bucket(sha common.Hash) *bucket {
d := logdist(tab.self.sha, sha)
@@ -676,21 +548,33 @@ func (tab *Table) bucket(sha common.Hash) *bucket {
return tab.buckets[d-bucketMinDistance-1]
}
-// add attempts to add the given node its corresponding bucket. If the
-// bucket has space available, adding the node succeeds immediately.
-// Otherwise, the node is added if the least recently active node in
-// the bucket does not respond to a ping packet.
+// add attempts to add the given node to its corresponding bucket. If the bucket has space
+// available, adding the node succeeds immediately. Otherwise, the node is added if the
+// least recently active node in the bucket does not respond to a ping packet.
//
// The caller must not hold tab.mutex.
-func (tab *Table) add(new *Node) {
+func (tab *Table) add(n *Node) {
tab.mutex.Lock()
defer tab.mutex.Unlock()
- b := tab.bucket(new.sha)
- if !tab.bumpOrAdd(b, new) {
+ b := tab.bucket(n.sha)
+ if !tab.bumpOrAdd(b, n) {
// Node is not in table. Add it to the replacement list.
- tab.addReplacement(b, new)
+ tab.addReplacement(b, n)
+ }
+}
+
+// addThroughPing adds the given node to the table. Compared to plain
+// 'add' there is an additional safety measure: if the table is still
+// initializing the node is not added. This prevents an attack where the
+// table could be filled by just sending ping repeatedly.
+//
+// The caller must not hold tab.mutex.
+func (tab *Table) addThroughPing(n *Node) {
+ if !tab.isInitDone() {
+ return
}
+ tab.add(n)
}
// stuff adds nodes the table to the end of their corresponding bucket
@@ -710,8 +594,7 @@ func (tab *Table) stuff(nodes []*Node) {
}
}
-// delete removes an entry from the node table (used to evacuate
-// failed/non-bonded discovery peers).
+// delete removes an entry from the node table. It is used to evacuate dead nodes.
func (tab *Table) delete(node *Node) {
tab.mutex.Lock()
defer tab.mutex.Unlock()
diff --git a/p2p/discover/table_test.go b/p2p/discover/table_test.go
index f2d3f9a2a..ed55ebd9a 100644
--- a/p2p/discover/table_test.go
+++ b/p2p/discover/table_test.go
@@ -52,27 +52,22 @@ func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding
tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil)
defer tab.Close()
- // Wait for init so bond is accepted.
<-tab.initDone
- // fill up the sender's bucket.
+ // Fill up the sender's bucket.
pingSender := NewNode(MustHexID("a502af0f59b2aab7746995408c79e9ca312d2793cc997e44fc55eda62f0150bbb8c59a6f9269ba3a081518b62699ee807c7c19c20125ddfccca872608af9e370"), net.IP{}, 99, 99)
last := fillBucket(tab, pingSender)
- // this call to bond should replace the last node
- // in its bucket if the node is not responding.
+ // Add the sender as if it just pinged us. Revalidate should replace the last node in
+ // its bucket if it is unresponsive. Revalidate again to ensure that
transport.dead[last.ID] = !lastInBucketIsResponding
transport.dead[pingSender.ID] = !newNodeIsResponding
- tab.bond(true, pingSender.ID, &net.UDPAddr{}, 0)
+ tab.add(pingSender)
+ tab.doRevalidate(make(chan struct{}, 1))
tab.doRevalidate(make(chan struct{}, 1))
- // first ping goes to sender (bonding pingback)
- if !transport.pinged[pingSender.ID] {
- t.Error("table did not ping back sender")
- }
if !transport.pinged[last.ID] {
- // second ping goes to oldest node in bucket
- // to see whether it is still alive.
+ // Oldest node in bucket is pinged to see whether it is still alive.
t.Error("table did not ping last node in bucket")
}
@@ -83,7 +78,7 @@ func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding
wantSize--
}
if l := len(tab.bucket(pingSender.sha).entries); l != wantSize {
- t.Errorf("wrong bucket size after bond: got %d, want %d", l, wantSize)
+ t.Errorf("wrong bucket size after add: got %d, want %d", l, wantSize)
}
if found := contains(tab.bucket(pingSender.sha).entries, last.ID); found != lastInBucketIsResponding {
t.Errorf("last entry found: %t, want: %t", found, lastInBucketIsResponding)
@@ -206,10 +201,7 @@ func newPingRecorder() *pingRecorder {
func (t *pingRecorder) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) {
return nil, nil
}
-func (t *pingRecorder) close() {}
-func (t *pingRecorder) waitping(from NodeID) error {
- return nil // remote always pings
-}
+
func (t *pingRecorder) ping(toid NodeID, toaddr *net.UDPAddr) error {
t.mu.Lock()
defer t.mu.Unlock()
@@ -222,6 +214,8 @@ func (t *pingRecorder) ping(toid NodeID, toaddr *net.UDPAddr) error {
}
}
+func (t *pingRecorder) close() {}
+
func TestTable_closest(t *testing.T) {
t.Parallel()
diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go
index f6bcd9708..0ff47c5e4 100644
--- a/p2p/discover/udp.go
+++ b/p2p/discover/udp.go
@@ -32,8 +32,6 @@ import (
"github.com/ethereum/go-ethereum/rlp"
)
-const Version = 4
-
// Errors
var (
errPacketTooSmall = errors.New("too small")
@@ -272,21 +270,33 @@ func (t *udp) close() {
// ping sends a ping message to the given node and waits for a reply.
func (t *udp) ping(toid NodeID, toaddr *net.UDPAddr) error {
+ return <-t.sendPing(toid, toaddr, nil)
+}
+
+// sendPing sends a ping message to the given node and invokes the callback
+// when the reply arrives.
+func (t *udp) sendPing(toid NodeID, toaddr *net.UDPAddr, callback func()) <-chan error {
req := &ping{
- Version: Version,
+ Version: 4,
From: t.ourEndpoint,
To: makeEndpoint(toaddr, 0), // TODO: maybe use known TCP port from DB
Expiration: uint64(time.Now().Add(expiration).Unix()),
}
packet, hash, err := encodePacket(t.priv, pingPacket, req)
if err != nil {
- return err
+ errc := make(chan error, 1)
+ errc <- err
+ return errc
}
errc := t.pending(toid, pongPacket, func(p interface{}) bool {
- return bytes.Equal(p.(*pong).ReplyTok, hash)
+ ok := bytes.Equal(p.(*pong).ReplyTok, hash)
+ if ok && callback != nil {
+ callback()
+ }
+ return ok
})
t.write(toaddr, req.name(), packet)
- return <-errc
+ return errc
}
func (t *udp) waitping(from NodeID) error {
@@ -296,6 +306,13 @@ func (t *udp) waitping(from NodeID) error {
// findnode sends a findnode request to the given node and waits until
// the node has sent up to k neighbors.
func (t *udp) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) {
+ // If we haven't seen a ping from the destination node for a while, it won't remember
+ // our endpoint proof and reject findnode. Solicit a ping first.
+ if time.Since(t.db.lastPingReceived(toid)) > nodeDBNodeExpiration {
+ t.ping(toid, toaddr)
+ t.waitping(toid)
+ }
+
nodes := make([]*Node, 0, bucketSize)
nreceived := 0
errc := t.pending(toid, neighborsPacket, func(r interface{}) bool {
@@ -315,8 +332,7 @@ func (t *udp) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node
Target: target,
Expiration: uint64(time.Now().Add(expiration).Unix()),
})
- err := <-errc
- return nodes, err
+ return nodes, <-errc
}
// pending adds a reply callback to the pending reply queue.
@@ -587,10 +603,17 @@ func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) er
ReplyTok: mac,
Expiration: uint64(time.Now().Add(expiration).Unix()),
})
- if !t.handleReply(fromID, pingPacket, req) {
- // Note: we're ignoring the provided IP address right now
- go t.bond(true, fromID, from, req.From.TCP)
+ t.handleReply(fromID, pingPacket, req)
+
+ // Add the node to the table. Before doing so, ensure that we have a recent enough pong
+ // recorded in the database so their findnode requests will be accepted later.
+ n := NewNode(fromID, from.IP, uint16(from.Port), req.From.TCP)
+ if time.Since(t.db.lastPongReceived(fromID)) > nodeDBNodeExpiration {
+ t.sendPing(fromID, from, func() { t.addThroughPing(n) })
+ } else {
+ t.addThroughPing(n)
}
+ t.db.updateLastPingReceived(fromID, time.Now())
return nil
}
@@ -603,6 +626,7 @@ func (req *pong) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) er
if !t.handleReply(fromID, pongPacket, req) {
return errUnsolicitedReply
}
+ t.db.updateLastPongReceived(fromID, time.Now())
return nil
}
@@ -613,13 +637,12 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte
return errExpired
}
if !t.db.hasBond(fromID) {
- // No bond exists, we don't process the packet. This prevents
- // an attack vector where the discovery protocol could be used
- // to amplify traffic in a DDOS attack. A malicious actor
- // would send a findnode request with the IP address and UDP
- // port of the target as the source address. The recipient of
- // the findnode packet would then send a neighbors packet
- // (which is a much bigger packet than findnode) to the victim.
+ // No endpoint proof pong exists, we don't process the packet. This prevents an
+ // attack vector where the discovery protocol could be used to amplify traffic in a
+ // DDOS attack. A malicious actor would send a findnode request with the IP address
+ // and UDP port of the target as the source address. The recipient of the findnode
+ // packet would then send a neighbors packet (which is a much bigger packet than
+ // findnode) to the victim.
return errUnknownNode
}
target := crypto.Keccak256Hash(req.Target[:])
diff --git a/p2p/discover/udp_test.go b/p2p/discover/udp_test.go
index db9804f7b..b4363a12b 100644
--- a/p2p/discover/udp_test.go
+++ b/p2p/discover/udp_test.go
@@ -124,7 +124,7 @@ func TestUDP_packetErrors(t *testing.T) {
test := newUDPTest(t)
defer test.table.Close()
- test.packetIn(errExpired, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: Version})
+ test.packetIn(errExpired, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: 4})
test.packetIn(errUnsolicitedReply, pongPacket, &pong{ReplyTok: []byte{}, Expiration: futureExp})
test.packetIn(errUnknownNode, findnodePacket, &findnode{Expiration: futureExp})
test.packetIn(errUnsolicitedReply, neighborsPacket, &neighbors{Expiration: futureExp})
@@ -247,7 +247,7 @@ func TestUDP_findnode(t *testing.T) {
// ensure there's a bond with the test node,
// findnode won't be accepted otherwise.
- test.table.db.updateBondTime(PubkeyID(&test.remotekey.PublicKey), time.Now())
+ test.table.db.updateLastPongReceived(PubkeyID(&test.remotekey.PublicKey), time.Now())
// check that closest neighbors are returned.
test.packetIn(nil, findnodePacket, &findnode{Target: testTarget, Expiration: futureExp})
@@ -273,10 +273,12 @@ func TestUDP_findnodeMultiReply(t *testing.T) {
test := newUDPTest(t)
defer test.table.Close()
+ rid := PubkeyID(&test.remotekey.PublicKey)
+ test.table.db.updateLastPingReceived(rid, time.Now())
+
// queue a pending findnode request
resultc, errc := make(chan []*Node), make(chan error)
go func() {
- rid := PubkeyID(&test.remotekey.PublicKey)
ns, err := test.udp.findnode(rid, test.remoteaddr, testTarget)
if err != nil && len(ns) == 0 {
errc <- err
@@ -328,7 +330,7 @@ func TestUDP_successfulPing(t *testing.T) {
defer test.table.Close()
// The remote side sends a ping packet to initiate the exchange.
- go test.packetIn(nil, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: Version, Expiration: futureExp})
+ go test.packetIn(nil, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: 4, Expiration: futureExp})
// the ping is replied to.
test.waitPacketOut(func(p *pong) {
diff --git a/p2p/metrics.go b/p2p/metrics.go
index 4cbff90ac..2d52fd1fd 100644
--- a/p2p/metrics.go
+++ b/p2p/metrics.go
@@ -31,10 +31,10 @@ var (
egressTrafficMeter = metrics.NewRegisteredMeter("p2p/OutboundTraffic", nil)
)
-// meteredConn is a wrapper around a network TCP connection that meters both the
+// meteredConn is a wrapper around a net.Conn that meters both the
// inbound and outbound network traffic.
type meteredConn struct {
- *net.TCPConn // Network connection to wrap with metering
+ net.Conn // Network connection to wrap with metering
}
// newMeteredConn creates a new metered connection, also bumping the ingress or
@@ -51,13 +51,13 @@ func newMeteredConn(conn net.Conn, ingress bool) net.Conn {
} else {
egressConnectMeter.Mark(1)
}
- return &meteredConn{conn.(*net.TCPConn)}
+ return &meteredConn{Conn: conn}
}
// Read delegates a network read to the underlying connection, bumping the ingress
// traffic meter along the way.
func (c *meteredConn) Read(b []byte) (n int, err error) {
- n, err = c.TCPConn.Read(b)
+ n, err = c.Conn.Read(b)
ingressTrafficMeter.Mark(int64(n))
return
}
@@ -65,7 +65,7 @@ func (c *meteredConn) Read(b []byte) (n int, err error) {
// Write delegates a network write to the underlying connection, bumping the
// egress traffic meter along the way.
func (c *meteredConn) Write(b []byte) (n int, err error) {
- n, err = c.TCPConn.Write(b)
+ n, err = c.Conn.Write(b)
egressTrafficMeter.Mark(int64(n))
return
}
diff --git a/p2p/nat/natpmp.go b/p2p/nat/natpmp.go
index 577a424fb..8ba971472 100644
--- a/p2p/nat/natpmp.go
+++ b/p2p/nat/natpmp.go
@@ -115,8 +115,7 @@ func potentialGateways() (gws []net.IP) {
return gws
}
for _, addr := range ifaddrs {
- switch x := addr.(type) {
- case *net.IPNet:
+ if x, ok := addr.(*net.IPNet); ok {
if lan10.Contains(x.IP) || lan176.Contains(x.IP) || lan192.Contains(x.IP) {
ip := x.IP.Mask(x.Mask).To4()
if ip != nil {
diff --git a/p2p/nat/natupnp.go b/p2p/nat/natupnp.go
index 69099ac04..029143b7b 100644
--- a/p2p/nat/natupnp.go
+++ b/p2p/nat/natupnp.go
@@ -81,11 +81,8 @@ func (n *upnp) internalAddress() (net.IP, error) {
return nil, err
}
for _, addr := range addrs {
- switch x := addr.(type) {
- case *net.IPNet:
- if x.Contains(devaddr.IP) {
- return x.IP, nil
- }
+ if x, ok := addr.(*net.IPNet); ok && x.Contains(devaddr.IP) {
+ return x.IP, nil
}
}
}
diff --git a/p2p/peer.go b/p2p/peer.go
index c4c1fcd7c..482e3d506 100644
--- a/p2p/peer.go
+++ b/p2p/peer.go
@@ -17,6 +17,7 @@
package p2p
import (
+ "errors"
"fmt"
"io"
"net"
@@ -31,6 +32,10 @@ import (
"github.com/ethereum/go-ethereum/rlp"
)
+var (
+ ErrShuttingDown = errors.New("shutting down")
+)
+
const (
baseProtocolVersion = 5
baseProtocolLength = uint64(16)
@@ -371,7 +376,7 @@ func (p *Peer) getProto(code uint64) (*protoRW, error) {
type protoRW struct {
Protocol
- in chan Msg // receices read messages
+ in chan Msg // receives read messages
closed <-chan struct{} // receives when peer is shutting down
wstart <-chan struct{} // receives when write may start
werr chan<- error // for write results
@@ -393,7 +398,7 @@ func (rw *protoRW) WriteMsg(msg Msg) (err error) {
// as well but we don't want to rely on that.
rw.werr <- err
case <-rw.closed:
- err = fmt.Errorf("shutting down")
+ err = ErrShuttingDown
}
return err
}
diff --git a/p2p/protocols/protocol.go b/p2p/protocols/protocol.go
index 849a7ef39..615f74b56 100644
--- a/p2p/protocols/protocol.go
+++ b/p2p/protocols/protocol.go
@@ -29,14 +29,22 @@ devp2p subprotocols by abstracting away code standardly shared by protocols.
package protocols
import (
+ "bufio"
+ "bytes"
"context"
"fmt"
+ "io"
"reflect"
"sync"
"time"
+ "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/metrics"
"github.com/ethereum/go-ethereum/p2p"
+ "github.com/ethereum/go-ethereum/rlp"
+ "github.com/ethereum/go-ethereum/swarm/spancontext"
+ "github.com/ethereum/go-ethereum/swarm/tracing"
+ opentracing "github.com/opentracing/opentracing-go"
)
// error codes used by this protocol scheme
@@ -107,6 +115,13 @@ func errorf(code int, format string, params ...interface{}) *Error {
}
}
+// WrappedMsg is used to propagate marshalled context alongside message payloads
+type WrappedMsg struct {
+ Context []byte
+ Size uint32
+ Payload []byte
+}
+
// Spec is a protocol specification including its name and version as well as
// the types of messages which are exchanged
type Spec struct {
@@ -199,9 +214,14 @@ func NewPeer(p *p2p.Peer, rw p2p.MsgReadWriter, spec *Spec) *Peer {
// the handler argument is a function which is called for each message received
// from the remote peer, a returned error causes the loop to exit
// resulting in disconnection
-func (p *Peer) Run(handler func(msg interface{}) error) error {
+func (p *Peer) Run(handler func(ctx context.Context, msg interface{}) error) error {
for {
if err := p.handleIncoming(handler); err != nil {
+ if err != io.EOF {
+ metrics.GetOrRegisterCounter("peer.handleincoming.error", nil).Inc(1)
+ log.Error("peer.handleIncoming", "err", err)
+ }
+
return err
}
}
@@ -218,14 +238,47 @@ func (p *Peer) Drop(err error) {
// message off to the peer
// this low level call will be wrapped by libraries providing routed or broadcast sends
// but often just used to forward and push messages to directly connected peers
-func (p *Peer) Send(msg interface{}) error {
+func (p *Peer) Send(ctx context.Context, msg interface{}) error {
defer metrics.GetOrRegisterResettingTimer("peer.send_t", nil).UpdateSince(time.Now())
metrics.GetOrRegisterCounter("peer.send", nil).Inc(1)
+
+ var b bytes.Buffer
+ if tracing.Enabled {
+ writer := bufio.NewWriter(&b)
+
+ tracer := opentracing.GlobalTracer()
+
+ sctx := spancontext.FromContext(ctx)
+
+ if sctx != nil {
+ err := tracer.Inject(
+ sctx,
+ opentracing.Binary,
+ writer)
+ if err != nil {
+ return err
+ }
+ }
+
+ writer.Flush()
+ }
+
+ r, err := rlp.EncodeToBytes(msg)
+ if err != nil {
+ return err
+ }
+
+ wmsg := WrappedMsg{
+ Context: b.Bytes(),
+ Size: uint32(len(r)),
+ Payload: r,
+ }
+
code, found := p.spec.GetCode(msg)
if !found {
return errorf(ErrInvalidMsgType, "%v", code)
}
- return p2p.Send(p.rw, code, msg)
+ return p2p.Send(p.rw, code, wmsg)
}
// handleIncoming(code)
@@ -236,7 +289,7 @@ func (p *Peer) Send(msg interface{}) error {
// * checks for out-of-range message codes,
// * handles decoding with reflection,
// * call handlers as callbacks
-func (p *Peer) handleIncoming(handle func(msg interface{}) error) error {
+func (p *Peer) handleIncoming(handle func(ctx context.Context, msg interface{}) error) error {
msg, err := p.rw.ReadMsg()
if err != nil {
return err
@@ -248,11 +301,38 @@ func (p *Peer) handleIncoming(handle func(msg interface{}) error) error {
return errorf(ErrMsgTooLong, "%v > %v", msg.Size, p.spec.MaxMsgSize)
}
+ // unmarshal wrapped msg, which might contain context
+ var wmsg WrappedMsg
+ err = msg.Decode(&wmsg)
+ if err != nil {
+ log.Error(err.Error())
+ return err
+ }
+
+ ctx := context.Background()
+
+ // if tracing is enabled and the context coming within the request is
+ // not empty, try to unmarshal it
+ if tracing.Enabled && len(wmsg.Context) > 0 {
+ var sctx opentracing.SpanContext
+
+ tracer := opentracing.GlobalTracer()
+ sctx, err = tracer.Extract(
+ opentracing.Binary,
+ bytes.NewReader(wmsg.Context))
+ if err != nil {
+ log.Error(err.Error())
+ return err
+ }
+
+ ctx = spancontext.WithContext(ctx, sctx)
+ }
+
val, ok := p.spec.NewMsg(msg.Code)
if !ok {
return errorf(ErrInvalidMsgCode, "%v", msg.Code)
}
- if err := msg.Decode(val); err != nil {
+ if err := rlp.DecodeBytes(wmsg.Payload, val); err != nil {
return errorf(ErrDecode, "<= %v: %v", msg, err)
}
@@ -261,7 +341,7 @@ func (p *Peer) handleIncoming(handle func(msg interface{}) error) error {
// which the handler is supposed to cast to the appropriate type
// it is entirely safe not to check the cast in the handler since the handler is
// chosen based on the proper type in the first place
- if err := handle(val); err != nil {
+ if err := handle(ctx, val); err != nil {
return errorf(ErrHandler, "(msg code %v): %v", msg.Code, err)
}
return nil
@@ -281,14 +361,14 @@ func (p *Peer) Handshake(ctx context.Context, hs interface{}, verify func(interf
return nil, errorf(ErrHandshake, "unknown handshake message type: %T", hs)
}
errc := make(chan error, 2)
- handle := func(msg interface{}) error {
+ handle := func(ctx context.Context, msg interface{}) error {
rhs = msg
if verify != nil {
return verify(rhs)
}
return nil
}
- send := func() { errc <- p.Send(hs) }
+ send := func() { errc <- p.Send(ctx, hs) }
receive := func() { errc <- p.handleIncoming(handle) }
go func() {
diff --git a/p2p/protocols/protocol_test.go b/p2p/protocols/protocol_test.go
index aaae7502b..11df8ff39 100644
--- a/p2p/protocols/protocol_test.go
+++ b/p2p/protocols/protocol_test.go
@@ -104,7 +104,7 @@ func newProtocol(pp *p2ptest.TestPeerPool) func(*p2p.Peer, p2p.MsgReadWriter) er
return fmt.Errorf("handshake mismatch remote %v > local %v", rmhs.C, lhs.C)
}
- handle := func(msg interface{}) error {
+ handle := func(ctx context.Context, msg interface{}) error {
switch msg := msg.(type) {
case *protoHandshake:
@@ -116,7 +116,7 @@ func newProtocol(pp *p2ptest.TestPeerPool) func(*p2p.Peer, p2p.MsgReadWriter) er
return fmt.Errorf("handshake mismatch remote %v > local %v", rhs.C, lhs.C)
}
lhs.C += rhs.C
- return peer.Send(lhs)
+ return peer.Send(ctx, lhs)
case *kill:
// demonstrates use of peerPool, killing another peer connection as a response to a message
diff --git a/p2p/rlpx.go b/p2p/rlpx.go
index 149eda689..46b666869 100644
--- a/p2p/rlpx.go
+++ b/p2p/rlpx.go
@@ -181,9 +181,9 @@ func (t *rlpx) doEncHandshake(prv *ecdsa.PrivateKey, dial *discover.Node) (disco
err error
)
if dial == nil {
- sec, err = receiverEncHandshake(t.fd, prv, nil)
+ sec, err = receiverEncHandshake(t.fd, prv)
} else {
- sec, err = initiatorEncHandshake(t.fd, prv, dial.ID, nil)
+ sec, err = initiatorEncHandshake(t.fd, prv, dial.ID)
}
if err != nil {
return discover.NodeID{}, err
@@ -280,9 +280,9 @@ func (h *encHandshake) staticSharedSecret(prv *ecdsa.PrivateKey) ([]byte, error)
// it should be called on the dialing side of the connection.
//
// prv is the local client's private key.
-func initiatorEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, remoteID discover.NodeID, token []byte) (s secrets, err error) {
+func initiatorEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, remoteID discover.NodeID) (s secrets, err error) {
h := &encHandshake{initiator: true, remoteID: remoteID}
- authMsg, err := h.makeAuthMsg(prv, token)
+ authMsg, err := h.makeAuthMsg(prv)
if err != nil {
return s, err
}
@@ -306,7 +306,7 @@ func initiatorEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, remoteID d
}
// makeAuthMsg creates the initiator handshake message.
-func (h *encHandshake) makeAuthMsg(prv *ecdsa.PrivateKey, token []byte) (*authMsgV4, error) {
+func (h *encHandshake) makeAuthMsg(prv *ecdsa.PrivateKey) (*authMsgV4, error) {
rpub, err := h.remoteID.Pubkey()
if err != nil {
return nil, fmt.Errorf("bad remoteID: %v", err)
@@ -324,7 +324,7 @@ func (h *encHandshake) makeAuthMsg(prv *ecdsa.PrivateKey, token []byte) (*authMs
}
// Sign known message: static-shared-secret ^ nonce
- token, err = h.staticSharedSecret(prv)
+ token, err := h.staticSharedSecret(prv)
if err != nil {
return nil, err
}
@@ -352,8 +352,7 @@ func (h *encHandshake) handleAuthResp(msg *authRespV4) (err error) {
// it should be called on the listening side of the connection.
//
// prv is the local client's private key.
-// token is the token from a previous session with this node.
-func receiverEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, token []byte) (s secrets, err error) {
+func receiverEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey) (s secrets, err error) {
authMsg := new(authMsgV4)
authPacket, err := readHandshakeMsg(authMsg, encAuthMsgLen, prv, conn)
if err != nil {
diff --git a/p2p/server.go b/p2p/server.go
index d2cb94925..669ef740d 100644
--- a/p2p/server.go
+++ b/p2p/server.go
@@ -371,8 +371,8 @@ func (srv *Server) makeSelf(listener net.Listener, ntab discoverTable) *discover
// It blocks until all active connections have been closed.
func (srv *Server) Stop() {
srv.lock.Lock()
- defer srv.lock.Unlock()
if !srv.running {
+ srv.lock.Unlock()
return
}
srv.running = false
@@ -381,6 +381,7 @@ func (srv *Server) Stop() {
srv.listener.Close()
}
close(srv.quit)
+ srv.lock.Unlock()
srv.loopWG.Wait()
}
diff --git a/p2p/simulations/adapters/inproc.go b/p2p/simulations/adapters/inproc.go
index b68d08f39..b0fdf49b9 100644
--- a/p2p/simulations/adapters/inproc.go
+++ b/p2p/simulations/adapters/inproc.go
@@ -296,6 +296,13 @@ func (sn *SimNode) Stop() error {
return sn.node.Stop()
}
+// Service returns a running service by name
+func (sn *SimNode) Service(name string) node.Service {
+ sn.lock.RLock()
+ defer sn.lock.RUnlock()
+ return sn.running[name]
+}
+
// Services returns a copy of the underlying services
func (sn *SimNode) Services() []node.Service {
sn.lock.RLock()
@@ -307,6 +314,17 @@ func (sn *SimNode) Services() []node.Service {
return services
}
+// ServiceMap returns a map by names of the underlying services
+func (sn *SimNode) ServiceMap() map[string]node.Service {
+ sn.lock.RLock()
+ defer sn.lock.RUnlock()
+ services := make(map[string]node.Service, len(sn.running))
+ for name, service := range sn.running {
+ services[name] = service
+ }
+ return services
+}
+
// Server returns the underlying p2p.Server
func (sn *SimNode) Server() *p2p.Server {
return sn.node.Server()
@@ -335,8 +353,7 @@ func (sn *SimNode) NodeInfo() *p2p.NodeInfo {
}
func setSocketBuffer(conn net.Conn, socketReadBuffer int, socketWriteBuffer int) error {
- switch v := conn.(type) {
- case *net.UnixConn:
+ if v, ok := conn.(*net.UnixConn); ok {
err := v.SetReadBuffer(socketReadBuffer)
if err != nil {
return err
diff --git a/p2p/simulations/adapters/state.go b/p2p/simulations/adapters/state.go
deleted file mode 100644
index 78dfb11f9..000000000
--- a/p2p/simulations/adapters/state.go
+++ /dev/null
@@ -1,36 +0,0 @@
-// Copyright 2017 The go-ethereum Authors
-// This file is part of the go-ethereum library.
-//
-// The go-ethereum library is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Lesser General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// The go-ethereum library is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Lesser General Public License for more details.
-//
-// You should have received a copy of the GNU Lesser General Public License
-// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
-
-package adapters
-
-type SimStateStore struct {
- m map[string][]byte
-}
-
-func (st *SimStateStore) Load(s string) ([]byte, error) {
- return st.m[s], nil
-}
-
-func (st *SimStateStore) Save(s string, data []byte) error {
- st.m[s] = data
- return nil
-}
-
-func NewSimStateStore() *SimStateStore {
- return &SimStateStore{
- make(map[string][]byte),
- }
-}
diff --git a/p2p/simulations/network.go b/p2p/simulations/network.go
index a8a46cd87..0fb7485ad 100644
--- a/p2p/simulations/network.go
+++ b/p2p/simulations/network.go
@@ -31,7 +31,7 @@ import (
"github.com/ethereum/go-ethereum/p2p/simulations/adapters"
)
-var dialBanTimeout = 200 * time.Millisecond
+var DialBanTimeout = 200 * time.Millisecond
// NetworkConfig defines configuration options for starting a Network
type NetworkConfig struct {
@@ -78,41 +78,25 @@ func (net *Network) Events() *event.Feed {
return &net.events
}
-// NewNode adds a new node to the network with a random ID
-func (net *Network) NewNode() (*Node, error) {
- conf := adapters.RandomNodeConfig()
- conf.Services = []string{net.DefaultService}
- return net.NewNodeWithConfig(conf)
-}
-
// NewNodeWithConfig adds a new node to the network with the given config,
// returning an error if a node with the same ID or name already exists
func (net *Network) NewNodeWithConfig(conf *adapters.NodeConfig) (*Node, error) {
net.lock.Lock()
defer net.lock.Unlock()
- // create a random ID and PrivateKey if not set
- if conf.ID == (discover.NodeID{}) {
- c := adapters.RandomNodeConfig()
- conf.ID = c.ID
- conf.PrivateKey = c.PrivateKey
- }
- id := conf.ID
if conf.Reachable == nil {
conf.Reachable = func(otherID discover.NodeID) bool {
_, err := net.InitConn(conf.ID, otherID)
- return err == nil
+ if err != nil && bytes.Compare(conf.ID.Bytes(), otherID.Bytes()) < 0 {
+ return false
+ }
+ return true
}
}
- // assign a name to the node if not set
- if conf.Name == "" {
- conf.Name = fmt.Sprintf("node%02d", len(net.Nodes)+1)
- }
-
// check the node doesn't already exist
- if node := net.getNode(id); node != nil {
- return nil, fmt.Errorf("node with ID %q already exists", id)
+ if node := net.getNode(conf.ID); node != nil {
+ return nil, fmt.Errorf("node with ID %q already exists", conf.ID)
}
if node := net.getNodeByName(conf.Name); node != nil {
return nil, fmt.Errorf("node with name %q already exists", conf.Name)
@@ -132,8 +116,8 @@ func (net *Network) NewNodeWithConfig(conf *adapters.NodeConfig) (*Node, error)
Node: adapterNode,
Config: conf,
}
- log.Trace(fmt.Sprintf("node %v created", id))
- net.nodeMap[id] = len(net.Nodes)
+ log.Trace(fmt.Sprintf("node %v created", conf.ID))
+ net.nodeMap[conf.ID] = len(net.Nodes)
net.Nodes = append(net.Nodes, node)
// emit a "control" event
@@ -181,7 +165,9 @@ func (net *Network) Start(id discover.NodeID) error {
// startWithSnapshots starts the node with the given ID using the give
// snapshots
func (net *Network) startWithSnapshots(id discover.NodeID, snapshots map[string][]byte) error {
- node := net.GetNode(id)
+ net.lock.Lock()
+ defer net.lock.Unlock()
+ node := net.getNode(id)
if node == nil {
return fmt.Errorf("node %v does not exist", id)
}
@@ -220,9 +206,13 @@ func (net *Network) watchPeerEvents(id discover.NodeID, events chan *p2p.PeerEve
// assume the node is now down
net.lock.Lock()
+ defer net.lock.Unlock()
node := net.getNode(id)
+ if node == nil {
+ log.Error("Can not find node for id", "id", id)
+ return
+ }
node.Up = false
- net.lock.Unlock()
net.events.Send(NewEvent(node))
}()
for {
@@ -259,7 +249,9 @@ func (net *Network) watchPeerEvents(id discover.NodeID, events chan *p2p.PeerEve
// Stop stops the node with the given ID
func (net *Network) Stop(id discover.NodeID) error {
- node := net.GetNode(id)
+ net.lock.Lock()
+ defer net.lock.Unlock()
+ node := net.getNode(id)
if node == nil {
return fmt.Errorf("node %v does not exist", id)
}
@@ -312,7 +304,9 @@ func (net *Network) Disconnect(oneID, otherID discover.NodeID) error {
// DidConnect tracks the fact that the "one" node connected to the "other" node
func (net *Network) DidConnect(one, other discover.NodeID) error {
- conn, err := net.GetOrCreateConn(one, other)
+ net.lock.Lock()
+ defer net.lock.Unlock()
+ conn, err := net.getOrCreateConn(one, other)
if err != nil {
return fmt.Errorf("connection between %v and %v does not exist", one, other)
}
@@ -327,7 +321,9 @@ func (net *Network) DidConnect(one, other discover.NodeID) error {
// DidDisconnect tracks the fact that the "one" node disconnected from the
// "other" node
func (net *Network) DidDisconnect(one, other discover.NodeID) error {
- conn := net.GetConn(one, other)
+ net.lock.Lock()
+ defer net.lock.Unlock()
+ conn := net.getConn(one, other)
if conn == nil {
return fmt.Errorf("connection between %v and %v does not exist", one, other)
}
@@ -335,7 +331,7 @@ func (net *Network) DidDisconnect(one, other discover.NodeID) error {
return fmt.Errorf("%v and %v already disconnected", one, other)
}
conn.Up = false
- conn.initiated = time.Now().Add(-dialBanTimeout)
+ conn.initiated = time.Now().Add(-DialBanTimeout)
net.events.Send(NewEvent(conn))
return nil
}
@@ -476,16 +472,19 @@ func (net *Network) InitConn(oneID, otherID discover.NodeID) (*Conn, error) {
if err != nil {
return nil, err
}
- if time.Since(conn.initiated) < dialBanTimeout {
- return nil, fmt.Errorf("connection between %v and %v recently attempted", oneID, otherID)
- }
if conn.Up {
return nil, fmt.Errorf("%v and %v already connected", oneID, otherID)
}
+ if time.Since(conn.initiated) < DialBanTimeout {
+ return nil, fmt.Errorf("connection between %v and %v recently attempted", oneID, otherID)
+ }
+
err = conn.nodesUp()
if err != nil {
+ log.Trace(fmt.Sprintf("nodes not up: %v", err))
return nil, fmt.Errorf("nodes not up: %v", err)
}
+ log.Debug("InitConn - connection initiated")
conn.initiated = time.Now()
return conn, nil
}
diff --git a/p2p/testing/protocolsession.go b/p2p/testing/protocolsession.go
index 8f73bfa03..e3ec41ad6 100644
--- a/p2p/testing/protocolsession.go
+++ b/p2p/testing/protocolsession.go
@@ -91,7 +91,9 @@ func (s *ProtocolSession) trigger(trig Trigger) error {
errc := make(chan error)
go func() {
+ log.Trace(fmt.Sprintf("trigger %v (%v)....", trig.Msg, trig.Code))
errc <- mockNode.Trigger(&trig)
+ log.Trace(fmt.Sprintf("triggered %v (%v)", trig.Msg, trig.Code))
}()
t := trig.Timeout
diff --git a/p2p/testing/protocoltester.go b/p2p/testing/protocoltester.go
index 636613c57..c99578fe0 100644
--- a/p2p/testing/protocoltester.go
+++ b/p2p/testing/protocoltester.go
@@ -180,7 +180,8 @@ func (m *mockNode) Run(peer *p2p.Peer, rw p2p.MsgReadWriter) error {
for {
select {
case trig := <-m.trigger:
- m.err <- p2p.Send(rw, trig.Code, trig.Msg)
+ wmsg := Wrap(trig.Msg)
+ m.err <- p2p.Send(rw, trig.Code, wmsg)
case exps := <-m.expect:
m.err <- expectMsgs(rw, exps)
case <-m.stop:
@@ -220,7 +221,7 @@ func expectMsgs(rw p2p.MsgReadWriter, exps []Expect) error {
}
var found bool
for i, exp := range exps {
- if exp.Code == msg.Code && bytes.Equal(actualContent, mustEncodeMsg(exp.Msg)) {
+ if exp.Code == msg.Code && bytes.Equal(actualContent, mustEncodeMsg(Wrap(exp.Msg))) {
if matched[i] {
return fmt.Errorf("message #%d received two times", i)
}
@@ -235,7 +236,7 @@ func expectMsgs(rw p2p.MsgReadWriter, exps []Expect) error {
if matched[i] {
continue
}
- expected = append(expected, fmt.Sprintf("code %d payload %x", exp.Code, mustEncodeMsg(exp.Msg)))
+ expected = append(expected, fmt.Sprintf("code %d payload %x", exp.Code, mustEncodeMsg(Wrap(exp.Msg))))
}
return fmt.Errorf("unexpected message code %d payload %x, expected %s", msg.Code, actualContent, strings.Join(expected, " or "))
}
@@ -267,3 +268,17 @@ func mustEncodeMsg(msg interface{}) []byte {
}
return contentEnc
}
+
+type WrappedMsg struct {
+ Context []byte
+ Size uint32
+ Payload []byte
+}
+
+func Wrap(msg interface{}) interface{} {
+ data, _ := rlp.EncodeToBytes(msg)
+ return &WrappedMsg{
+ Size: uint32(len(data)),
+ Payload: data,
+ }
+}