diff options
Diffstat (limited to 'p2p')
-rw-r--r-- | p2p/message.go | 2 | ||||
-rw-r--r-- | p2p/messenger.go | 14 | ||||
-rw-r--r-- | p2p/messenger_test.go | 128 | ||||
-rw-r--r-- | p2p/protocol.go | 5 |
4 files changed, 96 insertions, 53 deletions
diff --git a/p2p/message.go b/p2p/message.go index 366cff5d7..97d440a27 100644 --- a/p2p/message.go +++ b/p2p/message.go @@ -98,7 +98,7 @@ type byteReader interface { io.ByteReader } -// readMsg reads a message header. +// readMsg reads a message header from r. func readMsg(r byteReader) (msg Msg, err error) { // read magic and payload size start := make([]byte, 8) diff --git a/p2p/messenger.go b/p2p/messenger.go index 7375ecc07..c7948a9ac 100644 --- a/p2p/messenger.go +++ b/p2p/messenger.go @@ -11,7 +11,7 @@ import ( "time" ) -type Handlers map[string]func() Protocol +type Handlers map[string]Protocol type proto struct { in chan Msg @@ -23,6 +23,7 @@ func (rw *proto) WriteMsg(msg Msg) error { if msg.Code >= rw.maxcode { return NewPeerError(InvalidMsgCode, "not handled") } + msg.Code += rw.offset return rw.messenger.writeMsg(msg) } @@ -31,12 +32,13 @@ func (rw *proto) ReadMsg() (Msg, error) { if !ok { return msg, io.EOF } + msg.Code -= rw.offset return msg, nil } -// eofSignal is used to 'lend' the network connection -// to a protocol. when the protocol's read loop has read the -// whole payload, the done channel is closed. +// eofSignal wraps a reader with eof signaling. +// the eof channel is closed when the wrapped reader +// reaches EOF. type eofSignal struct { wrapped io.Reader eof chan struct{} @@ -119,7 +121,6 @@ func (m *messenger) readLoop() { m.err <- err return } - msg.Code -= proto.offset if msg.Size <= wholePayloadSize { // optimization: msg is small enough, read all // of it and move on to the next message @@ -185,11 +186,10 @@ func (m *messenger) setRemoteProtocols(protocols []string) { defer m.protocolLock.Unlock() offset := baseProtocolOffset for _, name := range protocols { - protocolFunc, ok := m.handlers[name] + inst, ok := m.handlers[name] if !ok { continue // not handled } - inst := protocolFunc() m.protocols[name] = m.startProto(offset, name, inst) offset += inst.Offset() } diff --git a/p2p/messenger_test.go b/p2p/messenger_test.go index f10469e2f..2264e10d3 100644 --- a/p2p/messenger_test.go +++ b/p2p/messenger_test.go @@ -11,14 +11,14 @@ import ( "testing" "time" - "github.com/ethereum/go-ethereum/ethutil" + logpkg "github.com/ethereum/go-ethereum/logger" ) func init() { - ethlog.AddLogSystem(ethlog.NewStdLogSystem(os.Stdout, log.LstdFlags, ethlog.DebugLevel)) + logpkg.AddLogSystem(logpkg.NewStdLogSystem(os.Stdout, log.LstdFlags, logpkg.DebugLevel)) } -func setupMessenger(handlers Handlers) (net.Conn, *Peer, *messenger) { +func testMessenger(handlers Handlers) (net.Conn, *Peer, *messenger) { conn1, conn2 := net.Pipe() id := NewSimpleClientIdentity("test", "0", "0", "public key") server := New(nil, conn1.LocalAddr(), id, handlers, 10, NewBlacklist()) @@ -33,7 +33,7 @@ func performTestHandshake(r *bufio.Reader, w io.Writer) error { return fmt.Errorf("read error: %v", err) } if msg.Code != handshakeMsg { - return fmt.Errorf("first message should be handshake, got %x", msg.Code) + return fmt.Errorf("first message should be handshake, got %d", msg.Code) } if err := msg.Discard(); err != nil { return err @@ -44,56 +44,102 @@ func performTestHandshake(r *bufio.Reader, w io.Writer) error { return writeMsg(w, msg) } -type testMsg struct { - code MsgCode - data *ethutil.Value +type testProtocol struct { + offset MsgCode + f func(MsgReadWriter) } -type testProto struct { - recv chan testMsg +func (p *testProtocol) Offset() MsgCode { + return p.offset } -func (*testProto) Offset() MsgCode { return 5 } - -func (tp *testProto) Start(peer *Peer, rw MsgReadWriter) error { - return MsgLoop(rw, 1024, func(code MsgCode, data *ethutil.Value) error { - logger.Debugf("testprotocol got msg: %d\n", code) - tp.recv <- testMsg{code, data} - return nil - }) +func (p *testProtocol) Start(peer *Peer, rw MsgReadWriter) error { + p.f(rw) + return nil } func TestRead(t *testing.T) { - testProtocol := &testProto{make(chan testMsg)} - handlers := Handlers{"a": func() Protocol { return testProtocol }} - net, peer, mess := setupMessenger(handlers) - bufr := bufio.NewReader(net) + done := make(chan struct{}) + handlers := Handlers{ + "a": &testProtocol{5, func(rw MsgReadWriter) { + msg, err := rw.ReadMsg() + if err != nil { + t.Errorf("read error: %v", err) + } + if msg.Code != 2 { + t.Errorf("incorrect msg code %d relayed to protocol", msg.Code) + } + data, err := msg.Data() + if err != nil { + t.Errorf("data decoding error: %v", err) + } + expdata := []interface{}{1, []byte{0x30, 0x30, 0x30}} + if !reflect.DeepEqual(data.Slice(), expdata) { + t.Errorf("incorrect msg data %#v", data.Slice()) + } + close(done) + }}, + } + + net, peer, m := testMessenger(handlers) defer peer.Stop() + bufr := bufio.NewReader(net) if err := performTestHandshake(bufr, net); err != nil { t.Fatalf("handshake failed: %v", err) } + m.setRemoteProtocols([]string{"a"}) - mess.setRemoteProtocols([]string{"a"}) - writeMsg(net, NewMsg(17, uint32(1), "000")) + writeMsg(net, NewMsg(18, 1, "000")) select { - case msg := <-testProtocol.recv: - if msg.code != 1 { - t.Errorf("incorrect msg code %d relayed to protocol", msg.code) - } - expdata := []interface{}{1, []byte{0x30, 0x30, 0x30}} - if !reflect.DeepEqual(msg.data.Slice(), expdata) { - t.Errorf("incorrect msg data %#v", msg.data.Slice()) - } + case <-done: case <-time.After(2 * time.Second): t.Errorf("receive timeout") } } -func TestWriteProtoMsg(t *testing.T) { - handlers := make(Handlers) - testProtocol := &testProto{recv: make(chan testMsg, 1)} - handlers["a"] = func() Protocol { return testProtocol } - net, peer, mess := setupMessenger(handlers) +func TestWriteFromProto(t *testing.T) { + handlers := Handlers{ + "a": &testProtocol{2, func(rw MsgReadWriter) { + if err := rw.WriteMsg(NewMsg(2)); err == nil { + t.Error("expected error for out-of-range msg code, got nil") + } + if err := rw.WriteMsg(NewMsg(1)); err != nil { + t.Errorf("write error: %v", err) + } + }}, + } + net, peer, mess := testMessenger(handlers) + defer peer.Stop() + bufr := bufio.NewReader(net) + if err := performTestHandshake(bufr, net); err != nil { + t.Fatalf("handshake failed: %v", err) + } + mess.setRemoteProtocols([]string{"a"}) + + msg, err := readMsg(bufr) + if err != nil { + t.Errorf("read error: %v") + } + if msg.Code != 17 { + t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17) + } +} + +var discardProto = &testProtocol{1, func(rw MsgReadWriter) { + for { + msg, err := rw.ReadMsg() + if err != nil { + return + } + if err = msg.Discard(); err != nil { + return + } + } +}} + +func TestMessengerWriteProtoMsg(t *testing.T) { + handlers := Handlers{"a": discardProto} + net, peer, mess := testMessenger(handlers) defer peer.Stop() bufr := bufio.NewReader(net) if err := performTestHandshake(bufr, net); err != nil { @@ -120,13 +166,13 @@ func TestWriteProtoMsg(t *testing.T) { read <- msg } }() - if err := mess.writeProtoMsg("a", NewMsg(3)); err != nil { + if err := mess.writeProtoMsg("a", NewMsg(0)); err != nil { t.Errorf("expect no error for known protocol: %v", err) } select { case msg := <-read: - if msg.Code != 19 { - t.Errorf("wrong code, got %d, expected %d", msg.Code, 19) + if msg.Code != 16 { + t.Errorf("wrong code, got %d, expected %d", msg.Code, 16) } msg.Discard() case err := <-readerr: @@ -135,7 +181,7 @@ func TestWriteProtoMsg(t *testing.T) { } func TestPulse(t *testing.T) { - net, peer, _ := setupMessenger(nil) + net, peer, _ := testMessenger(nil) defer peer.Stop() bufr := bufio.NewReader(net) if err := performTestHandshake(bufr, net); err != nil { @@ -149,7 +195,7 @@ func TestPulse(t *testing.T) { } after := time.Now() if msg.Code != pingMsg { - t.Errorf("expected ping message, got %x", msg.Code) + t.Errorf("expected ping message, got %d", msg.Code) } if d := after.Sub(before); d < pingTimeout { t.Errorf("ping sent too early after %v, expected at least %v", d, pingTimeout) diff --git a/p2p/protocol.go b/p2p/protocol.go index ccc275287..d22ba70cb 100644 --- a/p2p/protocol.go +++ b/p2p/protocol.go @@ -143,9 +143,6 @@ func (d DiscReason) String() string { return discReasonToString[d] } -func (bp *baseProtocol) Ping() { -} - func (bp *baseProtocol) Offset() MsgCode { return baseProtocolOffset } @@ -287,7 +284,7 @@ func (bp *baseProtocol) handleHandshake(c *ethutil.Value) error { // self connect detection if bytes.Compare(bp.peer.server.ClientIdentity().Pubkey()[1:], pubkey) == 0 { - return NewPeerError(PubkeyForbidden, "not allowed to connect to bp") + return NewPeerError(PubkeyForbidden, "not allowed to connect to self") } // register pubkey on server. this also sets the pubkey on the peer (need lock) |