aboutsummaryrefslogtreecommitdiffstats
path: root/crypto
diff options
context:
space:
mode:
Diffstat (limited to 'crypto')
-rw-r--r--crypto/crypto.go13
-rw-r--r--crypto/crypto_test.go31
2 files changed, 38 insertions, 6 deletions
diff --git a/crypto/crypto.go b/crypto/crypto.go
index 76d1ffaf6..619440e81 100644
--- a/crypto/crypto.go
+++ b/crypto/crypto.go
@@ -39,6 +39,8 @@ var (
secp256k1halfN = new(big.Int).Div(secp256k1N, big.NewInt(2))
)
+var errInvalidPubkey = errors.New("invalid secp256k1 public key")
+
// Keccak256 calculates and returns the Keccak256 hash of the input data.
func Keccak256(data ...[]byte) []byte {
d := sha3.NewKeccak256()
@@ -122,12 +124,13 @@ func FromECDSA(priv *ecdsa.PrivateKey) []byte {
return math.PaddedBigBytes(priv.D, priv.Params().BitSize/8)
}
-func ToECDSAPub(pub []byte) *ecdsa.PublicKey {
- if len(pub) == 0 {
- return nil
- }
+// UnmarshalPubkey converts bytes to a secp256k1 public key.
+func UnmarshalPubkey(pub []byte) (*ecdsa.PublicKey, error) {
x, y := elliptic.Unmarshal(S256(), pub)
- return &ecdsa.PublicKey{Curve: S256(), X: x, Y: y}
+ if x == nil {
+ return nil, errInvalidPubkey
+ }
+ return &ecdsa.PublicKey{Curve: S256(), X: x, Y: y}, nil
}
func FromECDSAPub(pub *ecdsa.PublicKey) []byte {
diff --git a/crypto/crypto_test.go b/crypto/crypto_test.go
index 804de3fe2..177c19c0c 100644
--- a/crypto/crypto_test.go
+++ b/crypto/crypto_test.go
@@ -23,9 +23,11 @@ import (
"io/ioutil"
"math/big"
"os"
+ "reflect"
"testing"
"github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/common/hexutil"
)
var testAddrHex = "970e8128ab834e8eac17ab8e3812f010678cf791"
@@ -56,6 +58,33 @@ func BenchmarkSha3(b *testing.B) {
}
}
+func TestUnmarshalPubkey(t *testing.T) {
+ key, err := UnmarshalPubkey(nil)
+ if err != errInvalidPubkey || key != nil {
+ t.Fatalf("expected error, got %v, %v", err, key)
+ }
+ key, err = UnmarshalPubkey([]byte{1, 2, 3})
+ if err != errInvalidPubkey || key != nil {
+ t.Fatalf("expected error, got %v, %v", err, key)
+ }
+
+ var (
+ enc, _ = hex.DecodeString("04760c4460e5336ac9bbd87952a3c7ec4363fc0a97bd31c86430806e287b437fd1b01abc6e1db640cf3106b520344af1d58b00b57823db3e1407cbc433e1b6d04d")
+ dec = &ecdsa.PublicKey{
+ Curve: S256(),
+ X: hexutil.MustDecodeBig("0x760c4460e5336ac9bbd87952a3c7ec4363fc0a97bd31c86430806e287b437fd1"),
+ Y: hexutil.MustDecodeBig("0xb01abc6e1db640cf3106b520344af1d58b00b57823db3e1407cbc433e1b6d04d"),
+ }
+ )
+ key, err = UnmarshalPubkey(enc)
+ if err != nil {
+ t.Fatalf("expected no error, got %v", err)
+ }
+ if !reflect.DeepEqual(key, dec) {
+ t.Fatal("wrong result")
+ }
+}
+
func TestSign(t *testing.T) {
key, _ := HexToECDSA(testPrivHex)
addr := common.HexToAddress(testAddrHex)
@@ -69,7 +98,7 @@ func TestSign(t *testing.T) {
if err != nil {
t.Errorf("ECRecover error: %s", err)
}
- pubKey := ToECDSAPub(recoveredPub)
+ pubKey, _ := UnmarshalPubkey(recoveredPub)
recoveredAddr := PubkeyToAddress(*pubKey)
if addr != recoveredAddr {
t.Errorf("Address mismatch: want: %x have: %x", addr, recoveredAddr)