diff options
Diffstat (limited to 'p2p/testing/protocoltester.go')
-rw-r--r-- | p2p/testing/protocoltester.go | 79 |
1 files changed, 74 insertions, 5 deletions
diff --git a/p2p/testing/protocoltester.go b/p2p/testing/protocoltester.go index ea5b106ff..a797412d6 100644 --- a/p2p/testing/protocoltester.go +++ b/p2p/testing/protocoltester.go @@ -24,7 +24,11 @@ that can be used to send and receive messages package testing import ( + "bytes" "fmt" + "io" + "io/ioutil" + "strings" "sync" "testing" @@ -34,6 +38,7 @@ import ( "github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/p2p/simulations" "github.com/ethereum/go-ethereum/p2p/simulations/adapters" + "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rpc" ) @@ -152,7 +157,7 @@ type mockNode struct { testNode trigger chan *Trigger - expect chan *Expect + expect chan []Expect err chan error stop chan struct{} stopOnce sync.Once @@ -161,7 +166,7 @@ type mockNode struct { func newMockNode() *mockNode { mock := &mockNode{ trigger: make(chan *Trigger), - expect: make(chan *Expect), + expect: make(chan []Expect), err: make(chan error), stop: make(chan struct{}), } @@ -176,8 +181,8 @@ func (m *mockNode) Run(peer *p2p.Peer, rw p2p.MsgReadWriter) error { select { case trig := <-m.trigger: m.err <- p2p.Send(rw, trig.Code, trig.Msg) - case exp := <-m.expect: - m.err <- p2p.ExpectMsg(rw, exp.Code, exp.Msg) + case exps := <-m.expect: + m.err <- expectMsgs(rw, exps) case <-m.stop: return nil } @@ -189,7 +194,7 @@ func (m *mockNode) Trigger(trig *Trigger) error { return <-m.err } -func (m *mockNode) Expect(exp *Expect) error { +func (m *mockNode) Expect(exp ...Expect) error { m.expect <- exp return <-m.err } @@ -198,3 +203,67 @@ func (m *mockNode) Stop() error { m.stopOnce.Do(func() { close(m.stop) }) return nil } + +func expectMsgs(rw p2p.MsgReadWriter, exps []Expect) error { + matched := make([]bool, len(exps)) + for { + msg, err := rw.ReadMsg() + if err != nil { + if err == io.EOF { + break + } + return err + } + actualContent, err := ioutil.ReadAll(msg.Payload) + if err != nil { + return err + } + var found bool + for i, exp := range exps { + if exp.Code == msg.Code && bytes.Equal(actualContent, mustEncodeMsg(exp.Msg)) { + if matched[i] { + return fmt.Errorf("message #%d received two times", i) + } + matched[i] = true + found = true + break + } + } + if !found { + expected := make([]string, 0) + for i, exp := range exps { + if matched[i] { + continue + } + expected = append(expected, fmt.Sprintf("code %d payload %x", exp.Code, mustEncodeMsg(exp.Msg))) + } + return fmt.Errorf("unexpected message code %d payload %x, expected %s", msg.Code, actualContent, strings.Join(expected, " or ")) + } + done := true + for _, m := range matched { + if !m { + done = false + break + } + } + if done { + return nil + } + } + for i, m := range matched { + if !m { + return fmt.Errorf("expected message #%d not received", i) + } + } + return nil +} + +// mustEncodeMsg uses rlp to encode a message. +// In case of error it panics. +func mustEncodeMsg(msg interface{}) []byte { + contentEnc, err := rlp.EncodeToBytes(msg) + if err != nil { + panic("content encode error: " + err.Error()) + } + return contentEnc +} |