From 429dd2a100f3b9e2b612b59bcb48f79a805cd6f9 Mon Sep 17 00:00:00 2001 From: obscuren Date: Fri, 7 Nov 2014 12:18:48 +0100 Subject: Implemented new miner w/ ui interface for merged mining. Closes #177 * Miner has been rewritten * Added new miner pane * Added option for local txs * Added option to read from MergeMining contract and list them for merged mining --- p2p/connection.go | 2 +- p2p/message.go | 2 +- p2p/messenger_test.go | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) (limited to 'p2p') diff --git a/p2p/connection.go b/p2p/connection.go index e999cbe55..be366235d 100644 --- a/p2p/connection.go +++ b/p2p/connection.go @@ -6,7 +6,7 @@ import ( "net" "time" - "github.com/ethereum/eth-go/ethutil" + "github.com/ethereum/go-ethereum/ethutil" ) type Connection struct { diff --git a/p2p/message.go b/p2p/message.go index 4886eaa1f..446e74dff 100644 --- a/p2p/message.go +++ b/p2p/message.go @@ -2,7 +2,7 @@ package p2p import ( // "fmt" - "github.com/ethereum/eth-go/ethutil" + "github.com/ethereum/go-ethereum/ethutil" ) type MsgCode uint8 diff --git a/p2p/messenger_test.go b/p2p/messenger_test.go index bc21d34ba..472d74515 100644 --- a/p2p/messenger_test.go +++ b/p2p/messenger_test.go @@ -3,9 +3,10 @@ package p2p import ( // "fmt" "bytes" - "github.com/ethereum/eth-go/ethutil" "testing" "time" + + "github.com/ethereum/go-ethereum/ethutil" ) func setupMessenger(handlers Handlers) (*TestNetworkConnection, chan *PeerError, *Messenger) { -- cgit v1.2.3 From f38052c499c1fee61423efeddb1f52677f1442e9 Mon Sep 17 00:00:00 2001 From: Felix Lange Date: Tue, 4 Nov 2014 13:21:44 +0100 Subject: p2p: rework protocol API --- p2p/connection.go | 275 -------------------------------- p2p/connection_test.go | 222 -------------------------- p2p/message.go | 201 +++++++++++++++++------ p2p/message_test.go | 75 ++++++--- p2p/messenger.go | 353 +++++++++++++++++++++-------------------- p2p/messenger_test.go | 224 +++++++++++++------------- p2p/peer.go | 29 +--- p2p/peer_error.go | 10 +- p2p/peer_error_handler.go | 31 ++-- p2p/peer_error_handler_test.go | 2 +- p2p/peer_test.go | 170 ++++++++++---------- p2p/protocol.go | 353 +++++++++++++++++++++++------------------ p2p/server.go | 150 ++++++++--------- p2p/server_test.go | 204 ++++++++++++++++-------- 14 files changed, 1017 insertions(+), 1282 deletions(-) delete mode 100644 p2p/connection.go delete mode 100644 p2p/connection_test.go (limited to 'p2p') diff --git a/p2p/connection.go b/p2p/connection.go deleted file mode 100644 index be366235d..000000000 --- a/p2p/connection.go +++ /dev/null @@ -1,275 +0,0 @@ -package p2p - -import ( - "bytes" - // "fmt" - "net" - "time" - - "github.com/ethereum/go-ethereum/ethutil" -) - -type Connection struct { - conn net.Conn - // conn NetworkConnection - timeout time.Duration - in chan []byte - out chan []byte - err chan *PeerError - closingIn chan chan bool - closingOut chan chan bool -} - -// const readBufferLength = 2 //for testing - -const readBufferLength = 1440 -const partialsQueueSize = 10 -const maxPendingQueueSize = 1 -const defaultTimeout = 500 - -var magicToken = []byte{34, 64, 8, 145} - -func (self *Connection) Open() { - go self.startRead() - go self.startWrite() -} - -func (self *Connection) Close() { - self.closeIn() - self.closeOut() -} - -func (self *Connection) closeIn() { - errc := make(chan bool) - self.closingIn <- errc - <-errc -} - -func (self *Connection) closeOut() { - errc := make(chan bool) - self.closingOut <- errc - <-errc -} - -func NewConnection(conn net.Conn, errchan chan *PeerError) *Connection { - return &Connection{ - conn: conn, - timeout: defaultTimeout, - in: make(chan []byte), - out: make(chan []byte), - err: errchan, - closingIn: make(chan chan bool, 1), - closingOut: make(chan chan bool, 1), - } -} - -func (self *Connection) Read() <-chan []byte { - return self.in -} - -func (self *Connection) Write() chan<- []byte { - return self.out -} - -func (self *Connection) Error() <-chan *PeerError { - return self.err -} - -func (self *Connection) startRead() { - payloads := make(chan []byte) - done := make(chan *PeerError) - pending := [][]byte{} - var head []byte - var wait time.Duration // initally 0 (no delay) - read := time.After(wait * time.Millisecond) - - for { - // if pending empty, nil channel blocks - var in chan []byte - if len(pending) > 0 { - in = self.in // enable send case - head = pending[0] - } else { - in = nil - } - - select { - case <-read: - go self.read(payloads, done) - case err := <-done: - if err == nil { // no error but nothing to read - if len(pending) < maxPendingQueueSize { - wait = 100 - } else if wait == 0 { - wait = 100 - } else { - wait = 2 * wait - } - } else { - self.err <- err // report error - wait = 100 - } - read = time.After(wait * time.Millisecond) - case payload := <-payloads: - pending = append(pending, payload) - if len(pending) < maxPendingQueueSize { - wait = 0 - } else { - wait = 100 - } - read = time.After(wait * time.Millisecond) - case in <- head: - pending = pending[1:] - case errc := <-self.closingIn: - errc <- true - close(self.in) - return - } - - } -} - -func (self *Connection) startWrite() { - pending := [][]byte{} - done := make(chan *PeerError) - writing := false - for { - if len(pending) > 0 && !writing { - writing = true - go self.write(pending[0], done) - } - select { - case payload := <-self.out: - pending = append(pending, payload) - case err := <-done: - if err == nil { - pending = pending[1:] - writing = false - } else { - self.err <- err // report error - } - case errc := <-self.closingOut: - errc <- true - close(self.out) - return - } - } -} - -func pack(payload []byte) (packet []byte) { - length := ethutil.NumberToBytes(uint32(len(payload)), 32) - // return error if too long? - // Write magic token and payload length (first 8 bytes) - packet = append(magicToken, length...) - packet = append(packet, payload...) - return -} - -func avoidPanic(done chan *PeerError) { - if rec := recover(); rec != nil { - err := NewPeerError(MiscError, " %v", rec) - logger.Debugln(err) - done <- err - } -} - -func (self *Connection) write(payload []byte, done chan *PeerError) { - defer avoidPanic(done) - var err *PeerError - _, ok := self.conn.Write(pack(payload)) - if ok != nil { - err = NewPeerError(WriteError, " %v", ok) - logger.Debugln(err) - } - done <- err -} - -func (self *Connection) read(payloads chan []byte, done chan *PeerError) { - //defer avoidPanic(done) - - partials := make(chan []byte, partialsQueueSize) - errc := make(chan *PeerError) - go self.readPartials(partials, errc) - - packet := []byte{} - length := 8 - start := true - var err *PeerError -out: - for { - // appends partials read via connection until packet is - // - either parseable (>=8bytes) - // - or complete (payload fully consumed) - for len(packet) < length { - partial, ok := <-partials - if !ok { // partials channel is closed - err = <-errc - if err == nil && len(packet) > 0 { - if start { - err = NewPeerError(PacketTooShort, "%v", packet) - } else { - err = NewPeerError(PayloadTooShort, "%d < %d", len(packet), length) - } - } - break out - } - packet = append(packet, partial...) - } - if start { - // at least 8 bytes read, can validate packet - if bytes.Compare(magicToken, packet[:4]) != 0 { - err = NewPeerError(MagicTokenMismatch, " received %v", packet[:4]) - break - } - length = int(ethutil.BytesToNumber(packet[4:8])) - packet = packet[8:] - - if length > 0 { - start = false // now consuming payload - } else { //penalize peer but read on - self.err <- NewPeerError(EmptyPayload, "") - length = 8 - } - } else { - // packet complete (payload fully consumed) - payloads <- packet[:length] - packet = packet[length:] // resclice packet - start = true - length = 8 - } - } - - // this stops partials read via the connection, should we? - //if err != nil { - // select { - // case errc <- err - // default: - //} - done <- err -} - -func (self *Connection) readPartials(partials chan []byte, errc chan *PeerError) { - defer close(partials) - for { - // Give buffering some time - self.conn.SetReadDeadline(time.Now().Add(self.timeout * time.Millisecond)) - buffer := make([]byte, readBufferLength) - // read partial from connection - bytesRead, err := self.conn.Read(buffer) - if err == nil || err.Error() == "EOF" { - if bytesRead > 0 { - partials <- buffer[:bytesRead] - } - if err != nil && err.Error() == "EOF" { - break - } - } else { - // unexpected error, report to errc - err := NewPeerError(ReadError, " %v", err) - logger.Debugln(err) - errc <- err - return // will close partials channel - } - } - close(errc) -} diff --git a/p2p/connection_test.go b/p2p/connection_test.go deleted file mode 100644 index 76ee8021c..000000000 --- a/p2p/connection_test.go +++ /dev/null @@ -1,222 +0,0 @@ -package p2p - -import ( - "bytes" - "fmt" - "io" - "net" - "testing" - "time" -) - -type TestNetworkConnection struct { - in chan []byte - current []byte - Out [][]byte - addr net.Addr -} - -func NewTestNetworkConnection(addr net.Addr) *TestNetworkConnection { - return &TestNetworkConnection{ - in: make(chan []byte), - current: []byte{}, - Out: [][]byte{}, - addr: addr, - } -} - -func (self *TestNetworkConnection) In(latency time.Duration, packets ...[]byte) { - time.Sleep(latency) - for _, s := range packets { - self.in <- s - } -} - -func (self *TestNetworkConnection) Read(buff []byte) (n int, err error) { - if len(self.current) == 0 { - select { - case self.current = <-self.in: - default: - return 0, io.EOF - } - } - length := len(self.current) - if length > len(buff) { - copy(buff[:], self.current[:len(buff)]) - self.current = self.current[len(buff):] - return len(buff), nil - } else { - copy(buff[:length], self.current[:]) - self.current = []byte{} - return length, io.EOF - } -} - -func (self *TestNetworkConnection) Write(buff []byte) (n int, err error) { - self.Out = append(self.Out, buff) - fmt.Printf("net write %v\n%v\n", len(self.Out), buff) - return len(buff), nil -} - -func (self *TestNetworkConnection) Close() (err error) { - return -} - -func (self *TestNetworkConnection) LocalAddr() (addr net.Addr) { - return -} - -func (self *TestNetworkConnection) RemoteAddr() (addr net.Addr) { - return self.addr -} - -func (self *TestNetworkConnection) SetDeadline(t time.Time) (err error) { - return -} - -func (self *TestNetworkConnection) SetReadDeadline(t time.Time) (err error) { - return -} - -func (self *TestNetworkConnection) SetWriteDeadline(t time.Time) (err error) { - return -} - -func setupConnection() (*Connection, *TestNetworkConnection) { - addr := &TestAddr{"test:30303"} - net := NewTestNetworkConnection(addr) - conn := NewConnection(net, NewPeerErrorChannel()) - conn.Open() - return conn, net -} - -func TestReadingNilPacket(t *testing.T) { - conn, net := setupConnection() - go net.In(0, []byte{}) - // time.Sleep(10 * time.Millisecond) - select { - case packet := <-conn.Read(): - t.Errorf("read %v", packet) - case err := <-conn.Error(): - t.Errorf("incorrect error %v", err) - default: - } - conn.Close() -} - -func TestReadingShortPacket(t *testing.T) { - conn, net := setupConnection() - go net.In(0, []byte{0}) - select { - case packet := <-conn.Read(): - t.Errorf("read %v", packet) - case err := <-conn.Error(): - if err.Code != PacketTooShort { - t.Errorf("incorrect error %v, expected %v", err.Code, PacketTooShort) - } - } - conn.Close() -} - -func TestReadingInvalidPacket(t *testing.T) { - conn, net := setupConnection() - go net.In(0, []byte{1, 0, 0, 0, 0, 0, 0, 0}) - select { - case packet := <-conn.Read(): - t.Errorf("read %v", packet) - case err := <-conn.Error(): - if err.Code != MagicTokenMismatch { - t.Errorf("incorrect error %v, expected %v", err.Code, MagicTokenMismatch) - } - } - conn.Close() -} - -func TestReadingInvalidPayload(t *testing.T) { - conn, net := setupConnection() - go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 2, 0}) - select { - case packet := <-conn.Read(): - t.Errorf("read %v", packet) - case err := <-conn.Error(): - if err.Code != PayloadTooShort { - t.Errorf("incorrect error %v, expected %v", err.Code, PayloadTooShort) - } - } - conn.Close() -} - -func TestReadingEmptyPayload(t *testing.T) { - conn, net := setupConnection() - go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 0}) - time.Sleep(10 * time.Millisecond) - select { - case packet := <-conn.Read(): - t.Errorf("read %v", packet) - default: - } - select { - case err := <-conn.Error(): - code := err.Code - if code != EmptyPayload { - t.Errorf("incorrect error, expected EmptyPayload, got %v", code) - } - default: - t.Errorf("no error, expected EmptyPayload") - } - conn.Close() -} - -func TestReadingCompletePacket(t *testing.T) { - conn, net := setupConnection() - go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 1, 1}) - time.Sleep(10 * time.Millisecond) - select { - case packet := <-conn.Read(): - if bytes.Compare(packet, []byte{1}) != 0 { - t.Errorf("incorrect payload read") - } - case err := <-conn.Error(): - t.Errorf("incorrect error %v", err) - default: - t.Errorf("nothing read") - } - conn.Close() -} - -func TestReadingTwoCompletePackets(t *testing.T) { - conn, net := setupConnection() - go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 1, 0, 34, 64, 8, 145, 0, 0, 0, 1, 1}) - - for i := 0; i < 2; i++ { - time.Sleep(10 * time.Millisecond) - select { - case packet := <-conn.Read(): - if bytes.Compare(packet, []byte{byte(i)}) != 0 { - t.Errorf("incorrect payload read") - } - case err := <-conn.Error(): - t.Errorf("incorrect error %v", err) - default: - t.Errorf("nothing read") - } - } - conn.Close() -} - -func TestWriting(t *testing.T) { - conn, net := setupConnection() - conn.Write() <- []byte{0} - time.Sleep(10 * time.Millisecond) - if len(net.Out) == 0 { - t.Errorf("no output") - } else { - out := net.Out[0] - if bytes.Compare(out, []byte{34, 64, 8, 145, 0, 0, 0, 1, 0}) != 0 { - t.Errorf("incorrect packet %v", out) - } - } - conn.Close() -} - -// hello packet with client id ABC: 0x22 40 08 91 00 00 00 08 84 00 00 00 43414243 diff --git a/p2p/message.go b/p2p/message.go index 446e74dff..366cff5d7 100644 --- a/p2p/message.go +++ b/p2p/message.go @@ -1,75 +1,174 @@ package p2p import ( - // "fmt" + "bytes" + "encoding/binary" + "fmt" + "io" + "io/ioutil" + "math/big" + "github.com/ethereum/go-ethereum/ethutil" ) -type MsgCode uint8 +type MsgCode uint64 +// Msg defines the structure of a p2p message. +// +// Note that a Msg can only be sent once since the Payload reader is +// consumed during sending. It is not possible to create a Msg and +// send it any number of times. If you want to reuse an encoded +// structure, encode the payload into a byte array and create a +// separate Msg with a bytes.Reader as Payload for each send. type Msg struct { - code MsgCode // this is the raw code as per adaptive msg code scheme - data *ethutil.Value - encoded []byte + Code MsgCode + Size uint32 // size of the paylod + Payload io.Reader } -func (self *Msg) Code() MsgCode { - return self.code +// NewMsg creates an RLP-encoded message with the given code. +func NewMsg(code MsgCode, params ...interface{}) Msg { + buf := new(bytes.Buffer) + for _, p := range params { + buf.Write(ethutil.Encode(p)) + } + return Msg{Code: code, Size: uint32(buf.Len()), Payload: buf} } -func (self *Msg) Data() *ethutil.Value { - return self.data +func encodePayload(params ...interface{}) []byte { + buf := new(bytes.Buffer) + for _, p := range params { + buf.Write(ethutil.Encode(p)) + } + return buf.Bytes() } -func NewMsg(code MsgCode, params ...interface{}) (msg *Msg, err error) { - - // // data := [][]interface{}{} - // data := []interface{}{} - // for _, value := range params { - // if encodable, ok := value.(ethutil.RlpEncodeDecode); ok { - // data = append(data, encodable.RlpValue()) - // } else if raw, ok := value.([]interface{}); ok { - // data = append(data, raw) - // } else { - // // data = append(data, interface{}(raw)) - // err = fmt.Errorf("Unable to encode object of type %T", value) - // return - // } - // } - return &Msg{ - code: code, - data: ethutil.NewValue(interface{}(params)), - }, nil +// Data returns the decoded RLP payload items in a message. +func (msg Msg) Data() (*ethutil.Value, error) { + // TODO: avoid copying when we have a better RLP decoder + buf := new(bytes.Buffer) + var s []interface{} + if _, err := buf.ReadFrom(msg.Payload); err != nil { + return nil, err + } + for buf.Len() > 0 { + s = append(s, ethutil.DecodeWithReader(buf)) + } + return ethutil.NewValue(s), nil +} + +// Discard reads any remaining payload data into a black hole. +func (msg Msg) Discard() error { + _, err := io.Copy(ioutil.Discard, msg.Payload) + return err +} + +var magicToken = []byte{34, 64, 8, 145} + +func writeMsg(w io.Writer, msg Msg) error { + // TODO: handle case when Size + len(code) + len(listhdr) overflows uint32 + code := ethutil.Encode(uint32(msg.Code)) + listhdr := makeListHeader(msg.Size + uint32(len(code))) + payloadLen := uint32(len(listhdr)) + uint32(len(code)) + msg.Size + + start := make([]byte, 8) + copy(start, magicToken) + binary.BigEndian.PutUint32(start[4:], payloadLen) + + for _, b := range [][]byte{start, listhdr, code} { + if _, err := w.Write(b); err != nil { + return err + } + } + _, err := io.CopyN(w, msg.Payload, int64(msg.Size)) + return err } -func NewMsgFromBytes(encoded []byte) (msg *Msg, err error) { - value := ethutil.NewValueFromBytes(encoded) - // Type of message - code := value.Get(0).Uint() - // Actual data - data := value.SliceFrom(1) - - msg = &Msg{ - code: MsgCode(code), - data: data, - // data: ethutil.NewValue(data), - encoded: encoded, +func makeListHeader(length uint32) []byte { + if length < 56 { + return []byte{byte(length + 0xc0)} } - return + enc := big.NewInt(int64(length)).Bytes() + lenb := byte(len(enc)) + 0xf7 + return append([]byte{lenb}, enc...) } -func (self *Msg) Decode(offset MsgCode) { - self.code = self.code - offset +type byteReader interface { + io.Reader + io.ByteReader } -// encode takes an offset argument to implement adaptive message coding -// the encoded message is memoized to make msgs relayed to several peers more efficient -func (self *Msg) Encode(offset MsgCode) (res []byte) { - if len(self.encoded) == 0 { - res = ethutil.NewValue(append([]interface{}{byte(self.code + offset)}, self.data.Slice()...)).Encode() - self.encoded = res +// readMsg reads a message header. +func readMsg(r byteReader) (msg Msg, err error) { + // read magic and payload size + start := make([]byte, 8) + if _, err = io.ReadFull(r, start); err != nil { + return msg, NewPeerError(ReadError, "%v", err) + } + if !bytes.HasPrefix(start, magicToken) { + return msg, NewPeerError(MagicTokenMismatch, "got %x, want %x", start[:4], magicToken) + } + size := binary.BigEndian.Uint32(start[4:]) + + // decode start of RLP message to get the message code + _, hdrlen, err := readListHeader(r) + if err != nil { + return msg, err + } + code, codelen, err := readMsgCode(r) + if err != nil { + return msg, err + } + + rlpsize := size - hdrlen - codelen + return Msg{ + Code: code, + Size: rlpsize, + Payload: io.LimitReader(r, int64(rlpsize)), + }, nil +} + +// readListHeader reads an RLP list header from r. +func readListHeader(r byteReader) (len uint64, hdrlen uint32, err error) { + b, err := r.ReadByte() + if err != nil { + return 0, 0, err + } + if b < 0xC0 { + return 0, 0, fmt.Errorf("expected list start byte >= 0xC0, got %x", b) + } else if b < 0xF7 { + len = uint64(b - 0xc0) + hdrlen = 1 } else { - res = self.encoded + lenlen := b - 0xF7 + lenbuf := make([]byte, 8) + if _, err := io.ReadFull(r, lenbuf[8-lenlen:]); err != nil { + return 0, 0, err + } + len = binary.BigEndian.Uint64(lenbuf) + hdrlen = 1 + uint32(lenlen) + } + return len, hdrlen, nil +} + +// readUint reads an RLP-encoded unsigned integer from r. +func readMsgCode(r byteReader) (code MsgCode, codelen uint32, err error) { + b, err := r.ReadByte() + if err != nil { + return 0, 0, err + } + if b < 0x80 { + return MsgCode(b), 1, nil + } else if b < 0x89 { // max length for uint64 is 8 bytes + codelen = uint32(b - 0x80) + if codelen == 0 { + return 0, 1, nil + } + buf := make([]byte, 8) + if _, err := io.ReadFull(r, buf[8-codelen:]); err != nil { + return 0, 0, err + } + return MsgCode(binary.BigEndian.Uint64(buf)), codelen, nil } - return + return 0, 0, fmt.Errorf("bad RLP type for message code: %x", b) } diff --git a/p2p/message_test.go b/p2p/message_test.go index e9d46f2c3..1edabc4e7 100644 --- a/p2p/message_test.go +++ b/p2p/message_test.go @@ -1,38 +1,67 @@ package p2p import ( + "bytes" + "io/ioutil" "testing" + + "github.com/ethereum/go-ethereum/ethutil" ) func TestNewMsg(t *testing.T) { - msg, _ := NewMsg(3, 1, "000") - if msg.Code() != 3 { - t.Errorf("incorrect code %v", msg.Code()) + msg := NewMsg(3, 1, "000") + if msg.Code != 3 { + t.Errorf("incorrect code %d, want %d", msg.Code) } - data0 := msg.Data().Get(0).Uint() - data1 := string(msg.Data().Get(1).Bytes()) - if data0 != 1 { - t.Errorf("incorrect data %v", data0) + if msg.Size != 5 { + t.Errorf("incorrect size %d, want %d", msg.Size, 5) } - if data1 != "000" { - t.Errorf("incorrect data %v", data1) + pl, _ := ioutil.ReadAll(msg.Payload) + expect := []byte{0x01, 0x83, 0x30, 0x30, 0x30} + if !bytes.Equal(pl, expect) { + t.Errorf("incorrect payload content, got %x, want %x", pl, expect) } } func TestEncodeDecodeMsg(t *testing.T) { - msg, _ := NewMsg(3, 1, "000") - encoded := msg.Encode(3) - msg, _ = NewMsgFromBytes(encoded) - msg.Decode(3) - if msg.Code() != 3 { - t.Errorf("incorrect code %v", msg.Code()) - } - data0 := msg.Data().Get(0).Uint() - data1 := msg.Data().Get(1).Str() - if data0 != 1 { - t.Errorf("incorrect data %v", data0) - } - if data1 != "000" { - t.Errorf("incorrect data %v", data1) + msg := NewMsg(3, 1, "000") + buf := new(bytes.Buffer) + if err := writeMsg(buf, msg); err != nil { + t.Fatalf("encodeMsg error: %v", err) + } + + t.Logf("encoded: %x", buf.Bytes()) + + decmsg, err := readMsg(buf) + if err != nil { + t.Fatalf("readMsg error: %v", err) + } + if decmsg.Code != 3 { + t.Errorf("incorrect code %d, want %d", decmsg.Code, 3) + } + if decmsg.Size != 5 { + t.Errorf("incorrect size %d, want %d", decmsg.Size, 5) + } + data, err := decmsg.Data() + if err != nil { + t.Fatalf("first payload item decode error: %v", err) + } + if v := data.Get(0).Uint(); v != 1 { + t.Errorf("incorrect data[0]: got %v, expected %d", v, 1) + } + if v := data.Get(1).Str(); v != "000" { + t.Errorf("incorrect data[1]: got %q, expected %q", v, "000") + } +} + +func TestDecodeRealMsg(t *testing.T) { + data := ethutil.Hex2Bytes("2240089100000080f87e8002b5457468657265756d282b2b292f5065657220536572766572204f6e652f76302e372e382f52656c656173652f4c696e75782f672b2bc082765fb84086dd80b7aefd6a6d2e3b93f4f300a86bfb6ef7bdc97cb03f793db6bb") + msg, err := readMsg(bytes.NewReader(data)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if msg.Code != 0 { + t.Errorf("incorrect code %d, want %d", msg.Code, 0) } } diff --git a/p2p/messenger.go b/p2p/messenger.go index d42ba1720..7375ecc07 100644 --- a/p2p/messenger.go +++ b/p2p/messenger.go @@ -1,220 +1,221 @@ package p2p import ( + "bufio" + "bytes" "fmt" + "io" + "io/ioutil" + "net" "sync" "time" ) -const ( - handlerTimeout = 1000 -) +type Handlers map[string]func() Protocol -type Handlers map[string](func(p *Peer) Protocol) - -type Messenger struct { - conn *Connection - peer *Peer - handlers Handlers - protocolLock sync.RWMutex - protocols []Protocol - offsets []MsgCode // offsets for adaptive message idss - protocolTable map[string]int - quit chan chan bool - err chan *PeerError - pulse chan bool -} - -func NewMessenger(peer *Peer, conn *Connection, errchan chan *PeerError, handlers Handlers) *Messenger { - baseProtocol := NewBaseProtocol(peer) - return &Messenger{ - conn: conn, - peer: peer, - offsets: []MsgCode{baseProtocol.Offset()}, - handlers: handlers, - protocols: []Protocol{baseProtocol}, - protocolTable: make(map[string]int), - err: errchan, - pulse: make(chan bool, 1), - quit: make(chan chan bool, 1), - } +type proto struct { + in chan Msg + maxcode, offset MsgCode + messenger *messenger } -func (self *Messenger) Start() { - self.conn.Open() - go self.messenger() - self.protocolLock.RLock() - defer self.protocolLock.RUnlock() - self.protocols[0].Start() +func (rw *proto) WriteMsg(msg Msg) error { + if msg.Code >= rw.maxcode { + return NewPeerError(InvalidMsgCode, "not handled") + } + return rw.messenger.writeMsg(msg) } -func (self *Messenger) Stop() { - // close pulse to stop ping pong monitoring - close(self.pulse) - self.protocolLock.RLock() - defer self.protocolLock.RUnlock() - for _, protocol := range self.protocols { - protocol.Stop() // could be parallel +func (rw *proto) ReadMsg() (Msg, error) { + msg, ok := <-rw.in + if !ok { + return msg, io.EOF } - q := make(chan bool) - self.quit <- q - <-q - self.conn.Close() + return msg, nil } -func (self *Messenger) messenger() { - in := self.conn.Read() - for { - select { - case payload, ok := <-in: - //dispatches message to the protocol asynchronously - if ok { - go self.handle(payload) - } else { - return - } - case q := <-self.quit: - q <- true - return - } - } +// eofSignal is used to 'lend' the network connection +// to a protocol. when the protocol's read loop has read the +// whole payload, the done channel is closed. +type eofSignal struct { + wrapped io.Reader + eof chan struct{} } -// handles each message by dispatching to the appropriate protocol -// using adaptive message codes -// this function is started as a separate go routine for each message -// it waits for the protocol response -// then encodes and sends outgoing messages to the connection's write channel -func (self *Messenger) handle(payload []byte) { - // send ping to heartbeat channel signalling time of last message - // select { - // case self.pulse <- true: - // default: - // } - self.pulse <- true - // initialise message from payload - msg, err := NewMsgFromBytes(payload) +func (r *eofSignal) Read(buf []byte) (int, error) { + n, err := r.wrapped.Read(buf) if err != nil { - self.err <- NewPeerError(MiscError, " %v", err) - return + close(r.eof) // tell messenger that msg has been consumed } - // retrieves protocol based on message Code - protocol, offset, peerErr := self.getProtocol(msg.Code()) - if err != nil { - self.err <- peerErr - return + return n, err +} + +// messenger represents a message-oriented peer connection. +// It keeps track of the set of protocols understood +// by the remote peer. +type messenger struct { + peer *Peer + handlers Handlers + + // the mutex protects the connection + // so only one protocol can write at a time. + writeMu sync.Mutex + conn net.Conn + bufconn *bufio.ReadWriter + + protocolLock sync.RWMutex + protocols map[string]*proto + offsets map[MsgCode]*proto + protoWG sync.WaitGroup + + err chan error + pulse chan bool +} + +func newMessenger(peer *Peer, conn net.Conn, errchan chan error, handlers Handlers) *messenger { + return &messenger{ + conn: conn, + bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)), + peer: peer, + handlers: handlers, + protocols: make(map[string]*proto), + err: errchan, + pulse: make(chan bool, 1), } - // reset message code based on adaptive offset - msg.Decode(offset) - // dispatches - response := make(chan *Msg) - go protocol.HandleIn(msg, response) - // protocol reponse timeout to prevent leaks - timer := time.After(handlerTimeout * time.Millisecond) +} + +func (m *messenger) Start() { + m.protocols[""] = m.startProto(0, "", &baseProtocol{}) + go m.readLoop() +} + +func (m *messenger) Stop() { + m.conn.Close() + m.protoWG.Wait() +} + +const ( + // maximum amount of time allowed for reading a message + msgReadTimeout = 5 * time.Second + + // messages smaller than this many bytes will be read at + // once before passing them to a protocol. + wholePayloadSize = 64 * 1024 +) + +func (m *messenger) readLoop() { + defer m.closeProtocols() for { - select { - case outgoing, ok := <-response: - // we check if response channel is not closed - if ok { - self.conn.Write() <- outgoing.Encode(offset) - } else { + m.conn.SetReadDeadline(time.Now().Add(msgReadTimeout)) + msg, err := readMsg(m.bufconn) + if err != nil { + m.err <- err + return + } + // send ping to heartbeat channel signalling time of last message + m.pulse <- true + proto, err := m.getProto(msg.Code) + if err != nil { + m.err <- err + return + } + msg.Code -= proto.offset + if msg.Size <= wholePayloadSize { + // optimization: msg is small enough, read all + // of it and move on to the next message + buf, err := ioutil.ReadAll(msg.Payload) + if err != nil { + m.err <- err return } - case <-timer: - return + msg.Payload = bytes.NewReader(buf) + proto.in <- msg + } else { + pr := &eofSignal{msg.Payload, make(chan struct{})} + msg.Payload = pr + proto.in <- msg + <-pr.eof } } } -// negotiated protocols -// stores offsets needed for adaptive message id scheme - -// based on offsets set at handshake -// get the right protocol to handle the message -func (self *Messenger) getProtocol(code MsgCode) (Protocol, MsgCode, *PeerError) { - self.protocolLock.RLock() - defer self.protocolLock.RUnlock() - base := MsgCode(0) - for index, offset := range self.offsets { - if code < offset { - return self.protocols[index], base, nil - } - base = offset +func (m *messenger) closeProtocols() { + m.protocolLock.RLock() + for _, p := range m.protocols { + close(p.in) } - return nil, MsgCode(0), NewPeerError(InvalidMsgCode, " %v", code) + m.protocolLock.RUnlock() } -func (self *Messenger) PingPong(timeout time.Duration, gracePeriod time.Duration, pingCallback func(), timeoutCallback func()) { - fmt.Printf("pingpong keepalive started at %v", time.Now()) +func (m *messenger) startProto(offset MsgCode, name string, impl Protocol) *proto { + proto := &proto{ + in: make(chan Msg), + offset: offset, + maxcode: impl.Offset(), + messenger: m, + } + m.protoWG.Add(1) + go func() { + if err := impl.Start(m.peer, proto); err != nil && err != io.EOF { + logger.Errorf("protocol %q error: %v\n", name, err) + m.err <- err + } + m.protoWG.Done() + }() + return proto +} - timer := time.After(timeout) - pinged := false - for { - select { - case _, ok := <-self.pulse: - if ok { - pinged = false - timer = time.After(timeout) - } else { - // pulse is closed, stop monitoring - return - } - case <-timer: - if pinged { - fmt.Printf("timeout at %v", time.Now()) - timeoutCallback() - return - } else { - fmt.Printf("pinged at %v", time.Now()) - pingCallback() - timer = time.After(gracePeriod) - pinged = true - } +// getProto finds the protocol responsible for handling +// the given message code. +func (m *messenger) getProto(code MsgCode) (*proto, error) { + m.protocolLock.RLock() + defer m.protocolLock.RUnlock() + for _, proto := range m.protocols { + if code >= proto.offset && code < proto.offset+proto.maxcode { + return proto, nil } } + return nil, NewPeerError(InvalidMsgCode, "%d", code) } -func (self *Messenger) AddProtocols(protocols []string) { - self.protocolLock.Lock() - defer self.protocolLock.Unlock() - i := len(self.offsets) - offset := self.offsets[i-1] +// setProtocols starts all subprotocols shared with the +// remote peer. the protocols must be sorted alphabetically. +func (m *messenger) setRemoteProtocols(protocols []string) { + m.protocolLock.Lock() + defer m.protocolLock.Unlock() + offset := baseProtocolOffset for _, name := range protocols { - protocolFunc, ok := self.handlers[name] - if ok { - protocol := protocolFunc(self.peer) - self.protocolTable[name] = i - i++ - offset += protocol.Offset() - fmt.Println("offset ", name, offset) - - self.offsets = append(self.offsets, offset) - self.protocols = append(self.protocols, protocol) - protocol.Start() - } else { - fmt.Println("no ", name) - // protocol not handled + protocolFunc, ok := m.handlers[name] + if !ok { + continue // not handled } + inst := protocolFunc() + m.protocols[name] = m.startProto(offset, name, inst) + offset += inst.Offset() } } -func (self *Messenger) Write(protocol string, msg *Msg) error { - self.protocolLock.RLock() - defer self.protocolLock.RUnlock() - i := 0 - offset := MsgCode(0) - if len(protocol) > 0 { - var ok bool - i, ok = self.protocolTable[protocol] - if !ok { - return fmt.Errorf("protocol %v not handled by peer", protocol) - } - offset = self.offsets[i-1] +// writeProtoMsg sends the given message on behalf of the given named protocol. +func (m *messenger) writeProtoMsg(protoName string, msg Msg) error { + m.protocolLock.RLock() + proto, ok := m.protocols[protoName] + m.protocolLock.RUnlock() + if !ok { + return fmt.Errorf("protocol %s not handled by peer", protoName) } - handler := self.protocols[i] - // checking if protocol status/caps allows the message to be sent out - if handler.HandleOut(msg) { - self.conn.Write() <- msg.Encode(offset) + if msg.Code >= proto.maxcode { + return NewPeerError(InvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName) + } + msg.Code += proto.offset + return m.writeMsg(msg) +} + +// writeMsg writes a message to the connection. +func (m *messenger) writeMsg(msg Msg) error { + m.writeMu.Lock() + defer m.writeMu.Unlock() + if err := writeMsg(m.bufconn, msg); err != nil { + return err } - return nil + return m.bufconn.Flush() } diff --git a/p2p/messenger_test.go b/p2p/messenger_test.go index 472d74515..f10469e2f 100644 --- a/p2p/messenger_test.go +++ b/p2p/messenger_test.go @@ -1,147 +1,157 @@ package p2p import ( - // "fmt" - "bytes" + "bufio" + "fmt" + "io" + "log" + "net" + "os" + "reflect" "testing" "time" "github.com/ethereum/go-ethereum/ethutil" ) -func setupMessenger(handlers Handlers) (*TestNetworkConnection, chan *PeerError, *Messenger) { - errchan := NewPeerErrorChannel() - addr := &TestAddr{"test:30303"} - net := NewTestNetworkConnection(addr) - conn := NewConnection(net, errchan) - mess := NewMessenger(nil, conn, errchan, handlers) - mess.Start() - return net, errchan, mess +func init() { + ethlog.AddLogSystem(ethlog.NewStdLogSystem(os.Stdout, log.LstdFlags, ethlog.DebugLevel)) } -type TestProtocol struct { - Msgs []*Msg +func setupMessenger(handlers Handlers) (net.Conn, *Peer, *messenger) { + conn1, conn2 := net.Pipe() + id := NewSimpleClientIdentity("test", "0", "0", "public key") + server := New(nil, conn1.LocalAddr(), id, handlers, 10, NewBlacklist()) + peer := server.addPeer(conn1, conn1.RemoteAddr(), true, 0) + return conn2, peer, peer.messenger } -func (self *TestProtocol) Start() { -} - -func (self *TestProtocol) Stop() { -} - -func (self *TestProtocol) Offset() MsgCode { - return MsgCode(5) +func performTestHandshake(r *bufio.Reader, w io.Writer) error { + // read remote handshake + msg, err := readMsg(r) + if err != nil { + return fmt.Errorf("read error: %v", err) + } + if msg.Code != handshakeMsg { + return fmt.Errorf("first message should be handshake, got %x", msg.Code) + } + if err := msg.Discard(); err != nil { + return err + } + // send empty handshake + pubkey := make([]byte, 64) + msg = NewMsg(handshakeMsg, p2pVersion, "testid", nil, 9999, pubkey) + return writeMsg(w, msg) } -func (self *TestProtocol) HandleIn(msg *Msg, response chan *Msg) { - self.Msgs = append(self.Msgs, msg) - close(response) +type testMsg struct { + code MsgCode + data *ethutil.Value } -func (self *TestProtocol) HandleOut(msg *Msg) bool { - if msg.Code() > 3 { - return false - } else { - return true - } +type testProto struct { + recv chan testMsg } -func (self *TestProtocol) Name() string { - return "a" -} +func (*testProto) Offset() MsgCode { return 5 } -func Packet(offset MsgCode, code MsgCode, params ...interface{}) []byte { - msg, _ := NewMsg(code, params...) - encoded := msg.Encode(offset) - packet := []byte{34, 64, 8, 145} - packet = append(packet, ethutil.NumberToBytes(uint32(len(encoded)), 32)...) - return append(packet, encoded...) +func (tp *testProto) Start(peer *Peer, rw MsgReadWriter) error { + return MsgLoop(rw, 1024, func(code MsgCode, data *ethutil.Value) error { + logger.Debugf("testprotocol got msg: %d\n", code) + tp.recv <- testMsg{code, data} + return nil + }) } func TestRead(t *testing.T) { - handlers := make(Handlers) - testProtocol := &TestProtocol{Msgs: []*Msg{}} - handlers["a"] = func(p *Peer) Protocol { return testProtocol } - net, _, mess := setupMessenger(handlers) - mess.AddProtocols([]string{"a"}) - defer mess.Stop() - wait := 1 * time.Millisecond - packet := Packet(16, 1, uint32(1), "000") - go net.In(0, packet) - time.Sleep(wait) - if len(testProtocol.Msgs) != 1 { - t.Errorf("msg not relayed to correct protocol") - } else { - if testProtocol.Msgs[0].Code() != 1 { - t.Errorf("incorrect msg code relayed to protocol") + testProtocol := &testProto{make(chan testMsg)} + handlers := Handlers{"a": func() Protocol { return testProtocol }} + net, peer, mess := setupMessenger(handlers) + bufr := bufio.NewReader(net) + defer peer.Stop() + if err := performTestHandshake(bufr, net); err != nil { + t.Fatalf("handshake failed: %v", err) + } + + mess.setRemoteProtocols([]string{"a"}) + writeMsg(net, NewMsg(17, uint32(1), "000")) + select { + case msg := <-testProtocol.recv: + if msg.code != 1 { + t.Errorf("incorrect msg code %d relayed to protocol", msg.code) + } + expdata := []interface{}{1, []byte{0x30, 0x30, 0x30}} + if !reflect.DeepEqual(msg.data.Slice(), expdata) { + t.Errorf("incorrect msg data %#v", msg.data.Slice()) } + case <-time.After(2 * time.Second): + t.Errorf("receive timeout") } } -func TestWrite(t *testing.T) { +func TestWriteProtoMsg(t *testing.T) { handlers := make(Handlers) - testProtocol := &TestProtocol{Msgs: []*Msg{}} - handlers["a"] = func(p *Peer) Protocol { return testProtocol } - net, _, mess := setupMessenger(handlers) - mess.AddProtocols([]string{"a"}) - defer mess.Stop() - wait := 1 * time.Millisecond - msg, _ := NewMsg(3, uint32(1), "000") - err := mess.Write("b", msg) - if err == nil { - t.Errorf("expect error for unknown protocol") + testProtocol := &testProto{recv: make(chan testMsg, 1)} + handlers["a"] = func() Protocol { return testProtocol } + net, peer, mess := setupMessenger(handlers) + defer peer.Stop() + bufr := bufio.NewReader(net) + if err := performTestHandshake(bufr, net); err != nil { + t.Fatalf("handshake failed: %v", err) } - err = mess.Write("a", msg) - if err != nil { - t.Errorf("expect no error for known protocol: %v", err) - } else { - time.Sleep(wait) - if len(net.Out) != 1 { - t.Errorf("msg not written") + mess.setRemoteProtocols([]string{"a"}) + + // test write errors + if err := mess.writeProtoMsg("b", NewMsg(3)); err == nil { + t.Errorf("expected error for unknown protocol, got nil") + } + if err := mess.writeProtoMsg("a", NewMsg(8)); err == nil { + t.Errorf("expected error for out-of-range msg code, got nil") + } else if perr, ok := err.(*PeerError); !ok || perr.Code != InvalidMsgCode { + t.Errorf("wrong error for out-of-range msg code, got %#v") + } + + // test succcessful write + read, readerr := make(chan Msg), make(chan error) + go func() { + if msg, err := readMsg(bufr); err != nil { + readerr <- err } else { - out := net.Out[0] - packet := Packet(16, 3, uint32(1), "000") - if bytes.Compare(out, packet) != 0 { - t.Errorf("incorrect packet %v", out) - } + read <- msg + } + }() + if err := mess.writeProtoMsg("a", NewMsg(3)); err != nil { + t.Errorf("expect no error for known protocol: %v", err) + } + select { + case msg := <-read: + if msg.Code != 19 { + t.Errorf("wrong code, got %d, expected %d", msg.Code, 19) } + msg.Discard() + case err := <-readerr: + t.Errorf("read error: %v", err) } } func TestPulse(t *testing.T) { - net, _, mess := setupMessenger(make(Handlers)) - defer mess.Stop() - ping := false - timeout := false - pingTimeout := 10 * time.Millisecond - gracePeriod := 200 * time.Millisecond - go mess.PingPong(pingTimeout, gracePeriod, func() { ping = true }, func() { timeout = true }) - net.In(0, Packet(0, 1)) - if ping { - t.Errorf("ping sent too early") - } - time.Sleep(pingTimeout + 100*time.Millisecond) - if !ping { - t.Errorf("no ping sent after timeout") - } - if timeout { - t.Errorf("timeout too early") + net, peer, _ := setupMessenger(nil) + defer peer.Stop() + bufr := bufio.NewReader(net) + if err := performTestHandshake(bufr, net); err != nil { + t.Fatalf("handshake failed: %v", err) } - ping = false - net.In(0, Packet(0, 1)) - time.Sleep(pingTimeout + 100*time.Millisecond) - if !ping { - t.Errorf("no ping sent after timeout") - } - if timeout { - t.Errorf("timeout too early") + + before := time.Now() + msg, err := readMsg(bufr) + if err != nil { + t.Fatalf("read error: %v", err) } - ping = false - time.Sleep(gracePeriod) - if ping { - t.Errorf("ping called twice") + after := time.Now() + if msg.Code != pingMsg { + t.Errorf("expected ping message, got %x", msg.Code) } - if !timeout { - t.Errorf("no timeout after grace period") + if d := after.Sub(before); d < pingTimeout { + t.Errorf("ping sent too early after %v, expected at least %v", d, pingTimeout) } } diff --git a/p2p/peer.go b/p2p/peer.go index f4b68a007..34b6152a3 100644 --- a/p2p/peer.go +++ b/p2p/peer.go @@ -7,7 +7,6 @@ import ( ) type Peer struct { - // quit chan chan bool Inbound bool // inbound (via listener) or outbound (via dialout) Address net.Addr Host []byte @@ -15,24 +14,12 @@ type Peer struct { Pubkey []byte Id string Caps []string - peerErrorChan chan *PeerError - messenger *Messenger + peerErrorChan chan error + messenger *messenger peerErrorHandler *PeerErrorHandler server *Server } -func (self *Peer) Messenger() *Messenger { - return self.messenger -} - -func (self *Peer) PeerErrorChan() chan *PeerError { - return self.peerErrorChan -} - -func (self *Peer) Server() *Server { - return self.server -} - func NewPeer(conn net.Conn, address net.Addr, inbound bool, server *Server) *Peer { peerErrorChan := NewPeerErrorChannel() host, port, _ := net.SplitHostPort(address.String()) @@ -45,9 +32,8 @@ func NewPeer(conn net.Conn, address net.Addr, inbound bool, server *Server) *Pee peerErrorChan: peerErrorChan, server: server, } - connection := NewConnection(conn, peerErrorChan) - peer.messenger = NewMessenger(peer, connection, peerErrorChan, server.Handlers()) - peer.peerErrorHandler = NewPeerErrorHandler(address, server.PeerDisconnect(), peerErrorChan, server.Blacklist()) + peer.messenger = newMessenger(peer, conn, peerErrorChan, server.Handlers()) + peer.peerErrorHandler = NewPeerErrorHandler(address, server.PeerDisconnect(), peerErrorChan) return peer } @@ -61,8 +47,8 @@ func (self *Peer) String() string { return fmt.Sprintf("%v:%v (%s) v%v %v", self.Host, self.Port, kind, self.Id, self.Caps) } -func (self *Peer) Write(protocol string, msg *Msg) error { - return self.messenger.Write(protocol, msg) +func (self *Peer) Write(protocol string, msg Msg) error { + return self.messenger.writeProtoMsg(protocol, msg) } func (self *Peer) Start() { @@ -73,9 +59,6 @@ func (self *Peer) Start() { func (self *Peer) Stop() { self.peerErrorHandler.Stop() self.messenger.Stop() - // q := make(chan bool) - // self.quit <- q - // <-q } func (p *Peer) Encode() []interface{} { diff --git a/p2p/peer_error.go b/p2p/peer_error.go index de921878a..f3ef98d98 100644 --- a/p2p/peer_error.go +++ b/p2p/peer_error.go @@ -9,10 +9,9 @@ type ErrorCode int const errorChanCapacity = 10 const ( - PacketTooShort = iota + PacketTooLong = iota PayloadTooShort MagicTokenMismatch - EmptyPayload ReadError WriteError MiscError @@ -31,10 +30,9 @@ const ( ) var errorToString = map[ErrorCode]string{ - PacketTooShort: "Packet too short", + PacketTooLong: "Packet too long", PayloadTooShort: "Payload too short", MagicTokenMismatch: "Magic token mismatch", - EmptyPayload: "Empty payload", ReadError: "Read error", WriteError: "Write error", MiscError: "Misc error", @@ -71,6 +69,6 @@ func (self *PeerError) Error() string { return self.message } -func NewPeerErrorChannel() chan *PeerError { - return make(chan *PeerError, errorChanCapacity) +func NewPeerErrorChannel() chan error { + return make(chan error, errorChanCapacity) } diff --git a/p2p/peer_error_handler.go b/p2p/peer_error_handler.go index ca6cae4db..47dcd14ff 100644 --- a/p2p/peer_error_handler.go +++ b/p2p/peer_error_handler.go @@ -18,17 +18,15 @@ type PeerErrorHandler struct { address net.Addr peerDisconnect chan DisconnectRequest severity int - peerErrorChan chan *PeerError - blacklist Blacklist + errc chan error } -func NewPeerErrorHandler(address net.Addr, peerDisconnect chan DisconnectRequest, peerErrorChan chan *PeerError, blacklist Blacklist) *PeerErrorHandler { +func NewPeerErrorHandler(address net.Addr, peerDisconnect chan DisconnectRequest, errc chan error) *PeerErrorHandler { return &PeerErrorHandler{ quit: make(chan chan bool), address: address, peerDisconnect: peerDisconnect, - peerErrorChan: peerErrorChan, - blacklist: blacklist, + errc: errc, } } @@ -45,10 +43,10 @@ func (self *PeerErrorHandler) Stop() { func (self *PeerErrorHandler) listen() { for { select { - case peerError, ok := <-self.peerErrorChan: + case err, ok := <-self.errc: if ok { - logger.Debugf("error %v\n", peerError) - go self.handle(peerError) + logger.Debugf("error %v\n", err) + go self.handle(err) } else { return } @@ -59,8 +57,12 @@ func (self *PeerErrorHandler) listen() { } } -func (self *PeerErrorHandler) handle(peerError *PeerError) { +func (self *PeerErrorHandler) handle(err error) { reason := DiscReason(' ') + peerError, ok := err.(*PeerError) + if !ok { + peerError = NewPeerError(MiscError, " %v", err) + } switch peerError.Code { case P2PVersionMismatch: reason = DiscIncompatibleVersion @@ -68,11 +70,11 @@ func (self *PeerErrorHandler) handle(peerError *PeerError) { reason = DiscInvalidIdentity case PubkeyForbidden: reason = DiscUselessPeer - case InvalidMsgCode, PacketTooShort, PayloadTooShort, MagicTokenMismatch, EmptyPayload, ProtocolBreach: + case InvalidMsgCode, PacketTooLong, PayloadTooShort, MagicTokenMismatch, ProtocolBreach: reason = DiscProtocolError case PingTimeout: reason = DiscReadTimeout - case WriteError, MiscError: + case ReadError, WriteError, MiscError: reason = DiscNetworkError case InvalidGenesis, InvalidNetworkId, InvalidProtocolVersion: reason = DiscSubprotocolError @@ -92,10 +94,5 @@ func (self *PeerErrorHandler) handle(peerError *PeerError) { } func (self *PeerErrorHandler) getSeverity(peerError *PeerError) int { - switch peerError.Code { - case ReadError: - return 4 //tolerate 3 :) - default: - return 1 - } + return 1 } diff --git a/p2p/peer_error_handler_test.go b/p2p/peer_error_handler_test.go index 790a7443b..b93252f6a 100644 --- a/p2p/peer_error_handler_test.go +++ b/p2p/peer_error_handler_test.go @@ -11,7 +11,7 @@ func TestPeerErrorHandler(t *testing.T) { address := &net.TCPAddr{IP: net.IP([]byte{1, 2, 3, 4}), Port: 30303} peerDisconnect := make(chan DisconnectRequest) peerErrorChan := NewPeerErrorChannel() - peh := NewPeerErrorHandler(address, peerDisconnect, peerErrorChan, NewBlacklist()) + peh := NewPeerErrorHandler(address, peerDisconnect, peerErrorChan) peh.Start() defer peh.Stop() for i := 0; i < 11; i++ { diff --git a/p2p/peer_test.go b/p2p/peer_test.go index c37540bef..da62cc380 100644 --- a/p2p/peer_test.go +++ b/p2p/peer_test.go @@ -1,96 +1,90 @@ package p2p -import ( - "bytes" - "fmt" - // "net" - "testing" - "time" -) +// "net" -func TestPeer(t *testing.T) { - handlers := make(Handlers) - testProtocol := &TestProtocol{Msgs: []*Msg{}} - handlers["aaa"] = func(p *Peer) Protocol { return testProtocol } - handlers["ccc"] = func(p *Peer) Protocol { return testProtocol } - addr := &TestAddr{"test:30"} - conn := NewTestNetworkConnection(addr) - _, server := SetupTestServer(handlers) - server.Handshake() - peer := NewPeer(conn, addr, true, server) - // peer.Messenger().AddProtocols([]string{"aaa", "ccc"}) - peer.Start() - defer peer.Stop() - time.Sleep(2 * time.Millisecond) - if len(conn.Out) != 1 { - t.Errorf("handshake not sent") - } else { - out := conn.Out[0] - packet := Packet(0, HandshakeMsg, P2PVersion, []byte(peer.server.identity.String()), []interface{}{peer.server.protocols}, peer.server.port, peer.server.identity.Pubkey()[1:]) - if bytes.Compare(out, packet) != 0 { - t.Errorf("incorrect handshake packet %v != %v", out, packet) - } - } +// func TestPeer(t *testing.T) { +// handlers := make(Handlers) +// testProtocol := &TestProtocol{recv: make(chan testMsg)} +// handlers["aaa"] = func(p *Peer) Protocol { return testProtocol } +// handlers["ccc"] = func(p *Peer) Protocol { return testProtocol } +// addr := &TestAddr{"test:30"} +// conn := NewTestNetworkConnection(addr) +// _, server := SetupTestServer(handlers) +// server.Handshake() +// peer := NewPeer(conn, addr, true, server) +// // peer.Messenger().AddProtocols([]string{"aaa", "ccc"}) +// peer.Start() +// defer peer.Stop() +// time.Sleep(2 * time.Millisecond) +// if len(conn.Out) != 1 { +// t.Errorf("handshake not sent") +// } else { +// out := conn.Out[0] +// packet := Packet(0, HandshakeMsg, P2PVersion, []byte(peer.server.identity.String()), []interface{}{peer.server.protocols}, peer.server.port, peer.server.identity.Pubkey()[1:]) +// if bytes.Compare(out, packet) != 0 { +// t.Errorf("incorrect handshake packet %v != %v", out, packet) +// } +// } - packet := Packet(0, HandshakeMsg, P2PVersion, []byte("peer"), []interface{}{"bbb", "aaa", "ccc"}, 30, []byte("0000000000000000000000000000000000000000000000000000000000000000")) - conn.In(0, packet) - time.Sleep(10 * time.Millisecond) +// packet := Packet(0, HandshakeMsg, P2PVersion, []byte("peer"), []interface{}{"bbb", "aaa", "ccc"}, 30, []byte("0000000000000000000000000000000000000000000000000000000000000000")) +// conn.In(0, packet) +// time.Sleep(10 * time.Millisecond) - pro, _ := peer.Messenger().protocols[0].(*BaseProtocol) - if pro.state != handshakeReceived { - t.Errorf("handshake not received") - } - if peer.Port != 30 { - t.Errorf("port incorrectly set") - } - if peer.Id != "peer" { - t.Errorf("id incorrectly set") - } - if string(peer.Pubkey) != "0000000000000000000000000000000000000000000000000000000000000000" { - t.Errorf("pubkey incorrectly set") - } - fmt.Println(peer.Caps) - if len(peer.Caps) != 3 || peer.Caps[0] != "aaa" || peer.Caps[1] != "bbb" || peer.Caps[2] != "ccc" { - t.Errorf("protocols incorrectly set") - } +// pro, _ := peer.Messenger().protocols[0].(*BaseProtocol) +// if pro.state != handshakeReceived { +// t.Errorf("handshake not received") +// } +// if peer.Port != 30 { +// t.Errorf("port incorrectly set") +// } +// if peer.Id != "peer" { +// t.Errorf("id incorrectly set") +// } +// if string(peer.Pubkey) != "0000000000000000000000000000000000000000000000000000000000000000" { +// t.Errorf("pubkey incorrectly set") +// } +// fmt.Println(peer.Caps) +// if len(peer.Caps) != 3 || peer.Caps[0] != "aaa" || peer.Caps[1] != "bbb" || peer.Caps[2] != "ccc" { +// t.Errorf("protocols incorrectly set") +// } - msg, _ := NewMsg(3) - err := peer.Write("aaa", msg) - if err != nil { - t.Errorf("expect no error for known protocol: %v", err) - } else { - time.Sleep(1 * time.Millisecond) - if len(conn.Out) != 2 { - t.Errorf("msg not written") - } else { - out := conn.Out[1] - packet := Packet(16, 3) - if bytes.Compare(out, packet) != 0 { - t.Errorf("incorrect packet %v != %v", out, packet) - } - } - } +// msg := NewMsg(3) +// err := peer.Write("aaa", msg) +// if err != nil { +// t.Errorf("expect no error for known protocol: %v", err) +// } else { +// time.Sleep(1 * time.Millisecond) +// if len(conn.Out) != 2 { +// t.Errorf("msg not written") +// } else { +// out := conn.Out[1] +// packet := Packet(16, 3) +// if bytes.Compare(out, packet) != 0 { +// t.Errorf("incorrect packet %v != %v", out, packet) +// } +// } +// } - msg, _ = NewMsg(2) - err = peer.Write("ccc", msg) - if err != nil { - t.Errorf("expect no error for known protocol: %v", err) - } else { - time.Sleep(1 * time.Millisecond) - if len(conn.Out) != 3 { - t.Errorf("msg not written") - } else { - out := conn.Out[2] - packet := Packet(21, 2) - if bytes.Compare(out, packet) != 0 { - t.Errorf("incorrect packet %v != %v", out, packet) - } - } - } +// msg = NewMsg(2) +// err = peer.Write("ccc", msg) +// if err != nil { +// t.Errorf("expect no error for known protocol: %v", err) +// } else { +// time.Sleep(1 * time.Millisecond) +// if len(conn.Out) != 3 { +// t.Errorf("msg not written") +// } else { +// out := conn.Out[2] +// packet := Packet(21, 2) +// if bytes.Compare(out, packet) != 0 { +// t.Errorf("incorrect packet %v != %v", out, packet) +// } +// } +// } - err = peer.Write("bbb", msg) - time.Sleep(1 * time.Millisecond) - if err == nil { - t.Errorf("expect error for unknown protocol") - } -} +// err = peer.Write("bbb", msg) +// time.Sleep(1 * time.Millisecond) +// if err == nil { +// t.Errorf("expect error for unknown protocol") +// } +// } diff --git a/p2p/protocol.go b/p2p/protocol.go index 5d05ced7d..ccc275287 100644 --- a/p2p/protocol.go +++ b/p2p/protocol.go @@ -2,43 +2,101 @@ package p2p import ( "bytes" - "fmt" "net" "sort" - "sync" "time" + + "github.com/ethereum/go-ethereum/ethutil" ) +// Protocol is implemented by P2P subprotocols. type Protocol interface { - Start() - Stop() - HandleIn(*Msg, chan *Msg) - HandleOut(*Msg) bool + // Start is called when the protocol becomes active. + // It should read and write messages from rw. + // Messages must be fully consumed. + // + // The connection is closed when Start returns. It should return + // any protocol-level error (such as an I/O error) that is + // encountered. + Start(peer *Peer, rw MsgReadWriter) error + + // Offset should return the number of message codes + // used by the protocol. Offset() MsgCode - Name() string +} + +type MsgReader interface { + ReadMsg() (Msg, error) +} + +type MsgWriter interface { + WriteMsg(Msg) error +} + +// MsgReadWriter is passed to protocols. Protocol implementations can +// use it to write messages back to a connected peer. +type MsgReadWriter interface { + MsgReader + MsgWriter +} + +type MsgHandler func(code MsgCode, data *ethutil.Value) error + +// MsgLoop reads messages off the given reader and +// calls the handler function for each decoded message until +// it returns an error or the peer connection is closed. +// +// If a message is larger than the given maximum size, RunProtocol +// returns an appropriate error.n +func MsgLoop(r MsgReader, maxsize uint32, handler MsgHandler) error { + for { + msg, err := r.ReadMsg() + if err != nil { + return err + } + if msg.Size > maxsize { + return NewPeerError(InvalidMsg, "size %d exceeds maximum size of %d", msg.Size, maxsize) + } + value, err := msg.Data() + if err != nil { + return err + } + if err := handler(msg.Code, value); err != nil { + return err + } + } +} + +// the ÐΞVp2p base protocol +type baseProtocol struct { + rw MsgReadWriter + peer *Peer +} + +type bpMsg struct { + code MsgCode + data *ethutil.Value } const ( - P2PVersion = 0 - pingTimeout = 2 - pingGracePeriod = 2 + p2pVersion = 0 + pingTimeout = 2 * time.Second + pingGracePeriod = 2 * time.Second ) const ( - HandshakeMsg = iota - DiscMsg - PingMsg - PongMsg - GetPeersMsg - PeersMsg - offset = 16 + // message codes + handshakeMsg = iota + discMsg + pingMsg + pongMsg + getPeersMsg + peersMsg ) -type ProtocolState uint8 - const ( - nullState = iota - handshakeReceived + baseProtocolOffset MsgCode = 16 + baseProtocolMaxMsgSize = 500 * 1024 ) type DiscReason byte @@ -62,7 +120,7 @@ const ( DiscSubprotocolError = 0x10 ) -var discReasonToString = map[DiscReason]string{ +var discReasonToString = [DiscSubprotocolError + 1]string{ DiscRequested: "Disconnect requested", DiscNetworkError: "Network error", DiscProtocolError: "Breach of protocol", @@ -82,197 +140,178 @@ func (d DiscReason) String() string { if len(discReasonToString) < int(d) { return "Unknown" } - return discReasonToString[d] } -type BaseProtocol struct { - peer *Peer - state ProtocolState - stateLock sync.RWMutex +func (bp *baseProtocol) Ping() { } -func NewBaseProtocol(peer *Peer) *BaseProtocol { - self := &BaseProtocol{ - peer: peer, - } - - return self +func (bp *baseProtocol) Offset() MsgCode { + return baseProtocolOffset } -func (self *BaseProtocol) Start() { - if self.peer != nil { - self.peer.Write("", self.peer.Server().Handshake()) - go self.peer.Messenger().PingPong( - pingTimeout*time.Second, - pingGracePeriod*time.Second, - self.Ping, - self.Timeout, - ) +func (bp *baseProtocol) Start(peer *Peer, rw MsgReadWriter) error { + bp.peer, bp.rw = peer, rw + + // Do the handshake. + // TODO: disconnect is valid before handshake, too. + rw.WriteMsg(bp.peer.server.handshakeMsg()) + msg, err := rw.ReadMsg() + if err != nil { + return err + } + if msg.Code != handshakeMsg { + return NewPeerError(ProtocolBreach, " first message must be handshake") + } + data, err := msg.Data() + if err != nil { + return NewPeerError(InvalidMsg, "%v", err) + } + if err := bp.handleHandshake(data); err != nil { + return err } -} -func (self *BaseProtocol) Stop() { + msgin := make(chan bpMsg) + done := make(chan error, 1) + go func() { + done <- MsgLoop(rw, baseProtocolMaxMsgSize, + func(code MsgCode, data *ethutil.Value) error { + msgin <- bpMsg{code, data} + return nil + }) + }() + return bp.loop(msgin, done) } -func (self *BaseProtocol) Ping() { - msg, _ := NewMsg(PingMsg) - self.peer.Write("", msg) +func (bp *baseProtocol) loop(msgin <-chan bpMsg, quit <-chan error) error { + logger.Debugf("pingpong keepalive started at %v\n", time.Now()) + messenger := bp.rw.(*proto).messenger + pingTimer := time.NewTimer(pingTimeout) + pinged := true + + for { + select { + case msg := <-msgin: + if err := bp.handle(msg.code, msg.data); err != nil { + return err + } + case err := <-quit: + return err + case <-messenger.pulse: + pingTimer.Reset(pingTimeout) + pinged = false + case <-pingTimer.C: + if pinged { + return NewPeerError(PingTimeout, "") + } + logger.Debugf("pinging at %v\n", time.Now()) + if err := bp.rw.WriteMsg(NewMsg(pingMsg)); err != nil { + return NewPeerError(WriteError, "%v", err) + } + pinged = true + pingTimer.Reset(pingTimeout) + } + } } -func (self *BaseProtocol) Timeout() { - self.peerError(PingTimeout, "") -} +func (bp *baseProtocol) handle(code MsgCode, data *ethutil.Value) error { + switch code { + case handshakeMsg: + return NewPeerError(ProtocolBreach, " extra handshake received") -func (self *BaseProtocol) Name() string { - return "" -} + case discMsg: + logger.Infof("Disconnect requested from peer %v, reason", DiscReason(data.Get(0).Uint())) + bp.peer.server.PeerDisconnect() <- DisconnectRequest{ + addr: bp.peer.Address, + reason: DiscRequested, + } -func (self *BaseProtocol) Offset() MsgCode { - return offset -} + case pingMsg: + return bp.rw.WriteMsg(NewMsg(pongMsg)) -func (self *BaseProtocol) CheckState(state ProtocolState) bool { - self.stateLock.RLock() - self.stateLock.RUnlock() - if self.state != state { - return false - } else { - return true - } -} + case pongMsg: + // reply for ping -func (self *BaseProtocol) HandleIn(msg *Msg, response chan *Msg) { - if msg.Code() == HandshakeMsg { - self.handleHandshake(msg) - } else { - if !self.CheckState(handshakeReceived) { - self.peerError(ProtocolBreach, "message code %v not allowed", msg.Code()) - close(response) - return - } - switch msg.Code() { - case DiscMsg: - logger.Infof("Disconnect requested from peer %v, reason", DiscReason(msg.Data().Get(0).Uint())) - self.peer.Server().PeerDisconnect() <- DisconnectRequest{ - addr: self.peer.Address, - reason: DiscRequested, - } - case PingMsg: - out, _ := NewMsg(PongMsg) - response <- out - case PongMsg: - case GetPeersMsg: - // Peer asked for list of connected peers - if out, err := self.peer.Server().PeersMessage(); err != nil { - response <- out + case getPeersMsg: + // Peer asked for list of connected peers. + peersRLP := bp.peer.server.encodedPeerList() + if peersRLP != nil { + msg := Msg{ + Code: peersMsg, + Size: uint32(len(peersRLP)), + Payload: bytes.NewReader(peersRLP), } - case PeersMsg: - self.handlePeers(msg) - default: - self.peerError(InvalidMsgCode, "unknown message code %v", msg.Code()) + return bp.rw.WriteMsg(msg) } - } - close(response) -} -func (self *BaseProtocol) HandleOut(msg *Msg) (allowed bool) { - // somewhat overly paranoid - allowed = msg.Code() == HandshakeMsg || msg.Code() == DiscMsg || msg.Code() < self.Offset() && self.CheckState(handshakeReceived) - return -} + case peersMsg: + bp.handlePeers(data) -func (self *BaseProtocol) peerError(errorCode ErrorCode, format string, v ...interface{}) { - err := NewPeerError(errorCode, format, v...) - logger.Warnln(err) - fmt.Println(self.peer, err) - if self.peer != nil { - self.peer.PeerErrorChan() <- err + default: + return NewPeerError(InvalidMsgCode, "unknown message code %v", code) } + return nil } -func (self *BaseProtocol) handlePeers(msg *Msg) { - it := msg.Data().NewIterator() +func (bp *baseProtocol) handlePeers(data *ethutil.Value) { + it := data.NewIterator() for it.Next() { ip := net.IP(it.Value().Get(0).Bytes()) port := it.Value().Get(1).Uint() address := &net.TCPAddr{IP: ip, Port: int(port)} - go self.peer.Server().PeerConnect(address) + go bp.peer.server.PeerConnect(address) } } -func (self *BaseProtocol) handleHandshake(msg *Msg) { - self.stateLock.Lock() - defer self.stateLock.Unlock() - if self.state != nullState { - self.peerError(ProtocolBreach, "extra handshake") - return - } - - c := msg.Data() - +func (bp *baseProtocol) handleHandshake(c *ethutil.Value) error { var ( - p2pVersion = c.Get(0).Uint() - id = c.Get(1).Str() - caps = c.Get(2) - port = c.Get(3).Uint() - pubkey = c.Get(4).Bytes() + remoteVersion = c.Get(0).Uint() + id = c.Get(1).Str() + caps = c.Get(2) + port = c.Get(3).Uint() + pubkey = c.Get(4).Bytes() ) - fmt.Printf("handshake received %v, %v, %v, %v, %v ", p2pVersion, id, caps, port, pubkey) - // Check correctness of p2p protocol version - if p2pVersion != P2PVersion { - self.peerError(P2PVersionMismatch, "Require protocol %d, received %d\n", P2PVersion, p2pVersion) - return + if remoteVersion != p2pVersion { + return NewPeerError(P2PVersionMismatch, "Require protocol %d, received %d\n", p2pVersion, remoteVersion) } // Handle the pub key (validation, uniqueness) if len(pubkey) == 0 { - self.peerError(PubkeyMissing, "not supplied in handshake.") - return + return NewPeerError(PubkeyMissing, "not supplied in handshake.") } if len(pubkey) != 64 { - self.peerError(PubkeyInvalid, "require 512 bit, got %v", len(pubkey)*8) - return + return NewPeerError(PubkeyInvalid, "require 512 bit, got %v", len(pubkey)*8) } - // Self connect detection - if bytes.Compare(self.peer.Server().ClientIdentity().Pubkey()[1:], pubkey) == 0 { - self.peerError(PubkeyForbidden, "not allowed to connect to self") - return + // self connect detection + if bytes.Compare(bp.peer.server.ClientIdentity().Pubkey()[1:], pubkey) == 0 { + return NewPeerError(PubkeyForbidden, "not allowed to connect to bp") } // register pubkey on server. this also sets the pubkey on the peer (need lock) - if err := self.peer.Server().RegisterPubkey(self.peer, pubkey); err != nil { - self.peerError(PubkeyForbidden, err.Error()) - return + if err := bp.peer.server.RegisterPubkey(bp.peer, pubkey); err != nil { + return NewPeerError(PubkeyForbidden, err.Error()) } // check port - if self.peer.Inbound { + if bp.peer.Inbound { uint16port := uint16(port) - if self.peer.Port > 0 && self.peer.Port != uint16port { - self.peerError(PortMismatch, "port mismatch: %v != %v", self.peer.Port, port) - return + if bp.peer.Port > 0 && bp.peer.Port != uint16port { + return NewPeerError(PortMismatch, "port mismatch: %v != %v", bp.peer.Port, port) } else { - self.peer.Port = uint16port + bp.peer.Port = uint16port } } capsIt := caps.NewIterator() for capsIt.Next() { cap := capsIt.Value().Str() - self.peer.Caps = append(self.peer.Caps, cap) + bp.peer.Caps = append(bp.peer.Caps, cap) } - sort.Strings(self.peer.Caps) - self.peer.Messenger().AddProtocols(self.peer.Caps) - - self.peer.Id = id - - self.state = handshakeReceived - - //p.ethereum.PushPeer(p) - // p.ethereum.reactor.Post("peerList", p.ethereum.Peers()) - return + sort.Strings(bp.peer.Caps) + bp.rw.(*proto).messenger.setRemoteProtocols(bp.peer.Caps) + bp.peer.Id = id + return nil } diff --git a/p2p/server.go b/p2p/server.go index 91bc4af5c..54d2cde30 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -80,12 +80,12 @@ type Server struct { quit chan chan bool peersLock sync.RWMutex - maxPeers int - peers []*Peer - peerSlots chan int - peersTable map[string]int - peersMsg *Msg - peerCount int + maxPeers int + peers []*Peer + peerSlots chan int + peersTable map[string]int + peerCount int + cachedEncodedPeers []byte peerConnect chan net.Addr peerDisconnect chan DisconnectRequest @@ -147,27 +147,6 @@ func (self *Server) ClientIdentity() ClientIdentity { return self.identity } -func (self *Server) PeersMessage() (msg *Msg, err error) { - // TODO: memoize and reset when peers change - self.peersLock.RLock() - defer self.peersLock.RUnlock() - msg = self.peersMsg - if msg == nil { - var peerData []interface{} - for _, i := range self.peersTable { - peer := self.peers[i] - peerData = append(peerData, peer.Encode()) - } - if len(peerData) == 0 { - err = fmt.Errorf("no peers") - } else { - msg, err = NewMsg(PeersMsg, peerData...) - self.peersMsg = msg //memoize - } - } - return -} - func (self *Server) Peers() (peers []*Peer) { self.peersLock.RLock() defer self.peersLock.RUnlock() @@ -185,8 +164,6 @@ func (self *Server) PeerCount() int { return self.peerCount } -var getPeersMsg, _ = NewMsg(GetPeersMsg) - func (self *Server) PeerConnect(addr net.Addr) { // TODO: should buffer, filter and uniq // send GetPeersMsg if not blocking @@ -209,12 +186,21 @@ func (self *Server) Handlers() Handlers { return self.handlers } -func (self *Server) Broadcast(protocol string, msg *Msg) { +func (self *Server) Broadcast(protocol string, code MsgCode, data ...interface{}) { + var payload []byte + if data != nil { + payload = encodePayload(data...) + } self.peersLock.RLock() defer self.peersLock.RUnlock() for _, peer := range self.peers { if peer != nil { - peer.Write(protocol, msg) + var msg = Msg{Code: code} + if data != nil { + msg.Payload = bytes.NewReader(payload) + msg.Size = uint32(len(payload)) + } + peer.messenger.writeProtoMsg(protocol, msg) } } } @@ -296,7 +282,7 @@ FOR: select { case slot := <-self.peerSlots: i++ - fmt.Printf("%v: found slot %v", i, slot) + fmt.Printf("%v: found slot %v\n", i, slot) if i == self.maxPeers { break FOR } @@ -358,70 +344,68 @@ func (self *Server) outboundPeerHandler(dialer Dialer) { } // check if peer address already connected -func (self *Server) connected(address net.Addr) (err error) { +func (self *Server) isConnected(address net.Addr) bool { self.peersLock.RLock() defer self.peersLock.RUnlock() - // fmt.Printf("address: %v\n", address) - slot, found := self.peersTable[address.String()] - if found { - err = fmt.Errorf("already connected as peer %v (%v)", slot, address) - } - return + _, found := self.peersTable[address.String()] + return found } // connect to peer via listener.Accept() func (self *Server) connectInboundPeer(listener net.Listener, slot int) { var address net.Addr conn, err := listener.Accept() - if err == nil { - address = conn.RemoteAddr() - err = self.connected(address) - if err != nil { - conn.Close() - } - } if err != nil { logger.Debugln(err) self.peerSlots <- slot - } else { - fmt.Printf("adding %v\n", address) - go self.addPeer(conn, address, true, slot) + return + } + address = conn.RemoteAddr() + // XXX: this won't work because the remote socket + // address does not identify the peer. we should + // probably get rid of this check and rely on public + // key detection in the base protocol. + if self.isConnected(address) { + conn.Close() + self.peerSlots <- slot + return } + fmt.Printf("adding %v\n", address) + go self.addPeer(conn, address, true, slot) } // connect to peer via dial out func (self *Server) connectOutboundPeer(dialer Dialer, address net.Addr, slot int) { - var conn net.Conn - err := self.connected(address) - if err == nil { - conn, err = dialer.Dial(address.Network(), address.String()) + if self.isConnected(address) { + return } + conn, err := dialer.Dial(address.Network(), address.String()) if err != nil { - logger.Debugln(err) self.peerSlots <- slot - } else { - go self.addPeer(conn, address, false, slot) + return } + go self.addPeer(conn, address, false, slot) } // creates the new peer object and inserts it into its slot -func (self *Server) addPeer(conn net.Conn, address net.Addr, inbound bool, slot int) { +func (self *Server) addPeer(conn net.Conn, address net.Addr, inbound bool, slot int) *Peer { self.peersLock.Lock() defer self.peersLock.Unlock() if self.closed { fmt.Println("oopsy, not no longer need peer") conn.Close() //oopsy our bad self.peerSlots <- slot // release slot - } else { - peer := NewPeer(conn, address, inbound, self) - self.peers[slot] = peer - self.peersTable[address.String()] = slot - self.peerCount++ - // reset peersmsg - self.peersMsg = nil - fmt.Printf("added peer %v %v (slot %v)\n", address, peer, slot) - peer.Start() + return nil } + logger.Infoln("adding new peer", address) + peer := NewPeer(conn, address, inbound, self) + self.peers[slot] = peer + self.peersTable[address.String()] = slot + self.peerCount++ + self.cachedEncodedPeers = nil + fmt.Printf("added peer %v %v (slot %v)\n", address, peer, slot) + peer.Start() + return peer } // removes peer: sending disconnect msg, stop peer, remove rom list/table, release slot @@ -441,13 +425,12 @@ func (self *Server) removePeer(request DisconnectRequest) { self.peerCount-- self.peers[slot] = nil delete(self.peersTable, address.String()) - // reset peersmsg - self.peersMsg = nil + self.cachedEncodedPeers = nil fmt.Printf("removed peer %v (slot %v)\n", peer, slot) self.peersLock.Unlock() // sending disconnect message - disconnectMsg, _ := NewMsg(DiscMsg, request.reason) + disconnectMsg := NewMsg(discMsg, request.reason) peer.Write("", disconnectMsg) // be nice and wait time.Sleep(disconnectGracePeriod * time.Second) @@ -459,11 +442,32 @@ func (self *Server) removePeer(request DisconnectRequest) { self.peerSlots <- slot } +// encodedPeerList returns an RLP-encoded list of peers. +// the returned slice will be nil if there are no peers. +func (self *Server) encodedPeerList() []byte { + // TODO: memoize and reset when peers change + self.peersLock.RLock() + defer self.peersLock.RUnlock() + if self.cachedEncodedPeers == nil && self.peerCount > 0 { + var peerData []interface{} + for _, i := range self.peersTable { + peer := self.peers[i] + peerData = append(peerData, peer.Encode()) + } + self.cachedEncodedPeers = encodePayload(peerData) + } + return self.cachedEncodedPeers +} + // fix handshake message to push to peers -func (self *Server) Handshake() *Msg { - fmt.Println(self.identity.Pubkey()[1:]) - msg, _ := NewMsg(HandshakeMsg, P2PVersion, []byte(self.identity.String()), []interface{}{self.protocols}, self.port, self.identity.Pubkey()[1:]) - return msg +func (self *Server) handshakeMsg() Msg { + return NewMsg(handshakeMsg, + p2pVersion, + []byte(self.identity.String()), + []interface{}{self.protocols}, + self.port, + self.identity.Pubkey()[1:], + ) } func (self *Server) RegisterPubkey(candidate *Peer, pubkey []byte) error { diff --git a/p2p/server_test.go b/p2p/server_test.go index f749cc490..472759231 100644 --- a/p2p/server_test.go +++ b/p2p/server_test.go @@ -1,8 +1,8 @@ package p2p import ( - "bytes" "fmt" + "io" "net" "testing" "time" @@ -32,6 +32,7 @@ func (self *TestNetwork) Listener(addr net.Addr) (net.Listener, error) { connections: self.connections, addr: addr, max: self.maxinbound, + close: make(chan struct{}), }, nil } @@ -76,24 +77,25 @@ type TestListener struct { addr net.Addr max int i int + close chan struct{} } -func (self *TestListener) Accept() (conn net.Conn, err error) { +func (self *TestListener) Accept() (net.Conn, error) { self.i++ if self.i > self.max { - err = fmt.Errorf("no more") - } else { - addr := &TestAddr{fmt.Sprintf("inboundpeer-%d", self.i)} - tconn := NewTestNetworkConnection(addr) - key := tconn.RemoteAddr().String() - self.connections[key] = tconn - conn = net.Conn(tconn) - fmt.Printf("accepted connection from: %v \n", addr) + <-self.close + return nil, io.EOF } - return + addr := &TestAddr{fmt.Sprintf("inboundpeer-%d", self.i)} + tconn := NewTestNetworkConnection(addr) + key := tconn.RemoteAddr().String() + self.connections[key] = tconn + fmt.Printf("accepted connection from: %v \n", addr) + return tconn, nil } func (self *TestListener) Close() error { + close(self.close) return nil } @@ -101,6 +103,86 @@ func (self *TestListener) Addr() net.Addr { return self.addr } +type TestNetworkConnection struct { + in chan []byte + close chan struct{} + current []byte + Out [][]byte + addr net.Addr +} + +func NewTestNetworkConnection(addr net.Addr) *TestNetworkConnection { + return &TestNetworkConnection{ + in: make(chan []byte), + close: make(chan struct{}), + current: []byte{}, + Out: [][]byte{}, + addr: addr, + } +} + +func (self *TestNetworkConnection) In(latency time.Duration, packets ...[]byte) { + time.Sleep(latency) + for _, s := range packets { + self.in <- s + } +} + +func (self *TestNetworkConnection) Read(buff []byte) (n int, err error) { + if len(self.current) == 0 { + var ok bool + select { + case self.current, ok = <-self.in: + if !ok { + return 0, io.EOF + } + case <-self.close: + return 0, io.EOF + } + } + length := len(self.current) + if length > len(buff) { + copy(buff[:], self.current[:len(buff)]) + self.current = self.current[len(buff):] + return len(buff), nil + } else { + copy(buff[:length], self.current[:]) + self.current = []byte{} + return length, io.EOF + } +} + +func (self *TestNetworkConnection) Write(buff []byte) (n int, err error) { + self.Out = append(self.Out, buff) + fmt.Printf("net write(%d): %x\n", len(self.Out), buff) + return len(buff), nil +} + +func (self *TestNetworkConnection) Close() error { + close(self.close) + return nil +} + +func (self *TestNetworkConnection) LocalAddr() (addr net.Addr) { + return +} + +func (self *TestNetworkConnection) RemoteAddr() (addr net.Addr) { + return self.addr +} + +func (self *TestNetworkConnection) SetDeadline(t time.Time) (err error) { + return +} + +func (self *TestNetworkConnection) SetReadDeadline(t time.Time) (err error) { + return +} + +func (self *TestNetworkConnection) SetWriteDeadline(t time.Time) (err error) { + return +} + func SetupTestServer(handlers Handlers) (network *TestNetwork, server *Server) { network = NewTestNetwork(1) addr := &TestAddr{"test:30303"} @@ -124,12 +206,10 @@ func TestServerListener(t *testing.T) { if !ok { t.Error("not found inbound peer 1") } else { - fmt.Printf("out: %v\n", peer1.Out) if len(peer1.Out) != 2 { - t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out)) + t.Errorf("wrong number of writes to peer 1: got %d, want %d", len(peer1.Out), 2) } } - } func TestServerDialer(t *testing.T) { @@ -142,65 +222,63 @@ func TestServerDialer(t *testing.T) { if !ok { t.Error("not found outbound peer 1") } else { - fmt.Printf("out: %v\n", peer1.Out) if len(peer1.Out) != 2 { - t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out)) + t.Errorf("wrong number of writes to peer 1: got %d, want %d", len(peer1.Out), 2) } } } -func TestServerBroadcast(t *testing.T) { - handlers := make(Handlers) - testProtocol := &TestProtocol{Msgs: []*Msg{}} - handlers["aaa"] = func(p *Peer) Protocol { return testProtocol } - network, server := SetupTestServer(handlers) - server.Start(true, true) - server.peerConnect <- &TestAddr{"outboundpeer-1"} - time.Sleep(10 * time.Millisecond) - msg, _ := NewMsg(0) - server.Broadcast("", msg) - packet := Packet(0, 0) - time.Sleep(10 * time.Millisecond) - server.Stop() - peer1, ok := network.connections["outboundpeer-1"] - if !ok { - t.Error("not found outbound peer 1") - } else { - fmt.Printf("out: %v\n", peer1.Out) - if len(peer1.Out) != 3 { - t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out)) - } else { - if bytes.Compare(peer1.Out[1], packet) != 0 { - t.Errorf("incorrect broadcast packet %v != %v", peer1.Out[1], packet) - } - } - } - peer2, ok := network.connections["inboundpeer-1"] - if !ok { - t.Error("not found inbound peer 2") - } else { - fmt.Printf("out: %v\n", peer2.Out) - if len(peer1.Out) != 3 { - t.Errorf("not enough messages sent to peer 2: %v ", len(peer2.Out)) - } else { - if bytes.Compare(peer2.Out[1], packet) != 0 { - t.Errorf("incorrect broadcast packet %v != %v", peer2.Out[1], packet) - } - } - } -} +// func TestServerBroadcast(t *testing.T) { +// handlers := make(Handlers) +// testProtocol := &TestProtocol{Msgs: []*Msg{}} +// handlers["aaa"] = func(p *Peer) Protocol { return testProtocol } +// network, server := SetupTestServer(handlers) +// server.Start(true, true) +// server.peerConnect <- &TestAddr{"outboundpeer-1"} +// time.Sleep(10 * time.Millisecond) +// msg := NewMsg(0) +// server.Broadcast("", msg) +// packet := Packet(0, 0) +// time.Sleep(10 * time.Millisecond) +// server.Stop() +// peer1, ok := network.connections["outboundpeer-1"] +// if !ok { +// t.Error("not found outbound peer 1") +// } else { +// fmt.Printf("out: %v\n", peer1.Out) +// if len(peer1.Out) != 3 { +// t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out)) +// } else { +// if bytes.Compare(peer1.Out[1], packet) != 0 { +// t.Errorf("incorrect broadcast packet %v != %v", peer1.Out[1], packet) +// } +// } +// } +// peer2, ok := network.connections["inboundpeer-1"] +// if !ok { +// t.Error("not found inbound peer 2") +// } else { +// fmt.Printf("out: %v\n", peer2.Out) +// if len(peer1.Out) != 3 { +// t.Errorf("not enough messages sent to peer 2: %v ", len(peer2.Out)) +// } else { +// if bytes.Compare(peer2.Out[1], packet) != 0 { +// t.Errorf("incorrect broadcast packet %v != %v", peer2.Out[1], packet) +// } +// } +// } +// } func TestServerPeersMessage(t *testing.T) { - handlers := make(Handlers) - _, server := SetupTestServer(handlers) + _, server := SetupTestServer(nil) server.Start(true, true) defer server.Stop() server.peerConnect <- &TestAddr{"outboundpeer-1"} - time.Sleep(10 * time.Millisecond) - peersMsg, err := server.PeersMessage() - fmt.Println(peersMsg) - if err != nil { - t.Errorf("expect no error, got %v", err) + time.Sleep(2000 * time.Millisecond) + + pl := server.encodedPeerList() + if pl == nil { + t.Errorf("expect non-nil peer list") } if c := server.PeerCount(); c != 2 { t.Errorf("expect 2 peers, got %v", c) -- cgit v1.2.3 From 7149191dd999f4d192398e4b0821b656e62f3345 Mon Sep 17 00:00:00 2001 From: Felix Lange Date: Wed, 5 Nov 2014 01:28:46 +0100 Subject: p2p: fix issues found during review --- p2p/message.go | 2 +- p2p/messenger.go | 14 +++--- p2p/messenger_test.go | 128 ++++++++++++++++++++++++++++++++++---------------- p2p/protocol.go | 5 +- 4 files changed, 96 insertions(+), 53 deletions(-) (limited to 'p2p') diff --git a/p2p/message.go b/p2p/message.go index 366cff5d7..97d440a27 100644 --- a/p2p/message.go +++ b/p2p/message.go @@ -98,7 +98,7 @@ type byteReader interface { io.ByteReader } -// readMsg reads a message header. +// readMsg reads a message header from r. func readMsg(r byteReader) (msg Msg, err error) { // read magic and payload size start := make([]byte, 8) diff --git a/p2p/messenger.go b/p2p/messenger.go index 7375ecc07..c7948a9ac 100644 --- a/p2p/messenger.go +++ b/p2p/messenger.go @@ -11,7 +11,7 @@ import ( "time" ) -type Handlers map[string]func() Protocol +type Handlers map[string]Protocol type proto struct { in chan Msg @@ -23,6 +23,7 @@ func (rw *proto) WriteMsg(msg Msg) error { if msg.Code >= rw.maxcode { return NewPeerError(InvalidMsgCode, "not handled") } + msg.Code += rw.offset return rw.messenger.writeMsg(msg) } @@ -31,12 +32,13 @@ func (rw *proto) ReadMsg() (Msg, error) { if !ok { return msg, io.EOF } + msg.Code -= rw.offset return msg, nil } -// eofSignal is used to 'lend' the network connection -// to a protocol. when the protocol's read loop has read the -// whole payload, the done channel is closed. +// eofSignal wraps a reader with eof signaling. +// the eof channel is closed when the wrapped reader +// reaches EOF. type eofSignal struct { wrapped io.Reader eof chan struct{} @@ -119,7 +121,6 @@ func (m *messenger) readLoop() { m.err <- err return } - msg.Code -= proto.offset if msg.Size <= wholePayloadSize { // optimization: msg is small enough, read all // of it and move on to the next message @@ -185,11 +186,10 @@ func (m *messenger) setRemoteProtocols(protocols []string) { defer m.protocolLock.Unlock() offset := baseProtocolOffset for _, name := range protocols { - protocolFunc, ok := m.handlers[name] + inst, ok := m.handlers[name] if !ok { continue // not handled } - inst := protocolFunc() m.protocols[name] = m.startProto(offset, name, inst) offset += inst.Offset() } diff --git a/p2p/messenger_test.go b/p2p/messenger_test.go index f10469e2f..2264e10d3 100644 --- a/p2p/messenger_test.go +++ b/p2p/messenger_test.go @@ -11,14 +11,14 @@ import ( "testing" "time" - "github.com/ethereum/go-ethereum/ethutil" + logpkg "github.com/ethereum/go-ethereum/logger" ) func init() { - ethlog.AddLogSystem(ethlog.NewStdLogSystem(os.Stdout, log.LstdFlags, ethlog.DebugLevel)) + logpkg.AddLogSystem(logpkg.NewStdLogSystem(os.Stdout, log.LstdFlags, logpkg.DebugLevel)) } -func setupMessenger(handlers Handlers) (net.Conn, *Peer, *messenger) { +func testMessenger(handlers Handlers) (net.Conn, *Peer, *messenger) { conn1, conn2 := net.Pipe() id := NewSimpleClientIdentity("test", "0", "0", "public key") server := New(nil, conn1.LocalAddr(), id, handlers, 10, NewBlacklist()) @@ -33,7 +33,7 @@ func performTestHandshake(r *bufio.Reader, w io.Writer) error { return fmt.Errorf("read error: %v", err) } if msg.Code != handshakeMsg { - return fmt.Errorf("first message should be handshake, got %x", msg.Code) + return fmt.Errorf("first message should be handshake, got %d", msg.Code) } if err := msg.Discard(); err != nil { return err @@ -44,56 +44,102 @@ func performTestHandshake(r *bufio.Reader, w io.Writer) error { return writeMsg(w, msg) } -type testMsg struct { - code MsgCode - data *ethutil.Value +type testProtocol struct { + offset MsgCode + f func(MsgReadWriter) } -type testProto struct { - recv chan testMsg +func (p *testProtocol) Offset() MsgCode { + return p.offset } -func (*testProto) Offset() MsgCode { return 5 } - -func (tp *testProto) Start(peer *Peer, rw MsgReadWriter) error { - return MsgLoop(rw, 1024, func(code MsgCode, data *ethutil.Value) error { - logger.Debugf("testprotocol got msg: %d\n", code) - tp.recv <- testMsg{code, data} - return nil - }) +func (p *testProtocol) Start(peer *Peer, rw MsgReadWriter) error { + p.f(rw) + return nil } func TestRead(t *testing.T) { - testProtocol := &testProto{make(chan testMsg)} - handlers := Handlers{"a": func() Protocol { return testProtocol }} - net, peer, mess := setupMessenger(handlers) - bufr := bufio.NewReader(net) + done := make(chan struct{}) + handlers := Handlers{ + "a": &testProtocol{5, func(rw MsgReadWriter) { + msg, err := rw.ReadMsg() + if err != nil { + t.Errorf("read error: %v", err) + } + if msg.Code != 2 { + t.Errorf("incorrect msg code %d relayed to protocol", msg.Code) + } + data, err := msg.Data() + if err != nil { + t.Errorf("data decoding error: %v", err) + } + expdata := []interface{}{1, []byte{0x30, 0x30, 0x30}} + if !reflect.DeepEqual(data.Slice(), expdata) { + t.Errorf("incorrect msg data %#v", data.Slice()) + } + close(done) + }}, + } + + net, peer, m := testMessenger(handlers) defer peer.Stop() + bufr := bufio.NewReader(net) if err := performTestHandshake(bufr, net); err != nil { t.Fatalf("handshake failed: %v", err) } + m.setRemoteProtocols([]string{"a"}) - mess.setRemoteProtocols([]string{"a"}) - writeMsg(net, NewMsg(17, uint32(1), "000")) + writeMsg(net, NewMsg(18, 1, "000")) select { - case msg := <-testProtocol.recv: - if msg.code != 1 { - t.Errorf("incorrect msg code %d relayed to protocol", msg.code) - } - expdata := []interface{}{1, []byte{0x30, 0x30, 0x30}} - if !reflect.DeepEqual(msg.data.Slice(), expdata) { - t.Errorf("incorrect msg data %#v", msg.data.Slice()) - } + case <-done: case <-time.After(2 * time.Second): t.Errorf("receive timeout") } } -func TestWriteProtoMsg(t *testing.T) { - handlers := make(Handlers) - testProtocol := &testProto{recv: make(chan testMsg, 1)} - handlers["a"] = func() Protocol { return testProtocol } - net, peer, mess := setupMessenger(handlers) +func TestWriteFromProto(t *testing.T) { + handlers := Handlers{ + "a": &testProtocol{2, func(rw MsgReadWriter) { + if err := rw.WriteMsg(NewMsg(2)); err == nil { + t.Error("expected error for out-of-range msg code, got nil") + } + if err := rw.WriteMsg(NewMsg(1)); err != nil { + t.Errorf("write error: %v", err) + } + }}, + } + net, peer, mess := testMessenger(handlers) + defer peer.Stop() + bufr := bufio.NewReader(net) + if err := performTestHandshake(bufr, net); err != nil { + t.Fatalf("handshake failed: %v", err) + } + mess.setRemoteProtocols([]string{"a"}) + + msg, err := readMsg(bufr) + if err != nil { + t.Errorf("read error: %v") + } + if msg.Code != 17 { + t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17) + } +} + +var discardProto = &testProtocol{1, func(rw MsgReadWriter) { + for { + msg, err := rw.ReadMsg() + if err != nil { + return + } + if err = msg.Discard(); err != nil { + return + } + } +}} + +func TestMessengerWriteProtoMsg(t *testing.T) { + handlers := Handlers{"a": discardProto} + net, peer, mess := testMessenger(handlers) defer peer.Stop() bufr := bufio.NewReader(net) if err := performTestHandshake(bufr, net); err != nil { @@ -120,13 +166,13 @@ func TestWriteProtoMsg(t *testing.T) { read <- msg } }() - if err := mess.writeProtoMsg("a", NewMsg(3)); err != nil { + if err := mess.writeProtoMsg("a", NewMsg(0)); err != nil { t.Errorf("expect no error for known protocol: %v", err) } select { case msg := <-read: - if msg.Code != 19 { - t.Errorf("wrong code, got %d, expected %d", msg.Code, 19) + if msg.Code != 16 { + t.Errorf("wrong code, got %d, expected %d", msg.Code, 16) } msg.Discard() case err := <-readerr: @@ -135,7 +181,7 @@ func TestWriteProtoMsg(t *testing.T) { } func TestPulse(t *testing.T) { - net, peer, _ := setupMessenger(nil) + net, peer, _ := testMessenger(nil) defer peer.Stop() bufr := bufio.NewReader(net) if err := performTestHandshake(bufr, net); err != nil { @@ -149,7 +195,7 @@ func TestPulse(t *testing.T) { } after := time.Now() if msg.Code != pingMsg { - t.Errorf("expected ping message, got %x", msg.Code) + t.Errorf("expected ping message, got %d", msg.Code) } if d := after.Sub(before); d < pingTimeout { t.Errorf("ping sent too early after %v, expected at least %v", d, pingTimeout) diff --git a/p2p/protocol.go b/p2p/protocol.go index ccc275287..d22ba70cb 100644 --- a/p2p/protocol.go +++ b/p2p/protocol.go @@ -143,9 +143,6 @@ func (d DiscReason) String() string { return discReasonToString[d] } -func (bp *baseProtocol) Ping() { -} - func (bp *baseProtocol) Offset() MsgCode { return baseProtocolOffset } @@ -287,7 +284,7 @@ func (bp *baseProtocol) handleHandshake(c *ethutil.Value) error { // self connect detection if bytes.Compare(bp.peer.server.ClientIdentity().Pubkey()[1:], pubkey) == 0 { - return NewPeerError(PubkeyForbidden, "not allowed to connect to bp") + return NewPeerError(PubkeyForbidden, "not allowed to connect to self") } // register pubkey on server. this also sets the pubkey on the peer (need lock) -- cgit v1.2.3 From e4a601c6444afdc11ce0cb80d7fd83116de2c8b9 Mon Sep 17 00:00:00 2001 From: Felix Lange Date: Mon, 10 Nov 2014 14:48:48 +0100 Subject: p2p: disable failing Server tests for now --- p2p/server_test.go | 3 +++ 1 file changed, 3 insertions(+) (limited to 'p2p') diff --git a/p2p/server_test.go b/p2p/server_test.go index 472759231..a2594acba 100644 --- a/p2p/server_test.go +++ b/p2p/server_test.go @@ -198,6 +198,8 @@ func SetupTestServer(handlers Handlers) (network *TestNetwork, server *Server) { } func TestServerListener(t *testing.T) { + t.SkipNow() + network, server := SetupTestServer(nil) server.Start(true, false) time.Sleep(10 * time.Millisecond) @@ -270,6 +272,7 @@ func TestServerDialer(t *testing.T) { // } func TestServerPeersMessage(t *testing.T) { + t.SkipNow() _, server := SetupTestServer(nil) server.Start(true, true) defer server.Stop() -- cgit v1.2.3 From 59b63caf5e4de64ceb7dcdf01551a080f53b1672 Mon Sep 17 00:00:00 2001 From: Felix Lange Date: Fri, 21 Nov 2014 21:48:49 +0100 Subject: p2p: API cleanup and PoC 7 compatibility Whoa, one more big commit. I didn't manage to untangle the changes while working towards compatibility. --- p2p/client_identity.go | 6 +- p2p/message.go | 62 +++- p2p/messenger.go | 221 ------------- p2p/messenger_test.go | 203 ------------ p2p/natpmp.go | 34 +- p2p/natupnp.go | 198 ++++++------ p2p/network.go | 196 ----------- p2p/peer.go | 476 ++++++++++++++++++++++++--- p2p/peer_error.go | 150 ++++++--- p2p/peer_error_handler.go | 98 ------ p2p/peer_error_handler_test.go | 34 -- p2p/peer_test.go | 308 +++++++++++++----- p2p/protocol.go | 412 +++++++++++------------- p2p/server.go | 713 ++++++++++++++++++++--------------------- p2p/server_test.go | 388 ++++++++-------------- p2p/testlog_test.go | 28 ++ p2p/testpoc7.go | 40 +++ 17 files changed, 1665 insertions(+), 1902 deletions(-) delete mode 100644 p2p/messenger.go delete mode 100644 p2p/messenger_test.go delete mode 100644 p2p/network.go delete mode 100644 p2p/peer_error_handler.go delete mode 100644 p2p/peer_error_handler_test.go create mode 100644 p2p/testlog_test.go create mode 100644 p2p/testpoc7.go (limited to 'p2p') diff --git a/p2p/client_identity.go b/p2p/client_identity.go index 236b23106..bc865b63b 100644 --- a/p2p/client_identity.go +++ b/p2p/client_identity.go @@ -5,10 +5,10 @@ import ( "runtime" ) -// should be used in Peer handleHandshake, incorporate Caps, ProtocolVersion, Pubkey etc. +// ClientIdentity represents the identity of a peer. type ClientIdentity interface { - String() string - Pubkey() []byte + String() string // human readable identity + Pubkey() []byte // 512-bit public key } type SimpleClientIdentity struct { diff --git a/p2p/message.go b/p2p/message.go index 97d440a27..89ad189d7 100644 --- a/p2p/message.go +++ b/p2p/message.go @@ -11,8 +11,6 @@ import ( "github.com/ethereum/go-ethereum/ethutil" ) -type MsgCode uint64 - // Msg defines the structure of a p2p message. // // Note that a Msg can only be sent once since the Payload reader is @@ -21,13 +19,13 @@ type MsgCode uint64 // structure, encode the payload into a byte array and create a // separate Msg with a bytes.Reader as Payload for each send. type Msg struct { - Code MsgCode + Code uint64 Size uint32 // size of the paylod Payload io.Reader } // NewMsg creates an RLP-encoded message with the given code. -func NewMsg(code MsgCode, params ...interface{}) Msg { +func NewMsg(code uint64, params ...interface{}) Msg { buf := new(bytes.Buffer) for _, p := range params { buf.Write(ethutil.Encode(p)) @@ -63,6 +61,52 @@ func (msg Msg) Discard() error { return err } +type MsgReader interface { + ReadMsg() (Msg, error) +} + +type MsgWriter interface { + // WriteMsg sends an existing message. + // The Payload reader of the message is consumed. + // Note that messages can be sent only once. + WriteMsg(Msg) error + + // EncodeMsg writes an RLP-encoded message with the given + // code and data elements. + EncodeMsg(code uint64, data ...interface{}) error +} + +// MsgReadWriter provides reading and writing of encoded messages. +type MsgReadWriter interface { + MsgReader + MsgWriter +} + +// MsgLoop reads messages off the given reader and +// calls the handler function for each decoded message until +// it returns an error or the peer connection is closed. +// +// If a message is larger than the given maximum size, +// MsgLoop returns an appropriate error. +func MsgLoop(r MsgReader, maxsize uint32, f func(code uint64, data *ethutil.Value) error) error { + for { + msg, err := r.ReadMsg() + if err != nil { + return err + } + if msg.Size > maxsize { + return newPeerError(errInvalidMsg, "size %d exceeds maximum size of %d", msg.Size, maxsize) + } + value, err := msg.Data() + if err != nil { + return err + } + if err := f(msg.Code, value); err != nil { + return err + } + } +} + var magicToken = []byte{34, 64, 8, 145} func writeMsg(w io.Writer, msg Msg) error { @@ -103,10 +147,10 @@ func readMsg(r byteReader) (msg Msg, err error) { // read magic and payload size start := make([]byte, 8) if _, err = io.ReadFull(r, start); err != nil { - return msg, NewPeerError(ReadError, "%v", err) + return msg, newPeerError(errRead, "%v", err) } if !bytes.HasPrefix(start, magicToken) { - return msg, NewPeerError(MagicTokenMismatch, "got %x, want %x", start[:4], magicToken) + return msg, newPeerError(errMagicTokenMismatch, "got %x, want %x", start[:4], magicToken) } size := binary.BigEndian.Uint32(start[4:]) @@ -152,13 +196,13 @@ func readListHeader(r byteReader) (len uint64, hdrlen uint32, err error) { } // readUint reads an RLP-encoded unsigned integer from r. -func readMsgCode(r byteReader) (code MsgCode, codelen uint32, err error) { +func readMsgCode(r byteReader) (code uint64, codelen uint32, err error) { b, err := r.ReadByte() if err != nil { return 0, 0, err } if b < 0x80 { - return MsgCode(b), 1, nil + return uint64(b), 1, nil } else if b < 0x89 { // max length for uint64 is 8 bytes codelen = uint32(b - 0x80) if codelen == 0 { @@ -168,7 +212,7 @@ func readMsgCode(r byteReader) (code MsgCode, codelen uint32, err error) { if _, err := io.ReadFull(r, buf[8-codelen:]); err != nil { return 0, 0, err } - return MsgCode(binary.BigEndian.Uint64(buf)), codelen, nil + return binary.BigEndian.Uint64(buf), codelen, nil } return 0, 0, fmt.Errorf("bad RLP type for message code: %x", b) } diff --git a/p2p/messenger.go b/p2p/messenger.go deleted file mode 100644 index c7948a9ac..000000000 --- a/p2p/messenger.go +++ /dev/null @@ -1,221 +0,0 @@ -package p2p - -import ( - "bufio" - "bytes" - "fmt" - "io" - "io/ioutil" - "net" - "sync" - "time" -) - -type Handlers map[string]Protocol - -type proto struct { - in chan Msg - maxcode, offset MsgCode - messenger *messenger -} - -func (rw *proto) WriteMsg(msg Msg) error { - if msg.Code >= rw.maxcode { - return NewPeerError(InvalidMsgCode, "not handled") - } - msg.Code += rw.offset - return rw.messenger.writeMsg(msg) -} - -func (rw *proto) ReadMsg() (Msg, error) { - msg, ok := <-rw.in - if !ok { - return msg, io.EOF - } - msg.Code -= rw.offset - return msg, nil -} - -// eofSignal wraps a reader with eof signaling. -// the eof channel is closed when the wrapped reader -// reaches EOF. -type eofSignal struct { - wrapped io.Reader - eof chan struct{} -} - -func (r *eofSignal) Read(buf []byte) (int, error) { - n, err := r.wrapped.Read(buf) - if err != nil { - close(r.eof) // tell messenger that msg has been consumed - } - return n, err -} - -// messenger represents a message-oriented peer connection. -// It keeps track of the set of protocols understood -// by the remote peer. -type messenger struct { - peer *Peer - handlers Handlers - - // the mutex protects the connection - // so only one protocol can write at a time. - writeMu sync.Mutex - conn net.Conn - bufconn *bufio.ReadWriter - - protocolLock sync.RWMutex - protocols map[string]*proto - offsets map[MsgCode]*proto - protoWG sync.WaitGroup - - err chan error - pulse chan bool -} - -func newMessenger(peer *Peer, conn net.Conn, errchan chan error, handlers Handlers) *messenger { - return &messenger{ - conn: conn, - bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)), - peer: peer, - handlers: handlers, - protocols: make(map[string]*proto), - err: errchan, - pulse: make(chan bool, 1), - } -} - -func (m *messenger) Start() { - m.protocols[""] = m.startProto(0, "", &baseProtocol{}) - go m.readLoop() -} - -func (m *messenger) Stop() { - m.conn.Close() - m.protoWG.Wait() -} - -const ( - // maximum amount of time allowed for reading a message - msgReadTimeout = 5 * time.Second - - // messages smaller than this many bytes will be read at - // once before passing them to a protocol. - wholePayloadSize = 64 * 1024 -) - -func (m *messenger) readLoop() { - defer m.closeProtocols() - for { - m.conn.SetReadDeadline(time.Now().Add(msgReadTimeout)) - msg, err := readMsg(m.bufconn) - if err != nil { - m.err <- err - return - } - // send ping to heartbeat channel signalling time of last message - m.pulse <- true - proto, err := m.getProto(msg.Code) - if err != nil { - m.err <- err - return - } - if msg.Size <= wholePayloadSize { - // optimization: msg is small enough, read all - // of it and move on to the next message - buf, err := ioutil.ReadAll(msg.Payload) - if err != nil { - m.err <- err - return - } - msg.Payload = bytes.NewReader(buf) - proto.in <- msg - } else { - pr := &eofSignal{msg.Payload, make(chan struct{})} - msg.Payload = pr - proto.in <- msg - <-pr.eof - } - } -} - -func (m *messenger) closeProtocols() { - m.protocolLock.RLock() - for _, p := range m.protocols { - close(p.in) - } - m.protocolLock.RUnlock() -} - -func (m *messenger) startProto(offset MsgCode, name string, impl Protocol) *proto { - proto := &proto{ - in: make(chan Msg), - offset: offset, - maxcode: impl.Offset(), - messenger: m, - } - m.protoWG.Add(1) - go func() { - if err := impl.Start(m.peer, proto); err != nil && err != io.EOF { - logger.Errorf("protocol %q error: %v\n", name, err) - m.err <- err - } - m.protoWG.Done() - }() - return proto -} - -// getProto finds the protocol responsible for handling -// the given message code. -func (m *messenger) getProto(code MsgCode) (*proto, error) { - m.protocolLock.RLock() - defer m.protocolLock.RUnlock() - for _, proto := range m.protocols { - if code >= proto.offset && code < proto.offset+proto.maxcode { - return proto, nil - } - } - return nil, NewPeerError(InvalidMsgCode, "%d", code) -} - -// setProtocols starts all subprotocols shared with the -// remote peer. the protocols must be sorted alphabetically. -func (m *messenger) setRemoteProtocols(protocols []string) { - m.protocolLock.Lock() - defer m.protocolLock.Unlock() - offset := baseProtocolOffset - for _, name := range protocols { - inst, ok := m.handlers[name] - if !ok { - continue // not handled - } - m.protocols[name] = m.startProto(offset, name, inst) - offset += inst.Offset() - } -} - -// writeProtoMsg sends the given message on behalf of the given named protocol. -func (m *messenger) writeProtoMsg(protoName string, msg Msg) error { - m.protocolLock.RLock() - proto, ok := m.protocols[protoName] - m.protocolLock.RUnlock() - if !ok { - return fmt.Errorf("protocol %s not handled by peer", protoName) - } - if msg.Code >= proto.maxcode { - return NewPeerError(InvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName) - } - msg.Code += proto.offset - return m.writeMsg(msg) -} - -// writeMsg writes a message to the connection. -func (m *messenger) writeMsg(msg Msg) error { - m.writeMu.Lock() - defer m.writeMu.Unlock() - if err := writeMsg(m.bufconn, msg); err != nil { - return err - } - return m.bufconn.Flush() -} diff --git a/p2p/messenger_test.go b/p2p/messenger_test.go deleted file mode 100644 index 2264e10d3..000000000 --- a/p2p/messenger_test.go +++ /dev/null @@ -1,203 +0,0 @@ -package p2p - -import ( - "bufio" - "fmt" - "io" - "log" - "net" - "os" - "reflect" - "testing" - "time" - - logpkg "github.com/ethereum/go-ethereum/logger" -) - -func init() { - logpkg.AddLogSystem(logpkg.NewStdLogSystem(os.Stdout, log.LstdFlags, logpkg.DebugLevel)) -} - -func testMessenger(handlers Handlers) (net.Conn, *Peer, *messenger) { - conn1, conn2 := net.Pipe() - id := NewSimpleClientIdentity("test", "0", "0", "public key") - server := New(nil, conn1.LocalAddr(), id, handlers, 10, NewBlacklist()) - peer := server.addPeer(conn1, conn1.RemoteAddr(), true, 0) - return conn2, peer, peer.messenger -} - -func performTestHandshake(r *bufio.Reader, w io.Writer) error { - // read remote handshake - msg, err := readMsg(r) - if err != nil { - return fmt.Errorf("read error: %v", err) - } - if msg.Code != handshakeMsg { - return fmt.Errorf("first message should be handshake, got %d", msg.Code) - } - if err := msg.Discard(); err != nil { - return err - } - // send empty handshake - pubkey := make([]byte, 64) - msg = NewMsg(handshakeMsg, p2pVersion, "testid", nil, 9999, pubkey) - return writeMsg(w, msg) -} - -type testProtocol struct { - offset MsgCode - f func(MsgReadWriter) -} - -func (p *testProtocol) Offset() MsgCode { - return p.offset -} - -func (p *testProtocol) Start(peer *Peer, rw MsgReadWriter) error { - p.f(rw) - return nil -} - -func TestRead(t *testing.T) { - done := make(chan struct{}) - handlers := Handlers{ - "a": &testProtocol{5, func(rw MsgReadWriter) { - msg, err := rw.ReadMsg() - if err != nil { - t.Errorf("read error: %v", err) - } - if msg.Code != 2 { - t.Errorf("incorrect msg code %d relayed to protocol", msg.Code) - } - data, err := msg.Data() - if err != nil { - t.Errorf("data decoding error: %v", err) - } - expdata := []interface{}{1, []byte{0x30, 0x30, 0x30}} - if !reflect.DeepEqual(data.Slice(), expdata) { - t.Errorf("incorrect msg data %#v", data.Slice()) - } - close(done) - }}, - } - - net, peer, m := testMessenger(handlers) - defer peer.Stop() - bufr := bufio.NewReader(net) - if err := performTestHandshake(bufr, net); err != nil { - t.Fatalf("handshake failed: %v", err) - } - m.setRemoteProtocols([]string{"a"}) - - writeMsg(net, NewMsg(18, 1, "000")) - select { - case <-done: - case <-time.After(2 * time.Second): - t.Errorf("receive timeout") - } -} - -func TestWriteFromProto(t *testing.T) { - handlers := Handlers{ - "a": &testProtocol{2, func(rw MsgReadWriter) { - if err := rw.WriteMsg(NewMsg(2)); err == nil { - t.Error("expected error for out-of-range msg code, got nil") - } - if err := rw.WriteMsg(NewMsg(1)); err != nil { - t.Errorf("write error: %v", err) - } - }}, - } - net, peer, mess := testMessenger(handlers) - defer peer.Stop() - bufr := bufio.NewReader(net) - if err := performTestHandshake(bufr, net); err != nil { - t.Fatalf("handshake failed: %v", err) - } - mess.setRemoteProtocols([]string{"a"}) - - msg, err := readMsg(bufr) - if err != nil { - t.Errorf("read error: %v") - } - if msg.Code != 17 { - t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17) - } -} - -var discardProto = &testProtocol{1, func(rw MsgReadWriter) { - for { - msg, err := rw.ReadMsg() - if err != nil { - return - } - if err = msg.Discard(); err != nil { - return - } - } -}} - -func TestMessengerWriteProtoMsg(t *testing.T) { - handlers := Handlers{"a": discardProto} - net, peer, mess := testMessenger(handlers) - defer peer.Stop() - bufr := bufio.NewReader(net) - if err := performTestHandshake(bufr, net); err != nil { - t.Fatalf("handshake failed: %v", err) - } - mess.setRemoteProtocols([]string{"a"}) - - // test write errors - if err := mess.writeProtoMsg("b", NewMsg(3)); err == nil { - t.Errorf("expected error for unknown protocol, got nil") - } - if err := mess.writeProtoMsg("a", NewMsg(8)); err == nil { - t.Errorf("expected error for out-of-range msg code, got nil") - } else if perr, ok := err.(*PeerError); !ok || perr.Code != InvalidMsgCode { - t.Errorf("wrong error for out-of-range msg code, got %#v") - } - - // test succcessful write - read, readerr := make(chan Msg), make(chan error) - go func() { - if msg, err := readMsg(bufr); err != nil { - readerr <- err - } else { - read <- msg - } - }() - if err := mess.writeProtoMsg("a", NewMsg(0)); err != nil { - t.Errorf("expect no error for known protocol: %v", err) - } - select { - case msg := <-read: - if msg.Code != 16 { - t.Errorf("wrong code, got %d, expected %d", msg.Code, 16) - } - msg.Discard() - case err := <-readerr: - t.Errorf("read error: %v", err) - } -} - -func TestPulse(t *testing.T) { - net, peer, _ := testMessenger(nil) - defer peer.Stop() - bufr := bufio.NewReader(net) - if err := performTestHandshake(bufr, net); err != nil { - t.Fatalf("handshake failed: %v", err) - } - - before := time.Now() - msg, err := readMsg(bufr) - if err != nil { - t.Fatalf("read error: %v", err) - } - after := time.Now() - if msg.Code != pingMsg { - t.Errorf("expected ping message, got %d", msg.Code) - } - if d := after.Sub(before); d < pingTimeout { - t.Errorf("ping sent too early after %v, expected at least %v", d, pingTimeout) - } -} diff --git a/p2p/natpmp.go b/p2p/natpmp.go index ff966d070..6714678c4 100644 --- a/p2p/natpmp.go +++ b/p2p/natpmp.go @@ -3,6 +3,7 @@ package p2p import ( "fmt" "net" + "time" natpmp "github.com/jackpal/go-nat-pmp" ) @@ -13,38 +14,37 @@ import ( // + Register for changes to the external address. // + Re-register port mapping when router reboots. // + A mechanism for keeping a port mapping registered. +// + Discover gateway address automatically. type natPMPClient struct { client *natpmp.Client } -func NewNatPMP(gateway net.IP) (nat NAT) { +// PMP returns a NAT traverser that uses NAT-PMP. The provided gateway +// address should be the IP of your router. +func PMP(gateway net.IP) (nat NAT) { return &natPMPClient{natpmp.NewClient(gateway)} } -func (n *natPMPClient) GetExternalAddress() (addr net.IP, err error) { +func (*natPMPClient) String() string { + return "NAT-PMP" +} + +func (n *natPMPClient) GetExternalAddress() (net.IP, error) { response, err := n.client.GetExternalAddress() if err != nil { - return + return nil, err } - ip := response.ExternalIPAddress - addr = net.IPv4(ip[0], ip[1], ip[2], ip[3]) - return + return response.ExternalIPAddress[:], nil } -func (n *natPMPClient) AddPortMapping(protocol string, externalPort, internalPort int, - description string, timeout int) (mappedExternalPort int, err error) { - if timeout <= 0 { - err = fmt.Errorf("timeout must not be <= 0") - return +func (n *natPMPClient) AddPortMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error { + if lifetime <= 0 { + return fmt.Errorf("lifetime must not be <= 0") } // Note order of port arguments is switched between our AddPortMapping and the client's AddPortMapping. - response, err := n.client.AddPortMapping(protocol, internalPort, externalPort, timeout) - if err != nil { - return - } - mappedExternalPort = int(response.MappedExternalPort) - return + _, err := n.client.AddPortMapping(protocol, intport, extport, int(lifetime/time.Second)) + return err } func (n *natPMPClient) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) { diff --git a/p2p/natupnp.go b/p2p/natupnp.go index fa9798d4d..2e0d8ce8d 100644 --- a/p2p/natupnp.go +++ b/p2p/natupnp.go @@ -7,6 +7,7 @@ import ( "bytes" "encoding/xml" "errors" + "fmt" "net" "net/http" "os" @@ -15,28 +16,46 @@ import ( "time" ) +const ( + upnpDiscoverAttempts = 3 + upnpDiscoverTimeout = 5 * time.Second +) + +// UPNP returns a NAT port mapper that uses UPnP. It will attempt to +// discover the address of your router using UDP broadcasts. +func UPNP() NAT { + return &upnpNAT{} +} + type upnpNAT struct { serviceURL string ourIP string } -func upnpDiscover(attempts int) (nat NAT, err error) { +func (n *upnpNAT) String() string { + return "UPNP" +} + +func (n *upnpNAT) discover() error { + if n.serviceURL != "" { + // already discovered + return nil + } + ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900") if err != nil { - return + return err } + // TODO: try on all network interfaces simultaneously. + // Broadcasting on 0.0.0.0 could select a random interface + // to send on (platform specific). conn, err := net.ListenPacket("udp4", ":0") if err != nil { - return - } - socket := conn.(*net.UDPConn) - defer socket.Close() - - err = socket.SetDeadline(time.Now().Add(10 * time.Second)) - if err != nil { - return + return err } + defer conn.Close() + conn.SetDeadline(time.Now().Add(10 * time.Second)) st := "ST: urn:schemas-upnp-org:device:InternetGatewayDevice:1\r\n" buf := bytes.NewBufferString( "M-SEARCH * HTTP/1.1\r\n" + @@ -46,19 +65,16 @@ func upnpDiscover(attempts int) (nat NAT, err error) { "MX: 2\r\n\r\n") message := buf.Bytes() answerBytes := make([]byte, 1024) - for i := 0; i < attempts; i++ { - _, err = socket.WriteToUDP(message, ssdp) + for i := 0; i < upnpDiscoverAttempts; i++ { + _, err = conn.WriteTo(message, ssdp) if err != nil { - return + return err } - var n int - n, _, err = socket.ReadFromUDP(answerBytes) + nn, _, err := conn.ReadFrom(answerBytes) if err != nil { continue - // socket.Close() - // return } - answer := string(answerBytes[0:n]) + answer := string(answerBytes[0:nn]) if strings.Index(answer, "\r\n"+st) < 0 { continue } @@ -79,17 +95,81 @@ func upnpDiscover(attempts int) (nat NAT, err error) { var serviceURL string serviceURL, err = getServiceURL(locURL) if err != nil { - return + return err } var ourIP string ourIP, err = getOurIP() if err != nil { - return + return err } - nat = &upnpNAT{serviceURL: serviceURL, ourIP: ourIP} + n.serviceURL = serviceURL + n.ourIP = ourIP + return nil + } + return errors.New("UPnP port discovery failed.") +} + +func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) { + if err := n.discover(); err != nil { + return nil, err + } + info, err := n.getStatusInfo() + return net.ParseIP(info.externalIpAddress), err +} + +func (n *upnpNAT) AddPortMapping(protocol string, extport, intport int, description string, lifetime time.Duration) error { + if err := n.discover(); err != nil { + return err + } + + // A single concatenation would break ARM compilation. + message := "\r\n" + + "" + strconv.Itoa(extport) + message += "" + protocol + "" + message += "" + strconv.Itoa(extport) + "" + + "" + n.ourIP + "" + + "1" + message += description + + "" + fmt.Sprint(lifetime/time.Second) + + "" + + // TODO: check response to see if the port was forwarded + _, err := soapRequest(n.serviceURL, "AddPortMapping", message) + return err +} + +func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) error { + if err := n.discover(); err != nil { + return err + } + + message := "\r\n" + + "" + strconv.Itoa(externalPort) + + "" + protocol + "" + + "" + + // TODO: check response to see if the port was deleted + _, err := soapRequest(n.serviceURL, "DeletePortMapping", message) + return err +} + +type statusInfo struct { + externalIpAddress string +} + +func (n *upnpNAT) getStatusInfo() (info statusInfo, err error) { + message := "\r\n" + + "" + + var response *http.Response + response, err = soapRequest(n.serviceURL, "GetStatusInfo", message) + if err != nil { return } - err = errors.New("UPnP port discovery failed.") + + // TODO: Write a soap reply parser. It has to eat the Body and envelope tags... + + response.Body.Close() return } @@ -259,77 +339,3 @@ func soapRequest(url, function, message string) (r *http.Response, err error) { } return } - -type statusInfo struct { - externalIpAddress string -} - -func (n *upnpNAT) getStatusInfo() (info statusInfo, err error) { - - message := "\r\n" + - "" - - var response *http.Response - response, err = soapRequest(n.serviceURL, "GetStatusInfo", message) - if err != nil { - return - } - - // TODO: Write a soap reply parser. It has to eat the Body and envelope tags... - - response.Body.Close() - return -} - -func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) { - info, err := n.getStatusInfo() - if err != nil { - return - } - addr = net.ParseIP(info.externalIpAddress) - return -} - -func (n *upnpNAT) AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) { - // A single concatenation would break ARM compilation. - message := "\r\n" + - "" + strconv.Itoa(externalPort) - message += "" + protocol + "" - message += "" + strconv.Itoa(internalPort) + "" + - "" + n.ourIP + "" + - "1" - message += description + - "" + strconv.Itoa(timeout) + - "" - - var response *http.Response - response, err = soapRequest(n.serviceURL, "AddPortMapping", message) - if err != nil { - return - } - - // TODO: check response to see if the port was forwarded - // log.Println(message, response) - mappedExternalPort = externalPort - _ = response - return -} - -func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) { - - message := "\r\n" + - "" + strconv.Itoa(externalPort) + - "" + protocol + "" + - "" - - var response *http.Response - response, err = soapRequest(n.serviceURL, "DeletePortMapping", message) - if err != nil { - return - } - - // TODO: check response to see if the port was deleted - // log.Println(message, response) - _ = response - return -} diff --git a/p2p/network.go b/p2p/network.go deleted file mode 100644 index 820cef1a9..000000000 --- a/p2p/network.go +++ /dev/null @@ -1,196 +0,0 @@ -package p2p - -import ( - "fmt" - "math/rand" - "net" - "strconv" - "time" -) - -const ( - DialerTimeout = 180 //seconds - KeepAlivePeriod = 60 //minutes - portMappingUpdateInterval = 900 // seconds = 15 mins - upnpDiscoverAttempts = 3 -) - -// Dialer is not an interface in net, so we define one -// *net.Dialer conforms to this -type Dialer interface { - Dial(network, address string) (net.Conn, error) -} - -type Network interface { - Start() error - Listener(net.Addr) (net.Listener, error) - Dialer(net.Addr) (Dialer, error) - NewAddr(string, int) (addr net.Addr, err error) - ParseAddr(string) (addr net.Addr, err error) -} - -type NAT interface { - GetExternalAddress() (addr net.IP, err error) - AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) - DeletePortMapping(protocol string, externalPort, internalPort int) (err error) -} - -type TCPNetwork struct { - nat NAT - natType NATType - quit chan chan bool - ports chan string -} - -type NATType int - -const ( - NONE = iota - UPNP - PMP -) - -const ( - portMappingTimeout = 1200 // 20 mins -) - -func NewTCPNetwork(natType NATType) (net *TCPNetwork) { - return &TCPNetwork{ - natType: natType, - ports: make(chan string), - } -} - -func (self *TCPNetwork) Dialer(addr net.Addr) (Dialer, error) { - return &net.Dialer{ - Timeout: DialerTimeout * time.Second, - // KeepAlive: KeepAlivePeriod * time.Minute, - LocalAddr: addr, - }, nil -} - -func (self *TCPNetwork) Listener(addr net.Addr) (net.Listener, error) { - if self.natType == UPNP { - _, port, _ := net.SplitHostPort(addr.String()) - if self.quit == nil { - self.quit = make(chan chan bool) - go self.updatePortMappings() - } - self.ports <- port - } - return net.Listen(addr.Network(), addr.String()) -} - -func (self *TCPNetwork) Start() (err error) { - switch self.natType { - case NONE: - case UPNP: - nat, uerr := upnpDiscover(upnpDiscoverAttempts) - if uerr != nil { - err = fmt.Errorf("UPNP failed: ", uerr) - } else { - self.nat = nat - } - case PMP: - err = fmt.Errorf("PMP not implemented") - default: - err = fmt.Errorf("Invalid NAT type: %v", self.natType) - } - return -} - -func (self *TCPNetwork) Stop() { - q := make(chan bool) - self.quit <- q - <-q -} - -func (self *TCPNetwork) addPortMapping(lport int) (err error) { - _, err = self.nat.AddPortMapping("TCP", lport, lport, "p2p listen port", portMappingTimeout) - if err != nil { - logger.Errorf("unable to add port mapping on %v: %v", lport, err) - } else { - logger.Debugf("succesfully added port mapping on %v", lport) - } - return -} - -func (self *TCPNetwork) updatePortMappings() { - timer := time.NewTimer(portMappingUpdateInterval * time.Second) - lports := []int{} -out: - for { - select { - case port := <-self.ports: - int64lport, _ := strconv.ParseInt(port, 10, 16) - lport := int(int64lport) - if err := self.addPortMapping(lport); err != nil { - lports = append(lports, lport) - } - case <-timer.C: - for lport := range lports { - if err := self.addPortMapping(lport); err != nil { - } - } - case errc := <-self.quit: - errc <- true - break out - } - } - - timer.Stop() - for lport := range lports { - if err := self.nat.DeletePortMapping("TCP", lport, lport); err != nil { - logger.Debugf("unable to remove port mapping on %v: %v", lport, err) - } else { - logger.Debugf("succesfully removed port mapping on %v", lport) - } - } -} - -func (self *TCPNetwork) NewAddr(host string, port int) (net.Addr, error) { - ip, err := self.lookupIP(host) - if err == nil { - return &net.TCPAddr{ - IP: ip, - Port: port, - }, nil - } - return nil, err -} - -func (self *TCPNetwork) ParseAddr(address string) (net.Addr, error) { - host, port, err := net.SplitHostPort(address) - if err == nil { - iport, _ := strconv.Atoi(port) - addr, e := self.NewAddr(host, iport) - return addr, e - } - return nil, err -} - -func (*TCPNetwork) lookupIP(host string) (ip net.IP, err error) { - if ip = net.ParseIP(host); ip != nil { - return - } - - var ips []net.IP - ips, err = net.LookupIP(host) - if err != nil { - logger.Warnln(err) - return - } - if len(ips) == 0 { - err = fmt.Errorf("No IP addresses available for %v", host) - logger.Warnln(err) - return - } - if len(ips) > 1 { - // Pick a random IP address, simulating round-robin DNS. - rand.Seed(time.Now().UTC().UnixNano()) - ip = ips[rand.Intn(len(ips))] - } else { - ip = ips[0] - } - return -} diff --git a/p2p/peer.go b/p2p/peer.go index 34b6152a3..238d3d9c9 100644 --- a/p2p/peer.go +++ b/p2p/peer.go @@ -1,66 +1,454 @@ package p2p import ( + "bufio" + "bytes" "fmt" + "io" + "io/ioutil" "net" - "strconv" + "sort" + "sync" + "time" + + "github.com/ethereum/go-ethereum/event" + "github.com/ethereum/go-ethereum/logger" ) +// peerAddr is the structure of a peer list element. +// It is also a valid net.Addr. +type peerAddr struct { + IP net.IP + Port uint64 + Pubkey []byte // optional +} + +func newPeerAddr(addr net.Addr, pubkey []byte) *peerAddr { + n := addr.Network() + if n != "tcp" && n != "tcp4" && n != "tcp6" { + // for testing with non-TCP + return &peerAddr{net.ParseIP("127.0.0.1"), 30303, pubkey} + } + ta := addr.(*net.TCPAddr) + return &peerAddr{ta.IP, uint64(ta.Port), pubkey} +} + +func (d peerAddr) Network() string { + if d.IP.To4() != nil { + return "tcp4" + } else { + return "tcp6" + } +} + +func (d peerAddr) String() string { + return fmt.Sprintf("%v:%d", d.IP, d.Port) +} + +func (d peerAddr) RlpData() interface{} { + return []interface{}{d.IP, d.Port, d.Pubkey} +} + +// Peer represents a remote peer. type Peer struct { - Inbound bool // inbound (via listener) or outbound (via dialout) - Address net.Addr - Host []byte - Port uint16 - Pubkey []byte - Id string - Caps []string - peerErrorChan chan error - messenger *messenger - peerErrorHandler *PeerErrorHandler - server *Server -} - -func NewPeer(conn net.Conn, address net.Addr, inbound bool, server *Server) *Peer { - peerErrorChan := NewPeerErrorChannel() - host, port, _ := net.SplitHostPort(address.String()) - intport, _ := strconv.Atoi(port) - peer := &Peer{ - Inbound: inbound, - Address: address, - Port: uint16(intport), - Host: net.ParseIP(host), - peerErrorChan: peerErrorChan, - server: server, - } - peer.messenger = newMessenger(peer, conn, peerErrorChan, server.Handlers()) - peer.peerErrorHandler = NewPeerErrorHandler(address, server.PeerDisconnect(), peerErrorChan) + // Peers have all the log methods. + // Use them to display messages related to the peer. + *logger.Logger + + infolock sync.Mutex + identity ClientIdentity + caps []Cap + listenAddr *peerAddr // what remote peer is listening on + dialAddr *peerAddr // non-nil if dialing + + // The mutex protects the connection + // so only one protocol can write at a time. + writeMu sync.Mutex + conn net.Conn + bufconn *bufio.ReadWriter + + // These fields maintain the running protocols. + protocols []Protocol + runBaseProtocol bool // for testing + + runlock sync.RWMutex // protects running + running map[string]*proto + + protoWG sync.WaitGroup + protoErr chan error + closed chan struct{} + disc chan DiscReason + + activity event.TypeMux // for activity events + + slot int // index into Server peer list + + // These fields are kept so base protocol can access them. + // TODO: this should be one or more interfaces + ourID ClientIdentity // client id of the Server + ourListenAddr *peerAddr // listen addr of Server, nil if not listening + newPeerAddr chan<- *peerAddr // tell server about received peers + otherPeers func() []*Peer // should return the list of all peers + pubkeyHook func(*peerAddr) error // called at end of handshake to validate pubkey +} + +// NewPeer returns a peer for testing purposes. +func NewPeer(id ClientIdentity, caps []Cap) *Peer { + conn, _ := net.Pipe() + peer := newPeer(conn, nil, nil) + peer.setHandshakeInfo(id, nil, caps) return peer } -func (self *Peer) String() string { - var kind string - if self.Inbound { - kind = "inbound" - } else { +func newServerPeer(server *Server, conn net.Conn, dialAddr *peerAddr) *Peer { + p := newPeer(conn, server.Protocols, dialAddr) + p.ourID = server.Identity + p.newPeerAddr = server.peerConnect + p.otherPeers = server.Peers + p.pubkeyHook = server.verifyPeer + p.runBaseProtocol = true + + // laddr can be updated concurrently by NAT traversal. + // newServerPeer must be called with the server lock held. + if server.laddr != nil { + p.ourListenAddr = newPeerAddr(server.laddr, server.Identity.Pubkey()) + } + return p +} + +func newPeer(conn net.Conn, protocols []Protocol, dialAddr *peerAddr) *Peer { + p := &Peer{ + Logger: logger.NewLogger("P2P " + conn.RemoteAddr().String()), + conn: conn, + dialAddr: dialAddr, + bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)), + protocols: protocols, + running: make(map[string]*proto), + disc: make(chan DiscReason), + protoErr: make(chan error), + closed: make(chan struct{}), + } + return p +} + +// Identity returns the client identity of the remote peer. The +// identity can be nil if the peer has not yet completed the +// handshake. +func (p *Peer) Identity() ClientIdentity { + p.infolock.Lock() + defer p.infolock.Unlock() + return p.identity +} + +// Caps returns the capabilities (supported subprotocols) of the remote peer. +func (p *Peer) Caps() []Cap { + p.infolock.Lock() + defer p.infolock.Unlock() + return p.caps +} + +func (p *Peer) setHandshakeInfo(id ClientIdentity, laddr *peerAddr, caps []Cap) { + p.infolock.Lock() + p.identity = id + p.listenAddr = laddr + p.caps = caps + p.infolock.Unlock() +} + +// RemoteAddr returns the remote address of the network connection. +func (p *Peer) RemoteAddr() net.Addr { + return p.conn.RemoteAddr() +} + +// LocalAddr returns the local address of the network connection. +func (p *Peer) LocalAddr() net.Addr { + return p.conn.LocalAddr() +} + +// Disconnect terminates the peer connection with the given reason. +// It returns immediately and does not wait until the connection is closed. +func (p *Peer) Disconnect(reason DiscReason) { + select { + case p.disc <- reason: + case <-p.closed: + } +} + +// String implements fmt.Stringer. +func (p *Peer) String() string { + kind := "inbound" + p.infolock.Lock() + if p.dialAddr != nil { kind = "outbound" } - return fmt.Sprintf("%v:%v (%s) v%v %v", self.Host, self.Port, kind, self.Id, self.Caps) + p.infolock.Unlock() + return fmt.Sprintf("Peer(%p %v %s)", p, p.conn.RemoteAddr(), kind) +} + +const ( + // maximum amount of time allowed for reading a message + msgReadTimeout = 5 * time.Second + // maximum amount of time allowed for writing a message + msgWriteTimeout = 5 * time.Second + // messages smaller than this many bytes will be read at + // once before passing them to a protocol. + wholePayloadSize = 64 * 1024 +) + +var ( + inactivityTimeout = 2 * time.Second + disconnectGracePeriod = 2 * time.Second +) + +func (p *Peer) loop() (reason DiscReason, err error) { + defer p.activity.Stop() + defer p.closeProtocols() + defer close(p.closed) + defer p.conn.Close() + + // read loop + readMsg := make(chan Msg) + readErr := make(chan error) + readNext := make(chan bool, 1) + protoDone := make(chan struct{}, 1) + go p.readLoop(readMsg, readErr, readNext) + readNext <- true + + if p.runBaseProtocol { + p.startBaseProtocol() + } + +loop: + for { + select { + case msg := <-readMsg: + // a new message has arrived. + var wait bool + if wait, err = p.dispatch(msg, protoDone); err != nil { + p.Errorf("msg dispatch error: %v\n", err) + reason = discReasonForError(err) + break loop + } + if !wait { + // Msg has already been read completely, continue with next message. + readNext <- true + } + p.activity.Post(time.Now()) + case <-protoDone: + // protocol has consumed the message payload, + // we can continue reading from the socket. + readNext <- true + + case err := <-readErr: + // read failed. there is no need to run the + // polite disconnect sequence because the connection + // is probably dead anyway. + // TODO: handle write errors as well + return DiscNetworkError, err + case err = <-p.protoErr: + reason = discReasonForError(err) + break loop + case reason = <-p.disc: + break loop + } + } + + // wait for read loop to return. + close(readNext) + <-readErr + // tell the remote end to disconnect + done := make(chan struct{}) + go func() { + p.conn.SetDeadline(time.Now().Add(disconnectGracePeriod)) + p.writeMsg(NewMsg(discMsg, reason), disconnectGracePeriod) + io.Copy(ioutil.Discard, p.conn) + close(done) + }() + select { + case <-done: + case <-time.After(disconnectGracePeriod): + } + return reason, err +} + +func (p *Peer) readLoop(msgc chan<- Msg, errc chan<- error, unblock <-chan bool) { + for _ = range unblock { + p.conn.SetReadDeadline(time.Now().Add(msgReadTimeout)) + if msg, err := readMsg(p.bufconn); err != nil { + errc <- err + } else { + msgc <- msg + } + } + close(errc) +} + +func (p *Peer) dispatch(msg Msg, protoDone chan struct{}) (wait bool, err error) { + proto, err := p.getProto(msg.Code) + if err != nil { + return false, err + } + if msg.Size <= wholePayloadSize { + // optimization: msg is small enough, read all + // of it and move on to the next message + buf, err := ioutil.ReadAll(msg.Payload) + if err != nil { + return false, err + } + msg.Payload = bytes.NewReader(buf) + proto.in <- msg + } else { + wait = true + pr := &eofSignal{msg.Payload, protoDone} + msg.Payload = pr + proto.in <- msg + } + return wait, nil +} + +func (p *Peer) startBaseProtocol() { + p.runlock.Lock() + defer p.runlock.Unlock() + p.running[""] = p.startProto(0, Protocol{ + Length: baseProtocolLength, + Run: runBaseProtocol, + }) +} + +// startProtocols starts matching named subprotocols. +func (p *Peer) startSubprotocols(caps []Cap) { + sort.Sort(capsByName(caps)) + + p.runlock.Lock() + defer p.runlock.Unlock() + offset := baseProtocolLength +outer: + for _, cap := range caps { + for _, proto := range p.protocols { + if proto.Name == cap.Name && + proto.Version == cap.Version && + p.running[cap.Name] == nil { + p.running[cap.Name] = p.startProto(offset, proto) + offset += proto.Length + continue outer + } + } + } +} + +func (p *Peer) startProto(offset uint64, impl Protocol) *proto { + rw := &proto{ + in: make(chan Msg), + offset: offset, + maxcode: impl.Length, + peer: p, + } + p.protoWG.Add(1) + go func() { + err := impl.Run(p, rw) + if err == nil { + p.Infof("protocol %q returned", impl.Name) + err = newPeerError(errMisc, "protocol returned") + } else { + p.Errorf("protocol %q error: %v\n", impl.Name, err) + } + select { + case p.protoErr <- err: + case <-p.closed: + } + p.protoWG.Done() + }() + return rw +} + +// getProto finds the protocol responsible for handling +// the given message code. +func (p *Peer) getProto(code uint64) (*proto, error) { + p.runlock.RLock() + defer p.runlock.RUnlock() + for _, proto := range p.running { + if code >= proto.offset && code < proto.offset+proto.maxcode { + return proto, nil + } + } + return nil, newPeerError(errInvalidMsgCode, "%d", code) +} + +func (p *Peer) closeProtocols() { + p.runlock.RLock() + for _, p := range p.running { + close(p.in) + } + p.runlock.RUnlock() + p.protoWG.Wait() +} + +// writeProtoMsg sends the given message on behalf of the given named protocol. +func (p *Peer) writeProtoMsg(protoName string, msg Msg) error { + p.runlock.RLock() + proto, ok := p.running[protoName] + p.runlock.RUnlock() + if !ok { + return fmt.Errorf("protocol %s not handled by peer", protoName) + } + if msg.Code >= proto.maxcode { + return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName) + } + msg.Code += proto.offset + return p.writeMsg(msg, msgWriteTimeout) +} + +// writeMsg writes a message to the connection. +func (p *Peer) writeMsg(msg Msg, timeout time.Duration) error { + p.writeMu.Lock() + defer p.writeMu.Unlock() + p.conn.SetWriteDeadline(time.Now().Add(timeout)) + if err := writeMsg(p.bufconn, msg); err != nil { + return newPeerError(errWrite, "%v", err) + } + return p.bufconn.Flush() +} + +type proto struct { + name string + in chan Msg + maxcode, offset uint64 + peer *Peer +} + +func (rw *proto) WriteMsg(msg Msg) error { + if msg.Code >= rw.maxcode { + return newPeerError(errInvalidMsgCode, "not handled") + } + msg.Code += rw.offset + return rw.peer.writeMsg(msg, msgWriteTimeout) } -func (self *Peer) Write(protocol string, msg Msg) error { - return self.messenger.writeProtoMsg(protocol, msg) +func (rw *proto) EncodeMsg(code uint64, data ...interface{}) error { + return rw.WriteMsg(NewMsg(code, data)) } -func (self *Peer) Start() { - self.peerErrorHandler.Start() - self.messenger.Start() +func (rw *proto) ReadMsg() (Msg, error) { + msg, ok := <-rw.in + if !ok { + return msg, io.EOF + } + msg.Code -= rw.offset + return msg, nil } -func (self *Peer) Stop() { - self.peerErrorHandler.Stop() - self.messenger.Stop() +// eofSignal wraps a reader with eof signaling. +// the eof channel is closed when the wrapped reader +// reaches EOF. +type eofSignal struct { + wrapped io.Reader + eof chan<- struct{} } -func (p *Peer) Encode() []interface{} { - return []interface{}{p.Host, p.Port, p.Pubkey} +func (r *eofSignal) Read(buf []byte) (int, error) { + n, err := r.wrapped.Read(buf) + if err != nil { + r.eof <- struct{}{} // tell Peer that msg has been consumed + } + return n, err } diff --git a/p2p/peer_error.go b/p2p/peer_error.go index f3ef98d98..88b870fbd 100644 --- a/p2p/peer_error.go +++ b/p2p/peer_error.go @@ -4,71 +4,121 @@ import ( "fmt" ) -type ErrorCode int - -const errorChanCapacity = 10 - const ( - PacketTooLong = iota - PayloadTooShort - MagicTokenMismatch - ReadError - WriteError - MiscError - InvalidMsgCode - InvalidMsg - P2PVersionMismatch - PubkeyMissing - PubkeyInvalid - PubkeyForbidden - ProtocolBreach - PortMismatch - PingTimeout - InvalidGenesis - InvalidNetworkId - InvalidProtocolVersion + errMagicTokenMismatch = iota + errRead + errWrite + errMisc + errInvalidMsgCode + errInvalidMsg + errP2PVersionMismatch + errPubkeyMissing + errPubkeyInvalid + errPubkeyForbidden + errProtocolBreach + errPingTimeout + errInvalidNetworkId + errInvalidProtocolVersion ) -var errorToString = map[ErrorCode]string{ - PacketTooLong: "Packet too long", - PayloadTooShort: "Payload too short", - MagicTokenMismatch: "Magic token mismatch", - ReadError: "Read error", - WriteError: "Write error", - MiscError: "Misc error", - InvalidMsgCode: "Invalid message code", - InvalidMsg: "Invalid message", - P2PVersionMismatch: "P2P Version Mismatch", - PubkeyMissing: "Public key missing", - PubkeyInvalid: "Public key invalid", - PubkeyForbidden: "Public key forbidden", - ProtocolBreach: "Protocol Breach", - PortMismatch: "Port mismatch", - PingTimeout: "Ping timeout", - InvalidGenesis: "Invalid genesis block", - InvalidNetworkId: "Invalid network id", - InvalidProtocolVersion: "Invalid protocol version", +var errorToString = map[int]string{ + errMagicTokenMismatch: "Magic token mismatch", + errRead: "Read error", + errWrite: "Write error", + errMisc: "Misc error", + errInvalidMsgCode: "Invalid message code", + errInvalidMsg: "Invalid message", + errP2PVersionMismatch: "P2P Version Mismatch", + errPubkeyMissing: "Public key missing", + errPubkeyInvalid: "Public key invalid", + errPubkeyForbidden: "Public key forbidden", + errProtocolBreach: "Protocol Breach", + errPingTimeout: "Ping timeout", + errInvalidNetworkId: "Invalid network id", + errInvalidProtocolVersion: "Invalid protocol version", } -type PeerError struct { - Code ErrorCode +type peerError struct { + Code int message string } -func NewPeerError(code ErrorCode, format string, v ...interface{}) *PeerError { +func newPeerError(code int, format string, v ...interface{}) *peerError { desc, ok := errorToString[code] if !ok { panic("invalid error code") } - format = desc + ": " + format - message := fmt.Sprintf(format, v...) - return &PeerError{code, message} + err := &peerError{code, desc} + if format != "" { + err.message += ": " + fmt.Sprintf(format, v...) + } + return err } -func (self *PeerError) Error() string { +func (self *peerError) Error() string { return self.message } -func NewPeerErrorChannel() chan error { - return make(chan error, errorChanCapacity) +type DiscReason byte + +const ( + DiscRequested DiscReason = 0x00 + DiscNetworkError = 0x01 + DiscProtocolError = 0x02 + DiscUselessPeer = 0x03 + DiscTooManyPeers = 0x04 + DiscAlreadyConnected = 0x05 + DiscIncompatibleVersion = 0x06 + DiscInvalidIdentity = 0x07 + DiscQuitting = 0x08 + DiscUnexpectedIdentity = 0x09 + DiscSelf = 0x0a + DiscReadTimeout = 0x0b + DiscSubprotocolError = 0x10 +) + +var discReasonToString = [DiscSubprotocolError + 1]string{ + DiscRequested: "Disconnect requested", + DiscNetworkError: "Network error", + DiscProtocolError: "Breach of protocol", + DiscUselessPeer: "Useless peer", + DiscTooManyPeers: "Too many peers", + DiscAlreadyConnected: "Already connected", + DiscIncompatibleVersion: "Incompatible P2P protocol version", + DiscInvalidIdentity: "Invalid node identity", + DiscQuitting: "Client quitting", + DiscUnexpectedIdentity: "Unexpected identity", + DiscSelf: "Connected to self", + DiscReadTimeout: "Read timeout", + DiscSubprotocolError: "Subprotocol error", +} + +func (d DiscReason) String() string { + if len(discReasonToString) < int(d) { + return fmt.Sprintf("Unknown Reason(%d)", d) + } + return discReasonToString[d] +} + +func discReasonForError(err error) DiscReason { + peerError, ok := err.(*peerError) + if !ok { + return DiscSubprotocolError + } + switch peerError.Code { + case errP2PVersionMismatch: + return DiscIncompatibleVersion + case errPubkeyMissing, errPubkeyInvalid: + return DiscInvalidIdentity + case errPubkeyForbidden: + return DiscUselessPeer + case errInvalidMsgCode, errMagicTokenMismatch, errProtocolBreach: + return DiscProtocolError + case errPingTimeout: + return DiscReadTimeout + case errRead, errWrite, errMisc: + return DiscNetworkError + default: + return DiscSubprotocolError + } } diff --git a/p2p/peer_error_handler.go b/p2p/peer_error_handler.go deleted file mode 100644 index 47dcd14ff..000000000 --- a/p2p/peer_error_handler.go +++ /dev/null @@ -1,98 +0,0 @@ -package p2p - -import ( - "net" -) - -const ( - severityThreshold = 10 -) - -type DisconnectRequest struct { - addr net.Addr - reason DiscReason -} - -type PeerErrorHandler struct { - quit chan chan bool - address net.Addr - peerDisconnect chan DisconnectRequest - severity int - errc chan error -} - -func NewPeerErrorHandler(address net.Addr, peerDisconnect chan DisconnectRequest, errc chan error) *PeerErrorHandler { - return &PeerErrorHandler{ - quit: make(chan chan bool), - address: address, - peerDisconnect: peerDisconnect, - errc: errc, - } -} - -func (self *PeerErrorHandler) Start() { - go self.listen() -} - -func (self *PeerErrorHandler) Stop() { - q := make(chan bool) - self.quit <- q - <-q -} - -func (self *PeerErrorHandler) listen() { - for { - select { - case err, ok := <-self.errc: - if ok { - logger.Debugf("error %v\n", err) - go self.handle(err) - } else { - return - } - case q := <-self.quit: - q <- true - return - } - } -} - -func (self *PeerErrorHandler) handle(err error) { - reason := DiscReason(' ') - peerError, ok := err.(*PeerError) - if !ok { - peerError = NewPeerError(MiscError, " %v", err) - } - switch peerError.Code { - case P2PVersionMismatch: - reason = DiscIncompatibleVersion - case PubkeyMissing, PubkeyInvalid: - reason = DiscInvalidIdentity - case PubkeyForbidden: - reason = DiscUselessPeer - case InvalidMsgCode, PacketTooLong, PayloadTooShort, MagicTokenMismatch, ProtocolBreach: - reason = DiscProtocolError - case PingTimeout: - reason = DiscReadTimeout - case ReadError, WriteError, MiscError: - reason = DiscNetworkError - case InvalidGenesis, InvalidNetworkId, InvalidProtocolVersion: - reason = DiscSubprotocolError - default: - self.severity += self.getSeverity(peerError) - } - - if self.severity >= severityThreshold { - reason = DiscSubprotocolError - } - if reason != DiscReason(' ') { - self.peerDisconnect <- DisconnectRequest{ - addr: self.address, - reason: reason, - } - } -} - -func (self *PeerErrorHandler) getSeverity(peerError *PeerError) int { - return 1 -} diff --git a/p2p/peer_error_handler_test.go b/p2p/peer_error_handler_test.go deleted file mode 100644 index b93252f6a..000000000 --- a/p2p/peer_error_handler_test.go +++ /dev/null @@ -1,34 +0,0 @@ -package p2p - -import ( - // "fmt" - "net" - "testing" - "time" -) - -func TestPeerErrorHandler(t *testing.T) { - address := &net.TCPAddr{IP: net.IP([]byte{1, 2, 3, 4}), Port: 30303} - peerDisconnect := make(chan DisconnectRequest) - peerErrorChan := NewPeerErrorChannel() - peh := NewPeerErrorHandler(address, peerDisconnect, peerErrorChan) - peh.Start() - defer peh.Stop() - for i := 0; i < 11; i++ { - select { - case <-peerDisconnect: - t.Errorf("expected no disconnect request") - default: - } - peerErrorChan <- NewPeerError(MiscError, "") - } - time.Sleep(1 * time.Millisecond) - select { - case request := <-peerDisconnect: - if request.addr.String() != address.String() { - t.Errorf("incorrect address %v != %v", request.addr, address) - } - default: - t.Errorf("expected disconnect request") - } -} diff --git a/p2p/peer_test.go b/p2p/peer_test.go index da62cc380..1afa0ab17 100644 --- a/p2p/peer_test.go +++ b/p2p/peer_test.go @@ -1,90 +1,222 @@ package p2p -// "net" - -// func TestPeer(t *testing.T) { -// handlers := make(Handlers) -// testProtocol := &TestProtocol{recv: make(chan testMsg)} -// handlers["aaa"] = func(p *Peer) Protocol { return testProtocol } -// handlers["ccc"] = func(p *Peer) Protocol { return testProtocol } -// addr := &TestAddr{"test:30"} -// conn := NewTestNetworkConnection(addr) -// _, server := SetupTestServer(handlers) -// server.Handshake() -// peer := NewPeer(conn, addr, true, server) -// // peer.Messenger().AddProtocols([]string{"aaa", "ccc"}) -// peer.Start() -// defer peer.Stop() -// time.Sleep(2 * time.Millisecond) -// if len(conn.Out) != 1 { -// t.Errorf("handshake not sent") -// } else { -// out := conn.Out[0] -// packet := Packet(0, HandshakeMsg, P2PVersion, []byte(peer.server.identity.String()), []interface{}{peer.server.protocols}, peer.server.port, peer.server.identity.Pubkey()[1:]) -// if bytes.Compare(out, packet) != 0 { -// t.Errorf("incorrect handshake packet %v != %v", out, packet) -// } -// } - -// packet := Packet(0, HandshakeMsg, P2PVersion, []byte("peer"), []interface{}{"bbb", "aaa", "ccc"}, 30, []byte("0000000000000000000000000000000000000000000000000000000000000000")) -// conn.In(0, packet) -// time.Sleep(10 * time.Millisecond) - -// pro, _ := peer.Messenger().protocols[0].(*BaseProtocol) -// if pro.state != handshakeReceived { -// t.Errorf("handshake not received") -// } -// if peer.Port != 30 { -// t.Errorf("port incorrectly set") -// } -// if peer.Id != "peer" { -// t.Errorf("id incorrectly set") -// } -// if string(peer.Pubkey) != "0000000000000000000000000000000000000000000000000000000000000000" { -// t.Errorf("pubkey incorrectly set") -// } -// fmt.Println(peer.Caps) -// if len(peer.Caps) != 3 || peer.Caps[0] != "aaa" || peer.Caps[1] != "bbb" || peer.Caps[2] != "ccc" { -// t.Errorf("protocols incorrectly set") -// } - -// msg := NewMsg(3) -// err := peer.Write("aaa", msg) -// if err != nil { -// t.Errorf("expect no error for known protocol: %v", err) -// } else { -// time.Sleep(1 * time.Millisecond) -// if len(conn.Out) != 2 { -// t.Errorf("msg not written") -// } else { -// out := conn.Out[1] -// packet := Packet(16, 3) -// if bytes.Compare(out, packet) != 0 { -// t.Errorf("incorrect packet %v != %v", out, packet) -// } -// } -// } - -// msg = NewMsg(2) -// err = peer.Write("ccc", msg) -// if err != nil { -// t.Errorf("expect no error for known protocol: %v", err) -// } else { -// time.Sleep(1 * time.Millisecond) -// if len(conn.Out) != 3 { -// t.Errorf("msg not written") -// } else { -// out := conn.Out[2] -// packet := Packet(21, 2) -// if bytes.Compare(out, packet) != 0 { -// t.Errorf("incorrect packet %v != %v", out, packet) -// } -// } -// } - -// err = peer.Write("bbb", msg) -// time.Sleep(1 * time.Millisecond) -// if err == nil { -// t.Errorf("expect error for unknown protocol") -// } -// } +import ( + "bufio" + "net" + "reflect" + "testing" + "time" +) + +var discard = Protocol{ + Name: "discard", + Length: 1, + Run: func(p *Peer, rw MsgReadWriter) error { + for { + msg, err := rw.ReadMsg() + if err != nil { + return err + } + if err = msg.Discard(); err != nil { + return err + } + } + }, +} + +func testPeer(protos []Protocol) (net.Conn, *Peer, <-chan error) { + conn1, conn2 := net.Pipe() + id := NewSimpleClientIdentity("test", "0", "0", "public key") + peer := newPeer(conn1, protos, nil) + peer.ourID = id + peer.pubkeyHook = func(*peerAddr) error { return nil } + errc := make(chan error, 1) + go func() { + _, err := peer.loop() + errc <- err + }() + return conn2, peer, errc +} + +func TestPeerProtoReadMsg(t *testing.T) { + defer testlog(t).detach() + + done := make(chan struct{}) + proto := Protocol{ + Name: "a", + Length: 5, + Run: func(peer *Peer, rw MsgReadWriter) error { + msg, err := rw.ReadMsg() + if err != nil { + t.Errorf("read error: %v", err) + } + if msg.Code != 2 { + t.Errorf("incorrect msg code %d relayed to protocol", msg.Code) + } + data, err := msg.Data() + if err != nil { + t.Errorf("data decoding error: %v", err) + } + expdata := []interface{}{1, []byte{0x30, 0x30, 0x30}} + if !reflect.DeepEqual(data.Slice(), expdata) { + t.Errorf("incorrect msg data %#v", data.Slice()) + } + close(done) + return nil + }, + } + + net, peer, errc := testPeer([]Protocol{proto}) + defer net.Close() + peer.startSubprotocols([]Cap{proto.cap()}) + + writeMsg(net, NewMsg(18, 1, "000")) + select { + case <-done: + case err := <-errc: + t.Errorf("peer returned: %v", err) + case <-time.After(2 * time.Second): + t.Errorf("receive timeout") + } +} + +func TestPeerProtoReadLargeMsg(t *testing.T) { + defer testlog(t).detach() + + msgsize := uint32(10 * 1024 * 1024) + done := make(chan struct{}) + proto := Protocol{ + Name: "a", + Length: 5, + Run: func(peer *Peer, rw MsgReadWriter) error { + msg, err := rw.ReadMsg() + if err != nil { + t.Errorf("read error: %v", err) + } + if msg.Size != msgsize+4 { + t.Errorf("incorrect msg.Size, got %d, expected %d", msg.Size, msgsize) + } + msg.Discard() + close(done) + return nil + }, + } + + net, peer, errc := testPeer([]Protocol{proto}) + defer net.Close() + peer.startSubprotocols([]Cap{proto.cap()}) + + writeMsg(net, NewMsg(18, make([]byte, msgsize))) + select { + case <-done: + case err := <-errc: + t.Errorf("peer returned: %v", err) + case <-time.After(2 * time.Second): + t.Errorf("receive timeout") + } +} + +func TestPeerProtoEncodeMsg(t *testing.T) { + defer testlog(t).detach() + + proto := Protocol{ + Name: "a", + Length: 2, + Run: func(peer *Peer, rw MsgReadWriter) error { + if err := rw.EncodeMsg(2); err == nil { + t.Error("expected error for out-of-range msg code, got nil") + } + if err := rw.EncodeMsg(1); err != nil { + t.Errorf("write error: %v", err) + } + return nil + }, + } + net, peer, _ := testPeer([]Protocol{proto}) + defer net.Close() + peer.startSubprotocols([]Cap{proto.cap()}) + + bufr := bufio.NewReader(net) + msg, err := readMsg(bufr) + if err != nil { + t.Errorf("read error: %v", err) + } + if msg.Code != 17 { + t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17) + } +} + +func TestPeerWrite(t *testing.T) { + defer testlog(t).detach() + + net, peer, peerErr := testPeer([]Protocol{discard}) + defer net.Close() + peer.startSubprotocols([]Cap{discard.cap()}) + + // test write errors + if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil { + t.Errorf("expected error for unknown protocol, got nil") + } + if err := peer.writeProtoMsg("discard", NewMsg(8)); err == nil { + t.Errorf("expected error for out-of-range msg code, got nil") + } else if perr, ok := err.(*peerError); !ok || perr.Code != errInvalidMsgCode { + t.Errorf("wrong error for out-of-range msg code, got %#v", err) + } + + // setup for reading the message on the other end + read := make(chan struct{}) + go func() { + bufr := bufio.NewReader(net) + msg, err := readMsg(bufr) + if err != nil { + t.Errorf("read error: %v", err) + } else if msg.Code != 16 { + t.Errorf("wrong code, got %d, expected %d", msg.Code, 16) + } + msg.Discard() + close(read) + }() + + // test succcessful write + if err := peer.writeProtoMsg("discard", NewMsg(0)); err != nil { + t.Errorf("expect no error for known protocol: %v", err) + } + select { + case <-read: + case err := <-peerErr: + t.Fatalf("peer stopped: %v", err) + } +} + +func TestPeerActivity(t *testing.T) { + // shorten inactivityTimeout while this test is running + oldT := inactivityTimeout + defer func() { inactivityTimeout = oldT }() + inactivityTimeout = 20 * time.Millisecond + + net, peer, peerErr := testPeer([]Protocol{discard}) + defer net.Close() + peer.startSubprotocols([]Cap{discard.cap()}) + + sub := peer.activity.Subscribe(time.Time{}) + defer sub.Unsubscribe() + + for i := 0; i < 6; i++ { + writeMsg(net, NewMsg(16)) + select { + case <-sub.Chan(): + case <-time.After(inactivityTimeout / 2): + t.Fatal("no event within ", inactivityTimeout/2) + case err := <-peerErr: + t.Fatal("peer error", err) + } + } + + select { + case <-time.After(inactivityTimeout * 2): + case <-sub.Chan(): + t.Fatal("got activity event while connection was inactive") + case err := <-peerErr: + t.Fatal("peer error", err) + } +} diff --git a/p2p/protocol.go b/p2p/protocol.go index d22ba70cb..169dcdb6e 100644 --- a/p2p/protocol.go +++ b/p2p/protocol.go @@ -3,249 +3,185 @@ package p2p import ( "bytes" "net" - "sort" "time" "github.com/ethereum/go-ethereum/ethutil" ) -// Protocol is implemented by P2P subprotocols. -type Protocol interface { - // Start is called when the protocol becomes active. - // It should read and write messages from rw. - // Messages must be fully consumed. - // - // The connection is closed when Start returns. It should return - // any protocol-level error (such as an I/O error) that is - // encountered. - Start(peer *Peer, rw MsgReadWriter) error +// Protocol represents a P2P subprotocol implementation. +type Protocol struct { + // Name should contain the official protocol name, + // often a three-letter word. + Name string - // Offset should return the number of message codes - // used by the protocol. - Offset() MsgCode -} + // Version should contain the version number of the protocol. + Version uint -type MsgReader interface { - ReadMsg() (Msg, error) -} - -type MsgWriter interface { - WriteMsg(Msg) error -} - -// MsgReadWriter is passed to protocols. Protocol implementations can -// use it to write messages back to a connected peer. -type MsgReadWriter interface { - MsgReader - MsgWriter -} + // Length should contain the number of message codes used + // by the protocol. + Length uint64 -type MsgHandler func(code MsgCode, data *ethutil.Value) error - -// MsgLoop reads messages off the given reader and -// calls the handler function for each decoded message until -// it returns an error or the peer connection is closed. -// -// If a message is larger than the given maximum size, RunProtocol -// returns an appropriate error.n -func MsgLoop(r MsgReader, maxsize uint32, handler MsgHandler) error { - for { - msg, err := r.ReadMsg() - if err != nil { - return err - } - if msg.Size > maxsize { - return NewPeerError(InvalidMsg, "size %d exceeds maximum size of %d", msg.Size, maxsize) - } - value, err := msg.Data() - if err != nil { - return err - } - if err := handler(msg.Code, value); err != nil { - return err - } - } -} - -// the ÐΞVp2p base protocol -type baseProtocol struct { - rw MsgReadWriter - peer *Peer + // Run is called in a new groutine when the protocol has been + // negotiated with a peer. It should read and write messages from + // rw. The Payload for each message must be fully consumed. + // + // The peer connection is closed when Start returns. It should return + // any protocol-level error (such as an I/O error) that is + // encountered. + Run func(peer *Peer, rw MsgReadWriter) error } -type bpMsg struct { - code MsgCode - data *ethutil.Value +func (p Protocol) cap() Cap { + return Cap{p.Name, p.Version} } const ( - p2pVersion = 0 - pingTimeout = 2 * time.Second - pingGracePeriod = 2 * time.Second + baseProtocolVersion = 2 + baseProtocolLength = uint64(16) + baseProtocolMaxMsgSize = 10 * 1024 * 1024 ) const ( - // message codes - handshakeMsg = iota - discMsg - pingMsg - pongMsg - getPeersMsg - peersMsg + // devp2p message codes + handshakeMsg = 0x00 + discMsg = 0x01 + pingMsg = 0x02 + pongMsg = 0x03 + getPeersMsg = 0x04 + peersMsg = 0x05 ) -const ( - baseProtocolOffset MsgCode = 16 - baseProtocolMaxMsgSize = 500 * 1024 -) - -type DiscReason byte +// handshake is the structure of a handshake list. +type handshake struct { + Version uint64 + ID string + Caps []Cap + ListenPort uint64 + NodeID []byte +} -const ( - // Values are given explicitly instead of by iota because these values are - // defined by the wire protocol spec; it is easier for humans to ensure - // correctness when values are explicit. - DiscRequested = 0x00 - DiscNetworkError = 0x01 - DiscProtocolError = 0x02 - DiscUselessPeer = 0x03 - DiscTooManyPeers = 0x04 - DiscAlreadyConnected = 0x05 - DiscIncompatibleVersion = 0x06 - DiscInvalidIdentity = 0x07 - DiscQuitting = 0x08 - DiscUnexpectedIdentity = 0x09 - DiscSelf = 0x0a - DiscReadTimeout = 0x0b - DiscSubprotocolError = 0x10 -) +func (h *handshake) String() string { + return h.ID +} +func (h *handshake) Pubkey() []byte { + return h.NodeID +} -var discReasonToString = [DiscSubprotocolError + 1]string{ - DiscRequested: "Disconnect requested", - DiscNetworkError: "Network error", - DiscProtocolError: "Breach of protocol", - DiscUselessPeer: "Useless peer", - DiscTooManyPeers: "Too many peers", - DiscAlreadyConnected: "Already connected", - DiscIncompatibleVersion: "Incompatible P2P protocol version", - DiscInvalidIdentity: "Invalid node identity", - DiscQuitting: "Client quitting", - DiscUnexpectedIdentity: "Unexpected identity", - DiscSelf: "Connected to self", - DiscReadTimeout: "Read timeout", - DiscSubprotocolError: "Subprotocol error", +// Cap is the structure of a peer capability. +type Cap struct { + Name string + Version uint } -func (d DiscReason) String() string { - if len(discReasonToString) < int(d) { - return "Unknown" - } - return discReasonToString[d] +func (cap Cap) RlpData() interface{} { + return []interface{}{cap.Name, cap.Version} } -func (bp *baseProtocol) Offset() MsgCode { - return baseProtocolOffset +type capsByName []Cap + +func (cs capsByName) Len() int { return len(cs) } +func (cs capsByName) Less(i, j int) bool { return cs[i].Name < cs[j].Name } +func (cs capsByName) Swap(i, j int) { cs[i], cs[j] = cs[j], cs[i] } + +type baseProtocol struct { + rw MsgReadWriter + peer *Peer } -func (bp *baseProtocol) Start(peer *Peer, rw MsgReadWriter) error { - bp.peer, bp.rw = peer, rw +func runBaseProtocol(peer *Peer, rw MsgReadWriter) error { + bp := &baseProtocol{rw, peer} - // Do the handshake. - // TODO: disconnect is valid before handshake, too. - rw.WriteMsg(bp.peer.server.handshakeMsg()) + // do handshake + if err := rw.WriteMsg(bp.handshakeMsg()); err != nil { + return err + } msg, err := rw.ReadMsg() if err != nil { return err } if msg.Code != handshakeMsg { - return NewPeerError(ProtocolBreach, " first message must be handshake") + return newPeerError(errProtocolBreach, "first message must be handshake, got %x", msg.Code) } data, err := msg.Data() if err != nil { - return NewPeerError(InvalidMsg, "%v", err) + return newPeerError(errInvalidMsg, "%v", err) } if err := bp.handleHandshake(data); err != nil { return err } - msgin := make(chan bpMsg) - done := make(chan error, 1) + // run main loop + quit := make(chan error, 1) go func() { - done <- MsgLoop(rw, baseProtocolMaxMsgSize, - func(code MsgCode, data *ethutil.Value) error { - msgin <- bpMsg{code, data} - return nil - }) + quit <- MsgLoop(rw, baseProtocolMaxMsgSize, bp.handle) }() - return bp.loop(msgin, done) + return bp.loop(quit) } -func (bp *baseProtocol) loop(msgin <-chan bpMsg, quit <-chan error) error { - logger.Debugf("pingpong keepalive started at %v\n", time.Now()) - messenger := bp.rw.(*proto).messenger - pingTimer := time.NewTimer(pingTimeout) - pinged := true +var pingTimeout = 2 * time.Second + +func (bp *baseProtocol) loop(quit <-chan error) error { + ping := time.NewTimer(pingTimeout) + activity := bp.peer.activity.Subscribe(time.Time{}) + lastActive := time.Time{} + defer ping.Stop() + defer activity.Unsubscribe() - for { + getPeersTick := time.NewTicker(10 * time.Second) + defer getPeersTick.Stop() + err := bp.rw.EncodeMsg(getPeersMsg) + + for err == nil { select { - case msg := <-msgin: - if err := bp.handle(msg.code, msg.data); err != nil { - return err - } - case err := <-quit: + case err = <-quit: return err - case <-messenger.pulse: - pingTimer.Reset(pingTimeout) - pinged = false - case <-pingTimer.C: - if pinged { - return NewPeerError(PingTimeout, "") + case <-getPeersTick.C: + err = bp.rw.EncodeMsg(getPeersMsg) + case event := <-activity.Chan(): + ping.Reset(pingTimeout) + lastActive = event.(time.Time) + case t := <-ping.C: + if lastActive.Add(pingTimeout * 2).Before(t) { + err = newPeerError(errPingTimeout, "") + } else if lastActive.Add(pingTimeout).Before(t) { + err = bp.rw.EncodeMsg(pingMsg) } - logger.Debugf("pinging at %v\n", time.Now()) - if err := bp.rw.WriteMsg(NewMsg(pingMsg)); err != nil { - return NewPeerError(WriteError, "%v", err) - } - pinged = true - pingTimer.Reset(pingTimeout) } } + return err } -func (bp *baseProtocol) handle(code MsgCode, data *ethutil.Value) error { +func (bp *baseProtocol) handle(code uint64, data *ethutil.Value) error { switch code { case handshakeMsg: - return NewPeerError(ProtocolBreach, " extra handshake received") + return newPeerError(errProtocolBreach, "extra handshake received") case discMsg: - logger.Infof("Disconnect requested from peer %v, reason", DiscReason(data.Get(0).Uint())) - bp.peer.server.PeerDisconnect() <- DisconnectRequest{ - addr: bp.peer.Address, - reason: DiscRequested, - } + bp.peer.Disconnect(DiscReason(data.Get(0).Uint())) + return nil case pingMsg: - return bp.rw.WriteMsg(NewMsg(pongMsg)) + return bp.rw.EncodeMsg(pongMsg) case pongMsg: - // reply for ping case getPeersMsg: - // Peer asked for list of connected peers. - peersRLP := bp.peer.server.encodedPeerList() - if peersRLP != nil { - msg := Msg{ - Code: peersMsg, - Size: uint32(len(peersRLP)), - Payload: bytes.NewReader(peersRLP), - } - return bp.rw.WriteMsg(msg) + peers := bp.peerList() + // this is dangerous. the spec says that we should _delay_ + // sending the response if no new information is available. + // this means that would need to send a response later when + // new peers become available. + // + // TODO: add event mechanism to notify baseProtocol for new peers + if len(peers) > 0 { + return bp.rw.EncodeMsg(peersMsg, peers) } case peersMsg: bp.handlePeers(data) default: - return NewPeerError(InvalidMsgCode, "unknown message code %v", code) + return newPeerError(errInvalidMsgCode, "unknown message code %v", code) } return nil } @@ -253,62 +189,102 @@ func (bp *baseProtocol) handle(code MsgCode, data *ethutil.Value) error { func (bp *baseProtocol) handlePeers(data *ethutil.Value) { it := data.NewIterator() for it.Next() { - ip := net.IP(it.Value().Get(0).Bytes()) - port := it.Value().Get(1).Uint() - address := &net.TCPAddr{IP: ip, Port: int(port)} - go bp.peer.server.PeerConnect(address) + addr := &peerAddr{ + IP: net.IP(it.Value().Get(0).Bytes()), + Port: it.Value().Get(1).Uint(), + Pubkey: it.Value().Get(2).Bytes(), + } + bp.peer.Debugf("received peer suggestion: %v", addr) + bp.peer.newPeerAddr <- addr } } func (bp *baseProtocol) handleHandshake(c *ethutil.Value) error { - var ( - remoteVersion = c.Get(0).Uint() - id = c.Get(1).Str() - caps = c.Get(2) - port = c.Get(3).Uint() - pubkey = c.Get(4).Bytes() - ) - // Check correctness of p2p protocol version - if remoteVersion != p2pVersion { - return NewPeerError(P2PVersionMismatch, "Require protocol %d, received %d\n", p2pVersion, remoteVersion) + hs := handshake{ + Version: c.Get(0).Uint(), + ID: c.Get(1).Str(), + Caps: nil, // decoded below + ListenPort: c.Get(3).Uint(), + NodeID: c.Get(4).Bytes(), } - - // Handle the pub key (validation, uniqueness) - if len(pubkey) == 0 { - return NewPeerError(PubkeyMissing, "not supplied in handshake.") + if hs.Version != baseProtocolVersion { + return newPeerError(errP2PVersionMismatch, "Require protocol %d, received %d\n", + baseProtocolVersion, hs.Version) } - - if len(pubkey) != 64 { - return NewPeerError(PubkeyInvalid, "require 512 bit, got %v", len(pubkey)*8) + if len(hs.NodeID) == 0 { + return newPeerError(errPubkeyMissing, "") + } + if len(hs.NodeID) != 64 { + return newPeerError(errPubkeyInvalid, "require 512 bit, got %v", len(hs.NodeID)*8) + } + if da := bp.peer.dialAddr; da != nil { + // verify that the peer we wanted to connect to + // actually holds the target public key. + if da.Pubkey != nil && !bytes.Equal(da.Pubkey, hs.NodeID) { + return newPeerError(errPubkeyForbidden, "dial address pubkey mismatch") + } + } + pa := newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID) + if err := bp.peer.pubkeyHook(pa); err != nil { + return newPeerError(errPubkeyForbidden, "%v", err) + } + capsIt := c.Get(2).NewIterator() + for capsIt.Next() { + cap := capsIt.Value() + name := cap.Get(0).Str() + if name != "" { + hs.Caps = append(hs.Caps, Cap{Name: name, Version: uint(cap.Get(1).Uint())}) + } } - // self connect detection - if bytes.Compare(bp.peer.server.ClientIdentity().Pubkey()[1:], pubkey) == 0 { - return NewPeerError(PubkeyForbidden, "not allowed to connect to self") + var addr *peerAddr + if hs.ListenPort != 0 { + addr = newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID) + addr.Port = hs.ListenPort } + bp.peer.setHandshakeInfo(&hs, addr, hs.Caps) + bp.peer.startSubprotocols(hs.Caps) + return nil +} - // register pubkey on server. this also sets the pubkey on the peer (need lock) - if err := bp.peer.server.RegisterPubkey(bp.peer, pubkey); err != nil { - return NewPeerError(PubkeyForbidden, err.Error()) +func (bp *baseProtocol) handshakeMsg() Msg { + var ( + port uint64 + caps []interface{} + ) + if bp.peer.ourListenAddr != nil { + port = bp.peer.ourListenAddr.Port } + for _, proto := range bp.peer.protocols { + caps = append(caps, proto.cap()) + } + return NewMsg(handshakeMsg, + baseProtocolVersion, + bp.peer.ourID.String(), + caps, + port, + bp.peer.ourID.Pubkey()[1:], + ) +} - // check port - if bp.peer.Inbound { - uint16port := uint16(port) - if bp.peer.Port > 0 && bp.peer.Port != uint16port { - return NewPeerError(PortMismatch, "port mismatch: %v != %v", bp.peer.Port, port) - } else { - bp.peer.Port = uint16port +func (bp *baseProtocol) peerList() []ethutil.RlpEncodable { + peers := bp.peer.otherPeers() + ds := make([]ethutil.RlpEncodable, 0, len(peers)) + for _, p := range peers { + p.infolock.Lock() + addr := p.listenAddr + p.infolock.Unlock() + // filter out this peer and peers that are not listening or + // have not completed the handshake. + // TODO: track previously sent peers and exclude them as well. + if p == bp.peer || addr == nil { + continue } + ds = append(ds, addr) } - - capsIt := caps.NewIterator() - for capsIt.Next() { - cap := capsIt.Value().Str() - bp.peer.Caps = append(bp.peer.Caps, cap) + ourAddr := bp.peer.ourListenAddr + if ourAddr != nil && !ourAddr.IP.IsLoopback() && !ourAddr.IP.IsUnspecified() { + ds = append(ds, ourAddr) } - sort.Strings(bp.peer.Caps) - bp.rw.(*proto).messenger.setRemoteProtocols(bp.peer.Caps) - bp.peer.Id = id - return nil + return ds } diff --git a/p2p/server.go b/p2p/server.go index 54d2cde30..8a6087566 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -2,155 +2,101 @@ package p2p import ( "bytes" + "errors" "fmt" "net" - "sort" - "strconv" "sync" "time" - logpkg "github.com/ethereum/go-ethereum/logger" + "github.com/ethereum/go-ethereum/logger" ) const ( - outboundAddressPoolSize = 10 - disconnectGracePeriod = 2 + outboundAddressPoolSize = 500 + defaultDialTimeout = 10 * time.Second + portMappingUpdateInterval = 15 * time.Minute + portMappingTimeout = 20 * time.Minute ) -type Blacklist interface { - Get([]byte) (bool, error) - Put([]byte) error - Delete([]byte) error - Exists(pubkey []byte) (ok bool) -} - -type BlacklistMap struct { - blacklist map[string]bool - lock sync.RWMutex -} - -func NewBlacklist() *BlacklistMap { - return &BlacklistMap{ - blacklist: make(map[string]bool), - } -} - -func (self *BlacklistMap) Get(pubkey []byte) (bool, error) { - self.lock.RLock() - defer self.lock.RUnlock() - v, ok := self.blacklist[string(pubkey)] - var err error - if !ok { - err = fmt.Errorf("not found") - } - return v, err -} - -func (self *BlacklistMap) Exists(pubkey []byte) (ok bool) { - self.lock.RLock() - defer self.lock.RUnlock() - _, ok = self.blacklist[string(pubkey)] - return -} - -func (self *BlacklistMap) Put(pubkey []byte) error { - self.lock.RLock() - defer self.lock.RUnlock() - self.blacklist[string(pubkey)] = true - return nil -} - -func (self *BlacklistMap) Delete(pubkey []byte) error { - self.lock.RLock() - defer self.lock.RUnlock() - delete(self.blacklist, string(pubkey)) - return nil -} +var srvlog = logger.NewLogger("P2P Server") +// Server manages all peer connections. +// +// The fields of Server are used as configuration parameters. +// You should set them before starting the Server. Fields may not be +// modified while the server is running. type Server struct { - network Network - listening bool //needed? - dialing bool //needed? - closed bool - identity ClientIdentity - addr net.Addr - port uint16 - protocols []string - - quit chan chan bool - peersLock sync.RWMutex - - maxPeers int - peers []*Peer - peerSlots chan int - peersTable map[string]int - peerCount int - cachedEncodedPeers []byte - - peerConnect chan net.Addr - peerDisconnect chan DisconnectRequest - blacklist Blacklist - handlers Handlers -} - -var logger = logpkg.NewLogger("P2P") - -func New(network Network, addr net.Addr, identity ClientIdentity, handlers Handlers, maxPeers int, blacklist Blacklist) *Server { - // get alphabetical list of protocol names from handlers map - protocols := []string{} - for protocol := range handlers { - protocols = append(protocols, protocol) - } - sort.Strings(protocols) - - _, port, _ := net.SplitHostPort(addr.String()) - intport, _ := strconv.Atoi(port) - - self := &Server{ - // NewSimpleClientIdentity(clientIdentifier, version, customIdentifier) - network: network, - identity: identity, - addr: addr, - port: uint16(intport), - protocols: protocols, - - quit: make(chan chan bool), - - maxPeers: maxPeers, - peers: make([]*Peer, maxPeers), - peerSlots: make(chan int, maxPeers), - peersTable: make(map[string]int), + // This field must be set to a valid client identity. + Identity ClientIdentity + + // MaxPeers is the maximum number of peers that can be + // connected. It must be greater than zero. + MaxPeers int + + // Protocols should contain the protocols supported + // by the server. Matching protocols are launched for + // each peer. + Protocols []Protocol + + // If Blacklist is set to a non-nil value, the given Blacklist + // is used to verify peer connections. + Blacklist Blacklist + + // If ListenAddr is set to a non-nil address, the server + // will listen for incoming connections. + // + // If the port is zero, the operating system will pick a port. The + // ListenAddr field will be updated with the actual address when + // the server is started. + ListenAddr string + + // If set to a non-nil value, the given NAT port mapper + // is used to make the listening port available to the + // Internet. + NAT NAT + + // If Dialer is set to a non-nil value, the given Dialer + // is used to dial outbound peer connections. + Dialer *net.Dialer + + // If NoDial is true, the server will not dial any peers. + NoDial bool + + // Hook for testing. This is useful because we can inhibit + // the whole protocol stack. + newPeerFunc peerFunc - peerConnect: make(chan net.Addr, outboundAddressPoolSize), - peerDisconnect: make(chan DisconnectRequest), - blacklist: blacklist, - - handlers: handlers, - } - for i := 0; i < maxPeers; i++ { - self.peerSlots <- i // fill up with indexes - } - return self + lock sync.RWMutex + running bool + listener net.Listener + laddr *net.TCPAddr // real listen addr + peers []*Peer + peerSlots chan int + peerCount int + + quit chan struct{} + wg sync.WaitGroup + peerConnect chan *peerAddr + peerDisconnect chan *Peer } -func (self *Server) NewAddr(host string, port int) (addr net.Addr, err error) { - addr, err = self.network.NewAddr(host, port) - return -} +// NAT is implemented by NAT traversal methods. +type NAT interface { + GetExternalAddress() (net.IP, error) + AddPortMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error + DeletePortMapping(protocol string, extport, intport int) error -func (self *Server) ParseAddr(address string) (addr net.Addr, err error) { - addr, err = self.network.ParseAddr(address) - return + // Should return name of the method. + String() string } -func (self *Server) ClientIdentity() ClientIdentity { - return self.identity -} +type peerFunc func(srv *Server, c net.Conn, dialAddr *peerAddr) *Peer -func (self *Server) Peers() (peers []*Peer) { - self.peersLock.RLock() - defer self.peersLock.RUnlock() - for _, peer := range self.peers { +// Peers returns all connected peers. +func (srv *Server) Peers() (peers []*Peer) { + srv.lock.RLock() + defer srv.lock.RUnlock() + for _, peer := range srv.peers { if peer != nil { peers = append(peers, peer) } @@ -158,331 +104,364 @@ func (self *Server) Peers() (peers []*Peer) { return } -func (self *Server) PeerCount() int { - self.peersLock.RLock() - defer self.peersLock.RUnlock() - return self.peerCount +// PeerCount returns the number of connected peers. +func (srv *Server) PeerCount() int { + srv.lock.RLock() + defer srv.lock.RUnlock() + return srv.peerCount } -func (self *Server) PeerConnect(addr net.Addr) { - // TODO: should buffer, filter and uniq - // send GetPeersMsg if not blocking +// SuggestPeer injects an address into the outbound address pool. +func (srv *Server) SuggestPeer(ip net.IP, port int, nodeID []byte) { select { - case self.peerConnect <- addr: // not enough peers - self.Broadcast("", getPeersMsg) - default: // we dont care + case srv.peerConnect <- &peerAddr{ip, uint64(port), nodeID}: + default: // don't block } } -func (self *Server) PeerDisconnect() chan DisconnectRequest { - return self.peerDisconnect -} - -func (self *Server) Blacklist() Blacklist { - return self.blacklist -} - -func (self *Server) Handlers() Handlers { - return self.handlers -} - -func (self *Server) Broadcast(protocol string, code MsgCode, data ...interface{}) { +// Broadcast sends an RLP-encoded message to all connected peers. +// This method is deprecated and will be removed later. +func (srv *Server) Broadcast(protocol string, code uint64, data ...interface{}) { var payload []byte if data != nil { payload = encodePayload(data...) } - self.peersLock.RLock() - defer self.peersLock.RUnlock() - for _, peer := range self.peers { + srv.lock.RLock() + defer srv.lock.RUnlock() + for _, peer := range srv.peers { if peer != nil { var msg = Msg{Code: code} if data != nil { msg.Payload = bytes.NewReader(payload) msg.Size = uint32(len(payload)) } - peer.messenger.writeProtoMsg(protocol, msg) + peer.writeProtoMsg(protocol, msg) } } } -// Start the server -func (self *Server) Start(listen bool, dial bool) { - self.network.Start() - if listen { - listener, err := self.network.Listener(self.addr) - if err != nil { - logger.Warnf("Error initializing listener: %v", err) - logger.Warnf("Connection listening disabled") - self.listening = false - } else { - self.listening = true - logger.Infoln("Listen on %v: ready and accepting connections", listener.Addr()) - go self.inboundPeerHandler(listener) - } +// Start starts running the server. +// Servers can be re-used and started again after stopping. +func (srv *Server) Start() (err error) { + srv.lock.Lock() + defer srv.lock.Unlock() + if srv.running { + return errors.New("server already running") + } + srvlog.Infoln("Starting Server") + + // initialize fields + if srv.Identity == nil { + return fmt.Errorf("Server.Identity must be set to a non-nil identity") } - if dial { - dialer, err := self.network.Dialer(self.addr) - if err != nil { - logger.Warnf("Error initializing dialer: %v", err) - logger.Warnf("Connection dialout disabled") - self.dialing = false - } else { - self.dialing = true - logger.Infoln("Dial peers watching outbound address pool") - go self.outboundPeerHandler(dialer) + if srv.MaxPeers <= 0 { + return fmt.Errorf("Server.MaxPeers must be > 0") + } + srv.quit = make(chan struct{}) + srv.peers = make([]*Peer, srv.MaxPeers) + srv.peerSlots = make(chan int, srv.MaxPeers) + srv.peerConnect = make(chan *peerAddr, outboundAddressPoolSize) + srv.peerDisconnect = make(chan *Peer) + if srv.newPeerFunc == nil { + srv.newPeerFunc = newServerPeer + } + if srv.Blacklist == nil { + srv.Blacklist = NewBlacklist() + } + if srv.Dialer == nil { + srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout} + } + + if srv.ListenAddr != "" { + if err := srv.startListening(); err != nil { + return err } } - logger.Infoln("server started") + if !srv.NoDial { + srv.wg.Add(1) + go srv.dialLoop() + } + if srv.NoDial && srv.ListenAddr == "" { + srvlog.Warnln("I will be kind-of useless, neither dialing nor listening.") + } + + // make all slots available + for i := range srv.peers { + srv.peerSlots <- i + } + // note: discLoop is not part of WaitGroup + go srv.discLoop() + srv.running = true + return nil } -func (self *Server) Stop() { - logger.Infoln("server stopping...") - // // quit one loop if dialing - if self.dialing { - logger.Infoln("stop dialout...") - dialq := make(chan bool) - self.quit <- dialq - <-dialq - fmt.Println("quit another") - } - // quit the other loop if listening - if self.listening { - logger.Infoln("stop listening...") - listenq := make(chan bool) - self.quit <- listenq - <-listenq - fmt.Println("quit one") - } - - fmt.Println("quit waited") - - logger.Infoln("stopping peers...") - peers := []net.Addr{} - self.peersLock.RLock() - self.closed = true - for _, peer := range self.peers { - if peer != nil { - peers = append(peers, peer.Address) - } +func (srv *Server) startListening() error { + listener, err := net.Listen("tcp", srv.ListenAddr) + if err != nil { + return err + } + srv.ListenAddr = listener.Addr().String() + srv.laddr = listener.Addr().(*net.TCPAddr) + srv.listener = listener + srv.wg.Add(1) + go srv.listenLoop() + if !srv.laddr.IP.IsLoopback() && srv.NAT != nil { + srv.wg.Add(1) + go srv.natLoop(srv.laddr.Port) + } + return nil +} + +// Stop terminates the server and all active peer connections. +// It blocks until all active connections have been closed. +func (srv *Server) Stop() { + srv.lock.Lock() + if !srv.running { + srv.lock.Unlock() + return } - self.peersLock.RUnlock() - for _, address := range peers { - go self.removePeer(DisconnectRequest{ - addr: address, - reason: DiscQuitting, - }) + srv.running = false + srv.lock.Unlock() + + srvlog.Infoln("Stopping server") + if srv.listener != nil { + // this unblocks listener Accept + srv.listener.Close() + } + close(srv.quit) + for _, peer := range srv.Peers() { + peer.Disconnect(DiscQuitting) } + srv.wg.Wait() + // wait till they actually disconnect - // this is checked by draining the peerSlots (slots are released back if a peer is removed) - i := 0 - fmt.Println("draining peers") + // this is checked by claiming all peerSlots. + // slots become available as the peers disconnect. + for i := 0; i < cap(srv.peerSlots); i++ { + <-srv.peerSlots + } + // terminate discLoop + close(srv.peerDisconnect) +} + +func (srv *Server) discLoop() { + for peer := range srv.peerDisconnect { + // peer has just disconnected. free up its slot. + srvlog.Infof("%v is gone", peer) + srv.peerSlots <- peer.slot + srv.lock.Lock() + srv.peers[peer.slot] = nil + srv.lock.Unlock() + } +} -FOR: +// main loop for adding connections via listening +func (srv *Server) listenLoop() { + defer srv.wg.Done() + + srvlog.Infoln("Listening on", srv.listener.Addr()) for { select { - case slot := <-self.peerSlots: - i++ - fmt.Printf("%v: found slot %v\n", i, slot) - if i == self.maxPeers { - break FOR + case slot := <-srv.peerSlots: + conn, err := srv.listener.Accept() + if err != nil { + srv.peerSlots <- slot + return } + srvlog.Debugf("Accepted conn %v (slot %d)\n", conn.RemoteAddr(), slot) + srv.addPeer(conn, nil, slot) + case <-srv.quit: + return } } - logger.Infoln("server stopped") } -// main loop for adding connections via listening -func (self *Server) inboundPeerHandler(listener net.Listener) { +func (srv *Server) natLoop(port int) { + defer srv.wg.Done() for { + srv.updatePortMapping(port) select { - case slot := <-self.peerSlots: - go self.connectInboundPeer(listener, slot) - case errc := <-self.quit: - listener.Close() - fmt.Println("quit listenloop") - errc <- true + case <-time.After(portMappingUpdateInterval): + // one more round + case <-srv.quit: + srv.removePortMapping(port) return } } } -// main loop for adding outbound peers based on peerConnect address pool -// this same loop handles peer disconnect requests as well -func (self *Server) outboundPeerHandler(dialer Dialer) { - // addressChan initially set to nil (only watches peerConnect if we need more peers) - var addressChan chan net.Addr - slots := self.peerSlots - var slot *int +func (srv *Server) updatePortMapping(port int) { + srvlog.Infoln("Attempting to map port", port, "with", srv.NAT) + err := srv.NAT.AddPortMapping("tcp", port, port, "ethereum p2p", portMappingTimeout) + if err != nil { + srvlog.Errorln("Port mapping error:", err) + return + } + extip, err := srv.NAT.GetExternalAddress() + if err != nil { + srvlog.Errorln("Error getting external IP:", err) + return + } + srv.lock.Lock() + extaddr := *(srv.listener.Addr().(*net.TCPAddr)) + extaddr.IP = extip + srvlog.Infoln("Mapped port, external addr is", &extaddr) + srv.laddr = &extaddr + srv.lock.Unlock() +} + +func (srv *Server) removePortMapping(port int) { + srvlog.Infoln("Removing port mapping for", port, "with", srv.NAT) + srv.NAT.DeletePortMapping("tcp", port, port) +} + +func (srv *Server) dialLoop() { + defer srv.wg.Done() + var ( + suggest chan *peerAddr + slot *int + slots = srv.peerSlots + ) for { select { case i := <-slots: // we need a peer in slot i, slot reserved slot = &i // now we can watch for candidate peers in the next loop - addressChan = self.peerConnect + suggest = srv.peerConnect // do not consume more until candidate peer is found slots = nil - case address := <-addressChan: + + case desc := <-suggest: // candidate peer found, will dial out asyncronously // if connection fails slot will be released - go self.connectOutboundPeer(dialer, address, *slot) + go srv.dialPeer(desc, *slot) // we can watch if more peers needed in the next loop - slots = self.peerSlots + slots = srv.peerSlots // until then we dont care about candidate peers - addressChan = nil - case request := <-self.peerDisconnect: - go self.removePeer(request) - case errc := <-self.quit: - if addressChan != nil && slot != nil { - self.peerSlots <- *slot + suggest = nil + + case <-srv.quit: + // give back the currently reserved slot + if slot != nil { + srv.peerSlots <- *slot } - fmt.Println("quit dialloop") - errc <- true return } } } -// check if peer address already connected -func (self *Server) isConnected(address net.Addr) bool { - self.peersLock.RLock() - defer self.peersLock.RUnlock() - _, found := self.peersTable[address.String()] - return found -} - -// connect to peer via listener.Accept() -func (self *Server) connectInboundPeer(listener net.Listener, slot int) { - var address net.Addr - conn, err := listener.Accept() - if err != nil { - logger.Debugln(err) - self.peerSlots <- slot - return - } - address = conn.RemoteAddr() - // XXX: this won't work because the remote socket - // address does not identify the peer. we should - // probably get rid of this check and rely on public - // key detection in the base protocol. - if self.isConnected(address) { - conn.Close() - self.peerSlots <- slot - return - } - fmt.Printf("adding %v\n", address) - go self.addPeer(conn, address, true, slot) -} - // connect to peer via dial out -func (self *Server) connectOutboundPeer(dialer Dialer, address net.Addr, slot int) { - if self.isConnected(address) { - return - } - conn, err := dialer.Dial(address.Network(), address.String()) +func (srv *Server) dialPeer(desc *peerAddr, slot int) { + srvlog.Debugf("Dialing %v (slot %d)\n", desc, slot) + conn, err := srv.Dialer.Dial(desc.Network(), desc.String()) if err != nil { - self.peerSlots <- slot + srvlog.Errorf("Dial error: %v", err) + srv.peerSlots <- slot return } - go self.addPeer(conn, address, false, slot) + go srv.addPeer(conn, desc, slot) } // creates the new peer object and inserts it into its slot -func (self *Server) addPeer(conn net.Conn, address net.Addr, inbound bool, slot int) *Peer { - self.peersLock.Lock() - defer self.peersLock.Unlock() - if self.closed { - fmt.Println("oopsy, not no longer need peer") - conn.Close() //oopsy our bad - self.peerSlots <- slot // release slot +func (srv *Server) addPeer(conn net.Conn, desc *peerAddr, slot int) *Peer { + srv.lock.Lock() + defer srv.lock.Unlock() + if !srv.running { + conn.Close() + srv.peerSlots <- slot // release slot return nil } - logger.Infoln("adding new peer", address) - peer := NewPeer(conn, address, inbound, self) - self.peers[slot] = peer - self.peersTable[address.String()] = slot - self.peerCount++ - self.cachedEncodedPeers = nil - fmt.Printf("added peer %v %v (slot %v)\n", address, peer, slot) - peer.Start() + peer := srv.newPeerFunc(srv, conn, desc) + peer.slot = slot + srv.peers[slot] = peer + srv.peerCount++ + go func() { peer.loop(); srv.peerDisconnect <- peer }() return peer } // removes peer: sending disconnect msg, stop peer, remove rom list/table, release slot -func (self *Server) removePeer(request DisconnectRequest) { - self.peersLock.Lock() - - address := request.addr - slot := self.peersTable[address.String()] - peer := self.peers[slot] - fmt.Printf("removing peer %v %v (slot %v)\n", address, peer, slot) - if peer == nil { - logger.Debugf("already removed peer on %v", address) - self.peersLock.Unlock() +func (srv *Server) removePeer(peer *Peer) { + srv.lock.Lock() + defer srv.lock.Unlock() + srvlog.Debugf("Removing peer %v %v (slot %v)\n", peer, peer.slot) + if srv.peers[peer.slot] != peer { + srvlog.Warnln("Invalid peer to remove:", peer) return } // remove from list and index - self.peerCount-- - self.peers[slot] = nil - delete(self.peersTable, address.String()) - self.cachedEncodedPeers = nil - fmt.Printf("removed peer %v (slot %v)\n", peer, slot) - self.peersLock.Unlock() - - // sending disconnect message - disconnectMsg := NewMsg(discMsg, request.reason) - peer.Write("", disconnectMsg) - // be nice and wait - time.Sleep(disconnectGracePeriod * time.Second) - // switch off peer and close connections etc. - fmt.Println("stopping peer") - peer.Stop() - fmt.Println("stopped peer") + srv.peerCount-- + srv.peers[peer.slot] = nil // release slot to signal need for a new peer, last! - self.peerSlots <- slot + srv.peerSlots <- peer.slot } -// encodedPeerList returns an RLP-encoded list of peers. -// the returned slice will be nil if there are no peers. -func (self *Server) encodedPeerList() []byte { - // TODO: memoize and reset when peers change - self.peersLock.RLock() - defer self.peersLock.RUnlock() - if self.cachedEncodedPeers == nil && self.peerCount > 0 { - var peerData []interface{} - for _, i := range self.peersTable { - peer := self.peers[i] - peerData = append(peerData, peer.Encode()) +func (srv *Server) verifyPeer(addr *peerAddr) error { + if srv.Blacklist.Exists(addr.Pubkey) { + return errors.New("blacklisted") + } + if bytes.Equal(srv.Identity.Pubkey()[1:], addr.Pubkey) { + return newPeerError(errPubkeyForbidden, "not allowed to connect to srv") + } + srv.lock.RLock() + defer srv.lock.RUnlock() + for _, peer := range srv.peers { + if peer != nil { + id := peer.Identity() + if id != nil && bytes.Equal(id.Pubkey(), addr.Pubkey) { + return errors.New("already connected") + } } - self.cachedEncodedPeers = encodePayload(peerData) } - return self.cachedEncodedPeers + return nil } -// fix handshake message to push to peers -func (self *Server) handshakeMsg() Msg { - return NewMsg(handshakeMsg, - p2pVersion, - []byte(self.identity.String()), - []interface{}{self.protocols}, - self.port, - self.identity.Pubkey()[1:], - ) +type Blacklist interface { + Get([]byte) (bool, error) + Put([]byte) error + Delete([]byte) error + Exists(pubkey []byte) (ok bool) +} + +type BlacklistMap struct { + blacklist map[string]bool + lock sync.RWMutex } -func (self *Server) RegisterPubkey(candidate *Peer, pubkey []byte) error { - // Check for blacklisting - if self.blacklist.Exists(pubkey) { - return fmt.Errorf("blacklisted") +func NewBlacklist() *BlacklistMap { + return &BlacklistMap{ + blacklist: make(map[string]bool), } +} - self.peersLock.RLock() - defer self.peersLock.RUnlock() - for _, peer := range self.peers { - if peer != nil && peer != candidate && bytes.Compare(peer.Pubkey, pubkey) == 0 { - return fmt.Errorf("already connected") - } +func (self *BlacklistMap) Get(pubkey []byte) (bool, error) { + self.lock.RLock() + defer self.lock.RUnlock() + v, ok := self.blacklist[string(pubkey)] + var err error + if !ok { + err = fmt.Errorf("not found") } - candidate.Pubkey = pubkey + return v, err +} + +func (self *BlacklistMap) Exists(pubkey []byte) (ok bool) { + self.lock.RLock() + defer self.lock.RUnlock() + _, ok = self.blacklist[string(pubkey)] + return +} + +func (self *BlacklistMap) Put(pubkey []byte) error { + self.lock.RLock() + defer self.lock.RUnlock() + self.blacklist[string(pubkey)] = true + return nil +} + +func (self *BlacklistMap) Delete(pubkey []byte) error { + self.lock.RLock() + defer self.lock.RUnlock() + delete(self.blacklist, string(pubkey)) return nil } diff --git a/p2p/server_test.go b/p2p/server_test.go index a2594acba..5c0d08d39 100644 --- a/p2p/server_test.go +++ b/p2p/server_test.go @@ -1,289 +1,161 @@ package p2p import ( - "fmt" + "bytes" "io" "net" + "sync" "testing" "time" ) -type TestNetwork struct { - connections map[string]*TestNetworkConnection - dialer Dialer - maxinbound int -} - -func NewTestNetwork(maxinbound int) *TestNetwork { - connections := make(map[string]*TestNetworkConnection) - return &TestNetwork{ - connections: connections, - dialer: &TestDialer{connections}, - maxinbound: maxinbound, +func startTestServer(t *testing.T, pf peerFunc) *Server { + server := &Server{ + Identity: NewSimpleClientIdentity("clientIdentifier", "version", "customIdentifier", "pubkey"), + MaxPeers: 10, + ListenAddr: "127.0.0.1:0", + newPeerFunc: pf, } -} - -func (self *TestNetwork) Dialer(addr net.Addr) (Dialer, error) { - return self.dialer, nil -} - -func (self *TestNetwork) Listener(addr net.Addr) (net.Listener, error) { - return &TestListener{ - connections: self.connections, - addr: addr, - max: self.maxinbound, - close: make(chan struct{}), - }, nil -} - -func (self *TestNetwork) Start() error { - return nil -} - -func (self *TestNetwork) NewAddr(string, int) (addr net.Addr, err error) { - return -} - -func (self *TestNetwork) ParseAddr(string) (addr net.Addr, err error) { - return -} - -type TestAddr struct { - name string -} - -func (self *TestAddr) String() string { - return self.name -} - -func (*TestAddr) Network() string { - return "test" -} - -type TestDialer struct { - connections map[string]*TestNetworkConnection -} - -func (self *TestDialer) Dial(network string, addr string) (conn net.Conn, err error) { - address := &TestAddr{addr} - tconn := NewTestNetworkConnection(address) - self.connections[addr] = tconn - conn = net.Conn(tconn) - return -} - -type TestListener struct { - connections map[string]*TestNetworkConnection - addr net.Addr - max int - i int - close chan struct{} -} - -func (self *TestListener) Accept() (net.Conn, error) { - self.i++ - if self.i > self.max { - <-self.close - return nil, io.EOF + if err := server.Start(); err != nil { + t.Fatalf("Could not start server: %v", err) } - addr := &TestAddr{fmt.Sprintf("inboundpeer-%d", self.i)} - tconn := NewTestNetworkConnection(addr) - key := tconn.RemoteAddr().String() - self.connections[key] = tconn - fmt.Printf("accepted connection from: %v \n", addr) - return tconn, nil -} - -func (self *TestListener) Close() error { - close(self.close) - return nil -} - -func (self *TestListener) Addr() net.Addr { - return self.addr + return server } -type TestNetworkConnection struct { - in chan []byte - close chan struct{} - current []byte - Out [][]byte - addr net.Addr -} +func TestServerListen(t *testing.T) { + defer testlog(t).detach() -func NewTestNetworkConnection(addr net.Addr) *TestNetworkConnection { - return &TestNetworkConnection{ - in: make(chan []byte), - close: make(chan struct{}), - current: []byte{}, - Out: [][]byte{}, - addr: addr, + // start the test server + connected := make(chan *Peer) + srv := startTestServer(t, func(srv *Server, conn net.Conn, dialAddr *peerAddr) *Peer { + if conn == nil { + t.Error("peer func called with nil conn") + } + if dialAddr != nil { + t.Error("peer func called with non-nil dialAddr") + } + peer := newPeer(conn, nil, dialAddr) + connected <- peer + return peer + }) + defer close(connected) + defer srv.Stop() + + // dial the test server + conn, err := net.DialTimeout("tcp", srv.ListenAddr, 5*time.Second) + if err != nil { + t.Fatalf("could not dial: %v", err) } -} + defer conn.Close() -func (self *TestNetworkConnection) In(latency time.Duration, packets ...[]byte) { - time.Sleep(latency) - for _, s := range packets { - self.in <- s + select { + case peer := <-connected: + if peer.conn.LocalAddr().String() != conn.RemoteAddr().String() { + t.Errorf("peer started with wrong conn: got %v, want %v", + peer.conn.LocalAddr(), conn.RemoteAddr()) + } + case <-time.After(1 * time.Second): + t.Error("server did not accept within one second") } } -func (self *TestNetworkConnection) Read(buff []byte) (n int, err error) { - if len(self.current) == 0 { - var ok bool +func TestServerDial(t *testing.T) { + defer testlog(t).detach() + + // run a fake TCP server to handle the connection. + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("could not setup listener: %v") + } + defer listener.Close() + accepted := make(chan net.Conn) + go func() { + conn, err := listener.Accept() + if err != nil { + t.Error("acccept error:", err) + } + conn.Close() + accepted <- conn + }() + + // start the test server + connected := make(chan *Peer) + srv := startTestServer(t, func(srv *Server, conn net.Conn, dialAddr *peerAddr) *Peer { + if conn == nil { + t.Error("peer func called with nil conn") + } + peer := newPeer(conn, nil, dialAddr) + connected <- peer + return peer + }) + defer close(connected) + defer srv.Stop() + + // tell the server to connect. + connAddr := newPeerAddr(listener.Addr(), nil) + srv.peerConnect <- connAddr + + select { + case conn := <-accepted: select { - case self.current, ok = <-self.in: - if !ok { - return 0, io.EOF + case peer := <-connected: + if peer.conn.RemoteAddr().String() != conn.LocalAddr().String() { + t.Errorf("peer started with wrong conn: got %v, want %v", + peer.conn.RemoteAddr(), conn.LocalAddr()) + } + if peer.dialAddr != connAddr { + t.Errorf("peer started with wrong dialAddr: got %v, want %v", + peer.dialAddr, connAddr) } - case <-self.close: - return 0, io.EOF + case <-time.After(1 * time.Second): + t.Error("server did not launch peer within one second") } - } - length := len(self.current) - if length > len(buff) { - copy(buff[:], self.current[:len(buff)]) - self.current = self.current[len(buff):] - return len(buff), nil - } else { - copy(buff[:length], self.current[:]) - self.current = []byte{} - return length, io.EOF - } -} - -func (self *TestNetworkConnection) Write(buff []byte) (n int, err error) { - self.Out = append(self.Out, buff) - fmt.Printf("net write(%d): %x\n", len(self.Out), buff) - return len(buff), nil -} - -func (self *TestNetworkConnection) Close() error { - close(self.close) - return nil -} - -func (self *TestNetworkConnection) LocalAddr() (addr net.Addr) { - return -} -func (self *TestNetworkConnection) RemoteAddr() (addr net.Addr) { - return self.addr -} - -func (self *TestNetworkConnection) SetDeadline(t time.Time) (err error) { - return -} - -func (self *TestNetworkConnection) SetReadDeadline(t time.Time) (err error) { - return -} - -func (self *TestNetworkConnection) SetWriteDeadline(t time.Time) (err error) { - return -} - -func SetupTestServer(handlers Handlers) (network *TestNetwork, server *Server) { - network = NewTestNetwork(1) - addr := &TestAddr{"test:30303"} - identity := NewSimpleClientIdentity("clientIdentifier", "version", "customIdentifier", "pubkey") - maxPeers := 2 - if handlers == nil { - handlers = make(Handlers) + case <-time.After(1 * time.Second): + t.Error("server did not connect within one second") } - blackist := NewBlacklist() - server = New(network, addr, identity, handlers, maxPeers, blackist) - fmt.Println(server.identity.Pubkey()) - return } -func TestServerListener(t *testing.T) { - t.SkipNow() - - network, server := SetupTestServer(nil) - server.Start(true, false) - time.Sleep(10 * time.Millisecond) - server.Stop() - peer1, ok := network.connections["inboundpeer-1"] - if !ok { - t.Error("not found inbound peer 1") - } else { - if len(peer1.Out) != 2 { - t.Errorf("wrong number of writes to peer 1: got %d, want %d", len(peer1.Out), 2) +func TestServerBroadcast(t *testing.T) { + defer testlog(t).detach() + var connected sync.WaitGroup + srv := startTestServer(t, func(srv *Server, c net.Conn, dialAddr *peerAddr) *Peer { + peer := newPeer(c, []Protocol{discard}, dialAddr) + peer.startSubprotocols([]Cap{discard.cap()}) + connected.Done() + return peer + }) + defer srv.Stop() + + // dial a bunch of conns + var conns = make([]net.Conn, 8) + connected.Add(len(conns)) + deadline := time.Now().Add(3 * time.Second) + dialer := &net.Dialer{Deadline: deadline} + for i := range conns { + conn, err := dialer.Dial("tcp", srv.ListenAddr) + if err != nil { + t.Fatalf("conn %d: dial error: %v", i, err) } + defer conn.Close() + conn.SetDeadline(deadline) + conns[i] = conn } -} - -func TestServerDialer(t *testing.T) { - network, server := SetupTestServer(nil) - server.Start(false, true) - server.peerConnect <- &TestAddr{"outboundpeer-1"} - time.Sleep(10 * time.Millisecond) - server.Stop() - peer1, ok := network.connections["outboundpeer-1"] - if !ok { - t.Error("not found outbound peer 1") - } else { - if len(peer1.Out) != 2 { - t.Errorf("wrong number of writes to peer 1: got %d, want %d", len(peer1.Out), 2) + connected.Wait() + + // broadcast one message + srv.Broadcast("discard", 0, "foo") + goldbuf := new(bytes.Buffer) + writeMsg(goldbuf, NewMsg(16, "foo")) + golden := goldbuf.Bytes() + + // check that the message has been written everywhere + for i, conn := range conns { + buf := make([]byte, len(golden)) + if _, err := io.ReadFull(conn, buf); err != nil { + t.Errorf("conn %d: read error: %v", i, err) + } else if !bytes.Equal(buf, golden) { + t.Errorf("conn %d: msg mismatch\ngot: %x\nwant: %x", i, buf, golden) } } } - -// func TestServerBroadcast(t *testing.T) { -// handlers := make(Handlers) -// testProtocol := &TestProtocol{Msgs: []*Msg{}} -// handlers["aaa"] = func(p *Peer) Protocol { return testProtocol } -// network, server := SetupTestServer(handlers) -// server.Start(true, true) -// server.peerConnect <- &TestAddr{"outboundpeer-1"} -// time.Sleep(10 * time.Millisecond) -// msg := NewMsg(0) -// server.Broadcast("", msg) -// packet := Packet(0, 0) -// time.Sleep(10 * time.Millisecond) -// server.Stop() -// peer1, ok := network.connections["outboundpeer-1"] -// if !ok { -// t.Error("not found outbound peer 1") -// } else { -// fmt.Printf("out: %v\n", peer1.Out) -// if len(peer1.Out) != 3 { -// t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out)) -// } else { -// if bytes.Compare(peer1.Out[1], packet) != 0 { -// t.Errorf("incorrect broadcast packet %v != %v", peer1.Out[1], packet) -// } -// } -// } -// peer2, ok := network.connections["inboundpeer-1"] -// if !ok { -// t.Error("not found inbound peer 2") -// } else { -// fmt.Printf("out: %v\n", peer2.Out) -// if len(peer1.Out) != 3 { -// t.Errorf("not enough messages sent to peer 2: %v ", len(peer2.Out)) -// } else { -// if bytes.Compare(peer2.Out[1], packet) != 0 { -// t.Errorf("incorrect broadcast packet %v != %v", peer2.Out[1], packet) -// } -// } -// } -// } - -func TestServerPeersMessage(t *testing.T) { - t.SkipNow() - _, server := SetupTestServer(nil) - server.Start(true, true) - defer server.Stop() - server.peerConnect <- &TestAddr{"outboundpeer-1"} - time.Sleep(2000 * time.Millisecond) - - pl := server.encodedPeerList() - if pl == nil { - t.Errorf("expect non-nil peer list") - } - if c := server.PeerCount(); c != 2 { - t.Errorf("expect 2 peers, got %v", c) - } -} diff --git a/p2p/testlog_test.go b/p2p/testlog_test.go new file mode 100644 index 000000000..951d43243 --- /dev/null +++ b/p2p/testlog_test.go @@ -0,0 +1,28 @@ +package p2p + +import ( + "testing" + + "github.com/ethereum/go-ethereum/logger" +) + +type testLogger struct{ t *testing.T } + +func testlog(t *testing.T) testLogger { + logger.Reset() + l := testLogger{t} + logger.AddLogSystem(l) + return l +} + +func (testLogger) GetLogLevel() logger.LogLevel { return logger.DebugLevel } +func (testLogger) SetLogLevel(logger.LogLevel) {} + +func (l testLogger) LogPrint(level logger.LogLevel, msg string) { + l.t.Logf("%s", msg) +} + +func (testLogger) detach() { + logger.Flush() + logger.Reset() +} diff --git a/p2p/testpoc7.go b/p2p/testpoc7.go new file mode 100644 index 000000000..c0cc5c544 --- /dev/null +++ b/p2p/testpoc7.go @@ -0,0 +1,40 @@ +// +build none + +package main + +import ( + "fmt" + "log" + "net" + "os" + + "github.com/ethereum/go-ethereum/logger" + "github.com/ethereum/go-ethereum/p2p" + "github.com/obscuren/secp256k1-go" +) + +func main() { + logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, log.LstdFlags, logger.DebugLevel)) + + pub, _ := secp256k1.GenerateKeyPair() + srv := p2p.Server{ + MaxPeers: 10, + Identity: p2p.NewSimpleClientIdentity("test", "1.0", "", string(pub)), + ListenAddr: ":30303", + NAT: p2p.PMP(net.ParseIP("10.0.0.1")), + } + if err := srv.Start(); err != nil { + fmt.Println("could not start server:", err) + os.Exit(1) + } + + // add seed peers + seed, err := net.ResolveTCPAddr("tcp", "poc-7.ethdev.com:30303") + if err != nil { + fmt.Println("couldn't resolve:", err) + os.Exit(1) + } + srv.SuggestPeer(seed.IP, seed.Port, nil) + + select {} +} -- cgit v1.2.3 From c1fca72552386868d28ce7541691e53e55673549 Mon Sep 17 00:00:00 2001 From: Felix Lange Date: Mon, 24 Nov 2014 19:02:48 +0100 Subject: p2p: use package rlp --- p2p/message.go | 93 ++++++++++++++++------------------------------------- p2p/message_test.go | 3 ++ p2p/peer_test.go | 2 +- 3 files changed, 31 insertions(+), 67 deletions(-) (limited to 'p2p') diff --git a/p2p/message.go b/p2p/message.go index 89ad189d7..ade39d25a 100644 --- a/p2p/message.go +++ b/p2p/message.go @@ -3,12 +3,12 @@ package p2p import ( "bytes" "encoding/binary" - "fmt" "io" "io/ioutil" "math/big" "github.com/ethereum/go-ethereum/ethutil" + "github.com/ethereum/go-ethereum/rlp" ) // Msg defines the structure of a p2p message. @@ -43,16 +43,10 @@ func encodePayload(params ...interface{}) []byte { // Data returns the decoded RLP payload items in a message. func (msg Msg) Data() (*ethutil.Value, error) { - // TODO: avoid copying when we have a better RLP decoder - buf := new(bytes.Buffer) - var s []interface{} - if _, err := buf.ReadFrom(msg.Payload); err != nil { - return nil, err - } - for buf.Len() > 0 { - s = append(s, ethutil.DecodeWithReader(buf)) - } - return ethutil.NewValue(s), nil + s := rlp.NewListStream(msg.Payload, uint64(msg.Size)) + var v []interface{} + err := s.Decode(&v) + return ethutil.NewValue(v), err } // Discard reads any remaining payload data into a black hole. @@ -137,13 +131,9 @@ func makeListHeader(length uint32) []byte { return append([]byte{lenb}, enc...) } -type byteReader interface { - io.Reader - io.ByteReader -} - // readMsg reads a message header from r. -func readMsg(r byteReader) (msg Msg, err error) { +// It takes an rlp.ByteReader to ensure that the decoding doesn't buffer. +func readMsg(r rlp.ByteReader) (msg Msg, err error) { // read magic and payload size start := make([]byte, 8) if _, err = io.ReadFull(r, start); err != nil { @@ -155,64 +145,35 @@ func readMsg(r byteReader) (msg Msg, err error) { size := binary.BigEndian.Uint32(start[4:]) // decode start of RLP message to get the message code - _, hdrlen, err := readListHeader(r) - if err != nil { + posr := &postrack{r, 0} + s := rlp.NewStream(posr) + if _, err := s.List(); err != nil { return msg, err } - code, codelen, err := readMsgCode(r) + code, err := s.Uint() if err != nil { return msg, err } + payloadsize := size - posr.p + return Msg{code, payloadsize, io.LimitReader(r, int64(payloadsize))}, nil +} - rlpsize := size - hdrlen - codelen - return Msg{ - Code: code, - Size: rlpsize, - Payload: io.LimitReader(r, int64(rlpsize)), - }, nil +// postrack wraps an rlp.ByteReader with a position counter. +type postrack struct { + r rlp.ByteReader + p uint32 } -// readListHeader reads an RLP list header from r. -func readListHeader(r byteReader) (len uint64, hdrlen uint32, err error) { - b, err := r.ReadByte() - if err != nil { - return 0, 0, err - } - if b < 0xC0 { - return 0, 0, fmt.Errorf("expected list start byte >= 0xC0, got %x", b) - } else if b < 0xF7 { - len = uint64(b - 0xc0) - hdrlen = 1 - } else { - lenlen := b - 0xF7 - lenbuf := make([]byte, 8) - if _, err := io.ReadFull(r, lenbuf[8-lenlen:]); err != nil { - return 0, 0, err - } - len = binary.BigEndian.Uint64(lenbuf) - hdrlen = 1 + uint32(lenlen) - } - return len, hdrlen, nil +func (r *postrack) Read(buf []byte) (int, error) { + n, err := r.r.Read(buf) + r.p += uint32(n) + return n, err } -// readUint reads an RLP-encoded unsigned integer from r. -func readMsgCode(r byteReader) (code uint64, codelen uint32, err error) { - b, err := r.ReadByte() - if err != nil { - return 0, 0, err - } - if b < 0x80 { - return uint64(b), 1, nil - } else if b < 0x89 { // max length for uint64 is 8 bytes - codelen = uint32(b - 0x80) - if codelen == 0 { - return 0, 1, nil - } - buf := make([]byte, 8) - if _, err := io.ReadFull(r, buf[8-codelen:]); err != nil { - return 0, 0, err - } - return binary.BigEndian.Uint64(buf), codelen, nil +func (r *postrack) ReadByte() (byte, error) { + b, err := r.r.ReadByte() + if err == nil { + r.p++ } - return 0, 0, fmt.Errorf("bad RLP type for message code: %x", b) + return b, err } diff --git a/p2p/message_test.go b/p2p/message_test.go index 1edabc4e7..02d70a28b 100644 --- a/p2p/message_test.go +++ b/p2p/message_test.go @@ -46,6 +46,9 @@ func TestEncodeDecodeMsg(t *testing.T) { if err != nil { t.Fatalf("first payload item decode error: %v", err) } + if v := data.Len(); v != 2 { + t.Errorf("incorrect data.Len(): got %v, expected %d", v, 1) + } if v := data.Get(0).Uint(); v != 1 { t.Errorf("incorrect data[0]: got %v, expected %d", v, 1) } diff --git a/p2p/peer_test.go b/p2p/peer_test.go index 1afa0ab17..56cd4d890 100644 --- a/p2p/peer_test.go +++ b/p2p/peer_test.go @@ -57,7 +57,7 @@ func TestPeerProtoReadMsg(t *testing.T) { if err != nil { t.Errorf("data decoding error: %v", err) } - expdata := []interface{}{1, []byte{0x30, 0x30, 0x30}} + expdata := []interface{}{[]byte{0x01}, []byte{0x30, 0x30, 0x30}} if !reflect.DeepEqual(data.Slice(), expdata) { t.Errorf("incorrect msg data %#v", data.Slice()) } -- cgit v1.2.3 From 6049fcd52ab10362721a352cfd7a93a01c3ffa97 Mon Sep 17 00:00:00 2001 From: Felix Lange Date: Tue, 25 Nov 2014 12:25:31 +0100 Subject: p2p: use package rlp for baseProtocol --- p2p/message.go | 18 ++++++--- p2p/message_test.go | 2 +- p2p/peer_test.go | 2 +- p2p/protocol.go | 107 +++++++++++++++++++++++++++------------------------- 4 files changed, 71 insertions(+), 58 deletions(-) (limited to 'p2p') diff --git a/p2p/message.go b/p2p/message.go index ade39d25a..845c832f0 100644 --- a/p2p/message.go +++ b/p2p/message.go @@ -41,14 +41,22 @@ func encodePayload(params ...interface{}) []byte { return buf.Bytes() } -// Data returns the decoded RLP payload items in a message. -func (msg Msg) Data() (*ethutil.Value, error) { - s := rlp.NewListStream(msg.Payload, uint64(msg.Size)) +// Value returns the decoded RLP payload items in a message. +func (msg Msg) Value() (*ethutil.Value, error) { var v []interface{} - err := s.Decode(&v) + err := msg.Decode(&v) return ethutil.NewValue(v), err } +// Decode parse the RLP content of a message into +// the given value, which must be a pointer. +// +// For the decoding rules, please see package rlp. +func (msg Msg) Decode(val interface{}) error { + s := rlp.NewListStream(msg.Payload, uint64(msg.Size)) + return s.Decode(val) +} + // Discard reads any remaining payload data into a black hole. func (msg Msg) Discard() error { _, err := io.Copy(ioutil.Discard, msg.Payload) @@ -91,7 +99,7 @@ func MsgLoop(r MsgReader, maxsize uint32, f func(code uint64, data *ethutil.Valu if msg.Size > maxsize { return newPeerError(errInvalidMsg, "size %d exceeds maximum size of %d", msg.Size, maxsize) } - value, err := msg.Data() + value, err := msg.Value() if err != nil { return err } diff --git a/p2p/message_test.go b/p2p/message_test.go index 02d70a28b..0f51f759e 100644 --- a/p2p/message_test.go +++ b/p2p/message_test.go @@ -42,7 +42,7 @@ func TestEncodeDecodeMsg(t *testing.T) { if decmsg.Size != 5 { t.Errorf("incorrect size %d, want %d", decmsg.Size, 5) } - data, err := decmsg.Data() + data, err := decmsg.Value() if err != nil { t.Fatalf("first payload item decode error: %v", err) } diff --git a/p2p/peer_test.go b/p2p/peer_test.go index 56cd4d890..629475421 100644 --- a/p2p/peer_test.go +++ b/p2p/peer_test.go @@ -53,7 +53,7 @@ func TestPeerProtoReadMsg(t *testing.T) { if msg.Code != 2 { t.Errorf("incorrect msg code %d relayed to protocol", msg.Code) } - data, err := msg.Data() + data, err := msg.Value() if err != nil { t.Errorf("data decoding error: %v", err) } diff --git a/p2p/protocol.go b/p2p/protocol.go index 169dcdb6e..28eab87cd 100644 --- a/p2p/protocol.go +++ b/p2p/protocol.go @@ -2,7 +2,6 @@ package p2p import ( "bytes" - "net" "time" "github.com/ethereum/go-ethereum/ethutil" @@ -90,30 +89,18 @@ type baseProtocol struct { func runBaseProtocol(peer *Peer, rw MsgReadWriter) error { bp := &baseProtocol{rw, peer} - - // do handshake - if err := rw.WriteMsg(bp.handshakeMsg()); err != nil { - return err - } - msg, err := rw.ReadMsg() - if err != nil { + if err := bp.doHandshake(rw); err != nil { return err } - if msg.Code != handshakeMsg { - return newPeerError(errProtocolBreach, "first message must be handshake, got %x", msg.Code) - } - data, err := msg.Data() - if err != nil { - return newPeerError(errInvalidMsg, "%v", err) - } - if err := bp.handleHandshake(data); err != nil { - return err - } - // run main loop quit := make(chan error, 1) go func() { - quit <- MsgLoop(rw, baseProtocolMaxMsgSize, bp.handle) + for { + if err := bp.handle(rw); err != nil { + quit <- err + break + } + } }() return bp.loop(quit) } @@ -151,13 +138,27 @@ func (bp *baseProtocol) loop(quit <-chan error) error { return err } -func (bp *baseProtocol) handle(code uint64, data *ethutil.Value) error { - switch code { +func (bp *baseProtocol) handle(rw MsgReadWriter) error { + msg, err := rw.ReadMsg() + if err != nil { + return err + } + if msg.Size > baseProtocolMaxMsgSize { + return newPeerError(errMisc, "message too big") + } + // make sure that the payload has been fully consumed + defer msg.Discard() + + switch msg.Code { case handshakeMsg: return newPeerError(errProtocolBreach, "extra handshake received") case discMsg: - bp.peer.Disconnect(DiscReason(data.Get(0).Uint())) + var reason DiscReason + if err := msg.Decode(&reason); err != nil { + return err + } + bp.peer.Disconnect(reason) return nil case pingMsg: @@ -178,35 +179,45 @@ func (bp *baseProtocol) handle(code uint64, data *ethutil.Value) error { } case peersMsg: - bp.handlePeers(data) + var peers []*peerAddr + if err := msg.Decode(&peers); err != nil { + return err + } + for _, addr := range peers { + bp.peer.Debugf("received peer suggestion: %v", addr) + bp.peer.newPeerAddr <- addr + } default: - return newPeerError(errInvalidMsgCode, "unknown message code %v", code) + return newPeerError(errInvalidMsgCode, "unknown message code %v", msg.Code) } return nil } -func (bp *baseProtocol) handlePeers(data *ethutil.Value) { - it := data.NewIterator() - for it.Next() { - addr := &peerAddr{ - IP: net.IP(it.Value().Get(0).Bytes()), - Port: it.Value().Get(1).Uint(), - Pubkey: it.Value().Get(2).Bytes(), - } - bp.peer.Debugf("received peer suggestion: %v", addr) - bp.peer.newPeerAddr <- addr +func (bp *baseProtocol) doHandshake(rw MsgReadWriter) error { + // send our handshake + if err := rw.WriteMsg(bp.handshakeMsg()); err != nil { + return err + } + + // read and handle remote handshake + msg, err := rw.ReadMsg() + if err != nil { + return err + } + if msg.Code != handshakeMsg { + return newPeerError(errProtocolBreach, "first message must be handshake, got %x", msg.Code) + } + if msg.Size > baseProtocolMaxMsgSize { + return newPeerError(errMisc, "message too big") } -} -func (bp *baseProtocol) handleHandshake(c *ethutil.Value) error { - hs := handshake{ - Version: c.Get(0).Uint(), - ID: c.Get(1).Str(), - Caps: nil, // decoded below - ListenPort: c.Get(3).Uint(), - NodeID: c.Get(4).Bytes(), + var hs handshake + if err := msg.Decode(&hs); err != nil { + return err } + + // validate handshake info if hs.Version != baseProtocolVersion { return newPeerError(errP2PVersionMismatch, "Require protocol %d, received %d\n", baseProtocolVersion, hs.Version) @@ -228,14 +239,8 @@ func (bp *baseProtocol) handleHandshake(c *ethutil.Value) error { if err := bp.peer.pubkeyHook(pa); err != nil { return newPeerError(errPubkeyForbidden, "%v", err) } - capsIt := c.Get(2).NewIterator() - for capsIt.Next() { - cap := capsIt.Value() - name := cap.Get(0).Str() - if name != "" { - hs.Caps = append(hs.Caps, Cap{Name: name, Version: uint(cap.Get(1).Uint())}) - } - } + + // TODO: remove Caps with empty name var addr *peerAddr if hs.ListenPort != 0 { -- cgit v1.2.3 From 9b85002b700500d421ba7e13ac2062a6b8090a83 Mon Sep 17 00:00:00 2001 From: Felix Lange Date: Tue, 25 Nov 2014 16:01:39 +0100 Subject: p2p: remove Msg.Value and MsgLoop --- p2p/message.go | 32 -------------------------------- p2p/message_test.go | 22 +++++++++++----------- p2p/peer_test.go | 14 ++++++++------ 3 files changed, 19 insertions(+), 49 deletions(-) (limited to 'p2p') diff --git a/p2p/message.go b/p2p/message.go index 845c832f0..d3b8b74d4 100644 --- a/p2p/message.go +++ b/p2p/message.go @@ -41,13 +41,6 @@ func encodePayload(params ...interface{}) []byte { return buf.Bytes() } -// Value returns the decoded RLP payload items in a message. -func (msg Msg) Value() (*ethutil.Value, error) { - var v []interface{} - err := msg.Decode(&v) - return ethutil.NewValue(v), err -} - // Decode parse the RLP content of a message into // the given value, which must be a pointer. // @@ -84,31 +77,6 @@ type MsgReadWriter interface { MsgWriter } -// MsgLoop reads messages off the given reader and -// calls the handler function for each decoded message until -// it returns an error or the peer connection is closed. -// -// If a message is larger than the given maximum size, -// MsgLoop returns an appropriate error. -func MsgLoop(r MsgReader, maxsize uint32, f func(code uint64, data *ethutil.Value) error) error { - for { - msg, err := r.ReadMsg() - if err != nil { - return err - } - if msg.Size > maxsize { - return newPeerError(errInvalidMsg, "size %d exceeds maximum size of %d", msg.Size, maxsize) - } - value, err := msg.Value() - if err != nil { - return err - } - if err := f(msg.Code, value); err != nil { - return err - } - } -} - var magicToken = []byte{34, 64, 8, 145} func writeMsg(w io.Writer, msg Msg) error { diff --git a/p2p/message_test.go b/p2p/message_test.go index 0f51f759e..7b39b061d 100644 --- a/p2p/message_test.go +++ b/p2p/message_test.go @@ -29,8 +29,7 @@ func TestEncodeDecodeMsg(t *testing.T) { if err := writeMsg(buf, msg); err != nil { t.Fatalf("encodeMsg error: %v", err) } - - t.Logf("encoded: %x", buf.Bytes()) + // t.Logf("encoded: %x", buf.Bytes()) decmsg, err := readMsg(buf) if err != nil { @@ -42,18 +41,19 @@ func TestEncodeDecodeMsg(t *testing.T) { if decmsg.Size != 5 { t.Errorf("incorrect size %d, want %d", decmsg.Size, 5) } - data, err := decmsg.Value() - if err != nil { - t.Fatalf("first payload item decode error: %v", err) + + var data struct { + I int + S string } - if v := data.Len(); v != 2 { - t.Errorf("incorrect data.Len(): got %v, expected %d", v, 1) + if err := decmsg.Decode(&data); err != nil { + t.Fatalf("Decode error: %v", err) } - if v := data.Get(0).Uint(); v != 1 { - t.Errorf("incorrect data[0]: got %v, expected %d", v, 1) + if data.I != 1 { + t.Errorf("incorrect data.I: got %v, expected %d", data.I, 1) } - if v := data.Get(1).Str(); v != "000" { - t.Errorf("incorrect data[1]: got %q, expected %q", v, "000") + if data.S != "000" { + t.Errorf("incorrect data.S: got %q, expected %q", data.S, "000") } } diff --git a/p2p/peer_test.go b/p2p/peer_test.go index 629475421..0994683a2 100644 --- a/p2p/peer_test.go +++ b/p2p/peer_test.go @@ -2,8 +2,10 @@ package p2p import ( "bufio" + "bytes" + "encoding/hex" + "io/ioutil" "net" - "reflect" "testing" "time" ) @@ -53,13 +55,13 @@ func TestPeerProtoReadMsg(t *testing.T) { if msg.Code != 2 { t.Errorf("incorrect msg code %d relayed to protocol", msg.Code) } - data, err := msg.Value() + data, err := ioutil.ReadAll(msg.Payload) if err != nil { - t.Errorf("data decoding error: %v", err) + t.Errorf("payload read error: %v", err) } - expdata := []interface{}{[]byte{0x01}, []byte{0x30, 0x30, 0x30}} - if !reflect.DeepEqual(data.Slice(), expdata) { - t.Errorf("incorrect msg data %#v", data.Slice()) + expdata, _ := hex.DecodeString("0183303030") + if !bytes.Equal(expdata, data) { + t.Errorf("incorrect msg data %x", data) } close(done) return nil -- cgit v1.2.3 From 3a09459c4c3c6d4edefa57a9b245402003ae191e Mon Sep 17 00:00:00 2001 From: Felix Lange Date: Wed, 26 Nov 2014 22:08:54 +0100 Subject: p2p: make Disconnect not hang for peers created with NewPeer --- p2p/peer.go | 1 + 1 file changed, 1 insertion(+) (limited to 'p2p') diff --git a/p2p/peer.go b/p2p/peer.go index 238d3d9c9..893ba86d7 100644 --- a/p2p/peer.go +++ b/p2p/peer.go @@ -97,6 +97,7 @@ func NewPeer(id ClientIdentity, caps []Cap) *Peer { conn, _ := net.Pipe() peer := newPeer(conn, nil, nil) peer.setHandshakeInfo(id, nil, caps) + close(peer.closed) return peer } -- cgit v1.2.3 From cfd7e74c25fa7d1b443f8527fca8afad14ef4419 Mon Sep 17 00:00:00 2001 From: Felix Lange Date: Wed, 26 Nov 2014 22:49:40 +0100 Subject: p2p: add test for NewPeer --- p2p/peer_test.go | 15 +++++++++++++++ 1 file changed, 15 insertions(+) (limited to 'p2p') diff --git a/p2p/peer_test.go b/p2p/peer_test.go index 0994683a2..d9640292f 100644 --- a/p2p/peer_test.go +++ b/p2p/peer_test.go @@ -6,6 +6,7 @@ import ( "encoding/hex" "io/ioutil" "net" + "reflect" "testing" "time" ) @@ -222,3 +223,17 @@ func TestPeerActivity(t *testing.T) { t.Fatal("peer error", err) } } + +func TestNewPeer(t *testing.T) { + id := NewSimpleClientIdentity("clientid", "version", "customid", "pubkey") + caps := []Cap{{"foo", 2}, {"bar", 3}} + p := NewPeer(id, caps) + if !reflect.DeepEqual(p.Caps(), caps) { + t.Errorf("Caps mismatch: got %v, expected %v", p.Caps(), caps) + } + if p.Identity() != id { + t.Errorf("Identity mismatch: got %v, expected %v", p.Identity(), id) + } + // Should not hang. + p.Disconnect(DiscAlreadyConnected) +} -- cgit v1.2.3 From 1fb84d3c5f486ef0d42a21f3cd8416d6ef211604 Mon Sep 17 00:00:00 2001 From: obscuren Date: Wed, 10 Dec 2014 10:57:19 +0100 Subject: Fixed tests --- p2p/message_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'p2p') diff --git a/p2p/message_test.go b/p2p/message_test.go index 7b39b061d..557bfed26 100644 --- a/p2p/message_test.go +++ b/p2p/message_test.go @@ -43,7 +43,7 @@ func TestEncodeDecodeMsg(t *testing.T) { } var data struct { - I int + I uint S string } if err := decmsg.Decode(&data); err != nil { -- cgit v1.2.3 From 9423401d73562b2a55398559a9d21af75210d955 Mon Sep 17 00:00:00 2001 From: Felix Lange Date: Tue, 9 Dec 2014 13:39:56 +0100 Subject: p2p: fix decoding of disconnect reason (fixes #200) --- p2p/protocol.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'p2p') diff --git a/p2p/protocol.go b/p2p/protocol.go index 28eab87cd..5af586f13 100644 --- a/p2p/protocol.go +++ b/p2p/protocol.go @@ -154,11 +154,11 @@ func (bp *baseProtocol) handle(rw MsgReadWriter) error { return newPeerError(errProtocolBreach, "extra handshake received") case discMsg: - var reason DiscReason + var reason [1]DiscReason if err := msg.Decode(&reason); err != nil { return err } - bp.peer.Disconnect(reason) + bp.peer.Disconnect(reason[0]) return nil case pingMsg: -- cgit v1.2.3 From e28c60caf9a31669451124a9add2b9036bec1e73 Mon Sep 17 00:00:00 2001 From: Felix Lange Date: Fri, 12 Dec 2014 11:38:42 +0100 Subject: p2p: improve and test eofSignal --- p2p/peer.go | 17 ++++++++++++----- p2p/peer_test.go | 56 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 5 deletions(-) (limited to 'p2p') diff --git a/p2p/peer.go b/p2p/peer.go index 893ba86d7..86c4d7ab5 100644 --- a/p2p/peer.go +++ b/p2p/peer.go @@ -300,7 +300,7 @@ func (p *Peer) dispatch(msg Msg, protoDone chan struct{}) (wait bool, err error) proto.in <- msg } else { wait = true - pr := &eofSignal{msg.Payload, protoDone} + pr := &eofSignal{msg.Payload, int64(msg.Size), protoDone} msg.Payload = pr proto.in <- msg } @@ -438,18 +438,25 @@ func (rw *proto) ReadMsg() (Msg, error) { return msg, nil } -// eofSignal wraps a reader with eof signaling. -// the eof channel is closed when the wrapped reader -// reaches EOF. +// eofSignal wraps a reader with eof signaling. the eof channel is +// closed when the wrapped reader returns an error or when count bytes +// have been read. +// type eofSignal struct { wrapped io.Reader + count int64 eof chan<- struct{} } +// note: when using eofSignal to detect whether a message payload +// has been read, Read might not be called for zero sized messages. + func (r *eofSignal) Read(buf []byte) (int, error) { n, err := r.wrapped.Read(buf) - if err != nil { + r.count -= int64(n) + if (err != nil || r.count <= 0) && r.eof != nil { r.eof <- struct{}{} // tell Peer that msg has been consumed + r.eof = nil } return n, err } diff --git a/p2p/peer_test.go b/p2p/peer_test.go index d9640292f..f7759786e 100644 --- a/p2p/peer_test.go +++ b/p2p/peer_test.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "encoding/hex" + "io" "io/ioutil" "net" "reflect" @@ -237,3 +238,58 @@ func TestNewPeer(t *testing.T) { // Should not hang. p.Disconnect(DiscAlreadyConnected) } + +func TestEOFSignal(t *testing.T) { + rb := make([]byte, 10) + + // empty reader + eof := make(chan struct{}, 1) + sig := &eofSignal{new(bytes.Buffer), 0, eof} + if n, err := sig.Read(rb); n != 0 || err != io.EOF { + t.Errorf("Read returned unexpected values: (%v, %v)", n, err) + } + select { + case <-eof: + default: + t.Error("EOF chan not signaled") + } + + // count before error + eof = make(chan struct{}, 1) + sig = &eofSignal{bytes.NewBufferString("aaaaaaaa"), 4, eof} + if n, err := sig.Read(rb); n != 8 || err != nil { + t.Errorf("Read returned unexpected values: (%v, %v)", n, err) + } + select { + case <-eof: + default: + t.Error("EOF chan not signaled") + } + + // error before count + eof = make(chan struct{}, 1) + sig = &eofSignal{bytes.NewBufferString("aaaa"), 999, eof} + if n, err := sig.Read(rb); n != 4 || err != nil { + t.Errorf("Read returned unexpected values: (%v, %v)", n, err) + } + if n, err := sig.Read(rb); n != 0 || err != io.EOF { + t.Errorf("Read returned unexpected values: (%v, %v)", n, err) + } + select { + case <-eof: + default: + t.Error("EOF chan not signaled") + } + + // no signal if neither occurs + eof = make(chan struct{}, 1) + sig = &eofSignal{bytes.NewBufferString("aaaaaaaaaaaaaaaaaaaaa"), 999, eof} + if n, err := sig.Read(rb); n != 10 || err != nil { + t.Errorf("Read returned unexpected values: (%v, %v)", n, err) + } + select { + case <-eof: + t.Error("unexpected EOF signal") + default: + } +} -- cgit v1.2.3 From 65e39bf20eecf7f3296fba531efb72ac28dc0124 Mon Sep 17 00:00:00 2001 From: Felix Lange Date: Fri, 12 Dec 2014 11:39:07 +0100 Subject: p2p: add MsgPipe for protocol testing --- p2p/message.go | 77 +++++++++++++++++++++++++++++++++++++++++++++++++++++ p2p/message_test.go | 63 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 140 insertions(+) (limited to 'p2p') diff --git a/p2p/message.go b/p2p/message.go index d3b8b74d4..f5418ff47 100644 --- a/p2p/message.go +++ b/p2p/message.go @@ -3,9 +3,11 @@ package p2p import ( "bytes" "encoding/binary" + "errors" "io" "io/ioutil" "math/big" + "sync/atomic" "github.com/ethereum/go-ethereum/ethutil" "github.com/ethereum/go-ethereum/rlp" @@ -153,3 +155,78 @@ func (r *postrack) ReadByte() (byte, error) { } return b, err } + +// MsgPipe creates a message pipe. Reads on one end are matched +// with writes on the other. The pipe is full-duplex, both ends +// implement MsgReadWriter. +func MsgPipe() (*MsgPipeRW, *MsgPipeRW) { + var ( + c1, c2 = make(chan Msg), make(chan Msg) + closing = make(chan struct{}) + closed = new(int32) + rw1 = &MsgPipeRW{c1, c2, closing, closed} + rw2 = &MsgPipeRW{c2, c1, closing, closed} + ) + return rw1, rw2 +} + +// ErrPipeClosed is returned from pipe operations after the +// pipe has been closed. +var ErrPipeClosed = errors.New("p2p: read or write on closed message pipe") + +// MsgPipeRW is an endpoint of a MsgReadWriter pipe. +type MsgPipeRW struct { + w chan<- Msg + r <-chan Msg + closing chan struct{} + closed *int32 +} + +// WriteMsg sends a messsage on the pipe. +// It blocks until the receiver has consumed the message payload. +func (p *MsgPipeRW) WriteMsg(msg Msg) error { + if atomic.LoadInt32(p.closed) == 0 { + consumed := make(chan struct{}, 1) + msg.Payload = &eofSignal{msg.Payload, int64(msg.Size), consumed} + select { + case p.w <- msg: + if msg.Size > 0 { + // wait for payload read or discard + <-consumed + } + return nil + case <-p.closing: + } + } + return ErrPipeClosed +} + +// EncodeMsg is a convenient shorthand for sending an RLP-encoded message. +func (p *MsgPipeRW) EncodeMsg(code uint64, data ...interface{}) error { + return p.WriteMsg(NewMsg(code, data...)) +} + +// ReadMsg returns a message sent on the other end of the pipe. +func (p *MsgPipeRW) ReadMsg() (Msg, error) { + if atomic.LoadInt32(p.closed) == 0 { + select { + case msg := <-p.r: + return msg, nil + case <-p.closing: + } + } + return Msg{}, ErrPipeClosed +} + +// Close unblocks any pending ReadMsg and WriteMsg calls on both ends +// of the pipe. They will return ErrPipeClosed. Note that Close does +// not interrupt any reads from a message payload. +func (p *MsgPipeRW) Close() error { + if atomic.AddInt32(p.closed, 1) != 1 { + // someone else is already closing + atomic.StoreInt32(p.closed, 1) // avoid overflow + return nil + } + close(p.closing) + return nil +} diff --git a/p2p/message_test.go b/p2p/message_test.go index 7b39b061d..0fbcfeef0 100644 --- a/p2p/message_test.go +++ b/p2p/message_test.go @@ -2,8 +2,11 @@ package p2p import ( "bytes" + "fmt" "io/ioutil" + "runtime" "testing" + "time" "github.com/ethereum/go-ethereum/ethutil" ) @@ -68,3 +71,63 @@ func TestDecodeRealMsg(t *testing.T) { t.Errorf("incorrect code %d, want %d", msg.Code, 0) } } + +func ExampleMsgPipe() { + rw1, rw2 := MsgPipe() + go func() { + rw1.EncodeMsg(8, []byte{0, 0}) + rw1.EncodeMsg(5, []byte{1, 1}) + rw1.Close() + }() + + for { + msg, err := rw2.ReadMsg() + if err != nil { + break + } + var data [1][]byte + msg.Decode(&data) + fmt.Printf("msg: %d, %x\n", msg.Code, data[0]) + } + // Output: + // msg: 8, 0000 + // msg: 5, 0101 +} + +func TestMsgPipeUnblockWrite(t *testing.T) { +loop: + for i := 0; i < 100; i++ { + rw1, rw2 := MsgPipe() + done := make(chan struct{}) + go func() { + if err := rw1.EncodeMsg(1); err == nil { + t.Error("EncodeMsg returned nil error") + } else if err != ErrPipeClosed { + t.Error("EncodeMsg returned wrong error: got %v, want %v", err, ErrPipeClosed) + } + close(done) + }() + + // this call should ensure that EncodeMsg is waiting to + // deliver sometimes. if this isn't done, Close is likely to + // be executed before EncodeMsg starts and then we won't test + // all the cases. + runtime.Gosched() + + rw2.Close() + select { + case <-done: + case <-time.After(200 * time.Millisecond): + t.Errorf("write didn't unblock") + break loop + } + } +} + +// This test should panic if concurrent close isn't implemented correctly. +func TestMsgPipeConcurrentClose(t *testing.T) { + rw1, _ := MsgPipe() + for i := 0; i < 10; i++ { + go rw1.Close() + } +} -- cgit v1.2.3 From f0f672777866a524c36e767a6c313f93f574dd8c Mon Sep 17 00:00:00 2001 From: Felix Lange Date: Fri, 12 Dec 2014 11:58:39 +0100 Subject: p2p: use an error type for disconnect requests Test-tastic. --- p2p/peer_error.go | 9 +++++++++ p2p/protocol.go | 3 +-- 2 files changed, 10 insertions(+), 2 deletions(-) (limited to 'p2p') diff --git a/p2p/peer_error.go b/p2p/peer_error.go index 88b870fbd..0eb7ec838 100644 --- a/p2p/peer_error.go +++ b/p2p/peer_error.go @@ -100,7 +100,16 @@ func (d DiscReason) String() string { return discReasonToString[d] } +type discRequestedError DiscReason + +func (err discRequestedError) Error() string { + return fmt.Sprintf("disconnect requested: %v", DiscReason(err)) +} + func discReasonForError(err error) DiscReason { + if reason, ok := err.(discRequestedError); ok { + return DiscReason(reason) + } peerError, ok := err.(*peerError) if !ok { return DiscSubprotocolError diff --git a/p2p/protocol.go b/p2p/protocol.go index 5af586f13..3f52205f5 100644 --- a/p2p/protocol.go +++ b/p2p/protocol.go @@ -158,8 +158,7 @@ func (bp *baseProtocol) handle(rw MsgReadWriter) error { if err := msg.Decode(&reason); err != nil { return err } - bp.peer.Disconnect(reason[0]) - return nil + return discRequestedError(reason[0]) case pingMsg: return bp.rw.EncodeMsg(pongMsg) -- cgit v1.2.3 From da900f94358a9b293a286066b0922a6f7b5d571c Mon Sep 17 00:00:00 2001 From: Felix Lange Date: Fri, 12 Dec 2014 11:39:29 +0100 Subject: p2p: add test for base protocol disconnect --- p2p/protocol_test.go | 58 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 p2p/protocol_test.go (limited to 'p2p') diff --git a/p2p/protocol_test.go b/p2p/protocol_test.go new file mode 100644 index 000000000..65f26fb12 --- /dev/null +++ b/p2p/protocol_test.go @@ -0,0 +1,58 @@ +package p2p + +import ( + "fmt" + "testing" +) + +func TestBaseProtocolDisconnect(t *testing.T) { + peer := NewPeer(NewSimpleClientIdentity("p1", "", "", "foo"), nil) + peer.ourID = NewSimpleClientIdentity("p2", "", "", "bar") + peer.pubkeyHook = func(*peerAddr) error { return nil } + + rw1, rw2 := MsgPipe() + done := make(chan struct{}) + go func() { + if err := expectMsg(rw2, handshakeMsg); err != nil { + t.Error(err) + } + err := rw2.EncodeMsg(handshakeMsg, + baseProtocolVersion, + "", + []interface{}{}, + 0, + make([]byte, 64), + ) + if err != nil { + t.Error(err) + } + if err := expectMsg(rw2, getPeersMsg); err != nil { + t.Error(err) + } + if err := rw2.EncodeMsg(discMsg, DiscQuitting); err != nil { + t.Error(err) + } + close(done) + }() + + if err := runBaseProtocol(peer, rw1); err == nil { + t.Errorf("base protocol returned without error") + } else if reason, ok := err.(discRequestedError); !ok || reason != DiscQuitting { + t.Errorf("base protocol returned wrong error: %v", err) + } + <-done +} + +func expectMsg(r MsgReader, code uint64) error { + msg, err := r.ReadMsg() + if err != nil { + return err + } + if err := msg.Discard(); err != nil { + return err + } + if msg.Code != code { + return fmt.Errorf("wrong message code: got %d, expected %d", msg.Code, code) + } + return nil +} -- cgit v1.2.3 From 56dac74f71d4d1fffa3e240c7d3cee803dec788d Mon Sep 17 00:00:00 2001 From: obscuren Date: Mon, 15 Dec 2014 13:00:09 +0100 Subject: made mist in a compilable, workable state using the new refactored packages --- p2p/server.go | 1 + 1 file changed, 1 insertion(+) (limited to 'p2p') diff --git a/p2p/server.go b/p2p/server.go index 8a6087566..9353e12ea 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -416,6 +416,7 @@ func (srv *Server) verifyPeer(addr *peerAddr) error { return nil } +// TODO replace with "Set" type Blacklist interface { Get([]byte) (bool, error) Put([]byte) error -- cgit v1.2.3 From aa3b91b8026c665ee53f768f8b94c3abe1713bf6 Mon Sep 17 00:00:00 2001 From: Felix Lange Date: Mon, 15 Dec 2014 22:33:18 +0100 Subject: p2p: fix call to Server.removePeer (might help with #209) --- p2p/server.go | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) (limited to 'p2p') diff --git a/p2p/server.go b/p2p/server.go index 9353e12ea..326781234 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -246,12 +246,7 @@ func (srv *Server) Stop() { func (srv *Server) discLoop() { for peer := range srv.peerDisconnect { - // peer has just disconnected. free up its slot. - srvlog.Infof("%v is gone", peer) - srv.peerSlots <- peer.slot - srv.lock.Lock() - srv.peers[peer.slot] = nil - srv.lock.Unlock() + srv.removePeer(peer) } } @@ -384,7 +379,7 @@ func (srv *Server) addPeer(conn net.Conn, desc *peerAddr, slot int) *Peer { func (srv *Server) removePeer(peer *Peer) { srv.lock.Lock() defer srv.lock.Unlock() - srvlog.Debugf("Removing peer %v %v (slot %v)\n", peer, peer.slot) + srvlog.Debugf("Removing %v (slot %v)\n", peer, peer.slot) if srv.peers[peer.slot] != peer { srvlog.Warnln("Invalid peer to remove:", peer) return -- cgit v1.2.3 From 09841b1c9b2553a4572590128580df37c8fa83ad Mon Sep 17 00:00:00 2001 From: obscuren Date: Sun, 4 Jan 2015 14:20:16 +0100 Subject: Cleaned up some of that util --- p2p/client_identity.go | 4 ++-- p2p/nat.go | 23 +++++++++++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) create mode 100644 p2p/nat.go (limited to 'p2p') diff --git a/p2p/client_identity.go b/p2p/client_identity.go index bc865b63b..f15fd01bf 100644 --- a/p2p/client_identity.go +++ b/p2p/client_identity.go @@ -17,10 +17,10 @@ type SimpleClientIdentity struct { customIdentifier string os string implementation string - pubkey string + pubkey []byte } -func NewSimpleClientIdentity(clientIdentifier string, version string, customIdentifier string, pubkey string) *SimpleClientIdentity { +func NewSimpleClientIdentity(clientIdentifier string, version string, customIdentifier string, pubkey []byte) *SimpleClientIdentity { clientIdentity := &SimpleClientIdentity{ clientIdentifier: clientIdentifier, version: version, diff --git a/p2p/nat.go b/p2p/nat.go new file mode 100644 index 000000000..9b771c3e8 --- /dev/null +++ b/p2p/nat.go @@ -0,0 +1,23 @@ +package p2p + +import ( + "fmt" + "net" +) + +func ParseNAT(natType string, gateway string) (nat NAT, err error) { + switch natType { + case "UPNP": + nat = UPNP() + case "PMP": + ip := net.ParseIP(gateway) + if ip == nil { + return nil, fmt.Errorf("cannot resolve PMP gateway IP %s", gateway) + } + nat = PMP(ip) + case "": + default: + return nil, fmt.Errorf("unrecognised NAT type '%s'", natType) + } + return +} -- cgit v1.2.3 From 6abf8ef78f1474fdeb7a6a6ce084bf994cc055f2 Mon Sep 17 00:00:00 2001 From: obscuren Date: Mon, 5 Jan 2015 17:10:42 +0100 Subject: Merge --- p2p/client_identity_test.go | 2 +- p2p/message.go | 10 +++++- p2p/peer.go | 28 ++++++++++++++-- p2p/peer_test.go | 14 +++++--- p2p/protocol.go | 54 +++++++---------------------- p2p/protocol_test.go | 82 +++++++++++++++++++++++++++++++++++++++++++-- p2p/server.go | 6 +++- p2p/server_test.go | 2 +- 8 files changed, 144 insertions(+), 54 deletions(-) (limited to 'p2p') diff --git a/p2p/client_identity_test.go b/p2p/client_identity_test.go index 40b0e6f5e..7248a7b1a 100644 --- a/p2p/client_identity_test.go +++ b/p2p/client_identity_test.go @@ -7,7 +7,7 @@ import ( ) func TestClientIdentity(t *testing.T) { - clientIdentity := NewSimpleClientIdentity("Ethereum(G)", "0.5.16", "test", "pubkey") + clientIdentity := NewSimpleClientIdentity("Ethereum(G)", "0.5.16", "test", []byte("pubkey")) clientString := clientIdentity.String() expected := fmt.Sprintf("Ethereum(G)/v0.5.16/test/%s/%s", runtime.GOOS, runtime.Version()) if clientString != expected { diff --git a/p2p/message.go b/p2p/message.go index f5418ff47..a6f62ec4c 100644 --- a/p2p/message.go +++ b/p2p/message.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" "errors" + "fmt" "io" "io/ioutil" "math/big" @@ -49,7 +50,14 @@ func encodePayload(params ...interface{}) []byte { // For the decoding rules, please see package rlp. func (msg Msg) Decode(val interface{}) error { s := rlp.NewListStream(msg.Payload, uint64(msg.Size)) - return s.Decode(val) + if err := s.Decode(val); err != nil { + return newPeerError(errInvalidMsg, "(code %#x) (size %d) %v", msg.Code, msg.Size, err) + } + return nil +} + +func (msg Msg) String() string { + return fmt.Sprintf("msg #%v (%v bytes)", msg.Code, msg.Size) } // Discard reads any remaining payload data into a black hole. diff --git a/p2p/peer.go b/p2p/peer.go index 86c4d7ab5..0d7eec9f4 100644 --- a/p2p/peer.go +++ b/p2p/peer.go @@ -45,8 +45,8 @@ func (d peerAddr) String() string { return fmt.Sprintf("%v:%d", d.IP, d.Port) } -func (d peerAddr) RlpData() interface{} { - return []interface{}{d.IP, d.Port, d.Pubkey} +func (d *peerAddr) RlpData() interface{} { + return []interface{}{string(d.IP), d.Port, d.Pubkey} } // Peer represents a remote peer. @@ -426,7 +426,7 @@ func (rw *proto) WriteMsg(msg Msg) error { } func (rw *proto) EncodeMsg(code uint64, data ...interface{}) error { - return rw.WriteMsg(NewMsg(code, data)) + return rw.WriteMsg(NewMsg(code, data...)) } func (rw *proto) ReadMsg() (Msg, error) { @@ -460,3 +460,25 @@ func (r *eofSignal) Read(buf []byte) (int, error) { } return n, err } + +func (peer *Peer) PeerList() []interface{} { + peers := peer.otherPeers() + ds := make([]interface{}, 0, len(peers)) + for _, p := range peers { + p.infolock.Lock() + addr := p.listenAddr + p.infolock.Unlock() + // filter out this peer and peers that are not listening or + // have not completed the handshake. + // TODO: track previously sent peers and exclude them as well. + if p == peer || addr == nil { + continue + } + ds = append(ds, addr) + } + ourAddr := peer.ourListenAddr + if ourAddr != nil && !ourAddr.IP.IsLoopback() && !ourAddr.IP.IsUnspecified() { + ds = append(ds, ourAddr) + } + return ds +} diff --git a/p2p/peer_test.go b/p2p/peer_test.go index f7759786e..5b9e9e784 100644 --- a/p2p/peer_test.go +++ b/p2p/peer_test.go @@ -30,9 +30,8 @@ var discard = Protocol{ func testPeer(protos []Protocol) (net.Conn, *Peer, <-chan error) { conn1, conn2 := net.Pipe() - id := NewSimpleClientIdentity("test", "0", "0", "public key") peer := newPeer(conn1, protos, nil) - peer.ourID = id + peer.ourID = &peerId{} peer.pubkeyHook = func(*peerAddr) error { return nil } errc := make(chan error, 1) go func() { @@ -130,7 +129,7 @@ func TestPeerProtoEncodeMsg(t *testing.T) { if err := rw.EncodeMsg(2); err == nil { t.Error("expected error for out-of-range msg code, got nil") } - if err := rw.EncodeMsg(1); err != nil { + if err := rw.EncodeMsg(1, "foo", "bar"); err != nil { t.Errorf("write error: %v", err) } return nil @@ -148,6 +147,13 @@ func TestPeerProtoEncodeMsg(t *testing.T) { if msg.Code != 17 { t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17) } + var data []string + if err := msg.Decode(&data); err != nil { + t.Errorf("payload decode error: %v", err) + } + if !reflect.DeepEqual(data, []string{"foo", "bar"}) { + t.Errorf("payload RLP mismatch, got %#v, want %#v", data, []string{"foo", "bar"}) + } } func TestPeerWrite(t *testing.T) { @@ -226,8 +232,8 @@ func TestPeerActivity(t *testing.T) { } func TestNewPeer(t *testing.T) { - id := NewSimpleClientIdentity("clientid", "version", "customid", "pubkey") caps := []Cap{{"foo", 2}, {"bar", 3}} + id := &peerId{} p := NewPeer(id, caps) if !reflect.DeepEqual(p.Caps(), caps) { t.Errorf("Caps mismatch: got %v, expected %v", p.Caps(), caps) diff --git a/p2p/protocol.go b/p2p/protocol.go index 3f52205f5..dd8cbc4ec 100644 --- a/p2p/protocol.go +++ b/p2p/protocol.go @@ -3,8 +3,6 @@ package p2p import ( "bytes" "time" - - "github.com/ethereum/go-ethereum/ethutil" ) // Protocol represents a P2P subprotocol implementation. @@ -89,20 +87,25 @@ type baseProtocol struct { func runBaseProtocol(peer *Peer, rw MsgReadWriter) error { bp := &baseProtocol{rw, peer} - if err := bp.doHandshake(rw); err != nil { + errc := make(chan error, 1) + go func() { errc <- rw.WriteMsg(bp.handshakeMsg()) }() + if err := bp.readHandshake(); err != nil { + return err + } + // handle write error + if err := <-errc; err != nil { return err } // run main loop - quit := make(chan error, 1) go func() { for { if err := bp.handle(rw); err != nil { - quit <- err + errc <- err break } } }() - return bp.loop(quit) + return bp.loop(errc) } var pingTimeout = 2 * time.Second @@ -166,7 +169,7 @@ func (bp *baseProtocol) handle(rw MsgReadWriter) error { case pongMsg: case getPeersMsg: - peers := bp.peerList() + peers := bp.peer.PeerList() // this is dangerous. the spec says that we should _delay_ // sending the response if no new information is available. // this means that would need to send a response later when @@ -174,7 +177,7 @@ func (bp *baseProtocol) handle(rw MsgReadWriter) error { // // TODO: add event mechanism to notify baseProtocol for new peers if len(peers) > 0 { - return bp.rw.EncodeMsg(peersMsg, peers) + return bp.rw.EncodeMsg(peersMsg, peers...) } case peersMsg: @@ -193,14 +196,9 @@ func (bp *baseProtocol) handle(rw MsgReadWriter) error { return nil } -func (bp *baseProtocol) doHandshake(rw MsgReadWriter) error { - // send our handshake - if err := rw.WriteMsg(bp.handshakeMsg()); err != nil { - return err - } - +func (bp *baseProtocol) readHandshake() error { // read and handle remote handshake - msg, err := rw.ReadMsg() + msg, err := bp.rw.ReadMsg() if err != nil { return err } @@ -210,12 +208,10 @@ func (bp *baseProtocol) doHandshake(rw MsgReadWriter) error { if msg.Size > baseProtocolMaxMsgSize { return newPeerError(errMisc, "message too big") } - var hs handshake if err := msg.Decode(&hs); err != nil { return err } - // validate handshake info if hs.Version != baseProtocolVersion { return newPeerError(errP2PVersionMismatch, "Require protocol %d, received %d\n", @@ -238,9 +234,7 @@ func (bp *baseProtocol) doHandshake(rw MsgReadWriter) error { if err := bp.peer.pubkeyHook(pa); err != nil { return newPeerError(errPubkeyForbidden, "%v", err) } - // TODO: remove Caps with empty name - var addr *peerAddr if hs.ListenPort != 0 { addr = newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID) @@ -270,25 +264,3 @@ func (bp *baseProtocol) handshakeMsg() Msg { bp.peer.ourID.Pubkey()[1:], ) } - -func (bp *baseProtocol) peerList() []ethutil.RlpEncodable { - peers := bp.peer.otherPeers() - ds := make([]ethutil.RlpEncodable, 0, len(peers)) - for _, p := range peers { - p.infolock.Lock() - addr := p.listenAddr - p.infolock.Unlock() - // filter out this peer and peers that are not listening or - // have not completed the handshake. - // TODO: track previously sent peers and exclude them as well. - if p == bp.peer || addr == nil { - continue - } - ds = append(ds, addr) - } - ourAddr := bp.peer.ourListenAddr - if ourAddr != nil && !ourAddr.IP.IsLoopback() && !ourAddr.IP.IsUnspecified() { - ds = append(ds, ourAddr) - } - return ds -} diff --git a/p2p/protocol_test.go b/p2p/protocol_test.go index 65f26fb12..ce25b3e1b 100644 --- a/p2p/protocol_test.go +++ b/p2p/protocol_test.go @@ -2,12 +2,89 @@ package p2p import ( "fmt" + "net" + "reflect" "testing" + + "github.com/ethereum/go-ethereum/crypto" ) +type peerId struct { + pubkey []byte +} + +func (self *peerId) String() string { + return fmt.Sprintf("test peer %x", self.Pubkey()[:4]) +} + +func (self *peerId) Pubkey() (pubkey []byte) { + pubkey = self.pubkey + if len(pubkey) == 0 { + pubkey = crypto.GenerateNewKeyPair().PublicKey + self.pubkey = pubkey + } + return +} + +func newTestPeer() (peer *Peer) { + peer = NewPeer(&peerId{}, []Cap{}) + peer.pubkeyHook = func(*peerAddr) error { return nil } + peer.ourID = &peerId{} + peer.listenAddr = &peerAddr{} + peer.otherPeers = func() []*Peer { return nil } + return +} + +func TestBaseProtocolPeers(t *testing.T) { + cannedPeerList := []*peerAddr{ + {IP: net.ParseIP("1.2.3.4"), Port: 2222, Pubkey: []byte{}}, + {IP: net.ParseIP("5.6.7.8"), Port: 3333, Pubkey: []byte{}}, + } + var ownAddr *peerAddr = &peerAddr{IP: net.ParseIP("1.3.5.7"), Port: 1111, Pubkey: []byte{}} + rw1, rw2 := MsgPipe() + // run matcher, close pipe when addresses have arrived + addrChan := make(chan *peerAddr, len(cannedPeerList)) + go func() { + for _, want := range cannedPeerList { + got := <-addrChan + t.Logf("got peer: %+v", got) + if !reflect.DeepEqual(want, got) { + t.Errorf("mismatch: got %#v, want %#v", got, want) + } + } + close(addrChan) + var own []*peerAddr + var got *peerAddr + for got = range addrChan { + own = append(own, got) + } + if len(own) != 1 || !reflect.DeepEqual(ownAddr, own[0]) { + t.Errorf("mismatch: peers own address is incorrectly or not given, got %v, want %#v", ownAddr) + } + rw2.Close() + }() + // run first peer + peer1 := newTestPeer() + peer1.ourListenAddr = ownAddr + peer1.otherPeers = func() []*Peer { + pl := make([]*Peer, len(cannedPeerList)) + for i, addr := range cannedPeerList { + pl[i] = &Peer{listenAddr: addr} + } + return pl + } + go runBaseProtocol(peer1, rw1) + // run second peer + peer2 := newTestPeer() + peer2.newPeerAddr = addrChan // feed peer suggestions into matcher + if err := runBaseProtocol(peer2, rw2); err != ErrPipeClosed { + t.Errorf("peer2 terminated with unexpected error: %v", err) + } +} + func TestBaseProtocolDisconnect(t *testing.T) { - peer := NewPeer(NewSimpleClientIdentity("p1", "", "", "foo"), nil) - peer.ourID = NewSimpleClientIdentity("p2", "", "", "bar") + peer := NewPeer(&peerId{}, nil) + peer.ourID = &peerId{} peer.pubkeyHook = func(*peerAddr) error { return nil } rw1, rw2 := MsgPipe() @@ -32,6 +109,7 @@ func TestBaseProtocolDisconnect(t *testing.T) { if err := rw2.EncodeMsg(discMsg, DiscQuitting); err != nil { t.Error(err) } + close(done) }() diff --git a/p2p/server.go b/p2p/server.go index 326781234..cfff442f7 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -113,9 +113,11 @@ func (srv *Server) PeerCount() int { // SuggestPeer injects an address into the outbound address pool. func (srv *Server) SuggestPeer(ip net.IP, port int, nodeID []byte) { + addr := &peerAddr{ip, uint64(port), nodeID} select { - case srv.peerConnect <- &peerAddr{ip, uint64(port), nodeID}: + case srv.peerConnect <- addr: default: // don't block + srvlog.Warnf("peer suggestion %v ignored", addr) } } @@ -258,6 +260,7 @@ func (srv *Server) listenLoop() { for { select { case slot := <-srv.peerSlots: + srvlog.Debugf("grabbed slot %v for listening", slot) conn, err := srv.listener.Accept() if err != nil { srv.peerSlots <- slot @@ -330,6 +333,7 @@ func (srv *Server) dialLoop() { case desc := <-suggest: // candidate peer found, will dial out asyncronously // if connection fails slot will be released + srvlog.Infof("dial %v (%v)", desc, *slot) go srv.dialPeer(desc, *slot) // we can watch if more peers needed in the next loop slots = srv.peerSlots diff --git a/p2p/server_test.go b/p2p/server_test.go index 5c0d08d39..ceb89e3f7 100644 --- a/p2p/server_test.go +++ b/p2p/server_test.go @@ -11,7 +11,7 @@ import ( func startTestServer(t *testing.T, pf peerFunc) *Server { server := &Server{ - Identity: NewSimpleClientIdentity("clientIdentifier", "version", "customIdentifier", "pubkey"), + Identity: &peerId{}, MaxPeers: 10, ListenAddr: "127.0.0.1:0", newPeerFunc: pf, -- cgit v1.2.3 From eb0e7b1b8120852a1d56aa0ebd3a98e652965635 Mon Sep 17 00:00:00 2001 From: Felix Lange Date: Tue, 6 Jan 2015 11:35:09 +0100 Subject: eth, p2p: remove EncodeMsg from p2p.MsgWriter ...and make it a top-level function instead. The original idea behind having EncodeMsg in the interface was that implementations might be able to encode RLP data to their underlying writer directly instead of buffering the encoded data. The encoder will buffer anyway, so that doesn't matter anymore. Given the recent problems with EncodeMsg (copy-pasted implementation bug) I'd rather implement once, correctly. --- p2p/message.go | 20 +++++++++----------- p2p/message_test.go | 6 +++--- p2p/peer_test.go | 4 ++-- p2p/protocol.go | 10 +++++----- p2p/protocol_test.go | 4 ++-- 5 files changed, 21 insertions(+), 23 deletions(-) (limited to 'p2p') diff --git a/p2p/message.go b/p2p/message.go index a6f62ec4c..daf2bf05c 100644 --- a/p2p/message.go +++ b/p2p/message.go @@ -71,14 +71,11 @@ type MsgReader interface { } type MsgWriter interface { - // WriteMsg sends an existing message. - // The Payload reader of the message is consumed. + // WriteMsg sends a message. It will block until the message's + // Payload has been consumed by the other end. + // // Note that messages can be sent only once. WriteMsg(Msg) error - - // EncodeMsg writes an RLP-encoded message with the given - // code and data elements. - EncodeMsg(code uint64, data ...interface{}) error } // MsgReadWriter provides reading and writing of encoded messages. @@ -87,6 +84,12 @@ type MsgReadWriter interface { MsgWriter } +// EncodeMsg writes an RLP-encoded message with the given code and +// data elements. +func EncodeMsg(w MsgWriter, code uint64, data ...interface{}) error { + return w.WriteMsg(NewMsg(code, data...)) +} + var magicToken = []byte{34, 64, 8, 145} func writeMsg(w io.Writer, msg Msg) error { @@ -209,11 +212,6 @@ func (p *MsgPipeRW) WriteMsg(msg Msg) error { return ErrPipeClosed } -// EncodeMsg is a convenient shorthand for sending an RLP-encoded message. -func (p *MsgPipeRW) EncodeMsg(code uint64, data ...interface{}) error { - return p.WriteMsg(NewMsg(code, data...)) -} - // ReadMsg returns a message sent on the other end of the pipe. func (p *MsgPipeRW) ReadMsg() (Msg, error) { if atomic.LoadInt32(p.closed) == 0 { diff --git a/p2p/message_test.go b/p2p/message_test.go index 066d2516d..5cde9abf5 100644 --- a/p2p/message_test.go +++ b/p2p/message_test.go @@ -75,8 +75,8 @@ func TestDecodeRealMsg(t *testing.T) { func ExampleMsgPipe() { rw1, rw2 := MsgPipe() go func() { - rw1.EncodeMsg(8, []byte{0, 0}) - rw1.EncodeMsg(5, []byte{1, 1}) + EncodeMsg(rw1, 8, []byte{0, 0}) + EncodeMsg(rw1, 5, []byte{1, 1}) rw1.Close() }() @@ -100,7 +100,7 @@ loop: rw1, rw2 := MsgPipe() done := make(chan struct{}) go func() { - if err := rw1.EncodeMsg(1); err == nil { + if err := EncodeMsg(rw1, 1); err == nil { t.Error("EncodeMsg returned nil error") } else if err != ErrPipeClosed { t.Error("EncodeMsg returned wrong error: got %v, want %v", err, ErrPipeClosed) diff --git a/p2p/peer_test.go b/p2p/peer_test.go index 5b9e9e784..4ee88f112 100644 --- a/p2p/peer_test.go +++ b/p2p/peer_test.go @@ -126,10 +126,10 @@ func TestPeerProtoEncodeMsg(t *testing.T) { Name: "a", Length: 2, Run: func(peer *Peer, rw MsgReadWriter) error { - if err := rw.EncodeMsg(2); err == nil { + if err := EncodeMsg(rw, 2); err == nil { t.Error("expected error for out-of-range msg code, got nil") } - if err := rw.EncodeMsg(1, "foo", "bar"); err != nil { + if err := EncodeMsg(rw, 1, "foo", "bar"); err != nil { t.Errorf("write error: %v", err) } return nil diff --git a/p2p/protocol.go b/p2p/protocol.go index dd8cbc4ec..969937076 100644 --- a/p2p/protocol.go +++ b/p2p/protocol.go @@ -119,14 +119,14 @@ func (bp *baseProtocol) loop(quit <-chan error) error { getPeersTick := time.NewTicker(10 * time.Second) defer getPeersTick.Stop() - err := bp.rw.EncodeMsg(getPeersMsg) + err := EncodeMsg(bp.rw, getPeersMsg) for err == nil { select { case err = <-quit: return err case <-getPeersTick.C: - err = bp.rw.EncodeMsg(getPeersMsg) + err = EncodeMsg(bp.rw, getPeersMsg) case event := <-activity.Chan(): ping.Reset(pingTimeout) lastActive = event.(time.Time) @@ -134,7 +134,7 @@ func (bp *baseProtocol) loop(quit <-chan error) error { if lastActive.Add(pingTimeout * 2).Before(t) { err = newPeerError(errPingTimeout, "") } else if lastActive.Add(pingTimeout).Before(t) { - err = bp.rw.EncodeMsg(pingMsg) + err = EncodeMsg(bp.rw, pingMsg) } } } @@ -164,7 +164,7 @@ func (bp *baseProtocol) handle(rw MsgReadWriter) error { return discRequestedError(reason[0]) case pingMsg: - return bp.rw.EncodeMsg(pongMsg) + return EncodeMsg(bp.rw, pongMsg) case pongMsg: @@ -177,7 +177,7 @@ func (bp *baseProtocol) handle(rw MsgReadWriter) error { // // TODO: add event mechanism to notify baseProtocol for new peers if len(peers) > 0 { - return bp.rw.EncodeMsg(peersMsg, peers...) + return EncodeMsg(bp.rw, peersMsg, peers...) } case peersMsg: diff --git a/p2p/protocol_test.go b/p2p/protocol_test.go index ce25b3e1b..ba5e95c02 100644 --- a/p2p/protocol_test.go +++ b/p2p/protocol_test.go @@ -93,7 +93,7 @@ func TestBaseProtocolDisconnect(t *testing.T) { if err := expectMsg(rw2, handshakeMsg); err != nil { t.Error(err) } - err := rw2.EncodeMsg(handshakeMsg, + err := EncodeMsg(rw2, handshakeMsg, baseProtocolVersion, "", []interface{}{}, @@ -106,7 +106,7 @@ func TestBaseProtocolDisconnect(t *testing.T) { if err := expectMsg(rw2, getPeersMsg); err != nil { t.Error(err) } - if err := rw2.EncodeMsg(discMsg, DiscQuitting); err != nil { + if err := EncodeMsg(rw2, discMsg, DiscQuitting); err != nil { t.Error(err) } -- cgit v1.2.3 From b0ff946b55c23f0fffc50a700bcb255f95855afc Mon Sep 17 00:00:00 2001 From: Felix Lange Date: Tue, 6 Jan 2015 12:14:29 +0100 Subject: p2p: move peerList back into baseProtocol It had been moved to Peer, probably for debugging. --- p2p/peer.go | 22 ---------------------- p2p/protocol.go | 24 +++++++++++++++++++++++- 2 files changed, 23 insertions(+), 23 deletions(-) (limited to 'p2p') diff --git a/p2p/peer.go b/p2p/peer.go index 0d7eec9f4..2380a3285 100644 --- a/p2p/peer.go +++ b/p2p/peer.go @@ -460,25 +460,3 @@ func (r *eofSignal) Read(buf []byte) (int, error) { } return n, err } - -func (peer *Peer) PeerList() []interface{} { - peers := peer.otherPeers() - ds := make([]interface{}, 0, len(peers)) - for _, p := range peers { - p.infolock.Lock() - addr := p.listenAddr - p.infolock.Unlock() - // filter out this peer and peers that are not listening or - // have not completed the handshake. - // TODO: track previously sent peers and exclude them as well. - if p == peer || addr == nil { - continue - } - ds = append(ds, addr) - } - ourAddr := peer.ourListenAddr - if ourAddr != nil && !ourAddr.IP.IsLoopback() && !ourAddr.IP.IsUnspecified() { - ds = append(ds, ourAddr) - } - return ds -} diff --git a/p2p/protocol.go b/p2p/protocol.go index 969937076..1d121a885 100644 --- a/p2p/protocol.go +++ b/p2p/protocol.go @@ -169,7 +169,7 @@ func (bp *baseProtocol) handle(rw MsgReadWriter) error { case pongMsg: case getPeersMsg: - peers := bp.peer.PeerList() + peers := bp.peerList() // this is dangerous. the spec says that we should _delay_ // sending the response if no new information is available. // this means that would need to send a response later when @@ -264,3 +264,25 @@ func (bp *baseProtocol) handshakeMsg() Msg { bp.peer.ourID.Pubkey()[1:], ) } + +func (bp *baseProtocol) peerList() []interface{} { + peers := bp.peer.otherPeers() + ds := make([]interface{}, 0, len(peers)) + for _, p := range peers { + p.infolock.Lock() + addr := p.listenAddr + p.infolock.Unlock() + // filter out this peer and peers that are not listening or + // have not completed the handshake. + // TODO: track previously sent peers and exclude them as well. + if p == bp.peer || addr == nil { + continue + } + ds = append(ds, addr) + } + ourAddr := bp.peer.ourListenAddr + if ourAddr != nil && !ourAddr.IP.IsLoopback() && !ourAddr.IP.IsUnspecified() { + ds = append(ds, ourAddr) + } + return ds +} -- cgit v1.2.3 From 3caa4ad1baba3019c06733e1a80d78d9a57137bb Mon Sep 17 00:00:00 2001 From: Felix Lange Date: Tue, 6 Jan 2015 12:15:51 +0100 Subject: p2p: improve test for peers message The test now checks that the number of of addresses is correct and terminates cleanly. --- p2p/protocol_test.go | 64 +++++++++++++++++++++++++++++++++++----------------- 1 file changed, 43 insertions(+), 21 deletions(-) (limited to 'p2p') diff --git a/p2p/protocol_test.go b/p2p/protocol_test.go index ba5e95c02..b1d10ac53 100644 --- a/p2p/protocol_test.go +++ b/p2p/protocol_test.go @@ -4,6 +4,7 @@ import ( "fmt" "net" "reflect" + "sync" "testing" "github.com/ethereum/go-ethereum/crypto" @@ -36,50 +37,71 @@ func newTestPeer() (peer *Peer) { } func TestBaseProtocolPeers(t *testing.T) { - cannedPeerList := []*peerAddr{ + peerList := []*peerAddr{ {IP: net.ParseIP("1.2.3.4"), Port: 2222, Pubkey: []byte{}}, {IP: net.ParseIP("5.6.7.8"), Port: 3333, Pubkey: []byte{}}, } - var ownAddr *peerAddr = &peerAddr{IP: net.ParseIP("1.3.5.7"), Port: 1111, Pubkey: []byte{}} + listenAddr := &peerAddr{IP: net.ParseIP("1.3.5.7"), Port: 1111, Pubkey: []byte{}} rw1, rw2 := MsgPipe() + defer rw1.Close() + wg := new(sync.WaitGroup) + // run matcher, close pipe when addresses have arrived - addrChan := make(chan *peerAddr, len(cannedPeerList)) + numPeers := len(peerList) + 1 + addrChan := make(chan *peerAddr) + wg.Add(1) go func() { - for _, want := range cannedPeerList { - got := <-addrChan - t.Logf("got peer: %+v", got) + i := 0 + for got := range addrChan { + var want *peerAddr + switch { + case i < len(peerList): + want = peerList[i] + case i == len(peerList): + want = listenAddr // listenAddr should be the last thing sent + } + t.Logf("got peer %d/%d: %v", i+1, numPeers, got) if !reflect.DeepEqual(want, got) { - t.Errorf("mismatch: got %#v, want %#v", got, want) + t.Errorf("mismatch: got %+v, want %+v", got, want) + } + i++ + if i == numPeers { + break } } - close(addrChan) - var own []*peerAddr - var got *peerAddr - for got = range addrChan { - own = append(own, got) - } - if len(own) != 1 || !reflect.DeepEqual(ownAddr, own[0]) { - t.Errorf("mismatch: peers own address is incorrectly or not given, got %v, want %#v", ownAddr) + if i != numPeers { + t.Errorf("wrong number of peers received: got %d, want %d", i, numPeers) } - rw2.Close() + rw1.Close() + wg.Done() }() - // run first peer + + // run first peer (in background) peer1 := newTestPeer() - peer1.ourListenAddr = ownAddr + peer1.ourListenAddr = listenAddr peer1.otherPeers = func() []*Peer { - pl := make([]*Peer, len(cannedPeerList)) - for i, addr := range cannedPeerList { + pl := make([]*Peer, len(peerList)) + for i, addr := range peerList { pl[i] = &Peer{listenAddr: addr} } return pl } - go runBaseProtocol(peer1, rw1) + wg.Add(1) + go func() { + runBaseProtocol(peer1, rw1) + wg.Done() + }() + // run second peer peer2 := newTestPeer() peer2.newPeerAddr = addrChan // feed peer suggestions into matcher if err := runBaseProtocol(peer2, rw2); err != ErrPipeClosed { t.Errorf("peer2 terminated with unexpected error: %v", err) } + + // terminate matcher + close(addrChan) + wg.Wait() } func TestBaseProtocolDisconnect(t *testing.T) { -- cgit v1.2.3 From 8d1637f5677304b3d540648cf9b4ff9f92ad7f1f Mon Sep 17 00:00:00 2001 From: obscuren Date: Mon, 19 Jan 2015 11:21:46 +0100 Subject: Moved connection errors to DebugDetail level --- p2p/server.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'p2p') diff --git a/p2p/server.go b/p2p/server.go index cfff442f7..4fd1f7d03 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -333,7 +333,7 @@ func (srv *Server) dialLoop() { case desc := <-suggest: // candidate peer found, will dial out asyncronously // if connection fails slot will be released - srvlog.Infof("dial %v (%v)", desc, *slot) + srvlog.DebugDetailf("dial %v (%v)", desc, *slot) go srv.dialPeer(desc, *slot) // we can watch if more peers needed in the next loop slots = srv.peerSlots @@ -355,7 +355,7 @@ func (srv *Server) dialPeer(desc *peerAddr, slot int) { srvlog.Debugf("Dialing %v (slot %d)\n", desc, slot) conn, err := srv.Dialer.Dial(desc.Network(), desc.String()) if err != nil { - srvlog.Errorf("Dial error: %v", err) + srvlog.DebugDetailf("dial error: %v", err) srv.peerSlots <- slot return } -- cgit v1.2.3 From 67f9783e6a0fa5613a031e05549b92adbee57399 Mon Sep 17 00:00:00 2001 From: obscuren Date: Thu, 22 Jan 2015 00:35:00 +0100 Subject: Moved `obscuren` secp256k1-go --- p2p/testpoc7.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'p2p') diff --git a/p2p/testpoc7.go b/p2p/testpoc7.go index c0cc5c544..63b996f79 100644 --- a/p2p/testpoc7.go +++ b/p2p/testpoc7.go @@ -8,9 +8,9 @@ import ( "net" "os" + "github.com/ethereum/go-ethereum/crypto/secp256k1" "github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/p2p" - "github.com/obscuren/secp256k1-go" ) func main() { -- cgit v1.2.3