aboutsummaryrefslogtreecommitdiffstats
path: root/p2p/peer_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'p2p/peer_test.go')
-rw-r--r--p2p/peer_test.go75
1 files changed, 19 insertions, 56 deletions
diff --git a/p2p/peer_test.go b/p2p/peer_test.go
index a1260adbd..cc9f1f0cd 100644
--- a/p2p/peer_test.go
+++ b/p2p/peer_test.go
@@ -3,6 +3,7 @@ package p2p
import (
"bytes"
"fmt"
+ "io"
"io/ioutil"
"net"
"reflect"
@@ -29,8 +30,8 @@ var discard = Protocol{
},
}
-func testPeer(protos []Protocol) (*conn, *Peer, <-chan DiscReason) {
- fd1, fd2 := net.Pipe()
+func testPeer(protos []Protocol) (io.Closer, *conn, *Peer, <-chan DiscReason) {
+ fd1, _ := net.Pipe()
hs1 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion}
hs2 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion}
for _, p := range protos {
@@ -38,11 +39,12 @@ func testPeer(protos []Protocol) (*conn, *Peer, <-chan DiscReason) {
hs2.Caps = append(hs2.Caps, p.cap())
}
- peer := newPeer(newConn(fd1, hs1), protos)
+ p1, p2 := MsgPipe()
+ peer := newPeer(fd1, &conn{p1, hs1}, protos)
errc := make(chan DiscReason, 1)
go func() { errc <- peer.run() }()
- return newConn(fd2, hs2), peer, errc
+ return p1, &conn{p2, hs2}, peer, errc
}
func TestPeerProtoReadMsg(t *testing.T) {
@@ -67,8 +69,8 @@ func TestPeerProtoReadMsg(t *testing.T) {
},
}
- rw, _, errc := testPeer([]Protocol{proto})
- defer rw.Close()
+ closer, rw, _, errc := testPeer([]Protocol{proto})
+ defer closer.Close()
EncodeMsg(rw, baseProtocolLength+2, 1)
EncodeMsg(rw, baseProtocolLength+3, 2)
@@ -83,41 +85,6 @@ func TestPeerProtoReadMsg(t *testing.T) {
}
}
-func TestPeerProtoReadLargeMsg(t *testing.T) {
- defer testlog(t).detach()
-
- msgsize := uint32(10 * 1024 * 1024)
- done := make(chan struct{})
- proto := Protocol{
- Name: "a",
- Length: 5,
- Run: func(peer *Peer, rw MsgReadWriter) error {
- msg, err := rw.ReadMsg()
- if err != nil {
- t.Errorf("read error: %v", err)
- }
- if msg.Size != msgsize+4 {
- t.Errorf("incorrect msg.Size, got %d, expected %d", msg.Size, msgsize)
- }
- msg.Discard()
- close(done)
- return nil
- },
- }
-
- rw, _, errc := testPeer([]Protocol{proto})
- defer rw.Close()
-
- EncodeMsg(rw, 18, make([]byte, msgsize))
- select {
- case <-done:
- case err := <-errc:
- t.Errorf("peer returned: %v", err)
- case <-time.After(2 * time.Second):
- t.Errorf("receive timeout")
- }
-}
-
func TestPeerProtoEncodeMsg(t *testing.T) {
defer testlog(t).detach()
@@ -134,8 +101,8 @@ func TestPeerProtoEncodeMsg(t *testing.T) {
return nil
},
}
- rw, _, _ := testPeer([]Protocol{proto})
- defer rw.Close()
+ closer, rw, _, _ := testPeer([]Protocol{proto})
+ defer closer.Close()
if err := expectMsg(rw, 17, []string{"foo", "bar"}); err != nil {
t.Error(err)
@@ -145,8 +112,8 @@ func TestPeerProtoEncodeMsg(t *testing.T) {
func TestPeerWriteForBroadcast(t *testing.T) {
defer testlog(t).detach()
- rw, peer, peerErr := testPeer([]Protocol{discard})
- defer rw.Close()
+ closer, rw, peer, peerErr := testPeer([]Protocol{discard})
+ defer closer.Close()
// test write errors
if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil {
@@ -181,8 +148,8 @@ func TestPeerWriteForBroadcast(t *testing.T) {
func TestPeerPing(t *testing.T) {
defer testlog(t).detach()
- rw, _, _ := testPeer(nil)
- defer rw.Close()
+ closer, rw, _, _ := testPeer(nil)
+ defer closer.Close()
if err := EncodeMsg(rw, pingMsg); err != nil {
t.Fatal(err)
}
@@ -194,15 +161,15 @@ func TestPeerPing(t *testing.T) {
func TestPeerDisconnect(t *testing.T) {
defer testlog(t).detach()
- rw, _, disc := testPeer(nil)
- defer rw.Close()
+ closer, rw, _, disc := testPeer(nil)
+ defer closer.Close()
if err := EncodeMsg(rw, discMsg, DiscQuitting); err != nil {
t.Fatal(err)
}
if err := expectMsg(rw, discMsg, []interface{}{DiscRequested}); err != nil {
t.Error(err)
}
- rw.Close() // make test end faster
+ closer.Close() // make test end faster
if reason := <-disc; reason != DiscRequested {
t.Errorf("run returned wrong reason: got %v, want %v", reason, DiscRequested)
}
@@ -244,13 +211,9 @@ func expectMsg(r MsgReader, code uint64, content interface{}) error {
if err != nil {
panic("content encode error: " + err.Error())
}
- // skip over list header in encoded value. this is temporary.
- contentEncR := bytes.NewReader(contentEnc)
- if k, _, err := rlp.NewStream(contentEncR).Kind(); k != rlp.List || err != nil {
- panic("content must encode as RLP list")
+ if int(msg.Size) != len(contentEnc) {
+ return fmt.Errorf("message size mismatch: got %d, want %d", msg.Size, len(contentEnc))
}
- contentEnc = contentEnc[len(contentEnc)-contentEncR.Len():]
-
actualContent, err := ioutil.ReadAll(msg.Payload)
if err != nil {
return err