aboutsummaryrefslogtreecommitdiffstats
path: root/p2p/rlpx_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'p2p/rlpx_test.go')
-rw-r--r--p2p/rlpx_test.go244
1 files changed, 239 insertions, 5 deletions
diff --git a/p2p/rlpx_test.go b/p2p/rlpx_test.go
index d98f1c2cd..44be46a99 100644
--- a/p2p/rlpx_test.go
+++ b/p2p/rlpx_test.go
@@ -3,19 +3,253 @@ package p2p
import (
"bytes"
"crypto/rand"
+ "errors"
+ "fmt"
"io/ioutil"
+ "net"
+ "reflect"
"strings"
+ "sync"
"testing"
+ "time"
+ "github.com/davecgh/go-spew/spew"
"github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/crypto/ecies"
"github.com/ethereum/go-ethereum/crypto/sha3"
+ "github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/rlp"
)
-func TestRlpxFrameFake(t *testing.T) {
+func TestSharedSecret(t *testing.T) {
+ prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader)
+ pub0 := &prv0.PublicKey
+ prv1, _ := crypto.GenerateKey()
+ pub1 := &prv1.PublicKey
+
+ ss0, err := ecies.ImportECDSA(prv0).GenerateShared(ecies.ImportECDSAPublic(pub1), sskLen, sskLen)
+ if err != nil {
+ return
+ }
+ ss1, err := ecies.ImportECDSA(prv1).GenerateShared(ecies.ImportECDSAPublic(pub0), sskLen, sskLen)
+ if err != nil {
+ return
+ }
+ t.Logf("Secret:\n%v %x\n%v %x", len(ss0), ss0, len(ss0), ss1)
+ if !bytes.Equal(ss0, ss1) {
+ t.Errorf("dont match :(")
+ }
+}
+
+func TestEncHandshake(t *testing.T) {
+ for i := 0; i < 10; i++ {
+ start := time.Now()
+ if err := testEncHandshake(nil); err != nil {
+ t.Fatalf("i=%d %v", i, err)
+ }
+ t.Logf("(without token) %d %v\n", i+1, time.Since(start))
+ }
+ for i := 0; i < 10; i++ {
+ tok := make([]byte, shaLen)
+ rand.Reader.Read(tok)
+ start := time.Now()
+ if err := testEncHandshake(tok); err != nil {
+ t.Fatalf("i=%d %v", i, err)
+ }
+ t.Logf("(with token) %d %v\n", i+1, time.Since(start))
+ }
+}
+
+func testEncHandshake(token []byte) error {
+ type result struct {
+ side string
+ id discover.NodeID
+ err error
+ }
+ var (
+ prv0, _ = crypto.GenerateKey()
+ prv1, _ = crypto.GenerateKey()
+ fd0, fd1 = net.Pipe()
+ c0, c1 = newRLPX(fd0).(*rlpx), newRLPX(fd1).(*rlpx)
+ output = make(chan result)
+ )
+
+ go func() {
+ r := result{side: "initiator"}
+ defer func() { output <- r }()
+
+ dest := &discover.Node{ID: discover.PubkeyID(&prv1.PublicKey)}
+ r.id, r.err = c0.doEncHandshake(prv0, dest)
+ if r.err != nil {
+ return
+ }
+ id1 := discover.PubkeyID(&prv1.PublicKey)
+ if r.id != id1 {
+ r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.id, id1)
+ }
+ }()
+ go func() {
+ r := result{side: "receiver"}
+ defer func() { output <- r }()
+
+ r.id, r.err = c1.doEncHandshake(prv1, nil)
+ if r.err != nil {
+ return
+ }
+ id0 := discover.PubkeyID(&prv0.PublicKey)
+ if r.id != id0 {
+ r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.id, id0)
+ }
+ }()
+
+ // wait for results from both sides
+ r1, r2 := <-output, <-output
+ if r1.err != nil {
+ return fmt.Errorf("%s side error: %v", r1.side, r1.err)
+ }
+ if r2.err != nil {
+ return fmt.Errorf("%s side error: %v", r2.side, r2.err)
+ }
+
+ // compare derived secrets
+ if !reflect.DeepEqual(c0.rw.egressMAC, c1.rw.ingressMAC) {
+ return fmt.Errorf("egress mac mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.egressMAC, c1.rw.ingressMAC)
+ }
+ if !reflect.DeepEqual(c0.rw.ingressMAC, c1.rw.egressMAC) {
+ return fmt.Errorf("ingress mac mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.ingressMAC, c1.rw.egressMAC)
+ }
+ if !reflect.DeepEqual(c0.rw.enc, c1.rw.enc) {
+ return fmt.Errorf("enc cipher mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.enc, c1.rw.enc)
+ }
+ if !reflect.DeepEqual(c0.rw.dec, c1.rw.dec) {
+ return fmt.Errorf("dec cipher mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.dec, c1.rw.dec)
+ }
+ return nil
+}
+
+func TestProtocolHandshake(t *testing.T) {
+ var (
+ prv0, _ = crypto.GenerateKey()
+ node0 = &discover.Node{ID: discover.PubkeyID(&prv0.PublicKey), IP: net.IP{1, 2, 3, 4}, TCP: 33}
+ hs0 = &protoHandshake{Version: 3, ID: node0.ID, Caps: []Cap{{"a", 0}, {"b", 2}}}
+
+ prv1, _ = crypto.GenerateKey()
+ node1 = &discover.Node{ID: discover.PubkeyID(&prv1.PublicKey), IP: net.IP{5, 6, 7, 8}, TCP: 44}
+ hs1 = &protoHandshake{Version: 3, ID: node1.ID, Caps: []Cap{{"c", 1}, {"d", 3}}}
+
+ fd0, fd1 = net.Pipe()
+ wg sync.WaitGroup
+ )
+
+ wg.Add(2)
+ go func() {
+ defer wg.Done()
+ rlpx := newRLPX(fd0)
+ remid, err := rlpx.doEncHandshake(prv0, node1)
+ if err != nil {
+ t.Errorf("dial side enc handshake failed: %v", err)
+ return
+ }
+ if remid != node1.ID {
+ t.Errorf("dial side remote id mismatch: got %v, want %v", remid, node1.ID)
+ return
+ }
+
+ phs, err := rlpx.doProtoHandshake(hs0)
+ if err != nil {
+ t.Errorf("dial side proto handshake error: %v", err)
+ return
+ }
+ if !reflect.DeepEqual(phs, hs1) {
+ t.Errorf("dial side proto handshake mismatch:\ngot: %s\nwant: %s\n", spew.Sdump(phs), spew.Sdump(hs1))
+ return
+ }
+ rlpx.close(DiscQuitting)
+ }()
+ go func() {
+ defer wg.Done()
+ rlpx := newRLPX(fd1)
+ remid, err := rlpx.doEncHandshake(prv1, nil)
+ if err != nil {
+ t.Errorf("listen side enc handshake failed: %v", err)
+ return
+ }
+ if remid != node0.ID {
+ t.Errorf("listen side remote id mismatch: got %v, want %v", remid, node0.ID)
+ return
+ }
+
+ phs, err := rlpx.doProtoHandshake(hs1)
+ if err != nil {
+ t.Errorf("listen side proto handshake error: %v", err)
+ return
+ }
+ if !reflect.DeepEqual(phs, hs0) {
+ t.Errorf("listen side proto handshake mismatch:\ngot: %s\nwant: %s\n", spew.Sdump(phs), spew.Sdump(hs0))
+ return
+ }
+
+ if err := ExpectMsg(rlpx, discMsg, []DiscReason{DiscQuitting}); err != nil {
+ t.Errorf("error receiving disconnect: %v", err)
+ }
+ }()
+ wg.Wait()
+}
+
+func TestProtocolHandshakeErrors(t *testing.T) {
+ our := &protoHandshake{Version: 3, Caps: []Cap{{"foo", 2}, {"bar", 3}}, Name: "quux"}
+ id := randomID()
+ tests := []struct {
+ code uint64
+ msg interface{}
+ err error
+ }{
+ {
+ code: discMsg,
+ msg: []DiscReason{DiscQuitting},
+ err: DiscQuitting,
+ },
+ {
+ code: 0x989898,
+ msg: []byte{1},
+ err: errors.New("expected handshake, got 989898"),
+ },
+ {
+ code: handshakeMsg,
+ msg: make([]byte, baseProtocolMaxMsgSize+2),
+ err: errors.New("message too big"),
+ },
+ {
+ code: handshakeMsg,
+ msg: []byte{1, 2, 3},
+ err: newPeerError(errInvalidMsg, "(code 0) (size 4) rlp: expected input list for p2p.protoHandshake"),
+ },
+ {
+ code: handshakeMsg,
+ msg: &protoHandshake{Version: 9944, ID: id},
+ err: DiscIncompatibleVersion,
+ },
+ {
+ code: handshakeMsg,
+ msg: &protoHandshake{Version: 3},
+ err: DiscInvalidIdentity,
+ },
+ }
+
+ for i, test := range tests {
+ p1, p2 := MsgPipe()
+ go Send(p1, test.code, test.msg)
+ _, err := readProtocolHandshake(p2, our)
+ if !reflect.DeepEqual(err, test.err) {
+ t.Errorf("test %d: error mismatch: got %q, want %q", i, err, test.err)
+ }
+ }
+}
+
+func TestRLPXFrameFake(t *testing.T) {
buf := new(bytes.Buffer)
hash := fakeHash([]byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1})
- rw := newRlpxFrameRW(buf, secrets{
+ rw := newRLPXFrameRW(buf, secrets{
AES: crypto.Sha3(),
MAC: crypto.Sha3(),
IngressMAC: hash,
@@ -66,7 +300,7 @@ func (fakeHash) BlockSize() int { return 0 }
func (h fakeHash) Size() int { return len(h) }
func (h fakeHash) Sum(b []byte) []byte { return append(b, h...) }
-func TestRlpxFrameRW(t *testing.T) {
+func TestRLPXFrameRW(t *testing.T) {
var (
aesSecret = make([]byte, 16)
macSecret = make([]byte, 16)
@@ -86,7 +320,7 @@ func TestRlpxFrameRW(t *testing.T) {
}
s1.EgressMAC.Write(egressMACinit)
s1.IngressMAC.Write(ingressMACinit)
- rw1 := newRlpxFrameRW(conn, s1)
+ rw1 := newRLPXFrameRW(conn, s1)
s2 := secrets{
AES: aesSecret,
@@ -96,7 +330,7 @@ func TestRlpxFrameRW(t *testing.T) {
}
s2.EgressMAC.Write(ingressMACinit)
s2.IngressMAC.Write(egressMACinit)
- rw2 := newRlpxFrameRW(conn, s2)
+ rw2 := newRLPXFrameRW(conn, s2)
// send some messages
for i := 0; i < 10; i++ {