diff options
Diffstat (limited to 'p2p')
-rw-r--r-- | p2p/protocols/accounting.go | 35 | ||||
-rw-r--r-- | p2p/protocols/protocol.go | 14 | ||||
-rw-r--r-- | p2p/protocols/protocol_test.go | 12 | ||||
-rw-r--r-- | p2p/protocols/reporter_test.go | 28 | ||||
-rw-r--r-- | p2p/simulations/connect.go | 43 | ||||
-rw-r--r-- | p2p/simulations/events.go | 2 | ||||
-rw-r--r-- | p2p/simulations/http_test.go | 18 | ||||
-rw-r--r-- | p2p/simulations/mocker_test.go | 9 | ||||
-rw-r--r-- | p2p/simulations/network.go | 158 | ||||
-rw-r--r-- | p2p/simulations/network_test.go | 135 | ||||
-rw-r--r-- | p2p/testing/protocoltester.go | 3 |
11 files changed, 346 insertions, 111 deletions
diff --git a/p2p/protocols/accounting.go b/p2p/protocols/accounting.go index bdc490e59..558247254 100644 --- a/p2p/protocols/accounting.go +++ b/p2p/protocols/accounting.go @@ -27,23 +27,21 @@ var ( // All metrics are cumulative // total amount of units credited - mBalanceCredit metrics.Counter + mBalanceCredit = metrics.NewRegisteredCounterForced("account.balance.credit", metrics.AccountingRegistry) // total amount of units debited - mBalanceDebit metrics.Counter + mBalanceDebit = metrics.NewRegisteredCounterForced("account.balance.debit", metrics.AccountingRegistry) // total amount of bytes credited - mBytesCredit metrics.Counter + mBytesCredit = metrics.NewRegisteredCounterForced("account.bytes.credit", metrics.AccountingRegistry) // total amount of bytes debited - mBytesDebit metrics.Counter + mBytesDebit = metrics.NewRegisteredCounterForced("account.bytes.debit", metrics.AccountingRegistry) // total amount of credited messages - mMsgCredit metrics.Counter + mMsgCredit = metrics.NewRegisteredCounterForced("account.msg.credit", metrics.AccountingRegistry) // total amount of debited messages - mMsgDebit metrics.Counter + mMsgDebit = metrics.NewRegisteredCounterForced("account.msg.debit", metrics.AccountingRegistry) // how many times local node had to drop remote peers - mPeerDrops metrics.Counter + mPeerDrops = metrics.NewRegisteredCounterForced("account.peerdrops", metrics.AccountingRegistry) // how many times local node overdrafted and dropped - mSelfDrops metrics.Counter - - MetricsRegistry metrics.Registry + mSelfDrops = metrics.NewRegisteredCounterForced("account.selfdrops", metrics.AccountingRegistry) ) // Prices defines how prices are being passed on to the accounting instance @@ -110,24 +108,13 @@ func NewAccounting(balance Balance, po Prices) *Accounting { return ah } -// SetupAccountingMetrics creates a separate registry for p2p accounting metrics; +// SetupAccountingMetrics uses a separate registry for p2p accounting metrics; // this registry should be independent of any other metrics as it persists at different endpoints. -// It also instantiates the given metrics and starts the persisting go-routine which +// It also starts the persisting go-routine which // at the passed interval writes the metrics to a LevelDB func SetupAccountingMetrics(reportInterval time.Duration, path string) *AccountingMetrics { - // create an empty registry - MetricsRegistry = metrics.NewRegistry() - // instantiate the metrics - mBalanceCredit = metrics.NewRegisteredCounterForced("account.balance.credit", MetricsRegistry) - mBalanceDebit = metrics.NewRegisteredCounterForced("account.balance.debit", MetricsRegistry) - mBytesCredit = metrics.NewRegisteredCounterForced("account.bytes.credit", MetricsRegistry) - mBytesDebit = metrics.NewRegisteredCounterForced("account.bytes.debit", MetricsRegistry) - mMsgCredit = metrics.NewRegisteredCounterForced("account.msg.credit", MetricsRegistry) - mMsgDebit = metrics.NewRegisteredCounterForced("account.msg.debit", MetricsRegistry) - mPeerDrops = metrics.NewRegisteredCounterForced("account.peerdrops", MetricsRegistry) - mSelfDrops = metrics.NewRegisteredCounterForced("account.selfdrops", MetricsRegistry) // create the DB and start persisting - return NewAccountingMetrics(MetricsRegistry, reportInterval, path) + return NewAccountingMetrics(metrics.AccountingRegistry, reportInterval, path) } // Send takes a peer, a size and a msg and diff --git a/p2p/protocols/protocol.go b/p2p/protocols/protocol.go index b16720dd3..bf879b985 100644 --- a/p2p/protocols/protocol.go +++ b/p2p/protocols/protocol.go @@ -423,3 +423,17 @@ func (p *Peer) Handshake(ctx context.Context, hs interface{}, verify func(interf } return rhs, nil } + +// HasCap returns true if Peer has a capability +// with provided name. +func (p *Peer) HasCap(capName string) (yes bool) { + if p == nil || p.Peer == nil { + return false + } + for _, c := range p.Caps() { + if c.Name == capName { + return true + } + } + return false +} diff --git a/p2p/protocols/protocol_test.go b/p2p/protocols/protocol_test.go index a26222cd8..4bc1e547e 100644 --- a/p2p/protocols/protocol_test.go +++ b/p2p/protocols/protocol_test.go @@ -142,9 +142,9 @@ func newProtocol(pp *p2ptest.TestPeerPool) func(*p2p.Peer, p2p.MsgReadWriter) er } } -func protocolTester(t *testing.T, pp *p2ptest.TestPeerPool) *p2ptest.ProtocolTester { +func protocolTester(pp *p2ptest.TestPeerPool) *p2ptest.ProtocolTester { conf := adapters.RandomNodeConfig() - return p2ptest.NewProtocolTester(t, conf.ID, 2, newProtocol(pp)) + return p2ptest.NewProtocolTester(conf.ID, 2, newProtocol(pp)) } func protoHandshakeExchange(id enode.ID, proto *protoHandshake) []p2ptest.Exchange { @@ -173,7 +173,7 @@ func protoHandshakeExchange(id enode.ID, proto *protoHandshake) []p2ptest.Exchan func runProtoHandshake(t *testing.T, proto *protoHandshake, errs ...error) { pp := p2ptest.NewTestPeerPool() - s := protocolTester(t, pp) + s := protocolTester(pp) // TODO: make this more than one handshake node := s.Nodes[0] if err := s.TestExchanges(protoHandshakeExchange(node.ID(), proto)...); err != nil { @@ -250,7 +250,7 @@ func TestProtocolHook(t *testing.T) { } conf := adapters.RandomNodeConfig() - tester := p2ptest.NewProtocolTester(t, conf.ID, 2, runFunc) + tester := p2ptest.NewProtocolTester(conf.ID, 2, runFunc) err := tester.TestExchanges(p2ptest.Exchange{ Expects: []p2ptest.Expect{ { @@ -389,7 +389,7 @@ func moduleHandshakeExchange(id enode.ID, resp uint) []p2ptest.Exchange { func runModuleHandshake(t *testing.T, resp uint, errs ...error) { pp := p2ptest.NewTestPeerPool() - s := protocolTester(t, pp) + s := protocolTester(pp) node := s.Nodes[0] if err := s.TestExchanges(protoHandshakeExchange(node.ID(), &protoHandshake{42, "420"})...); err != nil { t.Fatal(err) @@ -469,7 +469,7 @@ func testMultiPeerSetup(a, b enode.ID) []p2ptest.Exchange { func runMultiplePeers(t *testing.T, peer int, errs ...error) { pp := p2ptest.NewTestPeerPool() - s := protocolTester(t, pp) + s := protocolTester(pp) if err := s.TestExchanges(testMultiPeerSetup(s.Nodes[0].ID(), s.Nodes[1].ID())...); err != nil { t.Fatal(err) diff --git a/p2p/protocols/reporter_test.go b/p2p/protocols/reporter_test.go index b9f06e674..c5c025d20 100644 --- a/p2p/protocols/reporter_test.go +++ b/p2p/protocols/reporter_test.go @@ -43,21 +43,27 @@ func TestReporter(t *testing.T) { metrics := SetupAccountingMetrics(reportInterval, filepath.Join(dir, "test.db")) log.Debug("Done.") - //do some metrics + //change metrics mBalanceCredit.Inc(12) mBytesCredit.Inc(34) mMsgDebit.Inc(9) + //store expected metrics + expectedBalanceCredit := mBalanceCredit.Count() + expectedBytesCredit := mBytesCredit.Count() + expectedMsgDebit := mMsgDebit.Count() + //give the reporter time to write the metrics to DB time.Sleep(20 * time.Millisecond) - //set the metrics to nil - this effectively simulates the node having shut down... - mBalanceCredit = nil - mBytesCredit = nil - mMsgDebit = nil //close the DB also, or we can't create a new one metrics.Close() + //clear the metrics - this effectively simulates the node having shut down... + mBalanceCredit.Clear() + mBytesCredit.Clear() + mMsgDebit.Clear() + //setup the metrics again log.Debug("Setting up metrics second time") metrics = SetupAccountingMetrics(reportInterval, filepath.Join(dir, "test.db")) @@ -65,13 +71,13 @@ func TestReporter(t *testing.T) { log.Debug("Done.") //now check the metrics, they should have the same value as before "shutdown" - if mBalanceCredit.Count() != 12 { - t.Fatalf("Expected counter to be %d, but is %d", 12, mBalanceCredit.Count()) + if mBalanceCredit.Count() != expectedBalanceCredit { + t.Fatalf("Expected counter to be %d, but is %d", expectedBalanceCredit, mBalanceCredit.Count()) } - if mBytesCredit.Count() != 34 { - t.Fatalf("Expected counter to be %d, but is %d", 23, mBytesCredit.Count()) + if mBytesCredit.Count() != expectedBytesCredit { + t.Fatalf("Expected counter to be %d, but is %d", expectedBytesCredit, mBytesCredit.Count()) } - if mMsgDebit.Count() != 9 { - t.Fatalf("Expected counter to be %d, but is %d", 9, mMsgDebit.Count()) + if mMsgDebit.Count() != expectedMsgDebit { + t.Fatalf("Expected counter to be %d, but is %d", expectedMsgDebit, mMsgDebit.Count()) } } diff --git a/p2p/simulations/connect.go b/p2p/simulations/connect.go index bb7e7999a..ede96b34c 100644 --- a/p2p/simulations/connect.go +++ b/p2p/simulations/connect.go @@ -32,6 +32,9 @@ var ( // It is useful when constructing a chain network topology // when Network adds and removes nodes dynamically. func (net *Network) ConnectToLastNode(id enode.ID) (err error) { + net.lock.Lock() + defer net.lock.Unlock() + ids := net.getUpNodeIDs() l := len(ids) if l < 2 { @@ -41,29 +44,35 @@ func (net *Network) ConnectToLastNode(id enode.ID) (err error) { if last == id { last = ids[l-2] } - return net.connect(last, id) + return net.connectNotConnected(last, id) } // ConnectToRandomNode connects the node with provided NodeID // to a random node that is up. func (net *Network) ConnectToRandomNode(id enode.ID) (err error) { - selected := net.GetRandomUpNode(id) + net.lock.Lock() + defer net.lock.Unlock() + + selected := net.getRandomUpNode(id) if selected == nil { return ErrNodeNotFound } - return net.connect(selected.ID(), id) + return net.connectNotConnected(selected.ID(), id) } // ConnectNodesFull connects all nodes one to another. // It provides a complete connectivity in the network // which should be rarely needed. func (net *Network) ConnectNodesFull(ids []enode.ID) (err error) { + net.lock.Lock() + defer net.lock.Unlock() + if ids == nil { ids = net.getUpNodeIDs() } for i, lid := range ids { for _, rid := range ids[i+1:] { - if err = net.connect(lid, rid); err != nil { + if err = net.connectNotConnected(lid, rid); err != nil { return err } } @@ -74,12 +83,19 @@ func (net *Network) ConnectNodesFull(ids []enode.ID) (err error) { // ConnectNodesChain connects all nodes in a chain topology. // If ids argument is nil, all nodes that are up will be connected. func (net *Network) ConnectNodesChain(ids []enode.ID) (err error) { + net.lock.Lock() + defer net.lock.Unlock() + + return net.connectNodesChain(ids) +} + +func (net *Network) connectNodesChain(ids []enode.ID) (err error) { if ids == nil { ids = net.getUpNodeIDs() } l := len(ids) for i := 0; i < l-1; i++ { - if err := net.connect(ids[i], ids[i+1]); err != nil { + if err := net.connectNotConnected(ids[i], ids[i+1]); err != nil { return err } } @@ -89,6 +105,9 @@ func (net *Network) ConnectNodesChain(ids []enode.ID) (err error) { // ConnectNodesRing connects all nodes in a ring topology. // If ids argument is nil, all nodes that are up will be connected. func (net *Network) ConnectNodesRing(ids []enode.ID) (err error) { + net.lock.Lock() + defer net.lock.Unlock() + if ids == nil { ids = net.getUpNodeIDs() } @@ -96,15 +115,18 @@ func (net *Network) ConnectNodesRing(ids []enode.ID) (err error) { if l < 2 { return nil } - if err := net.ConnectNodesChain(ids); err != nil { + if err := net.connectNodesChain(ids); err != nil { return err } - return net.connect(ids[l-1], ids[0]) + return net.connectNotConnected(ids[l-1], ids[0]) } // ConnectNodesStar connects all nodes into a star topology // If ids argument is nil, all nodes that are up will be connected. func (net *Network) ConnectNodesStar(ids []enode.ID, center enode.ID) (err error) { + net.lock.Lock() + defer net.lock.Unlock() + if ids == nil { ids = net.getUpNodeIDs() } @@ -112,16 +134,15 @@ func (net *Network) ConnectNodesStar(ids []enode.ID, center enode.ID) (err error if center == id { continue } - if err := net.connect(center, id); err != nil { + if err := net.connectNotConnected(center, id); err != nil { return err } } return nil } -// connect connects two nodes but ignores already connected error. -func (net *Network) connect(oneID, otherID enode.ID) error { - return ignoreAlreadyConnectedErr(net.Connect(oneID, otherID)) +func (net *Network) connectNotConnected(oneID, otherID enode.ID) error { + return ignoreAlreadyConnectedErr(net.connect(oneID, otherID)) } func ignoreAlreadyConnectedErr(err error) error { diff --git a/p2p/simulations/events.go b/p2p/simulations/events.go index 9b2a990e0..984c2e088 100644 --- a/p2p/simulations/events.go +++ b/p2p/simulations/events.go @@ -100,7 +100,7 @@ func ControlEvent(v interface{}) *Event { func (e *Event) String() string { switch e.Type { case EventTypeNode: - return fmt.Sprintf("<node-event> id: %s up: %t", e.Node.ID().TerminalString(), e.Node.Up) + return fmt.Sprintf("<node-event> id: %s up: %t", e.Node.ID().TerminalString(), e.Node.Up()) case EventTypeConn: return fmt.Sprintf("<conn-event> nodes: %s->%s up: %t", e.Conn.One.TerminalString(), e.Conn.Other.TerminalString(), e.Conn.Up) case EventTypeMsg: diff --git a/p2p/simulations/http_test.go b/p2p/simulations/http_test.go index c0a5acb3d..ed43c0ed7 100644 --- a/p2p/simulations/http_test.go +++ b/p2p/simulations/http_test.go @@ -421,14 +421,15 @@ type expectEvents struct { } func (t *expectEvents) nodeEvent(id string, up bool) *Event { + node := Node{ + Config: &adapters.NodeConfig{ + ID: enode.HexID(id), + }, + up: up, + } return &Event{ Type: EventTypeNode, - Node: &Node{ - Config: &adapters.NodeConfig{ - ID: enode.HexID(id), - }, - Up: up, - }, + Node: &node, } } @@ -480,6 +481,7 @@ loop: } func (t *expectEvents) expect(events ...*Event) { + t.Helper() timeout := time.After(10 * time.Second) i := 0 for { @@ -501,8 +503,8 @@ func (t *expectEvents) expect(events ...*Event) { if event.Node.ID() != expected.Node.ID() { t.Fatalf("expected node event %d to have id %q, got %q", i, expected.Node.ID().TerminalString(), event.Node.ID().TerminalString()) } - if event.Node.Up != expected.Node.Up { - t.Fatalf("expected node event %d to have up=%t, got up=%t", i, expected.Node.Up, event.Node.Up) + if event.Node.Up() != expected.Node.Up() { + t.Fatalf("expected node event %d to have up=%t, got up=%t", i, expected.Node.Up(), event.Node.Up()) } case EventTypeConn: diff --git a/p2p/simulations/mocker_test.go b/p2p/simulations/mocker_test.go index 192be1732..069040257 100644 --- a/p2p/simulations/mocker_test.go +++ b/p2p/simulations/mocker_test.go @@ -90,15 +90,12 @@ func TestMocker(t *testing.T) { for { select { case event := <-events: - //if the event is a node Up event only - if event.Node != nil && event.Node.Up { + if isNodeUp(event) { //add the correspondent node ID to the map nodemap[event.Node.Config.ID] = true //this means all nodes got a nodeUp event, so we can continue the test if len(nodemap) == nodeCount { nodesComplete = true - //wait for 3s as the mocker will need time to connect the nodes - //time.Sleep( 3 *time.Second) } } else if event.Conn != nil && nodesComplete { connCount += 1 @@ -169,3 +166,7 @@ func TestMocker(t *testing.T) { t.Fatalf("Expected empty list of nodes, got: %d", len(nodesInfo)) } } + +func isNodeUp(event *Event) bool { + return event.Node != nil && event.Node.Up() +} diff --git a/p2p/simulations/network.go b/p2p/simulations/network.go index 86f7dc9be..2049a5108 100644 --- a/p2p/simulations/network.go +++ b/p2p/simulations/network.go @@ -136,7 +136,7 @@ func (net *Network) Config() *NetworkConfig { // StartAll starts all nodes in the network func (net *Network) StartAll() error { for _, node := range net.Nodes { - if node.Up { + if node.Up() { continue } if err := net.Start(node.ID()); err != nil { @@ -149,7 +149,7 @@ func (net *Network) StartAll() error { // StopAll stops all nodes in the network func (net *Network) StopAll() error { for _, node := range net.Nodes { - if !node.Up { + if !node.Up() { continue } if err := net.Stop(node.ID()); err != nil { @@ -174,7 +174,7 @@ func (net *Network) startWithSnapshots(id enode.ID, snapshots map[string][]byte) if node == nil { return fmt.Errorf("node %v does not exist", id) } - if node.Up { + if node.Up() { return fmt.Errorf("node %v already up", id) } log.Trace("Starting node", "id", id, "adapter", net.nodeAdapter.Name()) @@ -182,10 +182,10 @@ func (net *Network) startWithSnapshots(id enode.ID, snapshots map[string][]byte) log.Warn("Node startup failed", "id", id, "err", err) return err } - node.Up = true + node.SetUp(true) log.Info("Started node", "id", id) - - net.events.Send(NewEvent(node)) + ev := NewEvent(node) + net.events.Send(ev) // subscribe to peer events client, err := node.Client() @@ -210,12 +210,14 @@ func (net *Network) watchPeerEvents(id enode.ID, events chan *p2p.PeerEvent, sub // assume the node is now down net.lock.Lock() defer net.lock.Unlock() + node := net.getNode(id) if node == nil { return } - node.Up = false - net.events.Send(NewEvent(node)) + node.SetUp(false) + ev := NewEvent(node) + net.events.Send(ev) }() for { select { @@ -251,34 +253,57 @@ func (net *Network) watchPeerEvents(id enode.ID, events chan *p2p.PeerEvent, sub // Stop stops the node with the given ID func (net *Network) Stop(id enode.ID) error { - net.lock.Lock() - node := net.getNode(id) - if node == nil { - return fmt.Errorf("node %v does not exist", id) - } - if !node.Up { - return fmt.Errorf("node %v already down", id) + // IMPORTANT: node.Stop() must NOT be called under net.lock as + // node.Reachable() closure has a reference to the network and + // calls net.InitConn() what also locks the network. => DEADLOCK + // That holds until the following ticket is not resolved: + + var err error + + node, err := func() (*Node, error) { + net.lock.Lock() + defer net.lock.Unlock() + + node := net.getNode(id) + if node == nil { + return nil, fmt.Errorf("node %v does not exist", id) + } + if !node.Up() { + return nil, fmt.Errorf("node %v already down", id) + } + node.SetUp(false) + return node, nil + }() + if err != nil { + return err } - node.Up = false - net.lock.Unlock() - err := node.Stop() + err = node.Stop() // must be called without net.lock + + net.lock.Lock() + defer net.lock.Unlock() + if err != nil { - net.lock.Lock() - node.Up = true - net.lock.Unlock() + node.SetUp(true) return err } log.Info("Stopped node", "id", id, "err", err) - net.events.Send(ControlEvent(node)) + ev := ControlEvent(node) + net.events.Send(ev) return nil } // Connect connects two nodes together by calling the "admin_addPeer" RPC // method on the "one" node so that it connects to the "other" node func (net *Network) Connect(oneID, otherID enode.ID) error { + net.lock.Lock() + defer net.lock.Unlock() + return net.connect(oneID, otherID) +} + +func (net *Network) connect(oneID, otherID enode.ID) error { log.Debug("Connecting nodes with addPeer", "id", oneID, "other", otherID) - conn, err := net.InitConn(oneID, otherID) + conn, err := net.initConn(oneID, otherID) if err != nil { return err } @@ -376,6 +401,14 @@ func (net *Network) GetNode(id enode.ID) *Node { return net.getNode(id) } +func (net *Network) getNode(id enode.ID) *Node { + i, found := net.nodeMap[id] + if !found { + return nil + } + return net.Nodes[i] +} + // GetNode gets the node with the given name, returning nil if the node does // not exist func (net *Network) GetNodeByName(name string) *Node { @@ -398,28 +431,29 @@ func (net *Network) GetNodes() (nodes []*Node) { net.lock.RLock() defer net.lock.RUnlock() - nodes = append(nodes, net.Nodes...) - return nodes + return net.getNodes() } -func (net *Network) getNode(id enode.ID) *Node { - i, found := net.nodeMap[id] - if !found { - return nil - } - return net.Nodes[i] +func (net *Network) getNodes() (nodes []*Node) { + nodes = append(nodes, net.Nodes...) + return nodes } // GetRandomUpNode returns a random node on the network, which is running. func (net *Network) GetRandomUpNode(excludeIDs ...enode.ID) *Node { net.lock.RLock() defer net.lock.RUnlock() + return net.getRandomUpNode(excludeIDs...) +} + +// GetRandomUpNode returns a random node on the network, which is running. +func (net *Network) getRandomUpNode(excludeIDs ...enode.ID) *Node { return net.getRandomNode(net.getUpNodeIDs(), excludeIDs) } func (net *Network) getUpNodeIDs() (ids []enode.ID) { for _, node := range net.Nodes { - if node.Up { + if node.Up() { ids = append(ids, node.ID()) } } @@ -434,8 +468,8 @@ func (net *Network) GetRandomDownNode(excludeIDs ...enode.ID) *Node { } func (net *Network) getDownNodeIDs() (ids []enode.ID) { - for _, node := range net.GetNodes() { - if !node.Up { + for _, node := range net.getNodes() { + if !node.Up() { ids = append(ids, node.ID()) } } @@ -449,7 +483,7 @@ func (net *Network) getRandomNode(ids []enode.ID, excludeIDs []enode.ID) *Node { if l == 0 { return nil } - return net.GetNode(filtered[rand.Intn(l)]) + return net.getNode(filtered[rand.Intn(l)]) } func filterIDs(ids []enode.ID, excludeIDs []enode.ID) []enode.ID { @@ -527,6 +561,10 @@ func (net *Network) getConn(oneID, otherID enode.ID) *Conn { func (net *Network) InitConn(oneID, otherID enode.ID) (*Conn, error) { net.lock.Lock() defer net.lock.Unlock() + return net.initConn(oneID, otherID) +} + +func (net *Network) initConn(oneID, otherID enode.ID) (*Conn, error) { if oneID == otherID { return nil, fmt.Errorf("refusing to connect to self %v", oneID) } @@ -584,8 +622,21 @@ type Node struct { // Config if the config used to created the node Config *adapters.NodeConfig `json:"config"` - // Up tracks whether or not the node is running - Up bool `json:"up"` + // up tracks whether or not the node is running + up bool + upMu sync.RWMutex +} + +func (n *Node) Up() bool { + n.upMu.RLock() + defer n.upMu.RUnlock() + return n.up +} + +func (n *Node) SetUp(up bool) { + n.upMu.Lock() + defer n.upMu.Unlock() + n.up = up } // ID returns the ID of the node @@ -619,10 +670,29 @@ func (n *Node) MarshalJSON() ([]byte, error) { }{ Info: n.NodeInfo(), Config: n.Config, - Up: n.Up, + Up: n.Up(), }) } +// UnmarshalJSON implements json.Unmarshaler interface so that we don't lose +// Node.up status. IMPORTANT: The implementation is incomplete; we lose +// p2p.NodeInfo. +func (n *Node) UnmarshalJSON(raw []byte) error { + // TODO: How should we turn back NodeInfo into n.Node? + // Ticket: https://github.com/ethersphere/go-ethereum/issues/1177 + node := struct { + Config *adapters.NodeConfig `json:"config,omitempty"` + Up bool `json:"up"` + }{} + if err := json.Unmarshal(raw, &node); err != nil { + return err + } + + n.SetUp(node.Up) + n.Config = node.Config + return nil +} + // Conn represents a connection between two nodes in the network type Conn struct { // One is the node which initiated the connection @@ -642,10 +712,10 @@ type Conn struct { // nodesUp returns whether both nodes are currently up func (c *Conn) nodesUp() error { - if !c.one.Up { + if !c.one.Up() { return fmt.Errorf("one %v is not up", c.One) } - if !c.other.Up { + if !c.other.Up() { return fmt.Errorf("other %v is not up", c.Other) } return nil @@ -717,7 +787,7 @@ func (net *Network) snapshot(addServices []string, removeServices []string) (*Sn } for i, node := range net.Nodes { snap.Nodes[i] = NodeSnapshot{Node: *node} - if !node.Up { + if !node.Up() { continue } snapshots, err := node.Snapshots() @@ -772,7 +842,7 @@ func (net *Network) Load(snap *Snapshot) error { if _, err := net.NewNodeWithConfig(n.Node.Config); err != nil { return err } - if !n.Node.Up { + if !n.Node.Up() { continue } if err := net.startWithSnapshots(n.Node.Config.ID, n.Snapshots); err != nil { @@ -844,7 +914,7 @@ func (net *Network) Load(snap *Snapshot) error { // Start connecting. for _, conn := range snap.Conns { - if !net.GetNode(conn.One).Up || !net.GetNode(conn.Other).Up { + if !net.GetNode(conn.One).Up() || !net.GetNode(conn.Other).Up() { //in this case, at least one of the nodes of a connection is not up, //so it would result in the snapshot `Load` to fail continue @@ -898,7 +968,7 @@ func (net *Network) executeControlEvent(event *Event) { } func (net *Network) executeNodeEvent(e *Event) error { - if !e.Node.Up { + if !e.Node.Up() { return net.Stop(e.Node.ID()) } diff --git a/p2p/simulations/network_test.go b/p2p/simulations/network_test.go index b7852addb..8b644ffb0 100644 --- a/p2p/simulations/network_test.go +++ b/p2p/simulations/network_test.go @@ -20,6 +20,7 @@ import ( "context" "encoding/json" "fmt" + "reflect" "strconv" "strings" "testing" @@ -485,3 +486,137 @@ func benchmarkMinimalServiceTmp(b *testing.B) { } } } + +func TestNode_UnmarshalJSON(t *testing.T) { + t.Run( + "test unmarshal of Node up field", + func(t *testing.T) { + runNodeUnmarshalJSON(t, casesNodeUnmarshalJSONUpField()) + }, + ) + t.Run( + "test unmarshal of Node Config field", + func(t *testing.T) { + runNodeUnmarshalJSON(t, casesNodeUnmarshalJSONConfigField()) + }, + ) +} + +func runNodeUnmarshalJSON(t *testing.T, tests []nodeUnmarshalTestCase) { + t.Helper() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got Node + if err := got.UnmarshalJSON([]byte(tt.marshaled)); err != nil { + expectErrorMessageToContain(t, err, tt.wantErr) + } + expectNodeEquality(t, got, tt.want) + }) + } +} + +type nodeUnmarshalTestCase struct { + name string + marshaled string + want Node + wantErr string +} + +func expectErrorMessageToContain(t *testing.T, got error, want string) { + t.Helper() + if got == nil && want == "" { + return + } + + if got == nil && want != "" { + t.Errorf("error was expected, got: nil, want: %v", want) + return + } + + if !strings.Contains(got.Error(), want) { + t.Errorf( + "unexpected error message, got %v, want: %v", + want, + got, + ) + } +} + +func expectNodeEquality(t *testing.T, got Node, want Node) { + t.Helper() + if !reflect.DeepEqual(got, want) { + t.Errorf("Node.UnmarshalJSON() = %v, want %v", got, want) + } +} + +func casesNodeUnmarshalJSONUpField() []nodeUnmarshalTestCase { + return []nodeUnmarshalTestCase{ + { + name: "empty json", + marshaled: "{}", + want: Node{ + up: false, + }, + }, + { + name: "a stopped node", + marshaled: "{\"up\": false}", + want: Node{ + up: false, + }, + }, + { + name: "a running node", + marshaled: "{\"up\": true}", + want: Node{ + up: true, + }, + }, + { + name: "invalid JSON value on valid key", + marshaled: "{\"up\": foo}", + wantErr: "invalid character", + }, + { + name: "invalid JSON key and value", + marshaled: "{foo: bar}", + wantErr: "invalid character", + }, + { + name: "bool value expected but got something else (string)", + marshaled: "{\"up\": \"true\"}", + wantErr: "cannot unmarshal string into Go struct", + }, + } +} + +func casesNodeUnmarshalJSONConfigField() []nodeUnmarshalTestCase { + // Don't do a big fuss around testing, as adapters.NodeConfig should + // handle it's own serialization. Just do a sanity check. + return []nodeUnmarshalTestCase{ + { + name: "Config field is omitted", + marshaled: "{}", + want: Node{ + Config: nil, + }, + }, + { + name: "Config field is nil", + marshaled: "{\"config\": nil}", + want: Node{ + Config: nil, + }, + }, + { + name: "a non default Config field", + marshaled: "{\"config\":{\"name\":\"node_ecdd0\",\"port\":44665}}", + want: Node{ + Config: &adapters.NodeConfig{ + Name: "node_ecdd0", + Port: 44665, + }, + }, + }, + } +} diff --git a/p2p/testing/protocoltester.go b/p2p/testing/protocoltester.go index afc03b009..cbd8ce6fe 100644 --- a/p2p/testing/protocoltester.go +++ b/p2p/testing/protocoltester.go @@ -30,7 +30,6 @@ import ( "io/ioutil" "strings" "sync" - "testing" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/node" @@ -52,7 +51,7 @@ type ProtocolTester struct { // NewProtocolTester constructs a new ProtocolTester // it takes as argument the pivot node id, the number of dummy peers and the // protocol run function called on a peer connection by the p2p server -func NewProtocolTester(t *testing.T, id enode.ID, n int, run func(*p2p.Peer, p2p.MsgReadWriter) error) *ProtocolTester { +func NewProtocolTester(id enode.ID, n int, run func(*p2p.Peer, p2p.MsgReadWriter) error) *ProtocolTester { services := adapters.Services{ "test": func(ctx *adapters.ServiceContext) (node.Service, error) { return &testNode{run}, nil |