aboutsummaryrefslogtreecommitdiffstats
path: root/p2p/handshake_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'p2p/handshake_test.go')
-rw-r--r--p2p/handshake_test.go171
1 files changed, 171 insertions, 0 deletions
diff --git a/p2p/handshake_test.go b/p2p/handshake_test.go
new file mode 100644
index 000000000..19423bb82
--- /dev/null
+++ b/p2p/handshake_test.go
@@ -0,0 +1,171 @@
+package p2p
+
+import (
+ "bytes"
+ "crypto/rand"
+ "fmt"
+ "net"
+ "reflect"
+ "testing"
+ "time"
+
+ "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/crypto/ecies"
+ "github.com/ethereum/go-ethereum/p2p/discover"
+)
+
+func TestSharedSecret(t *testing.T) {
+ prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader)
+ pub0 := &prv0.PublicKey
+ prv1, _ := crypto.GenerateKey()
+ pub1 := &prv1.PublicKey
+
+ ss0, err := ecies.ImportECDSA(prv0).GenerateShared(ecies.ImportECDSAPublic(pub1), sskLen, sskLen)
+ if err != nil {
+ return
+ }
+ ss1, err := ecies.ImportECDSA(prv1).GenerateShared(ecies.ImportECDSAPublic(pub0), sskLen, sskLen)
+ if err != nil {
+ return
+ }
+ t.Logf("Secret:\n%v %x\n%v %x", len(ss0), ss0, len(ss0), ss1)
+ if !bytes.Equal(ss0, ss1) {
+ t.Errorf("dont match :(")
+ }
+}
+
+func TestEncHandshake(t *testing.T) {
+ for i := 0; i < 20; i++ {
+ start := time.Now()
+ if err := testEncHandshake(nil); err != nil {
+ t.Fatalf("i=%d %v", i, err)
+ }
+ t.Logf("(without token) %d %v\n", i+1, time.Since(start))
+ }
+
+ for i := 0; i < 20; i++ {
+ tok := make([]byte, shaLen)
+ rand.Reader.Read(tok)
+ start := time.Now()
+ if err := testEncHandshake(tok); err != nil {
+ t.Fatalf("i=%d %v", i, err)
+ }
+ t.Logf("(with token) %d %v\n", i+1, time.Since(start))
+ }
+}
+
+func testEncHandshake(token []byte) error {
+ type result struct {
+ side string
+ s secrets
+ err error
+ }
+ var (
+ prv0, _ = crypto.GenerateKey()
+ prv1, _ = crypto.GenerateKey()
+ rw0, rw1 = net.Pipe()
+ output = make(chan result)
+ )
+
+ go func() {
+ r := result{side: "initiator"}
+ defer func() { output <- r }()
+
+ pub1s := discover.PubkeyID(&prv1.PublicKey)
+ r.s, r.err = initiatorEncHandshake(rw0, prv0, pub1s, token)
+ if r.err != nil {
+ return
+ }
+ id1 := discover.PubkeyID(&prv1.PublicKey)
+ if r.s.RemoteID != id1 {
+ r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.s.RemoteID, id1)
+ }
+ }()
+ go func() {
+ r := result{side: "receiver"}
+ defer func() { output <- r }()
+
+ r.s, r.err = receiverEncHandshake(rw1, prv1, token)
+ if r.err != nil {
+ return
+ }
+ id0 := discover.PubkeyID(&prv0.PublicKey)
+ if r.s.RemoteID != id0 {
+ r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.s.RemoteID, id0)
+ }
+ }()
+
+ // wait for results from both sides
+ r1, r2 := <-output, <-output
+
+ if r1.err != nil {
+ return fmt.Errorf("%s side error: %v", r1.side, r1.err)
+ }
+ if r2.err != nil {
+ return fmt.Errorf("%s side error: %v", r2.side, r2.err)
+ }
+
+ // don't compare remote node IDs
+ r1.s.RemoteID, r2.s.RemoteID = discover.NodeID{}, discover.NodeID{}
+ // flip MACs on one of them so they compare equal
+ r1.s.EgressMAC, r1.s.IngressMAC = r1.s.IngressMAC, r1.s.EgressMAC
+ if !reflect.DeepEqual(r1.s, r2.s) {
+ return fmt.Errorf("secrets mismatch:\n t1: %#v\n t2: %#v", r1.s, r2.s)
+ }
+ return nil
+}
+
+func TestSetupConn(t *testing.T) {
+ prv0, _ := crypto.GenerateKey()
+ prv1, _ := crypto.GenerateKey()
+ node0 := &discover.Node{
+ ID: discover.PubkeyID(&prv0.PublicKey),
+ IP: net.IP{1, 2, 3, 4},
+ TCPPort: 33,
+ }
+ node1 := &discover.Node{
+ ID: discover.PubkeyID(&prv1.PublicKey),
+ IP: net.IP{5, 6, 7, 8},
+ TCPPort: 44,
+ }
+ hs0 := &protoHandshake{
+ Version: baseProtocolVersion,
+ ID: node0.ID,
+ Caps: []Cap{{"a", 0}, {"b", 2}},
+ }
+ hs1 := &protoHandshake{
+ Version: baseProtocolVersion,
+ ID: node1.ID,
+ Caps: []Cap{{"c", 1}, {"d", 3}},
+ }
+ fd0, fd1 := net.Pipe()
+
+ done := make(chan struct{})
+ go func() {
+ defer close(done)
+ conn0, err := setupConn(fd0, prv0, hs0, node1)
+ if err != nil {
+ t.Errorf("outbound side error: %v", err)
+ return
+ }
+ if conn0.ID != node1.ID {
+ t.Errorf("outbound conn id mismatch: got %v, want %v", conn0.ID, node1.ID)
+ }
+ if !reflect.DeepEqual(conn0.Caps, hs1.Caps) {
+ t.Errorf("outbound caps mismatch: got %v, want %v", conn0.Caps, hs1.Caps)
+ }
+ }()
+
+ conn1, err := setupConn(fd1, prv1, hs1, nil)
+ if err != nil {
+ t.Fatalf("inbound side error: %v", err)
+ }
+ if conn1.ID != node0.ID {
+ t.Errorf("inbound conn id mismatch: got %v, want %v", conn1.ID, node0.ID)
+ }
+ if !reflect.DeepEqual(conn1.Caps, hs0.Caps) {
+ t.Errorf("inbound caps mismatch: got %v, want %v", conn1.Caps, hs0.Caps)
+ }
+
+ <-done
+}