aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--p2p/message.go93
-rw-r--r--p2p/message_test.go3
-rw-r--r--p2p/peer_test.go2
3 files changed, 31 insertions, 67 deletions
diff --git a/p2p/message.go b/p2p/message.go
index 89ad189d7..ade39d25a 100644
--- a/p2p/message.go
+++ b/p2p/message.go
@@ -3,12 +3,12 @@ package p2p
import (
"bytes"
"encoding/binary"
- "fmt"
"io"
"io/ioutil"
"math/big"
"github.com/ethereum/go-ethereum/ethutil"
+ "github.com/ethereum/go-ethereum/rlp"
)
// Msg defines the structure of a p2p message.
@@ -43,16 +43,10 @@ func encodePayload(params ...interface{}) []byte {
// Data returns the decoded RLP payload items in a message.
func (msg Msg) Data() (*ethutil.Value, error) {
- // TODO: avoid copying when we have a better RLP decoder
- buf := new(bytes.Buffer)
- var s []interface{}
- if _, err := buf.ReadFrom(msg.Payload); err != nil {
- return nil, err
- }
- for buf.Len() > 0 {
- s = append(s, ethutil.DecodeWithReader(buf))
- }
- return ethutil.NewValue(s), nil
+ s := rlp.NewListStream(msg.Payload, uint64(msg.Size))
+ var v []interface{}
+ err := s.Decode(&v)
+ return ethutil.NewValue(v), err
}
// Discard reads any remaining payload data into a black hole.
@@ -137,13 +131,9 @@ func makeListHeader(length uint32) []byte {
return append([]byte{lenb}, enc...)
}
-type byteReader interface {
- io.Reader
- io.ByteReader
-}
-
// readMsg reads a message header from r.
-func readMsg(r byteReader) (msg Msg, err error) {
+// It takes an rlp.ByteReader to ensure that the decoding doesn't buffer.
+func readMsg(r rlp.ByteReader) (msg Msg, err error) {
// read magic and payload size
start := make([]byte, 8)
if _, err = io.ReadFull(r, start); err != nil {
@@ -155,64 +145,35 @@ func readMsg(r byteReader) (msg Msg, err error) {
size := binary.BigEndian.Uint32(start[4:])
// decode start of RLP message to get the message code
- _, hdrlen, err := readListHeader(r)
- if err != nil {
+ posr := &postrack{r, 0}
+ s := rlp.NewStream(posr)
+ if _, err := s.List(); err != nil {
return msg, err
}
- code, codelen, err := readMsgCode(r)
+ code, err := s.Uint()
if err != nil {
return msg, err
}
+ payloadsize := size - posr.p
+ return Msg{code, payloadsize, io.LimitReader(r, int64(payloadsize))}, nil
+}
- rlpsize := size - hdrlen - codelen
- return Msg{
- Code: code,
- Size: rlpsize,
- Payload: io.LimitReader(r, int64(rlpsize)),
- }, nil
+// postrack wraps an rlp.ByteReader with a position counter.
+type postrack struct {
+ r rlp.ByteReader
+ p uint32
}
-// readListHeader reads an RLP list header from r.
-func readListHeader(r byteReader) (len uint64, hdrlen uint32, err error) {
- b, err := r.ReadByte()
- if err != nil {
- return 0, 0, err
- }
- if b < 0xC0 {
- return 0, 0, fmt.Errorf("expected list start byte >= 0xC0, got %x", b)
- } else if b < 0xF7 {
- len = uint64(b - 0xc0)
- hdrlen = 1
- } else {
- lenlen := b - 0xF7
- lenbuf := make([]byte, 8)
- if _, err := io.ReadFull(r, lenbuf[8-lenlen:]); err != nil {
- return 0, 0, err
- }
- len = binary.BigEndian.Uint64(lenbuf)
- hdrlen = 1 + uint32(lenlen)
- }
- return len, hdrlen, nil
+func (r *postrack) Read(buf []byte) (int, error) {
+ n, err := r.r.Read(buf)
+ r.p += uint32(n)
+ return n, err
}
-// readUint reads an RLP-encoded unsigned integer from r.
-func readMsgCode(r byteReader) (code uint64, codelen uint32, err error) {
- b, err := r.ReadByte()
- if err != nil {
- return 0, 0, err
- }
- if b < 0x80 {
- return uint64(b), 1, nil
- } else if b < 0x89 { // max length for uint64 is 8 bytes
- codelen = uint32(b - 0x80)
- if codelen == 0 {
- return 0, 1, nil
- }
- buf := make([]byte, 8)
- if _, err := io.ReadFull(r, buf[8-codelen:]); err != nil {
- return 0, 0, err
- }
- return binary.BigEndian.Uint64(buf), codelen, nil
+func (r *postrack) ReadByte() (byte, error) {
+ b, err := r.r.ReadByte()
+ if err == nil {
+ r.p++
}
- return 0, 0, fmt.Errorf("bad RLP type for message code: %x", b)
+ return b, err
}
diff --git a/p2p/message_test.go b/p2p/message_test.go
index 1edabc4e7..02d70a28b 100644
--- a/p2p/message_test.go
+++ b/p2p/message_test.go
@@ -46,6 +46,9 @@ func TestEncodeDecodeMsg(t *testing.T) {
if err != nil {
t.Fatalf("first payload item decode error: %v", err)
}
+ if v := data.Len(); v != 2 {
+ t.Errorf("incorrect data.Len(): got %v, expected %d", v, 1)
+ }
if v := data.Get(0).Uint(); v != 1 {
t.Errorf("incorrect data[0]: got %v, expected %d", v, 1)
}
diff --git a/p2p/peer_test.go b/p2p/peer_test.go
index 1afa0ab17..56cd4d890 100644
--- a/p2p/peer_test.go
+++ b/p2p/peer_test.go
@@ -57,7 +57,7 @@ func TestPeerProtoReadMsg(t *testing.T) {
if err != nil {
t.Errorf("data decoding error: %v", err)
}
- expdata := []interface{}{1, []byte{0x30, 0x30, 0x30}}
+ expdata := []interface{}{[]byte{0x01}, []byte{0x30, 0x30, 0x30}}
if !reflect.DeepEqual(data.Slice(), expdata) {
t.Errorf("incorrect msg data %#v", data.Slice())
}