diff options
Diffstat (limited to 'p2p/enode')
-rw-r--r-- | p2p/enode/localnode.go | 246 | ||||
-rw-r--r-- | p2p/enode/localnode_test.go | 76 | ||||
-rw-r--r-- | p2p/enode/node.go | 7 | ||||
-rw-r--r-- | p2p/enode/nodedb.go | 124 | ||||
-rw-r--r-- | p2p/enode/nodedb_test.go | 4 |
5 files changed, 416 insertions, 41 deletions
diff --git a/p2p/enode/localnode.go b/p2p/enode/localnode.go new file mode 100644 index 000000000..623f8eae1 --- /dev/null +++ b/p2p/enode/localnode.go @@ -0,0 +1,246 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +package enode + +import ( + "crypto/ecdsa" + "fmt" + "net" + "reflect" + "strconv" + "sync" + "sync/atomic" + "time" + + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p/enr" + "github.com/ethereum/go-ethereum/p2p/netutil" +) + +const ( + // IP tracker configuration + iptrackMinStatements = 10 + iptrackWindow = 5 * time.Minute + iptrackContactWindow = 10 * time.Minute +) + +// LocalNode produces the signed node record of a local node, i.e. a node run in the +// current process. Setting ENR entries via the Set method updates the record. A new version +// of the record is signed on demand when the Node method is called. +type LocalNode struct { + cur atomic.Value // holds a non-nil node pointer while the record is up-to-date. + id ID + key *ecdsa.PrivateKey + db *DB + + // everything below is protected by a lock + mu sync.Mutex + seq uint64 + entries map[string]enr.Entry + udpTrack *netutil.IPTracker // predicts external UDP endpoint + staticIP net.IP + fallbackIP net.IP + fallbackUDP int +} + +// NewLocalNode creates a local node. +func NewLocalNode(db *DB, key *ecdsa.PrivateKey) *LocalNode { + ln := &LocalNode{ + id: PubkeyToIDV4(&key.PublicKey), + db: db, + key: key, + udpTrack: netutil.NewIPTracker(iptrackWindow, iptrackContactWindow, iptrackMinStatements), + entries: make(map[string]enr.Entry), + } + ln.seq = db.localSeq(ln.id) + ln.invalidate() + return ln +} + +// Database returns the node database associated with the local node. +func (ln *LocalNode) Database() *DB { + return ln.db +} + +// Node returns the current version of the local node record. +func (ln *LocalNode) Node() *Node { + n := ln.cur.Load().(*Node) + if n != nil { + return n + } + // Record was invalidated, sign a new copy. + ln.mu.Lock() + defer ln.mu.Unlock() + ln.sign() + return ln.cur.Load().(*Node) +} + +// ID returns the local node ID. +func (ln *LocalNode) ID() ID { + return ln.id +} + +// Set puts the given entry into the local record, overwriting +// any existing value. +func (ln *LocalNode) Set(e enr.Entry) { + ln.mu.Lock() + defer ln.mu.Unlock() + + ln.set(e) +} + +func (ln *LocalNode) set(e enr.Entry) { + val, exists := ln.entries[e.ENRKey()] + if !exists || !reflect.DeepEqual(val, e) { + ln.entries[e.ENRKey()] = e + ln.invalidate() + } +} + +// Delete removes the given entry from the local record. +func (ln *LocalNode) Delete(e enr.Entry) { + ln.mu.Lock() + defer ln.mu.Unlock() + + ln.delete(e) +} + +func (ln *LocalNode) delete(e enr.Entry) { + _, exists := ln.entries[e.ENRKey()] + if exists { + delete(ln.entries, e.ENRKey()) + ln.invalidate() + } +} + +// SetStaticIP sets the local IP to the given one unconditionally. +// This disables endpoint prediction. +func (ln *LocalNode) SetStaticIP(ip net.IP) { + ln.mu.Lock() + defer ln.mu.Unlock() + + ln.staticIP = ip + ln.updateEndpoints() +} + +// SetFallbackIP sets the last-resort IP address. This address is used +// if no endpoint prediction can be made and no static IP is set. +func (ln *LocalNode) SetFallbackIP(ip net.IP) { + ln.mu.Lock() + defer ln.mu.Unlock() + + ln.fallbackIP = ip + ln.updateEndpoints() +} + +// SetFallbackUDP sets the last-resort UDP port. This port is used +// if no endpoint prediction can be made. +func (ln *LocalNode) SetFallbackUDP(port int) { + ln.mu.Lock() + defer ln.mu.Unlock() + + ln.fallbackUDP = port + ln.updateEndpoints() +} + +// UDPEndpointStatement should be called whenever a statement about the local node's +// UDP endpoint is received. It feeds the local endpoint predictor. +func (ln *LocalNode) UDPEndpointStatement(fromaddr, endpoint *net.UDPAddr) { + ln.mu.Lock() + defer ln.mu.Unlock() + + ln.udpTrack.AddStatement(fromaddr.String(), endpoint.String()) + ln.updateEndpoints() +} + +// UDPContact should be called whenever the local node has announced itself to another node +// via UDP. It feeds the local endpoint predictor. +func (ln *LocalNode) UDPContact(toaddr *net.UDPAddr) { + ln.mu.Lock() + defer ln.mu.Unlock() + + ln.udpTrack.AddContact(toaddr.String()) + ln.updateEndpoints() +} + +func (ln *LocalNode) updateEndpoints() { + // Determine the endpoints. + newIP := ln.fallbackIP + newUDP := ln.fallbackUDP + if ln.staticIP != nil { + newIP = ln.staticIP + } else if ip, port := predictAddr(ln.udpTrack); ip != nil { + newIP = ip + newUDP = port + } + + // Update the record. + if newIP != nil && !newIP.IsUnspecified() { + ln.set(enr.IP(newIP)) + if newUDP != 0 { + ln.set(enr.UDP(newUDP)) + } else { + ln.delete(enr.UDP(0)) + } + } else { + ln.delete(enr.IP{}) + } +} + +// predictAddr wraps IPTracker.PredictEndpoint, converting from its string-based +// endpoint representation to IP and port types. +func predictAddr(t *netutil.IPTracker) (net.IP, int) { + ep := t.PredictEndpoint() + if ep == "" { + return nil, 0 + } + ipString, portString, _ := net.SplitHostPort(ep) + ip := net.ParseIP(ipString) + port, _ := strconv.Atoi(portString) + return ip, port +} + +func (ln *LocalNode) invalidate() { + ln.cur.Store((*Node)(nil)) +} + +func (ln *LocalNode) sign() { + if n := ln.cur.Load().(*Node); n != nil { + return // no changes + } + + var r enr.Record + for _, e := range ln.entries { + r.Set(e) + } + ln.bumpSeq() + r.SetSeq(ln.seq) + if err := SignV4(&r, ln.key); err != nil { + panic(fmt.Errorf("enode: can't sign record: %v", err)) + } + n, err := New(ValidSchemes, &r) + if err != nil { + panic(fmt.Errorf("enode: can't verify local record: %v", err)) + } + ln.cur.Store(n) + log.Info("New local node record", "seq", ln.seq, "id", n.ID(), "ip", n.IP(), "udp", n.UDP(), "tcp", n.TCP()) +} + +func (ln *LocalNode) bumpSeq() { + ln.seq++ + ln.db.storeLocalSeq(ln.id, ln.seq) +} diff --git a/p2p/enode/localnode_test.go b/p2p/enode/localnode_test.go new file mode 100644 index 000000000..f5e3496d6 --- /dev/null +++ b/p2p/enode/localnode_test.go @@ -0,0 +1,76 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +package enode + +import ( + "testing" + + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/p2p/enr" +) + +func newLocalNodeForTesting() (*LocalNode, *DB) { + db, _ := OpenDB("") + key, _ := crypto.GenerateKey() + return NewLocalNode(db, key), db +} + +func TestLocalNode(t *testing.T) { + ln, db := newLocalNodeForTesting() + defer db.Close() + + if ln.Node().ID() != ln.ID() { + t.Fatal("inconsistent ID") + } + + ln.Set(enr.WithEntry("x", uint(3))) + var x uint + if err := ln.Node().Load(enr.WithEntry("x", &x)); err != nil { + t.Fatal("can't load entry 'x':", err) + } else if x != 3 { + t.Fatal("wrong value for entry 'x':", x) + } +} + +func TestLocalNodeSeqPersist(t *testing.T) { + ln, db := newLocalNodeForTesting() + defer db.Close() + + if s := ln.Node().Seq(); s != 1 { + t.Fatalf("wrong initial seq %d, want 1", s) + } + ln.Set(enr.WithEntry("x", uint(1))) + if s := ln.Node().Seq(); s != 2 { + t.Fatalf("wrong seq %d after set, want 2", s) + } + + // Create a new instance, it should reload the sequence number. + // The number increases just after that because a new record is + // created without the "x" entry. + ln2 := NewLocalNode(db, ln.key) + if s := ln2.Node().Seq(); s != 3 { + t.Fatalf("wrong seq %d on new instance, want 3", s) + } + + // Create a new instance with a different node key on the same database. + // This should reset the sequence number. + key, _ := crypto.GenerateKey() + ln3 := NewLocalNode(db, key) + if s := ln3.Node().Seq(); s != 1 { + t.Fatalf("wrong seq %d on instance with changed key, want 1", s) + } +} diff --git a/p2p/enode/node.go b/p2p/enode/node.go index 84088fcd2..b454ab255 100644 --- a/p2p/enode/node.go +++ b/p2p/enode/node.go @@ -98,6 +98,13 @@ func (n *Node) Pubkey() *ecdsa.PublicKey { return &key } +// Record returns the node's record. The return value is a copy and may +// be modified by the caller. +func (n *Node) Record() *enr.Record { + cpy := n.r + return &cpy +} + // checks whether n is a valid complete node. func (n *Node) ValidateComplete() error { if n.Incomplete() { diff --git a/p2p/enode/nodedb.go b/p2p/enode/nodedb.go index a929b75d7..7ee0c09a9 100644 --- a/p2p/enode/nodedb.go +++ b/p2p/enode/nodedb.go @@ -35,11 +35,24 @@ import ( "github.com/syndtr/goleveldb/leveldb/util" ) +// Keys in the node database. +const ( + dbVersionKey = "version" // Version of the database to flush if changes + dbItemPrefix = "n:" // Identifier to prefix node entries with + + dbDiscoverRoot = ":discover" + dbDiscoverSeq = dbDiscoverRoot + ":seq" + dbDiscoverPing = dbDiscoverRoot + ":lastping" + dbDiscoverPong = dbDiscoverRoot + ":lastpong" + dbDiscoverFindFails = dbDiscoverRoot + ":findfail" + dbLocalRoot = ":local" + dbLocalSeq = dbLocalRoot + ":seq" +) + var ( - nodeDBNilID = ID{} // Special node ID to use as a nil element. - nodeDBNodeExpiration = 24 * time.Hour // Time after which an unseen node should be dropped. - nodeDBCleanupCycle = time.Hour // Time period for running the expiration task. - nodeDBVersion = 6 + dbNodeExpiration = 24 * time.Hour // Time after which an unseen node should be dropped. + dbCleanupCycle = time.Hour // Time period for running the expiration task. + dbVersion = 7 ) // DB is the node database, storing previously seen nodes and any collected metadata about @@ -50,17 +63,6 @@ type DB struct { quit chan struct{} // Channel to signal the expiring thread to stop } -// Schema layout for the node database -var ( - nodeDBVersionKey = []byte("version") // Version of the database to flush if changes - nodeDBItemPrefix = []byte("n:") // Identifier to prefix node entries with - - nodeDBDiscoverRoot = ":discover" - nodeDBDiscoverPing = nodeDBDiscoverRoot + ":lastping" - nodeDBDiscoverPong = nodeDBDiscoverRoot + ":lastpong" - nodeDBDiscoverFindFails = nodeDBDiscoverRoot + ":findfail" -) - // OpenDB opens a node database for storing and retrieving infos about known peers in the // network. If no path is given an in-memory, temporary database is constructed. func OpenDB(path string) (*DB, error) { @@ -93,13 +95,13 @@ func newPersistentDB(path string) (*DB, error) { // The nodes contained in the cache correspond to a certain protocol version. // Flush all nodes if the version doesn't match. currentVer := make([]byte, binary.MaxVarintLen64) - currentVer = currentVer[:binary.PutVarint(currentVer, int64(nodeDBVersion))] + currentVer = currentVer[:binary.PutVarint(currentVer, int64(dbVersion))] - blob, err := db.Get(nodeDBVersionKey, nil) + blob, err := db.Get([]byte(dbVersionKey), nil) switch err { case leveldb.ErrNotFound: // Version not found (i.e. empty cache), insert it - if err := db.Put(nodeDBVersionKey, currentVer, nil); err != nil { + if err := db.Put([]byte(dbVersionKey), currentVer, nil); err != nil { db.Close() return nil, err } @@ -120,28 +122,27 @@ func newPersistentDB(path string) (*DB, error) { // makeKey generates the leveldb key-blob from a node id and its particular // field of interest. func makeKey(id ID, field string) []byte { - if bytes.Equal(id[:], nodeDBNilID[:]) { + if (id == ID{}) { return []byte(field) } - return append(nodeDBItemPrefix, append(id[:], field...)...) + return append([]byte(dbItemPrefix), append(id[:], field...)...) } // splitKey tries to split a database key into a node id and a field part. func splitKey(key []byte) (id ID, field string) { // If the key is not of a node, return it plainly - if !bytes.HasPrefix(key, nodeDBItemPrefix) { + if !bytes.HasPrefix(key, []byte(dbItemPrefix)) { return ID{}, string(key) } // Otherwise split the id and field - item := key[len(nodeDBItemPrefix):] + item := key[len(dbItemPrefix):] copy(id[:], item[:len(id)]) field = string(item[len(id):]) return id, field } -// fetchInt64 retrieves an integer instance associated with a particular -// database key. +// fetchInt64 retrieves an integer associated with a particular key. func (db *DB) fetchInt64(key []byte) int64 { blob, err := db.lvl.Get(key, nil) if err != nil { @@ -154,18 +155,33 @@ func (db *DB) fetchInt64(key []byte) int64 { return val } -// storeInt64 update a specific database entry to the current time instance as a -// unix timestamp. +// storeInt64 stores an integer in the given key. func (db *DB) storeInt64(key []byte, n int64) error { blob := make([]byte, binary.MaxVarintLen64) blob = blob[:binary.PutVarint(blob, n)] + return db.lvl.Put(key, blob, nil) +} + +// fetchUint64 retrieves an integer associated with a particular key. +func (db *DB) fetchUint64(key []byte) uint64 { + blob, err := db.lvl.Get(key, nil) + if err != nil { + return 0 + } + val, _ := binary.Uvarint(blob) + return val +} +// storeUint64 stores an integer in the given key. +func (db *DB) storeUint64(key []byte, n uint64) error { + blob := make([]byte, binary.MaxVarintLen64) + blob = blob[:binary.PutUvarint(blob, n)] return db.lvl.Put(key, blob, nil) } // Node retrieves a node with a given id from the database. func (db *DB) Node(id ID) *Node { - blob, err := db.lvl.Get(makeKey(id, nodeDBDiscoverRoot), nil) + blob, err := db.lvl.Get(makeKey(id, dbDiscoverRoot), nil) if err != nil { return nil } @@ -184,11 +200,31 @@ func mustDecodeNode(id, data []byte) *Node { // UpdateNode inserts - potentially overwriting - a node into the peer database. func (db *DB) UpdateNode(node *Node) error { + if node.Seq() < db.NodeSeq(node.ID()) { + return nil + } blob, err := rlp.EncodeToBytes(&node.r) if err != nil { return err } - return db.lvl.Put(makeKey(node.ID(), nodeDBDiscoverRoot), blob, nil) + if err := db.lvl.Put(makeKey(node.ID(), dbDiscoverRoot), blob, nil); err != nil { + return err + } + return db.storeUint64(makeKey(node.ID(), dbDiscoverSeq), node.Seq()) +} + +// NodeSeq returns the stored record sequence number of the given node. +func (db *DB) NodeSeq(id ID) uint64 { + return db.fetchUint64(makeKey(id, dbDiscoverSeq)) +} + +// Resolve returns the stored record of the node if it has a larger sequence +// number than n. +func (db *DB) Resolve(n *Node) *Node { + if n.Seq() > db.NodeSeq(n.ID()) { + return n + } + return db.Node(n.ID()) } // DeleteNode deletes all information/keys associated with a node. @@ -218,7 +254,7 @@ func (db *DB) ensureExpirer() { // expirer should be started in a go routine, and is responsible for looping ad // infinitum and dropping stale data from the database. func (db *DB) expirer() { - tick := time.NewTicker(nodeDBCleanupCycle) + tick := time.NewTicker(dbCleanupCycle) defer tick.Stop() for { select { @@ -235,7 +271,7 @@ func (db *DB) expirer() { // expireNodes iterates over the database and deletes all nodes that have not // been seen (i.e. received a pong from) for some allotted time. func (db *DB) expireNodes() error { - threshold := time.Now().Add(-nodeDBNodeExpiration) + threshold := time.Now().Add(-dbNodeExpiration) // Find discovered nodes that are older than the allowance it := db.lvl.NewIterator(nil, nil) @@ -244,7 +280,7 @@ func (db *DB) expireNodes() error { for it.Next() { // Skip the item if not a discovery node id, field := splitKey(it.Key()) - if field != nodeDBDiscoverRoot { + if field != dbDiscoverRoot { continue } // Skip the node if not expired yet (and not self) @@ -260,34 +296,44 @@ func (db *DB) expireNodes() error { // LastPingReceived retrieves the time of the last ping packet received from // a remote node. func (db *DB) LastPingReceived(id ID) time.Time { - return time.Unix(db.fetchInt64(makeKey(id, nodeDBDiscoverPing)), 0) + return time.Unix(db.fetchInt64(makeKey(id, dbDiscoverPing)), 0) } // UpdateLastPingReceived updates the last time we tried contacting a remote node. func (db *DB) UpdateLastPingReceived(id ID, instance time.Time) error { - return db.storeInt64(makeKey(id, nodeDBDiscoverPing), instance.Unix()) + return db.storeInt64(makeKey(id, dbDiscoverPing), instance.Unix()) } // LastPongReceived retrieves the time of the last successful pong from remote node. func (db *DB) LastPongReceived(id ID) time.Time { // Launch expirer db.ensureExpirer() - return time.Unix(db.fetchInt64(makeKey(id, nodeDBDiscoverPong)), 0) + return time.Unix(db.fetchInt64(makeKey(id, dbDiscoverPong)), 0) } // UpdateLastPongReceived updates the last pong time of a node. func (db *DB) UpdateLastPongReceived(id ID, instance time.Time) error { - return db.storeInt64(makeKey(id, nodeDBDiscoverPong), instance.Unix()) + return db.storeInt64(makeKey(id, dbDiscoverPong), instance.Unix()) } // FindFails retrieves the number of findnode failures since bonding. func (db *DB) FindFails(id ID) int { - return int(db.fetchInt64(makeKey(id, nodeDBDiscoverFindFails))) + return int(db.fetchInt64(makeKey(id, dbDiscoverFindFails))) } // UpdateFindFails updates the number of findnode failures since bonding. func (db *DB) UpdateFindFails(id ID, fails int) error { - return db.storeInt64(makeKey(id, nodeDBDiscoverFindFails), int64(fails)) + return db.storeInt64(makeKey(id, dbDiscoverFindFails), int64(fails)) +} + +// LocalSeq retrieves the local record sequence counter. +func (db *DB) localSeq(id ID) uint64 { + return db.fetchUint64(makeKey(id, dbLocalSeq)) +} + +// storeLocalSeq stores the local record sequence counter. +func (db *DB) storeLocalSeq(id ID, n uint64) { + db.storeUint64(makeKey(id, dbLocalSeq), n) } // QuerySeeds retrieves random nodes to be used as potential seed nodes @@ -309,7 +355,7 @@ seek: ctr := id[0] rand.Read(id[:]) id[0] = ctr + id[0]%16 - it.Seek(makeKey(id, nodeDBDiscoverRoot)) + it.Seek(makeKey(id, dbDiscoverRoot)) n := nextNode(it) if n == nil { @@ -334,7 +380,7 @@ seek: func nextNode(it iterator.Iterator) *Node { for end := false; !end; end = !it.Next() { id, field := splitKey(it.Key()) - if field != nodeDBDiscoverRoot { + if field != dbDiscoverRoot { continue } return mustDecodeNode(id[:], it.Value()) diff --git a/p2p/enode/nodedb_test.go b/p2p/enode/nodedb_test.go index b476a3439..96794827c 100644 --- a/p2p/enode/nodedb_test.go +++ b/p2p/enode/nodedb_test.go @@ -332,7 +332,7 @@ var nodeDBExpirationNodes = []struct { 30303, 30303, ), - pong: time.Now().Add(-nodeDBNodeExpiration + time.Minute), + pong: time.Now().Add(-dbNodeExpiration + time.Minute), exp: false, }, { node: NewV4( @@ -341,7 +341,7 @@ var nodeDBExpirationNodes = []struct { 30303, 30303, ), - pong: time.Now().Add(-nodeDBNodeExpiration - time.Minute), + pong: time.Now().Add(-dbNodeExpiration - time.Minute), exp: true, }, } |