package p2p

import (
	"bufio"
	"bytes"
	"encoding/hex"
	"io"
	"io/ioutil"
	"net"
	"reflect"
	"testing"
	"time"
)

var discard = Protocol{
	Name:   "discard",
	Length: 1,
	Run: func(p *Peer, rw MsgReadWriter) error {
		for {
			msg, err := rw.ReadMsg()
			if err != nil {
				return err
			}
			if err = msg.Discard(); err != nil {
				return err
			}
		}
	},
}

func testPeer(protos []Protocol) (net.Conn, *Peer, <-chan error) {
	conn1, conn2 := net.Pipe()
	peer := newPeer(conn1, protos, nil)
	peer.ourID = &peerId{}
	peer.pubkeyHook = func(*peerAddr) error { return nil }
	errc := make(chan error, 1)
	go func() {
		_, err := peer.loop()
		errc <- err
	}()
	return conn2, peer, errc
}

func TestPeerProtoReadMsg(t *testing.T) {
	defer testlog(t).detach()

	done := make(chan struct{})
	proto := Protocol{
		Name:   "a",
		Length: 5,
		Run: func(peer *Peer, rw MsgReadWriter) error {
			msg, err := rw.ReadMsg()
			if err != nil {
				t.Errorf("read error: %v", err)
			}
			if msg.Code != 2 {
				t.Errorf("incorrect msg code %d relayed to protocol", msg.Code)
			}
			data, err := ioutil.ReadAll(msg.Payload)
			if err != nil {
				t.Errorf("payload read error: %v", err)
			}
			expdata, _ := hex.DecodeString("0183303030")
			if !bytes.Equal(expdata, data) {
				t.Errorf("incorrect msg data %x", data)
			}
			close(done)
			return nil
		},
	}

	net, peer, errc := testPeer([]Protocol{proto})
	defer net.Close()
	peer.startSubprotocols([]Cap{proto.cap()})

	writeMsg(net, NewMsg(18, 1, "000"))
	select {
	case <-done:
	case err := <-errc:
		t.Errorf("peer returned: %v", err)
	case <-time.After(2 * time.Second):
		t.Errorf("receive timeout")
	}
}

func TestPeerProtoReadLargeMsg(t *testing.T) {
	defer testlog(t).detach()

	msgsize := uint32(10 * 1024 * 1024)
	done := make(chan struct{})
	proto := Protocol{
		Name:   "a",
		Length: 5,
		Run: func(peer *Peer, rw MsgReadWriter) error {
			msg, err := rw.ReadMsg()
			if err != nil {
				t.Errorf("read error: %v", err)
			}
			if msg.Size != msgsize+4 {
				t.Errorf("incorrect msg.Size, got %d, expected %d", msg.Size, msgsize)
			}
			msg.Discard()
			close(done)
			return nil
		},
	}

	net, peer, errc := testPeer([]Protocol{proto})
	defer net.Close()
	peer.startSubprotocols([]Cap{proto.cap()})

	writeMsg(net, NewMsg(18, make([]byte, msgsize)))
	select {
	case <-done:
	case err := <-errc:
		t.Errorf("peer returned: %v", err)
	case <-time.After(2 * time.Second):
		t.Errorf("receive timeout")
	}
}

func TestPeerProtoEncodeMsg(t *testing.T) {
	defer testlog(t).detach()

	proto := Protocol{
		Name:   "a",
		Length: 2,
		Run: func(peer *Peer, rw MsgReadWriter) error {
			if err := EncodeMsg(rw, 2); err == nil {
				t.Error("expected error for out-of-range msg code, got nil")
			}
			if err := EncodeMsg(rw, 1, "foo", "bar"); err != nil {
				t.Errorf("write error: %v", err)
			}
			return nil
		},
	}
	net, peer, _ := testPeer([]Protocol{proto})
	defer net.Close()
	peer.startSubprotocols([]Cap{proto.cap()})

	bufr := bufio.NewReader(net)
	msg, err := readMsg(bufr)
	if err != nil {
		t.Errorf("read error: %v", err)
	}
	if msg.Code != 17 {
		t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17)
	}
	var data []string
	if err := msg.Decode(&data); err != nil {
		t.Errorf("payload decode error: %v", err)
	}
	if !reflect.DeepEqual(data, []string{"foo", "bar"}) {
		t.Errorf("payload RLP mismatch, got %#v, want %#v", data, []string{"foo", "bar"})
	}
}

