diff options
author | obscuren <geffobscura@gmail.com> | 2014-12-05 23:27:11 +0800 |
---|---|---|
committer | obscuren <geffobscura@gmail.com> | 2014-12-05 23:27:11 +0800 |
commit | 384b8c75f07b6811aa3012ad52a44844b3ab6e52 (patch) | |
tree | 9b1a96d7c65c4dcb5515eddf6c2b03e7e6b0d397 | |
parent | 710360bab61178cf7fbc52213ec4c612be37ad18 (diff) | |
parent | cfd7e74c25fa7d1b443f8527fca8afad14ef4419 (diff) | |
download | go-tangerine-384b8c75f07b6811aa3012ad52a44844b3ab6e52.tar go-tangerine-384b8c75f07b6811aa3012ad52a44844b3ab6e52.tar.gz go-tangerine-384b8c75f07b6811aa3012ad52a44844b3ab6e52.tar.bz2 go-tangerine-384b8c75f07b6811aa3012ad52a44844b3ab6e52.tar.lz go-tangerine-384b8c75f07b6811aa3012ad52a44844b3ab6e52.tar.xz go-tangerine-384b8c75f07b6811aa3012ad52a44844b3ab6e52.tar.zst go-tangerine-384b8c75f07b6811aa3012ad52a44844b3ab6e52.zip |
Merge branch 'feature/p2p-protocol-interface' of https://github.com/fjl/go-ethereum into fjl-feature/p2p-protocol-interface
-rw-r--r-- | p2p/client_identity.go | 6 | ||||
-rw-r--r-- | p2p/connection.go | 275 | ||||
-rw-r--r-- | p2p/connection_test.go | 222 | ||||
-rw-r--r-- | p2p/message.go | 186 | ||||
-rw-r--r-- | p2p/message_test.go | 78 | ||||
-rw-r--r-- | p2p/messenger.go | 220 | ||||
-rw-r--r-- | p2p/messenger_test.go | 147 | ||||
-rw-r--r-- | p2p/natpmp.go | 34 | ||||
-rw-r--r-- | p2p/natupnp.go | 198 | ||||
-rw-r--r-- | p2p/network.go | 196 | ||||
-rw-r--r-- | p2p/peer.go | 494 | ||||
-rw-r--r-- | p2p/peer_error.go | 152 | ||||
-rw-r--r-- | p2p/peer_error_handler.go | 101 | ||||
-rw-r--r-- | p2p/peer_error_handler_test.go | 34 | ||||
-rw-r--r-- | p2p/peer_test.go | 283 | ||||
-rw-r--r-- | p2p/protocol.go | 455 | ||||
-rw-r--r-- | p2p/server.go | 735 | ||||
-rw-r--r-- | p2p/server_test.go | 309 | ||||
-rw-r--r-- | p2p/testlog_test.go | 28 | ||||
-rw-r--r-- | p2p/testpoc7.go | 40 | ||||
-rw-r--r-- | rlp/decode.go | 81 | ||||
-rw-r--r-- | rlp/decode_test.go | 89 |
22 files changed, 1985 insertions, 2378 deletions
diff --git a/p2p/client_identity.go b/p2p/client_identity.go index 236b23106..bc865b63b 100644 --- a/p2p/client_identity.go +++ b/p2p/client_identity.go @@ -5,10 +5,10 @@ import ( "runtime" ) -// should be used in Peer handleHandshake, incorporate Caps, ProtocolVersion, Pubkey etc. +// ClientIdentity represents the identity of a peer. type ClientIdentity interface { - String() string - Pubkey() []byte + String() string // human readable identity + Pubkey() []byte // 512-bit public key } type SimpleClientIdentity struct { diff --git a/p2p/connection.go b/p2p/connection.go deleted file mode 100644 index be366235d..000000000 --- a/p2p/connection.go +++ /dev/null @@ -1,275 +0,0 @@ -package p2p - -import ( - "bytes" - // "fmt" - "net" - "time" - - "github.com/ethereum/go-ethereum/ethutil" -) - -type Connection struct { - conn net.Conn - // conn NetworkConnection - timeout time.Duration - in chan []byte - out chan []byte - err chan *PeerError - closingIn chan chan bool - closingOut chan chan bool -} - -// const readBufferLength = 2 //for testing - -const readBufferLength = 1440 -const partialsQueueSize = 10 -const maxPendingQueueSize = 1 -const defaultTimeout = 500 - -var magicToken = []byte{34, 64, 8, 145} - -func (self *Connection) Open() { - go self.startRead() - go self.startWrite() -} - -func (self *Connection) Close() { - self.closeIn() - self.closeOut() -} - -func (self *Connection) closeIn() { - errc := make(chan bool) - self.closingIn <- errc - <-errc -} - -func (self *Connection) closeOut() { - errc := make(chan bool) - self.closingOut <- errc - <-errc -} - -func NewConnection(conn net.Conn, errchan chan *PeerError) *Connection { - return &Connection{ - conn: conn, - timeout: defaultTimeout, - in: make(chan []byte), - out: make(chan []byte), - err: errchan, - closingIn: make(chan chan bool, 1), - closingOut: make(chan chan bool, 1), - } -} - -func (self *Connection) Read() <-chan []byte { - return self.in -} - -func (self *Connection) Write() chan<- []byte { - return self.out -} - -func (self *Connection) Error() <-chan *PeerError { - return self.err -} - -func (self *Connection) startRead() { - payloads := make(chan []byte) - done := make(chan *PeerError) - pending := [][]byte{} - var head []byte - var wait time.Duration // initally 0 (no delay) - read := time.After(wait * time.Millisecond) - - for { - // if pending empty, nil channel blocks - var in chan []byte - if len(pending) > 0 { - in = self.in // enable send case - head = pending[0] - } else { - in = nil - } - - select { - case <-read: - go self.read(payloads, done) - case err := <-done: - if err == nil { // no error but nothing to read - if len(pending) < maxPendingQueueSize { - wait = 100 - } else if wait == 0 { - wait = 100 - } else { - wait = 2 * wait - } - } else { - self.err <- err // report error - wait = 100 - } - read = time.After(wait * time.Millisecond) - case payload := <-payloads: - pending = append(pending, payload) - if len(pending) < maxPendingQueueSize { - wait = 0 - } else { - wait = 100 - } - read = time.After(wait * time.Millisecond) - case in <- head: - pending = pending[1:] - case errc := <-self.closingIn: - errc <- true - close(self.in) - return - } - - } -} - -func (self *Connection) startWrite() { - pending := [][]byte{} - done := make(chan *PeerError) - writing := false - for { - if len(pending) > 0 && !writing { - writing = true - go self.write(pending[0], done) - } - select { - case payload := <-self.out: - pending = append(pending, payload) - case err := <-done: - if err == nil { - pending = pending[1:] - writing = false - } else { - self.err <- err // report error - } - case errc := <-self.closingOut: - errc <- true - close(self.out) - return - } - } -} - -func pack(payload []byte) (packet []byte) { - length := ethutil.NumberToBytes(uint32(len(payload)), 32) - // return error if too long? - // Write magic token and payload length (first 8 bytes) - packet = append(magicToken, length...) - packet = append(packet, payload...) - return -} - -func avoidPanic(done chan *PeerError) { - if rec := recover(); rec != nil { - err := NewPeerError(MiscError, " %v", rec) - logger.Debugln(err) - done <- err - } -} - -func (self *Connection) write(payload []byte, done chan *PeerError) { - defer avoidPanic(done) - var err *PeerError - _, ok := self.conn.Write(pack(payload)) - if ok != nil { - err = NewPeerError(WriteError, " %v", ok) - logger.Debugln(err) - } - done <- err -} - -func (self *Connection) read(payloads chan []byte, done chan *PeerError) { - //defer avoidPanic(done) - - partials := make(chan []byte, partialsQueueSize) - errc := make(chan *PeerError) - go self.readPartials(partials, errc) - - packet := []byte{} - length := 8 - start := true - var err *PeerError -out: - for { - // appends partials read via connection until packet is - // - either parseable (>=8bytes) - // - or complete (payload fully consumed) - for len(packet) < length { - partial, ok := <-partials - if !ok { // partials channel is closed - err = <-errc - if err == nil && len(packet) > 0 { - if start { - err = NewPeerError(PacketTooShort, "%v", packet) - } else { - err = NewPeerError(PayloadTooShort, "%d < %d", len(packet), length) - } - } - break out - } - packet = append(packet, partial...) - } - if start { - // at least 8 bytes read, can validate packet - if bytes.Compare(magicToken, packet[:4]) != 0 { - err = NewPeerError(MagicTokenMismatch, " received %v", packet[:4]) - break - } - length = int(ethutil.BytesToNumber(packet[4:8])) - packet = packet[8:] - - if length > 0 { - start = false // now consuming payload - } else { //penalize peer but read on - self.err <- NewPeerError(EmptyPayload, "") - length = 8 - } - } else { - // packet complete (payload fully consumed) - payloads <- packet[:length] - packet = packet[length:] // resclice packet - start = true - length = 8 - } - } - - // this stops partials read via the connection, should we? - //if err != nil { - // select { - // case errc <- err - // default: - //} - done <- err -} - -func (self *Connection) readPartials(partials chan []byte, errc chan *PeerError) { - defer close(partials) - for { - // Give buffering some time - self.conn.SetReadDeadline(time.Now().Add(self.timeout * time.Millisecond)) - buffer := make([]byte, readBufferLength) - // read partial from connection - bytesRead, err := self.conn.Read(buffer) - if err == nil || err.Error() == "EOF" { - if bytesRead > 0 { - partials <- buffer[:bytesRead] - } - if err != nil && err.Error() == "EOF" { - break - } - } else { - // unexpected error, report to errc - err := NewPeerError(ReadError, " %v", err) - logger.Debugln(err) - errc <- err - return // will close partials channel - } - } - close(errc) -} diff --git a/p2p/connection_test.go b/p2p/connection_test.go deleted file mode 100644 index 76ee8021c..000000000 --- a/p2p/connection_test.go +++ /dev/null @@ -1,222 +0,0 @@ -package p2p - -import ( - "bytes" - "fmt" - "io" - "net" - "testing" - "time" -) - -type TestNetworkConnection struct { - in chan []byte - current []byte - Out [][]byte - addr net.Addr -} - -func NewTestNetworkConnection(addr net.Addr) *TestNetworkConnection { - return &TestNetworkConnection{ - in: make(chan []byte), - current: []byte{}, - Out: [][]byte{}, - addr: addr, - } -} - -func (self *TestNetworkConnection) In(latency time.Duration, packets ...[]byte) { - time.Sleep(latency) - for _, s := range packets { - self.in <- s - } -} - -func (self *TestNetworkConnection) Read(buff []byte) (n int, err error) { - if len(self.current) == 0 { - select { - case self.current = <-self.in: - default: - return 0, io.EOF - } - } - length := len(self.current) - if length > len(buff) { - copy(buff[:], self.current[:len(buff)]) - self.current = self.current[len(buff):] - return len(buff), nil - } else { - copy(buff[:length], self.current[:]) - self.current = []byte{} - return length, io.EOF - } -} - -func (self *TestNetworkConnection) Write(buff []byte) (n int, err error) { - self.Out = append(self.Out, buff) - fmt.Printf("net write %v\n%v\n", len(self.Out), buff) - return len(buff), nil -} - -func (self *TestNetworkConnection) Close() (err error) { - return -} - -func (self *TestNetworkConnection) LocalAddr() (addr net.Addr) { - return -} - -func (self *TestNetworkConnection) RemoteAddr() (addr net.Addr) { - return self.addr -} - -func (self *TestNetworkConnection) SetDeadline(t time.Time) (err error) { - return -} - -func (self *TestNetworkConnection) SetReadDeadline(t time.Time) (err error) { - return -} - -func (self *TestNetworkConnection) SetWriteDeadline(t time.Time) (err error) { - return -} - -func setupConnection() (*Connection, *TestNetworkConnection) { - addr := &TestAddr{"test:30303"} - net := NewTestNetworkConnection(addr) - conn := NewConnection(net, NewPeerErrorChannel()) - conn.Open() - return conn, net -} - -func TestReadingNilPacket(t *testing.T) { - conn, net := setupConnection() - go net.In(0, []byte{}) - // time.Sleep(10 * time.Millisecond) - select { - case packet := <-conn.Read(): - t.Errorf("read %v", packet) - case err := <-conn.Error(): - t.Errorf("incorrect error %v", err) - default: - } - conn.Close() -} - -func TestReadingShortPacket(t *testing.T) { - conn, net := setupConnection() - go net.In(0, []byte{0}) - select { - case packet := <-conn.Read(): - t.Errorf("read %v", packet) - case err := <-conn.Error(): - if err.Code != PacketTooShort { - t.Errorf("incorrect error %v, expected %v", err.Code, PacketTooShort) - } - } - conn.Close() -} - -func TestReadingInvalidPacket(t *testing.T) { - conn, net := setupConnection() - go net.In(0, []byte{1, 0, 0, 0, 0, 0, 0, 0}) - select { - case packet := <-conn.Read(): - t.Errorf("read %v", packet) - case err := <-conn.Error(): - if err.Code != MagicTokenMismatch { - t.Errorf("incorrect error %v, expected %v", err.Code, MagicTokenMismatch) - } - } - conn.Close() -} - -func TestReadingInvalidPayload(t *testing.T) { - conn, net := setupConnection() - go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 2, 0}) - select { - case packet := <-conn.Read(): - t.Errorf("read %v", packet) - case err := <-conn.Error(): - if err.Code != PayloadTooShort { - t.Errorf("incorrect error %v, expected %v", err.Code, PayloadTooShort) - } - } - conn.Close() -} - -func TestReadingEmptyPayload(t *testing.T) { - conn, net := setupConnection() - go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 0}) - time.Sleep(10 * time.Millisecond) - select { - case packet := <-conn.Read(): - t.Errorf("read %v", packet) - default: - } - select { - case err := <-conn.Error(): - code := err.Code - if code != EmptyPayload { - t.Errorf("incorrect error, expected EmptyPayload, got %v", code) - } - default: - t.Errorf("no error, expected EmptyPayload") - } - conn.Close() -} - -func TestReadingCompletePacket(t *testing.T) { - conn, net := setupConnection() - go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 1, 1}) - time.Sleep(10 * time.Millisecond) - select { - case packet := <-conn.Read(): - if bytes.Compare(packet, []byte{1}) != 0 { - t.Errorf("incorrect payload read") - } - case err := <-conn.Error(): - t.Errorf("incorrect error %v", err) - default: - t.Errorf("nothing read") - } - conn.Close() -} - -func TestReadingTwoCompletePackets(t *testing.T) { - conn, net := setupConnection() - go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 1, 0, 34, 64, 8, 145, 0, 0, 0, 1, 1}) - - for i := 0; i < 2; i++ { - time.Sleep(10 * time.Millisecond) - select { - case packet := <-conn.Read(): - if bytes.Compare(packet, []byte{byte(i)}) != 0 { - t.Errorf("incorrect payload read") - } - case err := <-conn.Error(): - t.Errorf("incorrect error %v", err) - default: - t.Errorf("nothing read") - } - } - conn.Close() -} - -func TestWriting(t *testing.T) { - conn, net := setupConnection() - conn.Write() <- []byte{0} - time.Sleep(10 * time.Millisecond) - if len(net.Out) == 0 { - t.Errorf("no output") - } else { - out := net.Out[0] - if bytes.Compare(out, []byte{34, 64, 8, 145, 0, 0, 0, 1, 0}) != 0 { - t.Errorf("incorrect packet %v", out) - } - } - conn.Close() -} - -// hello packet with client id ABC: 0x22 40 08 91 00 00 00 08 84 00 00 00 43414243 diff --git a/p2p/message.go b/p2p/message.go index 446e74dff..d3b8b74d4 100644 --- a/p2p/message.go +++ b/p2p/message.go @@ -1,75 +1,155 @@ package p2p import ( - // "fmt" + "bytes" + "encoding/binary" + "io" + "io/ioutil" + "math/big" + "github.com/ethereum/go-ethereum/ethutil" + "github.com/ethereum/go-ethereum/rlp" ) -type MsgCode uint8 - +// Msg defines the structure of a p2p message. +// +// Note that a Msg can only be sent once since the Payload reader is +// consumed during sending. It is not possible to create a Msg and +// send it any number of times. If you want to reuse an encoded +// structure, encode the payload into a byte array and create a +// separate Msg with a bytes.Reader as Payload for each send. type Msg struct { - code MsgCode // this is the raw code as per adaptive msg code scheme - data *ethutil.Value - encoded []byte + Code uint64 + Size uint32 // size of the paylod + Payload io.Reader +} + +// NewMsg creates an RLP-encoded message with the given code. +func NewMsg(code uint64, params ...interface{}) Msg { + buf := new(bytes.Buffer) + for _, p := range params { + buf.Write(ethutil.Encode(p)) + } + return Msg{Code: code, Size: uint32(buf.Len()), Payload: buf} +} + +func encodePayload(params ...interface{}) []byte { + buf := new(bytes.Buffer) + for _, p := range params { + buf.Write(ethutil.Encode(p)) + } + return buf.Bytes() +} + +// Decode parse the RLP content of a message into +// the given value, which must be a pointer. +// +// For the decoding rules, please see package rlp. +func (msg Msg) Decode(val interface{}) error { + s := rlp.NewListStream(msg.Payload, uint64(msg.Size)) + return s.Decode(val) +} + +// Discard reads any remaining payload data into a black hole. +func (msg Msg) Discard() error { + _, err := io.Copy(ioutil.Discard, msg.Payload) + return err +} + +type MsgReader interface { + ReadMsg() (Msg, error) } -func (self *Msg) Code() MsgCode { - return self.code +type MsgWriter interface { + // WriteMsg sends an existing message. + // The Payload reader of the message is consumed. + // Note that messages can be sent only once. + WriteMsg(Msg) error + + // EncodeMsg writes an RLP-encoded message with the given + // code and data elements. + EncodeMsg(code uint64, data ...interface{}) error +} + +// MsgReadWriter provides reading and writing of encoded messages. +type MsgReadWriter interface { + MsgReader + MsgWriter } -func (self *Msg) Data() *ethutil.Value { - return self.data +var magicToken = []byte{34, 64, 8, 145} + +func writeMsg(w io.Writer, msg Msg) error { + // TODO: handle case when Size + len(code) + len(listhdr) overflows uint32 + code := ethutil.Encode(uint32(msg.Code)) + listhdr := makeListHeader(msg.Size + uint32(len(code))) + payloadLen := uint32(len(listhdr)) + uint32(len(code)) + msg.Size + + start := make([]byte, 8) + copy(start, magicToken) + binary.BigEndian.PutUint32(start[4:], payloadLen) + + for _, b := range [][]byte{start, listhdr, code} { + if _, err := w.Write(b); err != nil { + return err + } + } + _, err := io.CopyN(w, msg.Payload, int64(msg.Size)) + return err } -func NewMsg(code MsgCode, params ...interface{}) (msg *Msg, err error) { - - // // data := [][]interface{}{} - // data := []interface{}{} - // for _, value := range params { - // if encodable, ok := value.(ethutil.RlpEncodeDecode); ok { - // data = append(data, encodable.RlpValue()) - // } else if raw, ok := value.([]interface{}); ok { - // data = append(data, raw) - // } else { - // // data = append(data, interface{}(raw)) - // err = fmt.Errorf("Unable to encode object of type %T", value) - // return - // } - // } - return &Msg{ - code: code, - data: ethutil.NewValue(interface{}(params)), - }, nil +func makeListHeader(length uint32) []byte { + if length < 56 { + return []byte{byte(length + 0xc0)} + } + enc := big.NewInt(int64(length)).Bytes() + lenb := byte(len(enc)) + 0xf7 + return append([]byte{lenb}, enc...) } -func NewMsgFromBytes(encoded []byte) (msg *Msg, err error) { - value := ethutil.NewValueFromBytes(encoded) - // Type of message - code := value.Get(0).Uint() - // Actual data - data := value.SliceFrom(1) - - msg = &Msg{ - code: MsgCode(code), - data: data, - // data: ethutil.NewValue(data), - encoded: encoded, +// readMsg reads a message header from r. +// It takes an rlp.ByteReader to ensure that the decoding doesn't buffer. +func readMsg(r rlp.ByteReader) (msg Msg, err error) { + // read magic and payload size + start := make([]byte, 8) + if _, err = io.ReadFull(r, start); err != nil { + return msg, newPeerError(errRead, "%v", err) + } + if !bytes.HasPrefix(start, magicToken) { + return msg, newPeerError(errMagicTokenMismatch, "got %x, want %x", start[:4], magicToken) + } + size := binary.BigEndian.Uint32(start[4:]) + + // decode start of RLP message to get the message code + posr := &postrack{r, 0} + s := rlp.NewStream(posr) + if _, err := s.List(); err != nil { + return msg, err } - return + code, err := s.Uint() + if err != nil { + return msg, err + } + payloadsize := size - posr.p + return Msg{code, payloadsize, io.LimitReader(r, int64(payloadsize))}, nil +} + +// postrack wraps an rlp.ByteReader with a position counter. +type postrack struct { + r rlp.ByteReader + p uint32 } -func (self *Msg) Decode(offset MsgCode) { - self.code = self.code - offset +func (r *postrack) Read(buf []byte) (int, error) { + n, err := r.r.Read(buf) + r.p += uint32(n) + return n, err } -// encode takes an offset argument to implement adaptive message coding -// the encoded message is memoized to make msgs relayed to several peers more efficient -func (self *Msg) Encode(offset MsgCode) (res []byte) { - if len(self.encoded) == 0 { - res = ethutil.NewValue(append([]interface{}{byte(self.code + offset)}, self.data.Slice()...)).Encode() - self.encoded = res - } else { - res = self.encoded +func (r *postrack) ReadByte() (byte, error) { + b, err := r.r.ReadByte() + if err == nil { + r.p++ } - return + return b, err } diff --git a/p2p/message_test.go b/p2p/message_test.go index e9d46f2c3..7b39b061d 100644 --- a/p2p/message_test.go +++ b/p2p/message_test.go @@ -1,38 +1,70 @@ package p2p import ( + "bytes" + "io/ioutil" "testing" + + "github.com/ethereum/go-ethereum/ethutil" ) func TestNewMsg(t *testing.T) { - msg, _ := NewMsg(3, 1, "000") - if msg.Code() != 3 { - t.Errorf("incorrect code %v", msg.Code()) + msg := NewMsg(3, 1, "000") + if msg.Code != 3 { + t.Errorf("incorrect code %d, want %d", msg.Code) } - data0 := msg.Data().Get(0).Uint() - data1 := string(msg.Data().Get(1).Bytes()) - if data0 != 1 { - t.Errorf("incorrect data %v", data0) + if msg.Size != 5 { + t.Errorf("incorrect size %d, want %d", msg.Size, 5) } - if data1 != "000" { - t.Errorf("incorrect data %v", data1) + pl, _ := ioutil.ReadAll(msg.Payload) + expect := []byte{0x01, 0x83, 0x30, 0x30, 0x30} + if !bytes.Equal(pl, expect) { + t.Errorf("incorrect payload content, got %x, want %x", pl, expect) } } func TestEncodeDecodeMsg(t *testing.T) { - msg, _ := NewMsg(3, 1, "000") - encoded := msg.Encode(3) - msg, _ = NewMsgFromBytes(encoded) - msg.Decode(3) - if msg.Code() != 3 { - t.Errorf("incorrect code %v", msg.Code()) - } - data0 := msg.Data().Get(0).Uint() - data1 := msg.Data().Get(1).Str() - if data0 != 1 { - t.Errorf("incorrect data %v", data0) - } - if data1 != "000" { - t.Errorf("incorrect data %v", data1) + msg := NewMsg(3, 1, "000") + buf := new(bytes.Buffer) + if err := writeMsg(buf, msg); err != nil { + t.Fatalf("encodeMsg error: %v", err) + } + // t.Logf("encoded: %x", buf.Bytes()) + + decmsg, err := readMsg(buf) + if err != nil { + t.Fatalf("readMsg error: %v", err) + } + if decmsg.Code != 3 { + t.Errorf("incorrect code %d, want %d", decmsg.Code, 3) + } + if decmsg.Size != 5 { + t.Errorf("incorrect size %d, want %d", decmsg.Size, 5) + } + + var data struct { + I int + S string + } + if err := decmsg.Decode(&data); err != nil { + t.Fatalf("Decode error: %v", err) + } + if data.I != 1 { + t.Errorf("incorrect data.I: got %v, expected %d", data.I, 1) + } + if data.S != "000" { + t.Errorf("incorrect data.S: got %q, expected %q", data.S, "000") + } +} + +func TestDecodeRealMsg(t *testing.T) { + data := ethutil.Hex2Bytes("2240089100000080f87e8002b5457468657265756d282b2b292f5065657220536572766572204f6e652f76302e372e382f52656c656173652f4c696e75782f672b2bc082765fb84086dd80b7aefd6a6d2e3b93f4f300a86bfb6ef7bdc97cb03f793db6bb") + msg, err := readMsg(bytes.NewReader(data)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if msg.Code != 0 { + t.Errorf("incorrect code %d, want %d", msg.Code, 0) } } diff --git a/p2p/messenger.go b/p2p/messenger.go deleted file mode 100644 index d42ba1720..000000000 --- a/p2p/messenger.go +++ /dev/null @@ -1,220 +0,0 @@ -package p2p - -import ( - "fmt" - "sync" - "time" -) - -const ( - handlerTimeout = 1000 -) - -type Handlers map[string](func(p *Peer) Protocol) - -type Messenger struct { - conn *Connection - peer *Peer - handlers Handlers - protocolLock sync.RWMutex - protocols []Protocol - offsets []MsgCode // offsets for adaptive message idss - protocolTable map[string]int - quit chan chan bool - err chan *PeerError - pulse chan bool -} - -func NewMessenger(peer *Peer, conn *Connection, errchan chan *PeerError, handlers Handlers) *Messenger { - baseProtocol := NewBaseProtocol(peer) - return &Messenger{ - conn: conn, - peer: peer, - offsets: []MsgCode{baseProtocol.Offset()}, - handlers: handlers, - protocols: []Protocol{baseProtocol}, - protocolTable: make(map[string]int), - err: errchan, - pulse: make(chan bool, 1), - quit: make(chan chan bool, 1), - } -} - -func (self *Messenger) Start() { - self.conn.Open() - go self.messenger() - self.protocolLock.RLock() - defer self.protocolLock.RUnlock() - self.protocols[0].Start() -} - -func (self *Messenger) Stop() { - // close pulse to stop ping pong monitoring - close(self.pulse) - self.protocolLock.RLock() - defer self.protocolLock.RUnlock() - for _, protocol := range self.protocols { - protocol.Stop() // could be parallel - } - q := make(chan bool) - self.quit <- q - <-q - self.conn.Close() -} - -func (self *Messenger) messenger() { - in := self.conn.Read() - for { - select { - case payload, ok := <-in: - //dispatches message to the protocol asynchronously - if ok { - go self.handle(payload) - } else { - return - } - case q := <-self.quit: - q <- true - return - } - } -} - -// handles each message by dispatching to the appropriate protocol -// using adaptive message codes -// this function is started as a separate go routine for each message -// it waits for the protocol response -// then encodes and sends outgoing messages to the connection's write channel -func (self *Messenger) handle(payload []byte) { - // send ping to heartbeat channel signalling time of last message - // select { - // case self.pulse <- true: - // default: - // } - self.pulse <- true - // initialise message from payload - msg, err := NewMsgFromBytes(payload) - if err != nil { - self.err <- NewPeerError(MiscError, " %v", err) - return - } - // retrieves protocol based on message Code - protocol, offset, peerErr := self.getProtocol(msg.Code()) - if err != nil { - self.err <- peerErr - return - } - // reset message code based on adaptive offset - msg.Decode(offset) - // dispatches - response := make(chan *Msg) - go protocol.HandleIn(msg, response) - // protocol reponse timeout to prevent leaks - timer := time.After(handlerTimeout * time.Millisecond) - for { - select { - case outgoing, ok := <-response: - // we check if response channel is not closed - if ok { - self.conn.Write() <- outgoing.Encode(offset) - } else { - return - } - case <-timer: - return - } - } -} - -// negotiated protocols -// stores offsets needed for adaptive message id scheme - -// based on offsets set at handshake -// get the right protocol to handle the message -func (self *Messenger) getProtocol(code MsgCode) (Protocol, MsgCode, *PeerError) { - self.protocolLock.RLock() - defer self.protocolLock.RUnlock() - base := MsgCode(0) - for index, offset := range self.offsets { - if code < offset { - return self.protocols[index], base, nil - } - base = offset - } - return nil, MsgCode(0), NewPeerError(InvalidMsgCode, " %v", code) -} - -func (self *Messenger) PingPong(timeout time.Duration, gracePeriod time.Duration, pingCallback func(), timeoutCallback func()) { - fmt.Printf("pingpong keepalive started at %v", time.Now()) - - timer := time.After(timeout) - pinged := false - for { - select { - case _, ok := <-self.pulse: - if ok { - pinged = false - timer = time.After(timeout) - } else { - // pulse is closed, stop monitoring - return - } - case <-timer: - if pinged { - fmt.Printf("timeout at %v", time.Now()) - timeoutCallback() - return - } else { - fmt.Printf("pinged at %v", time.Now()) - pingCallback() - timer = time.After(gracePeriod) - pinged = true - } - } - } -} - -func (self *Messenger) AddProtocols(protocols []string) { - self.protocolLock.Lock() - defer self.protocolLock.Unlock() - i := len(self.offsets) - offset := self.offsets[i-1] - for _, name := range protocols { - protocolFunc, ok := self.handlers[name] - if ok { - protocol := protocolFunc(self.peer) - self.protocolTable[name] = i - i++ - offset += protocol.Offset() - fmt.Println("offset ", name, offset) - - self.offsets = append(self.offsets, offset) - self.protocols = append(self.protocols, protocol) - protocol.Start() - } else { - fmt.Println("no ", name) - // protocol not handled - } - } -} - -func (self *Messenger) Write(protocol string, msg *Msg) error { - self.protocolLock.RLock() - defer self.protocolLock.RUnlock() - i := 0 - offset := MsgCode(0) - if len(protocol) > 0 { - var ok bool - i, ok = self.protocolTable[protocol] - if !ok { - return fmt.Errorf("protocol %v not handled by peer", protocol) - } - offset = self.offsets[i-1] - } - handler := self.protocols[i] - // checking if protocol status/caps allows the message to be sent out - if handler.HandleOut(msg) { - self.conn.Write() <- msg.Encode(offset) - } - return nil -} diff --git a/p2p/messenger_test.go b/p2p/messenger_test.go deleted file mode 100644 index 472d74515..000000000 --- a/p2p/messenger_test.go +++ /dev/null @@ -1,147 +0,0 @@ -package p2p - -import ( - // "fmt" - "bytes" - "testing" - "time" - - "github.com/ethereum/go-ethereum/ethutil" -) - -func setupMessenger(handlers Handlers) (*TestNetworkConnection, chan *PeerError, *Messenger) { - errchan := NewPeerErrorChannel() - addr := &TestAddr{"test:30303"} - net := NewTestNetworkConnection(addr) - conn := NewConnection(net, errchan) - mess := NewMessenger(nil, conn, errchan, handlers) - mess.Start() - return net, errchan, mess -} - -type TestProtocol struct { - Msgs []*Msg -} - -func (self *TestProtocol) Start() { -} - -func (self *TestProtocol) Stop() { -} - -func (self *TestProtocol) Offset() MsgCode { - return MsgCode(5) -} - -func (self *TestProtocol) HandleIn(msg *Msg, response chan *Msg) { - self.Msgs = append(self.Msgs, msg) - close(response) -} - -func (self *TestProtocol) HandleOut(msg *Msg) bool { - if msg.Code() > 3 { - return false - } else { - return true - } -} - -func (self *TestProtocol) Name() string { - return "a" -} - -func Packet(offset MsgCode, code MsgCode, params ...interface{}) []byte { - msg, _ := NewMsg(code, params...) - encoded := msg.Encode(offset) - packet := []byte{34, 64, 8, 145} - packet = append(packet, ethutil.NumberToBytes(uint32(len(encoded)), 32)...) - return append(packet, encoded...) -} - -func TestRead(t *testing.T) { - handlers := make(Handlers) - testProtocol := &TestProtocol{Msgs: []*Msg{}} - handlers["a"] = func(p *Peer) Protocol { return testProtocol } - net, _, mess := setupMessenger(handlers) - mess.AddProtocols([]string{"a"}) - defer mess.Stop() - wait := 1 * time.Millisecond - packet := Packet(16, 1, uint32(1), "000") - go net.In(0, packet) - time.Sleep(wait) - if len(testProtocol.Msgs) != 1 { - t.Errorf("msg not relayed to correct protocol") - } else { - if testProtocol.Msgs[0].Code() != 1 { - t.Errorf("incorrect msg code relayed to protocol") - } - } -} - -func TestWrite(t *testing.T) { - handlers := make(Handlers) - testProtocol := &TestProtocol{Msgs: []*Msg{}} - handlers["a"] = func(p *Peer) Protocol { return testProtocol } - net, _, mess := setupMessenger(handlers) - mess.AddProtocols([]string{"a"}) - defer mess.Stop() - wait := 1 * time.Millisecond - msg, _ := NewMsg(3, uint32(1), "000") - err := mess.Write("b", msg) - if err == nil { - t.Errorf("expect error for unknown protocol") - } - err = mess.Write("a", msg) - if err != nil { - t.Errorf("expect no error for known protocol: %v", err) - } else { - time.Sleep(wait) - if len(net.Out) != 1 { - t.Errorf("msg not written") - } else { - out := net.Out[0] - packet := Packet(16, 3, uint32(1), "000") - if bytes.Compare(out, packet) != 0 { - t.Errorf("incorrect packet %v", out) - } - } - } -} - -func TestPulse(t *testing.T) { - net, _, mess := setupMessenger(make(Handlers)) - defer mess.Stop() - ping := false - timeout := false - pingTimeout := 10 * time.Millisecond - gracePeriod := 200 * time.Millisecond - go mess.PingPong(pingTimeout, gracePeriod, func() { ping = true }, func() { timeout = true }) - net.In(0, Packet(0, 1)) - if ping { - t.Errorf("ping sent too early") - } - time.Sleep(pingTimeout + 100*time.Millisecond) - if !ping { - t.Errorf("no ping sent after timeout") - } - if timeout { - t.Errorf("timeout too early") - } - ping = false - net.In(0, Packet(0, 1)) - time.Sleep(pingTimeout + 100*time.Millisecond) - if !ping { - t.Errorf("no ping sent after timeout") - } - if timeout { - t.Errorf("timeout too early") - } - ping = false - time.Sleep(gracePeriod) - if ping { - t.Errorf("ping called twice") - } - if !timeout { - t.Errorf("no timeout after grace period") - } -} diff --git a/p2p/natpmp.go b/p2p/natpmp.go index ff966d070..6714678c4 100644 --- a/p2p/natpmp.go +++ b/p2p/natpmp.go @@ -3,6 +3,7 @@ package p2p import ( "fmt" "net" + "time" natpmp "github.com/jackpal/go-nat-pmp" ) @@ -13,38 +14,37 @@ import ( // + Register for changes to the external address. // + Re-register port mapping when router reboots. // + A mechanism for keeping a port mapping registered. +// + Discover gateway address automatically. type natPMPClient struct { client *natpmp.Client } -func NewNatPMP(gateway net.IP) (nat NAT) { +// PMP returns a NAT traverser that uses NAT-PMP. The provided gateway +// address should be the IP of your router. +func PMP(gateway net.IP) (nat NAT) { return &natPMPClient{natpmp.NewClient(gateway)} } -func (n *natPMPClient) GetExternalAddress() (addr net.IP, err error) { +func (*natPMPClient) String() string { + return "NAT-PMP" +} + +func (n *natPMPClient) GetExternalAddress() (net.IP, error) { response, err := n.client.GetExternalAddress() if err != nil { - return + return nil, err } - ip := response.ExternalIPAddress - addr = net.IPv4(ip[0], ip[1], ip[2], ip[3]) - return + return response.ExternalIPAddress[:], nil } -func (n *natPMPClient) AddPortMapping(protocol string, externalPort, internalPort int, - description string, timeout int) (mappedExternalPort int, err error) { - if timeout <= 0 { - err = fmt.Errorf("timeout must not be <= 0") - return +func (n *natPMPClient) AddPortMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error { + if lifetime <= 0 { + return fmt.Errorf("lifetime must not be <= 0") } // Note order of port arguments is switched between our AddPortMapping and the client's AddPortMapping. - response, err := n.client.AddPortMapping(protocol, internalPort, externalPort, timeout) - if err != nil { - return - } - mappedExternalPort = int(response.MappedExternalPort) - return + _, err := n.client.AddPortMapping(protocol, intport, extport, int(lifetime/time.Second)) + return err } func (n *natPMPClient) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) { diff --git a/p2p/natupnp.go b/p2p/natupnp.go index fa9798d4d..2e0d8ce8d 100644 --- a/p2p/natupnp.go +++ b/p2p/natupnp.go @@ -7,6 +7,7 @@ import ( "bytes" "encoding/xml" "errors" + "fmt" "net" "net/http" "os" @@ -15,28 +16,46 @@ import ( "time" ) +const ( + upnpDiscoverAttempts = 3 + upnpDiscoverTimeout = 5 * time.Second +) + +// UPNP returns a NAT port mapper that uses UPnP. It will attempt to +// discover the address of your router using UDP broadcasts. +func UPNP() NAT { + return &upnpNAT{} +} + type upnpNAT struct { serviceURL string ourIP string } -func upnpDiscover(attempts int) (nat NAT, err error) { +func (n *upnpNAT) String() string { + return "UPNP" +} + +func (n *upnpNAT) discover() error { + if n.serviceURL != "" { + // already discovered + return nil + } + ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900") if err != nil { - return + return err } + // TODO: try on all network interfaces simultaneously. + // Broadcasting on 0.0.0.0 could select a random interface + // to send on (platform specific). conn, err := net.ListenPacket("udp4", ":0") if err != nil { - return - } - socket := conn.(*net.UDPConn) - defer socket.Close() - - err = socket.SetDeadline(time.Now().Add(10 * time.Second)) - if err != nil { - return + return err } + defer conn.Close() + conn.SetDeadline(time.Now().Add(10 * time.Second)) st := "ST: urn:schemas-upnp-org:device:InternetGatewayDevice:1\r\n" buf := bytes.NewBufferString( "M-SEARCH * HTTP/1.1\r\n" + @@ -46,19 +65,16 @@ func upnpDiscover(attempts int) (nat NAT, err error) { "MX: 2\r\n\r\n") message := buf.Bytes() answerBytes := make([]byte, 1024) - for i := 0; i < attempts; i++ { - _, err = socket.WriteToUDP(message, ssdp) + for i := 0; i < upnpDiscoverAttempts; i++ { + _, err = conn.WriteTo(message, ssdp) if err != nil { - return + return err } - var n int - n, _, err = socket.ReadFromUDP(answerBytes) + nn, _, err := conn.ReadFrom(answerBytes) if err != nil { continue - // socket.Close() - // return } - answer := string(answerBytes[0:n]) + answer := string(answerBytes[0:nn]) if strings.Index(answer, "\r\n"+st) < 0 { continue } @@ -79,17 +95,81 @@ func upnpDiscover(attempts int) (nat NAT, err error) { var serviceURL string serviceURL, err = getServiceURL(locURL) if err != nil { - return + return err } var ourIP string ourIP, err = getOurIP() if err != nil { - return + return err } - nat = &upnpNAT{serviceURL: serviceURL, ourIP: ourIP} + n.serviceURL = serviceURL + n.ourIP = ourIP + return nil + } + return errors.New("UPnP port discovery failed.") +} + +func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) { + if err := n.discover(); err != nil { + return nil, err + } + info, err := n.getStatusInfo() + return net.ParseIP(info.externalIpAddress), err +} + +func (n *upnpNAT) AddPortMapping(protocol string, extport, intport int, description string, lifetime time.Duration) error { + if err := n.discover(); err != nil { + return err + } + + // A single concatenation would break ARM compilation. + message := "<u:AddPortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" + + "<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(extport) + message += "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" + message += "<NewInternalPort>" + strconv.Itoa(extport) + "</NewInternalPort>" + + "<NewInternalClient>" + n.ourIP + "</NewInternalClient>" + + "<NewEnabled>1</NewEnabled><NewPortMappingDescription>" + message += description + + "</NewPortMappingDescription><NewLeaseDuration>" + fmt.Sprint(lifetime/time.Second) + + "</NewLeaseDuration></u:AddPortMapping>" + + // TODO: check response to see if the port was forwarded + _, err := soapRequest(n.serviceURL, "AddPortMapping", message) + return err +} + +func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) error { + if err := n.discover(); err != nil { + return err + } + + message := "<u:DeletePortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" + + "<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) + + "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" + + "</u:DeletePortMapping>" + + // TODO: check response to see if the port was deleted + _, err := soapRequest(n.serviceURL, "DeletePortMapping", message) + return err +} + +type statusInfo struct { + externalIpAddress string +} + +func (n *upnpNAT) getStatusInfo() (info statusInfo, err error) { + message := "<u:GetStatusInfo xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" + + "</u:GetStatusInfo>" + + var response *http.Response + response, err = soapRequest(n.serviceURL, "GetStatusInfo", message) + if err != nil { return } - err = errors.New("UPnP port discovery failed.") + + // TODO: Write a soap reply parser. It has to eat the Body and envelope tags... + + response.Body.Close() return } @@ -259,77 +339,3 @@ func soapRequest(url, function, message string) (r *http.Response, err error) { } return } - -type statusInfo struct { - externalIpAddress string -} - -func (n *upnpNAT) getStatusInfo() (info statusInfo, err error) { - - message := "<u:GetStatusInfo xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" + - "</u:GetStatusInfo>" - - var response *http.Response - response, err = soapRequest(n.serviceURL, "GetStatusInfo", message) - if err != nil { - return - } - - // TODO: Write a soap reply parser. It has to eat the Body and envelope tags... - - response.Body.Close() - return -} - -func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) { - info, err := n.getStatusInfo() - if err != nil { - return - } - addr = net.ParseIP(info.externalIpAddress) - return -} - -func (n *upnpNAT) AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) { - // A single concatenation would break ARM compilation. - message := "<u:AddPortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" + - "<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) - message += "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" - message += "<NewInternalPort>" + strconv.Itoa(internalPort) + "</NewInternalPort>" + - "<NewInternalClient>" + n.ourIP + "</NewInternalClient>" + - "<NewEnabled>1</NewEnabled><NewPortMappingDescription>" - message += description + - "</NewPortMappingDescription><NewLeaseDuration>" + strconv.Itoa(timeout) + - "</NewLeaseDuration></u:AddPortMapping>" - - var response *http.Response - response, err = soapRequest(n.serviceURL, "AddPortMapping", message) - if err != nil { - return - } - - // TODO: check response to see if the port was forwarded - // log.Println(message, response) - mappedExternalPort = externalPort - _ = response - return -} - -func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) { - - message := "<u:DeletePortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" + - "<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) + - "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" + - "</u:DeletePortMapping>" - - var response *http.Response - response, err = soapRequest(n.serviceURL, "DeletePortMapping", message) - if err != nil { - return - } - - // TODO: check response to see if the port was deleted - // log.Println(message, response) - _ = response - return -} diff --git a/p2p/network.go b/p2p/network.go deleted file mode 100644 index 820cef1a9..000000000 --- a/p2p/network.go +++ /dev/null @@ -1,196 +0,0 @@ -package p2p - -import ( - "fmt" - "math/rand" - "net" - "strconv" - "time" -) - -const ( - DialerTimeout = 180 //seconds - KeepAlivePeriod = 60 //minutes - portMappingUpdateInterval = 900 // seconds = 15 mins - upnpDiscoverAttempts = 3 -) - -// Dialer is not an interface in net, so we define one -// *net.Dialer conforms to this -type Dialer interface { - Dial(network, address string) (net.Conn, error) -} - -type Network interface { - Start() error - Listener(net.Addr) (net.Listener, error) - Dialer(net.Addr) (Dialer, error) - NewAddr(string, int) (addr net.Addr, err error) - ParseAddr(string) (addr net.Addr, err error) -} - -type NAT interface { - GetExternalAddress() (addr net.IP, err error) - AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) - DeletePortMapping(protocol string, externalPort, internalPort int) (err error) -} - -type TCPNetwork struct { - nat NAT - natType NATType - quit chan chan bool - ports chan string -} - -type NATType int - -const ( - NONE = iota - UPNP - PMP -) - -const ( - portMappingTimeout = 1200 // 20 mins -) - -func NewTCPNetwork(natType NATType) (net *TCPNetwork) { - return &TCPNetwork{ - natType: natType, - ports: make(chan string), - } -} - -func (self *TCPNetwork) Dialer(addr net.Addr) (Dialer, error) { - return &net.Dialer{ - Timeout: DialerTimeout * time.Second, - // KeepAlive: KeepAlivePeriod * time.Minute, - LocalAddr: addr, - }, nil -} - -func (self *TCPNetwork) Listener(addr net.Addr) (net.Listener, error) { - if self.natType == UPNP { - _, port, _ := net.SplitHostPort(addr.String()) - if self.quit == nil { - self.quit = make(chan chan bool) - go self.updatePortMappings() - } - self.ports <- port - } - return net.Listen(addr.Network(), addr.String()) -} - -func (self *TCPNetwork) Start() (err error) { - switch self.natType { - case NONE: - case UPNP: - nat, uerr := upnpDiscover(upnpDiscoverAttempts) - if uerr != nil { - err = fmt.Errorf("UPNP failed: ", uerr) - } else { - self.nat = nat - } - case PMP: - err = fmt.Errorf("PMP not implemented") - default: - err = fmt.Errorf("Invalid NAT type: %v", self.natType) - } - return -} - -func (self *TCPNetwork) Stop() { - q := make(chan bool) - self.quit <- q - <-q -} - -func (self *TCPNetwork) addPortMapping(lport int) (err error) { - _, err = self.nat.AddPortMapping("TCP", lport, lport, "p2p listen port", portMappingTimeout) - if err != nil { - logger.Errorf("unable to add port mapping on %v: %v", lport, err) - } else { - logger.Debugf("succesfully added port mapping on %v", lport) - } - return -} - -func (self *TCPNetwork) updatePortMappings() { - timer := time.NewTimer(portMappingUpdateInterval * time.Second) - lports := []int{} -out: - for { - select { - case port := <-self.ports: - int64lport, _ := strconv.ParseInt(port, 10, 16) - lport := int(int64lport) - if err := self.addPortMapping(lport); err != nil { - lports = append(lports, lport) - } - case <-timer.C: - for lport := range lports { - if err := self.addPortMapping(lport); err != nil { - } - } - case errc := <-self.quit: - errc <- true - break out - } - } - - timer.Stop() - for lport := range lports { - if err := self.nat.DeletePortMapping("TCP", lport, lport); err != nil { - logger.Debugf("unable to remove port mapping on %v: %v", lport, err) - } else { - logger.Debugf("succesfully removed port mapping on %v", lport) - } - } -} - -func (self *TCPNetwork) NewAddr(host string, port int) (net.Addr, error) { - ip, err := self.lookupIP(host) - if err == nil { - return &net.TCPAddr{ - IP: ip, - Port: port, - }, nil - } - return nil, err -} - -func (self *TCPNetwork) ParseAddr(address string) (net.Addr, error) { - host, port, err := net.SplitHostPort(address) - if err == nil { - iport, _ := strconv.Atoi(port) - addr, e := self.NewAddr(host, iport) - return addr, e - } - return nil, err -} - -func (*TCPNetwork) lookupIP(host string) (ip net.IP, err error) { - if ip = net.ParseIP(host); ip != nil { - return - } - - var ips []net.IP - ips, err = net.LookupIP(host) - if err != nil { - logger.Warnln(err) - return - } - if len(ips) == 0 { - err = fmt.Errorf("No IP addresses available for %v", host) - logger.Warnln(err) - return - } - if len(ips) > 1 { - // Pick a random IP address, simulating round-robin DNS. - rand.Seed(time.Now().UTC().UnixNano()) - ip = ips[rand.Intn(len(ips))] - } else { - ip = ips[0] - } - return -} diff --git a/p2p/peer.go b/p2p/peer.go index f4b68a007..893ba86d7 100644 --- a/p2p/peer.go +++ b/p2p/peer.go @@ -1,83 +1,455 @@ package p2p import ( + "bufio" + "bytes" "fmt" + "io" + "io/ioutil" "net" - "strconv" + "sort" + "sync" + "time" + + "github.com/ethereum/go-ethereum/event" + "github.com/ethereum/go-ethereum/logger" ) +// peerAddr is the structure of a peer list element. +// It is also a valid net.Addr. +type peerAddr struct { + IP net.IP + Port uint64 + Pubkey []byte // optional +} + +func newPeerAddr(addr net.Addr, pubkey []byte) *peerAddr { + n := addr.Network() + if n != "tcp" && n != "tcp4" && n != "tcp6" { + // for testing with non-TCP + return &peerAddr{net.ParseIP("127.0.0.1"), 30303, pubkey} + } + ta := addr.(*net.TCPAddr) + return &peerAddr{ta.IP, uint64(ta.Port), pubkey} +} + +func (d peerAddr) Network() string { + if d.IP.To4() != nil { + return "tcp4" + } else { + return "tcp6" + } +} + +func (d peerAddr) String() string { + return fmt.Sprintf("%v:%d", d.IP, d.Port) +} + +func (d peerAddr) RlpData() interface{} { + return []interface{}{d.IP, d.Port, d.Pubkey} +} + +// Peer represents a remote peer. type Peer struct { - // quit chan chan bool - Inbound bool // inbound (via listener) or outbound (via dialout) - Address net.Addr - Host []byte - Port uint16 - Pubkey []byte - Id string - Caps []string - peerErrorChan chan *PeerError - messenger *Messenger - peerErrorHandler *PeerErrorHandler - server *Server -} - -func (self *Peer) Messenger() *Messenger { - return self.messenger -} - -func (self *Peer) PeerErrorChan() chan *PeerError { - return self.peerErrorChan -} - -func (self *Peer) Server() *Server { - return self.server -} - -func NewPeer(conn net.Conn, address net.Addr, inbound bool, server *Server) *Peer { - peerErrorChan := NewPeerErrorChannel() - host, port, _ := net.SplitHostPort(address.String()) - intport, _ := strconv.Atoi(port) - peer := &Peer{ - Inbound: inbound, - Address: address, - Port: uint16(intport), - Host: net.ParseIP(host), - peerErrorChan: peerErrorChan, - server: server, - } - connection := NewConnection(conn, peerErrorChan) - peer.messenger = NewMessenger(peer, connection, peerErrorChan, server.Handlers()) - peer.peerErrorHandler = NewPeerErrorHandler(address, server.PeerDisconnect(), peerErrorChan, server.Blacklist()) + // Peers have all the log methods. + // Use them to display messages related to the peer. + *logger.Logger + + infolock sync.Mutex + identity ClientIdentity + caps []Cap + listenAddr *peerAddr // what remote peer is listening on + dialAddr *peerAddr // non-nil if dialing + + // The mutex protects the connection + // so only one protocol can write at a time. + writeMu sync.Mutex + conn net.Conn + bufconn *bufio.ReadWriter + + // These fields maintain the running protocols. + protocols []Protocol + runBaseProtocol bool // for testing + + runlock sync.RWMutex // protects running + running map[string]*proto + + protoWG sync.WaitGroup + protoErr chan error + closed chan struct{} + disc chan DiscReason + + activity event.TypeMux // for activity events + + slot int // index into Server peer list + + // These fields are kept so base protocol can access them. + // TODO: this should be one or more interfaces + ourID ClientIdentity // client id of the Server + ourListenAddr *peerAddr // listen addr of Server, nil if not listening + newPeerAddr chan<- *peerAddr // tell server about received peers + otherPeers func() []*Peer // should return the list of all peers + pubkeyHook func(*peerAddr) error // called at end of handshake to validate pubkey +} + +// NewPeer returns a peer for testing purposes. +func NewPeer(id ClientIdentity, caps []Cap) *Peer { + conn, _ := net.Pipe() + peer := newPeer(conn, nil, nil) + peer.setHandshakeInfo(id, nil, caps) + close(peer.closed) return peer } -func (self *Peer) String() string { - var kind string - if self.Inbound { - kind = "inbound" - } else { +func newServerPeer(server *Server, conn net.Conn, dialAddr *peerAddr) *Peer { + p := newPeer(conn, server.Protocols, dialAddr) + p.ourID = server.Identity + p.newPeerAddr = server.peerConnect + p.otherPeers = server.Peers + p.pubkeyHook = server.verifyPeer + p.runBaseProtocol = true + + // laddr can be updated concurrently by NAT traversal. + // newServerPeer must be called with the server lock held. + if server.laddr != nil { + p.ourListenAddr = newPeerAddr(server.laddr, server.Identity.Pubkey()) + } + return p +} + +func newPeer(conn net.Conn, protocols []Protocol, dialAddr *peerAddr) *Peer { + p := &Peer{ + Logger: logger.NewLogger("P2P " + conn.RemoteAddr().String()), + conn: conn, + dialAddr: dialAddr, + bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)), + protocols: protocols, + running: make(map[string]*proto), + disc: make(chan DiscReason), + protoErr: make(chan error), + closed: make(chan struct{}), + } + return p +} + +// Identity returns the client identity of the remote peer. The +// identity can be nil if the peer has not yet completed the +// handshake. +func (p *Peer) Identity() ClientIdentity { + p.infolock.Lock() + defer p.infolock.Unlock() + return p.identity +} + +// Caps returns the capabilities (supported subprotocols) of the remote peer. +func (p *Peer) Caps() []Cap { + p.infolock.Lock() + defer p.infolock.Unlock() + return p.caps +} + +func (p *Peer) setHandshakeInfo(id ClientIdentity, laddr *peerAddr, caps []Cap) { + p.infolock.Lock() + p.identity = id + p.listenAddr = laddr + p.caps = caps + p.infolock.Unlock() +} + +// RemoteAddr returns the remote address of the network connection. +func (p *Peer) RemoteAddr() net.Addr { + return p.conn.RemoteAddr() +} + +// LocalAddr returns the local address of the network connection. +func (p *Peer) LocalAddr() net.Addr { + return p.conn.LocalAddr() +} + +// Disconnect terminates the peer connection with the given reason. +// It returns immediately and does not wait until the connection is closed. +func (p *Peer) Disconnect(reason DiscReason) { + select { + case p.disc <- reason: + case <-p.closed: + } +} + +// String implements fmt.Stringer. +func (p *Peer) String() string { + kind := "inbound" + p.infolock.Lock() + if p.dialAddr != nil { kind = "outbound" } - return fmt.Sprintf("%v:%v (%s) v%v %v", self.Host, self.Port, kind, self.Id, self.Caps) + p.infolock.Unlock() + return fmt.Sprintf("Peer(%p %v %s)", p, p.conn.RemoteAddr(), kind) +} + +const ( + // maximum amount of time allowed for reading a message + msgReadTimeout = 5 * time.Second + // maximum amount of time allowed for writing a message + msgWriteTimeout = 5 * time.Second + // messages smaller than this many bytes will be read at + // once before passing them to a protocol. + wholePayloadSize = 64 * 1024 +) + +var ( + inactivityTimeout = 2 * time.Second + disconnectGracePeriod = 2 * time.Second +) + +func (p *Peer) loop() (reason DiscReason, err error) { + defer p.activity.Stop() + defer p.closeProtocols() + defer close(p.closed) + defer p.conn.Close() + + // read loop + readMsg := make(chan Msg) + readErr := make(chan error) + readNext := make(chan bool, 1) + protoDone := make(chan struct{}, 1) + go p.readLoop(readMsg, readErr, readNext) + readNext <- true + + if p.runBaseProtocol { + p.startBaseProtocol() + } + +loop: + for { + select { + case msg := <-readMsg: + // a new message has arrived. + var wait bool + if wait, err = p.dispatch(msg, protoDone); err != nil { + p.Errorf("msg dispatch error: %v\n", err) + reason = discReasonForError(err) + break loop + } + if !wait { + // Msg has already been read completely, continue with next message. + readNext <- true + } + p.activity.Post(time.Now()) + case <-protoDone: + // protocol has consumed the message payload, + // we can continue reading from the socket. + readNext <- true + + case err := <-readErr: + // read failed. there is no need to run the + // polite disconnect sequence because the connection + // is probably dead anyway. + // TODO: handle write errors as well + return DiscNetworkError, err + case err = <-p.protoErr: + reason = discReasonForError(err) + break loop + case reason = <-p.disc: + break loop + } + } + + // wait for read loop to return. + close(readNext) + <-readErr + // tell the remote end to disconnect + done := make(chan struct{}) + go func() { + p.conn.SetDeadline(time.Now().Add(disconnectGracePeriod)) + p.writeMsg(NewMsg(discMsg, reason), disconnectGracePeriod) + io.Copy(ioutil.Discard, p.conn) + close(done) + }() + select { + case <-done: + case <-time.After(disconnectGracePeriod): + } + return reason, err +} + +func (p *Peer) readLoop(msgc chan<- Msg, errc chan<- error, unblock <-chan bool) { + for _ = range unblock { + p.conn.SetReadDeadline(time.Now().Add(msgReadTimeout)) + if msg, err := readMsg(p.bufconn); err != nil { + errc <- err + } else { + msgc <- msg + } + } + close(errc) +} + +func (p *Peer) dispatch(msg Msg, protoDone chan struct{}) (wait bool, err error) { + proto, err := p.getProto(msg.Code) + if err != nil { + return false, err + } + if msg.Size <= wholePayloadSize { + // optimization: msg is small enough, read all + // of it and move on to the next message + buf, err := ioutil.ReadAll(msg.Payload) + if err != nil { + return false, err + } + msg.Payload = bytes.NewReader(buf) + proto.in <- msg + } else { + wait = true + pr := &eofSignal{msg.Payload, protoDone} + msg.Payload = pr + proto.in <- msg + } + return wait, nil +} + +func (p *Peer) startBaseProtocol() { + p.runlock.Lock() + defer p.runlock.Unlock() + p.running[""] = p.startProto(0, Protocol{ + Length: baseProtocolLength, + Run: runBaseProtocol, + }) +} + +// startProtocols starts matching named subprotocols. +func (p *Peer) startSubprotocols(caps []Cap) { + sort.Sort(capsByName(caps)) + + p.runlock.Lock() + defer p.runlock.Unlock() + offset := baseProtocolLength +outer: + for _, cap := range caps { + for _, proto := range p.protocols { + if proto.Name == cap.Name && + proto.Version == cap.Version && + p.running[cap.Name] == nil { + p.running[cap.Name] = p.startProto(offset, proto) + offset += proto.Length + continue outer + } + } + } +} + +func (p *Peer) startProto(offset uint64, impl Protocol) *proto { + rw := &proto{ + in: make(chan Msg), + offset: offset, + maxcode: impl.Length, + peer: p, + } + p.protoWG.Add(1) + go func() { + err := impl.Run(p, rw) + if err == nil { + p.Infof("protocol %q returned", impl.Name) + err = newPeerError(errMisc, "protocol returned") + } else { + p.Errorf("protocol %q error: %v\n", impl.Name, err) + } + select { + case p.protoErr <- err: + case <-p.closed: + } + p.protoWG.Done() + }() + return rw +} + +// getProto finds the protocol responsible for handling +// the given message code. +func (p *Peer) getProto(code uint64) (*proto, error) { + p.runlock.RLock() + defer p.runlock.RUnlock() + for _, proto := range p.running { + if code >= proto.offset && code < proto.offset+proto.maxcode { + return proto, nil + } + } + return nil, newPeerError(errInvalidMsgCode, "%d", code) +} + +func (p *Peer) closeProtocols() { + p.runlock.RLock() + for _, p := range p.running { + close(p.in) + } + p.runlock.RUnlock() + p.protoWG.Wait() +} + +// writeProtoMsg sends the given message on behalf of the given named protocol. +func (p *Peer) writeProtoMsg(protoName string, msg Msg) error { + p.runlock.RLock() + proto, ok := p.running[protoName] + p.runlock.RUnlock() + if !ok { + return fmt.Errorf("protocol %s not handled by peer", protoName) + } + if msg.Code >= proto.maxcode { + return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName) + } + msg.Code += proto.offset + return p.writeMsg(msg, msgWriteTimeout) +} + +// writeMsg writes a message to the connection. +func (p *Peer) writeMsg(msg Msg, timeout time.Duration) error { + p.writeMu.Lock() + defer p.writeMu.Unlock() + p.conn.SetWriteDeadline(time.Now().Add(timeout)) + if err := writeMsg(p.bufconn, msg); err != nil { + return newPeerError(errWrite, "%v", err) + } + return p.bufconn.Flush() +} + +type proto struct { + name string + in chan Msg + maxcode, offset uint64 + peer *Peer +} + +func (rw *proto) WriteMsg(msg Msg) error { + if msg.Code >= rw.maxcode { + return newPeerError(errInvalidMsgCode, "not handled") + } + msg.Code += rw.offset + return rw.peer.writeMsg(msg, msgWriteTimeout) } -func (self *Peer) Write(protocol string, msg *Msg) error { - return self.messenger.Write(protocol, msg) +func (rw *proto) EncodeMsg(code uint64, data ...interface{}) error { + return rw.WriteMsg(NewMsg(code, data)) } -func (self *Peer) Start() { - self.peerErrorHandler.Start() - self.messenger.Start() +func (rw *proto) ReadMsg() (Msg, error) { + msg, ok := <-rw.in + if !ok { + return msg, io.EOF + } + msg.Code -= rw.offset + return msg, nil } -func (self *Peer) Stop() { - self.peerErrorHandler.Stop() - self.messenger.Stop() - // q := make(chan bool) - // self.quit <- q - // <-q +// eofSignal wraps a reader with eof signaling. +// the eof channel is closed when the wrapped reader +// reaches EOF. +type eofSignal struct { + wrapped io.Reader + eof chan<- struct{} } -func (p *Peer) Encode() []interface{} { - return []interface{}{p.Host, p.Port, p.Pubkey} +func (r *eofSignal) Read(buf []byte) (int, error) { + n, err := r.wrapped.Read(buf) + if err != nil { + r.eof <- struct{}{} // tell Peer that msg has been consumed + } + return n, err } diff --git a/p2p/peer_error.go b/p2p/peer_error.go index de921878a..88b870fbd 100644 --- a/p2p/peer_error.go +++ b/p2p/peer_error.go @@ -4,73 +4,121 @@ import ( "fmt" ) -type ErrorCode int - -const errorChanCapacity = 10 - const ( - PacketTooShort = iota - PayloadTooShort - MagicTokenMismatch - EmptyPayload - ReadError - WriteError - MiscError - InvalidMsgCode - InvalidMsg - P2PVersionMismatch - PubkeyMissing - PubkeyInvalid - PubkeyForbidden - ProtocolBreach - PortMismatch - PingTimeout - InvalidGenesis - InvalidNetworkId - InvalidProtocolVersion + errMagicTokenMismatch = iota + errRead + errWrite + errMisc + errInvalidMsgCode + errInvalidMsg + errP2PVersionMismatch + errPubkeyMissing + errPubkeyInvalid + errPubkeyForbidden + errProtocolBreach + errPingTimeout + errInvalidNetworkId + errInvalidProtocolVersion ) -var errorToString = map[ErrorCode]string{ - PacketTooShort: "Packet too short", - PayloadTooShort: "Payload too short", - MagicTokenMismatch: "Magic token mismatch", - EmptyPayload: "Empty payload", - ReadError: "Read error", - WriteError: "Write error", - MiscError: "Misc error", - InvalidMsgCode: "Invalid message code", - InvalidMsg: "Invalid message", - P2PVersionMismatch: "P2P Version Mismatch", - PubkeyMissing: "Public key missing", - PubkeyInvalid: "Public key invalid", - PubkeyForbidden: "Public key forbidden", - ProtocolBreach: "Protocol Breach", - PortMismatch: "Port mismatch", - PingTimeout: "Ping timeout", - InvalidGenesis: "Invalid genesis block", - InvalidNetworkId: "Invalid network id", - InvalidProtocolVersion: "Invalid protocol version", +var errorToString = map[int]string{ + errMagicTokenMismatch: "Magic token mismatch", + errRead: "Read error", + errWrite: "Write error", + errMisc: "Misc error", + errInvalidMsgCode: "Invalid message code", + errInvalidMsg: "Invalid message", + errP2PVersionMismatch: "P2P Version Mismatch", + errPubkeyMissing: "Public key missing", + errPubkeyInvalid: "Public key invalid", + errPubkeyForbidden: "Public key forbidden", + errProtocolBreach: "Protocol Breach", + errPingTimeout: "Ping timeout", + errInvalidNetworkId: "Invalid network id", + errInvalidProtocolVersion: "Invalid protocol version", } -type PeerError struct { - Code ErrorCode +type peerError struct { + Code int message string } -func NewPeerError(code ErrorCode, format string, v ...interface{}) *PeerError { +func newPeerError(code int, format string, v ...interface{}) *peerError { desc, ok := errorToString[code] if !ok { panic("invalid error code") } - format = desc + ": " + format - message := fmt.Sprintf(format, v...) - return &PeerError{code, message} + err := &peerError{code, desc} + if format != "" { + err.message += ": " + fmt.Sprintf(format, v...) + } + return err } -func (self *PeerError) Error() string { +func (self *peerError) Error() string { return self.message } -func NewPeerErrorChannel() chan *PeerError { - return make(chan *PeerError, errorChanCapacity) +type DiscReason byte + +const ( + DiscRequested DiscReason = 0x00 + DiscNetworkError = 0x01 + DiscProtocolError = 0x02 + DiscUselessPeer = 0x03 + DiscTooManyPeers = 0x04 + DiscAlreadyConnected = 0x05 + DiscIncompatibleVersion = 0x06 + DiscInvalidIdentity = 0x07 + DiscQuitting = 0x08 + DiscUnexpectedIdentity = 0x09 + DiscSelf = 0x0a + DiscReadTimeout = 0x0b + DiscSubprotocolError = 0x10 +) + +var discReasonToString = [DiscSubprotocolError + 1]string{ + DiscRequested: "Disconnect requested", + DiscNetworkError: "Network error", + DiscProtocolError: "Breach of protocol", + DiscUselessPeer: "Useless peer", + DiscTooManyPeers: "Too many peers", + DiscAlreadyConnected: "Already connected", + DiscIncompatibleVersion: "Incompatible P2P protocol version", + DiscInvalidIdentity: "Invalid node identity", + DiscQuitting: "Client quitting", + DiscUnexpectedIdentity: "Unexpected identity", + DiscSelf: "Connected to self", + DiscReadTimeout: "Read timeout", + DiscSubprotocolError: "Subprotocol error", +} + +func (d DiscReason) String() string { + if len(discReasonToString) < int(d) { + return fmt.Sprintf("Unknown Reason(%d)", d) + } + return discReasonToString[d] +} + +func discReasonForError(err error) DiscReason { + peerError, ok := err.(*peerError) + if !ok { + return DiscSubprotocolError + } + switch peerError.Code { + case errP2PVersionMismatch: + return DiscIncompatibleVersion + case errPubkeyMissing, errPubkeyInvalid: + return DiscInvalidIdentity + case errPubkeyForbidden: + return DiscUselessPeer + case errInvalidMsgCode, errMagicTokenMismatch, errProtocolBreach: + return DiscProtocolError + case errPingTimeout: + return DiscReadTimeout + case errRead, errWrite, errMisc: + return DiscNetworkError + default: + return DiscSubprotocolError + } } diff --git a/p2p/peer_error_handler.go b/p2p/peer_error_handler.go deleted file mode 100644 index ca6cae4db..000000000 --- a/p2p/peer_error_handler.go +++ /dev/null @@ -1,101 +0,0 @@ -package p2p - -import ( - "net" -) - -const ( - severityThreshold = 10 -) - -type DisconnectRequest struct { - addr net.Addr - reason DiscReason -} - -type PeerErrorHandler struct { - quit chan chan bool - address net.Addr - peerDisconnect chan DisconnectRequest - severity int - peerErrorChan chan *PeerError - blacklist Blacklist -} - -func NewPeerErrorHandler(address net.Addr, peerDisconnect chan DisconnectRequest, peerErrorChan chan *PeerError, blacklist Blacklist) *PeerErrorHandler { - return &PeerErrorHandler{ - quit: make(chan chan bool), - address: address, - peerDisconnect: peerDisconnect, - peerErrorChan: peerErrorChan, - blacklist: blacklist, - } -} - -func (self *PeerErrorHandler) Start() { - go self.listen() -} - -func (self *PeerErrorHandler) Stop() { - q := make(chan bool) - self.quit <- q - <-q -} - -func (self *PeerErrorHandler) listen() { - for { - select { - case peerError, ok := <-self.peerErrorChan: - if ok { - logger.Debugf("error %v\n", peerError) - go self.handle(peerError) - } else { - return - } - case q := <-self.quit: - q <- true - return - } - } -} - -func (self *PeerErrorHandler) handle(peerError *PeerError) { - reason := DiscReason(' ') - switch peerError.Code { - case P2PVersionMismatch: - reason = DiscIncompatibleVersion - case PubkeyMissing, PubkeyInvalid: - reason = DiscInvalidIdentity - case PubkeyForbidden: - reason = DiscUselessPeer - case InvalidMsgCode, PacketTooShort, PayloadTooShort, MagicTokenMismatch, EmptyPayload, ProtocolBreach: - reason = DiscProtocolError - case PingTimeout: - reason = DiscReadTimeout - case WriteError, MiscError: - reason = DiscNetworkError - case InvalidGenesis, InvalidNetworkId, InvalidProtocolVersion: - reason = DiscSubprotocolError - default: - self.severity += self.getSeverity(peerError) - } - - if self.severity >= severityThreshold { - reason = DiscSubprotocolError - } - if reason != DiscReason(' ') { - self.peerDisconnect <- DisconnectRequest{ - addr: self.address, - reason: reason, - } - } -} - -func (self *PeerErrorHandler) getSeverity(peerError *PeerError) int { - switch peerError.Code { - case ReadError: - return 4 //tolerate 3 :) - default: - return 1 - } -} diff --git a/p2p/peer_error_handler_test.go b/p2p/peer_error_handler_test.go deleted file mode 100644 index 790a7443b..000000000 --- a/p2p/peer_error_handler_test.go +++ /dev/null @@ -1,34 +0,0 @@ -package p2p - -import ( - // "fmt" - "net" - "testing" - "time" -) - -func TestPeerErrorHandler(t *testing.T) { - address := &net.TCPAddr{IP: net.IP([]byte{1, 2, 3, 4}), Port: 30303} - peerDisconnect := make(chan DisconnectRequest) - peerErrorChan := NewPeerErrorChannel() - peh := NewPeerErrorHandler(address, peerDisconnect, peerErrorChan, NewBlacklist()) - peh.Start() - defer peh.Stop() - for i := 0; i < 11; i++ { - select { - case <-peerDisconnect: - t.Errorf("expected no disconnect request") - default: - } - peerErrorChan <- NewPeerError(MiscError, "") - } - time.Sleep(1 * time.Millisecond) - select { - case request := <-peerDisconnect: - if request.addr.String() != address.String() { - t.Errorf("incorrect address %v != %v", request.addr, address) - } - default: - t.Errorf("expected disconnect request") - } -} diff --git a/p2p/peer_test.go b/p2p/peer_test.go index c37540bef..d9640292f 100644 --- a/p2p/peer_test.go +++ b/p2p/peer_test.go @@ -1,96 +1,239 @@ package p2p import ( + "bufio" "bytes" - "fmt" - // "net" + "encoding/hex" + "io/ioutil" + "net" + "reflect" "testing" "time" ) -func TestPeer(t *testing.T) { - handlers := make(Handlers) - testProtocol := &TestProtocol{Msgs: []*Msg{}} - handlers["aaa"] = func(p *Peer) Protocol { return testProtocol } - handlers["ccc"] = func(p *Peer) Protocol { return testProtocol } - addr := &TestAddr{"test:30"} - conn := NewTestNetworkConnection(addr) - _, server := SetupTestServer(handlers) - server.Handshake() - peer := NewPeer(conn, addr, true, server) - // peer.Messenger().AddProtocols([]string{"aaa", "ccc"}) - peer.Start() - defer peer.Stop() - time.Sleep(2 * time.Millisecond) - if len(conn.Out) != 1 { - t.Errorf("handshake not sent") - } else { - out := conn.Out[0] - packet := Packet(0, HandshakeMsg, P2PVersion, []byte(peer.server.identity.String()), []interface{}{peer.server.protocols}, peer.server.port, peer.server.identity.Pubkey()[1:]) - if bytes.Compare(out, packet) != 0 { - t.Errorf("incorrect handshake packet %v != %v", out, packet) +var discard = Protocol{ + Name: "discard", + Length: 1, + Run: func(p *Peer, rw MsgReadWriter) error { + for { + msg, err := rw.ReadMsg() + if err != nil { + return err + } + if err = msg.Discard(); err != nil { + return err + } } - } + }, +} - packet := Packet(0, HandshakeMsg, P2PVersion, []byte("peer"), []interface{}{"bbb", "aaa", "ccc"}, 30, []byte("0000000000000000000000000000000000000000000000000000000000000000")) - conn.In(0, packet) - time.Sleep(10 * time.Millisecond) +func testPeer(protos []Protocol) (net.Conn, *Peer, <-chan error) { + conn1, conn2 := net.Pipe() + id := NewSimpleClientIdentity("test", "0", "0", "public key") + peer := newPeer(conn1, protos, nil) + peer.ourID = id + peer.pubkeyHook = func(*peerAddr) error { return nil } + errc := make(chan error, 1) + go func() { + _, err := peer.loop() + errc <- err + }() + return conn2, peer, errc +} - pro, _ := peer.Messenger().protocols[0].(*BaseProtocol) - if pro.state != handshakeReceived { - t.Errorf("handshake not received") - } - if peer.Port != 30 { - t.Errorf("port incorrectly set") +func TestPeerProtoReadMsg(t *testing.T) { + defer testlog(t).detach() + + done := make(chan struct{}) + proto := Protocol{ + Name: "a", + Length: 5, + Run: func(peer *Peer, rw MsgReadWriter) error { + msg, err := rw.ReadMsg() + if err != nil { + t.Errorf("read error: %v", err) + } + if msg.Code != 2 { + t.Errorf("incorrect msg code %d relayed to protocol", msg.Code) + } + data, err := ioutil.ReadAll(msg.Payload) + if err != nil { + t.Errorf("payload read error: %v", err) + } + expdata, _ := hex.DecodeString("0183303030") + if !bytes.Equal(expdata, data) { + t.Errorf("incorrect msg data %x", data) + } + close(done) + return nil + }, } - if peer.Id != "peer" { - t.Errorf("id incorrectly set") + + net, peer, errc := testPeer([]Protocol{proto}) + defer net.Close() + peer.startSubprotocols([]Cap{proto.cap()}) + + writeMsg(net, NewMsg(18, 1, "000")) + select { + case <-done: + case err := <-errc: + t.Errorf("peer returned: %v", err) + case <-time.After(2 * time.Second): + t.Errorf("receive timeout") } - if string(peer.Pubkey) != "0000000000000000000000000000000000000000000000000000000000000000" { - t.Errorf("pubkey incorrectly set") +} + +func TestPeerProtoReadLargeMsg(t *testing.T) { + defer testlog(t).detach() + + msgsize := uint32(10 * 1024 * 1024) + done := make(chan struct{}) + proto := Protocol{ + Name: "a", + Length: 5, + Run: func(peer *Peer, rw MsgReadWriter) error { + msg, err := rw.ReadMsg() + if err != nil { + t.Errorf("read error: %v", err) + } + if msg.Size != msgsize+4 { + t.Errorf("incorrect msg.Size, got %d, expected %d", msg.Size, msgsize) + } + msg.Discard() + close(done) + return nil + }, } - fmt.Println(peer.Caps) - if len(peer.Caps) != 3 || peer.Caps[0] != "aaa" || peer.Caps[1] != "bbb" || peer.Caps[2] != "ccc" { - t.Errorf("protocols incorrectly set") + + net, peer, errc := testPeer([]Protocol{proto}) + defer net.Close() + peer.startSubprotocols([]Cap{proto.cap()}) + + writeMsg(net, NewMsg(18, make([]byte, msgsize))) + select { + case <-done: + case err := <-errc: + t.Errorf("peer returned: %v", err) + case <-time.After(2 * time.Second): + t.Errorf("receive timeout") } +} - msg, _ := NewMsg(3) - err := peer.Write("aaa", msg) - if err != nil { - t.Errorf("expect no error for known protocol: %v", err) - } else { - time.Sleep(1 * time.Millisecond) - if len(conn.Out) != 2 { - t.Errorf("msg not written") - } else { - out := conn.Out[1] - packet := Packet(16, 3) - if bytes.Compare(out, packet) != 0 { - t.Errorf("incorrect packet %v != %v", out, packet) +func TestPeerProtoEncodeMsg(t *testing.T) { + defer testlog(t).detach() + + proto := Protocol{ + Name: "a", + Length: 2, + Run: func(peer *Peer, rw MsgReadWriter) error { + if err := rw.EncodeMsg(2); err == nil { + t.Error("expected error for out-of-range msg code, got nil") } - } + if err := rw.EncodeMsg(1); err != nil { + t.Errorf("write error: %v", err) + } + return nil + }, } + net, peer, _ := testPeer([]Protocol{proto}) + defer net.Close() + peer.startSubprotocols([]Cap{proto.cap()}) - msg, _ = NewMsg(2) - err = peer.Write("ccc", msg) + bufr := bufio.NewReader(net) + msg, err := readMsg(bufr) if err != nil { + t.Errorf("read error: %v", err) + } + if msg.Code != 17 { + t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17) + } +} + +func TestPeerWrite(t *testing.T) { + defer testlog(t).detach() + + net, peer, peerErr := testPeer([]Protocol{discard}) + defer net.Close() + peer.startSubprotocols([]Cap{discard.cap()}) + + // test write errors + if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil { + t.Errorf("expected error for unknown protocol, got nil") + } + if err := peer.writeProtoMsg("discard", NewMsg(8)); err == nil { + t.Errorf("expected error for out-of-range msg code, got nil") + } else if perr, ok := err.(*peerError); !ok || perr.Code != errInvalidMsgCode { + t.Errorf("wrong error for out-of-range msg code, got %#v", err) + } + + // setup for reading the message on the other end + read := make(chan struct{}) + go func() { + bufr := bufio.NewReader(net) + msg, err := readMsg(bufr) + if err != nil { + t.Errorf("read error: %v", err) + } else if msg.Code != 16 { + t.Errorf("wrong code, got %d, expected %d", msg.Code, 16) + } + msg.Discard() + close(read) + }() + + // test succcessful write + if err := peer.writeProtoMsg("discard", NewMsg(0)); err != nil { t.Errorf("expect no error for known protocol: %v", err) - } else { - time.Sleep(1 * time.Millisecond) - if len(conn.Out) != 3 { - t.Errorf("msg not written") - } else { - out := conn.Out[2] - packet := Packet(21, 2) - if bytes.Compare(out, packet) != 0 { - t.Errorf("incorrect packet %v != %v", out, packet) - } + } + select { + case <-read: + case err := <-peerErr: + t.Fatalf("peer stopped: %v", err) + } +} + +func TestPeerActivity(t *testing.T) { + // shorten inactivityTimeout while this test is running + oldT := inactivityTimeout + defer func() { inactivityTimeout = oldT }() + inactivityTimeout = 20 * time.Millisecond + + net, peer, peerErr := testPeer([]Protocol{discard}) + defer net.Close() + peer.startSubprotocols([]Cap{discard.cap()}) + + sub := peer.activity.Subscribe(time.Time{}) + defer sub.Unsubscribe() + + for i := 0; i < 6; i++ { + writeMsg(net, NewMsg(16)) + select { + case <-sub.Chan(): + case <-time.After(inactivityTimeout / 2): + t.Fatal("no event within ", inactivityTimeout/2) + case err := <-peerErr: + t.Fatal("peer error", err) } } - err = peer.Write("bbb", msg) - time.Sleep(1 * time.Millisecond) - if err == nil { - t.Errorf("expect error for unknown protocol") + select { + case <-time.After(inactivityTimeout * 2): + case <-sub.Chan(): + t.Fatal("got activity event while connection was inactive") + case err := <-peerErr: + t.Fatal("peer error", err) + } +} + +func TestNewPeer(t *testing.T) { + id := NewSimpleClientIdentity("clientid", "version", "customid", "pubkey") + caps := []Cap{{"foo", 2}, {"bar", 3}} + p := NewPeer(id, caps) + if !reflect.DeepEqual(p.Caps(), caps) { + t.Errorf("Caps mismatch: got %v, expected %v", p.Caps(), caps) + } + if p.Identity() != id { + t.Errorf("Identity mismatch: got %v, expected %v", p.Identity(), id) } + // Should not hang. + p.Disconnect(DiscAlreadyConnected) } diff --git a/p2p/protocol.go b/p2p/protocol.go index 5d05ced7d..28eab87cd 100644 --- a/p2p/protocol.go +++ b/p2p/protocol.go @@ -2,277 +2,294 @@ package p2p import ( "bytes" - "fmt" - "net" - "sort" - "sync" "time" -) - -type Protocol interface { - Start() - Stop() - HandleIn(*Msg, chan *Msg) - HandleOut(*Msg) bool - Offset() MsgCode - Name() string -} -const ( - P2PVersion = 0 - pingTimeout = 2 - pingGracePeriod = 2 + "github.com/ethereum/go-ethereum/ethutil" ) -const ( - HandshakeMsg = iota - DiscMsg - PingMsg - PongMsg - GetPeersMsg - PeersMsg - offset = 16 -) +// Protocol represents a P2P subprotocol implementation. +type Protocol struct { + // Name should contain the official protocol name, + // often a three-letter word. + Name string + + // Version should contain the version number of the protocol. + Version uint + + // Length should contain the number of message codes used + // by the protocol. + Length uint64 + + // Run is called in a new groutine when the protocol has been + // negotiated with a peer. It should read and write messages from + // rw. The Payload for each message must be fully consumed. + // + // The peer connection is closed when Start returns. It should return + // any protocol-level error (such as an I/O error) that is + // encountered. + Run func(peer *Peer, rw MsgReadWriter) error +} -type ProtocolState uint8 +func (p Protocol) cap() Cap { + return Cap{p.Name, p.Version} +} const ( - nullState = iota - handshakeReceived + baseProtocolVersion = 2 + baseProtocolLength = uint64(16) + baseProtocolMaxMsgSize = 10 * 1024 * 1024 ) -type DiscReason byte - const ( - // Values are given explicitly instead of by iota because these values are - // defined by the wire protocol spec; it is easier for humans to ensure - // correctness when values are explicit. - DiscRequested = 0x00 - DiscNetworkError = 0x01 - DiscProtocolError = 0x02 - DiscUselessPeer = 0x03 - DiscTooManyPeers = 0x04 - DiscAlreadyConnected = 0x05 - DiscIncompatibleVersion = 0x06 - DiscInvalidIdentity = 0x07 - DiscQuitting = 0x08 - DiscUnexpectedIdentity = 0x09 - DiscSelf = 0x0a - DiscReadTimeout = 0x0b - DiscSubprotocolError = 0x10 + // devp2p message codes + handshakeMsg = 0x00 + discMsg = 0x01 + pingMsg = 0x02 + pongMsg = 0x03 + getPeersMsg = 0x04 + peersMsg = 0x05 ) -var discReasonToString = map[DiscReason]string{ - DiscRequested: "Disconnect requested", - DiscNetworkError: "Network error", - DiscProtocolError: "Breach of protocol", - DiscUselessPeer: "Useless peer", - DiscTooManyPeers: "Too many peers", - DiscAlreadyConnected: "Already connected", - DiscIncompatibleVersion: "Incompatible P2P protocol version", - DiscInvalidIdentity: "Invalid node identity", - DiscQuitting: "Client quitting", - DiscUnexpectedIdentity: "Unexpected identity", - DiscSelf: "Connected to self", - DiscReadTimeout: "Read timeout", - DiscSubprotocolError: "Subprotocol error", -} - -func (d DiscReason) String() string { - if len(discReasonToString) < int(d) { - return "Unknown" - } - - return discReasonToString[d] -} - -type BaseProtocol struct { - peer *Peer - state ProtocolState - stateLock sync.RWMutex +// handshake is the structure of a handshake list. +type handshake struct { + Version uint64 + ID string + Caps []Cap + ListenPort uint64 + NodeID []byte } -func NewBaseProtocol(peer *Peer) *BaseProtocol { - self := &BaseProtocol{ - peer: peer, - } - - return self +func (h *handshake) String() string { + return h.ID } - -func (self *BaseProtocol) Start() { - if self.peer != nil { - self.peer.Write("", self.peer.Server().Handshake()) - go self.peer.Messenger().PingPong( - pingTimeout*time.Second, - pingGracePeriod*time.Second, - self.Ping, - self.Timeout, - ) - } +func (h *handshake) Pubkey() []byte { + return h.NodeID } -func (self *BaseProtocol) Stop() { +// Cap is the structure of a peer capability. +type Cap struct { + Name string + Version uint } -func (self *BaseProtocol) Ping() { - msg, _ := NewMsg(PingMsg) - self.peer.Write("", msg) +func (cap Cap) RlpData() interface{} { + return []interface{}{cap.Name, cap.Version} } -func (self *BaseProtocol) Timeout() { - self.peerError(PingTimeout, "") -} +type capsByName []Cap -func (self *BaseProtocol) Name() string { - return "" -} +func (cs capsByName) Len() int { return len(cs) } +func (cs capsByName) Less(i, j int) bool { return cs[i].Name < cs[j].Name } +func (cs capsByName) Swap(i, j int) { cs[i], cs[j] = cs[j], cs[i] } -func (self *BaseProtocol) Offset() MsgCode { - return offset +type baseProtocol struct { + rw MsgReadWriter + peer *Peer } -func (self *BaseProtocol) CheckState(state ProtocolState) bool { - self.stateLock.RLock() - self.stateLock.RUnlock() - if self.state != state { - return false - } else { - return true +func runBaseProtocol(peer *Peer, rw MsgReadWriter) error { + bp := &baseProtocol{rw, peer} + if err := bp.doHandshake(rw); err != nil { + return err } + // run main loop + quit := make(chan error, 1) + go func() { + for { + if err := bp.handle(rw); err != nil { + quit <- err + break + } + } + }() + return bp.loop(quit) } -func (self *BaseProtocol) HandleIn(msg *Msg, response chan *Msg) { - if msg.Code() == HandshakeMsg { - self.handleHandshake(msg) - } else { - if !self.CheckState(handshakeReceived) { - self.peerError(ProtocolBreach, "message code %v not allowed", msg.Code()) - close(response) - return - } - switch msg.Code() { - case DiscMsg: - logger.Infof("Disconnect requested from peer %v, reason", DiscReason(msg.Data().Get(0).Uint())) - self.peer.Server().PeerDisconnect() <- DisconnectRequest{ - addr: self.peer.Address, - reason: DiscRequested, - } - case PingMsg: - out, _ := NewMsg(PongMsg) - response <- out - case PongMsg: - case GetPeersMsg: - // Peer asked for list of connected peers - if out, err := self.peer.Server().PeersMessage(); err != nil { - response <- out +var pingTimeout = 2 * time.Second + +func (bp *baseProtocol) loop(quit <-chan error) error { + ping := time.NewTimer(pingTimeout) + activity := bp.peer.activity.Subscribe(time.Time{}) + lastActive := time.Time{} + defer ping.Stop() + defer activity.Unsubscribe() + + getPeersTick := time.NewTicker(10 * time.Second) + defer getPeersTick.Stop() + err := bp.rw.EncodeMsg(getPeersMsg) + + for err == nil { + select { + case err = <-quit: + return err + case <-getPeersTick.C: + err = bp.rw.EncodeMsg(getPeersMsg) + case event := <-activity.Chan(): + ping.Reset(pingTimeout) + lastActive = event.(time.Time) + case t := <-ping.C: + if lastActive.Add(pingTimeout * 2).Before(t) { + err = newPeerError(errPingTimeout, "") + } else if lastActive.Add(pingTimeout).Before(t) { + err = bp.rw.EncodeMsg(pingMsg) } - case PeersMsg: - self.handlePeers(msg) - default: - self.peerError(InvalidMsgCode, "unknown message code %v", msg.Code()) } } - close(response) + return err } -func (self *BaseProtocol) HandleOut(msg *Msg) (allowed bool) { - // somewhat overly paranoid - allowed = msg.Code() == HandshakeMsg || msg.Code() == DiscMsg || msg.Code() < self.Offset() && self.CheckState(handshakeReceived) - return -} - -func (self *BaseProtocol) peerError(errorCode ErrorCode, format string, v ...interface{}) { - err := NewPeerError(errorCode, format, v...) - logger.Warnln(err) - fmt.Println(self.peer, err) - if self.peer != nil { - self.peer.PeerErrorChan() <- err +func (bp *baseProtocol) handle(rw MsgReadWriter) error { + msg, err := rw.ReadMsg() + if err != nil { + return err } -} - -func (self *BaseProtocol) handlePeers(msg *Msg) { - it := msg.Data().NewIterator() - for it.Next() { - ip := net.IP(it.Value().Get(0).Bytes()) - port := it.Value().Get(1).Uint() - address := &net.TCPAddr{IP: ip, Port: int(port)} - go self.peer.Server().PeerConnect(address) + if msg.Size > baseProtocolMaxMsgSize { + return newPeerError(errMisc, "message too big") } -} + // make sure that the payload has been fully consumed + defer msg.Discard() -func (self *BaseProtocol) handleHandshake(msg *Msg) { - self.stateLock.Lock() - defer self.stateLock.Unlock() - if self.state != nullState { - self.peerError(ProtocolBreach, "extra handshake") - return - } + switch msg.Code { + case handshakeMsg: + return newPeerError(errProtocolBreach, "extra handshake received") - c := msg.Data() + case discMsg: + var reason DiscReason + if err := msg.Decode(&reason); err != nil { + return err + } + bp.peer.Disconnect(reason) + return nil + + case pingMsg: + return bp.rw.EncodeMsg(pongMsg) + + case pongMsg: + + case getPeersMsg: + peers := bp.peerList() + // this is dangerous. the spec says that we should _delay_ + // sending the response if no new information is available. + // this means that would need to send a response later when + // new peers become available. + // + // TODO: add event mechanism to notify baseProtocol for new peers + if len(peers) > 0 { + return bp.rw.EncodeMsg(peersMsg, peers) + } - var ( - p2pVersion = c.Get(0).Uint() - id = c.Get(1).Str() - caps = c.Get(2) - port = c.Get(3).Uint() - pubkey = c.Get(4).Bytes() - ) - fmt.Printf("handshake received %v, %v, %v, %v, %v ", p2pVersion, id, caps, port, pubkey) + case peersMsg: + var peers []*peerAddr + if err := msg.Decode(&peers); err != nil { + return err + } + for _, addr := range peers { + bp.peer.Debugf("received peer suggestion: %v", addr) + bp.peer.newPeerAddr <- addr + } - // Check correctness of p2p protocol version - if p2pVersion != P2PVersion { - self.peerError(P2PVersionMismatch, "Require protocol %d, received %d\n", P2PVersion, p2pVersion) - return + default: + return newPeerError(errInvalidMsgCode, "unknown message code %v", msg.Code) } + return nil +} - // Handle the pub key (validation, uniqueness) - if len(pubkey) == 0 { - self.peerError(PubkeyMissing, "not supplied in handshake.") - return +func (bp *baseProtocol) doHandshake(rw MsgReadWriter) error { + // send our handshake + if err := rw.WriteMsg(bp.handshakeMsg()); err != nil { + return err } - if len(pubkey) != 64 { - self.peerError(PubkeyInvalid, "require 512 bit, got %v", len(pubkey)*8) - return + // read and handle remote handshake + msg, err := rw.ReadMsg() + if err != nil { + return err } - - // Self connect detection - if bytes.Compare(self.peer.Server().ClientIdentity().Pubkey()[1:], pubkey) == 0 { - self.peerError(PubkeyForbidden, "not allowed to connect to self") - return + if msg.Code != handshakeMsg { + return newPeerError(errProtocolBreach, "first message must be handshake, got %x", msg.Code) + } + if msg.Size > baseProtocolMaxMsgSize { + return newPeerError(errMisc, "message too big") } - // register pubkey on server. this also sets the pubkey on the peer (need lock) - if err := self.peer.Server().RegisterPubkey(self.peer, pubkey); err != nil { - self.peerError(PubkeyForbidden, err.Error()) - return + var hs handshake + if err := msg.Decode(&hs); err != nil { + return err } - // check port - if self.peer.Inbound { - uint16port := uint16(port) - if self.peer.Port > 0 && self.peer.Port != uint16port { - self.peerError(PortMismatch, "port mismatch: %v != %v", self.peer.Port, port) - return - } else { - self.peer.Port = uint16port + // validate handshake info + if hs.Version != baseProtocolVersion { + return newPeerError(errP2PVersionMismatch, "Require protocol %d, received %d\n", + baseProtocolVersion, hs.Version) + } + if len(hs.NodeID) == 0 { + return newPeerError(errPubkeyMissing, "") + } + if len(hs.NodeID) != 64 { + return newPeerError(errPubkeyInvalid, "require 512 bit, got %v", len(hs.NodeID)*8) + } + if da := bp.peer.dialAddr; da != nil { + // verify that the peer we wanted to connect to + // actually holds the target public key. + if da.Pubkey != nil && !bytes.Equal(da.Pubkey, hs.NodeID) { + return newPeerError(errPubkeyForbidden, "dial address pubkey mismatch") } } - - capsIt := caps.NewIterator() - for capsIt.Next() { - cap := capsIt.Value().Str() - self.peer.Caps = append(self.peer.Caps, cap) + pa := newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID) + if err := bp.peer.pubkeyHook(pa); err != nil { + return newPeerError(errPubkeyForbidden, "%v", err) } - sort.Strings(self.peer.Caps) - self.peer.Messenger().AddProtocols(self.peer.Caps) - self.peer.Id = id + // TODO: remove Caps with empty name - self.state = handshakeReceived + var addr *peerAddr + if hs.ListenPort != 0 { + addr = newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID) + addr.Port = hs.ListenPort + } + bp.peer.setHandshakeInfo(&hs, addr, hs.Caps) + bp.peer.startSubprotocols(hs.Caps) + return nil +} + +func (bp *baseProtocol) handshakeMsg() Msg { + var ( + port uint64 + caps []interface{} + ) + if bp.peer.ourListenAddr != nil { + port = bp.peer.ourListenAddr.Port + } + for _, proto := range bp.peer.protocols { + caps = append(caps, proto.cap()) + } + return NewMsg(handshakeMsg, + baseProtocolVersion, + bp.peer.ourID.String(), + caps, + port, + bp.peer.ourID.Pubkey()[1:], + ) +} - //p.ethereum.PushPeer(p) - // p.ethereum.reactor.Post("peerList", p.ethereum.Peers()) - return +func (bp *baseProtocol) peerList() []ethutil.RlpEncodable { + peers := bp.peer.otherPeers() + ds := make([]ethutil.RlpEncodable, 0, len(peers)) + for _, p := range peers { + p.infolock.Lock() + addr := p.listenAddr + p.infolock.Unlock() + // filter out this peer and peers that are not listening or + // have not completed the handshake. + // TODO: track previously sent peers and exclude them as well. + if p == bp.peer || addr == nil { + continue + } + ds = append(ds, addr) + } + ourAddr := bp.peer.ourListenAddr + if ourAddr != nil && !ourAddr.IP.IsLoopback() && !ourAddr.IP.IsUnspecified() { + ds = append(ds, ourAddr) + } + return ds } diff --git a/p2p/server.go b/p2p/server.go index 91bc4af5c..8a6087566 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -2,483 +2,466 @@ package p2p import ( "bytes" + "errors" "fmt" "net" - "sort" - "strconv" "sync" "time" - logpkg "github.com/ethereum/go-ethereum/logger" + "github.com/ethereum/go-ethereum/logger" ) const ( - outboundAddressPoolSize = 10 - disconnectGracePeriod = 2 + outboundAddressPoolSize = 500 + defaultDialTimeout = 10 * time.Second + portMappingUpdateInterval = 15 * time.Minute + portMappingTimeout = 20 * time.Minute ) -type Blacklist interface { - Get([]byte) (bool, error) - Put([]byte) error - Delete([]byte) error - Exists(pubkey []byte) (ok bool) -} +var srvlog = logger.NewLogger("P2P Server") + +// Server manages all peer connections. +// +// The fields of Server are used as configuration parameters. +// You should set them before starting the Server. Fields may not be +// modified while the server is running. +type Server struct { + // This field must be set to a valid client identity. + Identity ClientIdentity + + // MaxPeers is the maximum number of peers that can be + // connected. It must be greater than zero. + MaxPeers int + + // Protocols should contain the protocols supported + // by the server. Matching protocols are launched for + // each peer. + Protocols []Protocol + + // If Blacklist is set to a non-nil value, the given Blacklist + // is used to verify peer connections. + Blacklist Blacklist + + // If ListenAddr is set to a non-nil address, the server + // will listen for incoming connections. + // + // If the port is zero, the operating system will pick a port. The + // ListenAddr field will be updated with the actual address when + // the server is started. + ListenAddr string + + // If set to a non-nil value, the given NAT port mapper + // is used to make the listening port available to the + // Internet. + NAT NAT + + // If Dialer is set to a non-nil value, the given Dialer + // is used to dial outbound peer connections. + Dialer *net.Dialer + + // If NoDial is true, the server will not dial any peers. + NoDial bool + + // Hook for testing. This is useful because we can inhibit + // the whole protocol stack. + newPeerFunc peerFunc -type BlacklistMap struct { - blacklist map[string]bool lock sync.RWMutex + running bool + listener net.Listener + laddr *net.TCPAddr // real listen addr + peers []*Peer + peerSlots chan int + peerCount int + + quit chan struct{} + wg sync.WaitGroup + peerConnect chan *peerAddr + peerDisconnect chan *Peer } -func NewBlacklist() *BlacklistMap { - return &BlacklistMap{ - blacklist: make(map[string]bool), - } -} +// NAT is implemented by NAT traversal methods. +type NAT interface { + GetExternalAddress() (net.IP, error) + AddPortMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error + DeletePortMapping(protocol string, extport, intport int) error -func (self *BlacklistMap) Get(pubkey []byte) (bool, error) { - self.lock.RLock() - defer self.lock.RUnlock() - v, ok := self.blacklist[string(pubkey)] - var err error - if !ok { - err = fmt.Errorf("not found") - } - return v, err + // Should return name of the method. + String() string } -func (self *BlacklistMap) Exists(pubkey []byte) (ok bool) { - self.lock.RLock() - defer self.lock.RUnlock() - _, ok = self.blacklist[string(pubkey)] +type peerFunc func(srv *Server, c net.Conn, dialAddr *peerAddr) *Peer + +// Peers returns all connected peers. +func (srv *Server) Peers() (peers []*Peer) { + srv.lock.RLock() + defer srv.lock.RUnlock() + for _, peer := range srv.peers { + if peer != nil { + peers = append(peers, peer) + } + } return } -func (self *BlacklistMap) Put(pubkey []byte) error { - self.lock.RLock() - defer self.lock.RUnlock() - self.blacklist[string(pubkey)] = true - return nil +// PeerCount returns the number of connected peers. +func (srv *Server) PeerCount() int { + srv.lock.RLock() + defer srv.lock.RUnlock() + return srv.peerCount } -func (self *BlacklistMap) Delete(pubkey []byte) error { - self.lock.RLock() - defer self.lock.RUnlock() - delete(self.blacklist, string(pubkey)) - return nil +// SuggestPeer injects an address into the outbound address pool. +func (srv *Server) SuggestPeer(ip net.IP, port int, nodeID []byte) { + select { + case srv.peerConnect <- &peerAddr{ip, uint64(port), nodeID}: + default: // don't block + } } -type Server struct { - network Network - listening bool //needed? - dialing bool //needed? - closed bool - identity ClientIdentity - addr net.Addr - port uint16 - protocols []string - - quit chan chan bool - peersLock sync.RWMutex - - maxPeers int - peers []*Peer - peerSlots chan int - peersTable map[string]int - peersMsg *Msg - peerCount int - - peerConnect chan net.Addr - peerDisconnect chan DisconnectRequest - blacklist Blacklist - handlers Handlers +// Broadcast sends an RLP-encoded message to all connected peers. +// This method is deprecated and will be removed later. +func (srv *Server) Broadcast(protocol string, code uint64, data ...interface{}) { + var payload []byte + if data != nil { + payload = encodePayload(data...) + } + srv.lock.RLock() + defer srv.lock.RUnlock() + for _, peer := range srv.peers { + if peer != nil { + var msg = Msg{Code: code} + if data != nil { + msg.Payload = bytes.NewReader(payload) + msg.Size = uint32(len(payload)) + } + peer.writeProtoMsg(protocol, msg) + } + } } -var logger = logpkg.NewLogger("P2P") - -func New(network Network, addr net.Addr, identity ClientIdentity, handlers Handlers, maxPeers int, blacklist Blacklist) *Server { - // get alphabetical list of protocol names from handlers map - protocols := []string{} - for protocol := range handlers { - protocols = append(protocols, protocol) +// Start starts running the server. +// Servers can be re-used and started again after stopping. +func (srv *Server) Start() (err error) { + srv.lock.Lock() + defer srv.lock.Unlock() + if srv.running { + return errors.New("server already running") } - sort.Strings(protocols) + srvlog.Infoln("Starting Server") - _, port, _ := net.SplitHostPort(addr.String()) - intport, _ := strconv.Atoi(port) - - self := &Server{ - // NewSimpleClientIdentity(clientIdentifier, version, customIdentifier) - network: network, - identity: identity, - addr: addr, - port: uint16(intport), - protocols: protocols, - - quit: make(chan chan bool), - - maxPeers: maxPeers, - peers: make([]*Peer, maxPeers), - peerSlots: make(chan int, maxPeers), - peersTable: make(map[string]int), - - peerConnect: make(chan net.Addr, outboundAddressPoolSize), - peerDisconnect: make(chan DisconnectRequest), - blacklist: blacklist, - - handlers: handlers, + // initialize fields + if srv.Identity == nil { + return fmt.Errorf("Server.Identity must be set to a non-nil identity") } - for i := 0; i < maxPeers; i++ { - self.peerSlots <- i // fill up with indexes + if srv.MaxPeers <= 0 { + return fmt.Errorf("Server.MaxPeers must be > 0") } - return self -} - -func (self *Server) NewAddr(host string, port int) (addr net.Addr, err error) { - addr, err = self.network.NewAddr(host, port) - return -} - -func (self *Server) ParseAddr(address string) (addr net.Addr, err error) { - addr, err = self.network.ParseAddr(address) - return -} - -func (self *Server) ClientIdentity() ClientIdentity { - return self.identity -} - -func (self *Server) PeersMessage() (msg *Msg, err error) { - // TODO: memoize and reset when peers change - self.peersLock.RLock() - defer self.peersLock.RUnlock() - msg = self.peersMsg - if msg == nil { - var peerData []interface{} - for _, i := range self.peersTable { - peer := self.peers[i] - peerData = append(peerData, peer.Encode()) - } - if len(peerData) == 0 { - err = fmt.Errorf("no peers") - } else { - msg, err = NewMsg(PeersMsg, peerData...) - self.peersMsg = msg //memoize - } + srv.quit = make(chan struct{}) + srv.peers = make([]*Peer, srv.MaxPeers) + srv.peerSlots = make(chan int, srv.MaxPeers) + srv.peerConnect = make(chan *peerAddr, outboundAddressPoolSize) + srv.peerDisconnect = make(chan *Peer) + if srv.newPeerFunc == nil { + srv.newPeerFunc = newServerPeer + } + if srv.Blacklist == nil { + srv.Blacklist = NewBlacklist() + } + if srv.Dialer == nil { + srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout} } - return -} -func (self *Server) Peers() (peers []*Peer) { - self.peersLock.RLock() - defer self.peersLock.RUnlock() - for _, peer := range self.peers { - if peer != nil { - peers = append(peers, peer) + if srv.ListenAddr != "" { + if err := srv.startListening(); err != nil { + return err } } - return -} - -func (self *Server) PeerCount() int { - self.peersLock.RLock() - defer self.peersLock.RUnlock() - return self.peerCount -} - -var getPeersMsg, _ = NewMsg(GetPeersMsg) + if !srv.NoDial { + srv.wg.Add(1) + go srv.dialLoop() + } + if srv.NoDial && srv.ListenAddr == "" { + srvlog.Warnln("I will be kind-of useless, neither dialing nor listening.") + } -func (self *Server) PeerConnect(addr net.Addr) { - // TODO: should buffer, filter and uniq - // send GetPeersMsg if not blocking - select { - case self.peerConnect <- addr: // not enough peers - self.Broadcast("", getPeersMsg) - default: // we dont care + // make all slots available + for i := range srv.peers { + srv.peerSlots <- i } + // note: discLoop is not part of WaitGroup + go srv.discLoop() + srv.running = true + return nil } -func (self *Server) PeerDisconnect() chan DisconnectRequest { - return self.peerDisconnect +func (srv *Server) startListening() error { + listener, err := net.Listen("tcp", srv.ListenAddr) + if err != nil { + return err + } + srv.ListenAddr = listener.Addr().String() + srv.laddr = listener.Addr().(*net.TCPAddr) + srv.listener = listener + srv.wg.Add(1) + go srv.listenLoop() + if !srv.laddr.IP.IsLoopback() && srv.NAT != nil { + srv.wg.Add(1) + go srv.natLoop(srv.laddr.Port) + } + return nil } -func (self *Server) Blacklist() Blacklist { - return self.blacklist -} +// Stop terminates the server and all active peer connections. +// It blocks until all active connections have been closed. +func (srv *Server) Stop() { + srv.lock.Lock() + if !srv.running { + srv.lock.Unlock() + return + } + srv.running = false + srv.lock.Unlock() -func (self *Server) Handlers() Handlers { - return self.handlers -} + srvlog.Infoln("Stopping server") + if srv.listener != nil { + // this unblocks listener Accept + srv.listener.Close() + } + close(srv.quit) + for _, peer := range srv.Peers() { + peer.Disconnect(DiscQuitting) + } + srv.wg.Wait() -func (self *Server) Broadcast(protocol string, msg *Msg) { - self.peersLock.RLock() - defer self.peersLock.RUnlock() - for _, peer := range self.peers { - if peer != nil { - peer.Write(protocol, msg) - } + // wait till they actually disconnect + // this is checked by claiming all peerSlots. + // slots become available as the peers disconnect. + for i := 0; i < cap(srv.peerSlots); i++ { + <-srv.peerSlots } + // terminate discLoop + close(srv.peerDisconnect) } -// Start the server -func (self *Server) Start(listen bool, dial bool) { - self.network.Start() - if listen { - listener, err := self.network.Listener(self.addr) - if err != nil { - logger.Warnf("Error initializing listener: %v", err) - logger.Warnf("Connection listening disabled") - self.listening = false - } else { - self.listening = true - logger.Infoln("Listen on %v: ready and accepting connections", listener.Addr()) - go self.inboundPeerHandler(listener) - } +func (srv *Server) discLoop() { + for peer := range srv.peerDisconnect { + // peer has just disconnected. free up its slot. + srvlog.Infof("%v is gone", peer) + srv.peerSlots <- peer.slot + srv.lock.Lock() + srv.peers[peer.slot] = nil + srv.lock.Unlock() } - if dial { - dialer, err := self.network.Dialer(self.addr) - if err != nil { - logger.Warnf("Error initializing dialer: %v", err) - logger.Warnf("Connection dialout disabled") - self.dialing = false - } else { - self.dialing = true - logger.Infoln("Dial peers watching outbound address pool") - go self.outboundPeerHandler(dialer) - } - } - logger.Infoln("server started") } -func (self *Server) Stop() { - logger.Infoln("server stopping...") - // // quit one loop if dialing - if self.dialing { - logger.Infoln("stop dialout...") - dialq := make(chan bool) - self.quit <- dialq - <-dialq - fmt.Println("quit another") - } - // quit the other loop if listening - if self.listening { - logger.Infoln("stop listening...") - listenq := make(chan bool) - self.quit <- listenq - <-listenq - fmt.Println("quit one") - } - - fmt.Println("quit waited") - - logger.Infoln("stopping peers...") - peers := []net.Addr{} - self.peersLock.RLock() - self.closed = true - for _, peer := range self.peers { - if peer != nil { - peers = append(peers, peer.Address) - } - } - self.peersLock.RUnlock() - for _, address := range peers { - go self.removePeer(DisconnectRequest{ - addr: address, - reason: DiscQuitting, - }) - } - // wait till they actually disconnect - // this is checked by draining the peerSlots (slots are released back if a peer is removed) - i := 0 - fmt.Println("draining peers") +// main loop for adding connections via listening +func (srv *Server) listenLoop() { + defer srv.wg.Done() -FOR: + srvlog.Infoln("Listening on", srv.listener.Addr()) for { select { - case slot := <-self.peerSlots: - i++ - fmt.Printf("%v: found slot %v", i, slot) - if i == self.maxPeers { - break FOR + case slot := <-srv.peerSlots: + conn, err := srv.listener.Accept() + if err != nil { + srv.peerSlots <- slot + return } + srvlog.Debugf("Accepted conn %v (slot %d)\n", conn.RemoteAddr(), slot) + srv.addPeer(conn, nil, slot) + case <-srv.quit: + return } } - logger.Infoln("server stopped") } -// main loop for adding connections via listening -func (self *Server) inboundPeerHandler(listener net.Listener) { +func (srv *Server) natLoop(port int) { + defer srv.wg.Done() for { + srv.updatePortMapping(port) select { - case slot := <-self.peerSlots: - go self.connectInboundPeer(listener, slot) - case errc := <-self.quit: - listener.Close() - fmt.Println("quit listenloop") - errc <- true + case <-time.After(portMappingUpdateInterval): + // one more round + case <-srv.quit: + srv.removePortMapping(port) return } } } -// main loop for adding outbound peers based on peerConnect address pool -// this same loop handles peer disconnect requests as well -func (self *Server) outboundPeerHandler(dialer Dialer) { - // addressChan initially set to nil (only watches peerConnect if we need more peers) - var addressChan chan net.Addr - slots := self.peerSlots - var slot *int +func (srv *Server) updatePortMapping(port int) { + srvlog.Infoln("Attempting to map port", port, "with", srv.NAT) + err := srv.NAT.AddPortMapping("tcp", port, port, "ethereum p2p", portMappingTimeout) + if err != nil { + srvlog.Errorln("Port mapping error:", err) + return + } + extip, err := srv.NAT.GetExternalAddress() + if err != nil { + srvlog.Errorln("Error getting external IP:", err) + return + } + srv.lock.Lock() + extaddr := *(srv.listener.Addr().(*net.TCPAddr)) + extaddr.IP = extip + srvlog.Infoln("Mapped port, external addr is", &extaddr) + srv.laddr = &extaddr + srv.lock.Unlock() +} + +func (srv *Server) removePortMapping(port int) { + srvlog.Infoln("Removing port mapping for", port, "with", srv.NAT) + srv.NAT.DeletePortMapping("tcp", port, port) +} + +func (srv *Server) dialLoop() { + defer srv.wg.Done() + var ( + suggest chan *peerAddr + slot *int + slots = srv.peerSlots + ) for { select { case i := <-slots: // we need a peer in slot i, slot reserved slot = &i // now we can watch for candidate peers in the next loop - addressChan = self.peerConnect + suggest = srv.peerConnect // do not consume more until candidate peer is found slots = nil - case address := <-addressChan: + + case desc := <-suggest: // candidate peer found, will dial out asyncronously // if connection fails slot will be released - go self.connectOutboundPeer(dialer, address, *slot) + go srv.dialPeer(desc, *slot) // we can watch if more peers needed in the next loop - slots = self.peerSlots + slots = srv.peerSlots // until then we dont care about candidate peers - addressChan = nil - case request := <-self.peerDisconnect: - go self.removePeer(request) - case errc := <-self.quit: - if addressChan != nil && slot != nil { - self.peerSlots <- *slot + suggest = nil + + case <-srv.quit: + // give back the currently reserved slot + if slot != nil { + srv.peerSlots <- *slot } - fmt.Println("quit dialloop") - errc <- true return } } } -// check if peer address already connected -func (self *Server) connected(address net.Addr) (err error) { - self.peersLock.RLock() - defer self.peersLock.RUnlock() - // fmt.Printf("address: %v\n", address) - slot, found := self.peersTable[address.String()] - if found { - err = fmt.Errorf("already connected as peer %v (%v)", slot, address) - } - return -} - -// connect to peer via listener.Accept() -func (self *Server) connectInboundPeer(listener net.Listener, slot int) { - var address net.Addr - conn, err := listener.Accept() - if err == nil { - address = conn.RemoteAddr() - err = self.connected(address) - if err != nil { - conn.Close() - } - } - if err != nil { - logger.Debugln(err) - self.peerSlots <- slot - } else { - fmt.Printf("adding %v\n", address) - go self.addPeer(conn, address, true, slot) - } -} - // connect to peer via dial out -func (self *Server) connectOutboundPeer(dialer Dialer, address net.Addr, slot int) { - var conn net.Conn - err := self.connected(address) - if err == nil { - conn, err = dialer.Dial(address.Network(), address.String()) - } +func (srv *Server) dialPeer(desc *peerAddr, slot int) { + srvlog.Debugf("Dialing %v (slot %d)\n", desc, slot) + conn, err := srv.Dialer.Dial(desc.Network(), desc.String()) if err != nil { - logger.Debugln(err) - self.peerSlots <- slot - } else { - go self.addPeer(conn, address, false, slot) + srvlog.Errorf("Dial error: %v", err) + srv.peerSlots <- slot + return } + go srv.addPeer(conn, desc, slot) } // creates the new peer object and inserts it into its slot -func (self *Server) addPeer(conn net.Conn, address net.Addr, inbound bool, slot int) { - self.peersLock.Lock() - defer self.peersLock.Unlock() - if self.closed { - fmt.Println("oopsy, not no longer need peer") - conn.Close() //oopsy our bad - self.peerSlots <- slot // release slot - } else { - peer := NewPeer(conn, address, inbound, self) - self.peers[slot] = peer - self.peersTable[address.String()] = slot - self.peerCount++ - // reset peersmsg - self.peersMsg = nil - fmt.Printf("added peer %v %v (slot %v)\n", address, peer, slot) - peer.Start() +func (srv *Server) addPeer(conn net.Conn, desc *peerAddr, slot int) *Peer { + srv.lock.Lock() + defer srv.lock.Unlock() + if !srv.running { + conn.Close() + srv.peerSlots <- slot // release slot + return nil } + peer := srv.newPeerFunc(srv, conn, desc) + peer.slot = slot + srv.peers[slot] = peer + srv.peerCount++ + go func() { peer.loop(); srv.peerDisconnect <- peer }() + return peer } // removes peer: sending disconnect msg, stop peer, remove rom list/table, release slot -func (self *Server) removePeer(request DisconnectRequest) { - self.peersLock.Lock() - - address := request.addr - slot := self.peersTable[address.String()] - peer := self.peers[slot] - fmt.Printf("removing peer %v %v (slot %v)\n", address, peer, slot) - if peer == nil { - logger.Debugf("already removed peer on %v", address) - self.peersLock.Unlock() +func (srv *Server) removePeer(peer *Peer) { + srv.lock.Lock() + defer srv.lock.Unlock() + srvlog.Debugf("Removing peer %v %v (slot %v)\n", peer, peer.slot) + if srv.peers[peer.slot] != peer { + srvlog.Warnln("Invalid peer to remove:", peer) return } // remove from list and index - self.peerCount-- - self.peers[slot] = nil - delete(self.peersTable, address.String()) - // reset peersmsg - self.peersMsg = nil - fmt.Printf("removed peer %v (slot %v)\n", peer, slot) - self.peersLock.Unlock() - - // sending disconnect message - disconnectMsg, _ := NewMsg(DiscMsg, request.reason) - peer.Write("", disconnectMsg) - // be nice and wait - time.Sleep(disconnectGracePeriod * time.Second) - // switch off peer and close connections etc. - fmt.Println("stopping peer") - peer.Stop() - fmt.Println("stopped peer") + srv.peerCount-- + srv.peers[peer.slot] = nil // release slot to signal need for a new peer, last! - self.peerSlots <- slot + srv.peerSlots <- peer.slot +} + +func (srv *Server) verifyPeer(addr *peerAddr) error { + if srv.Blacklist.Exists(addr.Pubkey) { + return errors.New("blacklisted") + } + if bytes.Equal(srv.Identity.Pubkey()[1:], addr.Pubkey) { + return newPeerError(errPubkeyForbidden, "not allowed to connect to srv") + } + srv.lock.RLock() + defer srv.lock.RUnlock() + for _, peer := range srv.peers { + if peer != nil { + id := peer.Identity() + if id != nil && bytes.Equal(id.Pubkey(), addr.Pubkey) { + return errors.New("already connected") + } + } + } + return nil } -// fix handshake message to push to peers -func (self *Server) Handshake() *Msg { - fmt.Println(self.identity.Pubkey()[1:]) - msg, _ := NewMsg(HandshakeMsg, P2PVersion, []byte(self.identity.String()), []interface{}{self.protocols}, self.port, self.identity.Pubkey()[1:]) - return msg +type Blacklist interface { + Get([]byte) (bool, error) + Put([]byte) error + Delete([]byte) error + Exists(pubkey []byte) (ok bool) } -func (self *Server) RegisterPubkey(candidate *Peer, pubkey []byte) error { - // Check for blacklisting - if self.blacklist.Exists(pubkey) { - return fmt.Errorf("blacklisted") +type BlacklistMap struct { + blacklist map[string]bool + lock sync.RWMutex +} + +func NewBlacklist() *BlacklistMap { + return &BlacklistMap{ + blacklist: make(map[string]bool), } +} - self.peersLock.RLock() - defer self.peersLock.RUnlock() - for _, peer := range self.peers { - if peer != nil && peer != candidate && bytes.Compare(peer.Pubkey, pubkey) == 0 { - return fmt.Errorf("already connected") - } +func (self *BlacklistMap) Get(pubkey []byte) (bool, error) { + self.lock.RLock() + defer self.lock.RUnlock() + v, ok := self.blacklist[string(pubkey)] + var err error + if !ok { + err = fmt.Errorf("not found") } - candidate.Pubkey = pubkey + return v, err +} + +func (self *BlacklistMap) Exists(pubkey []byte) (ok bool) { + self.lock.RLock() + defer self.lock.RUnlock() + _, ok = self.blacklist[string(pubkey)] + return +} + +func (self *BlacklistMap) Put(pubkey []byte) error { + self.lock.RLock() + defer self.lock.RUnlock() + self.blacklist[string(pubkey)] = true + return nil +} + +func (self *BlacklistMap) Delete(pubkey []byte) error { + self.lock.RLock() + defer self.lock.RUnlock() + delete(self.blacklist, string(pubkey)) return nil } diff --git a/p2p/server_test.go b/p2p/server_test.go index f749cc490..5c0d08d39 100644 --- a/p2p/server_test.go +++ b/p2p/server_test.go @@ -2,207 +2,160 @@ package p2p import ( "bytes" - "fmt" + "io" "net" + "sync" "testing" "time" ) -type TestNetwork struct { - connections map[string]*TestNetworkConnection - dialer Dialer - maxinbound int -} - -func NewTestNetwork(maxinbound int) *TestNetwork { - connections := make(map[string]*TestNetworkConnection) - return &TestNetwork{ - connections: connections, - dialer: &TestDialer{connections}, - maxinbound: maxinbound, +func startTestServer(t *testing.T, pf peerFunc) *Server { + server := &Server{ + Identity: NewSimpleClientIdentity("clientIdentifier", "version", "customIdentifier", "pubkey"), + MaxPeers: 10, + ListenAddr: "127.0.0.1:0", + newPeerFunc: pf, } -} - -func (self *TestNetwork) Dialer(addr net.Addr) (Dialer, error) { - return self.dialer, nil -} - -func (self *TestNetwork) Listener(addr net.Addr) (net.Listener, error) { - return &TestListener{ - connections: self.connections, - addr: addr, - max: self.maxinbound, - }, nil -} - -func (self *TestNetwork) Start() error { - return nil -} - -func (self *TestNetwork) NewAddr(string, int) (addr net.Addr, err error) { - return -} - -func (self *TestNetwork) ParseAddr(string) (addr net.Addr, err error) { - return -} - -type TestAddr struct { - name string -} - -func (self *TestAddr) String() string { - return self.name -} - -func (*TestAddr) Network() string { - return "test" -} - -type TestDialer struct { - connections map[string]*TestNetworkConnection -} - -func (self *TestDialer) Dial(network string, addr string) (conn net.Conn, err error) { - address := &TestAddr{addr} - tconn := NewTestNetworkConnection(address) - self.connections[addr] = tconn - conn = net.Conn(tconn) - return -} - -type TestListener struct { - connections map[string]*TestNetworkConnection - addr net.Addr - max int - i int -} - -func (self *TestListener) Accept() (conn net.Conn, err error) { - self.i++ - if self.i > self.max { - err = fmt.Errorf("no more") - } else { - addr := &TestAddr{fmt.Sprintf("inboundpeer-%d", self.i)} - tconn := NewTestNetworkConnection(addr) - key := tconn.RemoteAddr().String() - self.connections[key] = tconn - conn = net.Conn(tconn) - fmt.Printf("accepted connection from: %v \n", addr) + if err := server.Start(); err != nil { + t.Fatalf("Could not start server: %v", err) } - return -} - -func (self *TestListener) Close() error { - return nil + return server } -func (self *TestListener) Addr() net.Addr { - return self.addr -} - -func SetupTestServer(handlers Handlers) (network *TestNetwork, server *Server) { - network = NewTestNetwork(1) - addr := &TestAddr{"test:30303"} - identity := NewSimpleClientIdentity("clientIdentifier", "version", "customIdentifier", "pubkey") - maxPeers := 2 - if handlers == nil { - handlers = make(Handlers) - } - blackist := NewBlacklist() - server = New(network, addr, identity, handlers, maxPeers, blackist) - fmt.Println(server.identity.Pubkey()) - return -} +func TestServerListen(t *testing.T) { + defer testlog(t).detach() -func TestServerListener(t *testing.T) { - network, server := SetupTestServer(nil) - server.Start(true, false) - time.Sleep(10 * time.Millisecond) - server.Stop() - peer1, ok := network.connections["inboundpeer-1"] - if !ok { - t.Error("not found inbound peer 1") - } else { - fmt.Printf("out: %v\n", peer1.Out) - if len(peer1.Out) != 2 { - t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out)) + // start the test server + connected := make(chan *Peer) + srv := startTestServer(t, func(srv *Server, conn net.Conn, dialAddr *peerAddr) *Peer { + if conn == nil { + t.Error("peer func called with nil conn") + } + if dialAddr != nil { + t.Error("peer func called with non-nil dialAddr") } + peer := newPeer(conn, nil, dialAddr) + connected <- peer + return peer + }) + defer close(connected) + defer srv.Stop() + + // dial the test server + conn, err := net.DialTimeout("tcp", srv.ListenAddr, 5*time.Second) + if err != nil { + t.Fatalf("could not dial: %v", err) } + defer conn.Close() -} - -func TestServerDialer(t *testing.T) { - network, server := SetupTestServer(nil) - server.Start(false, true) - server.peerConnect <- &TestAddr{"outboundpeer-1"} - time.Sleep(10 * time.Millisecond) - server.Stop() - peer1, ok := network.connections["outboundpeer-1"] - if !ok { - t.Error("not found outbound peer 1") - } else { - fmt.Printf("out: %v\n", peer1.Out) - if len(peer1.Out) != 2 { - t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out)) + select { + case peer := <-connected: + if peer.conn.LocalAddr().String() != conn.RemoteAddr().String() { + t.Errorf("peer started with wrong conn: got %v, want %v", + peer.conn.LocalAddr(), conn.RemoteAddr()) } + case <-time.After(1 * time.Second): + t.Error("server did not accept within one second") } } -func TestServerBroadcast(t *testing.T) { - handlers := make(Handlers) - testProtocol := &TestProtocol{Msgs: []*Msg{}} - handlers["aaa"] = func(p *Peer) Protocol { return testProtocol } - network, server := SetupTestServer(handlers) - server.Start(true, true) - server.peerConnect <- &TestAddr{"outboundpeer-1"} - time.Sleep(10 * time.Millisecond) - msg, _ := NewMsg(0) - server.Broadcast("", msg) - packet := Packet(0, 0) - time.Sleep(10 * time.Millisecond) - server.Stop() - peer1, ok := network.connections["outboundpeer-1"] - if !ok { - t.Error("not found outbound peer 1") - } else { - fmt.Printf("out: %v\n", peer1.Out) - if len(peer1.Out) != 3 { - t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out)) - } else { - if bytes.Compare(peer1.Out[1], packet) != 0 { - t.Errorf("incorrect broadcast packet %v != %v", peer1.Out[1], packet) - } - } +func TestServerDial(t *testing.T) { + defer testlog(t).detach() + + // run a fake TCP server to handle the connection. + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("could not setup listener: %v") } - peer2, ok := network.connections["inboundpeer-1"] - if !ok { - t.Error("not found inbound peer 2") - } else { - fmt.Printf("out: %v\n", peer2.Out) - if len(peer1.Out) != 3 { - t.Errorf("not enough messages sent to peer 2: %v ", len(peer2.Out)) - } else { - if bytes.Compare(peer2.Out[1], packet) != 0 { - t.Errorf("incorrect broadcast packet %v != %v", peer2.Out[1], packet) + defer listener.Close() + accepted := make(chan net.Conn) + go func() { + conn, err := listener.Accept() + if err != nil { + t.Error("acccept error:", err) + } + conn.Close() + accepted <- conn + }() + + // start the test server + connected := make(chan *Peer) + srv := startTestServer(t, func(srv *Server, conn net.Conn, dialAddr *peerAddr) *Peer { + if conn == nil { + t.Error("peer func called with nil conn") + } + peer := newPeer(conn, nil, dialAddr) + connected <- peer + return peer + }) + defer close(connected) + defer srv.Stop() + + // tell the server to connect. + connAddr := newPeerAddr(listener.Addr(), nil) + srv.peerConnect <- connAddr + + select { + case conn := <-accepted: + select { + case peer := <-connected: + if peer.conn.RemoteAddr().String() != conn.LocalAddr().String() { + t.Errorf("peer started with wrong conn: got %v, want %v", + peer.conn.RemoteAddr(), conn.LocalAddr()) } + if peer.dialAddr != connAddr { + t.Errorf("peer started with wrong dialAddr: got %v, want %v", + peer.dialAddr, connAddr) + } + case <-time.After(1 * time.Second): + t.Error("server did not launch peer within one second") } + + case <-time.After(1 * time.Second): + t.Error("server did not connect within one second") } } -func TestServerPeersMessage(t *testing.T) { - handlers := make(Handlers) - _, server := SetupTestServer(handlers) - server.Start(true, true) - defer server.Stop() - server.peerConnect <- &TestAddr{"outboundpeer-1"} - time.Sleep(10 * time.Millisecond) - peersMsg, err := server.PeersMessage() - fmt.Println(peersMsg) - if err != nil { - t.Errorf("expect no error, got %v", err) +func TestServerBroadcast(t *testing.T) { + defer testlog(t).detach() + var connected sync.WaitGroup + srv := startTestServer(t, func(srv *Server, c net.Conn, dialAddr *peerAddr) *Peer { + peer := newPeer(c, []Protocol{discard}, dialAddr) + peer.startSubprotocols([]Cap{discard.cap()}) + connected.Done() + return peer + }) + defer srv.Stop() + + // dial a bunch of conns + var conns = make([]net.Conn, 8) + connected.Add(len(conns)) + deadline := time.Now().Add(3 * time.Second) + dialer := &net.Dialer{Deadline: deadline} + for i := range conns { + conn, err := dialer.Dial("tcp", srv.ListenAddr) + if err != nil { + t.Fatalf("conn %d: dial error: %v", i, err) + } + defer conn.Close() + conn.SetDeadline(deadline) + conns[i] = conn } - if c := server.PeerCount(); c != 2 { - t.Errorf("expect 2 peers, got %v", c) + connected.Wait() + + // broadcast one message + srv.Broadcast("discard", 0, "foo") + goldbuf := new(bytes.Buffer) + writeMsg(goldbuf, NewMsg(16, "foo")) + golden := goldbuf.Bytes() + + // check that the message has been written everywhere + for i, conn := range conns { + buf := make([]byte, len(golden)) + if _, err := io.ReadFull(conn, buf); err != nil { + t.Errorf("conn %d: read error: %v", i, err) + } else if !bytes.Equal(buf, golden) { + t.Errorf("conn %d: msg mismatch\ngot: %x\nwant: %x", i, buf, golden) + } } } diff --git a/p2p/testlog_test.go b/p2p/testlog_test.go new file mode 100644 index 000000000..951d43243 --- /dev/null +++ b/p2p/testlog_test.go @@ -0,0 +1,28 @@ +package p2p + +import ( + "testing" + + "github.com/ethereum/go-ethereum/logger" +) + +type testLogger struct{ t *testing.T } + +func testlog(t *testing.T) testLogger { + logger.Reset() + l := testLogger{t} + logger.AddLogSystem(l) + return l +} + +func (testLogger) GetLogLevel() logger.LogLevel { return logger.DebugLevel } +func (testLogger) SetLogLevel(logger.LogLevel) {} + +func (l testLogger) LogPrint(level logger.LogLevel, msg string) { + l.t.Logf("%s", msg) +} + +func (testLogger) detach() { + logger.Flush() + logger.Reset() +} diff --git a/p2p/testpoc7.go b/p2p/testpoc7.go new file mode 100644 index 000000000..c0cc5c544 --- /dev/null +++ b/p2p/testpoc7.go @@ -0,0 +1,40 @@ +// +build none + +package main + +import ( + "fmt" + "log" + "net" + "os" + + "github.com/ethereum/go-ethereum/logger" + "github.com/ethereum/go-ethereum/p2p" + "github.com/obscuren/secp256k1-go" +) + +func main() { + logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, log.LstdFlags, logger.DebugLevel)) + + pub, _ := secp256k1.GenerateKeyPair() + srv := p2p.Server{ + MaxPeers: 10, + Identity: p2p.NewSimpleClientIdentity("test", "1.0", "", string(pub)), + ListenAddr: ":30303", + NAT: p2p.PMP(net.ParseIP("10.0.0.1")), + } + if err := srv.Start(); err != nil { + fmt.Println("could not start server:", err) + os.Exit(1) + } + + // add seed peers + seed, err := net.ResolveTCPAddr("tcp", "poc-7.ethdev.com:30303") + if err != nil { + fmt.Println("couldn't resolve:", err) + os.Exit(1) + } + srv.SuggestPeer(seed.IP, seed.Port, nil) + + select {} +} diff --git a/rlp/decode.go b/rlp/decode.go index 96d912f56..7d95af02b 100644 --- a/rlp/decode.go +++ b/rlp/decode.go @@ -1,6 +1,7 @@ package rlp import ( + "bufio" "encoding/binary" "errors" "fmt" @@ -24,8 +25,9 @@ type Decoder interface { DecodeRLP(*Stream) error } -// Decode parses RLP-encoded data from r and stores the result -// in the value pointed to by val. Val must be a non-nil pointer. +// Decode parses RLP-encoded data from r and stores the result in the +// value pointed to by val. Val must be a non-nil pointer. If r does +// not implement ByteReader, Decode will do its own buffering. // // Decode uses the following type-dependent decoding rules: // @@ -66,10 +68,19 @@ type Decoder interface { // // Non-empty interface types are not supported, nor are bool, float32, // float64, maps, channel types and functions. -func Decode(r ByteReader, val interface{}) error { +func Decode(r io.Reader, val interface{}) error { return NewStream(r).Decode(val) } +type decodeError struct { + msg string + typ reflect.Type +} + +func (err decodeError) Error() string { + return fmt.Sprintf("rlp: %s for %v", err.msg, err.typ) +} + func makeNumDecoder(typ reflect.Type) decoder { kind := typ.Kind() switch { @@ -83,8 +94,11 @@ func makeNumDecoder(typ reflect.Type) decoder { } func decodeInt(s *Stream, val reflect.Value) error { - num, err := s.uint(val.Type().Bits()) - if err != nil { + typ := val.Type() + num, err := s.uint(typ.Bits()) + if err == errUintOverflow { + return decodeError{"input string too long", typ} + } else if err != nil { return err } val.SetInt(int64(num)) @@ -92,8 +106,11 @@ func decodeInt(s *Stream, val reflect.Value) error { } func decodeUint(s *Stream, val reflect.Value) error { - num, err := s.uint(val.Type().Bits()) - if err != nil { + typ := val.Type() + num, err := s.uint(typ.Bits()) + if err == errUintOverflow { + return decodeError{"input string too big", typ} + } else if err != nil { return err } val.SetUint(num) @@ -175,7 +192,7 @@ func decodeList(s *Stream, val reflect.Value, elemdec decoder, maxelem int) erro i := 0 for { if i > maxelem { - return fmt.Errorf("rlp: input List has more than %d elements", maxelem) + return decodeError{"input list has too many elements", val.Type()} } if val.Kind() == reflect.Slice { // grow slice if necessary @@ -226,8 +243,6 @@ func decodeByteSlice(s *Stream, val reflect.Value) error { return err } -var errStringDoesntFitArray = errors.New("rlp: string value doesn't fit into target array") - func decodeByteArray(s *Stream, val reflect.Value) error { kind, size, err := s.Kind() if err != nil { @@ -236,14 +251,14 @@ func decodeByteArray(s *Stream, val reflect.Value) error { switch kind { case Byte: if val.Len() == 0 { - return errStringDoesntFitArray + return decodeError{"input string too big", val.Type()} } bv, _ := s.Uint() val.Index(0).SetUint(bv) zero(val, 1) case String: if uint64(val.Len()) < size { - return errStringDoesntFitArray + return decodeError{"input string too big", val.Type()} } slice := val.Slice(0, int(size)).Interface().([]byte) if err := s.readFull(slice); err != nil { @@ -293,7 +308,7 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) { } } if err = s.ListEnd(); err == errNotAtEOL { - err = errors.New("rlp: input List has too many elements") + err = decodeError{"input list has too many elements", typ} } return err } @@ -432,8 +447,23 @@ type Stream struct { type listpos struct{ pos, size uint64 } -func NewStream(r ByteReader) *Stream { - return &Stream{r: r, uintbuf: make([]byte, 8), kind: -1} +// NewStream creates a new stream reading from r. +// If r does not implement ByteReader, the Stream will +// introduce its own buffering. +func NewStream(r io.Reader) *Stream { + s := new(Stream) + s.Reset(r) + return s +} + +// NewListStream creates a new stream that pretends to be positioned +// at an encoded list of the given length. +func NewListStream(r io.Reader, len uint64) *Stream { + s := new(Stream) + s.Reset(r) + s.kind = List + s.size = len + return s } // Bytes reads an RLP string and returns its contents as a byte slice. @@ -459,6 +489,8 @@ func (s *Stream) Bytes() ([]byte, error) { } } +var errUintOverflow = errors.New("rlp: uint overflow") + // Uint reads an RLP string of up to 8 bytes and returns its contents // as an unsigned integer. If the input does not contain an RLP string, the // returned error will be ErrExpectedString. @@ -477,7 +509,7 @@ func (s *Stream) uint(maxbits int) (uint64, error) { return uint64(s.byteval), nil case String: if size > uint64(maxbits/8) { - return 0, fmt.Errorf("rlp: string is larger than %d bits", maxbits) + return 0, errUintOverflow } return s.readUint(byte(size)) default: @@ -543,6 +575,23 @@ func (s *Stream) Decode(val interface{}) error { return info.decoder(s, rval.Elem()) } +// Reset discards any information about the current decoding context +// and starts reading from r. If r does not also implement ByteReader, +// Stream will do its own buffering. +func (s *Stream) Reset(r io.Reader) { + bufr, ok := r.(ByteReader) + if !ok { + bufr = bufio.NewReader(r) + } + s.r = bufr + s.stack = s.stack[:0] + s.size = 0 + s.kind = -1 + if s.uintbuf == nil { + s.uintbuf = make([]byte, 8) + } +} + // Kind returns the kind and size of the next value in the // input stream. // diff --git a/rlp/decode_test.go b/rlp/decode_test.go index eb1618299..3b60234dd 100644 --- a/rlp/decode_test.go +++ b/rlp/decode_test.go @@ -3,7 +3,6 @@ package rlp import ( "bytes" "encoding/hex" - "errors" "fmt" "io" "math/big" @@ -54,6 +53,24 @@ func TestStreamKind(t *testing.T) { } } +func TestNewListStream(t *testing.T) { + ls := NewListStream(bytes.NewReader(unhex("0101010101")), 3) + if k, size, err := ls.Kind(); k != List || size != 3 || err != nil { + t.Errorf("Kind() returned (%v, %d, %v), expected (List, 3, nil)", k, size, err) + } + if size, err := ls.List(); size != 3 || err != nil { + t.Errorf("List() returned (%d, %v), expected (3, nil)", size, err) + } + for i := 0; i < 3; i++ { + if val, err := ls.Uint(); val != 1 || err != nil { + t.Errorf("Uint() returned (%d, %v), expected (1, nil)", val, err) + } + } + if err := ls.ListEnd(); err != nil { + t.Errorf("ListEnd() returned %v, expected (3, nil)", err) + } +} + func TestStreamErrors(t *testing.T) { type calls []string tests := []struct { @@ -69,7 +86,7 @@ func TestStreamErrors(t *testing.T) { {"81", calls{"Bytes"}, io.ErrUnexpectedEOF}, {"81", calls{"Uint"}, io.ErrUnexpectedEOF}, {"BFFFFFFFFFFFFFFF", calls{"Bytes"}, io.ErrUnexpectedEOF}, - {"89000000000000000001", calls{"Uint"}, errors.New("rlp: string is larger than 64 bits")}, + {"89000000000000000001", calls{"Uint"}, errUintOverflow}, {"00", calls{"List"}, ErrExpectedList}, {"80", calls{"List"}, ErrExpectedList}, {"C0", calls{"List", "Uint"}, EOL}, @@ -163,7 +180,7 @@ type decodeTest struct { input string ptr interface{} value interface{} - error error + error string } type simplestruct struct { @@ -196,8 +213,8 @@ var decodeTests = []decodeTest{ {input: "820505", ptr: new(uint32), value: uint32(0x0505)}, {input: "83050505", ptr: new(uint32), value: uint32(0x050505)}, {input: "8405050505", ptr: new(uint32), value: uint32(0x05050505)}, - {input: "850505050505", ptr: new(uint32), error: errors.New("rlp: string is larger than 32 bits")}, - {input: "C0", ptr: new(uint32), error: ErrExpectedString}, + {input: "850505050505", ptr: new(uint32), error: "rlp: input string too big for uint32"}, + {input: "C0", ptr: new(uint32), error: ErrExpectedString.Error()}, // slices {input: "C0", ptr: new([]int), value: []int{}}, @@ -206,7 +223,7 @@ var decodeTests = []decodeTest{ // arrays {input: "C0", ptr: new([5]int), value: [5]int{}}, {input: "C50102030405", ptr: new([5]int), value: [5]int{1, 2, 3, 4, 5}}, - {input: "C6010203040506", ptr: new([5]int), error: errors.New("rlp: input List has more than 5 elements")}, + {input: "C6010203040506", ptr: new([5]int), error: "rlp: input list has too many elements for [5]int"}, // byte slices {input: "01", ptr: new([]byte), value: []byte{1}}, @@ -214,7 +231,7 @@ var decodeTests = []decodeTest{ {input: "8D6162636465666768696A6B6C6D", ptr: new([]byte), value: []byte("abcdefghijklm")}, {input: "C0", ptr: new([]byte), value: []byte{}}, {input: "C3010203", ptr: new([]byte), value: []byte{1, 2, 3}}, - {input: "C3820102", ptr: new([]byte), error: errors.New("rlp: string is larger than 8 bits")}, + {input: "C3820102", ptr: new([]byte), error: "rlp: input string too big for uint8"}, // byte arrays {input: "01", ptr: new([5]byte), value: [5]byte{1}}, @@ -222,9 +239,9 @@ var decodeTests = []decodeTest{ {input: "850102030405", ptr: new([5]byte), value: [5]byte{1, 2, 3, 4, 5}}, {input: "C0", ptr: new([5]byte), value: [5]byte{}}, {input: "C3010203", ptr: new([5]byte), value: [5]byte{1, 2, 3, 0, 0}}, - {input: "C3820102", ptr: new([5]byte), error: errors.New("rlp: string is larger than 8 bits")}, - {input: "86010203040506", ptr: new([5]byte), error: errStringDoesntFitArray}, - {input: "850101", ptr: new([5]byte), error: io.ErrUnexpectedEOF}, + {input: "C3820102", ptr: new([5]byte), error: "rlp: input string too big for uint8"}, + {input: "86010203040506", ptr: new([5]byte), error: "rlp: input string too big for [5]uint8"}, + {input: "850101", ptr: new([5]byte), error: io.ErrUnexpectedEOF.Error()}, // byte array reuse (should be zeroed) {input: "850102030405", ptr: &sharedByteArray, value: [5]byte{1, 2, 3, 4, 5}}, @@ -237,25 +254,25 @@ var decodeTests = []decodeTest{ // zero sized byte arrays {input: "80", ptr: new([0]byte), value: [0]byte{}}, {input: "C0", ptr: new([0]byte), value: [0]byte{}}, - {input: "01", ptr: new([0]byte), error: errStringDoesntFitArray}, - {input: "8101", ptr: new([0]byte), error: errStringDoesntFitArray}, + {input: "01", ptr: new([0]byte), error: "rlp: input string too big for [0]uint8"}, + {input: "8101", ptr: new([0]byte), error: "rlp: input string too big for [0]uint8"}, // strings {input: "00", ptr: new(string), value: "\000"}, {input: "8D6162636465666768696A6B6C6D", ptr: new(string), value: "abcdefghijklm"}, - {input: "C0", ptr: new(string), error: ErrExpectedString}, + {input: "C0", ptr: new(string), error: ErrExpectedString.Error()}, // big ints {input: "01", ptr: new(*big.Int), value: big.NewInt(1)}, {input: "89FFFFFFFFFFFFFFFFFF", ptr: new(*big.Int), value: veryBigInt}, {input: "10", ptr: new(big.Int), value: *big.NewInt(16)}, // non-pointer also works - {input: "C0", ptr: new(*big.Int), error: ErrExpectedString}, + {input: "C0", ptr: new(*big.Int), error: ErrExpectedString.Error()}, // structs {input: "C0", ptr: new(simplestruct), value: simplestruct{0, ""}}, {input: "C105", ptr: new(simplestruct), value: simplestruct{5, ""}}, {input: "C50583343434", ptr: new(simplestruct), value: simplestruct{5, "444"}}, - {input: "C3010101", ptr: new(simplestruct), error: errors.New("rlp: input List has too many elements")}, + {input: "C3010101", ptr: new(simplestruct), error: "rlp: input list has too many elements for rlp.simplestruct"}, { input: "C501C302C103", ptr: new(recstruct), @@ -286,20 +303,20 @@ var decodeTests = []decodeTest{ func intp(i int) *int { return &i } -func TestDecode(t *testing.T) { +func runTests(t *testing.T, decode func([]byte, interface{}) error) { for i, test := range decodeTests { input, err := hex.DecodeString(test.input) if err != nil { t.Errorf("test %d: invalid hex input %q", i, test.input) continue } - err = Decode(bytes.NewReader(input), test.ptr) - if err != nil && test.error == nil { + err = decode(input, test.ptr) + if err != nil && test.error == "" { t.Errorf("test %d: unexpected Decode error: %v\ndecoding into %T\ninput %q", i, err, test.ptr, test.input) continue } - if test.error != nil && fmt.Sprint(err) != fmt.Sprint(test.error) { + if test.error != "" && fmt.Sprint(err) != test.error { t.Errorf("test %d: Decode error mismatch\ngot %v\nwant %v\ndecoding into %T\ninput %q", i, err, test.error, test.ptr, test.input) continue @@ -312,6 +329,40 @@ func TestDecode(t *testing.T) { } } +func TestDecodeWithByteReader(t *testing.T) { + runTests(t, func(input []byte, into interface{}) error { + return Decode(bytes.NewReader(input), into) + }) +} + +// dumbReader reads from a byte slice but does not +// implement ReadByte. +type dumbReader []byte + +func (r *dumbReader) Read(buf []byte) (n int, err error) { + if len(*r) == 0 { + return 0, io.EOF + } + n = copy(buf, *r) + *r = (*r)[n:] + return n, nil +} + +func TestDecodeWithNonByteReader(t *testing.T) { + runTests(t, func(input []byte, into interface{}) error { + r := dumbReader(input) + return Decode(&r, into) + }) +} + +func TestDecodeStreamReset(t *testing.T) { + s := NewStream(nil) + runTests(t, func(input []byte, into interface{}) error { + s.Reset(bytes.NewReader(input)) + return s.Decode(into) + }) +} + type testDecoder struct{ called bool } func (t *testDecoder) DecodeRLP(s *Stream) error { |