aboutsummaryrefslogtreecommitdiffstats
path: root/p2p
diff options
context:
space:
mode:
Diffstat (limited to 'p2p')
-rw-r--r--p2p/discover/node.go15
-rw-r--r--p2p/discover/node_test.go18
2 files changed, 33 insertions, 0 deletions
diff --git a/p2p/discover/node.go b/p2p/discover/node.go
index c6d2e9766..de2588258 100644
--- a/p2p/discover/node.go
+++ b/p2p/discover/node.go
@@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"io"
+ "math/big"
"math/rand"
"net"
"net/url"
@@ -14,6 +15,7 @@ import (
"strings"
"time"
+ "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/crypto/secp256k1"
"github.com/ethereum/go-ethereum/rlp"
)
@@ -187,6 +189,19 @@ func PubkeyID(pub *ecdsa.PublicKey) NodeID {
return id
}
+// Pubkey returns the public key represented by the node ID.
+// It returns an error if the ID is not a point on the curve.
+func (id NodeID) Pubkey() (*ecdsa.PublicKey, error) {
+ p := &ecdsa.PublicKey{Curve: crypto.S256(), X: new(big.Int), Y: new(big.Int)}
+ half := len(id) / 2
+ p.X.SetBytes(id[:half])
+ p.Y.SetBytes(id[half:])
+ if !p.Curve.IsOnCurve(p.X, p.Y) {
+ return nil, errors.New("not a point on the S256 curve")
+ }
+ return p, nil
+}
+
// recoverNodeID computes the public key used to sign the
// given hash from the signature.
func recoverNodeID(hash, sig []byte) (id NodeID, err error) {
diff --git a/p2p/discover/node_test.go b/p2p/discover/node_test.go
index ae82ae4f1..60b01b6ca 100644
--- a/p2p/discover/node_test.go
+++ b/p2p/discover/node_test.go
@@ -133,6 +133,24 @@ func TestNodeID_recover(t *testing.T) {
if pub != recpub {
t.Errorf("recovered wrong pubkey:\ngot: %v\nwant: %v", recpub, pub)
}
+
+ ecdsa, err := pub.Pubkey()
+ if err != nil {
+ t.Errorf("Pubkey error: %v", err)
+ }
+ if !reflect.DeepEqual(ecdsa, &prv.PublicKey) {
+ t.Errorf("Pubkey mismatch:\n got: %#v\n want: %#v", ecdsa, &prv.PublicKey)
+ }
+}
+
+func TestNodeID_pubkeyBad(t *testing.T) {
+ ecdsa, err := NodeID{}.Pubkey()
+ if err == nil {
+ t.Error("expected error for zero ID")
+ }
+ if ecdsa != nil {
+ t.Error("expected nil result")
+ }
}
func TestNodeID_distcmp(t *testing.T) {