diff options
Diffstat (limited to 'p2p/rlpx_test.go')
-rw-r--r-- | p2p/rlpx_test.go | 244 |
1 files changed, 239 insertions, 5 deletions
diff --git a/p2p/rlpx_test.go b/p2p/rlpx_test.go index d98f1c2cd..44be46a99 100644 --- a/p2p/rlpx_test.go +++ b/p2p/rlpx_test.go @@ -3,19 +3,253 @@ package p2p import ( "bytes" "crypto/rand" + "errors" + "fmt" "io/ioutil" + "net" + "reflect" "strings" + "sync" "testing" + "time" + "github.com/davecgh/go-spew/spew" "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/crypto/ecies" "github.com/ethereum/go-ethereum/crypto/sha3" + "github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/rlp" ) -func TestRlpxFrameFake(t *testing.T) { +func TestSharedSecret(t *testing.T) { + prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader) + pub0 := &prv0.PublicKey + prv1, _ := crypto.GenerateKey() + pub1 := &prv1.PublicKey + + ss0, err := ecies.ImportECDSA(prv0).GenerateShared(ecies.ImportECDSAPublic(pub1), sskLen, sskLen) + if err != nil { + return + } + ss1, err := ecies.ImportECDSA(prv1).GenerateShared(ecies.ImportECDSAPublic(pub0), sskLen, sskLen) + if err != nil { + return + } + t.Logf("Secret:\n%v %x\n%v %x", len(ss0), ss0, len(ss0), ss1) + if !bytes.Equal(ss0, ss1) { + t.Errorf("dont match :(") + } +} + +func TestEncHandshake(t *testing.T) { + for i := 0; i < 10; i++ { + start := time.Now() + if err := testEncHandshake(nil); err != nil { + t.Fatalf("i=%d %v", i, err) + } + t.Logf("(without token) %d %v\n", i+1, time.Since(start)) + } + for i := 0; i < 10; i++ { + tok := make([]byte, shaLen) + rand.Reader.Read(tok) + start := time.Now() + if err := testEncHandshake(tok); err != nil { + t.Fatalf("i=%d %v", i, err) + } + t.Logf("(with token) %d %v\n", i+1, time.Since(start)) + } +} + +func testEncHandshake(token []byte) error { + type result struct { + side string + id discover.NodeID + err error + } + var ( + prv0, _ = crypto.GenerateKey() + prv1, _ = crypto.GenerateKey() + fd0, fd1 = net.Pipe() + c0, c1 = newRLPX(fd0).(*rlpx), newRLPX(fd1).(*rlpx) + output = make(chan result) + ) + + go func() { + r := result{side: "initiator"} + defer func() { output <- r }() + + dest := &discover.Node{ID: discover.PubkeyID(&prv1.PublicKey)} + r.id, r.err = c0.doEncHandshake(prv0, dest) + if r.err != nil { + return + } + id1 := discover.PubkeyID(&prv1.PublicKey) + if r.id != id1 { + r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.id, id1) + } + }() + go func() { + r := result{side: "receiver"} + defer func() { output <- r }() + + r.id, r.err = c1.doEncHandshake(prv1, nil) + if r.err != nil { + return + } + id0 := discover.PubkeyID(&prv0.PublicKey) + if r.id != id0 { + r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.id, id0) + } + }() + + // wait for results from both sides + r1, r2 := <-output, <-output + if r1.err != nil { + return fmt.Errorf("%s side error: %v", r1.side, r1.err) + } + if r2.err != nil { + return fmt.Errorf("%s side error: %v", r2.side, r2.err) + } + + // compare derived secrets + if !reflect.DeepEqual(c0.rw.egressMAC, c1.rw.ingressMAC) { + return fmt.Errorf("egress mac mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.egressMAC, c1.rw.ingressMAC) + } + if !reflect.DeepEqual(c0.rw.ingressMAC, c1.rw.egressMAC) { + return fmt.Errorf("ingress mac mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.ingressMAC, c1.rw.egressMAC) + } + if !reflect.DeepEqual(c0.rw.enc, c1.rw.enc) { + return fmt.Errorf("enc cipher mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.enc, c1.rw.enc) + } + if !reflect.DeepEqual(c0.rw.dec, c1.rw.dec) { + return fmt.Errorf("dec cipher mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.dec, c1.rw.dec) + } + return nil +} + +func TestProtocolHandshake(t *testing.T) { + var ( + prv0, _ = crypto.GenerateKey() + node0 = &discover.Node{ID: discover.PubkeyID(&prv0.PublicKey), IP: net.IP{1, 2, 3, 4}, TCP: 33} + hs0 = &protoHandshake{Version: 3, ID: node0.ID, Caps: []Cap{{"a", 0}, {"b", 2}}} + + prv1, _ = crypto.GenerateKey() + node1 = &discover.Node{ID: discover.PubkeyID(&prv1.PublicKey), IP: net.IP{5, 6, 7, 8}, TCP: 44} + hs1 = &protoHandshake{Version: 3, ID: node1.ID, Caps: []Cap{{"c", 1}, {"d", 3}}} + + fd0, fd1 = net.Pipe() + wg sync.WaitGroup + ) + + wg.Add(2) + go func() { + defer wg.Done() + rlpx := newRLPX(fd0) + remid, err := rlpx.doEncHandshake(prv0, node1) + if err != nil { + t.Errorf("dial side enc handshake failed: %v", err) + return + } + if remid != node1.ID { + t.Errorf("dial side remote id mismatch: got %v, want %v", remid, node1.ID) + return + } + + phs, err := rlpx.doProtoHandshake(hs0) + if err != nil { + t.Errorf("dial side proto handshake error: %v", err) + return + } + if !reflect.DeepEqual(phs, hs1) { + t.Errorf("dial side proto handshake mismatch:\ngot: %s\nwant: %s\n", spew.Sdump(phs), spew.Sdump(hs1)) + return + } + rlpx.close(DiscQuitting) + }() + go func() { + defer wg.Done() + rlpx := newRLPX(fd1) + remid, err := rlpx.doEncHandshake(prv1, nil) + if err != nil { + t.Errorf("listen side enc handshake failed: %v", err) + return + } + if remid != node0.ID { + t.Errorf("listen side remote id mismatch: got %v, want %v", remid, node0.ID) + return + } + + phs, err := rlpx.doProtoHandshake(hs1) + if err != nil { + t.Errorf("listen side proto handshake error: %v", err) + return + } + if !reflect.DeepEqual(phs, hs0) { + t.Errorf("listen side proto handshake mismatch:\ngot: %s\nwant: %s\n", spew.Sdump(phs), spew.Sdump(hs0)) + return + } + + if err := ExpectMsg(rlpx, discMsg, []DiscReason{DiscQuitting}); err != nil { + t.Errorf("error receiving disconnect: %v", err) + } + }() + wg.Wait() +} + +func TestProtocolHandshakeErrors(t *testing.T) { + our := &protoHandshake{Version: 3, Caps: []Cap{{"foo", 2}, {"bar", 3}}, Name: "quux"} + id := randomID() + tests := []struct { + code uint64 + msg interface{} + err error + }{ + { + code: discMsg, + msg: []DiscReason{DiscQuitting}, + err: DiscQuitting, + }, + { + code: 0x989898, + msg: []byte{1}, + err: errors.New("expected handshake, got 989898"), + }, + { + code: handshakeMsg, + msg: make([]byte, baseProtocolMaxMsgSize+2), + err: errors.New("message too big"), + }, + { + code: handshakeMsg, + msg: []byte{1, 2, 3}, + err: newPeerError(errInvalidMsg, "(code 0) (size 4) rlp: expected input list for p2p.protoHandshake"), + }, + { + code: handshakeMsg, + msg: &protoHandshake{Version: 9944, ID: id}, + err: DiscIncompatibleVersion, + }, + { + code: handshakeMsg, + msg: &protoHandshake{Version: 3}, + err: DiscInvalidIdentity, + }, + } + + for i, test := range tests { + p1, p2 := MsgPipe() + go Send(p1, test.code, test.msg) + _, err := readProtocolHandshake(p2, our) + if !reflect.DeepEqual(err, test.err) { + t.Errorf("test %d: error mismatch: got %q, want %q", i, err, test.err) + } + } +} + +func TestRLPXFrameFake(t *testing.T) { buf := new(bytes.Buffer) hash := fakeHash([]byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}) - rw := newRlpxFrameRW(buf, secrets{ + rw := newRLPXFrameRW(buf, secrets{ AES: crypto.Sha3(), MAC: crypto.Sha3(), IngressMAC: hash, @@ -66,7 +300,7 @@ func (fakeHash) BlockSize() int { return 0 } func (h fakeHash) Size() int { return len(h) } func (h fakeHash) Sum(b []byte) []byte { return append(b, h...) } -func TestRlpxFrameRW(t *testing.T) { +func TestRLPXFrameRW(t *testing.T) { var ( aesSecret = make([]byte, 16) macSecret = make([]byte, 16) @@ -86,7 +320,7 @@ func TestRlpxFrameRW(t *testing.T) { } s1.EgressMAC.Write(egressMACinit) s1.IngressMAC.Write(ingressMACinit) - rw1 := newRlpxFrameRW(conn, s1) + rw1 := newRLPXFrameRW(conn, s1) s2 := secrets{ AES: aesSecret, @@ -96,7 +330,7 @@ func TestRlpxFrameRW(t *testing.T) { } s2.EgressMAC.Write(ingressMACinit) s2.IngressMAC.Write(egressMACinit) - rw2 := newRlpxFrameRW(conn, s2) + rw2 := newRLPXFrameRW(conn, s2) // send some messages for i := 0; i < 10; i++ { |