aboutsummaryrefslogtreecommitdiffstats
path: root/p2p/testing/protocolsession.go
diff options
context:
space:
mode:
Diffstat (limited to 'p2p/testing/protocolsession.go')
-rw-r--r--p2p/testing/protocolsession.go178
1 files changed, 126 insertions, 52 deletions
diff --git a/p2p/testing/protocolsession.go b/p2p/testing/protocolsession.go
index a779aeebb..361285f06 100644
--- a/p2p/testing/protocolsession.go
+++ b/p2p/testing/protocolsession.go
@@ -19,13 +19,17 @@ package testing
import (
"errors"
"fmt"
+ "sync"
"time"
+ "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/p2p/simulations/adapters"
)
+var errTimedOut = errors.New("timed out")
+
// ProtocolSession is a quasi simulation of a pivot node running
// a service and a number of dummy peers that can send (trigger) or
// receive (expect) messages
@@ -46,6 +50,7 @@ type Exchange struct {
Label string
Triggers []Trigger
Expects []Expect
+ Timeout time.Duration
}
// Trigger is part of the exchange, incoming message for the pivot node
@@ -102,76 +107,145 @@ func (self *ProtocolSession) trigger(trig Trigger) error {
}
// expect checks an expectation of a message sent out by the pivot node
-func (self *ProtocolSession) expect(exp Expect) error {
- if exp.Msg == nil {
- return errors.New("no message to expect")
- }
- simNode, ok := self.adapter.GetNode(exp.Peer)
- if !ok {
- return fmt.Errorf("trigger: peer %v does not exist (1- %v)", exp.Peer, len(self.IDs))
+func (self *ProtocolSession) expect(exps []Expect) error {
+ // construct a map of expectations for each node
+ peerExpects := make(map[discover.NodeID][]Expect)
+ for _, exp := range exps {
+ if exp.Msg == nil {
+ return errors.New("no message to expect")
+ }
+ peerExpects[exp.Peer] = append(peerExpects[exp.Peer], exp)
}
- mockNode, ok := simNode.Services()[0].(*mockNode)
- if !ok {
- return fmt.Errorf("trigger: peer %v is not a mock", exp.Peer)
+
+ // construct a map of mockNodes for each node
+ mockNodes := make(map[discover.NodeID]*mockNode)
+ for nodeID := range peerExpects {
+ simNode, ok := self.adapter.GetNode(nodeID)
+ if !ok {
+ return fmt.Errorf("trigger: peer %v does not exist (1- %v)", nodeID, len(self.IDs))
+ }
+ mockNode, ok := simNode.Services()[0].(*mockNode)
+ if !ok {
+ return fmt.Errorf("trigger: peer %v is not a mock", nodeID)
+ }
+ mockNodes[nodeID] = mockNode
}
+ // done chanell cancels all created goroutines when function returns
+ done := make(chan struct{})
+ defer close(done)
+ // errc catches the first error from
errc := make(chan error)
+
+ wg := &sync.WaitGroup{}
+ wg.Add(len(mockNodes))
+ for nodeID, mockNode := range mockNodes {
+ nodeID := nodeID
+ mockNode := mockNode
+ go func() {
+ defer wg.Done()
+
+ // Sum all Expect timeouts to give the maximum
+ // time for all expectations to finish.
+ // mockNode.Expect checks all received messages against
+ // a list of expected messages and timeout for each
+ // of them can not be checked separately.
+ var t time.Duration
+ for _, exp := range peerExpects[nodeID] {
+ if exp.Timeout == time.Duration(0) {
+ t += 2000 * time.Millisecond
+ } else {
+ t += exp.Timeout
+ }
+ }
+ alarm := time.NewTimer(t)
+ defer alarm.Stop()
+
+ // expectErrc is used to check if error returned
+ // from mockNode.Expect is not nil and to send it to
+ // errc only in that case.
+ // done channel will be closed when function
+ expectErrc := make(chan error)
+ go func() {
+ select {
+ case expectErrc <- mockNode.Expect(peerExpects[nodeID]...):
+ case <-done:
+ case <-alarm.C:
+ }
+ }()
+
+ select {
+ case err := <-expectErrc:
+ if err != nil {
+ select {
+ case errc <- err:
+ case <-done:
+ case <-alarm.C:
+ errc <- errTimedOut
+ }
+ }
+ case <-done:
+ case <-alarm.C:
+ errc <- errTimedOut
+ }
+
+ }()
+ }
+
go func() {
- errc <- mockNode.Expect(&exp)
+ wg.Wait()
+ // close errc when all goroutines finish to return nill err from errc
+ close(errc)
}()
- t := exp.Timeout
- if t == time.Duration(0) {
- t = 2000 * time.Millisecond
- }
- select {
- case err := <-errc:
- return err
- case <-time.After(t):
- return fmt.Errorf("timout expecting %v sent to peer %v", exp.Msg, exp.Peer)
- }
+ return <-errc
}
// TestExchanges tests a series of exchanges against the session
func (self *ProtocolSession) TestExchanges(exchanges ...Exchange) error {
- // launch all triggers of this exchanges
+ for i, e := range exchanges {
+ if err := self.testExchange(e); err != nil {
+ return fmt.Errorf("exchange #%d %q: %v", i, e.Label, err)
+ }
+ log.Trace(fmt.Sprintf("exchange #%d %q: run successfully", i, e.Label))
+ }
+ return nil
+}
+
+// testExchange tests a single Exchange.
+// Default timeout value is 2 seconds.
+func (self *ProtocolSession) testExchange(e Exchange) error {
+ errc := make(chan error)
+ done := make(chan struct{})
+ defer close(done)
- for _, e := range exchanges {
- errc := make(chan error, len(e.Triggers)+len(e.Expects))
+ go func() {
for _, trig := range e.Triggers {
- errc <- self.trigger(trig)
+ err := self.trigger(trig)
+ if err != nil {
+ errc <- err
+ return
+ }
}
- // each expectation is spawned in separate go-routine
- // expectations of an exchange are conjunctive but unordered, i.e.,
- // only all of them arriving constitutes a pass
- // each expectation is meant to be for a different peer, otherwise they are expected to panic
- // testing of an exchange blocks until all expectations are decided
- // an expectation is decided if
- // expected message arrives OR
- // an unexpected message arrives (panic)
- // times out on their individual timeout
- for _, ex := range e.Expects {
- // expect msg spawned to separate go routine
- go func(exp Expect) {
- errc <- self.expect(exp)
- }(ex)
+ select {
+ case errc <- self.expect(e.Expects):
+ case <-done:
}
+ }()
- // time out globally or finish when all expectations satisfied
- timeout := time.After(5 * time.Second)
- for i := 0; i < len(e.Triggers)+len(e.Expects); i++ {
- select {
- case err := <-errc:
- if err != nil {
- return fmt.Errorf("exchange failed with: %v", err)
- }
- case <-timeout:
- return fmt.Errorf("exchange %v: '%v' timed out", i, e.Label)
- }
- }
+ // time out globally or finish when all expectations satisfied
+ t := e.Timeout
+ if t == 0 {
+ t = 2000 * time.Millisecond
+ }
+ alarm := time.NewTimer(t)
+ select {
+ case err := <-errc:
+ return err
+ case <-alarm.C:
+ return errTimedOut
}
- return nil
}
// TestDisconnected tests the disconnections given as arguments