package p2p

import (
	"bytes"
	"fmt"
	"io"
	"net"
	"testing"
	"time"
)

type TestNetworkConnection struct {
	in      chan []byte
	current []byte
	Out     [][]byte
	addr    net.Addr
}

func NewTestNetworkConnection(addr net.Addr) *TestNetworkConnection {
	return &TestNetworkConnection{
		in:      make(chan []byte),
		current: []byte{},
		Out:     [][]byte{},
		addr:    addr,
	}
}

func (self *TestNetworkConnection) In(latency time.Duration, packets ...[]byte) {
	time.Sleep(latency)
	for _, s := range packets {
		self.in <- s
	}
}

func (self *TestNetworkConnection) Read(buff []byte) (n int, err error) {
	if len(self.current) == 0 {
		select {
		case self.current = <-self.in:
		default:
			return 0, io.EOF
		}
	}
	length := len(self.current)
	if length > len(buff) {
		copy(buff[:], self.current[:len(buff)])
		self.current = self.current[len(buff):]
		return len(buff), nil
	} else {
		copy(buff[:length], self.current[:])
		self.current = []byte{}
		return length, io.EOF
	}
}

func (self *TestNetworkConnection) Write(buff []byte) (n int, err error) {
	self.Out = append(self.Out, buff)
	fmt.Printf("net write %v\n%v\n", len(self.Out), buff)
	return len(buff), nil
}

func (self *TestNetworkConnection) Close() (err error) {
	return
}

func (self *TestNetworkConnection) LocalAddr() (addr net.Addr) {
	return
}

func (self *TestNetworkConnection) RemoteAddr() (addr net.Addr) {
	return self.addr
}

func (self *TestNetworkConnection) SetDeadline(t time.Time) (err error) {
	return
}

func (self *TestNetworkConnection) SetReadDeadline(t time.Time) (err error) {
	return
}

func (self *TestNetworkConnection) SetWriteDeadline(t time.Time) (err error) {
	return
}

func setupConnection() (*Connection, *TestNetworkConnection) {
	addr := &TestAddr{"test:30303"}
	net := NewTestNetworkConnection(addr)
	conn := NewConnection(net, NewPeerErrorChannel())
	conn.Open()
	return conn, net
}

func TestReadingNilPacket(t *testing.T) {
	conn, net := setupConnection()
	go net.In(0, []byte{})
	// time.Sleep(10 * time.Millisecond)
	select {
	case packet := <-conn.Read():
		t.Errorf("read %v", packet)
	case err := <-conn.Error():
		t.Errorf("incorrect error %v", err)
	default:
	}
	conn.Close()
}

func TestReadingShortPacket(t *testing.T) {
	conn, net := setupConnection()
	go net.In(0, []byte{0})
	select {
	case packet := <-conn.Read():
		t.Errorf("read %v", packet)
	case err := <-conn.Error():
		if err.Code != PacketTooShort {
			t.Errorf("incorrect error %v, expected %v", err.Code, PacketTooShort)
		}
	}
	conn.Close()
}

func TestReadingInvalidPacket(t *testing.T) {
	conn, net := setupConnection()
	go net.In(0, []byte{1, 0, 0, 0, 0, 0, 0, 0})
	select {
	case packet := <-conn.Read():
		t.Errorf("read %v", packet)
	case err := <-conn.Error():
		if err.Code != MagicTokenMismatch {
			t.Errorf("incorrect error %v, expected %v", err.Code, MagicTokenMismatch)
		}
	}
	conn.Close()
}

func TestReadingInvalidPayload(t *testing.T) {
	conn, net := setupConnection()
	go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 2, 0})
	select {
	case packet := <-conn.Read():
		t.Errorf("read %v", packet)
	case err := <-conn.Error():
		if err.Code != PayloadTooShort {
			t.Errorf("incorrect error %v, expected %v", err.Code, PayloadTooShort)
		}
	}
	conn.Close()
}

func TestReadingEmptyPayload(t *testing.T) {
	conn, net := setupConnection()
	go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 0})
	time.Sleep(10 * time.Millisecond)
	select {
	case packet := <-conn.Read():
		t.Errorf("read %v", packet)
	default:
	}
	select {
	case err := <-conn.Error():
		code := err.Code
		if code != EmptyPayload {
			t.Errorf("incorrect error, expected EmptyPayload, got %v", code)
		}
	default:
		t.Errorf("no error, expected EmptyPayload")
	}
	conn.Close()
}

func TestReadingCompletePacket(t *testing.T) {
	conn, net := setupConnection()
	go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 1, 1})
	time.Sleep(10 * time.Millisecond)
	select {
	case packet := <-conn.Read():
		if bytes.Compare(packet, []byte{1}) != 0 {
			t.Errorf("incorrect payload read")
		}
	case err := <-conn.Error():
		t.Errorf("incorrect error %v", err)
	default:
		t.Errorf("nothing read")
	}
	conn.Close()
}

func TestReadingTwoCompletePackets(t *testing.T) {
	conn, net := setupConnection()
	go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 1, 0, 34, 64, 8, 145, 0, 0, 0, 1, 1})

	for i := 0; i < 2; i++ {
		time.Sleep(10 * time.Millisecond)
		select {
		case packet := <-conn.Read():
			if bytes.Compare(packet, []byte{byte(i)}) != 0 {
				t.Errorf("incorrect payload read")
			}
		case err := <-conn.Error():
			t.Errorf("incorrect error %v", err)
		default:
			t.Errorf("nothing read")
		}
	}
	conn.Close()
}

func TestWriting(t *testing.T) {
	conn, net := setupConnection()
	conn.Write() <- []byte{0}
	time.Sleep(10 * time.Millisecond)
	if len(net.Out) == 0 {
		t.Errorf("no output")
	} else {
		out := net.Out[0]
		if bytes.Compare(out, []byte{34, 64, 8, 145, 0, 0, 0, 1, 0}) != 0 {
			t.Errorf("incorrect packet %v", out)
		}
	}
	conn.Close()
}

// hello packet with client id ABC: 0x22 40 08 91 00 00 00 08 84 00 00 00 43414243