aboutsummaryrefslogtreecommitdiffstats
path: root/p2p/simulations/network.go
diff options
context:
space:
mode:
Diffstat (limited to 'p2p/simulations/network.go')
-rw-r--r--p2p/simulations/network.go83
1 files changed, 71 insertions, 12 deletions
diff --git a/p2p/simulations/network.go b/p2p/simulations/network.go
index ab9f582c5..a6fac2c2a 100644
--- a/p2p/simulations/network.go
+++ b/p2p/simulations/network.go
@@ -22,6 +22,7 @@ import (
"encoding/json"
"errors"
"fmt"
+ "math/rand"
"sync"
"time"
@@ -57,6 +58,8 @@ type Network struct {
Conns []*Conn `json:"conns"`
connMap map[string]int
+ pivotNodeID enode.ID
+
nodeAdapter adapters.NodeAdapter
events event.Feed
lock sync.RWMutex
@@ -370,23 +373,32 @@ func (net *Network) DidReceive(sender, receiver enode.ID, proto string, code uin
// GetNode gets the node with the given ID, returning nil if the node does not
// exist
func (net *Network) GetNode(id enode.ID) *Node {
- net.lock.Lock()
- defer net.lock.Unlock()
+ net.lock.RLock()
+ defer net.lock.RUnlock()
return net.getNode(id)
}
// GetNode gets the node with the given name, returning nil if the node does
// not exist
func (net *Network) GetNodeByName(name string) *Node {
- net.lock.Lock()
- defer net.lock.Unlock()
+ net.lock.RLock()
+ defer net.lock.RUnlock()
return net.getNodeByName(name)
}
+func (net *Network) getNodeByName(name string) *Node {
+ for _, node := range net.Nodes {
+ if node.Config.Name == name {
+ return node
+ }
+ }
+ return nil
+}
+
// GetNodes returns the existing nodes
func (net *Network) GetNodes() (nodes []*Node) {
- net.lock.Lock()
- defer net.lock.Unlock()
+ net.lock.RLock()
+ defer net.lock.RUnlock()
nodes = append(nodes, net.Nodes...)
return nodes
@@ -400,20 +412,67 @@ func (net *Network) getNode(id enode.ID) *Node {
return net.Nodes[i]
}
-func (net *Network) getNodeByName(name string) *Node {
+// GetRandomUpNode returns a random node on the network, which is running.
+func (net *Network) GetRandomUpNode(excludeIDs ...enode.ID) *Node {
+ net.lock.RLock()
+ defer net.lock.RUnlock()
+ return net.getRandomNode(net.getUpNodeIDs(), excludeIDs)
+}
+
+func (net *Network) getUpNodeIDs() (ids []enode.ID) {
for _, node := range net.Nodes {
- if node.Config.Name == name {
- return node
+ if node.Up {
+ ids = append(ids, node.ID())
}
}
- return nil
+ return ids
+}
+
+// GetRandomDownNode returns a random node on the network, which is stopped.
+func (net *Network) GetRandomDownNode(excludeIDs ...enode.ID) *Node {
+ net.lock.RLock()
+ defer net.lock.RUnlock()
+ return net.getRandomNode(net.getDownNodeIDs(), excludeIDs)
+}
+
+func (net *Network) getDownNodeIDs() (ids []enode.ID) {
+ for _, node := range net.GetNodes() {
+ if !node.Up {
+ ids = append(ids, node.ID())
+ }
+ }
+ return ids
+}
+
+func (net *Network) getRandomNode(ids []enode.ID, excludeIDs []enode.ID) *Node {
+ filtered := filterIDs(ids, excludeIDs)
+
+ l := len(filtered)
+ if l == 0 {
+ return nil
+ }
+ return net.GetNode(filtered[rand.Intn(l)])
+}
+
+func filterIDs(ids []enode.ID, excludeIDs []enode.ID) []enode.ID {
+ exclude := make(map[enode.ID]bool)
+ for _, id := range excludeIDs {
+ exclude[id] = true
+ }
+ var filtered []enode.ID
+ for _, id := range ids {
+ if _, found := exclude[id]; !found {
+ filtered = append(filtered, id)
+ }
+ }
+ return filtered
}
// GetConn returns the connection which exists between "one" and "other"
// regardless of which node initiated the connection
func (net *Network) GetConn(oneID, otherID enode.ID) *Conn {
- net.lock.Lock()
- defer net.lock.Unlock()
+ net.lock.RLock()
+ defer net.lock.RUnlock()
return net.getConn(oneID, otherID)
}