diff options
Diffstat (limited to 'p2p/testing/protocolsession.go')
-rw-r--r-- | p2p/testing/protocolsession.go | 178 |
1 files changed, 126 insertions, 52 deletions
diff --git a/p2p/testing/protocolsession.go b/p2p/testing/protocolsession.go index a779aeebb..361285f06 100644 --- a/p2p/testing/protocolsession.go +++ b/p2p/testing/protocolsession.go @@ -19,13 +19,17 @@ package testing import ( "errors" "fmt" + "sync" "time" + "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/p2p/simulations/adapters" ) +var errTimedOut = errors.New("timed out") + // ProtocolSession is a quasi simulation of a pivot node running // a service and a number of dummy peers that can send (trigger) or // receive (expect) messages @@ -46,6 +50,7 @@ type Exchange struct { Label string Triggers []Trigger Expects []Expect + Timeout time.Duration } // Trigger is part of the exchange, incoming message for the pivot node @@ -102,76 +107,145 @@ func (self *ProtocolSession) trigger(trig Trigger) error { } // expect checks an expectation of a message sent out by the pivot node -func (self *ProtocolSession) expect(exp Expect) error { - if exp.Msg == nil { - return errors.New("no message to expect") - } - simNode, ok := self.adapter.GetNode(exp.Peer) - if !ok { - return fmt.Errorf("trigger: peer %v does not exist (1- %v)", exp.Peer, len(self.IDs)) +func (self *ProtocolSession) expect(exps []Expect) error { + // construct a map of expectations for each node + peerExpects := make(map[discover.NodeID][]Expect) + for _, exp := range exps { + if exp.Msg == nil { + return errors.New("no message to expect") + } + peerExpects[exp.Peer] = append(peerExpects[exp.Peer], exp) } - mockNode, ok := simNode.Services()[0].(*mockNode) - if !ok { - return fmt.Errorf("trigger: peer %v is not a mock", exp.Peer) + + // construct a map of mockNodes for each node + mockNodes := make(map[discover.NodeID]*mockNode) + for nodeID := range peerExpects { + simNode, ok := self.adapter.GetNode(nodeID) + if !ok { + return fmt.Errorf("trigger: peer %v does not exist (1- %v)", nodeID, len(self.IDs)) + } + mockNode, ok := simNode.Services()[0].(*mockNode) + if !ok { + return fmt.Errorf("trigger: peer %v is not a mock", nodeID) + } + mockNodes[nodeID] = mockNode } + // done chanell cancels all created goroutines when function returns + done := make(chan struct{}) + defer close(done) + // errc catches the first error from errc := make(chan error) + + wg := &sync.WaitGroup{} + wg.Add(len(mockNodes)) + for nodeID, mockNode := range mockNodes { + nodeID := nodeID + mockNode := mockNode + go func() { + defer wg.Done() + + // Sum all Expect timeouts to give the maximum + // time for all expectations to finish. + // mockNode.Expect checks all received messages against + // a list of expected messages and timeout for each + // of them can not be checked separately. + var t time.Duration + for _, exp := range peerExpects[nodeID] { + if exp.Timeout == time.Duration(0) { + t += 2000 * time.Millisecond + } else { + t += exp.Timeout + } + } + alarm := time.NewTimer(t) + defer alarm.Stop() + + // expectErrc is used to check if error returned + // from mockNode.Expect is not nil and to send it to + // errc only in that case. + // done channel will be closed when function + expectErrc := make(chan error) + go func() { + select { + case expectErrc <- mockNode.Expect(peerExpects[nodeID]...): + case <-done: + case <-alarm.C: + } + }() + + select { + case err := <-expectErrc: + if err != nil { + select { + case errc <- err: + case <-done: + case <-alarm.C: + errc <- errTimedOut + } + } + case <-done: + case <-alarm.C: + errc <- errTimedOut + } + + }() + } + go func() { - errc <- mockNode.Expect(&exp) + wg.Wait() + // close errc when all goroutines finish to return nill err from errc + close(errc) }() - t := exp.Timeout - if t == time.Duration(0) { - t = 2000 * time.Millisecond - } - select { - case err := <-errc: - return err - case <-time.After(t): - return fmt.Errorf("timout expecting %v sent to peer %v", exp.Msg, exp.Peer) - } + return <-errc } // TestExchanges tests a series of exchanges against the session func (self *ProtocolSession) TestExchanges(exchanges ...Exchange) error { - // launch all triggers of this exchanges + for i, e := range exchanges { + if err := self.testExchange(e); err != nil { + return fmt.Errorf("exchange #%d %q: %v", i, e.Label, err) + } + log.Trace(fmt.Sprintf("exchange #%d %q: run successfully", i, e.Label)) + } + return nil +} + +// testExchange tests a single Exchange. +// Default timeout value is 2 seconds. +func (self *ProtocolSession) testExchange(e Exchange) error { + errc := make(chan error) + done := make(chan struct{}) + defer close(done) - for _, e := range exchanges { - errc := make(chan error, len(e.Triggers)+len(e.Expects)) + go func() { for _, trig := range e.Triggers { - errc <- self.trigger(trig) + err := self.trigger(trig) + if err != nil { + errc <- err + return + } } - // each expectation is spawned in separate go-routine - // expectations of an exchange are conjunctive but unordered, i.e., - // only all of them arriving constitutes a pass - // each expectation is meant to be for a different peer, otherwise they are expected to panic - // testing of an exchange blocks until all expectations are decided - // an expectation is decided if - // expected message arrives OR - // an unexpected message arrives (panic) - // times out on their individual timeout - for _, ex := range e.Expects { - // expect msg spawned to separate go routine - go func(exp Expect) { - errc <- self.expect(exp) - }(ex) + select { + case errc <- self.expect(e.Expects): + case <-done: } + }() - // time out globally or finish when all expectations satisfied - timeout := time.After(5 * time.Second) - for i := 0; i < len(e.Triggers)+len(e.Expects); i++ { - select { - case err := <-errc: - if err != nil { - return fmt.Errorf("exchange failed with: %v", err) - } - case <-timeout: - return fmt.Errorf("exchange %v: '%v' timed out", i, e.Label) - } - } + // time out globally or finish when all expectations satisfied + t := e.Timeout + if t == 0 { + t = 2000 * time.Millisecond + } + alarm := time.NewTimer(t) + select { + case err := <-errc: + return err + case <-alarm.C: + return errTimedOut } - return nil } // TestDisconnected tests the disconnections given as arguments |