aboutsummaryrefslogtreecommitdiffstats
path: root/ecies.go
diff options
context:
space:
mode:
Diffstat (limited to 'ecies.go')
-rw-r--r--ecies.go326
1 files changed, 326 insertions, 0 deletions
diff --git a/ecies.go b/ecies.go
new file mode 100644
index 000000000..0e2403d47
--- /dev/null
+++ b/ecies.go
@@ -0,0 +1,326 @@
+package ecies
+
+import (
+ "crypto/cipher"
+ "crypto/ecdsa"
+ "crypto/elliptic"
+ "crypto/hmac"
+ "crypto/subtle"
+ "fmt"
+ "hash"
+ "io"
+ "math/big"
+)
+
+var (
+ ErrImport = fmt.Errorf("ecies: failed to import key")
+ ErrInvalidCurve = fmt.Errorf("ecies: invalid elliptic curve")
+ ErrInvalidParams = fmt.Errorf("ecies: invalid ECIES parameters")
+ ErrInvalidPublicKey = fmt.Errorf("ecies: invalid public key")
+ ErrSharedKeyTooBig = fmt.Errorf("ecies: shared key is too big")
+)
+
+// PublicKey is a representation of an elliptic curve public key.
+type PublicKey struct {
+ X *big.Int
+ Y *big.Int
+ elliptic.Curve
+ Params *ECIESParams
+}
+
+// Export an ECIES public key as an ECDSA public key.
+func (pub *PublicKey) ExportECDSA() *ecdsa.PublicKey {
+ return &ecdsa.PublicKey{pub.Curve, pub.X, pub.Y}
+}
+
+// Import an ECDSA public key as an ECIES public key.
+func ImportECDSAPublic(pub *ecdsa.PublicKey) *PublicKey {
+ return &PublicKey{
+ X: pub.X,
+ Y: pub.Y,
+ Curve: pub.Curve,
+ Params: ParamsFromCurve(pub.Curve),
+ }
+}
+
+// PrivateKey is a representation of an elliptic curve private key.
+type PrivateKey struct {
+ PublicKey
+ D *big.Int
+}
+
+// Export an ECIES private key as an ECDSA private key.
+func (prv *PrivateKey) ExportECDSA() *ecdsa.PrivateKey {
+ pub := &prv.PublicKey
+ pubECDSA := pub.ExportECDSA()
+ return &ecdsa.PrivateKey{*pubECDSA, prv.D}
+}
+
+// Import an ECDSA private key as an ECIES private key.
+func ImportECDSA(prv *ecdsa.PrivateKey) *PrivateKey {
+ pub := ImportECDSAPublic(&prv.PublicKey)
+ return &PrivateKey{*pub, prv.D}
+}
+
+// Generate an elliptic curve public / private keypair. If params is nil,
+// the recommended default paramters for the key will be chosen.
+func GenerateKey(rand io.Reader, curve elliptic.Curve, params *ECIESParams) (prv *PrivateKey, err error) {
+ pb, x, y, err := elliptic.GenerateKey(curve, rand)
+ if err != nil {
+ return
+ }
+ prv = new(PrivateKey)
+ prv.PublicKey.X = x
+ prv.PublicKey.Y = y
+ prv.PublicKey.Curve = curve
+ prv.D = new(big.Int).SetBytes(pb)
+ if params == nil {
+ params = ParamsFromCurve(curve)
+ }
+ prv.PublicKey.Params = params
+ return
+}
+
+// MaxSharedKeyLength returns the maximum length of the shared key the
+// public key can produce.
+func MaxSharedKeyLength(pub *PublicKey) int {
+ return (pub.Curve.Params().BitSize + 7) / 8
+}
+
+// ECDH key agreement method used to establish secret keys for encryption.
+func (prv *PrivateKey) GenerateShared(pub *PublicKey, skLen, macLen int) (sk []byte, err error) {
+ if prv.PublicKey.Curve != pub.Curve {
+ err = ErrInvalidCurve
+ return
+ }
+ x, _ := pub.Curve.ScalarMult(pub.X, pub.Y, prv.D.Bytes())
+ if x == nil || (x.BitLen()+7)/8 < (skLen+macLen) {
+ err = ErrSharedKeyTooBig
+ return
+ }
+ sk = x.Bytes()[:skLen+macLen]
+ return
+}
+
+var (
+ ErrKeyDataTooLong = fmt.Errorf("ecies: can't supply requested key data")
+ ErrSharedTooLong = fmt.Errorf("ecies: shared secret is too long")
+ ErrInvalidMessage = fmt.Errorf("ecies: invalid message")
+)
+
+var (
+ big2To32 = new(big.Int).Exp(big.NewInt(2), big.NewInt(32), nil)
+ big2To32M1 = new(big.Int).Sub(big2To32, big.NewInt(1))
+)
+
+func incCounter(ctr []byte) {
+ if ctr[3]++; ctr[3] != 0 {
+ return
+ } else if ctr[2]++; ctr[2] != 0 {
+ return
+ } else if ctr[1]++; ctr[1] != 0 {
+ return
+ } else if ctr[0]++; ctr[0] != 0 {
+ return
+ }
+ return
+}
+
+// NIST SP 800-56 Concatenation Key Derivation Function (see section 5.8.1).
+func concatKDF(hash hash.Hash, z, s1 []byte, kdLen int) (k []byte, err error) {
+ if s1 == nil {
+ s1 = make([]byte, 0)
+ }
+
+ reps := ((kdLen + 7) * 8) / (hash.BlockSize() * 8)
+ if big.NewInt(int64(reps)).Cmp(big2To32M1) > 0 {
+ fmt.Println(big2To32M1)
+ return nil, ErrKeyDataTooLong
+ }
+
+ counter := []byte{0, 0, 0, 1}
+ k = make([]byte, 0)
+
+ for i := 0; i <= reps; i++ {
+ hash.Write(counter)
+ hash.Write(z)
+ hash.Write(s1)
+ k = append(k, hash.Sum(nil)...)
+ hash.Reset()
+ incCounter(counter)
+ }
+
+ k = k[:kdLen]
+ return
+}
+
+// messageTag computes the MAC of a message (called the tag) as per
+// SEC 1, 3.5.
+func messageTag(hash func() hash.Hash, km, msg, shared []byte) []byte {
+ if shared == nil {
+ shared = make([]byte, 0)
+ }
+ mac := hmac.New(hash, km)
+ mac.Write(msg)
+ tag := mac.Sum(nil)
+ return tag
+}
+
+// Generate an initialisation vector for CTR mode.
+func generateIV(params *ECIESParams, rand io.Reader) (iv []byte, err error) {
+ iv = make([]byte, params.BlockSize)
+ _, err = io.ReadFull(rand, iv)
+ return
+}
+
+// symEncrypt carries out CTR encryption using the block cipher specified in the
+// parameters.
+func symEncrypt(rand io.Reader, params *ECIESParams, key, m []byte) (ct []byte, err error) {
+ c, err := params.Cipher(key)
+ if err != nil {
+ return
+ }
+
+ iv, err := generateIV(params, rand)
+ if err != nil {
+ return
+ }
+ ctr := cipher.NewCTR(c, iv)
+
+ ct = make([]byte, len(m)+params.BlockSize)
+ copy(ct, iv)
+ ctr.XORKeyStream(ct[params.BlockSize:], m)
+ return
+}
+
+// symDecrypt carries out CTR decryption using the block cipher specified in
+// the parameters
+func symDecrypt(rand io.Reader, params *ECIESParams, key, ct []byte) (m []byte, err error) {
+ c, err := params.Cipher(key)
+ if err != nil {
+ return
+ }
+
+ ctr := cipher.NewCTR(c, ct[:params.BlockSize])
+
+ m = make([]byte, len(ct)-params.BlockSize)
+ ctr.XORKeyStream(m, ct[params.BlockSize:])
+ return
+}
+
+// Encrypt encrypts a message using ECIES as specified in SEC 1, 5.1. If
+// the shared information parameters aren't being used, they should be
+// nil.
+func Encrypt(rand io.Reader, pub *PublicKey, m, s1, s2 []byte) (ct []byte, err error) {
+ params := pub.Params
+ if params == nil {
+ if params = ParamsFromCurve(pub.Curve); params == nil {
+ err = ErrUnsupportedECIESParameters
+ return
+ }
+ }
+ R, err := GenerateKey(rand, pub.Curve, params)
+ if err != nil {
+ return
+ }
+
+ hash := params.Hash()
+ z, err := R.GenerateShared(pub, params.KeyLen, params.KeyLen)
+ if err != nil {
+ return
+ }
+ K, err := concatKDF(hash, z, s1, params.KeyLen+params.KeyLen)
+ if err != nil {
+ return
+ }
+ Ke := K[:params.KeyLen]
+ Km := K[params.KeyLen:]
+ hash.Write(Km)
+ Km = hash.Sum(nil)
+ hash.Reset()
+
+ em, err := symEncrypt(rand, params, Ke, m)
+ if err != nil || len(em) <= params.BlockSize {
+ return
+ }
+
+ d := messageTag(params.Hash, Km, em, s2)
+
+ Rb := elliptic.Marshal(pub.Curve, R.PublicKey.X, R.PublicKey.Y)
+ ct = make([]byte, len(Rb)+len(em)+len(d))
+ copy(ct, Rb)
+ copy(ct[len(Rb):], em)
+ copy(ct[len(Rb)+len(em):], d)
+ return
+}
+
+// Decrypt decrypts an ECIES ciphertext.
+func (prv *PrivateKey) Decrypt(rand io.Reader, c, s1, s2 []byte) (m []byte, err error) {
+ if c == nil || len(c) == 0 {
+ err = ErrInvalidMessage
+ return
+ }
+ params := prv.PublicKey.Params
+ if params == nil {
+ if params = ParamsFromCurve(prv.PublicKey.Curve); params == nil {
+ err = ErrUnsupportedECIESParameters
+ return
+ }
+ }
+ hash := params.Hash()
+
+ var (
+ rLen int
+ hLen int = hash.Size()
+ mStart int
+ mEnd int
+ )
+
+ switch c[0] {
+ case 2, 3, 4:
+ rLen = ((prv.PublicKey.Curve.Params().BitSize + 7) / 4)
+ if len(c) < (rLen + hLen + 1) {
+ err = ErrInvalidMessage
+ return
+ }
+ default:
+ err = ErrInvalidPublicKey
+ return
+ }
+
+ mStart = rLen
+ mEnd = len(c) - hLen
+
+ R := new(PublicKey)
+ R.Curve = prv.PublicKey.Curve
+ R.X, R.Y = elliptic.Unmarshal(R.Curve, c[:rLen])
+ if R.X == nil {
+ err = ErrInvalidPublicKey
+ return
+ }
+
+ z, err := prv.GenerateShared(R, params.KeyLen, params.KeyLen)
+ if err != nil {
+ return
+ }
+
+ K, err := concatKDF(hash, z, s1, params.KeyLen+params.KeyLen)
+ if err != nil {
+ return
+ }
+
+ Ke := K[:params.KeyLen]
+ Km := K[params.KeyLen:]
+ hash.Write(Km)
+ Km = hash.Sum(nil)
+ hash.Reset()
+
+ d := messageTag(params.Hash, Km, c[mStart:mEnd], s2)
+ if subtle.ConstantTimeCompare(c[mEnd:], d) != 1 {
+ err = ErrInvalidMessage
+ return
+ }
+
+ m, err = symDecrypt(rand, params, Ke, c[mStart:mEnd])
+ return
+}