package p2p import ( "bytes" "crypto/rand" "errors" "fmt" "io/ioutil" "net" "reflect" "strings" "sync" "testing" "time" "github.com/davecgh/go-spew/spew" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto/ecies" "github.com/ethereum/go-ethereum/crypto/sha3" "github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/rlp" ) func TestSharedSecret(t *testing.T) { prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader) pub0 := &prv0.PublicKey prv1, _ := crypto.GenerateKey() pub1 := &prv1.PublicKey ss0, err := ecies.ImportECDSA(prv0).GenerateShared(ecies.ImportECDSAPublic(pub1), sskLen, sskLen) if err != nil { return } ss1, err := ecies.ImportECDSA(prv1).GenerateShared(ecies.ImportECDSAPublic(pub0), sskLen, sskLen) if err != nil { return } t.Logf("Secret:\n%v %x\n%v %x", len(ss0), ss0, len(ss0), ss1) if !bytes.Equal(ss0, ss1) { t.Errorf("dont match :(") } } func TestEncHandshake(t *testing.T) { for i := 0; i < 10; i++ { start := time.Now() if err := testEncHandshake(nil); err != nil { t.Fatalf("i=%d %v", i, err) } t.Logf("(without token) %d %v\n", i+1, time.Since(start)) } for i := 0; i < 10; i++ { tok := make([]byte, shaLen) rand.Reader.Read(tok) start := time.Now() if err := testEncHandshake(tok); err != nil { t.Fatalf("i=%d %v", i, err) } t.Logf("(with token) %d %v\n", i+1, time.Since(start)) } } func testEncHandshake(token []byte) error { type result struct { side string id discover.NodeID err error } var ( prv0, _ = crypto.GenerateKey() prv1, _ = crypto.GenerateKey() fd0, fd1 = net.Pipe() c0, c1 = newRLPX(fd0).(*rlpx), newRLPX(fd1).(*rlpx) output = make(chan result) ) go func() { r := result{side: "initiator"} defer func() { output <- r }() dest := &discover.Node{ID: discover.PubkeyID(&prv1.PublicKey)} r.id, r.err = c0.doEncHandshake(prv0, dest) if r.err != nil { return } id1 := discover.PubkeyID(&prv1.PublicKey) if r.id != id1 { r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.id, id1) } }() go func() { r := result{side: "receiver"} defer func() { output <- r }() r.id, r.err = c1.doEncHandshake(prv1, nil) if r.err != nil { return } id0 := discover.PubkeyID(&prv0.PublicKey) if r.id != id0 { r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.id, id0) } }() // wait for results from both sides r1, r2 := <-output, <-output if r1.err != nil { return fmt.Errorf("%s side error: %v", r1.side, r1.err) } if r2.err != nil { return fmt.Errorf("%s side error: %v", r2.side, r2.err) } // compare derived secrets if !reflect.DeepEqual(c0.rw.egressMAC, c1.rw.ingressMAC) { return fmt.Errorf("egress mac mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.egressMAC, c1.rw.ingressMAC) } if !reflect.DeepEqual(c0.rw.ingressMAC, c1.rw.egressMAC) { return fmt.Errorf("ingress mac mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.ingressMAC, c1.rw.egressMAC) } if !reflect.DeepEqual(c0.rw.enc, c1.rw.enc) { return fmt.Errorf("enc cipher mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.enc, c1.rw.enc) } if !reflect.DeepEqual(c0.rw.dec, c1.rw.dec) { return fmt.Errorf("dec cipher mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.dec, c1.rw.dec) } return nil } func TestProtocolHandshake(t *testing.T) { var ( prv0, _ = crypto.GenerateKey() node0 = &discover.Node{ID: discover.PubkeyID(&prv0.PublicKey), IP: net.IP{1, 2, 3, 4}, TCP: 33} hs0 = &protoHandshake{Version: 3, ID: node0.ID, Caps: []Cap{{"a", 0}, {"b", 2}}} prv1, _ = crypto.GenerateKey() node1 = &discover.Node{ID: discover.PubkeyID(&prv1.PublicKey), IP: net.IP{5, 6, 7, 8}, TCP: 44} hs1 = &protoHandshake{Version: 3, ID: node1.ID, Caps: []Cap{{"c", 1}, {"d", 3}}} fd0, fd1 = net.Pipe() wg sync.WaitGroup ) wg.Add(2) go func() { defer wg.Done() rlpx := newRLPX(fd0) remid, err := rlpx.doEncHandshake(prv0, node1) if err != nil { t.Errorf("dial side enc handshake failed: %v", err) return } if remid != node1.ID { t.Errorf("dial side remote id mismatch: got %v, want %v", remid, node1.ID) return } phs, err := rlpx.doProtoHandshake(hs0) if err != nil { t.Errorf("dial side proto handshake error: %v", err) return } if !reflect.DeepEqual(phs, hs1) { t.Errorf("dial side proto handshake mismatch:\ngot: %s\nwant: %s\n", spew.Sdump(phs), spew.Sdump(hs1)) return } rlpx.close(DiscQuitting) }() go func() { defer wg.Done() rlpx := newRLPX(fd1) remid, err := rlpx.doEncHandshake(prv1, nil) if err != nil { t.Errorf("listen side enc handshake failed: %v", err) return } if remid != node0.ID { t.Errorf("listen side remote id mismatch: got %v, want %v", remid, node0.ID) return } phs, err := rlpx.doProtoHandshake(hs1) if err != nil { t.Errorf("listen side proto handshake error: %v", err) return } if !reflect.DeepEqual(phs, hs0) { t.Errorf("listen side proto handshake mismatch:\ngot: %s\nwant: %s\n", spew.Sdump(phs), spew.Sdump(hs0)) return } if err := ExpectMsg(rlpx, discMsg, []DiscReason{DiscQuitting}); err != nil { t.Errorf("error receiving disconnect: %v", err) } }() wg.Wait() } func TestProtocolHandshakeErrors(t *testing.T) { our := &protoHandshake{Version: 3, Caps: []Cap{{"foo", 2}, {"bar", 3}}, Name: "quux"} id := randomID() tests := []struct { code uint64 msg interface{} err error }{ { code: discMsg, msg: []DiscReason{DiscQuitting}, err: DiscQuitting, }, { code: 0x989898, msg: []byte{1}, err: errors.New("expected handshake, got 989898"), }, { code: handshakeMsg, msg: make([]byte, baseProtocolMaxMsgSize+2), err: errors.New("message too big"), }, { code: handshakeMsg, msg: []byte{1, 2, 3}, err: newPeerError(errInvalidMsg, "(code 0) (size 4) rlp: expected input list for p2p.protoHandshake"), }, { code: handshakeMsg, msg: &protoHandshake{Version: 9944, ID: id}, err: DiscIncompatibleVersion, }, { code: handshakeMsg, msg: &protoHandshake{Version: 3}, err: DiscInvalidIdentity, }, } for i, test := range tests { p1, p2 := MsgPipe() go Send(p1, test.code, test.msg) _, err := readProtocolHandshake(p2, our) if !reflect.DeepEqual(err, test.err) { t.Errorf("test %d: error mismatch: got %q, want %q", i, err, test.err) } } } func TestRLPXFrameFake(t *testing.T) { buf := new(bytes.Buffer) hash := fakeHash([]byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}) rw := newRLPXFrameRW(buf, secrets{ AES: crypto.Sha3(), MAC: crypto.Sha3(), IngressMAC: hash, EgressMAC: hash, }) golden := unhex(` 00828ddae471818bb0bfa6b551d1cb42 01010101010101010101010101010101 ba628a4ba590cb43f7848f41c4382885 01010101010101010101010101010101 `) // Check WriteMsg. This puts a message into the buffer. if err := Send(rw, 8, []uint{1, 2, 3, 4}); err != nil { t.Fatalf("WriteMsg error: %v", err) } written := buf.Bytes() if !bytes.Equal(written, golden) { t.Fatalf("output mismatch:\n got: %x\n want: %x", written, golden) } // Check ReadMsg. It reads the message encoded by WriteMsg, which // is equivalent to the golden message above. msg, err := rw.ReadMsg() if err != nil { t.Fatalf("ReadMsg error: %v", err) } if msg.Size != 5 { t.Errorf("msg size mismatch: got %d, want %d", msg.Size, 5) } if msg.Code != 8 { t.Errorf("msg code mismatch: got %d, want %d", msg.Code, 8) } payload, _ := ioutil.ReadAll(msg.Payload) wantPayload := unhex("C401020304") if !bytes.Equal(payload, wantPayload) { t.Errorf("msg payload mismatch:\ngot %x\nwant %x", payload, wantPayload) } } type fakeHash []byte func (fakeHash) Write(p []byte) (int, error) { return len(p), nil } func (fakeHash) Reset() {} func (fakeHash) BlockSize() int { return 0 } func (h fakeHash) Size() int { return len(h) } func (h fakeHash) Sum(b []byte) []byte { return append(b, h...) } func TestRLPXFrameRW(t *testing.T) { var ( aesSecret = make([]byte, 16) macSecret = make([]byte, 16) egressMACinit = make([]byte, 32) ingressMACinit = make([]byte, 32) ) for _, s := range [][]byte{aesSecret, macSecret, egressMACinit, ingressMACinit} { rand.Read(s) } conn := new(bytes.Buffer) s1 := secrets{ AES: aesSecret, MAC: macSecret, EgressMAC: sha3.NewKeccak256(), IngressMAC: sha3.NewKeccak256(), } s1.EgressMAC.Write(egressMACinit) s1.IngressMAC.Write(ingressMACinit) rw1 := newRLPXFrameRW(conn, s1) s2 := secrets{ AES: aesSecret, MAC: macSecret, EgressMAC: sha3.NewKeccak256(), IngressMAC: sha3.NewKeccak256(), } s2.EgressMAC.Write(ingressMACinit) s2.IngressMAC.Write(egressMACinit) rw2 := newRLPXFrameRW(conn, s2) // send some messages for i := 0; i < 10; i++ { // write message into conn buffer wmsg := []interface{}{"foo", "bar", strings.Repeat("test", i)} err := Send(rw1, uint64(i), wmsg) if err != nil { t.Fatalf("WriteMsg error (i=%d): %v", i, err) } // read message that rw1 just wrote msg, err := rw2.ReadMsg() if err != nil { t.Fatalf("ReadMsg error (i=%d): %v", i, err) } if msg.Code != uint64(i) { t.Fatalf("msg code mismatch: got %d, want %d", msg.Code, i) } payload, _ := ioutil.ReadAll(msg.Payload) wantPayload, _ := rlp.EncodeToBytes(wmsg) if !bytes.Equal(payload, wantPayload) { t.Fatalf("msg payload mismatch:\ngot %x\nwant %x", payload, wantPayload) } } }