aboutsummaryrefslogtreecommitdiffstats
path: root/p2p/enode
diff options
context:
space:
mode:
Diffstat (limited to 'p2p/enode')
-rw-r--r--p2p/enode/localnode.go246
-rw-r--r--p2p/enode/localnode_test.go76
-rw-r--r--p2p/enode/node.go7
-rw-r--r--p2p/enode/nodedb.go124
-rw-r--r--p2p/enode/nodedb_test.go4
5 files changed, 416 insertions, 41 deletions
diff --git a/p2p/enode/localnode.go b/p2p/enode/localnode.go
new file mode 100644
index 000000000..623f8eae1
--- /dev/null
+++ b/p2p/enode/localnode.go
@@ -0,0 +1,246 @@
+// Copyright 2018 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+package enode
+
+import (
+ "crypto/ecdsa"
+ "fmt"
+ "net"
+ "reflect"
+ "strconv"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/ethereum/go-ethereum/log"
+ "github.com/ethereum/go-ethereum/p2p/enr"
+ "github.com/ethereum/go-ethereum/p2p/netutil"
+)
+
+const (
+ // IP tracker configuration
+ iptrackMinStatements = 10
+ iptrackWindow = 5 * time.Minute
+ iptrackContactWindow = 10 * time.Minute
+)
+
+// LocalNode produces the signed node record of a local node, i.e. a node run in the
+// current process. Setting ENR entries via the Set method updates the record. A new version
+// of the record is signed on demand when the Node method is called.
+type LocalNode struct {
+ cur atomic.Value // holds a non-nil node pointer while the record is up-to-date.
+ id ID
+ key *ecdsa.PrivateKey
+ db *DB
+
+ // everything below is protected by a lock
+ mu sync.Mutex
+ seq uint64
+ entries map[string]enr.Entry
+ udpTrack *netutil.IPTracker // predicts external UDP endpoint
+ staticIP net.IP
+ fallbackIP net.IP
+ fallbackUDP int
+}
+
+// NewLocalNode creates a local node.
+func NewLocalNode(db *DB, key *ecdsa.PrivateKey) *LocalNode {
+ ln := &LocalNode{
+ id: PubkeyToIDV4(&key.PublicKey),
+ db: db,
+ key: key,
+ udpTrack: netutil.NewIPTracker(iptrackWindow, iptrackContactWindow, iptrackMinStatements),
+ entries: make(map[string]enr.Entry),
+ }
+ ln.seq = db.localSeq(ln.id)
+ ln.invalidate()
+ return ln
+}
+
+// Database returns the node database associated with the local node.
+func (ln *LocalNode) Database() *DB {
+ return ln.db
+}
+
+// Node returns the current version of the local node record.
+func (ln *LocalNode) Node() *Node {
+ n := ln.cur.Load().(*Node)
+ if n != nil {
+ return n
+ }
+ // Record was invalidated, sign a new copy.
+ ln.mu.Lock()
+ defer ln.mu.Unlock()
+ ln.sign()
+ return ln.cur.Load().(*Node)
+}
+
+// ID returns the local node ID.
+func (ln *LocalNode) ID() ID {
+ return ln.id
+}
+
+// Set puts the given entry into the local record, overwriting
+// any existing value.
+func (ln *LocalNode) Set(e enr.Entry) {
+ ln.mu.Lock()
+ defer ln.mu.Unlock()
+
+ ln.set(e)
+}
+
+func (ln *LocalNode) set(e enr.Entry) {
+ val, exists := ln.entries[e.ENRKey()]
+ if !exists || !reflect.DeepEqual(val, e) {
+ ln.entries[e.ENRKey()] = e
+ ln.invalidate()
+ }
+}
+
+// Delete removes the given entry from the local record.
+func (ln *LocalNode) Delete(e enr.Entry) {
+ ln.mu.Lock()
+ defer ln.mu.Unlock()
+
+ ln.delete(e)
+}
+
+func (ln *LocalNode) delete(e enr.Entry) {
+ _, exists := ln.entries[e.ENRKey()]
+ if exists {
+ delete(ln.entries, e.ENRKey())
+ ln.invalidate()
+ }
+}
+
+// SetStaticIP sets the local IP to the given one unconditionally.
+// This disables endpoint prediction.
+func (ln *LocalNode) SetStaticIP(ip net.IP) {
+ ln.mu.Lock()
+ defer ln.mu.Unlock()
+
+ ln.staticIP = ip
+ ln.updateEndpoints()
+}
+
+// SetFallbackIP sets the last-resort IP address. This address is used
+// if no endpoint prediction can be made and no static IP is set.
+func (ln *LocalNode) SetFallbackIP(ip net.IP) {
+ ln.mu.Lock()
+ defer ln.mu.Unlock()
+
+ ln.fallbackIP = ip
+ ln.updateEndpoints()
+}
+
+// SetFallbackUDP sets the last-resort UDP port. This port is used
+// if no endpoint prediction can be made.
+func (ln *LocalNode) SetFallbackUDP(port int) {
+ ln.mu.Lock()
+ defer ln.mu.Unlock()
+
+ ln.fallbackUDP = port
+ ln.updateEndpoints()
+}
+
+// UDPEndpointStatement should be called whenever a statement about the local node's
+// UDP endpoint is received. It feeds the local endpoint predictor.
+func (ln *LocalNode) UDPEndpointStatement(fromaddr, endpoint *net.UDPAddr) {
+ ln.mu.Lock()
+ defer ln.mu.Unlock()
+
+ ln.udpTrack.AddStatement(fromaddr.String(), endpoint.String())
+ ln.updateEndpoints()
+}
+
+// UDPContact should be called whenever the local node has announced itself to another node
+// via UDP. It feeds the local endpoint predictor.
+func (ln *LocalNode) UDPContact(toaddr *net.UDPAddr) {
+ ln.mu.Lock()
+ defer ln.mu.Unlock()
+
+ ln.udpTrack.AddContact(toaddr.String())
+ ln.updateEndpoints()
+}
+
+func (ln *LocalNode) updateEndpoints() {
+ // Determine the endpoints.
+ newIP := ln.fallbackIP
+ newUDP := ln.fallbackUDP
+ if ln.staticIP != nil {
+ newIP = ln.staticIP
+ } else if ip, port := predictAddr(ln.udpTrack); ip != nil {
+ newIP = ip
+ newUDP = port
+ }
+
+ // Update the record.
+ if newIP != nil && !newIP.IsUnspecified() {
+ ln.set(enr.IP(newIP))
+ if newUDP != 0 {
+ ln.set(enr.UDP(newUDP))
+ } else {
+ ln.delete(enr.UDP(0))
+ }
+ } else {
+ ln.delete(enr.IP{})
+ }
+}
+
+// predictAddr wraps IPTracker.PredictEndpoint, converting from its string-based
+// endpoint representation to IP and port types.
+func predictAddr(t *netutil.IPTracker) (net.IP, int) {
+ ep := t.PredictEndpoint()
+ if ep == "" {
+ return nil, 0
+ }
+ ipString, portString, _ := net.SplitHostPort(ep)
+ ip := net.ParseIP(ipString)
+ port, _ := strconv.Atoi(portString)
+ return ip, port
+}
+
+func (ln *LocalNode) invalidate() {
+ ln.cur.Store((*Node)(nil))
+}
+
+func (ln *LocalNode) sign() {
+ if n := ln.cur.Load().(*Node); n != nil {
+ return // no changes
+ }
+
+ var r enr.Record
+ for _, e := range ln.entries {
+ r.Set(e)
+ }
+ ln.bumpSeq()
+ r.SetSeq(ln.seq)
+ if err := SignV4(&r, ln.key); err != nil {
+ panic(fmt.Errorf("enode: can't sign record: %v", err))
+ }
+ n, err := New(ValidSchemes, &r)
+ if err != nil {
+ panic(fmt.Errorf("enode: can't verify local record: %v", err))
+ }
+ ln.cur.Store(n)
+ log.Info("New local node record", "seq", ln.seq, "id", n.ID(), "ip", n.IP(), "udp", n.UDP(), "tcp", n.TCP())
+}
+
+func (ln *LocalNode) bumpSeq() {
+ ln.seq++
+ ln.db.storeLocalSeq(ln.id, ln.seq)
+}
diff --git a/p2p/enode/localnode_test.go b/p2p/enode/localnode_test.go
new file mode 100644
index 000000000..f5e3496d6
--- /dev/null
+++ b/p2p/enode/localnode_test.go
@@ -0,0 +1,76 @@
+// Copyright 2018 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+package enode
+
+import (
+ "testing"
+
+ "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/p2p/enr"
+)
+
+func newLocalNodeForTesting() (*LocalNode, *DB) {
+ db, _ := OpenDB("")
+ key, _ := crypto.GenerateKey()
+ return NewLocalNode(db, key), db
+}
+
+func TestLocalNode(t *testing.T) {
+ ln, db := newLocalNodeForTesting()
+ defer db.Close()
+
+ if ln.Node().ID() != ln.ID() {
+ t.Fatal("inconsistent ID")
+ }
+
+ ln.Set(enr.WithEntry("x", uint(3)))
+ var x uint
+ if err := ln.Node().Load(enr.WithEntry("x", &x)); err != nil {
+ t.Fatal("can't load entry 'x':", err)
+ } else if x != 3 {
+ t.Fatal("wrong value for entry 'x':", x)
+ }
+}
+
+func TestLocalNodeSeqPersist(t *testing.T) {
+ ln, db := newLocalNodeForTesting()
+ defer db.Close()
+
+ if s := ln.Node().Seq(); s != 1 {
+ t.Fatalf("wrong initial seq %d, want 1", s)
+ }
+ ln.Set(enr.WithEntry("x", uint(1)))
+ if s := ln.Node().Seq(); s != 2 {
+ t.Fatalf("wrong seq %d after set, want 2", s)
+ }
+
+ // Create a new instance, it should reload the sequence number.
+ // The number increases just after that because a new record is
+ // created without the "x" entry.
+ ln2 := NewLocalNode(db, ln.key)
+ if s := ln2.Node().Seq(); s != 3 {
+ t.Fatalf("wrong seq %d on new instance, want 3", s)
+ }
+
+ // Create a new instance with a different node key on the same database.
+ // This should reset the sequence number.
+ key, _ := crypto.GenerateKey()
+ ln3 := NewLocalNode(db, key)
+ if s := ln3.Node().Seq(); s != 1 {
+ t.Fatalf("wrong seq %d on instance with changed key, want 1", s)
+ }
+}
diff --git a/p2p/enode/node.go b/p2p/enode/node.go
index 84088fcd2..b454ab255 100644
--- a/p2p/enode/node.go
+++ b/p2p/enode/node.go
@@ -98,6 +98,13 @@ func (n *Node) Pubkey() *ecdsa.PublicKey {
return &key
}
+// Record returns the node's record. The return value is a copy and may
+// be modified by the caller.
+func (n *Node) Record() *enr.Record {
+ cpy := n.r
+ return &cpy
+}
+
// checks whether n is a valid complete node.
func (n *Node) ValidateComplete() error {
if n.Incomplete() {
diff --git a/p2p/enode/nodedb.go b/p2p/enode/nodedb.go
index a929b75d7..7ee0c09a9 100644
--- a/p2p/enode/nodedb.go
+++ b/p2p/enode/nodedb.go
@@ -35,11 +35,24 @@ import (
"github.com/syndtr/goleveldb/leveldb/util"
)
+// Keys in the node database.
+const (
+ dbVersionKey = "version" // Version of the database to flush if changes
+ dbItemPrefix = "n:" // Identifier to prefix node entries with
+
+ dbDiscoverRoot = ":discover"
+ dbDiscoverSeq = dbDiscoverRoot + ":seq"
+ dbDiscoverPing = dbDiscoverRoot + ":lastping"
+ dbDiscoverPong = dbDiscoverRoot + ":lastpong"
+ dbDiscoverFindFails = dbDiscoverRoot + ":findfail"
+ dbLocalRoot = ":local"
+ dbLocalSeq = dbLocalRoot + ":seq"
+)
+
var (
- nodeDBNilID = ID{} // Special node ID to use as a nil element.
- nodeDBNodeExpiration = 24 * time.Hour // Time after which an unseen node should be dropped.
- nodeDBCleanupCycle = time.Hour // Time period for running the expiration task.
- nodeDBVersion = 6
+ dbNodeExpiration = 24 * time.Hour // Time after which an unseen node should be dropped.
+ dbCleanupCycle = time.Hour // Time period for running the expiration task.
+ dbVersion = 7
)
// DB is the node database, storing previously seen nodes and any collected metadata about
@@ -50,17 +63,6 @@ type DB struct {
quit chan struct{} // Channel to signal the expiring thread to stop
}
-// Schema layout for the node database
-var (
- nodeDBVersionKey = []byte("version") // Version of the database to flush if changes
- nodeDBItemPrefix = []byte("n:") // Identifier to prefix node entries with
-
- nodeDBDiscoverRoot = ":discover"
- nodeDBDiscoverPing = nodeDBDiscoverRoot + ":lastping"
- nodeDBDiscoverPong = nodeDBDiscoverRoot + ":lastpong"
- nodeDBDiscoverFindFails = nodeDBDiscoverRoot + ":findfail"
-)
-
// OpenDB opens a node database for storing and retrieving infos about known peers in the
// network. If no path is given an in-memory, temporary database is constructed.
func OpenDB(path string) (*DB, error) {
@@ -93,13 +95,13 @@ func newPersistentDB(path string) (*DB, error) {
// The nodes contained in the cache correspond to a certain protocol version.
// Flush all nodes if the version doesn't match.
currentVer := make([]byte, binary.MaxVarintLen64)
- currentVer = currentVer[:binary.PutVarint(currentVer, int64(nodeDBVersion))]
+ currentVer = currentVer[:binary.PutVarint(currentVer, int64(dbVersion))]
- blob, err := db.Get(nodeDBVersionKey, nil)
+ blob, err := db.Get([]byte(dbVersionKey), nil)
switch err {
case leveldb.ErrNotFound:
// Version not found (i.e. empty cache), insert it
- if err := db.Put(nodeDBVersionKey, currentVer, nil); err != nil {
+ if err := db.Put([]byte(dbVersionKey), currentVer, nil); err != nil {
db.Close()
return nil, err
}
@@ -120,28 +122,27 @@ func newPersistentDB(path string) (*DB, error) {
// makeKey generates the leveldb key-blob from a node id and its particular
// field of interest.
func makeKey(id ID, field string) []byte {
- if bytes.Equal(id[:], nodeDBNilID[:]) {
+ if (id == ID{}) {
return []byte(field)
}
- return append(nodeDBItemPrefix, append(id[:], field...)...)
+ return append([]byte(dbItemPrefix), append(id[:], field...)...)
}
// splitKey tries to split a database key into a node id and a field part.
func splitKey(key []byte) (id ID, field string) {
// If the key is not of a node, return it plainly
- if !bytes.HasPrefix(key, nodeDBItemPrefix) {
+ if !bytes.HasPrefix(key, []byte(dbItemPrefix)) {
return ID{}, string(key)
}
// Otherwise split the id and field
- item := key[len(nodeDBItemPrefix):]
+ item := key[len(dbItemPrefix):]
copy(id[:], item[:len(id)])
field = string(item[len(id):])
return id, field
}
-// fetchInt64 retrieves an integer instance associated with a particular
-// database key.
+// fetchInt64 retrieves an integer associated with a particular key.
func (db *DB) fetchInt64(key []byte) int64 {
blob, err := db.lvl.Get(key, nil)
if err != nil {
@@ -154,18 +155,33 @@ func (db *DB) fetchInt64(key []byte) int64 {
return val
}
-// storeInt64 update a specific database entry to the current time instance as a
-// unix timestamp.
+// storeInt64 stores an integer in the given key.
func (db *DB) storeInt64(key []byte, n int64) error {
blob := make([]byte, binary.MaxVarintLen64)
blob = blob[:binary.PutVarint(blob, n)]
+ return db.lvl.Put(key, blob, nil)
+}
+
+// fetchUint64 retrieves an integer associated with a particular key.
+func (db *DB) fetchUint64(key []byte) uint64 {
+ blob, err := db.lvl.Get(key, nil)
+ if err != nil {
+ return 0
+ }
+ val, _ := binary.Uvarint(blob)
+ return val
+}
+// storeUint64 stores an integer in the given key.
+func (db *DB) storeUint64(key []byte, n uint64) error {
+ blob := make([]byte, binary.MaxVarintLen64)
+ blob = blob[:binary.PutUvarint(blob, n)]
return db.lvl.Put(key, blob, nil)
}
// Node retrieves a node with a given id from the database.
func (db *DB) Node(id ID) *Node {
- blob, err := db.lvl.Get(makeKey(id, nodeDBDiscoverRoot), nil)
+ blob, err := db.lvl.Get(makeKey(id, dbDiscoverRoot), nil)
if err != nil {
return nil
}
@@ -184,11 +200,31 @@ func mustDecodeNode(id, data []byte) *Node {
// UpdateNode inserts - potentially overwriting - a node into the peer database.
func (db *DB) UpdateNode(node *Node) error {
+ if node.Seq() < db.NodeSeq(node.ID()) {
+ return nil
+ }
blob, err := rlp.EncodeToBytes(&node.r)
if err != nil {
return err
}
- return db.lvl.Put(makeKey(node.ID(), nodeDBDiscoverRoot), blob, nil)
+ if err := db.lvl.Put(makeKey(node.ID(), dbDiscoverRoot), blob, nil); err != nil {
+ return err
+ }
+ return db.storeUint64(makeKey(node.ID(), dbDiscoverSeq), node.Seq())
+}
+
+// NodeSeq returns the stored record sequence number of the given node.
+func (db *DB) NodeSeq(id ID) uint64 {
+ return db.fetchUint64(makeKey(id, dbDiscoverSeq))
+}
+
+// Resolve returns the stored record of the node if it has a larger sequence
+// number than n.
+func (db *DB) Resolve(n *Node) *Node {
+ if n.Seq() > db.NodeSeq(n.ID()) {
+ return n
+ }
+ return db.Node(n.ID())
}
// DeleteNode deletes all information/keys associated with a node.
@@ -218,7 +254,7 @@ func (db *DB) ensureExpirer() {
// expirer should be started in a go routine, and is responsible for looping ad
// infinitum and dropping stale data from the database.
func (db *DB) expirer() {
- tick := time.NewTicker(nodeDBCleanupCycle)
+ tick := time.NewTicker(dbCleanupCycle)
defer tick.Stop()
for {
select {
@@ -235,7 +271,7 @@ func (db *DB) expirer() {
// expireNodes iterates over the database and deletes all nodes that have not
// been seen (i.e. received a pong from) for some allotted time.
func (db *DB) expireNodes() error {
- threshold := time.Now().Add(-nodeDBNodeExpiration)
+ threshold := time.Now().Add(-dbNodeExpiration)
// Find discovered nodes that are older than the allowance
it := db.lvl.NewIterator(nil, nil)
@@ -244,7 +280,7 @@ func (db *DB) expireNodes() error {
for it.Next() {
// Skip the item if not a discovery node
id, field := splitKey(it.Key())
- if field != nodeDBDiscoverRoot {
+ if field != dbDiscoverRoot {
continue
}
// Skip the node if not expired yet (and not self)
@@ -260,34 +296,44 @@ func (db *DB) expireNodes() error {
// LastPingReceived retrieves the time of the last ping packet received from
// a remote node.
func (db *DB) LastPingReceived(id ID) time.Time {
- return time.Unix(db.fetchInt64(makeKey(id, nodeDBDiscoverPing)), 0)
+ return time.Unix(db.fetchInt64(makeKey(id, dbDiscoverPing)), 0)
}
// UpdateLastPingReceived updates the last time we tried contacting a remote node.
func (db *DB) UpdateLastPingReceived(id ID, instance time.Time) error {
- return db.storeInt64(makeKey(id, nodeDBDiscoverPing), instance.Unix())
+ return db.storeInt64(makeKey(id, dbDiscoverPing), instance.Unix())
}
// LastPongReceived retrieves the time of the last successful pong from remote node.
func (db *DB) LastPongReceived(id ID) time.Time {
// Launch expirer
db.ensureExpirer()
- return time.Unix(db.fetchInt64(makeKey(id, nodeDBDiscoverPong)), 0)
+ return time.Unix(db.fetchInt64(makeKey(id, dbDiscoverPong)), 0)
}
// UpdateLastPongReceived updates the last pong time of a node.
func (db *DB) UpdateLastPongReceived(id ID, instance time.Time) error {
- return db.storeInt64(makeKey(id, nodeDBDiscoverPong), instance.Unix())
+ return db.storeInt64(makeKey(id, dbDiscoverPong), instance.Unix())
}
// FindFails retrieves the number of findnode failures since bonding.
func (db *DB) FindFails(id ID) int {
- return int(db.fetchInt64(makeKey(id, nodeDBDiscoverFindFails)))
+ return int(db.fetchInt64(makeKey(id, dbDiscoverFindFails)))
}
// UpdateFindFails updates the number of findnode failures since bonding.
func (db *DB) UpdateFindFails(id ID, fails int) error {
- return db.storeInt64(makeKey(id, nodeDBDiscoverFindFails), int64(fails))
+ return db.storeInt64(makeKey(id, dbDiscoverFindFails), int64(fails))
+}
+
+// LocalSeq retrieves the local record sequence counter.
+func (db *DB) localSeq(id ID) uint64 {
+ return db.fetchUint64(makeKey(id, dbLocalSeq))
+}
+
+// storeLocalSeq stores the local record sequence counter.
+func (db *DB) storeLocalSeq(id ID, n uint64) {
+ db.storeUint64(makeKey(id, dbLocalSeq), n)
}
// QuerySeeds retrieves random nodes to be used as potential seed nodes
@@ -309,7 +355,7 @@ seek:
ctr := id[0]
rand.Read(id[:])
id[0] = ctr + id[0]%16
- it.Seek(makeKey(id, nodeDBDiscoverRoot))
+ it.Seek(makeKey(id, dbDiscoverRoot))
n := nextNode(it)
if n == nil {
@@ -334,7 +380,7 @@ seek:
func nextNode(it iterator.Iterator) *Node {
for end := false; !end; end = !it.Next() {
id, field := splitKey(it.Key())
- if field != nodeDBDiscoverRoot {
+ if field != dbDiscoverRoot {
continue
}
return mustDecodeNode(id[:], it.Value())
diff --git a/p2p/enode/nodedb_test.go b/p2p/enode/nodedb_test.go
index b476a3439..96794827c 100644
--- a/p2p/enode/nodedb_test.go
+++ b/p2p/enode/nodedb_test.go
@@ -332,7 +332,7 @@ var nodeDBExpirationNodes = []struct {
30303,
30303,
),
- pong: time.Now().Add(-nodeDBNodeExpiration + time.Minute),
+ pong: time.Now().Add(-dbNodeExpiration + time.Minute),
exp: false,
}, {
node: NewV4(
@@ -341,7 +341,7 @@ var nodeDBExpirationNodes = []struct {
30303,
30303,
),
- pong: time.Now().Add(-nodeDBNodeExpiration - time.Minute),
+ pong: time.Now().Add(-dbNodeExpiration - time.Minute),
exp: true,
},
}