aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--p2p/testing/protocolsession.go178
-rw-r--r--p2p/testing/protocoltester.go79
2 files changed, 200 insertions, 57 deletions
diff --git a/p2p/testing/protocolsession.go b/p2p/testing/protocolsession.go
index a779aeebb..361285f06 100644
--- a/p2p/testing/protocolsession.go
+++ b/p2p/testing/protocolsession.go
@@ -19,13 +19,17 @@ package testing
import (
"errors"
"fmt"
+ "sync"
"time"
+ "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/p2p/simulations/adapters"
)
+var errTimedOut = errors.New("timed out")
+
// ProtocolSession is a quasi simulation of a pivot node running
// a service and a number of dummy peers that can send (trigger) or
// receive (expect) messages
@@ -46,6 +50,7 @@ type Exchange struct {
Label string
Triggers []Trigger
Expects []Expect
+ Timeout time.Duration
}
// Trigger is part of the exchange, incoming message for the pivot node
@@ -102,76 +107,145 @@ func (self *ProtocolSession) trigger(trig Trigger) error {
}
// expect checks an expectation of a message sent out by the pivot node
-func (self *ProtocolSession) expect(exp Expect) error {
- if exp.Msg == nil {
- return errors.New("no message to expect")
- }
- simNode, ok := self.adapter.GetNode(exp.Peer)
- if !ok {
- return fmt.Errorf("trigger: peer %v does not exist (1- %v)", exp.Peer, len(self.IDs))
+func (self *ProtocolSession) expect(exps []Expect) error {
+ // construct a map of expectations for each node
+ peerExpects := make(map[discover.NodeID][]Expect)
+ for _, exp := range exps {
+ if exp.Msg == nil {
+ return errors.New("no message to expect")
+ }
+ peerExpects[exp.Peer] = append(peerExpects[exp.Peer], exp)
}
- mockNode, ok := simNode.Services()[0].(*mockNode)
- if !ok {
- return fmt.Errorf("trigger: peer %v is not a mock", exp.Peer)
+
+ // construct a map of mockNodes for each node
+ mockNodes := make(map[discover.NodeID]*mockNode)
+ for nodeID := range peerExpects {
+ simNode, ok := self.adapter.GetNode(nodeID)
+ if !ok {
+ return fmt.Errorf("trigger: peer %v does not exist (1- %v)", nodeID, len(self.IDs))
+ }
+ mockNode, ok := simNode.Services()[0].(*mockNode)
+ if !ok {
+ return fmt.Errorf("trigger: peer %v is not a mock", nodeID)
+ }
+ mockNodes[nodeID] = mockNode
}
+ // done chanell cancels all created goroutines when function returns
+ done := make(chan struct{})
+ defer close(done)
+ // errc catches the first error from
errc := make(chan error)
+
+ wg := &sync.WaitGroup{}
+ wg.Add(len(mockNodes))
+ for nodeID, mockNode := range mockNodes {
+ nodeID := nodeID
+ mockNode := mockNode
+ go func() {
+ defer wg.Done()
+
+ // Sum all Expect timeouts to give the maximum
+ // time for all expectations to finish.
+ // mockNode.Expect checks all received messages against
+ // a list of expected messages and timeout for each
+ // of them can not be checked separately.
+ var t time.Duration
+ for _, exp := range peerExpects[nodeID] {
+ if exp.Timeout == time.Duration(0) {
+ t += 2000 * time.Millisecond
+ } else {
+ t += exp.Timeout
+ }
+ }
+ alarm := time.NewTimer(t)
+ defer alarm.Stop()
+
+ // expectErrc is used to check if error returned
+ // from mockNode.Expect is not nil and to send it to
+ // errc only in that case.
+ // done channel will be closed when function
+ expectErrc := make(chan error)
+ go func() {
+ select {
+ case expectErrc <- mockNode.Expect(peerExpects[nodeID]...):
+ case <-done:
+ case <-alarm.C:
+ }
+ }()
+
+ select {
+ case err := <-expectErrc:
+ if err != nil {
+ select {
+ case errc <- err:
+ case <-done:
+ case <-alarm.C:
+ errc <- errTimedOut
+ }
+ }
+ case <-done:
+ case <-alarm.C:
+ errc <- errTimedOut
+ }
+
+ }()
+ }
+
go func() {
- errc <- mockNode.Expect(&exp)
+ wg.Wait()
+ // close errc when all goroutines finish to return nill err from errc
+ close(errc)
}()
- t := exp.Timeout
- if t == time.Duration(0) {
- t = 2000 * time.Millisecond
- }
- select {
- case err := <-errc:
- return err
- case <-time.After(t):
- return fmt.Errorf("timout expecting %v sent to peer %v", exp.Msg, exp.Peer)
- }
+ return <-errc
}
// TestExchanges tests a series of exchanges against the session
func (self *ProtocolSession) TestExchanges(exchanges ...Exchange) error {
- // launch all triggers of this exchanges
+ for i, e := range exchanges {
+ if err := self.testExchange(e); err != nil {
+ return fmt.Errorf("exchange #%d %q: %v", i, e.Label, err)
+ }
+ log.Trace(fmt.Sprintf("exchange #%d %q: run successfully", i, e.Label))
+ }
+ return nil
+}
+
+// testExchange tests a single Exchange.
+// Default timeout value is 2 seconds.
+func (self *ProtocolSession) testExchange(e Exchange) error {
+ errc := make(chan error)
+ done := make(chan struct{})
+ defer close(done)
- for _, e := range exchanges {
- errc := make(chan error, len(e.Triggers)+len(e.Expects))
+ go func() {
for _, trig := range e.Triggers {
- errc <- self.trigger(trig)
+ err := self.trigger(trig)
+ if err != nil {
+ errc <- err
+ return
+ }
}
- // each expectation is spawned in separate go-routine
- // expectations of an exchange are conjunctive but unordered, i.e.,
- // only all of them arriving constitutes a pass
- // each expectation is meant to be for a different peer, otherwise they are expected to panic
- // testing of an exchange blocks until all expectations are decided
- // an expectation is decided if
- // expected message arrives OR
- // an unexpected message arrives (panic)
- // times out on their individual timeout
- for _, ex := range e.Expects {
- // expect msg spawned to separate go routine
- go func(exp Expect) {
- errc <- self.expect(exp)
- }(ex)
+ select {
+ case errc <- self.expect(e.Expects):
+ case <-done:
}
+ }()
- // time out globally or finish when all expectations satisfied
- timeout := time.After(5 * time.Second)
- for i := 0; i < len(e.Triggers)+len(e.Expects); i++ {
- select {
- case err := <-errc:
- if err != nil {
- return fmt.Errorf("exchange failed with: %v", err)
- }
- case <-timeout:
- return fmt.Errorf("exchange %v: '%v' timed out", i, e.Label)
- }
- }
+ // time out globally or finish when all expectations satisfied
+ t := e.Timeout
+ if t == 0 {
+ t = 2000 * time.Millisecond
+ }
+ alarm := time.NewTimer(t)
+ select {
+ case err := <-errc:
+ return err
+ case <-alarm.C:
+ return errTimedOut
}
- return nil
}
// TestDisconnected tests the disconnections given as arguments
diff --git a/p2p/testing/protocoltester.go b/p2p/testing/protocoltester.go
index ea5b106ff..a797412d6 100644
--- a/p2p/testing/protocoltester.go
+++ b/p2p/testing/protocoltester.go
@@ -24,7 +24,11 @@ that can be used to send and receive messages
package testing
import (
+ "bytes"
"fmt"
+ "io"
+ "io/ioutil"
+ "strings"
"sync"
"testing"
@@ -34,6 +38,7 @@ import (
"github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/p2p/simulations"
"github.com/ethereum/go-ethereum/p2p/simulations/adapters"
+ "github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/rpc"
)
@@ -152,7 +157,7 @@ type mockNode struct {
testNode
trigger chan *Trigger
- expect chan *Expect
+ expect chan []Expect
err chan error
stop chan struct{}
stopOnce sync.Once
@@ -161,7 +166,7 @@ type mockNode struct {
func newMockNode() *mockNode {
mock := &mockNode{
trigger: make(chan *Trigger),
- expect: make(chan *Expect),
+ expect: make(chan []Expect),
err: make(chan error),
stop: make(chan struct{}),
}
@@ -176,8 +181,8 @@ func (m *mockNode) Run(peer *p2p.Peer, rw p2p.MsgReadWriter) error {
select {
case trig := <-m.trigger:
m.err <- p2p.Send(rw, trig.Code, trig.Msg)
- case exp := <-m.expect:
- m.err <- p2p.ExpectMsg(rw, exp.Code, exp.Msg)
+ case exps := <-m.expect:
+ m.err <- expectMsgs(rw, exps)
case <-m.stop:
return nil
}
@@ -189,7 +194,7 @@ func (m *mockNode) Trigger(trig *Trigger) error {
return <-m.err
}
-func (m *mockNode) Expect(exp *Expect) error {
+func (m *mockNode) Expect(exp ...Expect) error {
m.expect <- exp
return <-m.err
}
@@ -198,3 +203,67 @@ func (m *mockNode) Stop() error {
m.stopOnce.Do(func() { close(m.stop) })
return nil
}
+
+func expectMsgs(rw p2p.MsgReadWriter, exps []Expect) error {
+ matched := make([]bool, len(exps))
+ for {
+ msg, err := rw.ReadMsg()
+ if err != nil {
+ if err == io.EOF {
+ break
+ }
+ return err
+ }
+ actualContent, err := ioutil.ReadAll(msg.Payload)
+ if err != nil {
+ return err
+ }
+ var found bool
+ for i, exp := range exps {
+ if exp.Code == msg.Code && bytes.Equal(actualContent, mustEncodeMsg(exp.Msg)) {
+ if matched[i] {
+ return fmt.Errorf("message #%d received two times", i)
+ }
+ matched[i] = true
+ found = true
+ break
+ }
+ }
+ if !found {
+ expected := make([]string, 0)
+ for i, exp := range exps {
+ if matched[i] {
+ continue
+ }
+ expected = append(expected, fmt.Sprintf("code %d payload %x", exp.Code, mustEncodeMsg(exp.Msg)))
+ }
+ return fmt.Errorf("unexpected message code %d payload %x, expected %s", msg.Code, actualContent, strings.Join(expected, " or "))
+ }
+ done := true
+ for _, m := range matched {
+ if !m {
+ done = false
+ break
+ }
+ }
+ if done {
+ return nil
+ }
+ }
+ for i, m := range matched {
+ if !m {
+ return fmt.Errorf("expected message #%d not received", i)
+ }
+ }
+ return nil
+}
+
+// mustEncodeMsg uses rlp to encode a message.
+// In case of error it panics.
+func mustEncodeMsg(msg interface{}) []byte {
+ contentEnc, err := rlp.EncodeToBytes(msg)
+ if err != nil {
+ panic("content encode error: " + err.Error())
+ }
+ return contentEnc
+}