aboutsummaryrefslogtreecommitdiffstats
path: root/p2p
diff options
context:
space:
mode:
Diffstat (limited to 'p2p')
-rw-r--r--p2p/message.go18
-rw-r--r--p2p/message_test.go16
-rw-r--r--p2p/peer.go4
-rw-r--r--p2p/peer_test.go43
-rw-r--r--p2p/rlpx_test.go14
-rw-r--r--p2p/server.go3
-rw-r--r--p2p/server_test.go2
7 files changed, 25 insertions, 75 deletions
diff --git a/p2p/message.go b/p2p/message.go
index 2ef84f99d..04b9e71f3 100644
--- a/p2p/message.go
+++ b/p2p/message.go
@@ -51,19 +51,8 @@ type Msg struct {
// NewMsg creates an RLP-encoded message with the given code.
func NewMsg(code uint64, params ...interface{}) Msg {
- buf := new(bytes.Buffer)
- for _, p := range params {
- buf.Write(ethutil.Encode(p))
- }
- return Msg{Code: code, Size: uint32(buf.Len()), Payload: buf}
-}
-
-func encodePayload(params ...interface{}) []byte {
- buf := new(bytes.Buffer)
- for _, p := range params {
- buf.Write(ethutil.Encode(p))
- }
- return buf.Bytes()
+ p := bytes.NewReader(ethutil.Encode(params))
+ return Msg{Code: code, Size: uint32(p.Len()), Payload: p}
}
// Decode parse the RLP content of a message into
@@ -71,8 +60,7 @@ func encodePayload(params ...interface{}) []byte {
//
// For the decoding rules, please see package rlp.
func (msg Msg) Decode(val interface{}) error {
- s := rlp.NewListStream(msg.Payload, uint64(msg.Size))
- if err := s.Decode(val); err != nil {
+ if err := rlp.Decode(msg.Payload, val); err != nil {
return newPeerError(errInvalidMsg, "(code %#x) (size %d) %v", msg.Code, msg.Size, err)
}
return nil
diff --git a/p2p/message_test.go b/p2p/message_test.go
index 1757cbe7a..31ed61d87 100644
--- a/p2p/message_test.go
+++ b/p2p/message_test.go
@@ -2,10 +2,12 @@ package p2p
import (
"bytes"
+ "encoding/hex"
"fmt"
"io"
"io/ioutil"
"runtime"
+ "strings"
"testing"
"time"
)
@@ -15,11 +17,11 @@ func TestNewMsg(t *testing.T) {
if msg.Code != 3 {
t.Errorf("incorrect code %d, want %d", msg.Code)
}
- if msg.Size != 5 {
- t.Errorf("incorrect size %d, want %d", msg.Size, 5)
+ expect := unhex("c50183303030")
+ if msg.Size != uint32(len(expect)) {
+ t.Errorf("incorrect size %d, want %d", msg.Size, len(expect))
}
pl, _ := ioutil.ReadAll(msg.Payload)
- expect := []byte{0x01, 0x83, 0x30, 0x30, 0x30}
if !bytes.Equal(pl, expect) {
t.Errorf("incorrect payload content, got %x, want %x", pl, expect)
}
@@ -139,3 +141,11 @@ func TestEOFSignal(t *testing.T) {
default:
}
}
+
+func unhex(str string) []byte {
+ b, err := hex.DecodeString(strings.Replace(str, "\n", "", -1))
+ if err != nil {
+ panic(fmt.Sprintf("invalid hex string: %q", str))
+ }
+ return b
+}
diff --git a/p2p/peer.go b/p2p/peer.go
index 4982c4612..025be4ba9 100644
--- a/p2p/peer.go
+++ b/p2p/peer.go
@@ -193,12 +193,12 @@ func (p *Peer) handle(msg Msg) error {
msg.Discard()
go EncodeMsg(p.rw, pongMsg)
case msg.Code == discMsg:
- var reason DiscReason
+ var reason [1]DiscReason
// no need to discard or for error checking, we'll close the
// connection after this.
rlp.Decode(msg.Payload, &reason)
p.Disconnect(DiscRequested)
- return discRequestedError(reason)
+ return discRequestedError(reason[0])
case msg.Code < baseProtocolLength:
// ignore other base protocol messages
return msg.Discard()
diff --git a/p2p/peer_test.go b/p2p/peer_test.go
index 1ba43bed5..cc9f1f0cd 100644
--- a/p2p/peer_test.go
+++ b/p2p/peer_test.go
@@ -85,41 +85,6 @@ func TestPeerProtoReadMsg(t *testing.T) {
}
}
-func TestPeerProtoReadLargeMsg(t *testing.T) {
- defer testlog(t).detach()
-
- msgsize := uint32(10 * 1024 * 1024)
- done := make(chan struct{})
- proto := Protocol{
- Name: "a",
- Length: 5,
- Run: func(peer *Peer, rw MsgReadWriter) error {
- msg, err := rw.ReadMsg()
- if err != nil {
- t.Errorf("read error: %v", err)
- }
- if msg.Size != msgsize+4 {
- t.Errorf("incorrect msg.Size, got %d, expected %d", msg.Size, msgsize)
- }
- msg.Discard()
- close(done)
- return nil
- },
- }
-
- closer, rw, _, errc := testPeer([]Protocol{proto})
- defer closer.Close()
-
- EncodeMsg(rw, 18, make([]byte, msgsize))
- select {
- case <-done:
- case err := <-errc:
- t.Errorf("peer returned: %v", err)
- case <-time.After(2 * time.Second):
- t.Errorf("receive timeout")
- }
-}
-
func TestPeerProtoEncodeMsg(t *testing.T) {
defer testlog(t).detach()
@@ -246,13 +211,9 @@ func expectMsg(r MsgReader, code uint64, content interface{}) error {
if err != nil {
panic("content encode error: " + err.Error())
}
- // skip over list header in encoded value. this is temporary.
- contentEncR := bytes.NewReader(contentEnc)
- if k, _, err := rlp.NewStream(contentEncR).Kind(); k != rlp.List || err != nil {
- panic("content must encode as RLP list")
+ if int(msg.Size) != len(contentEnc) {
+ return fmt.Errorf("message size mismatch: got %d, want %d", msg.Size, len(contentEnc))
}
- contentEnc = contentEnc[len(contentEnc)-contentEncR.Len():]
-
actualContent, err := ioutil.ReadAll(msg.Payload)
if err != nil {
return err
diff --git a/p2p/rlpx_test.go b/p2p/rlpx_test.go
index 077dd1309..49354c7ed 100644
--- a/p2p/rlpx_test.go
+++ b/p2p/rlpx_test.go
@@ -3,8 +3,6 @@ package p2p
import (
"bytes"
"crypto/rand"
- "encoding/hex"
- "fmt"
"io/ioutil"
"strings"
"testing"
@@ -32,7 +30,7 @@ ba628a4ba590cb43f7848f41c4382885
`)
// Check WriteMsg. This puts a message into the buffer.
- if err := EncodeMsg(rw, 8, []interface{}{1, 2, 3, 4}); err != nil {
+ if err := EncodeMsg(rw, 8, 1, 2, 3, 4); err != nil {
t.Fatalf("WriteMsg error: %v", err)
}
written := buf.Bytes()
@@ -68,14 +66,6 @@ func (fakeHash) BlockSize() int { return 0 }
func (h fakeHash) Size() int { return len(h) }
func (h fakeHash) Sum(b []byte) []byte { return append(b, h...) }
-func unhex(str string) []byte {
- b, err := hex.DecodeString(strings.Replace(str, "\n", "", -1))
- if err != nil {
- panic(fmt.Sprintf("invalid hex string: %q", str))
- }
- return b
-}
-
func TestRlpxFrameRW(t *testing.T) {
var (
aesSecret = make([]byte, 16)
@@ -112,7 +102,7 @@ func TestRlpxFrameRW(t *testing.T) {
for i := 0; i < 10; i++ {
// write message into conn buffer
wmsg := []interface{}{"foo", "bar", strings.Repeat("test", i)}
- err := EncodeMsg(rw1, uint64(i), wmsg)
+ err := EncodeMsg(rw1, uint64(i), wmsg...)
if err != nil {
t.Fatalf("WriteMsg error (i=%d): %v", i, err)
}
diff --git a/p2p/server.go b/p2p/server.go
index e53e832aa..67d5514b4 100644
--- a/p2p/server.go
+++ b/p2p/server.go
@@ -10,6 +10,7 @@ import (
"sync"
"time"
+ "github.com/ethereum/go-ethereum/ethutil"
"github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/p2p/nat"
@@ -135,7 +136,7 @@ func (srv *Server) SuggestPeer(n *discover.Node) {
func (srv *Server) Broadcast(protocol string, code uint64, data ...interface{}) {
var payload []byte
if data != nil {
- payload = encodePayload(data...)
+ payload = ethutil.Encode(data)
}
srv.lock.RLock()
defer srv.lock.RUnlock()
diff --git a/p2p/server_test.go b/p2p/server_test.go
index c348f5a9a..30447050c 100644
--- a/p2p/server_test.go
+++ b/p2p/server_test.go
@@ -150,7 +150,7 @@ func TestServerBroadcast(t *testing.T) {
// broadcast one message
srv.Broadcast("discard", 0, "foo")
- golden := unhex("66e94e166f0a2c3b884cfa59ca34")
+ golden := unhex("66e94d166f0a2c3b884cfa59ca34")
// check that the message has been written everywhere
for i, conn := range conns {