diff options
author | Felix Lange <fjl@users.noreply.github.com> | 2018-10-12 17:47:24 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-10-12 17:47:24 +0800 |
commit | 6f607de5d590ff2fbe8798b04e5924be3b7ca0b4 (patch) | |
tree | 2905b3462c0d4f162914a948dac6d1836ace4b77 | |
parent | dcae0d348bb7f5d9052e50a83383a33538ce376a (diff) | |
download | dexon-6f607de5d590ff2fbe8798b04e5924be3b7ca0b4.tar dexon-6f607de5d590ff2fbe8798b04e5924be3b7ca0b4.tar.gz dexon-6f607de5d590ff2fbe8798b04e5924be3b7ca0b4.tar.bz2 dexon-6f607de5d590ff2fbe8798b04e5924be3b7ca0b4.tar.lz dexon-6f607de5d590ff2fbe8798b04e5924be3b7ca0b4.tar.xz dexon-6f607de5d590ff2fbe8798b04e5924be3b7ca0b4.tar.zst dexon-6f607de5d590ff2fbe8798b04e5924be3b7ca0b4.zip |
p2p, p2p/discover: add signed ENR generation (#17753)
This PR adds enode.LocalNode and integrates it into the p2p
subsystem. This new object is the keeper of the local node
record. For now, a new version of the record is produced every
time the client restarts. We'll make it smarter to avoid that in
the future.
There are a couple of other changes in this commit: discovery now
waits for all of its goroutines at shutdown and the p2p server
now closes the node database after discovery has shut down. This
fixes a leveldb crash in tests. p2p server startup is faster
because it doesn't need to wait for the external IP query
anymore.
-rw-r--r-- | cmd/bootnode/main.go | 11 | ||||
-rw-r--r-- | node/node_test.go | 8 | ||||
-rw-r--r-- | p2p/dial.go | 7 | ||||
-rw-r--r-- | p2p/dial_test.go | 16 | ||||
-rw-r--r-- | p2p/discover/table.go | 41 | ||||
-rw-r--r-- | p2p/discover/table_test.go | 10 | ||||
-rw-r--r-- | p2p/discover/table_util_test.go | 25 | ||||
-rw-r--r-- | p2p/discover/udp.go | 92 | ||||
-rw-r--r-- | p2p/discover/udp_test.go | 11 | ||||
-rw-r--r-- | p2p/discv5/udp.go | 3 | ||||
-rw-r--r-- | p2p/enode/localnode.go | 246 | ||||
-rw-r--r-- | p2p/enode/localnode_test.go | 76 | ||||
-rw-r--r-- | p2p/enode/node.go | 7 | ||||
-rw-r--r-- | p2p/enode/nodedb.go | 124 | ||||
-rw-r--r-- | p2p/enode/nodedb_test.go | 4 | ||||
-rw-r--r-- | p2p/enr/enr.go | 2 | ||||
-rw-r--r-- | p2p/enr/enr_test.go | 12 | ||||
-rw-r--r-- | p2p/nat/nat.go | 18 | ||||
-rw-r--r-- | p2p/nat/nat_test.go | 2 | ||||
-rw-r--r-- | p2p/netutil/iptrack.go | 130 | ||||
-rw-r--r-- | p2p/netutil/iptrack_test.go | 138 | ||||
-rw-r--r-- | p2p/protocol.go | 10 | ||||
-rw-r--r-- | p2p/server.go | 232 | ||||
-rw-r--r-- | p2p/server_test.go | 26 |
24 files changed, 976 insertions, 275 deletions
diff --git a/cmd/bootnode/main.go b/cmd/bootnode/main.go index 845900865..346523ddb 100644 --- a/cmd/bootnode/main.go +++ b/cmd/bootnode/main.go @@ -119,16 +119,17 @@ func main() { } if *runv5 { - if _, err := discv5.ListenUDP(nodeKey, conn, realaddr, "", restrictList); err != nil { + if _, err := discv5.ListenUDP(nodeKey, conn, "", restrictList); err != nil { utils.Fatalf("%v", err) } } else { + db, _ := enode.OpenDB("") + ln := enode.NewLocalNode(db, nodeKey) cfg := discover.Config{ - PrivateKey: nodeKey, - AnnounceAddr: realaddr, - NetRestrict: restrictList, + PrivateKey: nodeKey, + NetRestrict: restrictList, } - if _, err := discover.ListenUDP(conn, cfg); err != nil { + if _, err := discover.ListenUDP(conn, ln, cfg); err != nil { utils.Fatalf("%v", err) } } diff --git a/node/node_test.go b/node/node_test.go index e51900bd1..f833cd688 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -454,9 +454,9 @@ func TestProtocolGather(t *testing.T) { Count int Maker InstrumentingWrapper }{ - "Zero Protocols": {0, InstrumentedServiceMakerA}, - "Single Protocol": {1, InstrumentedServiceMakerB}, - "Many Protocols": {25, InstrumentedServiceMakerC}, + "zero": {0, InstrumentedServiceMakerA}, + "one": {1, InstrumentedServiceMakerB}, + "many": {10, InstrumentedServiceMakerC}, } for id, config := range services { protocols := make([]p2p.Protocol, config.Count) @@ -480,7 +480,7 @@ func TestProtocolGather(t *testing.T) { defer stack.Stop() protocols := stack.Server().Protocols - if len(protocols) != 26 { + if len(protocols) != 11 { t.Fatalf("mismatching number of protocols launched: have %d, want %d", len(protocols), 26) } for id, config := range services { diff --git a/p2p/dial.go b/p2p/dial.go index 359cdbcbb..d228514fc 100644 --- a/p2p/dial.go +++ b/p2p/dial.go @@ -71,6 +71,7 @@ type dialstate struct { maxDynDials int ntab discoverTable netrestrict *netutil.Netlist + self enode.ID lookupRunning bool dialing map[enode.ID]connFlag @@ -84,7 +85,6 @@ type dialstate struct { } type discoverTable interface { - Self() *enode.Node Close() Resolve(*enode.Node) *enode.Node LookupRandom() []*enode.Node @@ -126,10 +126,11 @@ type waitExpireTask struct { time.Duration } -func newDialState(static []*enode.Node, bootnodes []*enode.Node, ntab discoverTable, maxdyn int, netrestrict *netutil.Netlist) *dialstate { +func newDialState(self enode.ID, static []*enode.Node, bootnodes []*enode.Node, ntab discoverTable, maxdyn int, netrestrict *netutil.Netlist) *dialstate { s := &dialstate{ maxDynDials: maxdyn, ntab: ntab, + self: self, netrestrict: netrestrict, static: make(map[enode.ID]*dialTask), dialing: make(map[enode.ID]connFlag), @@ -266,7 +267,7 @@ func (s *dialstate) checkDial(n *enode.Node, peers map[enode.ID]*Peer) error { return errAlreadyDialing case peers[n.ID()] != nil: return errAlreadyConnected - case s.ntab != nil && n.ID() == s.ntab.Self().ID(): + case n.ID() == s.self: return errSelf case s.netrestrict != nil && !s.netrestrict.Contains(n.IP()): return errNotWhitelisted diff --git a/p2p/dial_test.go b/p2p/dial_test.go index 2de2c5999..f41ab7752 100644 --- a/p2p/dial_test.go +++ b/p2p/dial_test.go @@ -89,7 +89,7 @@ func (t fakeTable) ReadRandomNodes(buf []*enode.Node) int { return copy(buf, t) // This test checks that dynamic dials are launched from discovery results. func TestDialStateDynDial(t *testing.T) { runDialTest(t, dialtest{ - init: newDialState(nil, nil, fakeTable{}, 5, nil), + init: newDialState(enode.ID{}, nil, nil, fakeTable{}, 5, nil), rounds: []round{ // A discovery query is launched. { @@ -236,7 +236,7 @@ func TestDialStateDynDialBootnode(t *testing.T) { newNode(uintID(8), nil), } runDialTest(t, dialtest{ - init: newDialState(nil, bootnodes, table, 5, nil), + init: newDialState(enode.ID{}, nil, bootnodes, table, 5, nil), rounds: []round{ // 2 dynamic dials attempted, bootnodes pending fallback interval { @@ -324,7 +324,7 @@ func TestDialStateDynDialFromTable(t *testing.T) { } runDialTest(t, dialtest{ - init: newDialState(nil, nil, table, 10, nil), + init: newDialState(enode.ID{}, nil, nil, table, 10, nil), rounds: []round{ // 5 out of 8 of the nodes returned by ReadRandomNodes are dialed. { @@ -430,7 +430,7 @@ func TestDialStateNetRestrict(t *testing.T) { restrict.Add("127.0.2.0/24") runDialTest(t, dialtest{ - init: newDialState(nil, nil, table, 10, restrict), + init: newDialState(enode.ID{}, nil, nil, table, 10, restrict), rounds: []round{ { new: []task{ @@ -453,7 +453,7 @@ func TestDialStateStaticDial(t *testing.T) { } runDialTest(t, dialtest{ - init: newDialState(wantStatic, nil, fakeTable{}, 0, nil), + init: newDialState(enode.ID{}, wantStatic, nil, fakeTable{}, 0, nil), rounds: []round{ // Static dials are launched for the nodes that // aren't yet connected. @@ -557,7 +557,7 @@ func TestDialStaticAfterReset(t *testing.T) { }, } dTest := dialtest{ - init: newDialState(wantStatic, nil, fakeTable{}, 0, nil), + init: newDialState(enode.ID{}, wantStatic, nil, fakeTable{}, 0, nil), rounds: rounds, } runDialTest(t, dTest) @@ -578,7 +578,7 @@ func TestDialStateCache(t *testing.T) { } runDialTest(t, dialtest{ - init: newDialState(wantStatic, nil, fakeTable{}, 0, nil), + init: newDialState(enode.ID{}, wantStatic, nil, fakeTable{}, 0, nil), rounds: []round{ // Static dials are launched for the nodes that // aren't yet connected. @@ -640,7 +640,7 @@ func TestDialStateCache(t *testing.T) { func TestDialResolve(t *testing.T) { resolved := newNode(uintID(1), net.IP{127, 0, 55, 234}) table := &resolveMock{answer: resolved} - state := newDialState(nil, nil, table, 0, nil) + state := newDialState(enode.ID{}, nil, nil, table, 0, nil) // Check that the task is generated with an incomplete ID. dest := newNode(uintID(1), nil) diff --git a/p2p/discover/table.go b/p2p/discover/table.go index 7a3e41de1..afd4c9a27 100644 --- a/p2p/discover/table.go +++ b/p2p/discover/table.go @@ -72,21 +72,20 @@ type Table struct { ips netutil.DistinctNetSet db *enode.DB // database of known nodes + net transport refreshReq chan chan struct{} initDone chan struct{} closeReq chan struct{} closed chan struct{} nodeAddedHook func(*node) // for testing - - net transport - self *node // metadata of the local node } // transport is implemented by the UDP transport. // it is an interface so we can test without opening lots of UDP // sockets and without generating a private key. type transport interface { + self() *enode.Node ping(enode.ID, *net.UDPAddr) error findnode(toid enode.ID, addr *net.UDPAddr, target encPubkey) ([]*node, error) close() @@ -100,11 +99,10 @@ type bucket struct { ips netutil.DistinctNetSet } -func newTable(t transport, self *enode.Node, db *enode.DB, bootnodes []*enode.Node) (*Table, error) { +func newTable(t transport, db *enode.DB, bootnodes []*enode.Node) (*Table, error) { tab := &Table{ net: t, db: db, - self: wrapNode(self), refreshReq: make(chan chan struct{}), initDone: make(chan struct{}), closeReq: make(chan struct{}), @@ -127,6 +125,10 @@ func newTable(t transport, self *enode.Node, db *enode.DB, bootnodes []*enode.No return tab, nil } +func (tab *Table) self() *enode.Node { + return tab.net.self() +} + func (tab *Table) seedRand() { var b [8]byte crand.Read(b[:]) @@ -136,11 +138,6 @@ func (tab *Table) seedRand() { tab.mutex.Unlock() } -// Self returns the local node. -func (tab *Table) Self() *enode.Node { - return unwrapNode(tab.self) -} - // ReadRandomNodes fills the given slice with random nodes from the table. The results // are guaranteed to be unique for a single invocation, no node will appear twice. func (tab *Table) ReadRandomNodes(buf []*enode.Node) (n int) { @@ -183,6 +180,10 @@ func (tab *Table) ReadRandomNodes(buf []*enode.Node) (n int) { // Close terminates the network listener and flushes the node database. func (tab *Table) Close() { + if tab.net != nil { + tab.net.close() + } + select { case <-tab.closed: // already closed. @@ -257,7 +258,7 @@ func (tab *Table) lookup(targetKey encPubkey, refreshIfEmpty bool) []*node { ) // don't query further if we hit ourself. // unlikely to happen often in practice. - asked[tab.self.ID()] = true + asked[tab.self().ID()] = true for { tab.mutex.Lock() @@ -340,8 +341,8 @@ func (tab *Table) loop() { revalidate = time.NewTimer(tab.nextRevalidateTime()) refresh = time.NewTicker(refreshInterval) copyNodes = time.NewTicker(copyNodesInterval) - revalidateDone = make(chan struct{}) refreshDone = make(chan struct{}) // where doRefresh reports completion + revalidateDone chan struct{} // where doRevalidate reports completion waiting = []chan struct{}{tab.initDone} // holds waiting callers while doRefresh runs ) defer refresh.Stop() @@ -372,9 +373,11 @@ loop: } waiting, refreshDone = nil, nil case <-revalidate.C: + revalidateDone = make(chan struct{}) go tab.doRevalidate(revalidateDone) case <-revalidateDone: revalidate.Reset(tab.nextRevalidateTime()) + revalidateDone = nil case <-copyNodes.C: go tab.copyLiveNodes() case <-tab.closeReq: @@ -382,15 +385,15 @@ loop: } } - if tab.net != nil { - tab.net.close() - } if refreshDone != nil { <-refreshDone } for _, ch := range waiting { close(ch) } + if revalidateDone != nil { + <-revalidateDone + } close(tab.closed) } @@ -408,7 +411,7 @@ func (tab *Table) doRefresh(done chan struct{}) { // Run self lookup to discover new neighbor nodes. // We can only do this if we have a secp256k1 identity. var key ecdsa.PublicKey - if err := tab.self.Load((*enode.Secp256k1)(&key)); err == nil { + if err := tab.self().Load((*enode.Secp256k1)(&key)); err == nil { tab.lookup(encodePubkey(&key), false) } @@ -530,7 +533,7 @@ func (tab *Table) len() (n int) { // bucket returns the bucket for the given node ID hash. func (tab *Table) bucket(id enode.ID) *bucket { - d := enode.LogDist(tab.self.ID(), id) + d := enode.LogDist(tab.self().ID(), id) if d <= bucketMinDistance { return tab.buckets[0] } @@ -543,7 +546,7 @@ func (tab *Table) bucket(id enode.ID) *bucket { // // The caller must not hold tab.mutex. func (tab *Table) add(n *node) { - if n.ID() == tab.self.ID() { + if n.ID() == tab.self().ID() { return } @@ -576,7 +579,7 @@ func (tab *Table) stuff(nodes []*node) { defer tab.mutex.Unlock() for _, n := range nodes { - if n.ID() == tab.self.ID() { + if n.ID() == tab.self().ID() { continue // don't add self } b := tab.bucket(n.ID()) diff --git a/p2p/discover/table_test.go b/p2p/discover/table_test.go index e8631024b..6b4cd2d18 100644 --- a/p2p/discover/table_test.go +++ b/p2p/discover/table_test.go @@ -141,7 +141,7 @@ func TestTable_IPLimit(t *testing.T) { defer db.Close() for i := 0; i < tableIPLimit+1; i++ { - n := nodeAtDistance(tab.self.ID(), i, net.IP{172, 0, 1, byte(i)}) + n := nodeAtDistance(tab.self().ID(), i, net.IP{172, 0, 1, byte(i)}) tab.add(n) } if tab.len() > tableIPLimit { @@ -158,7 +158,7 @@ func TestTable_BucketIPLimit(t *testing.T) { d := 3 for i := 0; i < bucketIPLimit+1; i++ { - n := nodeAtDistance(tab.self.ID(), d, net.IP{172, 0, 1, byte(i)}) + n := nodeAtDistance(tab.self().ID(), d, net.IP{172, 0, 1, byte(i)}) tab.add(n) } if tab.len() > bucketIPLimit { @@ -240,7 +240,7 @@ func TestTable_ReadRandomNodesGetAll(t *testing.T) { for i := 0; i < len(buf); i++ { ld := cfg.Rand.Intn(len(tab.buckets)) - tab.stuff([]*node{nodeAtDistance(tab.self.ID(), ld, intIP(ld))}) + tab.stuff([]*node{nodeAtDistance(tab.self().ID(), ld, intIP(ld))}) } gotN := tab.ReadRandomNodes(buf) if gotN != tab.len() { @@ -510,6 +510,10 @@ type preminedTestnet struct { dists [hashBits + 1][]encPubkey } +func (tn *preminedTestnet) self() *enode.Node { + return nullNode +} + func (tn *preminedTestnet) findnode(toid enode.ID, toaddr *net.UDPAddr, target encPubkey) ([]*node, error) { // current log distance is encoded in port number // fmt.Println("findnode query at dist", toaddr.Port) diff --git a/p2p/discover/table_util_test.go b/p2p/discover/table_util_test.go index 05ae0b6c0..d41519452 100644 --- a/p2p/discover/table_util_test.go +++ b/p2p/discover/table_util_test.go @@ -28,12 +28,17 @@ import ( "github.com/ethereum/go-ethereum/p2p/enr" ) -func newTestTable(t transport) (*Table, *enode.DB) { +var nullNode *enode.Node + +func init() { var r enr.Record r.Set(enr.IP{0, 0, 0, 0}) - n := enode.SignNull(&r, enode.ID{}) + nullNode = enode.SignNull(&r, enode.ID{}) +} + +func newTestTable(t transport) (*Table, *enode.DB) { db, _ := enode.OpenDB("") - tab, _ := newTable(t, n, db, nil) + tab, _ := newTable(t, db, nil) return tab, db } @@ -70,10 +75,10 @@ func intIP(i int) net.IP { // fillBucket inserts nodes into the given bucket until it is full. func fillBucket(tab *Table, n *node) (last *node) { - ld := enode.LogDist(tab.self.ID(), n.ID()) + ld := enode.LogDist(tab.self().ID(), n.ID()) b := tab.bucket(n.ID()) for len(b.entries) < bucketSize { - b.entries = append(b.entries, nodeAtDistance(tab.self.ID(), ld, intIP(ld))) + b.entries = append(b.entries, nodeAtDistance(tab.self().ID(), ld, intIP(ld))) } return b.entries[bucketSize-1] } @@ -81,15 +86,25 @@ func fillBucket(tab *Table, n *node) (last *node) { type pingRecorder struct { mu sync.Mutex dead, pinged map[enode.ID]bool + n *enode.Node } func newPingRecorder() *pingRecorder { + var r enr.Record + r.Set(enr.IP{0, 0, 0, 0}) + n := enode.SignNull(&r, enode.ID{}) + return &pingRecorder{ dead: make(map[enode.ID]bool), pinged: make(map[enode.ID]bool), + n: n, } } +func (t *pingRecorder) self() *enode.Node { + return nullNode +} + func (t *pingRecorder) findnode(toid enode.ID, toaddr *net.UDPAddr, target encPubkey) ([]*node, error) { return nil, nil } diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go index 45fcce282..37a044902 100644 --- a/p2p/discover/udp.go +++ b/p2p/discover/udp.go @@ -23,12 +23,12 @@ import ( "errors" "fmt" "net" + "sync" "time" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p/enode" - "github.com/ethereum/go-ethereum/p2p/nat" "github.com/ethereum/go-ethereum/p2p/netutil" "github.com/ethereum/go-ethereum/rlp" ) @@ -118,9 +118,11 @@ type ( ) func makeEndpoint(addr *net.UDPAddr, tcpPort uint16) rpcEndpoint { - ip := addr.IP.To4() - if ip == nil { - ip = addr.IP.To16() + ip := net.IP{} + if ip4 := addr.IP.To4(); ip4 != nil { + ip = ip4 + } else if ip6 := addr.IP.To16(); ip6 != nil { + ip = ip6 } return rpcEndpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort} } @@ -165,20 +167,19 @@ type conn interface { LocalAddr() net.Addr } -// udp implements the RPC protocol. +// udp implements the discovery v4 UDP wire protocol. type udp struct { conn conn netrestrict *netutil.Netlist priv *ecdsa.PrivateKey - ourEndpoint rpcEndpoint + localNode *enode.LocalNode + db *enode.DB + tab *Table + wg sync.WaitGroup addpending chan *pending gotreply chan reply - - closing chan struct{} - nat nat.Interface - - *Table + closing chan struct{} } // pending represents a pending reply. @@ -230,60 +231,57 @@ type Config struct { PrivateKey *ecdsa.PrivateKey // These settings are optional: - AnnounceAddr *net.UDPAddr // local address announced in the DHT - NodeDBPath string // if set, the node database is stored at this filesystem location - NetRestrict *netutil.Netlist // network whitelist - Bootnodes []*enode.Node // list of bootstrap nodes - Unhandled chan<- ReadPacket // unhandled packets are sent on this channel + NetRestrict *netutil.Netlist // network whitelist + Bootnodes []*enode.Node // list of bootstrap nodes + Unhandled chan<- ReadPacket // unhandled packets are sent on this channel } // ListenUDP returns a new table that listens for UDP packets on laddr. -func ListenUDP(c conn, cfg Config) (*Table, error) { - tab, _, err := newUDP(c, cfg) +func ListenUDP(c conn, ln *enode.LocalNode, cfg Config) (*Table, error) { + tab, _, err := newUDP(c, ln, cfg) if err != nil { return nil, err } - log.Info("UDP listener up", "self", tab.self) return tab, nil } -func newUDP(c conn, cfg Config) (*Table, *udp, error) { - realaddr := c.LocalAddr().(*net.UDPAddr) - if cfg.AnnounceAddr != nil { - realaddr = cfg.AnnounceAddr - } - self := enode.NewV4(&cfg.PrivateKey.PublicKey, realaddr.IP, realaddr.Port, realaddr.Port) - db, err := enode.OpenDB(cfg.NodeDBPath) - if err != nil { - return nil, nil, err - } - +func newUDP(c conn, ln *enode.LocalNode, cfg Config) (*Table, *udp, error) { udp := &udp{ conn: c, priv: cfg.PrivateKey, netrestrict: cfg.NetRestrict, + localNode: ln, + db: ln.Database(), closing: make(chan struct{}), gotreply: make(chan reply), addpending: make(chan *pending), } - // TODO: separate TCP port - udp.ourEndpoint = makeEndpoint(realaddr, uint16(realaddr.Port)) - tab, err := newTable(udp, self, db, cfg.Bootnodes) + tab, err := newTable(udp, ln.Database(), cfg.Bootnodes) if err != nil { return nil, nil, err } - udp.Table = tab + udp.tab = tab + udp.wg.Add(2) go udp.loop() go udp.readLoop(cfg.Unhandled) - return udp.Table, udp, nil + return udp.tab, udp, nil +} + +func (t *udp) self() *enode.Node { + return t.localNode.Node() } func (t *udp) close() { close(t.closing) t.conn.Close() - t.db.Close() - // TODO: wait for the loops to end. + t.wg.Wait() +} + +func (t *udp) ourEndpoint() rpcEndpoint { + n := t.self() + a := &net.UDPAddr{IP: n.IP(), Port: n.UDP()} + return makeEndpoint(a, uint16(n.TCP())) } // ping sends a ping message to the given node and waits for a reply. @@ -296,7 +294,7 @@ func (t *udp) ping(toid enode.ID, toaddr *net.UDPAddr) error { func (t *udp) sendPing(toid enode.ID, toaddr *net.UDPAddr, callback func()) <-chan error { req := &ping{ Version: 4, - From: t.ourEndpoint, + From: t.ourEndpoint(), To: makeEndpoint(toaddr, 0), // TODO: maybe use known TCP port from DB Expiration: uint64(time.Now().Add(expiration).Unix()), } @@ -313,6 +311,7 @@ func (t *udp) sendPing(toid enode.ID, toaddr *net.UDPAddr, callback func()) <-ch } return ok }) + t.localNode.UDPContact(toaddr) t.write(toaddr, req.name(), packet) return errc } @@ -381,6 +380,8 @@ func (t *udp) handleReply(from enode.ID, ptype byte, req packet) bool { // loop runs in its own goroutine. it keeps track of // the refresh timer and the pending reply queue. func (t *udp) loop() { + defer t.wg.Done() + var ( plist = list.New() timeout = time.NewTimer(0) @@ -542,10 +543,11 @@ func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) (packet, // readLoop runs in its own goroutine. it handles incoming UDP packets. func (t *udp) readLoop(unhandled chan<- ReadPacket) { - defer t.conn.Close() + defer t.wg.Done() if unhandled != nil { defer close(unhandled) } + // Discovery packets are defined to be no larger than 1280 bytes. // Packets larger than this size will be cut at the end and treated // as invalid because their hash won't match. @@ -629,10 +631,11 @@ func (req *ping) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte n := wrapNode(enode.NewV4(key, from.IP, int(req.From.TCP), from.Port)) t.handleReply(n.ID(), pingPacket, req) if time.Since(t.db.LastPongReceived(n.ID())) > bondExpiration { - t.sendPing(n.ID(), from, func() { t.addThroughPing(n) }) + t.sendPing(n.ID(), from, func() { t.tab.addThroughPing(n) }) } else { - t.addThroughPing(n) + t.tab.addThroughPing(n) } + t.localNode.UDPEndpointStatement(from, &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)}) t.db.UpdateLastPingReceived(n.ID(), time.Now()) return nil } @@ -647,6 +650,7 @@ func (req *pong) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte if !t.handleReply(fromID, pongPacket, req) { return errUnsolicitedReply } + t.localNode.UDPEndpointStatement(from, &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)}) t.db.UpdateLastPongReceived(fromID, time.Now()) return nil } @@ -668,9 +672,9 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac [] return errUnknownNode } target := enode.ID(crypto.Keccak256Hash(req.Target[:])) - t.mutex.Lock() - closest := t.closest(target, bucketSize).entries - t.mutex.Unlock() + t.tab.mutex.Lock() + closest := t.tab.closest(target, bucketSize).entries + t.tab.mutex.Unlock() p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())} var sent bool diff --git a/p2p/discover/udp_test.go b/p2p/discover/udp_test.go index da95c4f5c..a4ddaf750 100644 --- a/p2p/discover/udp_test.go +++ b/p2p/discover/udp_test.go @@ -71,7 +71,9 @@ func newUDPTest(t *testing.T) *udpTest { remotekey: newkey(), remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303}, } - test.table, test.udp, _ = newUDP(test.pipe, Config{PrivateKey: test.localkey}) + db, _ := enode.OpenDB("") + ln := enode.NewLocalNode(db, test.localkey) + test.table, test.udp, _ = newUDP(test.pipe, ln, Config{PrivateKey: test.localkey}) // Wait for initial refresh so the table doesn't send unexpected findnode. <-test.table.initDone return test @@ -355,12 +357,13 @@ func TestUDP_successfulPing(t *testing.T) { // remote is unknown, the table pings back. hash, _ := test.waitPacketOut(func(p *ping) error { - if !reflect.DeepEqual(p.From, test.udp.ourEndpoint) { - t.Errorf("got ping.From %v, want %v", p.From, test.udp.ourEndpoint) + if !reflect.DeepEqual(p.From, test.udp.ourEndpoint()) { + t.Errorf("got ping.From %#v, want %#v", p.From, test.udp.ourEndpoint()) } wantTo := rpcEndpoint{ // The mirrored UDP address is the UDP packet sender. - IP: test.remoteaddr.IP, UDP: uint16(test.remoteaddr.Port), + IP: test.remoteaddr.IP, + UDP: uint16(test.remoteaddr.Port), TCP: 0, } if !reflect.DeepEqual(p.To, wantTo) { diff --git a/p2p/discv5/udp.go b/p2p/discv5/udp.go index 49e1cb811..ff5ed983b 100644 --- a/p2p/discv5/udp.go +++ b/p2p/discv5/udp.go @@ -230,7 +230,8 @@ type udp struct { } // ListenUDP returns a new table that listens for UDP packets on laddr. -func ListenUDP(priv *ecdsa.PrivateKey, conn conn, realaddr *net.UDPAddr, nodeDBPath string, netrestrict *netutil.Netlist) (*Network, error) { +func ListenUDP(priv *ecdsa.PrivateKey, conn conn, nodeDBPath string, netrestrict *netutil.Netlist) (*Network, error) { + realaddr := conn.LocalAddr().(*net.UDPAddr) transport, err := listenUDP(priv, conn, realaddr) if err != nil { return nil, err diff --git a/p2p/enode/localnode.go b/p2p/enode/localnode.go new file mode 100644 index 000000000..623f8eae1 --- /dev/null +++ b/p2p/enode/localnode.go @@ -0,0 +1,246 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +package enode + +import ( + "crypto/ecdsa" + "fmt" + "net" + "reflect" + "strconv" + "sync" + "sync/atomic" + "time" + + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p/enr" + "github.com/ethereum/go-ethereum/p2p/netutil" +) + +const ( + // IP tracker configuration + iptrackMinStatements = 10 + iptrackWindow = 5 * time.Minute + iptrackContactWindow = 10 * time.Minute +) + +// LocalNode produces the signed node record of a local node, i.e. a node run in the +// current process. Setting ENR entries via the Set method updates the record. A new version +// of the record is signed on demand when the Node method is called. +type LocalNode struct { + cur atomic.Value // holds a non-nil node pointer while the record is up-to-date. + id ID + key *ecdsa.PrivateKey + db *DB + + // everything below is protected by a lock + mu sync.Mutex + seq uint64 + entries map[string]enr.Entry + udpTrack *netutil.IPTracker // predicts external UDP endpoint + staticIP net.IP + fallbackIP net.IP + fallbackUDP int +} + +// NewLocalNode creates a local node. +func NewLocalNode(db *DB, key *ecdsa.PrivateKey) *LocalNode { + ln := &LocalNode{ + id: PubkeyToIDV4(&key.PublicKey), + db: db, + key: key, + udpTrack: netutil.NewIPTracker(iptrackWindow, iptrackContactWindow, iptrackMinStatements), + entries: make(map[string]enr.Entry), + } + ln.seq = db.localSeq(ln.id) + ln.invalidate() + return ln +} + +// Database returns the node database associated with the local node. +func (ln *LocalNode) Database() *DB { + return ln.db +} + +// Node returns the current version of the local node record. +func (ln *LocalNode) Node() *Node { + n := ln.cur.Load().(*Node) + if n != nil { + return n + } + // Record was invalidated, sign a new copy. + ln.mu.Lock() + defer ln.mu.Unlock() + ln.sign() + return ln.cur.Load().(*Node) +} + +// ID returns the local node ID. +func (ln *LocalNode) ID() ID { + return ln.id +} + +// Set puts the given entry into the local record, overwriting +// any existing value. +func (ln *LocalNode) Set(e enr.Entry) { + ln.mu.Lock() + defer ln.mu.Unlock() + + ln.set(e) +} + +func (ln *LocalNode) set(e enr.Entry) { + val, exists := ln.entries[e.ENRKey()] + if !exists || !reflect.DeepEqual(val, e) { + ln.entries[e.ENRKey()] = e + ln.invalidate() + } +} + +// Delete removes the given entry from the local record. +func (ln *LocalNode) Delete(e enr.Entry) { + ln.mu.Lock() + defer ln.mu.Unlock() + + ln.delete(e) +} + +func (ln *LocalNode) delete(e enr.Entry) { + _, exists := ln.entries[e.ENRKey()] + if exists { + delete(ln.entries, e.ENRKey()) + ln.invalidate() + } +} + +// SetStaticIP sets the local IP to the given one unconditionally. +// This disables endpoint prediction. +func (ln *LocalNode) SetStaticIP(ip net.IP) { + ln.mu.Lock() + defer ln.mu.Unlock() + + ln.staticIP = ip + ln.updateEndpoints() +} + +// SetFallbackIP sets the last-resort IP address. This address is used +// if no endpoint prediction can be made and no static IP is set. +func (ln *LocalNode) SetFallbackIP(ip net.IP) { + ln.mu.Lock() + defer ln.mu.Unlock() + + ln.fallbackIP = ip + ln.updateEndpoints() +} + +// SetFallbackUDP sets the last-resort UDP port. This port is used +// if no endpoint prediction can be made. +func (ln *LocalNode) SetFallbackUDP(port int) { + ln.mu.Lock() + defer ln.mu.Unlock() + + ln.fallbackUDP = port + ln.updateEndpoints() +} + +// UDPEndpointStatement should be called whenever a statement about the local node's +// UDP endpoint is received. It feeds the local endpoint predictor. +func (ln *LocalNode) UDPEndpointStatement(fromaddr, endpoint *net.UDPAddr) { + ln.mu.Lock() + defer ln.mu.Unlock() + + ln.udpTrack.AddStatement(fromaddr.String(), endpoint.String()) + ln.updateEndpoints() +} + +// UDPContact should be called whenever the local node has announced itself to another node +// via UDP. It feeds the local endpoint predictor. +func (ln *LocalNode) UDPContact(toaddr *net.UDPAddr) { + ln.mu.Lock() + defer ln.mu.Unlock() + + ln.udpTrack.AddContact(toaddr.String()) + ln.updateEndpoints() +} + +func (ln *LocalNode) updateEndpoints() { + // Determine the endpoints. + newIP := ln.fallbackIP + newUDP := ln.fallbackUDP + if ln.staticIP != nil { + newIP = ln.staticIP + } else if ip, port := predictAddr(ln.udpTrack); ip != nil { + newIP = ip + newUDP = port + } + + // Update the record. + if newIP != nil && !newIP.IsUnspecified() { + ln.set(enr.IP(newIP)) + if newUDP != 0 { + ln.set(enr.UDP(newUDP)) + } else { + ln.delete(enr.UDP(0)) + } + } else { + ln.delete(enr.IP{}) + } +} + +// predictAddr wraps IPTracker.PredictEndpoint, converting from its string-based +// endpoint representation to IP and port types. +func predictAddr(t *netutil.IPTracker) (net.IP, int) { + ep := t.PredictEndpoint() + if ep == "" { + return nil, 0 + } + ipString, portString, _ := net.SplitHostPort(ep) + ip := net.ParseIP(ipString) + port, _ := strconv.Atoi(portString) + return ip, port +} + +func (ln *LocalNode) invalidate() { + ln.cur.Store((*Node)(nil)) +} + +func (ln *LocalNode) sign() { + if n := ln.cur.Load().(*Node); n != nil { + return // no changes + } + + var r enr.Record + for _, e := range ln.entries { + r.Set(e) + } + ln.bumpSeq() + r.SetSeq(ln.seq) + if err := SignV4(&r, ln.key); err != nil { + panic(fmt.Errorf("enode: can't sign record: %v", err)) + } + n, err := New(ValidSchemes, &r) + if err != nil { + panic(fmt.Errorf("enode: can't verify local record: %v", err)) + } + ln.cur.Store(n) + log.Info("New local node record", "seq", ln.seq, "id", n.ID(), "ip", n.IP(), "udp", n.UDP(), "tcp", n.TCP()) +} + +func (ln *LocalNode) bumpSeq() { + ln.seq++ + ln.db.storeLocalSeq(ln.id, ln.seq) +} diff --git a/p2p/enode/localnode_test.go b/p2p/enode/localnode_test.go new file mode 100644 index 000000000..f5e3496d6 --- /dev/null +++ b/p2p/enode/localnode_test.go @@ -0,0 +1,76 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +package enode + +import ( + "testing" + + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/p2p/enr" +) + +func newLocalNodeForTesting() (*LocalNode, *DB) { + db, _ := OpenDB("") + key, _ := crypto.GenerateKey() + return NewLocalNode(db, key), db +} + +func TestLocalNode(t *testing.T) { + ln, db := newLocalNodeForTesting() + defer db.Close() + + if ln.Node().ID() != ln.ID() { + t.Fatal("inconsistent ID") + } + + ln.Set(enr.WithEntry("x", uint(3))) + var x uint + if err := ln.Node().Load(enr.WithEntry("x", &x)); err != nil { + t.Fatal("can't load entry 'x':", err) + } else if x != 3 { + t.Fatal("wrong value for entry 'x':", x) + } +} + +func TestLocalNodeSeqPersist(t *testing.T) { + ln, db := newLocalNodeForTesting() + defer db.Close() + + if s := ln.Node().Seq(); s != 1 { + t.Fatalf("wrong initial seq %d, want 1", s) + } + ln.Set(enr.WithEntry("x", uint(1))) + if s := ln.Node().Seq(); s != 2 { + t.Fatalf("wrong seq %d after set, want 2", s) + } + + // Create a new instance, it should reload the sequence number. + // The number increases just after that because a new record is + // created without the "x" entry. + ln2 := NewLocalNode(db, ln.key) + if s := ln2.Node().Seq(); s != 3 { + t.Fatalf("wrong seq %d on new instance, want 3", s) + } + + // Create a new instance with a different node key on the same database. + // This should reset the sequence number. + key, _ := crypto.GenerateKey() + ln3 := NewLocalNode(db, key) + if s := ln3.Node().Seq(); s != 1 { + t.Fatalf("wrong seq %d on instance with changed key, want 1", s) + } +} diff --git a/p2p/enode/node.go b/p2p/enode/node.go index 84088fcd2..b454ab255 100644 --- a/p2p/enode/node.go +++ b/p2p/enode/node.go @@ -98,6 +98,13 @@ func (n *Node) Pubkey() *ecdsa.PublicKey { return &key } +// Record returns the node's record. The return value is a copy and may +// be modified by the caller. +func (n *Node) Record() *enr.Record { + cpy := n.r + return &cpy +} + // checks whether n is a valid complete node. func (n *Node) ValidateComplete() error { if n.Incomplete() { diff --git a/p2p/enode/nodedb.go b/p2p/enode/nodedb.go index a929b75d7..7ee0c09a9 100644 --- a/p2p/enode/nodedb.go +++ b/p2p/enode/nodedb.go @@ -35,11 +35,24 @@ import ( "github.com/syndtr/goleveldb/leveldb/util" ) +// Keys in the node database. +const ( + dbVersionKey = "version" // Version of the database to flush if changes + dbItemPrefix = "n:" // Identifier to prefix node entries with + + dbDiscoverRoot = ":discover" + dbDiscoverSeq = dbDiscoverRoot + ":seq" + dbDiscoverPing = dbDiscoverRoot + ":lastping" + dbDiscoverPong = dbDiscoverRoot + ":lastpong" + dbDiscoverFindFails = dbDiscoverRoot + ":findfail" + dbLocalRoot = ":local" + dbLocalSeq = dbLocalRoot + ":seq" +) + var ( - nodeDBNilID = ID{} // Special node ID to use as a nil element. - nodeDBNodeExpiration = 24 * time.Hour // Time after which an unseen node should be dropped. - nodeDBCleanupCycle = time.Hour // Time period for running the expiration task. - nodeDBVersion = 6 + dbNodeExpiration = 24 * time.Hour // Time after which an unseen node should be dropped. + dbCleanupCycle = time.Hour // Time period for running the expiration task. + dbVersion = 7 ) // DB is the node database, storing previously seen nodes and any collected metadata about @@ -50,17 +63,6 @@ type DB struct { quit chan struct{} // Channel to signal the expiring thread to stop } -// Schema layout for the node database -var ( - nodeDBVersionKey = []byte("version") // Version of the database to flush if changes - nodeDBItemPrefix = []byte("n:") // Identifier to prefix node entries with - - nodeDBDiscoverRoot = ":discover" - nodeDBDiscoverPing = nodeDBDiscoverRoot + ":lastping" - nodeDBDiscoverPong = nodeDBDiscoverRoot + ":lastpong" - nodeDBDiscoverFindFails = nodeDBDiscoverRoot + ":findfail" -) - // OpenDB opens a node database for storing and retrieving infos about known peers in the // network. If no path is given an in-memory, temporary database is constructed. func OpenDB(path string) (*DB, error) { @@ -93,13 +95,13 @@ func newPersistentDB(path string) (*DB, error) { // The nodes contained in the cache correspond to a certain protocol version. // Flush all nodes if the version doesn't match. currentVer := make([]byte, binary.MaxVarintLen64) - currentVer = currentVer[:binary.PutVarint(currentVer, int64(nodeDBVersion))] + currentVer = currentVer[:binary.PutVarint(currentVer, int64(dbVersion))] - blob, err := db.Get(nodeDBVersionKey, nil) + blob, err := db.Get([]byte(dbVersionKey), nil) switch err { case leveldb.ErrNotFound: // Version not found (i.e. empty cache), insert it - if err := db.Put(nodeDBVersionKey, currentVer, nil); err != nil { + if err := db.Put([]byte(dbVersionKey), currentVer, nil); err != nil { db.Close() return nil, err } @@ -120,28 +122,27 @@ func newPersistentDB(path string) (*DB, error) { // makeKey generates the leveldb key-blob from a node id and its particular // field of interest. func makeKey(id ID, field string) []byte { - if bytes.Equal(id[:], nodeDBNilID[:]) { + if (id == ID{}) { return []byte(field) } - return append(nodeDBItemPrefix, append(id[:], field...)...) + return append([]byte(dbItemPrefix), append(id[:], field...)...) } // splitKey tries to split a database key into a node id and a field part. func splitKey(key []byte) (id ID, field string) { // If the key is not of a node, return it plainly - if !bytes.HasPrefix(key, nodeDBItemPrefix) { + if !bytes.HasPrefix(key, []byte(dbItemPrefix)) { return ID{}, string(key) } // Otherwise split the id and field - item := key[len(nodeDBItemPrefix):] + item := key[len(dbItemPrefix):] copy(id[:], item[:len(id)]) field = string(item[len(id):]) return id, field } -// fetchInt64 retrieves an integer instance associated with a particular -// database key. +// fetchInt64 retrieves an integer associated with a particular key. func (db *DB) fetchInt64(key []byte) int64 { blob, err := db.lvl.Get(key, nil) if err != nil { @@ -154,18 +155,33 @@ func (db *DB) fetchInt64(key []byte) int64 { return val } -// storeInt64 update a specific database entry to the current time instance as a -// unix timestamp. +// storeInt64 stores an integer in the given key. func (db *DB) storeInt64(key []byte, n int64) error { blob := make([]byte, binary.MaxVarintLen64) blob = blob[:binary.PutVarint(blob, n)] + return db.lvl.Put(key, blob, nil) +} + +// fetchUint64 retrieves an integer associated with a particular key. +func (db *DB) fetchUint64(key []byte) uint64 { + blob, err := db.lvl.Get(key, nil) + if err != nil { + return 0 + } + val, _ := binary.Uvarint(blob) + return val +} +// storeUint64 stores an integer in the given key. +func (db *DB) storeUint64(key []byte, n uint64) error { + blob := make([]byte, binary.MaxVarintLen64) + blob = blob[:binary.PutUvarint(blob, n)] return db.lvl.Put(key, blob, nil) } // Node retrieves a node with a given id from the database. func (db *DB) Node(id ID) *Node { - blob, err := db.lvl.Get(makeKey(id, nodeDBDiscoverRoot), nil) + blob, err := db.lvl.Get(makeKey(id, dbDiscoverRoot), nil) if err != nil { return nil } @@ -184,11 +200,31 @@ func mustDecodeNode(id, data []byte) *Node { // UpdateNode inserts - potentially overwriting - a node into the peer database. func (db *DB) UpdateNode(node *Node) error { + if node.Seq() < db.NodeSeq(node.ID()) { + return nil + } blob, err := rlp.EncodeToBytes(&node.r) if err != nil { return err } - return db.lvl.Put(makeKey(node.ID(), nodeDBDiscoverRoot), blob, nil) + if err := db.lvl.Put(makeKey(node.ID(), dbDiscoverRoot), blob, nil); err != nil { + return err + } + return db.storeUint64(makeKey(node.ID(), dbDiscoverSeq), node.Seq()) +} + +// NodeSeq returns the stored record sequence number of the given node. +func (db *DB) NodeSeq(id ID) uint64 { + return db.fetchUint64(makeKey(id, dbDiscoverSeq)) +} + +// Resolve returns the stored record of the node if it has a larger sequence +// number than n. +func (db *DB) Resolve(n *Node) *Node { + if n.Seq() > db.NodeSeq(n.ID()) { + return n + } + return db.Node(n.ID()) } // DeleteNode deletes all information/keys associated with a node. @@ -218,7 +254,7 @@ func (db *DB) ensureExpirer() { // expirer should be started in a go routine, and is responsible for looping ad // infinitum and dropping stale data from the database. func (db *DB) expirer() { - tick := time.NewTicker(nodeDBCleanupCycle) + tick := time.NewTicker(dbCleanupCycle) defer tick.Stop() for { select { @@ -235,7 +271,7 @@ func (db *DB) expirer() { // expireNodes iterates over the database and deletes all nodes that have not // been seen (i.e. received a pong from) for some allotted time. func (db *DB) expireNodes() error { - threshold := time.Now().Add(-nodeDBNodeExpiration) + threshold := time.Now().Add(-dbNodeExpiration) // Find discovered nodes that are older than the allowance it := db.lvl.NewIterator(nil, nil) @@ -244,7 +280,7 @@ func (db *DB) expireNodes() error { for it.Next() { // Skip the item if not a discovery node id, field := splitKey(it.Key()) - if field != nodeDBDiscoverRoot { + if field != dbDiscoverRoot { continue } // Skip the node if not expired yet (and not self) @@ -260,34 +296,44 @@ func (db *DB) expireNodes() error { // LastPingReceived retrieves the time of the last ping packet received from // a remote node. func (db *DB) LastPingReceived(id ID) time.Time { - return time.Unix(db.fetchInt64(makeKey(id, nodeDBDiscoverPing)), 0) + return time.Unix(db.fetchInt64(makeKey(id, dbDiscoverPing)), 0) } // UpdateLastPingReceived updates the last time we tried contacting a remote node. func (db *DB) UpdateLastPingReceived(id ID, instance time.Time) error { - return db.storeInt64(makeKey(id, nodeDBDiscoverPing), instance.Unix()) + return db.storeInt64(makeKey(id, dbDiscoverPing), instance.Unix()) } // LastPongReceived retrieves the time of the last successful pong from remote node. func (db *DB) LastPongReceived(id ID) time.Time { // Launch expirer db.ensureExpirer() - return time.Unix(db.fetchInt64(makeKey(id, nodeDBDiscoverPong)), 0) + return time.Unix(db.fetchInt64(makeKey(id, dbDiscoverPong)), 0) } // UpdateLastPongReceived updates the last pong time of a node. func (db *DB) UpdateLastPongReceived(id ID, instance time.Time) error { - return db.storeInt64(makeKey(id, nodeDBDiscoverPong), instance.Unix()) + return db.storeInt64(makeKey(id, dbDiscoverPong), instance.Unix()) } // FindFails retrieves the number of findnode failures since bonding. func (db *DB) FindFails(id ID) int { - return int(db.fetchInt64(makeKey(id, nodeDBDiscoverFindFails))) + return int(db.fetchInt64(makeKey(id, dbDiscoverFindFails))) } // UpdateFindFails updates the number of findnode failures since bonding. func (db *DB) UpdateFindFails(id ID, fails int) error { - return db.storeInt64(makeKey(id, nodeDBDiscoverFindFails), int64(fails)) + return db.storeInt64(makeKey(id, dbDiscoverFindFails), int64(fails)) +} + +// LocalSeq retrieves the local record sequence counter. +func (db *DB) localSeq(id ID) uint64 { + return db.fetchUint64(makeKey(id, dbLocalSeq)) +} + +// storeLocalSeq stores the local record sequence counter. +func (db *DB) storeLocalSeq(id ID, n uint64) { + db.storeUint64(makeKey(id, dbLocalSeq), n) } // QuerySeeds retrieves random nodes to be used as potential seed nodes @@ -309,7 +355,7 @@ seek: ctr := id[0] rand.Read(id[:]) id[0] = ctr + id[0]%16 - it.Seek(makeKey(id, nodeDBDiscoverRoot)) + it.Seek(makeKey(id, dbDiscoverRoot)) n := nextNode(it) if n == nil { @@ -334,7 +380,7 @@ seek: func nextNode(it iterator.Iterator) *Node { for end := false; !end; end = !it.Next() { id, field := splitKey(it.Key()) - if field != nodeDBDiscoverRoot { + if field != dbDiscoverRoot { continue } return mustDecodeNode(id[:], it.Value()) diff --git a/p2p/enode/nodedb_test.go b/p2p/enode/nodedb_test.go index b476a3439..96794827c 100644 --- a/p2p/enode/nodedb_test.go +++ b/p2p/enode/nodedb_test.go @@ -332,7 +332,7 @@ var nodeDBExpirationNodes = []struct { 30303, 30303, ), - pong: time.Now().Add(-nodeDBNodeExpiration + time.Minute), + pong: time.Now().Add(-dbNodeExpiration + time.Minute), exp: false, }, { node: NewV4( @@ -341,7 +341,7 @@ var nodeDBExpirationNodes = []struct { 30303, 30303, ), - pong: time.Now().Add(-nodeDBNodeExpiration - time.Minute), + pong: time.Now().Add(-dbNodeExpiration - time.Minute), exp: true, }, } diff --git a/p2p/enr/enr.go b/p2p/enr/enr.go index 251caf458..444820c15 100644 --- a/p2p/enr/enr.go +++ b/p2p/enr/enr.go @@ -156,7 +156,7 @@ func (r *Record) Set(e Entry) { } func (r *Record) invalidate() { - if r.signature == nil { + if r.signature != nil { r.seq++ } r.signature = nil diff --git a/p2p/enr/enr_test.go b/p2p/enr/enr_test.go index 9bf22478d..449c898a8 100644 --- a/p2p/enr/enr_test.go +++ b/p2p/enr/enr_test.go @@ -169,6 +169,18 @@ func TestDirty(t *testing.T) { } } +func TestSeq(t *testing.T) { + var r Record + + assert.Equal(t, uint64(0), r.Seq()) + r.Set(UDP(1)) + assert.Equal(t, uint64(0), r.Seq()) + signTest([]byte{5}, &r) + assert.Equal(t, uint64(0), r.Seq()) + r.Set(UDP(2)) + assert.Equal(t, uint64(1), r.Seq()) +} + // TestGetSetOverwrite tests value overwrite when setting a new value with an existing key in record. func TestGetSetOverwrite(t *testing.T) { var r Record diff --git a/p2p/nat/nat.go b/p2p/nat/nat.go index a254648c6..8fad921c4 100644 --- a/p2p/nat/nat.go +++ b/p2p/nat/nat.go @@ -129,21 +129,15 @@ func Map(m Interface, c chan struct{}, protocol string, extport, intport int, na // ExtIP assumes that the local machine is reachable on the given // external IP address, and that any required ports were mapped manually. // Mapping operations will not return an error but won't actually do anything. -func ExtIP(ip net.IP) Interface { - if ip == nil { - panic("IP must not be nil") - } - return extIP(ip) -} +type ExtIP net.IP -type extIP net.IP - -func (n extIP) ExternalIP() (net.IP, error) { return net.IP(n), nil } -func (n extIP) String() string { return fmt.Sprintf("ExtIP(%v)", net.IP(n)) } +func (n ExtIP) ExternalIP() (net.IP, error) { return net.IP(n), nil } +func (n ExtIP) String() string { return fmt.Sprintf("ExtIP(%v)", net.IP(n)) } // These do nothing. -func (extIP) AddMapping(string, int, int, string, time.Duration) error { return nil } -func (extIP) DeleteMapping(string, int, int) error { return nil } + +func (ExtIP) AddMapping(string, int, int, string, time.Duration) error { return nil } +func (ExtIP) DeleteMapping(string, int, int) error { return nil } // Any returns a port mapper that tries to discover any supported // mechanism on the local network. diff --git a/p2p/nat/nat_test.go b/p2p/nat/nat_test.go index 469101e99..814e6d9e1 100644 --- a/p2p/nat/nat_test.go +++ b/p2p/nat/nat_test.go @@ -28,7 +28,7 @@ import ( func TestAutoDiscRace(t *testing.T) { ad := startautodisc("thing", func() Interface { time.Sleep(500 * time.Millisecond) - return extIP{33, 44, 55, 66} + return ExtIP{33, 44, 55, 66} }) // Spawn a few concurrent calls to ad.ExternalIP. diff --git a/p2p/netutil/iptrack.go b/p2p/netutil/iptrack.go new file mode 100644 index 000000000..b9cbd5e1c --- /dev/null +++ b/p2p/netutil/iptrack.go @@ -0,0 +1,130 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +package netutil + +import ( + "time" + + "github.com/ethereum/go-ethereum/common/mclock" +) + +// IPTracker predicts the external endpoint, i.e. IP address and port, of the local host +// based on statements made by other hosts. +type IPTracker struct { + window time.Duration + contactWindow time.Duration + minStatements int + clock mclock.Clock + statements map[string]ipStatement + contact map[string]mclock.AbsTime + lastStatementGC mclock.AbsTime + lastContactGC mclock.AbsTime +} + +type ipStatement struct { + endpoint string + time mclock.AbsTime +} + +// NewIPTracker creates an IP tracker. +// +// The window parameters configure the amount of past network events which are kept. The +// minStatements parameter enforces a minimum number of statements which must be recorded +// before any prediction is made. Higher values for these parameters decrease 'flapping' of +// predictions as network conditions change. Window duration values should typically be in +// the range of minutes. +func NewIPTracker(window, contactWindow time.Duration, minStatements int) *IPTracker { + return &IPTracker{ + window: window, + contactWindow: contactWindow, + statements: make(map[string]ipStatement), + minStatements: minStatements, + contact: make(map[string]mclock.AbsTime), + clock: mclock.System{}, + } +} + +// PredictFullConeNAT checks whether the local host is behind full cone NAT. It predicts by +// checking whether any statement has been received from a node we didn't contact before +// the statement was made. +func (it *IPTracker) PredictFullConeNAT() bool { + now := it.clock.Now() + it.gcContact(now) + it.gcStatements(now) + for host, st := range it.statements { + if c, ok := it.contact[host]; !ok || c > st.time { + return true + } + } + return false +} + +// PredictEndpoint returns the current prediction of the external endpoint. +func (it *IPTracker) PredictEndpoint() string { + it.gcStatements(it.clock.Now()) + + // The current strategy is simple: find the endpoint with most statements. + counts := make(map[string]int) + maxcount, max := 0, "" + for _, s := range it.statements { + c := counts[s.endpoint] + 1 + counts[s.endpoint] = c + if c > maxcount && c >= it.minStatements { + maxcount, max = c, s.endpoint + } + } + return max +} + +// AddStatement records that a certain host thinks our external endpoint is the one given. +func (it *IPTracker) AddStatement(host, endpoint string) { + now := it.clock.Now() + it.statements[host] = ipStatement{endpoint, now} + if time.Duration(now-it.lastStatementGC) >= it.window { + it.gcStatements(now) + } +} + +// AddContact records that a packet containing our endpoint information has been sent to a +// certain host. +func (it *IPTracker) AddContact(host string) { + now := it.clock.Now() + it.contact[host] = now + if time.Duration(now-it.lastContactGC) >= it.contactWindow { + it.gcContact(now) + } +} + +func (it *IPTracker) gcStatements(now mclock.AbsTime) { + it.lastStatementGC = now + cutoff := now.Add(-it.window) + for host, s := range it.statements { + if s.time < cutoff { + delete(it.statements, host) + } + } +} + +func (it *IPTracker) gcContact(now mclock.AbsTime) { + it.lastContactGC = now + cutoff := now.Add(-it.contactWindow) + for host, ct := range it.contact { + if ct < cutoff { + delete(it.contact, host) + } + } +} diff --git a/p2p/netutil/iptrack_test.go b/p2p/netutil/iptrack_test.go new file mode 100644 index 000000000..a9a2998a6 --- /dev/null +++ b/p2p/netutil/iptrack_test.go @@ -0,0 +1,138 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +package netutil + +import ( + "fmt" + mrand "math/rand" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common/mclock" +) + +const ( + opStatement = iota + opContact + opPredict + opCheckFullCone +) + +type iptrackTestEvent struct { + op int + time int // absolute, in milliseconds + ip, from string +} + +func TestIPTracker(t *testing.T) { + tests := map[string][]iptrackTestEvent{ + "minStatements": { + {opPredict, 0, "", ""}, + {opStatement, 0, "127.0.0.1", "127.0.0.2"}, + {opPredict, 1000, "", ""}, + {opStatement, 1000, "127.0.0.1", "127.0.0.3"}, + {opPredict, 1000, "", ""}, + {opStatement, 1000, "127.0.0.1", "127.0.0.4"}, + {opPredict, 1000, "127.0.0.1", ""}, + }, + "window": { + {opStatement, 0, "127.0.0.1", "127.0.0.2"}, + {opStatement, 2000, "127.0.0.1", "127.0.0.3"}, + {opStatement, 3000, "127.0.0.1", "127.0.0.4"}, + {opPredict, 10000, "127.0.0.1", ""}, + {opPredict, 10001, "", ""}, // first statement expired + {opStatement, 10100, "127.0.0.1", "127.0.0.2"}, + {opPredict, 10200, "127.0.0.1", ""}, + }, + "fullcone": { + {opContact, 0, "", "127.0.0.2"}, + {opStatement, 10, "127.0.0.1", "127.0.0.2"}, + {opContact, 2000, "", "127.0.0.3"}, + {opStatement, 2010, "127.0.0.1", "127.0.0.3"}, + {opContact, 3000, "", "127.0.0.4"}, + {opStatement, 3010, "127.0.0.1", "127.0.0.4"}, + {opCheckFullCone, 3500, "false", ""}, + }, + "fullcone_2": { + {opContact, 0, "", "127.0.0.2"}, + {opStatement, 10, "127.0.0.1", "127.0.0.2"}, + {opContact, 2000, "", "127.0.0.3"}, + {opStatement, 2010, "127.0.0.1", "127.0.0.3"}, + {opStatement, 3000, "127.0.0.1", "127.0.0.4"}, + {opContact, 3010, "", "127.0.0.4"}, + {opCheckFullCone, 3500, "true", ""}, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { runIPTrackerTest(t, test) }) + } +} + +func runIPTrackerTest(t *testing.T, evs []iptrackTestEvent) { + var ( + clock mclock.Simulated + it = NewIPTracker(10*time.Second, 10*time.Second, 3) + ) + it.clock = &clock + for i, ev := range evs { + evtime := time.Duration(ev.time) * time.Millisecond + clock.Run(evtime - time.Duration(clock.Now())) + switch ev.op { + case opStatement: + it.AddStatement(ev.from, ev.ip) + case opContact: + it.AddContact(ev.from) + case opPredict: + if pred := it.PredictEndpoint(); pred != ev.ip { + t.Errorf("op %d: wrong prediction %q, want %q", i, pred, ev.ip) + } + case opCheckFullCone: + pred := fmt.Sprintf("%t", it.PredictFullConeNAT()) + if pred != ev.ip { + t.Errorf("op %d: wrong prediction %s, want %s", i, pred, ev.ip) + } + } + } +} + +// This checks that old statements and contacts are GCed even if Predict* isn't called. +func TestIPTrackerForceGC(t *testing.T) { + var ( + clock mclock.Simulated + window = 10 * time.Second + rate = 50 * time.Millisecond + max = int(window/rate) + 1 + it = NewIPTracker(window, window, 3) + ) + it.clock = &clock + + for i := 0; i < 5*max; i++ { + e1 := make([]byte, 4) + e2 := make([]byte, 4) + mrand.Read(e1) + mrand.Read(e2) + it.AddStatement(string(e1), string(e2)) + it.AddContact(string(e1)) + clock.Run(rate) + } + if len(it.contact) > 2*max { + t.Errorf("contacts not GCed, have %d", len(it.contact)) + } + if len(it.statements) > 2*max { + t.Errorf("statements not GCed, have %d", len(it.statements)) + } +} diff --git a/p2p/protocol.go b/p2p/protocol.go index 4b90a2a70..9438ab8e4 100644 --- a/p2p/protocol.go +++ b/p2p/protocol.go @@ -20,6 +20,7 @@ import ( "fmt" "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" ) // Protocol represents a P2P subprotocol implementation. @@ -52,6 +53,9 @@ type Protocol struct { // about a certain peer in the network. If an info retrieval function is set, // but returns nil, it is assumed that the protocol handshake is still running. PeerInfo func(id enode.ID) interface{} + + // Attributes contains protocol specific information for the node record. + Attributes []enr.Entry } func (p Protocol) cap() Cap { @@ -64,10 +68,6 @@ type Cap struct { Version uint } -func (cap Cap) RlpData() interface{} { - return []interface{}{cap.Name, cap.Version} -} - func (cap Cap) String() string { return fmt.Sprintf("%s/%d", cap.Name, cap.Version) } @@ -79,3 +79,5 @@ func (cs capsByNameAndVersion) Swap(i, j int) { cs[i], cs[j] = cs[j], cs[i] } func (cs capsByNameAndVersion) Less(i, j int) bool { return cs[i].Name < cs[j].Name || (cs[i].Name == cs[j].Name && cs[i].Version < cs[j].Version) } + +func (capsByNameAndVersion) ENRKey() string { return "cap" } diff --git a/p2p/server.go b/p2p/server.go index 40db758e2..6482c0401 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -20,9 +20,11 @@ package p2p import ( "bytes" "crypto/ecdsa" + "encoding/hex" "errors" "fmt" "net" + "sort" "sync" "sync/atomic" "time" @@ -35,8 +37,10 @@ import ( "github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/p2p/discv5" "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" "github.com/ethereum/go-ethereum/p2p/nat" "github.com/ethereum/go-ethereum/p2p/netutil" + "github.com/ethereum/go-ethereum/rlp" ) const ( @@ -160,6 +164,8 @@ type Server struct { lock sync.Mutex // protects running running bool + nodedb *enode.DB + localnode *enode.LocalNode ntab discoverTable listener net.Listener ourHandshake *protoHandshake @@ -347,43 +353,13 @@ func (srv *Server) SubscribeEvents(ch chan *PeerEvent) event.Subscription { // Self returns the local node's endpoint information. func (srv *Server) Self() *enode.Node { srv.lock.Lock() - running, listener, ntab := srv.running, srv.listener, srv.ntab + ln := srv.localnode srv.lock.Unlock() - if !running { + if ln == nil { return enode.NewV4(&srv.PrivateKey.PublicKey, net.ParseIP("0.0.0.0"), 0, 0) } - return srv.makeSelf(listener, ntab) -} - -func (srv *Server) makeSelf(listener net.Listener, ntab discoverTable) *enode.Node { - // If the node is running but discovery is off, manually assemble the node infos. - if ntab == nil { - addr := srv.tcpAddr(listener) - return enode.NewV4(&srv.PrivateKey.PublicKey, addr.IP, addr.Port, 0) - } - // Otherwise return the discovery node. - return ntab.Self() -} - -func (srv *Server) tcpAddr(listener net.Listener) net.TCPAddr { - addr := net.TCPAddr{IP: net.IP{0, 0, 0, 0}} - if listener == nil { - return addr // Inbound connections disabled, use zero address. - } - // Otherwise inject the listener address too. - if a, ok := listener.Addr().(*net.TCPAddr); ok { - addr = *a - } - if srv.NAT != nil { - if ip, err := srv.NAT.ExternalIP(); err == nil { - addr.IP = ip - } - } - if addr.IP.IsUnspecified() { - addr.IP = net.IP{127, 0, 0, 1} - } - return addr + return ln.Node() } // Stop terminates the server and all active peer connections. @@ -443,7 +419,9 @@ func (srv *Server) Start() (err error) { if srv.log == nil { srv.log = log.New() } - srv.log.Info("Starting P2P networking") + if srv.NoDial && srv.ListenAddr == "" { + srv.log.Warn("P2P server will be useless, neither dialing nor listening") + } // static fields if srv.PrivateKey == nil { @@ -466,65 +444,120 @@ func (srv *Server) Start() (err error) { srv.peerOp = make(chan peerOpFunc) srv.peerOpDone = make(chan struct{}) - var ( - conn *net.UDPConn - sconn *sharedUDPConn - realaddr *net.UDPAddr - unhandled chan discover.ReadPacket - ) - - if !srv.NoDiscovery || srv.DiscoveryV5 { - addr, err := net.ResolveUDPAddr("udp", srv.ListenAddr) - if err != nil { + if err := srv.setupLocalNode(); err != nil { + return err + } + if srv.ListenAddr != "" { + if err := srv.setupListening(); err != nil { return err } - conn, err = net.ListenUDP("udp", addr) - if err != nil { - return err + } + if err := srv.setupDiscovery(); err != nil { + return err + } + + dynPeers := srv.maxDialedConns() + dialer := newDialState(srv.localnode.ID(), srv.StaticNodes, srv.BootstrapNodes, srv.ntab, dynPeers, srv.NetRestrict) + srv.loopWG.Add(1) + go srv.run(dialer) + return nil +} + +func (srv *Server) setupLocalNode() error { + // Create the devp2p handshake. + pubkey := crypto.FromECDSAPub(&srv.PrivateKey.PublicKey) + srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: pubkey[1:]} + for _, p := range srv.Protocols { + srv.ourHandshake.Caps = append(srv.ourHandshake.Caps, p.cap()) + } + sort.Sort(capsByNameAndVersion(srv.ourHandshake.Caps)) + + // Create the local node. + db, err := enode.OpenDB(srv.Config.NodeDatabase) + if err != nil { + return err + } + srv.nodedb = db + srv.localnode = enode.NewLocalNode(db, srv.PrivateKey) + srv.localnode.SetFallbackIP(net.IP{127, 0, 0, 1}) + srv.localnode.Set(capsByNameAndVersion(srv.ourHandshake.Caps)) + // TODO: check conflicts + for _, p := range srv.Protocols { + for _, e := range p.Attributes { + srv.localnode.Set(e) } - realaddr = conn.LocalAddr().(*net.UDPAddr) - if srv.NAT != nil { - if !realaddr.IP.IsLoopback() { - go nat.Map(srv.NAT, srv.quit, "udp", realaddr.Port, realaddr.Port, "ethereum discovery") - } - // TODO: react to external IP changes over time. - if ext, err := srv.NAT.ExternalIP(); err == nil { - realaddr = &net.UDPAddr{IP: ext, Port: realaddr.Port} + } + switch srv.NAT.(type) { + case nil: + // No NAT interface, do nothing. + case nat.ExtIP: + // ExtIP doesn't block, set the IP right away. + ip, _ := srv.NAT.ExternalIP() + srv.localnode.SetStaticIP(ip) + default: + // Ask the router about the IP. This takes a while and blocks startup, + // do it in the background. + srv.loopWG.Add(1) + go func() { + defer srv.loopWG.Done() + if ip, err := srv.NAT.ExternalIP(); err == nil { + srv.localnode.SetStaticIP(ip) } - } + }() + } + return nil +} + +func (srv *Server) setupDiscovery() error { + if srv.NoDiscovery && !srv.DiscoveryV5 { + return nil } - if !srv.NoDiscovery && srv.DiscoveryV5 { - unhandled = make(chan discover.ReadPacket, 100) - sconn = &sharedUDPConn{conn, unhandled} + addr, err := net.ResolveUDPAddr("udp", srv.ListenAddr) + if err != nil { + return err + } + conn, err := net.ListenUDP("udp", addr) + if err != nil { + return err + } + realaddr := conn.LocalAddr().(*net.UDPAddr) + srv.log.Debug("UDP listener up", "addr", realaddr) + if srv.NAT != nil { + if !realaddr.IP.IsLoopback() { + go nat.Map(srv.NAT, srv.quit, "udp", realaddr.Port, realaddr.Port, "ethereum discovery") + } } + srv.localnode.SetFallbackUDP(realaddr.Port) - // node table + // Discovery V4 + var unhandled chan discover.ReadPacket + var sconn *sharedUDPConn if !srv.NoDiscovery { + if srv.DiscoveryV5 { + unhandled = make(chan discover.ReadPacket, 100) + sconn = &sharedUDPConn{conn, unhandled} + } cfg := discover.Config{ - PrivateKey: srv.PrivateKey, - AnnounceAddr: realaddr, - NodeDBPath: srv.NodeDatabase, - NetRestrict: srv.NetRestrict, - Bootnodes: srv.BootstrapNodes, - Unhandled: unhandled, + PrivateKey: srv.PrivateKey, + NetRestrict: srv.NetRestrict, + Bootnodes: srv.BootstrapNodes, + Unhandled: unhandled, } - ntab, err := discover.ListenUDP(conn, cfg) + ntab, err := discover.ListenUDP(conn, srv.localnode, cfg) if err != nil { return err } srv.ntab = ntab } - + // Discovery V5 if srv.DiscoveryV5 { - var ( - ntab *discv5.Network - err error - ) + var ntab *discv5.Network + var err error if sconn != nil { - ntab, err = discv5.ListenUDP(srv.PrivateKey, sconn, realaddr, "", srv.NetRestrict) //srv.NodeDatabase) + ntab, err = discv5.ListenUDP(srv.PrivateKey, sconn, "", srv.NetRestrict) } else { - ntab, err = discv5.ListenUDP(srv.PrivateKey, conn, realaddr, "", srv.NetRestrict) //srv.NodeDatabase) + ntab, err = discv5.ListenUDP(srv.PrivateKey, conn, "", srv.NetRestrict) } if err != nil { return err @@ -534,32 +567,10 @@ func (srv *Server) Start() (err error) { } srv.DiscV5 = ntab } - - dynPeers := srv.maxDialedConns() - dialer := newDialState(srv.StaticNodes, srv.BootstrapNodes, srv.ntab, dynPeers, srv.NetRestrict) - - // handshake - pubkey := crypto.FromECDSAPub(&srv.PrivateKey.PublicKey) - srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: pubkey[1:]} - for _, p := range srv.Protocols { - srv.ourHandshake.Caps = append(srv.ourHandshake.Caps, p.cap()) - } - // listen/dial - if srv.ListenAddr != "" { - if err := srv.startListening(); err != nil { - return err - } - } - if srv.NoDial && srv.ListenAddr == "" { - srv.log.Warn("P2P server will be useless, neither dialing nor listening") - } - - srv.loopWG.Add(1) - go srv.run(dialer) return nil } -func (srv *Server) startListening() error { +func (srv *Server) setupListening() error { // Launch the TCP listener. listener, err := net.Listen("tcp", srv.ListenAddr) if err != nil { @@ -568,8 +579,11 @@ func (srv *Server) startListening() error { laddr := listener.Addr().(*net.TCPAddr) srv.ListenAddr = laddr.String() srv.listener = listener + srv.localnode.Set(enr.TCP(laddr.Port)) + srv.loopWG.Add(1) go srv.listenLoop() + // Map the TCP listening port if NAT is configured. if !laddr.IP.IsLoopback() && srv.NAT != nil { srv.loopWG.Add(1) @@ -589,7 +603,10 @@ type dialer interface { } func (srv *Server) run(dialstate dialer) { + srv.log.Info("Started P2P networking", "self", srv.localnode.Node()) defer srv.loopWG.Done() + defer srv.nodedb.Close() + var ( peers = make(map[enode.ID]*Peer) inboundCount = 0 @@ -781,7 +798,7 @@ func (srv *Server) encHandshakeChecks(peers map[enode.ID]*Peer, inboundCount int return DiscTooManyPeers case peers[c.node.ID()] != nil: return DiscAlreadyConnected - case c.node.ID() == srv.Self().ID(): + case c.node.ID() == srv.localnode.ID(): return DiscSelf default: return nil @@ -802,15 +819,11 @@ func (srv *Server) maxDialedConns() int { return srv.MaxPeers / r } -type tempError interface { - Temporary() bool -} - // listenLoop runs in its own goroutine and accepts // inbound connections. func (srv *Server) listenLoop() { defer srv.loopWG.Done() - srv.log.Info("RLPx listener up", "self", srv.Self()) + srv.log.Debug("TCP listener up", "addr", srv.listener.Addr()) tokens := defaultMaxPendingPeers if srv.MaxPendingPeers > 0 { @@ -831,7 +844,7 @@ func (srv *Server) listenLoop() { ) for { fd, err = srv.listener.Accept() - if tempErr, ok := err.(tempError); ok && tempErr.Temporary() { + if netutil.IsTemporaryError(err) { srv.log.Debug("Temporary read error", "err", err) continue } else if err != nil { @@ -864,10 +877,6 @@ func (srv *Server) listenLoop() { // as a peer. It returns when the connection has been added as a peer // or the handshakes have failed. func (srv *Server) SetupConn(fd net.Conn, flags connFlag, dialDest *enode.Node) error { - self := srv.Self() - if self == nil { - return errors.New("shutdown") - } c := &conn{fd: fd, transport: srv.newTransport(fd), flags: flags, cont: make(chan error)} err := srv.setupConn(c, flags, dialDest) if err != nil { @@ -1003,6 +1012,7 @@ type NodeInfo struct { ID string `json:"id"` // Unique node identifier (also the encryption key) Name string `json:"name"` // Name of the node, including client type, version, OS, custom data Enode string `json:"enode"` // Enode URL for adding this peer from remote peers + ENR string `json:"enr"` // Ethereum Node Record IP string `json:"ip"` // IP address of the node Ports struct { Discovery int `json:"discovery"` // UDP listening port for discovery protocol @@ -1014,9 +1024,8 @@ type NodeInfo struct { // NodeInfo gathers and returns a collection of metadata known about the host. func (srv *Server) NodeInfo() *NodeInfo { - node := srv.Self() - // Gather and assemble the generic node infos + node := srv.Self() info := &NodeInfo{ Name: srv.Name, Enode: node.String(), @@ -1027,6 +1036,9 @@ func (srv *Server) NodeInfo() *NodeInfo { } info.Ports.Discovery = node.UDP() info.Ports.Listener = node.TCP() + if enc, err := rlp.EncodeToBytes(node.Record()); err == nil { + info.ENR = "0x" + hex.EncodeToString(enc) + } // Gather all the running protocol infos (only once per protocol type) for _, proto := range srv.Protocols { diff --git a/p2p/server_test.go b/p2p/server_test.go index e0b1fc122..7e11577d6 100644 --- a/p2p/server_test.go +++ b/p2p/server_test.go @@ -225,12 +225,15 @@ func TestServerTaskScheduling(t *testing.T) { // The Server in this test isn't actually running // because we're only interested in what run does. + db, _ := enode.OpenDB("") srv := &Server{ - Config: Config{MaxPeers: 10}, - quit: make(chan struct{}), - ntab: fakeTable{}, - running: true, - log: log.New(), + Config: Config{MaxPeers: 10}, + localnode: enode.NewLocalNode(db, newkey()), + nodedb: db, + quit: make(chan struct{}), + ntab: fakeTable{}, + running: true, + log: log.New(), } srv.loopWG.Add(1) go func() { @@ -271,11 +274,14 @@ func TestServerManyTasks(t *testing.T) { } var ( - srv = &Server{ - quit: make(chan struct{}), - ntab: fakeTable{}, - running: true, - log: log.New(), + db, _ = enode.OpenDB("") + srv = &Server{ + quit: make(chan struct{}), + localnode: enode.NewLocalNode(db, newkey()), + nodedb: db, + ntab: fakeTable{}, + running: true, + log: log.New(), } done = make(chan *testTask) start, end = 0, 0 |