package p2p
import (
"bufio"
"bytes"
"encoding/hex"
"io"
"io/ioutil"
"net"
"reflect"
"testing"
"time"
)
var discard = Protocol{
Name: "discard",
Length: 1,
Run: func(p *Peer, rw MsgReadWriter) error {
for {
msg, err := rw.ReadMsg()
if err != nil {
return err
}
if err = msg.Discard(); err != nil {
return err
}
}
},
}
func testPeer(protos []Protocol) (net.Conn, *Peer, <-chan error) {
conn1, conn2 := net.Pipe()
peer := newPeer(conn1, protos, nil)
peer.ourID = &peerId{}
peer.pubkeyHook = func(*peerAddr) error { return nil }
errc := make(chan error, 1)
go func() {
_, err := peer.loop()
errc <- err
}()
return conn2, peer, errc
}
func TestPeerProtoReadMsg(t *testing.T) {
defer testlog(t).detach()
done := make(chan struct{})
proto := Protocol{
Name: "a",
Length: 5,
Run: func(peer *Peer, rw MsgReadWriter) error {
msg, err := rw.ReadMsg()
if err != nil {
t.Errorf("read error: %v", err)
}
if msg.Code != 2 {
t.Errorf("incorrect msg code %d relayed to protocol", msg.Code)
}
data, err := ioutil.ReadAll(msg.Payload)
if err != nil {
t.Errorf("payload read error: %v", err)
}
expdata, _ := hex.DecodeString("0183303030")
if !bytes.Equal(expdata, data) {
t.Errorf("incorrect msg data %x", data)
}
close(done)
return nil
},
}
net, peer, errc := testPeer([]Protocol{proto})
defer net.Close()
peer.startSubprotocols([]Cap{proto.cap()})
writeMsg(net, NewMsg(18, 1, "000"))
select {
case <-done:
case err := <-errc:
t.Errorf("peer returned: %v", err)
case <-time.After(2 * time.Second):
t.Errorf("receive timeout")
}
}
func TestPeerProtoReadLargeMsg(t *testing.T) {
defer testlog(t).detach()
msgsize := uint32(10 * 1024 * 1024)
done := make(chan struct{})
proto := Protocol{
Name: "a",
Length: 5,
Run: func(peer *Peer, rw MsgReadWriter) error {
msg, err := rw.ReadMsg()
if err != nil {
t.Errorf("read error: %v", err)
}
if msg.Size != msgsize+4 {
t.Errorf("incorrect msg.Size, got %d, expected %d", msg.Size, msgsize)
}
msg.Discard()
close(done)
return nil
},
}
net, peer, errc := testPeer([]Protocol{proto})
defer net.Close()
peer.startSubprotocols([]Cap{proto.cap()})
writeMsg(net, NewMsg(18, make([]byte, msgsize)))
select {
case <-done:
case err := <-errc:
t.Errorf("peer returned: %v", err)
case <-time.After(2 * time.Second):
t.Errorf("receive timeout")
}
}
func TestPeerProtoEncodeMsg(t *testing.T) {
defer testlog(t).detach()
proto := Protocol{
Name: "a",
Length: 2,
Run: func(peer *Peer, rw MsgReadWriter) error {
if err := rw.EncodeMsg(2); err == nil {
t.Error("expected error for out-of-range msg code, got nil")
}
if err := rw.EncodeMsg(1, "foo", "bar"); err != nil {
t.Errorf("write error: %v", err)
}
return nil
},
}
net, peer, _ := testPeer([]Protocol{proto})
defer net.Close()
peer.startSubprotocols([]Cap{proto.cap()})
bufr := bufio.NewReader(net)
msg, err := readMsg(bufr)
if err != nil {
t.Errorf("read error: %v", err)
}
if msg.Code != 17 {
t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17)
}
var data []string
if err := msg.Decode(&data); err != nil {
t.Errorf("payload decode error: %v", err)
}
if !reflect.DeepEqual(data, []string{"foo", "bar"}) {
t.Errorf("payload RLP mismatch, got %#v, want %#v", data, []string{"foo", "bar"})
}
}
func TestPeerWrite(t *testing.T) {
defer testlog(t).detach()
net, peer, peerErr := testPeer([]Protocol{discard})
defer net.Close()
peer.startSubprotocols([]Cap{discard.cap()})
// test write errors
if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil {
t.Errorf("expected error for unknown protocol, got nil")
}
if err := peer.writeProtoMsg("discard", NewMsg(8)); err == nil {
t.Errorf("expected error for out-of-range msg code, got nil")
} else if perr, ok := err.(*peerError); !ok || perr.Code != errInvalidMsgCode {
t.Errorf("wrong error for out-of-range msg code, got %#v", err)
}
// setup for reading the message on the other end
read := make(chan struct{})
go func() {
bufr := bufio.NewReader(net)
msg, err := readMsg(bufr)
if err != nil {
t.Errorf("read error: %v", err)
} else if msg.Code != 16 {
t.Errorf("wrong code, got %d, expected %d", msg.Code, 16)
}
msg.Discard()
close(read)
}()
// test succcessful write
if err := peer.writeProtoMsg("discard", NewMsg(0)); err != nil {
t.Errorf("expect no error for known protocol: %v", err)
}
select {
case <-read:
case err := <-peerErr:
t.Fatalf("peer stopped: %v", err)
}
}
func TestPeerActivity(t *testing.T) {
// shorten inactivityTimeout while this test is running
oldT := inactivityTimeout
defer func() { inactivityTimeout = oldT }()
inactivityTimeout = 20 * time.Millisecond
net, peer, peerErr := testPeer([]Protocol{discard})
defer net.Close()
peer.startSubprotocols([]Cap{discard.cap()})
sub := peer.activity.Subscribe(time.Time{})
defer sub.Unsubscribe()
for i := 0; i < 6; i++ {
writeMsg(net, NewMsg(16))
select {
case <-sub.Chan():
case <-time.After(inactivityTimeout / 2):
t.Fatal("no event within ", inactivityTimeout/2)
case err := <-peerErr:
t.Fatal("peer error", err)
}
}
select {
case <-time.After(inactivityTimeout * 2):
case <-sub.Chan():
t.Fatal("got activity event while connection was inactive")
case err := <-peerErr:
t.Fatal("peer error", err)
}
}
func TestNewPeer(t *testing.T) {
caps := []Cap{{"foo", 2}, {"bar", 3}}
id := &peerId{}
p := NewPeer(id, caps)
if !reflect.DeepEqual(p.Caps(), caps) {
t.Errorf("Caps mismatch: got %v, expected %v", p.Caps(), caps)
}
if p.Identity() != id {
t.Errorf("Identity mismatch: got %v, expected %v", p.Identity(), id)
}
// Should not hang.
p.Disconnect(DiscAlreadyConnected)
}
func TestEOFSignal(t *testing.T) {
rb := make([]byte, 10)
// empty reader
eof := make(chan struct{}, 1)
sig := &eofSignal{new(bytes.Buffer), 0, eof}
if n, err := sig.Read(rb); n != 0 || err != io.EOF {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
select {
case <-eof:
default:
t.Error("EOF chan not signaled")
}
// count before error
eof = make(chan struct{}, 1)
sig = &eofSignal{bytes.NewBufferString("aaaaaaaa"), 4, eof}
if n, err := sig.Read(rb); n != 8 || err != nil {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
select {
case <-eof:
default:
t.Error("EOF chan not signaled")
}
// error before count
eof = make(chan struct{}, 1)
sig = &eofSignal{bytes.NewBufferString("aaaa"), 999, eof}
if n, err := sig.Read(rb); n != 4 || err != nil {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
if n, err := sig.Read(rb); n != 0 || err != io.EOF {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
select {
case <-eof:
default:
t.Error("EOF chan not signaled")
}
// no signal if neither occurs
eof = make(chan struct{}, 1)
sig = &eofSignal{bytes.NewBufferString("aaaaaaaaaaaaaaaaaaaaa"), 999, eof}
if n, err := sig.Read(rb); n != 10 || err != nil {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
select {
case <-eof:
t.Error("unexpected EOF signal")
default:
}
}