aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--p2p/message.go18
-rw-r--r--p2p/message_test.go2
-rw-r--r--p2p/peer_test.go2
-rw-r--r--p2p/protocol.go107
4 files changed, 71 insertions, 58 deletions
diff --git a/p2p/message.go b/p2p/message.go
index ade39d25a..845c832f0 100644
--- a/p2p/message.go
+++ b/p2p/message.go
@@ -41,14 +41,22 @@ func encodePayload(params ...interface{}) []byte {
return buf.Bytes()
}
-// Data returns the decoded RLP payload items in a message.
-func (msg Msg) Data() (*ethutil.Value, error) {
- s := rlp.NewListStream(msg.Payload, uint64(msg.Size))
+// Value returns the decoded RLP payload items in a message.
+func (msg Msg) Value() (*ethutil.Value, error) {
var v []interface{}
- err := s.Decode(&v)
+ err := msg.Decode(&v)
return ethutil.NewValue(v), err
}
+// Decode parse the RLP content of a message into
+// the given value, which must be a pointer.
+//
+// For the decoding rules, please see package rlp.
+func (msg Msg) Decode(val interface{}) error {
+ s := rlp.NewListStream(msg.Payload, uint64(msg.Size))
+ return s.Decode(val)
+}
+
// Discard reads any remaining payload data into a black hole.
func (msg Msg) Discard() error {
_, err := io.Copy(ioutil.Discard, msg.Payload)
@@ -91,7 +99,7 @@ func MsgLoop(r MsgReader, maxsize uint32, f func(code uint64, data *ethutil.Valu
if msg.Size > maxsize {
return newPeerError(errInvalidMsg, "size %d exceeds maximum size of %d", msg.Size, maxsize)
}
- value, err := msg.Data()
+ value, err := msg.Value()
if err != nil {
return err
}
diff --git a/p2p/message_test.go b/p2p/message_test.go
index 02d70a28b..0f51f759e 100644
--- a/p2p/message_test.go
+++ b/p2p/message_test.go
@@ -42,7 +42,7 @@ func TestEncodeDecodeMsg(t *testing.T) {
if decmsg.Size != 5 {
t.Errorf("incorrect size %d, want %d", decmsg.Size, 5)
}
- data, err := decmsg.Data()
+ data, err := decmsg.Value()
if err != nil {
t.Fatalf("first payload item decode error: %v", err)
}
diff --git a/p2p/peer_test.go b/p2p/peer_test.go
index 56cd4d890..629475421 100644
--- a/p2p/peer_test.go
+++ b/p2p/peer_test.go
@@ -53,7 +53,7 @@ func TestPeerProtoReadMsg(t *testing.T) {
if msg.Code != 2 {
t.Errorf("incorrect msg code %d relayed to protocol", msg.Code)
}
- data, err := msg.Data()
+ data, err := msg.Value()
if err != nil {
t.Errorf("data decoding error: %v", err)
}
diff --git a/p2p/protocol.go b/p2p/protocol.go
index 169dcdb6e..28eab87cd 100644
--- a/p2p/protocol.go
+++ b/p2p/protocol.go
@@ -2,7 +2,6 @@ package p2p
import (
"bytes"
- "net"
"time"
"github.com/ethereum/go-ethereum/ethutil"
@@ -90,30 +89,18 @@ type baseProtocol struct {
func runBaseProtocol(peer *Peer, rw MsgReadWriter) error {
bp := &baseProtocol{rw, peer}
-
- // do handshake
- if err := rw.WriteMsg(bp.handshakeMsg()); err != nil {
- return err
- }
- msg, err := rw.ReadMsg()
- if err != nil {
+ if err := bp.doHandshake(rw); err != nil {
return err
}
- if msg.Code != handshakeMsg {
- return newPeerError(errProtocolBreach, "first message must be handshake, got %x", msg.Code)
- }
- data, err := msg.Data()
- if err != nil {
- return newPeerError(errInvalidMsg, "%v", err)
- }
- if err := bp.handleHandshake(data); err != nil {
- return err
- }
-
// run main loop
quit := make(chan error, 1)
go func() {
- quit <- MsgLoop(rw, baseProtocolMaxMsgSize, bp.handle)
+ for {
+ if err := bp.handle(rw); err != nil {
+ quit <- err
+ break
+ }
+ }
}()
return bp.loop(quit)
}
@@ -151,13 +138,27 @@ func (bp *baseProtocol) loop(quit <-chan error) error {
return err
}
-func (bp *baseProtocol) handle(code uint64, data *ethutil.Value) error {
- switch code {
+func (bp *baseProtocol) handle(rw MsgReadWriter) error {
+ msg, err := rw.ReadMsg()
+ if err != nil {
+ return err
+ }
+ if msg.Size > baseProtocolMaxMsgSize {
+ return newPeerError(errMisc, "message too big")
+ }
+ // make sure that the payload has been fully consumed
+ defer msg.Discard()
+
+ switch msg.Code {
case handshakeMsg:
return newPeerError(errProtocolBreach, "extra handshake received")
case discMsg:
- bp.peer.Disconnect(DiscReason(data.Get(0).Uint()))
+ var reason DiscReason
+ if err := msg.Decode(&reason); err != nil {
+ return err
+ }
+ bp.peer.Disconnect(reason)
return nil
case pingMsg:
@@ -178,35 +179,45 @@ func (bp *baseProtocol) handle(code uint64, data *ethutil.Value) error {
}
case peersMsg:
- bp.handlePeers(data)
+ var peers []*peerAddr
+ if err := msg.Decode(&peers); err != nil {
+ return err
+ }
+ for _, addr := range peers {
+ bp.peer.Debugf("received peer suggestion: %v", addr)
+ bp.peer.newPeerAddr <- addr
+ }
default:
- return newPeerError(errInvalidMsgCode, "unknown message code %v", code)
+ return newPeerError(errInvalidMsgCode, "unknown message code %v", msg.Code)
}
return nil
}
-func (bp *baseProtocol) handlePeers(data *ethutil.Value) {
- it := data.NewIterator()
- for it.Next() {
- addr := &peerAddr{
- IP: net.IP(it.Value().Get(0).Bytes()),
- Port: it.Value().Get(1).Uint(),
- Pubkey: it.Value().Get(2).Bytes(),
- }
- bp.peer.Debugf("received peer suggestion: %v", addr)
- bp.peer.newPeerAddr <- addr
+func (bp *baseProtocol) doHandshake(rw MsgReadWriter) error {
+ // send our handshake
+ if err := rw.WriteMsg(bp.handshakeMsg()); err != nil {
+ return err
+ }
+
+ // read and handle remote handshake
+ msg, err := rw.ReadMsg()
+ if err != nil {
+ return err
+ }
+ if msg.Code != handshakeMsg {
+ return newPeerError(errProtocolBreach, "first message must be handshake, got %x", msg.Code)
+ }
+ if msg.Size > baseProtocolMaxMsgSize {
+ return newPeerError(errMisc, "message too big")
}
-}
-func (bp *baseProtocol) handleHandshake(c *ethutil.Value) error {
- hs := handshake{
- Version: c.Get(0).Uint(),
- ID: c.Get(1).Str(),
- Caps: nil, // decoded below
- ListenPort: c.Get(3).Uint(),
- NodeID: c.Get(4).Bytes(),
+ var hs handshake
+ if err := msg.Decode(&hs); err != nil {
+ return err
}
+
+ // validate handshake info
if hs.Version != baseProtocolVersion {
return newPeerError(errP2PVersionMismatch, "Require protocol %d, received %d\n",
baseProtocolVersion, hs.Version)
@@ -228,14 +239,8 @@ func (bp *baseProtocol) handleHandshake(c *ethutil.Value) error {
if err := bp.peer.pubkeyHook(pa); err != nil {
return newPeerError(errPubkeyForbidden, "%v", err)
}
- capsIt := c.Get(2).NewIterator()
- for capsIt.Next() {
- cap := capsIt.Value()
- name := cap.Get(0).Str()
- if name != "" {
- hs.Caps = append(hs.Caps, Cap{Name: name, Version: uint(cap.Get(1).Uint())})
- }
- }
+
+ // TODO: remove Caps with empty name
var addr *peerAddr
if hs.ListenPort != 0 {