aboutsummaryrefslogtreecommitdiffstats
path: root/p2p/testing/protocoltester.go
diff options
context:
space:
mode:
Diffstat (limited to 'p2p/testing/protocoltester.go')
-rw-r--r--p2p/testing/protocoltester.go21
1 files changed, 18 insertions, 3 deletions
diff --git a/p2p/testing/protocoltester.go b/p2p/testing/protocoltester.go
index 636613c57..c99578fe0 100644
--- a/p2p/testing/protocoltester.go
+++ b/p2p/testing/protocoltester.go
@@ -180,7 +180,8 @@ func (m *mockNode) Run(peer *p2p.Peer, rw p2p.MsgReadWriter) error {
for {
select {
case trig := <-m.trigger:
- m.err <- p2p.Send(rw, trig.Code, trig.Msg)
+ wmsg := Wrap(trig.Msg)
+ m.err <- p2p.Send(rw, trig.Code, wmsg)
case exps := <-m.expect:
m.err <- expectMsgs(rw, exps)
case <-m.stop:
@@ -220,7 +221,7 @@ func expectMsgs(rw p2p.MsgReadWriter, exps []Expect) error {
}
var found bool
for i, exp := range exps {
- if exp.Code == msg.Code && bytes.Equal(actualContent, mustEncodeMsg(exp.Msg)) {
+ if exp.Code == msg.Code && bytes.Equal(actualContent, mustEncodeMsg(Wrap(exp.Msg))) {
if matched[i] {
return fmt.Errorf("message #%d received two times", i)
}
@@ -235,7 +236,7 @@ func expectMsgs(rw p2p.MsgReadWriter, exps []Expect) error {
if matched[i] {
continue
}
- expected = append(expected, fmt.Sprintf("code %d payload %x", exp.Code, mustEncodeMsg(exp.Msg)))
+ expected = append(expected, fmt.Sprintf("code %d payload %x", exp.Code, mustEncodeMsg(Wrap(exp.Msg))))
}
return fmt.Errorf("unexpected message code %d payload %x, expected %s", msg.Code, actualContent, strings.Join(expected, " or "))
}
@@ -267,3 +268,17 @@ func mustEncodeMsg(msg interface{}) []byte {
}
return contentEnc
}
+
+type WrappedMsg struct {
+ Context []byte
+ Size uint32
+ Payload []byte
+}
+
+func Wrap(msg interface{}) interface{} {
+ data, _ := rlp.EncodeToBytes(msg)
+ return &WrappedMsg{
+ Size: uint32(len(data)),
+ Payload: data,
+ }
+}