aboutsummaryrefslogtreecommitdiffstats
path: root/p2p/testing/protocoltester.go
diff options
context:
space:
mode:
Diffstat (limited to 'p2p/testing/protocoltester.go')
-rw-r--r--p2p/testing/protocoltester.go79
1 files changed, 74 insertions, 5 deletions
diff --git a/p2p/testing/protocoltester.go b/p2p/testing/protocoltester.go
index ea5b106ff..a797412d6 100644
--- a/p2p/testing/protocoltester.go
+++ b/p2p/testing/protocoltester.go
@@ -24,7 +24,11 @@ that can be used to send and receive messages
package testing
import (
+ "bytes"
"fmt"
+ "io"
+ "io/ioutil"
+ "strings"
"sync"
"testing"
@@ -34,6 +38,7 @@ import (
"github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/p2p/simulations"
"github.com/ethereum/go-ethereum/p2p/simulations/adapters"
+ "github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/rpc"
)
@@ -152,7 +157,7 @@ type mockNode struct {
testNode
trigger chan *Trigger
- expect chan *Expect
+ expect chan []Expect
err chan error
stop chan struct{}
stopOnce sync.Once
@@ -161,7 +166,7 @@ type mockNode struct {
func newMockNode() *mockNode {
mock := &mockNode{
trigger: make(chan *Trigger),
- expect: make(chan *Expect),
+ expect: make(chan []Expect),
err: make(chan error),
stop: make(chan struct{}),
}
@@ -176,8 +181,8 @@ func (m *mockNode) Run(peer *p2p.Peer, rw p2p.MsgReadWriter) error {
select {
case trig := <-m.trigger:
m.err <- p2p.Send(rw, trig.Code, trig.Msg)
- case exp := <-m.expect:
- m.err <- p2p.ExpectMsg(rw, exp.Code, exp.Msg)
+ case exps := <-m.expect:
+ m.err <- expectMsgs(rw, exps)
case <-m.stop:
return nil
}
@@ -189,7 +194,7 @@ func (m *mockNode) Trigger(trig *Trigger) error {
return <-m.err
}
-func (m *mockNode) Expect(exp *Expect) error {
+func (m *mockNode) Expect(exp ...Expect) error {
m.expect <- exp
return <-m.err
}
@@ -198,3 +203,67 @@ func (m *mockNode) Stop() error {
m.stopOnce.Do(func() { close(m.stop) })
return nil
}
+
+func expectMsgs(rw p2p.MsgReadWriter, exps []Expect) error {
+ matched := make([]bool, len(exps))
+ for {
+ msg, err := rw.ReadMsg()
+ if err != nil {
+ if err == io.EOF {
+ break
+ }
+ return err
+ }
+ actualContent, err := ioutil.ReadAll(msg.Payload)
+ if err != nil {
+ return err
+ }
+ var found bool
+ for i, exp := range exps {
+ if exp.Code == msg.Code && bytes.Equal(actualContent, mustEncodeMsg(exp.Msg)) {
+ if matched[i] {
+ return fmt.Errorf("message #%d received two times", i)
+ }
+ matched[i] = true
+ found = true
+ break
+ }
+ }
+ if !found {
+ expected := make([]string, 0)
+ for i, exp := range exps {
+ if matched[i] {
+ continue
+ }
+ expected = append(expected, fmt.Sprintf("code %d payload %x", exp.Code, mustEncodeMsg(exp.Msg)))
+ }
+ return fmt.Errorf("unexpected message code %d payload %x, expected %s", msg.Code, actualContent, strings.Join(expected, " or "))
+ }
+ done := true
+ for _, m := range matched {
+ if !m {
+ done = false
+ break
+ }
+ }
+ if done {
+ return nil
+ }
+ }
+ for i, m := range matched {
+ if !m {
+ return fmt.Errorf("expected message #%d not received", i)
+ }
+ }
+ return nil
+}
+
+// mustEncodeMsg uses rlp to encode a message.
+// In case of error it panics.
+func mustEncodeMsg(msg interface{}) []byte {
+ contentEnc, err := rlp.EncodeToBytes(msg)
+ if err != nil {
+ panic("content encode error: " + err.Error())
+ }
+ return contentEnc
+}