diff options
Diffstat (limited to 'p2p')
-rw-r--r-- | p2p/discover/database.go | 24 | ||||
-rw-r--r-- | p2p/discover/database_test.go | 34 | ||||
-rw-r--r-- | p2p/discover/table.go | 239 | ||||
-rw-r--r-- | p2p/discover/table_test.go | 26 | ||||
-rw-r--r-- | p2p/discover/udp.go | 59 | ||||
-rw-r--r-- | p2p/discover/udp_test.go | 10 | ||||
-rw-r--r-- | p2p/metrics.go | 10 | ||||
-rw-r--r-- | p2p/nat/natpmp.go | 3 | ||||
-rw-r--r-- | p2p/nat/natupnp.go | 7 | ||||
-rw-r--r-- | p2p/peer.go | 9 | ||||
-rw-r--r-- | p2p/protocols/protocol.go | 96 | ||||
-rw-r--r-- | p2p/protocols/protocol_test.go | 4 | ||||
-rw-r--r-- | p2p/rlpx.go | 15 | ||||
-rw-r--r-- | p2p/server.go | 3 | ||||
-rw-r--r-- | p2p/simulations/adapters/inproc.go | 21 | ||||
-rw-r--r-- | p2p/simulations/adapters/state.go | 36 | ||||
-rw-r--r-- | p2p/simulations/network.go | 67 | ||||
-rw-r--r-- | p2p/testing/protocolsession.go | 2 | ||||
-rw-r--r-- | p2p/testing/protocoltester.go | 21 |
19 files changed, 333 insertions, 353 deletions
diff --git a/p2p/discover/database.go b/p2p/discover/database.go index 6f98de9b4..22554145f 100644 --- a/p2p/discover/database.go +++ b/p2p/discover/database.go @@ -42,6 +42,7 @@ var ( nodeDBNilNodeID = NodeID{} // Special node ID to use as a nil element. nodeDBNodeExpiration = 24 * time.Hour // Time after which an unseen node should be dropped. nodeDBCleanupCycle = time.Hour // Time period for running the expiration task. + nodeDBVersion = 5 ) // nodeDB stores all nodes we know about. @@ -257,7 +258,7 @@ func (db *nodeDB) expireNodes() error { } // Skip the node if not expired yet (and not self) if !bytes.Equal(id[:], db.self[:]) { - if seen := db.bondTime(id); seen.After(threshold) { + if seen := db.lastPongReceived(id); seen.After(threshold) { continue } } @@ -267,29 +268,28 @@ func (db *nodeDB) expireNodes() error { return nil } -// lastPing retrieves the time of the last ping packet send to a remote node, -// requesting binding. -func (db *nodeDB) lastPing(id NodeID) time.Time { +// lastPingReceived retrieves the time of the last ping packet sent by the remote node. +func (db *nodeDB) lastPingReceived(id NodeID) time.Time { return time.Unix(db.fetchInt64(makeKey(id, nodeDBDiscoverPing)), 0) } -// updateLastPing updates the last time we tried contacting a remote node. -func (db *nodeDB) updateLastPing(id NodeID, instance time.Time) error { +// updateLastPing updates the last time remote node pinged us. +func (db *nodeDB) updateLastPingReceived(id NodeID, instance time.Time) error { return db.storeInt64(makeKey(id, nodeDBDiscoverPing), instance.Unix()) } -// bondTime retrieves the time of the last successful pong from remote node. -func (db *nodeDB) bondTime(id NodeID) time.Time { +// lastPongReceived retrieves the time of the last successful pong from remote node. +func (db *nodeDB) lastPongReceived(id NodeID) time.Time { return time.Unix(db.fetchInt64(makeKey(id, nodeDBDiscoverPong)), 0) } // hasBond reports whether the given node is considered bonded. func (db *nodeDB) hasBond(id NodeID) bool { - return time.Since(db.bondTime(id)) < nodeDBNodeExpiration + return time.Since(db.lastPongReceived(id)) < nodeDBNodeExpiration } -// updateBondTime updates the last pong time of a node. -func (db *nodeDB) updateBondTime(id NodeID, instance time.Time) error { +// updateLastPongReceived updates the last pong time of a node. +func (db *nodeDB) updateLastPongReceived(id NodeID, instance time.Time) error { return db.storeInt64(makeKey(id, nodeDBDiscoverPong), instance.Unix()) } @@ -332,7 +332,7 @@ seek: if n.ID == db.self { continue seek } - if now.Sub(db.bondTime(n.ID)) > maxAge { + if now.Sub(db.lastPongReceived(n.ID)) > maxAge { continue seek } for i := range nodes { diff --git a/p2p/discover/database_test.go b/p2p/discover/database_test.go index c4fa44d09..27974344e 100644 --- a/p2p/discover/database_test.go +++ b/p2p/discover/database_test.go @@ -79,7 +79,7 @@ var nodeDBInt64Tests = []struct { } func TestNodeDBInt64(t *testing.T) { - db, _ := newNodeDB("", Version, NodeID{}) + db, _ := newNodeDB("", nodeDBVersion, NodeID{}) defer db.close() tests := nodeDBInt64Tests @@ -111,27 +111,27 @@ func TestNodeDBFetchStore(t *testing.T) { inst := time.Now() num := 314 - db, _ := newNodeDB("", Version, NodeID{}) + db, _ := newNodeDB("", nodeDBVersion, NodeID{}) defer db.close() // Check fetch/store operations on a node ping object - if stored := db.lastPing(node.ID); stored.Unix() != 0 { + if stored := db.lastPingReceived(node.ID); stored.Unix() != 0 { t.Errorf("ping: non-existing object: %v", stored) } - if err := db.updateLastPing(node.ID, inst); err != nil { + if err := db.updateLastPingReceived(node.ID, inst); err != nil { t.Errorf("ping: failed to update: %v", err) } - if stored := db.lastPing(node.ID); stored.Unix() != inst.Unix() { + if stored := db.lastPingReceived(node.ID); stored.Unix() != inst.Unix() { t.Errorf("ping: value mismatch: have %v, want %v", stored, inst) } // Check fetch/store operations on a node pong object - if stored := db.bondTime(node.ID); stored.Unix() != 0 { + if stored := db.lastPongReceived(node.ID); stored.Unix() != 0 { t.Errorf("pong: non-existing object: %v", stored) } - if err := db.updateBondTime(node.ID, inst); err != nil { + if err := db.updateLastPongReceived(node.ID, inst); err != nil { t.Errorf("pong: failed to update: %v", err) } - if stored := db.bondTime(node.ID); stored.Unix() != inst.Unix() { + if stored := db.lastPongReceived(node.ID); stored.Unix() != inst.Unix() { t.Errorf("pong: value mismatch: have %v, want %v", stored, inst) } // Check fetch/store operations on a node findnode-failure object @@ -216,7 +216,7 @@ var nodeDBSeedQueryNodes = []struct { } func TestNodeDBSeedQuery(t *testing.T) { - db, _ := newNodeDB("", Version, nodeDBSeedQueryNodes[1].node.ID) + db, _ := newNodeDB("", nodeDBVersion, nodeDBSeedQueryNodes[1].node.ID) defer db.close() // Insert a batch of nodes for querying @@ -224,7 +224,7 @@ func TestNodeDBSeedQuery(t *testing.T) { if err := db.updateNode(seed.node); err != nil { t.Fatalf("node %d: failed to insert: %v", i, err) } - if err := db.updateBondTime(seed.node.ID, seed.pong); err != nil { + if err := db.updateLastPongReceived(seed.node.ID, seed.pong); err != nil { t.Fatalf("node %d: failed to insert bondTime: %v", i, err) } } @@ -267,7 +267,7 @@ func TestNodeDBPersistency(t *testing.T) { ) // Create a persistent database and store some values - db, err := newNodeDB(filepath.Join(root, "database"), Version, NodeID{}) + db, err := newNodeDB(filepath.Join(root, "database"), nodeDBVersion, NodeID{}) if err != nil { t.Fatalf("failed to create persistent database: %v", err) } @@ -277,7 +277,7 @@ func TestNodeDBPersistency(t *testing.T) { db.close() // Reopen the database and check the value - db, err = newNodeDB(filepath.Join(root, "database"), Version, NodeID{}) + db, err = newNodeDB(filepath.Join(root, "database"), nodeDBVersion, NodeID{}) if err != nil { t.Fatalf("failed to open persistent database: %v", err) } @@ -287,7 +287,7 @@ func TestNodeDBPersistency(t *testing.T) { db.close() // Change the database version and check flush - db, err = newNodeDB(filepath.Join(root, "database"), Version+1, NodeID{}) + db, err = newNodeDB(filepath.Join(root, "database"), nodeDBVersion+1, NodeID{}) if err != nil { t.Fatalf("failed to open persistent database: %v", err) } @@ -324,7 +324,7 @@ var nodeDBExpirationNodes = []struct { } func TestNodeDBExpiration(t *testing.T) { - db, _ := newNodeDB("", Version, NodeID{}) + db, _ := newNodeDB("", nodeDBVersion, NodeID{}) defer db.close() // Add all the test nodes and set their last pong time @@ -332,7 +332,7 @@ func TestNodeDBExpiration(t *testing.T) { if err := db.updateNode(seed.node); err != nil { t.Fatalf("node %d: failed to insert: %v", i, err) } - if err := db.updateBondTime(seed.node.ID, seed.pong); err != nil { + if err := db.updateLastPongReceived(seed.node.ID, seed.pong); err != nil { t.Fatalf("node %d: failed to update bondTime: %v", i, err) } } @@ -357,7 +357,7 @@ func TestNodeDBSelfExpiration(t *testing.T) { break } } - db, _ := newNodeDB("", Version, self) + db, _ := newNodeDB("", nodeDBVersion, self) defer db.close() // Add all the test nodes and set their last pong time @@ -365,7 +365,7 @@ func TestNodeDBSelfExpiration(t *testing.T) { if err := db.updateNode(seed.node); err != nil { t.Fatalf("node %d: failed to insert: %v", i, err) } - if err := db.updateBondTime(seed.node.ID, seed.pong); err != nil { + if err := db.updateLastPongReceived(seed.node.ID, seed.pong); err != nil { t.Fatalf("node %d: failed to update bondTime: %v", i, err) } } diff --git a/p2p/discover/table.go b/p2p/discover/table.go index 18920ccfd..8803daa56 100644 --- a/p2p/discover/table.go +++ b/p2p/discover/table.go @@ -25,7 +25,6 @@ package discover import ( crand "crypto/rand" "encoding/binary" - "errors" "fmt" mrand "math/rand" "net" @@ -54,15 +53,13 @@ const ( bucketIPLimit, bucketSubnet = 2, 24 // at most 2 addresses from the same /24 tableIPLimit, tableSubnet = 10, 24 - maxBondingPingPongs = 16 // Limit on the number of concurrent ping/pong interactions - maxFindnodeFailures = 5 // Nodes exceeding this limit are dropped - - refreshInterval = 30 * time.Minute - revalidateInterval = 10 * time.Second - copyNodesInterval = 30 * time.Second - seedMinTableTime = 5 * time.Minute - seedCount = 30 - seedMaxAge = 5 * 24 * time.Hour + maxFindnodeFailures = 5 // Nodes exceeding this limit are dropped + refreshInterval = 30 * time.Minute + revalidateInterval = 10 * time.Second + copyNodesInterval = 30 * time.Second + seedMinTableTime = 5 * time.Minute + seedCount = 30 + seedMaxAge = 5 * 24 * time.Hour ) type Table struct { @@ -78,28 +75,17 @@ type Table struct { closeReq chan struct{} closed chan struct{} - bondmu sync.Mutex - bonding map[NodeID]*bondproc - bondslots chan struct{} // limits total number of active bonding processes - nodeAddedHook func(*Node) // for testing net transport self *Node // metadata of the local node } -type bondproc struct { - err error - n *Node - done chan struct{} -} - // transport is implemented by the UDP transport. // it is an interface so we can test without opening lots of UDP // sockets and without generating a private key. type transport interface { ping(NodeID, *net.UDPAddr) error - waitping(NodeID) error findnode(toid NodeID, addr *net.UDPAddr, target NodeID) ([]*Node, error) close() } @@ -114,7 +100,7 @@ type bucket struct { func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string, bootnodes []*Node) (*Table, error) { // If no node database was given, use an in-memory one - db, err := newNodeDB(nodeDBPath, Version, ourID) + db, err := newNodeDB(nodeDBPath, nodeDBVersion, ourID) if err != nil { return nil, err } @@ -122,8 +108,6 @@ func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string net: t, db: db, self: NewNode(ourID, ourAddr.IP, uint16(ourAddr.Port), uint16(ourAddr.Port)), - bonding: make(map[NodeID]*bondproc), - bondslots: make(chan struct{}, maxBondingPingPongs), refreshReq: make(chan chan struct{}), initDone: make(chan struct{}), closeReq: make(chan struct{}), @@ -134,16 +118,13 @@ func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string if err := tab.setFallbackNodes(bootnodes); err != nil { return nil, err } - for i := 0; i < cap(tab.bondslots); i++ { - tab.bondslots <- struct{}{} - } for i := range tab.buckets { tab.buckets[i] = &bucket{ ips: netutil.DistinctNetSet{Subnet: bucketSubnet, Limit: bucketIPLimit}, } } tab.seedRand() - tab.loadSeedNodes(false) + tab.loadSeedNodes() // Start the background expiration goroutine after loading seeds so that the search for // seed nodes also considers older nodes that would otherwise be removed by the // expiration. @@ -315,22 +296,7 @@ func (tab *Table) lookup(targetID NodeID, refreshIfEmpty bool) []*Node { if !asked[n.ID] { asked[n.ID] = true pendingQueries++ - go func() { - // Find potential neighbors to bond with - r, err := tab.net.findnode(n.ID, n.addr(), targetID) - if err != nil { - // Bump the failure counter to detect and evacuate non-bonded entries - fails := tab.db.findFails(n.ID) + 1 - tab.db.updateFindFails(n.ID, fails) - log.Trace("Bumping findnode failure counter", "id", n.ID, "failcount", fails) - - if fails >= maxFindnodeFailures { - log.Trace("Too many findnode failures, dropping", "id", n.ID, "failcount", fails) - tab.delete(n) - } - } - reply <- tab.bondall(r) - }() + go tab.findnode(n, targetID, reply) } } if pendingQueries == 0 { @@ -349,6 +315,29 @@ func (tab *Table) lookup(targetID NodeID, refreshIfEmpty bool) []*Node { return result.entries } +func (tab *Table) findnode(n *Node, targetID NodeID, reply chan<- []*Node) { + fails := tab.db.findFails(n.ID) + r, err := tab.net.findnode(n.ID, n.addr(), targetID) + if err != nil || len(r) == 0 { + fails++ + tab.db.updateFindFails(n.ID, fails) + log.Trace("Findnode failed", "id", n.ID, "failcount", fails, "err", err) + if fails >= maxFindnodeFailures { + log.Trace("Too many findnode failures, dropping", "id", n.ID, "failcount", fails) + tab.delete(n) + } + } else if fails > 0 { + tab.db.updateFindFails(n.ID, fails-1) + } + + // Grab as many nodes as possible. Some of them might not be alive anymore, but we'll + // just remove those again during revalidation. + for _, n := range r { + tab.add(n) + } + reply <- r +} + func (tab *Table) refresh() <-chan struct{} { done := make(chan struct{}) select { @@ -401,7 +390,7 @@ loop: case <-revalidateDone: revalidate.Reset(tab.nextRevalidateTime()) case <-copyNodes.C: - go tab.copyBondedNodes() + go tab.copyLiveNodes() case <-tab.closeReq: break loop } @@ -429,7 +418,7 @@ func (tab *Table) doRefresh(done chan struct{}) { // Load nodes from the database and insert // them. This should yield a few previously seen nodes that are // (hopefully) still alive. - tab.loadSeedNodes(true) + tab.loadSeedNodes() // Run self lookup to discover new neighbor nodes. tab.lookup(tab.self.ID, false) @@ -447,15 +436,12 @@ func (tab *Table) doRefresh(done chan struct{}) { } } -func (tab *Table) loadSeedNodes(bond bool) { +func (tab *Table) loadSeedNodes() { seeds := tab.db.querySeeds(seedCount, seedMaxAge) seeds = append(seeds, tab.nursery...) - if bond { - seeds = tab.bondall(seeds) - } for i := range seeds { seed := seeds[i] - age := log.Lazy{Fn: func() interface{} { return time.Since(tab.db.bondTime(seed.ID)) }} + age := log.Lazy{Fn: func() interface{} { return time.Since(tab.db.lastPongReceived(seed.ID)) }} log.Debug("Found seed node in database", "id", seed.ID, "addr", seed.addr(), "age", age) tab.add(seed) } @@ -473,7 +459,7 @@ func (tab *Table) doRevalidate(done chan<- struct{}) { } // Ping the selected node and wait for a pong. - err := tab.ping(last.ID, last.addr()) + err := tab.net.ping(last.ID, last.addr()) tab.mutex.Lock() defer tab.mutex.Unlock() @@ -515,9 +501,9 @@ func (tab *Table) nextRevalidateTime() time.Duration { return time.Duration(tab.rand.Int63n(int64(revalidateInterval))) } -// copyBondedNodes adds nodes from the table to the database if they have been in the table +// copyLiveNodes adds nodes from the table to the database if they have been in the table // longer then minTableTime. -func (tab *Table) copyBondedNodes() { +func (tab *Table) copyLiveNodes() { tab.mutex.Lock() defer tab.mutex.Unlock() @@ -553,120 +539,6 @@ func (tab *Table) len() (n int) { return n } -// bondall bonds with all given nodes concurrently and returns -// those nodes for which bonding has probably succeeded. -func (tab *Table) bondall(nodes []*Node) (result []*Node) { - rc := make(chan *Node, len(nodes)) - for i := range nodes { - go func(n *Node) { - nn, _ := tab.bond(false, n.ID, n.addr(), n.TCP) - rc <- nn - }(nodes[i]) - } - for range nodes { - if n := <-rc; n != nil { - result = append(result, n) - } - } - return result -} - -// bond ensures the local node has a bond with the given remote node. -// It also attempts to insert the node into the table if bonding succeeds. -// The caller must not hold tab.mutex. -// -// A bond is must be established before sending findnode requests. -// Both sides must have completed a ping/pong exchange for a bond to -// exist. The total number of active bonding processes is limited in -// order to restrain network use. -// -// bond is meant to operate idempotently in that bonding with a remote -// node which still remembers a previously established bond will work. -// The remote node will simply not send a ping back, causing waitping -// to time out. -// -// If pinged is true, the remote node has just pinged us and one half -// of the process can be skipped. -func (tab *Table) bond(pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16) (*Node, error) { - if id == tab.self.ID { - return nil, errors.New("is self") - } - if pinged && !tab.isInitDone() { - return nil, errors.New("still initializing") - } - // Start bonding if we haven't seen this node for a while or if it failed findnode too often. - node, fails := tab.db.node(id), tab.db.findFails(id) - age := time.Since(tab.db.bondTime(id)) - var result error - if fails > 0 || age > nodeDBNodeExpiration { - log.Trace("Starting bonding ping/pong", "id", id, "known", node != nil, "failcount", fails, "age", age) - - tab.bondmu.Lock() - w := tab.bonding[id] - if w != nil { - // Wait for an existing bonding process to complete. - tab.bondmu.Unlock() - <-w.done - } else { - // Register a new bonding process. - w = &bondproc{done: make(chan struct{})} - tab.bonding[id] = w - tab.bondmu.Unlock() - // Do the ping/pong. The result goes into w. - tab.pingpong(w, pinged, id, addr, tcpPort) - // Unregister the process after it's done. - tab.bondmu.Lock() - delete(tab.bonding, id) - tab.bondmu.Unlock() - } - // Retrieve the bonding results - result = w.err - if result == nil { - node = w.n - } - } - // Add the node to the table even if the bonding ping/pong - // fails. It will be relaced quickly if it continues to be - // unresponsive. - if node != nil { - tab.add(node) - tab.db.updateFindFails(id, 0) - } - return node, result -} - -func (tab *Table) pingpong(w *bondproc, pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16) { - // Request a bonding slot to limit network usage - <-tab.bondslots - defer func() { tab.bondslots <- struct{}{} }() - - // Ping the remote side and wait for a pong. - if w.err = tab.ping(id, addr); w.err != nil { - close(w.done) - return - } - if !pinged { - // Give the remote node a chance to ping us before we start - // sending findnode requests. If they still remember us, - // waitping will simply time out. - tab.net.waitping(id) - } - // Bonding succeeded, update the node database. - w.n = NewNode(id, addr.IP, uint16(addr.Port), tcpPort) - close(w.done) -} - -// ping a remote endpoint and wait for a reply, also updating the node -// database accordingly. -func (tab *Table) ping(id NodeID, addr *net.UDPAddr) error { - tab.db.updateLastPing(id, time.Now()) - if err := tab.net.ping(id, addr); err != nil { - return err - } - tab.db.updateBondTime(id, time.Now()) - return nil -} - // bucket returns the bucket for the given node ID hash. func (tab *Table) bucket(sha common.Hash) *bucket { d := logdist(tab.self.sha, sha) @@ -676,21 +548,33 @@ func (tab *Table) bucket(sha common.Hash) *bucket { return tab.buckets[d-bucketMinDistance-1] } -// add attempts to add the given node its corresponding bucket. If the -// bucket has space available, adding the node succeeds immediately. -// Otherwise, the node is added if the least recently active node in -// the bucket does not respond to a ping packet. +// add attempts to add the given node to its corresponding bucket. If the bucket has space +// available, adding the node succeeds immediately. Otherwise, the node is added if the +// least recently active node in the bucket does not respond to a ping packet. // // The caller must not hold tab.mutex. -func (tab *Table) add(new *Node) { +func (tab *Table) add(n *Node) { tab.mutex.Lock() defer tab.mutex.Unlock() - b := tab.bucket(new.sha) - if !tab.bumpOrAdd(b, new) { + b := tab.bucket(n.sha) + if !tab.bumpOrAdd(b, n) { // Node is not in table. Add it to the replacement list. - tab.addReplacement(b, new) + tab.addReplacement(b, n) + } +} + +// addThroughPing adds the given node to the table. Compared to plain +// 'add' there is an additional safety measure: if the table is still +// initializing the node is not added. This prevents an attack where the +// table could be filled by just sending ping repeatedly. +// +// The caller must not hold tab.mutex. +func (tab *Table) addThroughPing(n *Node) { + if !tab.isInitDone() { + return } + tab.add(n) } // stuff adds nodes the table to the end of their corresponding bucket @@ -710,8 +594,7 @@ func (tab *Table) stuff(nodes []*Node) { } } -// delete removes an entry from the node table (used to evacuate -// failed/non-bonded discovery peers). +// delete removes an entry from the node table. It is used to evacuate dead nodes. func (tab *Table) delete(node *Node) { tab.mutex.Lock() defer tab.mutex.Unlock() diff --git a/p2p/discover/table_test.go b/p2p/discover/table_test.go index f2d3f9a2a..ed55ebd9a 100644 --- a/p2p/discover/table_test.go +++ b/p2p/discover/table_test.go @@ -52,27 +52,22 @@ func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil) defer tab.Close() - // Wait for init so bond is accepted. <-tab.initDone - // fill up the sender's bucket. + // Fill up the sender's bucket. pingSender := NewNode(MustHexID("a502af0f59b2aab7746995408c79e9ca312d2793cc997e44fc55eda62f0150bbb8c59a6f9269ba3a081518b62699ee807c7c19c20125ddfccca872608af9e370"), net.IP{}, 99, 99) last := fillBucket(tab, pingSender) - // this call to bond should replace the last node - // in its bucket if the node is not responding. + // Add the sender as if it just pinged us. Revalidate should replace the last node in + // its bucket if it is unresponsive. Revalidate again to ensure that transport.dead[last.ID] = !lastInBucketIsResponding transport.dead[pingSender.ID] = !newNodeIsResponding - tab.bond(true, pingSender.ID, &net.UDPAddr{}, 0) + tab.add(pingSender) + tab.doRevalidate(make(chan struct{}, 1)) tab.doRevalidate(make(chan struct{}, 1)) - // first ping goes to sender (bonding pingback) - if !transport.pinged[pingSender.ID] { - t.Error("table did not ping back sender") - } if !transport.pinged[last.ID] { - // second ping goes to oldest node in bucket - // to see whether it is still alive. + // Oldest node in bucket is pinged to see whether it is still alive. t.Error("table did not ping last node in bucket") } @@ -83,7 +78,7 @@ func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding wantSize-- } if l := len(tab.bucket(pingSender.sha).entries); l != wantSize { - t.Errorf("wrong bucket size after bond: got %d, want %d", l, wantSize) + t.Errorf("wrong bucket size after add: got %d, want %d", l, wantSize) } if found := contains(tab.bucket(pingSender.sha).entries, last.ID); found != lastInBucketIsResponding { t.Errorf("last entry found: %t, want: %t", found, lastInBucketIsResponding) @@ -206,10 +201,7 @@ func newPingRecorder() *pingRecorder { func (t *pingRecorder) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) { return nil, nil } -func (t *pingRecorder) close() {} -func (t *pingRecorder) waitping(from NodeID) error { - return nil // remote always pings -} + func (t *pingRecorder) ping(toid NodeID, toaddr *net.UDPAddr) error { t.mu.Lock() defer t.mu.Unlock() @@ -222,6 +214,8 @@ func (t *pingRecorder) ping(toid NodeID, toaddr *net.UDPAddr) error { } } +func (t *pingRecorder) close() {} + func TestTable_closest(t *testing.T) { t.Parallel() diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go index f6bcd9708..0ff47c5e4 100644 --- a/p2p/discover/udp.go +++ b/p2p/discover/udp.go @@ -32,8 +32,6 @@ import ( "github.com/ethereum/go-ethereum/rlp" ) -const Version = 4 - // Errors var ( errPacketTooSmall = errors.New("too small") @@ -272,21 +270,33 @@ func (t *udp) close() { // ping sends a ping message to the given node and waits for a reply. func (t *udp) ping(toid NodeID, toaddr *net.UDPAddr) error { + return <-t.sendPing(toid, toaddr, nil) +} + +// sendPing sends a ping message to the given node and invokes the callback +// when the reply arrives. +func (t *udp) sendPing(toid NodeID, toaddr *net.UDPAddr, callback func()) <-chan error { req := &ping{ - Version: Version, + Version: 4, From: t.ourEndpoint, To: makeEndpoint(toaddr, 0), // TODO: maybe use known TCP port from DB Expiration: uint64(time.Now().Add(expiration).Unix()), } packet, hash, err := encodePacket(t.priv, pingPacket, req) if err != nil { - return err + errc := make(chan error, 1) + errc <- err + return errc } errc := t.pending(toid, pongPacket, func(p interface{}) bool { - return bytes.Equal(p.(*pong).ReplyTok, hash) + ok := bytes.Equal(p.(*pong).ReplyTok, hash) + if ok && callback != nil { + callback() + } + return ok }) t.write(toaddr, req.name(), packet) - return <-errc + return errc } func (t *udp) waitping(from NodeID) error { @@ -296,6 +306,13 @@ func (t *udp) waitping(from NodeID) error { // findnode sends a findnode request to the given node and waits until // the node has sent up to k neighbors. func (t *udp) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) { + // If we haven't seen a ping from the destination node for a while, it won't remember + // our endpoint proof and reject findnode. Solicit a ping first. + if time.Since(t.db.lastPingReceived(toid)) > nodeDBNodeExpiration { + t.ping(toid, toaddr) + t.waitping(toid) + } + nodes := make([]*Node, 0, bucketSize) nreceived := 0 errc := t.pending(toid, neighborsPacket, func(r interface{}) bool { @@ -315,8 +332,7 @@ func (t *udp) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node Target: target, Expiration: uint64(time.Now().Add(expiration).Unix()), }) - err := <-errc - return nodes, err + return nodes, <-errc } // pending adds a reply callback to the pending reply queue. @@ -587,10 +603,17 @@ func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) er ReplyTok: mac, Expiration: uint64(time.Now().Add(expiration).Unix()), }) - if !t.handleReply(fromID, pingPacket, req) { - // Note: we're ignoring the provided IP address right now - go t.bond(true, fromID, from, req.From.TCP) + t.handleReply(fromID, pingPacket, req) + + // Add the node to the table. Before doing so, ensure that we have a recent enough pong + // recorded in the database so their findnode requests will be accepted later. + n := NewNode(fromID, from.IP, uint16(from.Port), req.From.TCP) + if time.Since(t.db.lastPongReceived(fromID)) > nodeDBNodeExpiration { + t.sendPing(fromID, from, func() { t.addThroughPing(n) }) + } else { + t.addThroughPing(n) } + t.db.updateLastPingReceived(fromID, time.Now()) return nil } @@ -603,6 +626,7 @@ func (req *pong) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) er if !t.handleReply(fromID, pongPacket, req) { return errUnsolicitedReply } + t.db.updateLastPongReceived(fromID, time.Now()) return nil } @@ -613,13 +637,12 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte return errExpired } if !t.db.hasBond(fromID) { - // No bond exists, we don't process the packet. This prevents - // an attack vector where the discovery protocol could be used - // to amplify traffic in a DDOS attack. A malicious actor - // would send a findnode request with the IP address and UDP - // port of the target as the source address. The recipient of - // the findnode packet would then send a neighbors packet - // (which is a much bigger packet than findnode) to the victim. + // No endpoint proof pong exists, we don't process the packet. This prevents an + // attack vector where the discovery protocol could be used to amplify traffic in a + // DDOS attack. A malicious actor would send a findnode request with the IP address + // and UDP port of the target as the source address. The recipient of the findnode + // packet would then send a neighbors packet (which is a much bigger packet than + // findnode) to the victim. return errUnknownNode } target := crypto.Keccak256Hash(req.Target[:]) diff --git a/p2p/discover/udp_test.go b/p2p/discover/udp_test.go index db9804f7b..b4363a12b 100644 --- a/p2p/discover/udp_test.go +++ b/p2p/discover/udp_test.go @@ -124,7 +124,7 @@ func TestUDP_packetErrors(t *testing.T) { test := newUDPTest(t) defer test.table.Close() - test.packetIn(errExpired, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: Version}) + test.packetIn(errExpired, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: 4}) test.packetIn(errUnsolicitedReply, pongPacket, &pong{ReplyTok: []byte{}, Expiration: futureExp}) test.packetIn(errUnknownNode, findnodePacket, &findnode{Expiration: futureExp}) test.packetIn(errUnsolicitedReply, neighborsPacket, &neighbors{Expiration: futureExp}) @@ -247,7 +247,7 @@ func TestUDP_findnode(t *testing.T) { // ensure there's a bond with the test node, // findnode won't be accepted otherwise. - test.table.db.updateBondTime(PubkeyID(&test.remotekey.PublicKey), time.Now()) + test.table.db.updateLastPongReceived(PubkeyID(&test.remotekey.PublicKey), time.Now()) // check that closest neighbors are returned. test.packetIn(nil, findnodePacket, &findnode{Target: testTarget, Expiration: futureExp}) @@ -273,10 +273,12 @@ func TestUDP_findnodeMultiReply(t *testing.T) { test := newUDPTest(t) defer test.table.Close() + rid := PubkeyID(&test.remotekey.PublicKey) + test.table.db.updateLastPingReceived(rid, time.Now()) + // queue a pending findnode request resultc, errc := make(chan []*Node), make(chan error) go func() { - rid := PubkeyID(&test.remotekey.PublicKey) ns, err := test.udp.findnode(rid, test.remoteaddr, testTarget) if err != nil && len(ns) == 0 { errc <- err @@ -328,7 +330,7 @@ func TestUDP_successfulPing(t *testing.T) { defer test.table.Close() // The remote side sends a ping packet to initiate the exchange. - go test.packetIn(nil, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: Version, Expiration: futureExp}) + go test.packetIn(nil, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: 4, Expiration: futureExp}) // the ping is replied to. test.waitPacketOut(func(p *pong) { diff --git a/p2p/metrics.go b/p2p/metrics.go index 4cbff90ac..2d52fd1fd 100644 --- a/p2p/metrics.go +++ b/p2p/metrics.go @@ -31,10 +31,10 @@ var ( egressTrafficMeter = metrics.NewRegisteredMeter("p2p/OutboundTraffic", nil) ) -// meteredConn is a wrapper around a network TCP connection that meters both the +// meteredConn is a wrapper around a net.Conn that meters both the // inbound and outbound network traffic. type meteredConn struct { - *net.TCPConn // Network connection to wrap with metering + net.Conn // Network connection to wrap with metering } // newMeteredConn creates a new metered connection, also bumping the ingress or @@ -51,13 +51,13 @@ func newMeteredConn(conn net.Conn, ingress bool) net.Conn { } else { egressConnectMeter.Mark(1) } - return &meteredConn{conn.(*net.TCPConn)} + return &meteredConn{Conn: conn} } // Read delegates a network read to the underlying connection, bumping the ingress // traffic meter along the way. func (c *meteredConn) Read(b []byte) (n int, err error) { - n, err = c.TCPConn.Read(b) + n, err = c.Conn.Read(b) ingressTrafficMeter.Mark(int64(n)) return } @@ -65,7 +65,7 @@ func (c *meteredConn) Read(b []byte) (n int, err error) { // Write delegates a network write to the underlying connection, bumping the // egress traffic meter along the way. func (c *meteredConn) Write(b []byte) (n int, err error) { - n, err = c.TCPConn.Write(b) + n, err = c.Conn.Write(b) egressTrafficMeter.Mark(int64(n)) return } diff --git a/p2p/nat/natpmp.go b/p2p/nat/natpmp.go index 577a424fb..8ba971472 100644 --- a/p2p/nat/natpmp.go +++ b/p2p/nat/natpmp.go @@ -115,8 +115,7 @@ func potentialGateways() (gws []net.IP) { return gws } for _, addr := range ifaddrs { - switch x := addr.(type) { - case *net.IPNet: + if x, ok := addr.(*net.IPNet); ok { if lan10.Contains(x.IP) || lan176.Contains(x.IP) || lan192.Contains(x.IP) { ip := x.IP.Mask(x.Mask).To4() if ip != nil { diff --git a/p2p/nat/natupnp.go b/p2p/nat/natupnp.go index 69099ac04..029143b7b 100644 --- a/p2p/nat/natupnp.go +++ b/p2p/nat/natupnp.go @@ -81,11 +81,8 @@ func (n *upnp) internalAddress() (net.IP, error) { return nil, err } for _, addr := range addrs { - switch x := addr.(type) { - case *net.IPNet: - if x.Contains(devaddr.IP) { - return x.IP, nil - } + if x, ok := addr.(*net.IPNet); ok && x.Contains(devaddr.IP) { + return x.IP, nil } } } diff --git a/p2p/peer.go b/p2p/peer.go index c4c1fcd7c..482e3d506 100644 --- a/p2p/peer.go +++ b/p2p/peer.go @@ -17,6 +17,7 @@ package p2p import ( + "errors" "fmt" "io" "net" @@ -31,6 +32,10 @@ import ( "github.com/ethereum/go-ethereum/rlp" ) +var ( + ErrShuttingDown = errors.New("shutting down") +) + const ( baseProtocolVersion = 5 baseProtocolLength = uint64(16) @@ -371,7 +376,7 @@ func (p *Peer) getProto(code uint64) (*protoRW, error) { type protoRW struct { Protocol - in chan Msg // receices read messages + in chan Msg // receives read messages closed <-chan struct{} // receives when peer is shutting down wstart <-chan struct{} // receives when write may start werr chan<- error // for write results @@ -393,7 +398,7 @@ func (rw *protoRW) WriteMsg(msg Msg) (err error) { // as well but we don't want to rely on that. rw.werr <- err case <-rw.closed: - err = fmt.Errorf("shutting down") + err = ErrShuttingDown } return err } diff --git a/p2p/protocols/protocol.go b/p2p/protocols/protocol.go index 849a7ef39..615f74b56 100644 --- a/p2p/protocols/protocol.go +++ b/p2p/protocols/protocol.go @@ -29,14 +29,22 @@ devp2p subprotocols by abstracting away code standardly shared by protocols. package protocols import ( + "bufio" + "bytes" "context" "fmt" + "io" "reflect" "sync" "time" + "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/swarm/spancontext" + "github.com/ethereum/go-ethereum/swarm/tracing" + opentracing "github.com/opentracing/opentracing-go" ) // error codes used by this protocol scheme @@ -107,6 +115,13 @@ func errorf(code int, format string, params ...interface{}) *Error { } } +// WrappedMsg is used to propagate marshalled context alongside message payloads +type WrappedMsg struct { + Context []byte + Size uint32 + Payload []byte +} + // Spec is a protocol specification including its name and version as well as // the types of messages which are exchanged type Spec struct { @@ -199,9 +214,14 @@ func NewPeer(p *p2p.Peer, rw p2p.MsgReadWriter, spec *Spec) *Peer { // the handler argument is a function which is called for each message received // from the remote peer, a returned error causes the loop to exit // resulting in disconnection -func (p *Peer) Run(handler func(msg interface{}) error) error { +func (p *Peer) Run(handler func(ctx context.Context, msg interface{}) error) error { for { if err := p.handleIncoming(handler); err != nil { + if err != io.EOF { + metrics.GetOrRegisterCounter("peer.handleincoming.error", nil).Inc(1) + log.Error("peer.handleIncoming", "err", err) + } + return err } } @@ -218,14 +238,47 @@ func (p *Peer) Drop(err error) { // message off to the peer // this low level call will be wrapped by libraries providing routed or broadcast sends // but often just used to forward and push messages to directly connected peers -func (p *Peer) Send(msg interface{}) error { +func (p *Peer) Send(ctx context.Context, msg interface{}) error { defer metrics.GetOrRegisterResettingTimer("peer.send_t", nil).UpdateSince(time.Now()) metrics.GetOrRegisterCounter("peer.send", nil).Inc(1) + + var b bytes.Buffer + if tracing.Enabled { + writer := bufio.NewWriter(&b) + + tracer := opentracing.GlobalTracer() + + sctx := spancontext.FromContext(ctx) + + if sctx != nil { + err := tracer.Inject( + sctx, + opentracing.Binary, + writer) + if err != nil { + return err + } + } + + writer.Flush() + } + + r, err := rlp.EncodeToBytes(msg) + if err != nil { + return err + } + + wmsg := WrappedMsg{ + Context: b.Bytes(), + Size: uint32(len(r)), + Payload: r, + } + code, found := p.spec.GetCode(msg) if !found { return errorf(ErrInvalidMsgType, "%v", code) } - return p2p.Send(p.rw, code, msg) + return p2p.Send(p.rw, code, wmsg) } // handleIncoming(code) @@ -236,7 +289,7 @@ func (p *Peer) Send(msg interface{}) error { // * checks for out-of-range message codes, // * handles decoding with reflection, // * call handlers as callbacks -func (p *Peer) handleIncoming(handle func(msg interface{}) error) error { +func (p *Peer) handleIncoming(handle func(ctx context.Context, msg interface{}) error) error { msg, err := p.rw.ReadMsg() if err != nil { return err @@ -248,11 +301,38 @@ func (p *Peer) handleIncoming(handle func(msg interface{}) error) error { return errorf(ErrMsgTooLong, "%v > %v", msg.Size, p.spec.MaxMsgSize) } + // unmarshal wrapped msg, which might contain context + var wmsg WrappedMsg + err = msg.Decode(&wmsg) + if err != nil { + log.Error(err.Error()) + return err + } + + ctx := context.Background() + + // if tracing is enabled and the context coming within the request is + // not empty, try to unmarshal it + if tracing.Enabled && len(wmsg.Context) > 0 { + var sctx opentracing.SpanContext + + tracer := opentracing.GlobalTracer() + sctx, err = tracer.Extract( + opentracing.Binary, + bytes.NewReader(wmsg.Context)) + if err != nil { + log.Error(err.Error()) + return err + } + + ctx = spancontext.WithContext(ctx, sctx) + } + val, ok := p.spec.NewMsg(msg.Code) if !ok { return errorf(ErrInvalidMsgCode, "%v", msg.Code) } - if err := msg.Decode(val); err != nil { + if err := rlp.DecodeBytes(wmsg.Payload, val); err != nil { return errorf(ErrDecode, "<= %v: %v", msg, err) } @@ -261,7 +341,7 @@ func (p *Peer) handleIncoming(handle func(msg interface{}) error) error { // which the handler is supposed to cast to the appropriate type // it is entirely safe not to check the cast in the handler since the handler is // chosen based on the proper type in the first place - if err := handle(val); err != nil { + if err := handle(ctx, val); err != nil { return errorf(ErrHandler, "(msg code %v): %v", msg.Code, err) } return nil @@ -281,14 +361,14 @@ func (p *Peer) Handshake(ctx context.Context, hs interface{}, verify func(interf return nil, errorf(ErrHandshake, "unknown handshake message type: %T", hs) } errc := make(chan error, 2) - handle := func(msg interface{}) error { + handle := func(ctx context.Context, msg interface{}) error { rhs = msg if verify != nil { return verify(rhs) } return nil } - send := func() { errc <- p.Send(hs) } + send := func() { errc <- p.Send(ctx, hs) } receive := func() { errc <- p.handleIncoming(handle) } go func() { diff --git a/p2p/protocols/protocol_test.go b/p2p/protocols/protocol_test.go index aaae7502b..11df8ff39 100644 --- a/p2p/protocols/protocol_test.go +++ b/p2p/protocols/protocol_test.go @@ -104,7 +104,7 @@ func newProtocol(pp *p2ptest.TestPeerPool) func(*p2p.Peer, p2p.MsgReadWriter) er return fmt.Errorf("handshake mismatch remote %v > local %v", rmhs.C, lhs.C) } - handle := func(msg interface{}) error { + handle := func(ctx context.Context, msg interface{}) error { switch msg := msg.(type) { case *protoHandshake: @@ -116,7 +116,7 @@ func newProtocol(pp *p2ptest.TestPeerPool) func(*p2p.Peer, p2p.MsgReadWriter) er return fmt.Errorf("handshake mismatch remote %v > local %v", rhs.C, lhs.C) } lhs.C += rhs.C - return peer.Send(lhs) + return peer.Send(ctx, lhs) case *kill: // demonstrates use of peerPool, killing another peer connection as a response to a message diff --git a/p2p/rlpx.go b/p2p/rlpx.go index 149eda689..46b666869 100644 --- a/p2p/rlpx.go +++ b/p2p/rlpx.go @@ -181,9 +181,9 @@ func (t *rlpx) doEncHandshake(prv *ecdsa.PrivateKey, dial *discover.Node) (disco err error ) if dial == nil { - sec, err = receiverEncHandshake(t.fd, prv, nil) + sec, err = receiverEncHandshake(t.fd, prv) } else { - sec, err = initiatorEncHandshake(t.fd, prv, dial.ID, nil) + sec, err = initiatorEncHandshake(t.fd, prv, dial.ID) } if err != nil { return discover.NodeID{}, err @@ -280,9 +280,9 @@ func (h *encHandshake) staticSharedSecret(prv *ecdsa.PrivateKey) ([]byte, error) // it should be called on the dialing side of the connection. // // prv is the local client's private key. -func initiatorEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, remoteID discover.NodeID, token []byte) (s secrets, err error) { +func initiatorEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, remoteID discover.NodeID) (s secrets, err error) { h := &encHandshake{initiator: true, remoteID: remoteID} - authMsg, err := h.makeAuthMsg(prv, token) + authMsg, err := h.makeAuthMsg(prv) if err != nil { return s, err } @@ -306,7 +306,7 @@ func initiatorEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, remoteID d } // makeAuthMsg creates the initiator handshake message. -func (h *encHandshake) makeAuthMsg(prv *ecdsa.PrivateKey, token []byte) (*authMsgV4, error) { +func (h *encHandshake) makeAuthMsg(prv *ecdsa.PrivateKey) (*authMsgV4, error) { rpub, err := h.remoteID.Pubkey() if err != nil { return nil, fmt.Errorf("bad remoteID: %v", err) @@ -324,7 +324,7 @@ func (h *encHandshake) makeAuthMsg(prv *ecdsa.PrivateKey, token []byte) (*authMs } // Sign known message: static-shared-secret ^ nonce - token, err = h.staticSharedSecret(prv) + token, err := h.staticSharedSecret(prv) if err != nil { return nil, err } @@ -352,8 +352,7 @@ func (h *encHandshake) handleAuthResp(msg *authRespV4) (err error) { // it should be called on the listening side of the connection. // // prv is the local client's private key. -// token is the token from a previous session with this node. -func receiverEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, token []byte) (s secrets, err error) { +func receiverEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey) (s secrets, err error) { authMsg := new(authMsgV4) authPacket, err := readHandshakeMsg(authMsg, encAuthMsgLen, prv, conn) if err != nil { diff --git a/p2p/server.go b/p2p/server.go index d2cb94925..669ef740d 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -371,8 +371,8 @@ func (srv *Server) makeSelf(listener net.Listener, ntab discoverTable) *discover // It blocks until all active connections have been closed. func (srv *Server) Stop() { srv.lock.Lock() - defer srv.lock.Unlock() if !srv.running { + srv.lock.Unlock() return } srv.running = false @@ -381,6 +381,7 @@ func (srv *Server) Stop() { srv.listener.Close() } close(srv.quit) + srv.lock.Unlock() srv.loopWG.Wait() } diff --git a/p2p/simulations/adapters/inproc.go b/p2p/simulations/adapters/inproc.go index b68d08f39..b0fdf49b9 100644 --- a/p2p/simulations/adapters/inproc.go +++ b/p2p/simulations/adapters/inproc.go @@ -296,6 +296,13 @@ func (sn *SimNode) Stop() error { return sn.node.Stop() } +// Service returns a running service by name +func (sn *SimNode) Service(name string) node.Service { + sn.lock.RLock() + defer sn.lock.RUnlock() + return sn.running[name] +} + // Services returns a copy of the underlying services func (sn *SimNode) Services() []node.Service { sn.lock.RLock() @@ -307,6 +314,17 @@ func (sn *SimNode) Services() []node.Service { return services } +// ServiceMap returns a map by names of the underlying services +func (sn *SimNode) ServiceMap() map[string]node.Service { + sn.lock.RLock() + defer sn.lock.RUnlock() + services := make(map[string]node.Service, len(sn.running)) + for name, service := range sn.running { + services[name] = service + } + return services +} + // Server returns the underlying p2p.Server func (sn *SimNode) Server() *p2p.Server { return sn.node.Server() @@ -335,8 +353,7 @@ func (sn *SimNode) NodeInfo() *p2p.NodeInfo { } func setSocketBuffer(conn net.Conn, socketReadBuffer int, socketWriteBuffer int) error { - switch v := conn.(type) { - case *net.UnixConn: + if v, ok := conn.(*net.UnixConn); ok { err := v.SetReadBuffer(socketReadBuffer) if err != nil { return err diff --git a/p2p/simulations/adapters/state.go b/p2p/simulations/adapters/state.go deleted file mode 100644 index 78dfb11f9..000000000 --- a/p2p/simulations/adapters/state.go +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2017 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. - -package adapters - -type SimStateStore struct { - m map[string][]byte -} - -func (st *SimStateStore) Load(s string) ([]byte, error) { - return st.m[s], nil -} - -func (st *SimStateStore) Save(s string, data []byte) error { - st.m[s] = data - return nil -} - -func NewSimStateStore() *SimStateStore { - return &SimStateStore{ - make(map[string][]byte), - } -} diff --git a/p2p/simulations/network.go b/p2p/simulations/network.go index a8a46cd87..0fb7485ad 100644 --- a/p2p/simulations/network.go +++ b/p2p/simulations/network.go @@ -31,7 +31,7 @@ import ( "github.com/ethereum/go-ethereum/p2p/simulations/adapters" ) -var dialBanTimeout = 200 * time.Millisecond +var DialBanTimeout = 200 * time.Millisecond // NetworkConfig defines configuration options for starting a Network type NetworkConfig struct { @@ -78,41 +78,25 @@ func (net *Network) Events() *event.Feed { return &net.events } -// NewNode adds a new node to the network with a random ID -func (net *Network) NewNode() (*Node, error) { - conf := adapters.RandomNodeConfig() - conf.Services = []string{net.DefaultService} - return net.NewNodeWithConfig(conf) -} - // NewNodeWithConfig adds a new node to the network with the given config, // returning an error if a node with the same ID or name already exists func (net *Network) NewNodeWithConfig(conf *adapters.NodeConfig) (*Node, error) { net.lock.Lock() defer net.lock.Unlock() - // create a random ID and PrivateKey if not set - if conf.ID == (discover.NodeID{}) { - c := adapters.RandomNodeConfig() - conf.ID = c.ID - conf.PrivateKey = c.PrivateKey - } - id := conf.ID if conf.Reachable == nil { conf.Reachable = func(otherID discover.NodeID) bool { _, err := net.InitConn(conf.ID, otherID) - return err == nil + if err != nil && bytes.Compare(conf.ID.Bytes(), otherID.Bytes()) < 0 { + return false + } + return true } } - // assign a name to the node if not set - if conf.Name == "" { - conf.Name = fmt.Sprintf("node%02d", len(net.Nodes)+1) - } - // check the node doesn't already exist - if node := net.getNode(id); node != nil { - return nil, fmt.Errorf("node with ID %q already exists", id) + if node := net.getNode(conf.ID); node != nil { + return nil, fmt.Errorf("node with ID %q already exists", conf.ID) } if node := net.getNodeByName(conf.Name); node != nil { return nil, fmt.Errorf("node with name %q already exists", conf.Name) @@ -132,8 +116,8 @@ func (net *Network) NewNodeWithConfig(conf *adapters.NodeConfig) (*Node, error) Node: adapterNode, Config: conf, } - log.Trace(fmt.Sprintf("node %v created", id)) - net.nodeMap[id] = len(net.Nodes) + log.Trace(fmt.Sprintf("node %v created", conf.ID)) + net.nodeMap[conf.ID] = len(net.Nodes) net.Nodes = append(net.Nodes, node) // emit a "control" event @@ -181,7 +165,9 @@ func (net *Network) Start(id discover.NodeID) error { // startWithSnapshots starts the node with the given ID using the give // snapshots func (net *Network) startWithSnapshots(id discover.NodeID, snapshots map[string][]byte) error { - node := net.GetNode(id) + net.lock.Lock() + defer net.lock.Unlock() + node := net.getNode(id) if node == nil { return fmt.Errorf("node %v does not exist", id) } @@ -220,9 +206,13 @@ func (net *Network) watchPeerEvents(id discover.NodeID, events chan *p2p.PeerEve // assume the node is now down net.lock.Lock() + defer net.lock.Unlock() node := net.getNode(id) + if node == nil { + log.Error("Can not find node for id", "id", id) + return + } node.Up = false - net.lock.Unlock() net.events.Send(NewEvent(node)) }() for { @@ -259,7 +249,9 @@ func (net *Network) watchPeerEvents(id discover.NodeID, events chan *p2p.PeerEve // Stop stops the node with the given ID func (net *Network) Stop(id discover.NodeID) error { - node := net.GetNode(id) + net.lock.Lock() + defer net.lock.Unlock() + node := net.getNode(id) if node == nil { return fmt.Errorf("node %v does not exist", id) } @@ -312,7 +304,9 @@ func (net *Network) Disconnect(oneID, otherID discover.NodeID) error { // DidConnect tracks the fact that the "one" node connected to the "other" node func (net *Network) DidConnect(one, other discover.NodeID) error { - conn, err := net.GetOrCreateConn(one, other) + net.lock.Lock() + defer net.lock.Unlock() + conn, err := net.getOrCreateConn(one, other) if err != nil { return fmt.Errorf("connection between %v and %v does not exist", one, other) } @@ -327,7 +321,9 @@ func (net *Network) DidConnect(one, other discover.NodeID) error { // DidDisconnect tracks the fact that the "one" node disconnected from the // "other" node func (net *Network) DidDisconnect(one, other discover.NodeID) error { - conn := net.GetConn(one, other) + net.lock.Lock() + defer net.lock.Unlock() + conn := net.getConn(one, other) if conn == nil { return fmt.Errorf("connection between %v and %v does not exist", one, other) } @@ -335,7 +331,7 @@ func (net *Network) DidDisconnect(one, other discover.NodeID) error { return fmt.Errorf("%v and %v already disconnected", one, other) } conn.Up = false - conn.initiated = time.Now().Add(-dialBanTimeout) + conn.initiated = time.Now().Add(-DialBanTimeout) net.events.Send(NewEvent(conn)) return nil } @@ -476,16 +472,19 @@ func (net *Network) InitConn(oneID, otherID discover.NodeID) (*Conn, error) { if err != nil { return nil, err } - if time.Since(conn.initiated) < dialBanTimeout { - return nil, fmt.Errorf("connection between %v and %v recently attempted", oneID, otherID) - } if conn.Up { return nil, fmt.Errorf("%v and %v already connected", oneID, otherID) } + if time.Since(conn.initiated) < DialBanTimeout { + return nil, fmt.Errorf("connection between %v and %v recently attempted", oneID, otherID) + } + err = conn.nodesUp() if err != nil { + log.Trace(fmt.Sprintf("nodes not up: %v", err)) return nil, fmt.Errorf("nodes not up: %v", err) } + log.Debug("InitConn - connection initiated") conn.initiated = time.Now() return conn, nil } diff --git a/p2p/testing/protocolsession.go b/p2p/testing/protocolsession.go index 8f73bfa03..e3ec41ad6 100644 --- a/p2p/testing/protocolsession.go +++ b/p2p/testing/protocolsession.go @@ -91,7 +91,9 @@ func (s *ProtocolSession) trigger(trig Trigger) error { errc := make(chan error) go func() { + log.Trace(fmt.Sprintf("trigger %v (%v)....", trig.Msg, trig.Code)) errc <- mockNode.Trigger(&trig) + log.Trace(fmt.Sprintf("triggered %v (%v)", trig.Msg, trig.Code)) }() t := trig.Timeout diff --git a/p2p/testing/protocoltester.go b/p2p/testing/protocoltester.go index 636613c57..c99578fe0 100644 --- a/p2p/testing/protocoltester.go +++ b/p2p/testing/protocoltester.go @@ -180,7 +180,8 @@ func (m *mockNode) Run(peer *p2p.Peer, rw p2p.MsgReadWriter) error { for { select { case trig := <-m.trigger: - m.err <- p2p.Send(rw, trig.Code, trig.Msg) + wmsg := Wrap(trig.Msg) + m.err <- p2p.Send(rw, trig.Code, wmsg) case exps := <-m.expect: m.err <- expectMsgs(rw, exps) case <-m.stop: @@ -220,7 +221,7 @@ func expectMsgs(rw p2p.MsgReadWriter, exps []Expect) error { } var found bool for i, exp := range exps { - if exp.Code == msg.Code && bytes.Equal(actualContent, mustEncodeMsg(exp.Msg)) { + if exp.Code == msg.Code && bytes.Equal(actualContent, mustEncodeMsg(Wrap(exp.Msg))) { if matched[i] { return fmt.Errorf("message #%d received two times", i) } @@ -235,7 +236,7 @@ func expectMsgs(rw p2p.MsgReadWriter, exps []Expect) error { if matched[i] { continue } - expected = append(expected, fmt.Sprintf("code %d payload %x", exp.Code, mustEncodeMsg(exp.Msg))) + expected = append(expected, fmt.Sprintf("code %d payload %x", exp.Code, mustEncodeMsg(Wrap(exp.Msg)))) } return fmt.Errorf("unexpected message code %d payload %x, expected %s", msg.Code, actualContent, strings.Join(expected, " or ")) } @@ -267,3 +268,17 @@ func mustEncodeMsg(msg interface{}) []byte { } return contentEnc } + +type WrappedMsg struct { + Context []byte + Size uint32 + Payload []byte +} + +func Wrap(msg interface{}) interface{} { + data, _ := rlp.EncodeToBytes(msg) + return &WrappedMsg{ + Size: uint32(len(data)), + Payload: data, + } +} |