func TestPeerWrite(t *testing.T) {
	defer testlog(t).detach()

	net, peer, peerErr := testPeer([]Protocol{discard})
	defer net.Close()
	peer.startSubprotocols([]Cap{discard.cap()})

	// test write errors
	if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil {
		t.Errorf("expected error for unknown protocol, got nil")
	}
	if err := peer.writeProtoMsg("discard", NewMsg(8)); err == nil {
		t.Errorf("expected error for out-of-range msg code, got nil")
	} else if perr, ok := err.(*peerError); !ok || perr.Code != errInvalidMsgCode {
		t.Errorf("wrong error for out-of-range msg code, got %#v", err)
	}

	// setup for reading the message on the other end
	read := make(chan struct{})
	go func() {
		bufr := bufio.NewReader(net)
		msg, err := readMsg(bufr)
		if err != nil {
			t.Errorf("read error: %v", err)
		} else if msg.Code != 16 {
			t.Errorf("wrong code, got %d, expected %d", msg.Code, 16)
		}
		msg.Discard()
		close(read)
	}()

	// test succcessful write
	if err := peer.writeProtoMsg("discard", NewMsg(0)); err != nil {
		t.Errorf("expect no error for known protocol: %v", err)
	}
	select {
	case <-read:
	case err := <-peerErr:
		t.Fatalf("peer stopped: %v", err)
	}
}

func TestPeerActivity(t *testing.T) {
	// shorten inactivityTimeout while this test is running
	oldT := inactivityTimeout
	defer func() { inactivityTimeout = oldT }()
	inactivityTimeout = 20 * time.Millisecond

	net, peer, peerErr := testPeer([]Protocol{discard})
	defer net.Close()
	peer.startSubprotocols([]Cap{discard.cap()})

	sub := peer.activity.Subscribe(time.Time{})
	defer sub.Unsubscribe()

	for i := 0; i < 6; i++ {
		writeMsg(net, NewMsg(16))
		select {
		case <-sub.Chan():
		case <-time.After(inactivityTimeout / 2):
			t.Fatal("no event within ", inactivityTimeout/2)
		case err := <-peerErr:
			t.Fatal("peer error", err)
		}
	}

	select {
	case <-time.After(inactivityTimeout * 2):
	case <-sub.Chan():
		t.Fatal("got activity event while connection was inactive")
	case err := <-peerErr:
		t.Fatal("peer error", err)
	}
}

func TestNewPeer(t *testing.T) {
	caps := []Cap{{"foo", 2}, {"bar", 3}}
	id := &peerId{}
	p := NewPeer(id, caps)
	if !reflect.DeepEqual(p.Caps(), caps) {
		t.Errorf("Caps mismatch: got %v, expected %v", p.Caps(), caps)
	}
	if p.Identity() != id {
		t.Errorf("Identity mismatch: got %v, expected %v", p.Identity(), id)
	}
	// Should not hang.
	p.Disconnect(DiscAlreadyConnected)
}

func TestEOFSignal(t *testing.T) {
	rb := make([]byte, 10)

	// empty reader
	eof := make(chan struct{}, 1)
	sig := &eofSignal{new(bytes.Buffer), 0, eof}
	if n, err := sig.Read(rb); n != 0 || err != io.EOF {
		t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
	}
	select {
	case <-eof:
	default:
		t.Error("EOF chan not signaled")
	}

	// count before error
	eof = make(chan struct{}, 1)
	sig = &eofSignal{bytes.NewBufferString("aaaaaaaa"), 4, eof}
	if n, err := sig.Read(rb); n != 8 || err != nil {
		t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
	}
	select {
	case <-eof:
	default:
		t.Error("EOF chan not signaled")
	}

	// error before count
	eof = make(chan struct{}, 1)
	sig = &eofSignal{bytes.NewBufferString("aaaa"), 999, eof}
	if n, err := sig.Read(rb); n != 4 || err != nil {
		t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
	}
	if n, err := sig.Read(rb); n != 0 || err != io.EOF {
		t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
	}
	select {
	case <-eof:
	default:
		t.Error("EOF chan not signaled")
	}

	// no signal if neither occurs
	eof = make(chan struct{}, 1)
	sig = &eofSignal{bytes.NewBufferString("aaaaaaaaaaaaaaaaaaaaa"), 999, eof}
	if n, err := sig.Read(rb); n != 10 || err != nil {
		t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
	}
	select {
	case <-eof:
		t.Error("unexpected EOF signal")
	default:
	}
}