diff options
Diffstat (limited to 'p2p/peer_test.go')
-rw-r--r-- | p2p/peer_test.go | 75 |
1 files changed, 19 insertions, 56 deletions
diff --git a/p2p/peer_test.go b/p2p/peer_test.go index a1260adbd..cc9f1f0cd 100644 --- a/p2p/peer_test.go +++ b/p2p/peer_test.go @@ -3,6 +3,7 @@ package p2p import ( "bytes" "fmt" + "io" "io/ioutil" "net" "reflect" @@ -29,8 +30,8 @@ var discard = Protocol{ }, } -func testPeer(protos []Protocol) (*conn, *Peer, <-chan DiscReason) { - fd1, fd2 := net.Pipe() +func testPeer(protos []Protocol) (io.Closer, *conn, *Peer, <-chan DiscReason) { + fd1, _ := net.Pipe() hs1 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion} hs2 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion} for _, p := range protos { @@ -38,11 +39,12 @@ func testPeer(protos []Protocol) (*conn, *Peer, <-chan DiscReason) { hs2.Caps = append(hs2.Caps, p.cap()) } - peer := newPeer(newConn(fd1, hs1), protos) + p1, p2 := MsgPipe() + peer := newPeer(fd1, &conn{p1, hs1}, protos) errc := make(chan DiscReason, 1) go func() { errc <- peer.run() }() - return newConn(fd2, hs2), peer, errc + return p1, &conn{p2, hs2}, peer, errc } func TestPeerProtoReadMsg(t *testing.T) { @@ -67,8 +69,8 @@ func TestPeerProtoReadMsg(t *testing.T) { }, } - rw, _, errc := testPeer([]Protocol{proto}) - defer rw.Close() + closer, rw, _, errc := testPeer([]Protocol{proto}) + defer closer.Close() EncodeMsg(rw, baseProtocolLength+2, 1) EncodeMsg(rw, baseProtocolLength+3, 2) @@ -83,41 +85,6 @@ func TestPeerProtoReadMsg(t *testing.T) { } } -func TestPeerProtoReadLargeMsg(t *testing.T) { - defer testlog(t).detach() - - msgsize := uint32(10 * 1024 * 1024) - done := make(chan struct{}) - proto := Protocol{ - Name: "a", - Length: 5, - Run: func(peer *Peer, rw MsgReadWriter) error { - msg, err := rw.ReadMsg() - if err != nil { - t.Errorf("read error: %v", err) - } - if msg.Size != msgsize+4 { - t.Errorf("incorrect msg.Size, got %d, expected %d", msg.Size, msgsize) - } - msg.Discard() - close(done) - return nil - }, - } - - rw, _, errc := testPeer([]Protocol{proto}) - defer rw.Close() - - EncodeMsg(rw, 18, make([]byte, msgsize)) - select { - case <-done: - case err := <-errc: - t.Errorf("peer returned: %v", err) - case <-time.After(2 * time.Second): - t.Errorf("receive timeout") - } -} - func TestPeerProtoEncodeMsg(t *testing.T) { defer testlog(t).detach() @@ -134,8 +101,8 @@ func TestPeerProtoEncodeMsg(t *testing.T) { return nil }, } - rw, _, _ := testPeer([]Protocol{proto}) - defer rw.Close() + closer, rw, _, _ := testPeer([]Protocol{proto}) + defer closer.Close() if err := expectMsg(rw, 17, []string{"foo", "bar"}); err != nil { t.Error(err) @@ -145,8 +112,8 @@ func TestPeerProtoEncodeMsg(t *testing.T) { func TestPeerWriteForBroadcast(t *testing.T) { defer testlog(t).detach() - rw, peer, peerErr := testPeer([]Protocol{discard}) - defer rw.Close() + closer, rw, peer, peerErr := testPeer([]Protocol{discard}) + defer closer.Close() // test write errors if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil { @@ -181,8 +148,8 @@ func TestPeerWriteForBroadcast(t *testing.T) { func TestPeerPing(t *testing.T) { defer testlog(t).detach() - rw, _, _ := testPeer(nil) - defer rw.Close() + closer, rw, _, _ := testPeer(nil) + defer closer.Close() if err := EncodeMsg(rw, pingMsg); err != nil { t.Fatal(err) } @@ -194,15 +161,15 @@ func TestPeerPing(t *testing.T) { func TestPeerDisconnect(t *testing.T) { defer testlog(t).detach() - rw, _, disc := testPeer(nil) - defer rw.Close() + closer, rw, _, disc := testPeer(nil) + defer closer.Close() if err := EncodeMsg(rw, discMsg, DiscQuitting); err != nil { t.Fatal(err) } if err := expectMsg(rw, discMsg, []interface{}{DiscRequested}); err != nil { t.Error(err) } - rw.Close() // make test end faster + closer.Close() // make test end faster if reason := <-disc; reason != DiscRequested { t.Errorf("run returned wrong reason: got %v, want %v", reason, DiscRequested) } @@ -244,13 +211,9 @@ func expectMsg(r MsgReader, code uint64, content interface{}) error { if err != nil { panic("content encode error: " + err.Error()) } - // skip over list header in encoded value. this is temporary. - contentEncR := bytes.NewReader(contentEnc) - if k, _, err := rlp.NewStream(contentEncR).Kind(); k != rlp.List || err != nil { - panic("content must encode as RLP list") + if int(msg.Size) != len(contentEnc) { + return fmt.Errorf("message size mismatch: got %d, want %d", msg.Size, len(contentEnc)) } - contentEnc = contentEnc[len(contentEnc)-contentEncR.Len():] - actualContent, err := ioutil.ReadAll(msg.Payload) if err != nil { return err |