diff options
author | Janos Guljas <janos@resenje.org> | 2018-02-09 19:23:30 +0800 |
---|---|---|
committer | Janos Guljas <janos@resenje.org> | 2018-02-22 21:23:17 +0800 |
commit | a3a07350dcef0ba39829a20d8ddba4bd3463d293 (patch) | |
tree | 100f2515cadd92105537a12e6981fab2193435ee /p2p | |
parent | 820cf09c98706f71d4b02b6c25e62db15830f3fb (diff) | |
parent | 1a4e68721a901e86322631fed1191025a6d14c52 (diff) | |
download | dexon-a3a07350dcef0ba39829a20d8ddba4bd3463d293.tar dexon-a3a07350dcef0ba39829a20d8ddba4bd3463d293.tar.gz dexon-a3a07350dcef0ba39829a20d8ddba4bd3463d293.tar.bz2 dexon-a3a07350dcef0ba39829a20d8ddba4bd3463d293.tar.lz dexon-a3a07350dcef0ba39829a20d8ddba4bd3463d293.tar.xz dexon-a3a07350dcef0ba39829a20d8ddba4bd3463d293.tar.zst dexon-a3a07350dcef0ba39829a20d8ddba4bd3463d293.zip |
swarm, cmd/swarm: Merge branch 'master' into multiple-ens-endpoints
Diffstat (limited to 'p2p')
-rw-r--r-- | p2p/discover/database.go | 6 | ||||
-rw-r--r-- | p2p/discover/udp.go | 44 | ||||
-rw-r--r-- | p2p/discover/udp_test.go | 3 | ||||
-rw-r--r-- | p2p/discv5/database.go | 6 | ||||
-rw-r--r-- | p2p/discv5/net.go | 59 | ||||
-rw-r--r-- | p2p/discv5/net_test.go | 2 | ||||
-rw-r--r-- | p2p/discv5/node.go | 5 | ||||
-rw-r--r-- | p2p/discv5/nodeevent_string.go | 6 | ||||
-rw-r--r-- | p2p/discv5/sim_test.go | 2 | ||||
-rw-r--r-- | p2p/discv5/ticket.go | 227 | ||||
-rw-r--r-- | p2p/discv5/topic.go | 3 | ||||
-rw-r--r-- | p2p/discv5/udp.go | 47 | ||||
-rw-r--r-- | p2p/enr/enr.go | 290 | ||||
-rw-r--r-- | p2p/enr/enr_test.go | 318 | ||||
-rw-r--r-- | p2p/enr/entries.go | 160 | ||||
-rw-r--r-- | p2p/server.go | 75 |
16 files changed, 1035 insertions, 218 deletions
diff --git a/p2p/discover/database.go b/p2p/discover/database.go index 7206a63c6..b136609f2 100644 --- a/p2p/discover/database.go +++ b/p2p/discover/database.go @@ -226,14 +226,14 @@ func (db *nodeDB) ensureExpirer() { // expirer should be started in a go routine, and is responsible for looping ad // infinitum and dropping stale data from the database. func (db *nodeDB) expirer() { - tick := time.Tick(nodeDBCleanupCycle) + tick := time.NewTicker(nodeDBCleanupCycle) + defer tick.Stop() for { select { - case <-tick: + case <-tick.C: if err := db.expireNodes(); err != nil { log.Error("Failed to expire nodedb items", "err", err) } - case <-db.quit: return } diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go index f9eb99ee3..60436952d 100644 --- a/p2p/discover/udp.go +++ b/p2p/discover/udp.go @@ -210,17 +210,15 @@ type reply struct { matched chan<- bool } +// ReadPacket is sent to the unhandled channel when it could not be processed +type ReadPacket struct { + Data []byte + Addr *net.UDPAddr +} + // ListenUDP returns a new table that listens for UDP packets on laddr. -func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, error) { - addr, err := net.ResolveUDPAddr("udp", laddr) - if err != nil { - return nil, err - } - conn, err := net.ListenUDP("udp", addr) - if err != nil { - return nil, err - } - tab, _, err := newUDP(priv, conn, natm, nodeDBPath, netrestrict) +func ListenUDP(priv *ecdsa.PrivateKey, conn conn, realaddr *net.UDPAddr, unhandled chan ReadPacket, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, error) { + tab, _, err := newUDP(priv, conn, realaddr, unhandled, nodeDBPath, netrestrict) if err != nil { return nil, err } @@ -228,7 +226,7 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBP return tab, nil } -func newUDP(priv *ecdsa.PrivateKey, c conn, natm nat.Interface, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, *udp, error) { +func newUDP(priv *ecdsa.PrivateKey, c conn, realaddr *net.UDPAddr, unhandled chan ReadPacket, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, *udp, error) { udp := &udp{ conn: c, priv: priv, @@ -237,16 +235,6 @@ func newUDP(priv *ecdsa.PrivateKey, c conn, natm nat.Interface, nodeDBPath strin gotreply: make(chan reply), addpending: make(chan *pending), } - realaddr := c.LocalAddr().(*net.UDPAddr) - if natm != nil { - if !realaddr.IP.IsLoopback() { - go nat.Map(natm, udp.closing, "udp", realaddr.Port, realaddr.Port, "ethereum discovery") - } - // TODO: react to external IP changes over time. - if ext, err := natm.ExternalIP(); err == nil { - realaddr = &net.UDPAddr{IP: ext, Port: realaddr.Port} - } - } // TODO: separate TCP port udp.ourEndpoint = makeEndpoint(realaddr, uint16(realaddr.Port)) tab, err := newTable(udp, PubkeyID(&priv.PublicKey), realaddr, nodeDBPath) @@ -256,7 +244,7 @@ func newUDP(priv *ecdsa.PrivateKey, c conn, natm nat.Interface, nodeDBPath strin udp.Table = tab go udp.loop() - go udp.readLoop() + go udp.readLoop(unhandled) return udp.Table, udp, nil } @@ -492,8 +480,11 @@ func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) ([]byte, } // readLoop runs in its own goroutine. it handles incoming UDP packets. -func (t *udp) readLoop() { +func (t *udp) readLoop(unhandled chan ReadPacket) { defer t.conn.Close() + if unhandled != nil { + defer close(unhandled) + } // Discovery packets are defined to be no larger than 1280 bytes. // Packets larger than this size will be cut at the end and treated // as invalid because their hash won't match. @@ -509,7 +500,12 @@ func (t *udp) readLoop() { log.Debug("UDP read error", "err", err) return } - t.handlePacket(from, buf[:nbytes]) + if t.handlePacket(from, buf[:nbytes]) != nil && unhandled != nil { + select { + case unhandled <- ReadPacket{buf[:nbytes], from}: + default: + } + } } } diff --git a/p2p/discover/udp_test.go b/p2p/discover/udp_test.go index 21e8b561d..b81caf839 100644 --- a/p2p/discover/udp_test.go +++ b/p2p/discover/udp_test.go @@ -70,7 +70,8 @@ func newUDPTest(t *testing.T) *udpTest { remotekey: newkey(), remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303}, } - test.table, test.udp, _ = newUDP(test.localkey, test.pipe, nil, "", nil) + realaddr := test.pipe.LocalAddr().(*net.UDPAddr) + test.table, test.udp, _ = newUDP(test.localkey, test.pipe, realaddr, nil, "", nil) return test } diff --git a/p2p/discv5/database.go b/p2p/discv5/database.go index a3b044ec1..3c2d5744c 100644 --- a/p2p/discv5/database.go +++ b/p2p/discv5/database.go @@ -239,14 +239,14 @@ func (db *nodeDB) ensureExpirer() { // expirer should be started in a go routine, and is responsible for looping ad // infinitum and dropping stale data from the database. func (db *nodeDB) expirer() { - tick := time.Tick(nodeDBCleanupCycle) + tick := time.NewTicker(nodeDBCleanupCycle) + defer tick.Stop() for { select { - case <-tick: + case <-tick.C: if err := db.expireNodes(); err != nil { log.Error(fmt.Sprintf("Failed to expire nodedb items: %v", err)) } - case <-db.quit: return } diff --git a/p2p/discv5/net.go b/p2p/discv5/net.go index 2fbb60824..f9baf126f 100644 --- a/p2p/discv5/net.go +++ b/p2p/discv5/net.go @@ -29,7 +29,6 @@ import ( "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto/sha3" "github.com/ethereum/go-ethereum/log" - "github.com/ethereum/go-ethereum/p2p/nat" "github.com/ethereum/go-ethereum/p2p/netutil" "github.com/ethereum/go-ethereum/rlp" ) @@ -51,16 +50,9 @@ const ( const testTopic = "foo" const ( - printDebugLogs = false printTestImgLogs = false ) -func debugLog(s string) { - if printDebugLogs { - fmt.Println(s) - } -} - // Network manages the table and all protocol interaction. type Network struct { db *nodeDB // database of known nodes @@ -141,7 +133,7 @@ type timeoutEvent struct { node *Node } -func newNetwork(conn transport, ourPubkey ecdsa.PublicKey, natm nat.Interface, dbPath string, netrestrict *netutil.Netlist) (*Network, error) { +func newNetwork(conn transport, ourPubkey ecdsa.PublicKey, dbPath string, netrestrict *netutil.Netlist) (*Network, error) { ourID := PubkeyID(&ourPubkey) var db *nodeDB @@ -388,14 +380,14 @@ func (net *Network) loop() { } }() resetNextTicket := func() { - t, timeout := net.ticketStore.nextFilteredTicket() - if t != nextTicket { - nextTicket = t + ticket, timeout := net.ticketStore.nextFilteredTicket() + if nextTicket != ticket { + nextTicket = ticket if nextRegisterTimer != nil { nextRegisterTimer.Stop() nextRegisterTime = nil } - if t != nil { + if ticket != nil { nextRegisterTimer = time.NewTimer(timeout) nextRegisterTime = nextRegisterTimer.C } @@ -423,13 +415,13 @@ loop: select { case <-net.closeReq: - debugLog("<-net.closeReq") + log.Trace("<-net.closeReq") break loop // Ingress packet handling. case pkt := <-net.read: //fmt.Println("read", pkt.ev) - debugLog("<-net.read") + log.Trace("<-net.read") n := net.internNode(&pkt) prestate := n.state status := "ok" @@ -444,7 +436,7 @@ loop: // State transition timeouts. case timeout := <-net.timeout: - debugLog("<-net.timeout") + log.Trace("<-net.timeout") if net.timeoutTimers[timeout] == nil { // Stale timer (was aborted). continue @@ -462,20 +454,20 @@ loop: // Querying. case q := <-net.queryReq: - debugLog("<-net.queryReq") + log.Trace("<-net.queryReq") if !q.start(net) { q.remote.deferQuery(q) } // Interacting with the table. case f := <-net.tableOpReq: - debugLog("<-net.tableOpReq") + log.Trace("<-net.tableOpReq") f() net.tableOpResp <- struct{}{} // Topic registration stuff. case req := <-net.topicRegisterReq: - debugLog("<-net.topicRegisterReq") + log.Trace("<-net.topicRegisterReq") if !req.add { net.ticketStore.removeRegisterTopic(req.topic) continue @@ -486,7 +478,7 @@ loop: // determination for new topics. // if topicRegisterLookupDone == nil { if topicRegisterLookupTarget.target == (common.Hash{}) { - debugLog("topicRegisterLookupTarget == null") + log.Trace("topicRegisterLookupTarget == null") if topicRegisterLookupTick.Stop() { <-topicRegisterLookupTick.C } @@ -496,7 +488,7 @@ loop: } case nodes := <-topicRegisterLookupDone: - debugLog("<-topicRegisterLookupDone") + log.Trace("<-topicRegisterLookupDone") net.ticketStore.registerLookupDone(topicRegisterLookupTarget, nodes, func(n *Node) []byte { net.ping(n, n.addr()) return n.pingEcho @@ -507,7 +499,7 @@ loop: topicRegisterLookupDone = nil case <-topicRegisterLookupTick.C: - debugLog("<-topicRegisterLookupTick") + log.Trace("<-topicRegisterLookupTick") if (topicRegisterLookupTarget.target == common.Hash{}) { target, delay := net.ticketStore.nextRegisterLookup() topicRegisterLookupTarget = target @@ -520,14 +512,14 @@ loop: } case <-nextRegisterTime: - debugLog("<-nextRegisterTime") + log.Trace("<-nextRegisterTime") net.ticketStore.ticketRegistered(*nextTicket) //fmt.Println("sendTopicRegister", nextTicket.t.node.addr().String(), nextTicket.t.topics, nextTicket.idx, nextTicket.t.pong) net.conn.sendTopicRegister(nextTicket.t.node, nextTicket.t.topics, nextTicket.idx, nextTicket.t.pong) case req := <-net.topicSearchReq: if refreshDone == nil { - debugLog("<-net.topicSearchReq") + log.Trace("<-net.topicSearchReq") info, ok := searchInfo[req.topic] if ok { if req.delay == time.Duration(0) { @@ -588,7 +580,7 @@ loop: }) case <-statsDump.C: - debugLog("<-statsDump.C") + log.Trace("<-statsDump.C") /*r, ok := net.ticketStore.radius[testTopic] if !ok { fmt.Printf("(%x) no radius @ %v\n", net.tab.self.ID[:8], time.Now()) @@ -617,7 +609,7 @@ loop: // Periodic / lookup-initiated bucket refresh. case <-refreshTimer.C: - debugLog("<-refreshTimer.C") + log.Trace("<-refreshTimer.C") // TODO: ideally we would start the refresh timer after // fallback nodes have been set for the first time. if refreshDone == nil { @@ -631,7 +623,7 @@ loop: bucketRefreshTimer.Reset(bucketRefreshInterval) }() case newNursery := <-net.refreshReq: - debugLog("<-net.refreshReq") + log.Trace("<-net.refreshReq") if newNursery != nil { net.nursery = newNursery } @@ -641,7 +633,7 @@ loop: } net.refreshResp <- refreshDone case <-refreshDone: - debugLog("<-net.refreshDone") + log.Trace("<-net.refreshDone") refreshDone = nil list := searchReqWhenRefreshDone searchReqWhenRefreshDone = nil @@ -652,7 +644,7 @@ loop: }() } } - debugLog("loop stopped") + log.Trace("loop stopped") log.Debug(fmt.Sprintf("shutting down")) if net.conn != nil { @@ -1109,14 +1101,14 @@ func (net *Network) ping(n *Node, addr *net.UDPAddr) { //fmt.Println(" not sent") return } - debugLog(fmt.Sprintf("ping(node = %x)", n.ID[:8])) + log.Trace("Pinging remote node", "node", n.ID) n.pingTopics = net.ticketStore.regTopicSet() n.pingEcho = net.conn.sendPing(n, addr, n.pingTopics) net.timedEvent(respTimeout, n, pongTimeout) } func (net *Network) handlePing(n *Node, pkt *ingressPacket) { - debugLog(fmt.Sprintf("handlePing(node = %x)", n.ID[:8])) + log.Trace("Handling remote ping", "node", n.ID) ping := pkt.data.(*ping) n.TCP = ping.From.TCP t := net.topictab.getTicket(n, ping.Topics) @@ -1131,7 +1123,7 @@ func (net *Network) handlePing(n *Node, pkt *ingressPacket) { } func (net *Network) handleKnownPong(n *Node, pkt *ingressPacket) error { - debugLog(fmt.Sprintf("handleKnownPong(node = %x)", n.ID[:8])) + log.Trace("Handling known pong", "node", n.ID) net.abortTimedEvent(n, pongTimeout) now := mclock.Now() ticket, err := pongToTicket(now, n.pingTopics, n, pkt) @@ -1139,9 +1131,8 @@ func (net *Network) handleKnownPong(n *Node, pkt *ingressPacket) error { // fmt.Printf("(%x) ticket: %+v\n", net.tab.self.ID[:8], pkt.data) net.ticketStore.addTicket(now, pkt.data.(*pong).ReplyTok, ticket) } else { - debugLog(fmt.Sprintf(" error: %v", err)) + log.Trace("Failed to convert pong to ticket", "err", err) } - n.pingEcho = nil n.pingTopics = nil return err diff --git a/p2p/discv5/net_test.go b/p2p/discv5/net_test.go index bd234f5ba..369282ca9 100644 --- a/p2p/discv5/net_test.go +++ b/p2p/discv5/net_test.go @@ -28,7 +28,7 @@ import ( func TestNetwork_Lookup(t *testing.T) { key, _ := crypto.GenerateKey() - network, err := newNetwork(lookupTestnet, key.PublicKey, nil, "", nil) + network, err := newNetwork(lookupTestnet, key.PublicKey, "", nil) if err != nil { t.Fatal(err) } diff --git a/p2p/discv5/node.go b/p2p/discv5/node.go index 2db7a508f..fd88a55b1 100644 --- a/p2p/discv5/node.go +++ b/p2p/discv5/node.go @@ -273,6 +273,11 @@ func (n NodeID) GoString() string { return fmt.Sprintf("discover.HexID(\"%x\")", n[:]) } +// TerminalString returns a shortened hex string for terminal logging. +func (n NodeID) TerminalString() string { + return hex.EncodeToString(n[:8]) +} + // HexID converts a hex string to a NodeID. // The string may be prefixed with 0x. func HexID(in string) (NodeID, error) { diff --git a/p2p/discv5/nodeevent_string.go b/p2p/discv5/nodeevent_string.go index fde9045c5..eb696fb8b 100644 --- a/p2p/discv5/nodeevent_string.go +++ b/p2p/discv5/nodeevent_string.go @@ -1,8 +1,8 @@ -// Code generated by "stringer -type nodeEvent"; DO NOT EDIT +// Code generated by "stringer -type=nodeEvent"; DO NOT EDIT. package discv5 -import "fmt" +import "strconv" const ( _nodeEvent_name_0 = "invalidEventpingPacketpongPacketfindnodePacketneighborsPacketfindnodeHashPackettopicRegisterPackettopicQueryPackettopicNodesPacket" @@ -22,6 +22,6 @@ func (i nodeEvent) String() string { i -= 265 return _nodeEvent_name_1[_nodeEvent_index_1[i]:_nodeEvent_index_1[i+1]] default: - return fmt.Sprintf("nodeEvent(%d)", i) + return "nodeEvent(" + strconv.FormatInt(int64(i), 10) + ")" } } diff --git a/p2p/discv5/sim_test.go b/p2p/discv5/sim_test.go index bf57872e2..543faecd4 100644 --- a/p2p/discv5/sim_test.go +++ b/p2p/discv5/sim_test.go @@ -282,7 +282,7 @@ func (s *simulation) launchNode(log bool) *Network { addr := &net.UDPAddr{IP: ip, Port: 30303} transport := &simTransport{joinTime: time.Now(), sender: id, senderAddr: addr, sim: s, priv: key} - net, err := newNetwork(transport, key.PublicKey, nil, "<no database>", nil) + net, err := newNetwork(transport, key.PublicKey, "<no database>", nil) if err != nil { panic("cannot launch new node: " + err.Error()) } diff --git a/p2p/discv5/ticket.go b/p2p/discv5/ticket.go index 193cef4be..37ce8d23c 100644 --- a/p2p/discv5/ticket.go +++ b/p2p/discv5/ticket.go @@ -28,6 +28,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/log" ) const ( @@ -128,8 +129,11 @@ type ticketStore struct { // Contains buckets (for each absolute minute) of tickets // that can be used in that minute. // This is only set if the topic is being registered. - tickets map[Topic]topicTickets - regtopics []Topic + tickets map[Topic]*topicTickets + + regQueue []Topic // Topic registration queue for round robin attempts + regSet map[Topic]struct{} // Topic registration queue contents for fast filling + nodes map[*Node]*ticket nodeLastReq map[*Node]reqInfo @@ -152,14 +156,16 @@ type sentQuery struct { } type topicTickets struct { - buckets map[timeBucket][]ticketRef - nextLookup, nextReg mclock.AbsTime + buckets map[timeBucket][]ticketRef + nextLookup mclock.AbsTime + nextReg mclock.AbsTime } func newTicketStore() *ticketStore { return &ticketStore{ radius: make(map[Topic]*topicRadius), - tickets: make(map[Topic]topicTickets), + tickets: make(map[Topic]*topicTickets), + regSet: make(map[Topic]struct{}), nodes: make(map[*Node]*ticket), nodeLastReq: make(map[*Node]reqInfo), searchTopicMap: make(map[Topic]searchTopic), @@ -169,13 +175,13 @@ func newTicketStore() *ticketStore { // addTopic starts tracking a topic. If register is true, // the local node will register the topic and tickets will be collected. -func (s *ticketStore) addTopic(t Topic, register bool) { - debugLog(fmt.Sprintf(" addTopic(%v, %v)", t, register)) - if s.radius[t] == nil { - s.radius[t] = newTopicRadius(t) +func (s *ticketStore) addTopic(topic Topic, register bool) { + log.Trace("Adding discovery topic", "topic", topic, "register", register) + if s.radius[topic] == nil { + s.radius[topic] = newTopicRadius(topic) } - if register && s.tickets[t].buckets == nil { - s.tickets[t] = topicTickets{buckets: make(map[timeBucket][]ticketRef)} + if register && s.tickets[topic] == nil { + s.tickets[topic] = &topicTickets{buckets: make(map[timeBucket][]ticketRef)} } } @@ -194,7 +200,11 @@ func (s *ticketStore) removeSearchTopic(t Topic) { // removeRegisterTopic deletes all tickets for the given topic. func (s *ticketStore) removeRegisterTopic(topic Topic) { - debugLog(fmt.Sprintf(" removeRegisterTopic(%v)", topic)) + log.Trace("Removing discovery topic", "topic", topic) + if s.tickets[topic] == nil { + log.Warn("Removing non-existent discovery topic", "topic", topic) + return + } for _, list := range s.tickets[topic].buckets { for _, ref := range list { ref.t.refCnt-- @@ -216,23 +226,35 @@ func (s *ticketStore) regTopicSet() []Topic { } // nextRegisterLookup returns the target of the next lookup for ticket collection. -func (s *ticketStore) nextRegisterLookup() (lookup lookupInfo, delay time.Duration) { - debugLog("nextRegisterLookup()") - firstTopic, ok := s.iterRegTopics() - for topic := firstTopic; ok; { - debugLog(fmt.Sprintf(" checking topic %v, len(s.tickets[topic]) = %d", topic, len(s.tickets[topic].buckets))) - if s.tickets[topic].buckets != nil && s.needMoreTickets(topic) { - next := s.radius[topic].nextTarget(false) - debugLog(fmt.Sprintf(" %x 1s", next.target[:8])) - return next, 100 * time.Millisecond +func (s *ticketStore) nextRegisterLookup() (lookupInfo, time.Duration) { + // Queue up any new topics (or discarded ones), preserving iteration order + for topic := range s.tickets { + if _, ok := s.regSet[topic]; !ok { + s.regQueue = append(s.regQueue, topic) + s.regSet[topic] = struct{}{} + } + } + // Iterate over the set of all topics and look up the next suitable one + for len(s.regQueue) > 0 { + // Fetch the next topic from the queue, and ensure it still exists + topic := s.regQueue[0] + s.regQueue = s.regQueue[1:] + delete(s.regSet, topic) + + if s.tickets[topic] == nil { + continue } - topic, ok = s.iterRegTopics() - if topic == firstTopic { - break // We have checked all topics. + // If the topic needs more tickets, return it + if s.tickets[topic].nextLookup < mclock.Now() { + next, delay := s.radius[topic].nextTarget(false), 100*time.Millisecond + log.Trace("Found discovery topic to register", "topic", topic, "target", next.target, "delay", delay) + return next, delay } } - debugLog(" null, 40s") - return lookupInfo{}, 40 * time.Second + // No registration topics found or all exhausted, sleep + delay := 40 * time.Second + log.Trace("No topic found to register", "delay", delay) + return lookupInfo{}, delay } func (s *ticketStore) nextSearchLookup(topic Topic) lookupInfo { @@ -246,40 +268,22 @@ func (s *ticketStore) nextSearchLookup(topic Topic) lookupInfo { return target } -// iterRegTopics returns topics to register in arbitrary order. -// The second return value is false if there are no topics. -func (s *ticketStore) iterRegTopics() (Topic, bool) { - debugLog("iterRegTopics()") - if len(s.regtopics) == 0 { - if len(s.tickets) == 0 { - debugLog(" false") - return "", false - } - // Refill register list. - for t := range s.tickets { - s.regtopics = append(s.regtopics, t) - } +// ticketsInWindow returns the tickets of a given topic in the registration window. +func (s *ticketStore) ticketsInWindow(topic Topic) []ticketRef { + // Sanity check that the topic still exists before operating on it + if s.tickets[topic] == nil { + log.Warn("Listing non-existing discovery tickets", "topic", topic) + return nil } - topic := s.regtopics[len(s.regtopics)-1] - s.regtopics = s.regtopics[:len(s.regtopics)-1] - debugLog(" " + string(topic) + " true") - return topic, true -} - -func (s *ticketStore) needMoreTickets(t Topic) bool { - return s.tickets[t].nextLookup < mclock.Now() -} + // Gather all the tickers in the next time window + var tickets []ticketRef -// ticketsInWindow returns the tickets of a given topic in the registration window. -func (s *ticketStore) ticketsInWindow(t Topic) []ticketRef { - ltBucket := s.lastBucketFetched - var res []ticketRef - tickets := s.tickets[t].buckets - for g := ltBucket; g < ltBucket+timeWindow; g++ { - res = append(res, tickets[g]...) + buckets := s.tickets[topic].buckets + for idx := timeBucket(0); idx < timeWindow; idx++ { + tickets = append(tickets, buckets[s.lastBucketFetched+idx]...) } - debugLog(fmt.Sprintf("ticketsInWindow(%v) = %v", t, len(res))) - return res + log.Trace("Retrieved discovery registration tickets", "topic", topic, "from", s.lastBucketFetched, "tickets", len(tickets)) + return tickets } func (s *ticketStore) removeExcessTickets(t Topic) { @@ -317,53 +321,55 @@ func (s ticketRefByWaitTime) Swap(i, j int) { func (s *ticketStore) addTicketRef(r ticketRef) { topic := r.t.topics[r.idx] - t := s.tickets[topic] - if t.buckets == nil { + tickets := s.tickets[topic] + if tickets == nil { + log.Warn("Adding ticket to non-existent topic", "topic", topic) return } bucket := timeBucket(r.t.regTime[r.idx] / mclock.AbsTime(ticketTimeBucketLen)) - t.buckets[bucket] = append(t.buckets[bucket], r) + tickets.buckets[bucket] = append(tickets.buckets[bucket], r) r.t.refCnt++ min := mclock.Now() - mclock.AbsTime(collectFrequency)*maxCollectDebt - if t.nextLookup < min { - t.nextLookup = min + if tickets.nextLookup < min { + tickets.nextLookup = min } - t.nextLookup += mclock.AbsTime(collectFrequency) - s.tickets[topic] = t + tickets.nextLookup += mclock.AbsTime(collectFrequency) //s.removeExcessTickets(topic) } -func (s *ticketStore) nextFilteredTicket() (t *ticketRef, wait time.Duration) { +func (s *ticketStore) nextFilteredTicket() (*ticketRef, time.Duration) { now := mclock.Now() for { - t, wait = s.nextRegisterableTicket() - if t == nil { - return + ticket, wait := s.nextRegisterableTicket() + if ticket == nil { + return ticket, wait } + log.Trace("Found discovery ticket to register", "node", ticket.t.node, "serial", ticket.t.serial, "wait", wait) + regTime := now + mclock.AbsTime(wait) - topic := t.t.topics[t.idx] - if regTime >= s.tickets[topic].nextReg { - return + topic := ticket.t.topics[ticket.idx] + if s.tickets[topic] != nil && regTime >= s.tickets[topic].nextReg { + return ticket, wait } - s.removeTicketRef(*t) + s.removeTicketRef(*ticket) } } -func (s *ticketStore) ticketRegistered(t ticketRef) { +func (s *ticketStore) ticketRegistered(ref ticketRef) { now := mclock.Now() - topic := t.t.topics[t.idx] - tt := s.tickets[topic] + topic := ref.t.topics[ref.idx] + tickets := s.tickets[topic] min := now - mclock.AbsTime(registerFrequency)*maxRegisterDebt - if min > tt.nextReg { - tt.nextReg = min + if min > tickets.nextReg { + tickets.nextReg = min } - tt.nextReg += mclock.AbsTime(registerFrequency) - s.tickets[topic] = tt + tickets.nextReg += mclock.AbsTime(registerFrequency) + s.tickets[topic] = tickets - s.removeTicketRef(t) + s.removeTicketRef(ref) } // nextRegisterableTicket returns the next ticket that can be used @@ -374,16 +380,7 @@ func (s *ticketStore) ticketRegistered(t ticketRef) { // // A ticket can be returned more than once with <= zero wait time in case // the ticket contains multiple topics. -func (s *ticketStore) nextRegisterableTicket() (t *ticketRef, wait time.Duration) { - defer func() { - if t == nil { - debugLog(" nil") - } else { - debugLog(fmt.Sprintf(" node = %x sn = %v wait = %v", t.t.node.ID[:8], t.t.serial, wait)) - } - }() - - debugLog("nextRegisterableTicket()") +func (s *ticketStore) nextRegisterableTicket() (*ticketRef, time.Duration) { now := mclock.Now() if s.nextTicketCached != nil { return s.nextTicketCached, time.Duration(s.nextTicketCached.topicRegTime() - now) @@ -412,9 +409,8 @@ func (s *ticketStore) nextRegisterableTicket() (t *ticketRef, wait time.Duration return nil, 0 } if nextTicket.t != nil { - wait = time.Duration(nextTicket.topicRegTime() - now) s.nextTicketCached = &nextTicket - return &nextTicket, wait + return &nextTicket, time.Duration(nextTicket.topicRegTime() - now) } s.lastBucketFetched = bucket } @@ -422,14 +418,20 @@ func (s *ticketStore) nextRegisterableTicket() (t *ticketRef, wait time.Duration // removeTicket removes a ticket from the ticket store func (s *ticketStore) removeTicketRef(ref ticketRef) { - debugLog(fmt.Sprintf("removeTicketRef(node = %x sn = %v)", ref.t.node.ID[:8], ref.t.serial)) + log.Trace("Removing discovery ticket reference", "node", ref.t.node.ID, "serial", ref.t.serial) + + // Make nextRegisterableTicket return the next available ticket. + s.nextTicketCached = nil + topic := ref.topic() - tickets := s.tickets[topic].buckets + tickets := s.tickets[topic] + if tickets == nil { + log.Trace("Removing tickets from unknown topic", "topic", topic) return } bucket := timeBucket(ref.t.regTime[ref.idx] / mclock.AbsTime(ticketTimeBucketLen)) - list := tickets[bucket] + list := tickets.buckets[bucket] idx := -1 for i, bt := range list { if bt.t == ref.t { @@ -442,18 +444,15 @@ func (s *ticketStore) removeTicketRef(ref ticketRef) { } list = append(list[:idx], list[idx+1:]...) if len(list) != 0 { - tickets[bucket] = list + tickets.buckets[bucket] = list } else { - delete(tickets, bucket) + delete(tickets.buckets, bucket) } ref.t.refCnt-- if ref.t.refCnt == 0 { delete(s.nodes, ref.t.node) delete(s.nodeLastReq, ref.t.node) } - - // Make nextRegisterableTicket return the next available ticket. - s.nextTicketCached = nil } type lookupInfo struct { @@ -523,21 +522,21 @@ func (s *ticketStore) adjustWithTicket(now mclock.AbsTime, targetHash common.Has } } -func (s *ticketStore) addTicket(localTime mclock.AbsTime, pingHash []byte, t *ticket) { - debugLog(fmt.Sprintf("add(node = %x sn = %v)", t.node.ID[:8], t.serial)) +func (s *ticketStore) addTicket(localTime mclock.AbsTime, pingHash []byte, ticket *ticket) { + log.Trace("Adding discovery ticket", "node", ticket.node.ID, "serial", ticket.serial) - lastReq, ok := s.nodeLastReq[t.node] + lastReq, ok := s.nodeLastReq[ticket.node] if !(ok && bytes.Equal(pingHash, lastReq.pingHash)) { return } - s.adjustWithTicket(localTime, lastReq.lookup.target, t) + s.adjustWithTicket(localTime, lastReq.lookup.target, ticket) - if lastReq.lookup.radiusLookup || s.nodes[t.node] != nil { + if lastReq.lookup.radiusLookup || s.nodes[ticket.node] != nil { return } topic := lastReq.lookup.topic - topicIdx := t.findIdx(topic) + topicIdx := ticket.findIdx(topic) if topicIdx == -1 { return } @@ -548,29 +547,29 @@ func (s *ticketStore) addTicket(localTime mclock.AbsTime, pingHash []byte, t *ti } if _, ok := s.tickets[topic]; ok { - wait := t.regTime[topicIdx] - localTime + wait := ticket.regTime[topicIdx] - localTime rnd := rand.ExpFloat64() if rnd > 10 { rnd = 10 } if float64(wait) < float64(keepTicketConst)+float64(keepTicketExp)*rnd { // use the ticket to register this topic - //fmt.Println("addTicket", t.node.ID[:8], t.node.addr().String(), t.serial, t.pong) - s.addTicketRef(ticketRef{t, topicIdx}) + //fmt.Println("addTicket", ticket.node.ID[:8], ticket.node.addr().String(), ticket.serial, ticket.pong) + s.addTicketRef(ticketRef{ticket, topicIdx}) } } - if t.refCnt > 0 { + if ticket.refCnt > 0 { s.nextTicketCached = nil - s.nodes[t.node] = t + s.nodes[ticket.node] = ticket } } func (s *ticketStore) getNodeTicket(node *Node) *ticket { if s.nodes[node] == nil { - debugLog(fmt.Sprintf("getNodeTicket(%x) sn = nil", node.ID[:8])) + log.Trace("Retrieving node ticket", "node", node.ID, "serial", nil) } else { - debugLog(fmt.Sprintf("getNodeTicket(%x) sn = %v", node.ID[:8], s.nodes[node].serial)) + log.Trace("Retrieving node ticket", "node", node.ID, "serial", s.nodes[node].serial) } return s.nodes[node] } @@ -643,7 +642,7 @@ func (s *ticketStore) gotTopicNodes(from *Node, hash common.Hash, nodes []rpcNod if ip.IsUnspecified() || ip.IsLoopback() { ip = from.IP } - n := NewNode(node.ID, ip, node.UDP-1, node.TCP-1) // subtract one from port while discv5 is running in test mode on UDPport+1 + n := NewNode(node.ID, ip, node.UDP, node.TCP) select { case chn <- n: default: diff --git a/p2p/discv5/topic.go b/p2p/discv5/topic.go index b6bea013c..e7a7f8e02 100644 --- a/p2p/discv5/topic.go +++ b/p2p/discv5/topic.go @@ -24,6 +24,7 @@ import ( "time" "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/log" ) const ( @@ -235,7 +236,7 @@ func (t *topicTable) deleteEntry(e *topicEntry) { // It is assumed that topics and waitPeriods have the same length. func (t *topicTable) useTicket(node *Node, serialNo uint32, topics []Topic, idx int, issueTime uint64, waitPeriods []uint32) (registered bool) { - debugLog(fmt.Sprintf("useTicket %v %v %v", serialNo, topics, waitPeriods)) + log.Trace("Using discovery ticket", "serial", serialNo, "topics", topics, "waits", waitPeriods) //fmt.Println("useTicket", serialNo, topics, waitPeriods) t.collectGarbage() diff --git a/p2p/discv5/udp.go b/p2p/discv5/udp.go index 26087cd8e..543771817 100644 --- a/p2p/discv5/udp.go +++ b/p2p/discv5/udp.go @@ -37,7 +37,7 @@ const Version = 4 // Errors var ( errPacketTooSmall = errors.New("too small") - errBadHash = errors.New("bad hash") + errBadPrefix = errors.New("bad prefix") errExpired = errors.New("expired") errUnsolicitedReply = errors.New("unsolicited reply") errUnknownNode = errors.New("unknown node") @@ -145,10 +145,11 @@ type ( } ) -const ( - macSize = 256 / 8 - sigSize = 520 / 8 - headSize = macSize + sigSize // space of packet frame data +var ( + versionPrefix = []byte("temporary discovery v5") + versionPrefixSize = len(versionPrefix) + sigSize = 520 / 8 + headSize = versionPrefixSize + sigSize // space of packet frame data ) // Neighbors replies are sent across multiple packets to @@ -237,30 +238,23 @@ type udp struct { } // ListenUDP returns a new table that listens for UDP packets on laddr. -func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string, netrestrict *netutil.Netlist) (*Network, error) { - transport, err := listenUDP(priv, laddr) +func ListenUDP(priv *ecdsa.PrivateKey, conn conn, realaddr *net.UDPAddr, nodeDBPath string, netrestrict *netutil.Netlist) (*Network, error) { + transport, err := listenUDP(priv, conn, realaddr) if err != nil { return nil, err } - net, err := newNetwork(transport, priv.PublicKey, natm, nodeDBPath, netrestrict) + net, err := newNetwork(transport, priv.PublicKey, nodeDBPath, netrestrict) if err != nil { return nil, err } + log.Info("UDP listener up", "net", net.tab.self) transport.net = net go transport.readLoop() return net, nil } -func listenUDP(priv *ecdsa.PrivateKey, laddr string) (*udp, error) { - addr, err := net.ResolveUDPAddr("udp", laddr) - if err != nil { - return nil, err - } - conn, err := net.ListenUDP("udp", addr) - if err != nil { - return nil, err - } - return &udp{conn: conn, priv: priv, ourEndpoint: makeEndpoint(addr, uint16(addr.Port))}, nil +func listenUDP(priv *ecdsa.PrivateKey, conn conn, realaddr *net.UDPAddr) (*udp, error) { + return &udp{conn: conn, priv: priv, ourEndpoint: makeEndpoint(realaddr, uint16(realaddr.Port))}, nil } func (t *udp) localAddr() *net.UDPAddr { @@ -372,11 +366,9 @@ func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) (p, hash log.Error(fmt.Sprint("could not sign packet:", err)) return nil, nil, err } - copy(packet[macSize:], sig) - // add the hash to the front. Note: this doesn't protect the - // packet in any way. - hash = crypto.Keccak256(packet[macSize:]) - copy(packet, hash) + copy(packet, versionPrefix) + copy(packet[versionPrefixSize:], sig) + hash = crypto.Keccak256(packet[versionPrefixSize:]) return packet, hash, nil } @@ -420,17 +412,16 @@ func decodePacket(buffer []byte, pkt *ingressPacket) error { } buf := make([]byte, len(buffer)) copy(buf, buffer) - hash, sig, sigdata := buf[:macSize], buf[macSize:headSize], buf[headSize:] - shouldhash := crypto.Keccak256(buf[macSize:]) - if !bytes.Equal(hash, shouldhash) { - return errBadHash + prefix, sig, sigdata := buf[:versionPrefixSize], buf[versionPrefixSize:headSize], buf[headSize:] + if !bytes.Equal(prefix, versionPrefix) { + return errBadPrefix } fromID, err := recoverNodeID(crypto.Keccak256(buf[headSize:]), sig) if err != nil { return err } pkt.rawData = buf - pkt.hash = hash + pkt.hash = crypto.Keccak256(buf[versionPrefixSize:]) pkt.remoteID = fromID switch pkt.ev = nodeEvent(sigdata[0]); pkt.ev { case pingPacket: diff --git a/p2p/enr/enr.go b/p2p/enr/enr.go new file mode 100644 index 000000000..2c3afb43e --- /dev/null +++ b/p2p/enr/enr.go @@ -0,0 +1,290 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +// Package enr implements Ethereum Node Records as defined in EIP-778. A node record holds +// arbitrary information about a node on the peer-to-peer network. +// +// Records contain named keys. To store and retrieve key/values in a record, use the Entry +// interface. +// +// Records must be signed before transmitting them to another node. Decoding a record verifies +// its signature. When creating a record, set the entries you want, then call Sign to add the +// signature. Modifying a record invalidates the signature. +// +// Package enr supports the "secp256k1-keccak" identity scheme. +package enr + +import ( + "bytes" + "crypto/ecdsa" + "errors" + "fmt" + "io" + "sort" + + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/crypto/sha3" + "github.com/ethereum/go-ethereum/rlp" +) + +const SizeLimit = 300 // maximum encoded size of a node record in bytes + +const ID_SECP256k1_KECCAK = ID("secp256k1-keccak") // the default identity scheme + +var ( + errNoID = errors.New("unknown or unspecified identity scheme") + errInvalidSigsize = errors.New("invalid signature size") + errInvalidSig = errors.New("invalid signature") + errNotSorted = errors.New("record key/value pairs are not sorted by key") + errDuplicateKey = errors.New("record contains duplicate key") + errIncompletePair = errors.New("record contains incomplete k/v pair") + errTooBig = fmt.Errorf("record bigger than %d bytes", SizeLimit) + errEncodeUnsigned = errors.New("can't encode unsigned record") + errNotFound = errors.New("no such key in record") +) + +// Record represents a node record. The zero value is an empty record. +type Record struct { + seq uint64 // sequence number + signature []byte // the signature + raw []byte // RLP encoded record + pairs []pair // sorted list of all key/value pairs +} + +// pair is a key/value pair in a record. +type pair struct { + k string + v rlp.RawValue +} + +// Signed reports whether the record has a valid signature. +func (r *Record) Signed() bool { + return r.signature != nil +} + +// Seq returns the sequence number. +func (r *Record) Seq() uint64 { + return r.seq +} + +// SetSeq updates the record sequence number. This invalidates any signature on the record. +// Calling SetSeq is usually not required because signing the redord increments the +// sequence number. +func (r *Record) SetSeq(s uint64) { + r.signature = nil + r.raw = nil + r.seq = s +} + +// Load retrieves the value of a key/value pair. The given Entry must be a pointer and will +// be set to the value of the entry in the record. +// +// Errors returned by Load are wrapped in KeyError. You can distinguish decoding errors +// from missing keys using the IsNotFound function. +func (r *Record) Load(e Entry) error { + i := sort.Search(len(r.pairs), func(i int) bool { return r.pairs[i].k >= e.ENRKey() }) + if i < len(r.pairs) && r.pairs[i].k == e.ENRKey() { + if err := rlp.DecodeBytes(r.pairs[i].v, e); err != nil { + return &KeyError{Key: e.ENRKey(), Err: err} + } + return nil + } + return &KeyError{Key: e.ENRKey(), Err: errNotFound} +} + +// Set adds or updates the given entry in the record. +// It panics if the value can't be encoded. +func (r *Record) Set(e Entry) { + r.signature = nil + r.raw = nil + blob, err := rlp.EncodeToBytes(e) + if err != nil { + panic(fmt.Errorf("enr: can't encode %s: %v", e.ENRKey(), err)) + } + + i := sort.Search(len(r.pairs), func(i int) bool { return r.pairs[i].k >= e.ENRKey() }) + + if i < len(r.pairs) && r.pairs[i].k == e.ENRKey() { + // element is present at r.pairs[i] + r.pairs[i].v = blob + return + } else if i < len(r.pairs) { + // insert pair before i-th elem + el := pair{e.ENRKey(), blob} + r.pairs = append(r.pairs, pair{}) + copy(r.pairs[i+1:], r.pairs[i:]) + r.pairs[i] = el + return + } + + // element should be placed at the end of r.pairs + r.pairs = append(r.pairs, pair{e.ENRKey(), blob}) +} + +// EncodeRLP implements rlp.Encoder. Encoding fails if +// the record is unsigned. +func (r Record) EncodeRLP(w io.Writer) error { + if !r.Signed() { + return errEncodeUnsigned + } + _, err := w.Write(r.raw) + return err +} + +// DecodeRLP implements rlp.Decoder. Decoding verifies the signature. +func (r *Record) DecodeRLP(s *rlp.Stream) error { + raw, err := s.Raw() + if err != nil { + return err + } + if len(raw) > SizeLimit { + return errTooBig + } + + // Decode the RLP container. + dec := Record{raw: raw} + s = rlp.NewStream(bytes.NewReader(raw), 0) + if _, err := s.List(); err != nil { + return err + } + if err = s.Decode(&dec.signature); err != nil { + return err + } + if err = s.Decode(&dec.seq); err != nil { + return err + } + // The rest of the record contains sorted k/v pairs. + var prevkey string + for i := 0; ; i++ { + var kv pair + if err := s.Decode(&kv.k); err != nil { + if err == rlp.EOL { + break + } + return err + } + if err := s.Decode(&kv.v); err != nil { + if err == rlp.EOL { + return errIncompletePair + } + return err + } + if i > 0 { + if kv.k == prevkey { + return errDuplicateKey + } + if kv.k < prevkey { + return errNotSorted + } + } + dec.pairs = append(dec.pairs, kv) + prevkey = kv.k + } + if err := s.ListEnd(); err != nil { + return err + } + + // Verify signature. + if err = dec.verifySignature(); err != nil { + return err + } + *r = dec + return nil +} + +type s256raw []byte + +func (s256raw) ENRKey() string { return "secp256k1" } + +// NodeAddr returns the node address. The return value will be nil if the record is +// unsigned. +func (r *Record) NodeAddr() []byte { + var entry s256raw + if r.Load(&entry) != nil { + return nil + } + return crypto.Keccak256(entry) +} + +// Sign signs the record with the given private key. It updates the record's identity +// scheme, public key and increments the sequence number. Sign returns an error if the +// encoded record is larger than the size limit. +func (r *Record) Sign(privkey *ecdsa.PrivateKey) error { + r.seq = r.seq + 1 + r.Set(ID_SECP256k1_KECCAK) + r.Set(Secp256k1(privkey.PublicKey)) + return r.signAndEncode(privkey) +} + +func (r *Record) appendPairs(list []interface{}) []interface{} { + list = append(list, r.seq) + for _, p := range r.pairs { + list = append(list, p.k, p.v) + } + return list +} + +func (r *Record) signAndEncode(privkey *ecdsa.PrivateKey) error { + // Put record elements into a flat list. Leave room for the signature. + list := make([]interface{}, 1, len(r.pairs)*2+2) + list = r.appendPairs(list) + + // Sign the tail of the list. + h := sha3.NewKeccak256() + rlp.Encode(h, list[1:]) + sig, err := crypto.Sign(h.Sum(nil), privkey) + if err != nil { + return err + } + sig = sig[:len(sig)-1] // remove v + + // Put signature in front. + r.signature, list[0] = sig, sig + r.raw, err = rlp.EncodeToBytes(list) + if err != nil { + return err + } + if len(r.raw) > SizeLimit { + return errTooBig + } + return nil +} + +func (r *Record) verifySignature() error { + // Get identity scheme, public key, signature. + var id ID + var entry s256raw + if err := r.Load(&id); err != nil { + return err + } else if id != ID_SECP256k1_KECCAK { + return errNoID + } + if err := r.Load(&entry); err != nil { + return err + } else if len(entry) != 33 { + return fmt.Errorf("invalid public key") + } + + // Verify the signature. + list := make([]interface{}, 0, len(r.pairs)*2+1) + list = r.appendPairs(list) + h := sha3.NewKeccak256() + rlp.Encode(h, list) + if !crypto.VerifySignature(entry, h.Sum(nil), r.signature) { + return errInvalidSig + } + return nil +} diff --git a/p2p/enr/enr_test.go b/p2p/enr/enr_test.go new file mode 100644 index 000000000..ce7767d10 --- /dev/null +++ b/p2p/enr/enr_test.go @@ -0,0 +1,318 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +package enr + +import ( + "bytes" + "encoding/hex" + "fmt" + "math/rand" + "testing" + "time" + + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/rlp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var ( + privkey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") + pubkey = &privkey.PublicKey +) + +var rnd = rand.New(rand.NewSource(time.Now().UnixNano())) + +func randomString(strlen int) string { + b := make([]byte, strlen) + rnd.Read(b) + return string(b) +} + +// TestGetSetID tests encoding/decoding and setting/getting of the ID key. +func TestGetSetID(t *testing.T) { + id := ID("someid") + var r Record + r.Set(id) + + var id2 ID + require.NoError(t, r.Load(&id2)) + assert.Equal(t, id, id2) +} + +// TestGetSetIP4 tests encoding/decoding and setting/getting of the IP4 key. +func TestGetSetIP4(t *testing.T) { + ip := IP4{192, 168, 0, 3} + var r Record + r.Set(ip) + + var ip2 IP4 + require.NoError(t, r.Load(&ip2)) + assert.Equal(t, ip, ip2) +} + +// TestGetSetIP6 tests encoding/decoding and setting/getting of the IP6 key. +func TestGetSetIP6(t *testing.T) { + ip := IP6{0x20, 0x01, 0x48, 0x60, 0, 0, 0x20, 0x01, 0, 0, 0, 0, 0, 0, 0x00, 0x68} + var r Record + r.Set(ip) + + var ip2 IP6 + require.NoError(t, r.Load(&ip2)) + assert.Equal(t, ip, ip2) +} + +// TestGetSetDiscPort tests encoding/decoding and setting/getting of the DiscPort key. +func TestGetSetDiscPort(t *testing.T) { + port := DiscPort(30309) + var r Record + r.Set(port) + + var port2 DiscPort + require.NoError(t, r.Load(&port2)) + assert.Equal(t, port, port2) +} + +// TestGetSetSecp256k1 tests encoding/decoding and setting/getting of the Secp256k1 key. +func TestGetSetSecp256k1(t *testing.T) { + var r Record + if err := r.Sign(privkey); err != nil { + t.Fatal(err) + } + + var pk Secp256k1 + require.NoError(t, r.Load(&pk)) + assert.EqualValues(t, pubkey, &pk) +} + +func TestLoadErrors(t *testing.T) { + var r Record + ip4 := IP4{127, 0, 0, 1} + r.Set(ip4) + + // Check error for missing keys. + var ip6 IP6 + err := r.Load(&ip6) + if !IsNotFound(err) { + t.Error("IsNotFound should return true for missing key") + } + assert.Equal(t, &KeyError{Key: ip6.ENRKey(), Err: errNotFound}, err) + + // Check error for invalid keys. + var list []uint + err = r.Load(WithEntry(ip4.ENRKey(), &list)) + kerr, ok := err.(*KeyError) + if !ok { + t.Fatalf("expected KeyError, got %T", err) + } + assert.Equal(t, kerr.Key, ip4.ENRKey()) + assert.Error(t, kerr.Err) + if IsNotFound(err) { + t.Error("IsNotFound should return false for decoding errors") + } +} + +// TestSortedGetAndSet tests that Set produced a sorted pairs slice. +func TestSortedGetAndSet(t *testing.T) { + type pair struct { + k string + v uint32 + } + + for _, tt := range []struct { + input []pair + want []pair + }{ + { + input: []pair{{"a", 1}, {"c", 2}, {"b", 3}}, + want: []pair{{"a", 1}, {"b", 3}, {"c", 2}}, + }, + { + input: []pair{{"a", 1}, {"c", 2}, {"b", 3}, {"d", 4}, {"a", 5}, {"bb", 6}}, + want: []pair{{"a", 5}, {"b", 3}, {"bb", 6}, {"c", 2}, {"d", 4}}, + }, + { + input: []pair{{"c", 2}, {"b", 3}, {"d", 4}, {"a", 5}, {"bb", 6}}, + want: []pair{{"a", 5}, {"b", 3}, {"bb", 6}, {"c", 2}, {"d", 4}}, + }, + } { + var r Record + for _, i := range tt.input { + r.Set(WithEntry(i.k, &i.v)) + } + for i, w := range tt.want { + // set got's key from r.pair[i], so that we preserve order of pairs + got := pair{k: r.pairs[i].k} + assert.NoError(t, r.Load(WithEntry(w.k, &got.v))) + assert.Equal(t, w, got) + } + } +} + +// TestDirty tests record signature removal on setting of new key/value pair in record. +func TestDirty(t *testing.T) { + var r Record + + if r.Signed() { + t.Error("Signed returned true for zero record") + } + if _, err := rlp.EncodeToBytes(r); err != errEncodeUnsigned { + t.Errorf("expected errEncodeUnsigned, got %#v", err) + } + + require.NoError(t, r.Sign(privkey)) + if !r.Signed() { + t.Error("Signed return false for signed record") + } + _, err := rlp.EncodeToBytes(r) + assert.NoError(t, err) + + r.SetSeq(3) + if r.Signed() { + t.Error("Signed returned true for modified record") + } + if _, err := rlp.EncodeToBytes(r); err != errEncodeUnsigned { + t.Errorf("expected errEncodeUnsigned, got %#v", err) + } +} + +// TestGetSetOverwrite tests value overwrite when setting a new value with an existing key in record. +func TestGetSetOverwrite(t *testing.T) { + var r Record + + ip := IP4{192, 168, 0, 3} + r.Set(ip) + + ip2 := IP4{192, 168, 0, 4} + r.Set(ip2) + + var ip3 IP4 + require.NoError(t, r.Load(&ip3)) + assert.Equal(t, ip2, ip3) +} + +// TestSignEncodeAndDecode tests signing, RLP encoding and RLP decoding of a record. +func TestSignEncodeAndDecode(t *testing.T) { + var r Record + r.Set(DiscPort(30303)) + r.Set(IP4{127, 0, 0, 1}) + require.NoError(t, r.Sign(privkey)) + + blob, err := rlp.EncodeToBytes(r) + require.NoError(t, err) + + var r2 Record + require.NoError(t, rlp.DecodeBytes(blob, &r2)) + assert.Equal(t, r, r2) + + blob2, err := rlp.EncodeToBytes(r2) + require.NoError(t, err) + assert.Equal(t, blob, blob2) +} + +func TestNodeAddr(t *testing.T) { + var r Record + if addr := r.NodeAddr(); addr != nil { + t.Errorf("wrong address on empty record: got %v, want %v", addr, nil) + } + + require.NoError(t, r.Sign(privkey)) + expected := "caaa1485d83b18b32ed9ad666026151bf0cae8a0a88c857ae2d4c5be2daa6726" + assert.Equal(t, expected, hex.EncodeToString(r.NodeAddr())) +} + +var pyRecord, _ = hex.DecodeString("f896b840954dc36583c1f4b69ab59b1375f362f06ee99f3723cd77e64b6de6d211c27d7870642a79d4516997f94091325d2a7ca6215376971455fb221d34f35b277149a1018664697363763582765f82696490736563703235366b312d6b656363616b83697034847f00000189736563703235366b31a103ca634cae0d49acb401d8a4c6b6fe8c55b70d115bf400769cc1400f3258cd3138") + +// TestPythonInterop checks that we can decode and verify a record produced by the Python +// implementation. +func TestPythonInterop(t *testing.T) { + var r Record + if err := rlp.DecodeBytes(pyRecord, &r); err != nil { + t.Fatalf("can't decode: %v", err) + } + + var ( + wantAddr, _ = hex.DecodeString("caaa1485d83b18b32ed9ad666026151bf0cae8a0a88c857ae2d4c5be2daa6726") + wantSeq = uint64(1) + wantIP = IP4{127, 0, 0, 1} + wantDiscport = DiscPort(30303) + ) + if r.Seq() != wantSeq { + t.Errorf("wrong seq: got %d, want %d", r.Seq(), wantSeq) + } + if addr := r.NodeAddr(); !bytes.Equal(addr, wantAddr) { + t.Errorf("wrong addr: got %x, want %x", addr, wantAddr) + } + want := map[Entry]interface{}{new(IP4): &wantIP, new(DiscPort): &wantDiscport} + for k, v := range want { + desc := fmt.Sprintf("loading key %q", k.ENRKey()) + if assert.NoError(t, r.Load(k), desc) { + assert.Equal(t, k, v, desc) + } + } +} + +// TestRecordTooBig tests that records bigger than SizeLimit bytes cannot be signed. +func TestRecordTooBig(t *testing.T) { + var r Record + key := randomString(10) + + // set a big value for random key, expect error + r.Set(WithEntry(key, randomString(300))) + if err := r.Sign(privkey); err != errTooBig { + t.Fatalf("expected to get errTooBig, got %#v", err) + } + + // set an acceptable value for random key, expect no error + r.Set(WithEntry(key, randomString(100))) + require.NoError(t, r.Sign(privkey)) +} + +// TestSignEncodeAndDecodeRandom tests encoding/decoding of records containing random key/value pairs. +func TestSignEncodeAndDecodeRandom(t *testing.T) { + var r Record + + // random key/value pairs for testing + pairs := map[string]uint32{} + for i := 0; i < 10; i++ { + key := randomString(7) + value := rnd.Uint32() + pairs[key] = value + r.Set(WithEntry(key, &value)) + } + + require.NoError(t, r.Sign(privkey)) + _, err := rlp.EncodeToBytes(r) + require.NoError(t, err) + + for k, v := range pairs { + desc := fmt.Sprintf("key %q", k) + var got uint32 + buf := WithEntry(k, &got) + require.NoError(t, r.Load(buf), desc) + require.Equal(t, v, got, desc) + } +} + +func BenchmarkDecode(b *testing.B) { + var r Record + for i := 0; i < b.N; i++ { + rlp.DecodeBytes(pyRecord, &r) + } + b.StopTimer() + r.NodeAddr() +} diff --git a/p2p/enr/entries.go b/p2p/enr/entries.go new file mode 100644 index 000000000..7591e6eff --- /dev/null +++ b/p2p/enr/entries.go @@ -0,0 +1,160 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +package enr + +import ( + "crypto/ecdsa" + "fmt" + "io" + "net" + + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/rlp" +) + +// Entry is implemented by known node record entry types. +// +// To define a new entry that is to be included in a node record, +// create a Go type that satisfies this interface. The type should +// also implement rlp.Decoder if additional checks are needed on the value. +type Entry interface { + ENRKey() string +} + +type generic struct { + key string + value interface{} +} + +func (g generic) ENRKey() string { return g.key } + +func (g generic) EncodeRLP(w io.Writer) error { + return rlp.Encode(w, g.value) +} + +func (g *generic) DecodeRLP(s *rlp.Stream) error { + return s.Decode(g.value) +} + +// WithEntry wraps any value with a key name. It can be used to set and load arbitrary values +// in a record. The value v must be supported by rlp. To use WithEntry with Load, the value +// must be a pointer. +func WithEntry(k string, v interface{}) Entry { + return &generic{key: k, value: v} +} + +// DiscPort is the "discv5" key, which holds the UDP port for discovery v5. +type DiscPort uint16 + +func (v DiscPort) ENRKey() string { return "discv5" } + +// ID is the "id" key, which holds the name of the identity scheme. +type ID string + +func (v ID) ENRKey() string { return "id" } + +// IP4 is the "ip4" key, which holds a 4-byte IPv4 address. +type IP4 net.IP + +func (v IP4) ENRKey() string { return "ip4" } + +// EncodeRLP implements rlp.Encoder. +func (v IP4) EncodeRLP(w io.Writer) error { + ip4 := net.IP(v).To4() + if ip4 == nil { + return fmt.Errorf("invalid IPv4 address: %v", v) + } + return rlp.Encode(w, ip4) +} + +// DecodeRLP implements rlp.Decoder. +func (v *IP4) DecodeRLP(s *rlp.Stream) error { + if err := s.Decode((*net.IP)(v)); err != nil { + return err + } + if len(*v) != 4 { + return fmt.Errorf("invalid IPv4 address, want 4 bytes: %v", *v) + } + return nil +} + +// IP6 is the "ip6" key, which holds a 16-byte IPv6 address. +type IP6 net.IP + +func (v IP6) ENRKey() string { return "ip6" } + +// EncodeRLP implements rlp.Encoder. +func (v IP6) EncodeRLP(w io.Writer) error { + ip6 := net.IP(v) + return rlp.Encode(w, ip6) +} + +// DecodeRLP implements rlp.Decoder. +func (v *IP6) DecodeRLP(s *rlp.Stream) error { + if err := s.Decode((*net.IP)(v)); err != nil { + return err + } + if len(*v) != 16 { + return fmt.Errorf("invalid IPv6 address, want 16 bytes: %v", *v) + } + return nil +} + +// Secp256k1 is the "secp256k1" key, which holds a public key. +type Secp256k1 ecdsa.PublicKey + +func (v Secp256k1) ENRKey() string { return "secp256k1" } + +// EncodeRLP implements rlp.Encoder. +func (v Secp256k1) EncodeRLP(w io.Writer) error { + return rlp.Encode(w, crypto.CompressPubkey((*ecdsa.PublicKey)(&v))) +} + +// DecodeRLP implements rlp.Decoder. +func (v *Secp256k1) DecodeRLP(s *rlp.Stream) error { + buf, err := s.Bytes() + if err != nil { + return err + } + pk, err := crypto.DecompressPubkey(buf) + if err != nil { + return err + } + *v = (Secp256k1)(*pk) + return nil +} + +// KeyError is an error related to a key. +type KeyError struct { + Key string + Err error +} + +// Error implements error. +func (err *KeyError) Error() string { + if err.Err == errNotFound { + return fmt.Sprintf("missing ENR key %q", err.Key) + } + return fmt.Sprintf("ENR key %q: %v", err.Key, err.Err) +} + +// IsNotFound reports whether the given error means that a key/value pair is +// missing from a record. +func IsNotFound(err error) bool { + kerr, ok := err.(*KeyError) + return ok && kerr.Err == errNotFound +} diff --git a/p2p/server.go b/p2p/server.go index 922df55ba..2cff94ea5 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -78,9 +78,6 @@ type Config struct { // protocol should be started or not. DiscoveryV5 bool `toml:",omitempty"` - // Listener address for the V5 discovery protocol UDP traffic. - DiscoveryV5Addr string `toml:",omitempty"` - // Name sets the node name of this server. // Use common.MakeName to create a name that follows existing conventions. Name string `toml:"-"` @@ -354,6 +351,32 @@ func (srv *Server) Stop() { srv.loopWG.Wait() } +// sharedUDPConn implements a shared connection. Write sends messages to the underlying connection while read returns +// messages that were found unprocessable and sent to the unhandled channel by the primary listener. +type sharedUDPConn struct { + *net.UDPConn + unhandled chan discover.ReadPacket +} + +// ReadFromUDP implements discv5.conn +func (s *sharedUDPConn) ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) { + packet, ok := <-s.unhandled + if !ok { + return 0, nil, fmt.Errorf("Connection was closed") + } + l := len(packet.Data) + if l > len(b) { + l = len(b) + } + copy(b[:l], packet.Data[:l]) + return l, packet.Addr, nil +} + +// Close implements discv5.conn +func (s *sharedUDPConn) Close() error { + return nil +} + // Start starts running the server. // Servers can not be re-used after stopping. func (srv *Server) Start() (err error) { @@ -388,9 +411,43 @@ func (srv *Server) Start() (err error) { srv.peerOp = make(chan peerOpFunc) srv.peerOpDone = make(chan struct{}) + var ( + conn *net.UDPConn + sconn *sharedUDPConn + realaddr *net.UDPAddr + unhandled chan discover.ReadPacket + ) + + if !srv.NoDiscovery || srv.DiscoveryV5 { + addr, err := net.ResolveUDPAddr("udp", srv.ListenAddr) + if err != nil { + return err + } + conn, err = net.ListenUDP("udp", addr) + if err != nil { + return err + } + + realaddr = conn.LocalAddr().(*net.UDPAddr) + if srv.NAT != nil { + if !realaddr.IP.IsLoopback() { + go nat.Map(srv.NAT, srv.quit, "udp", realaddr.Port, realaddr.Port, "ethereum discovery") + } + // TODO: react to external IP changes over time. + if ext, err := srv.NAT.ExternalIP(); err == nil { + realaddr = &net.UDPAddr{IP: ext, Port: realaddr.Port} + } + } + } + + if !srv.NoDiscovery && srv.DiscoveryV5 { + unhandled = make(chan discover.ReadPacket, 100) + sconn = &sharedUDPConn{conn, unhandled} + } + // node table if !srv.NoDiscovery { - ntab, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT, srv.NodeDatabase, srv.NetRestrict) + ntab, err := discover.ListenUDP(srv.PrivateKey, conn, realaddr, unhandled, srv.NodeDatabase, srv.NetRestrict) if err != nil { return err } @@ -401,7 +458,15 @@ func (srv *Server) Start() (err error) { } if srv.DiscoveryV5 { - ntab, err := discv5.ListenUDP(srv.PrivateKey, srv.DiscoveryV5Addr, srv.NAT, "", srv.NetRestrict) //srv.NodeDatabase) + var ( + ntab *discv5.Network + err error + ) + if sconn != nil { + ntab, err = discv5.ListenUDP(srv.PrivateKey, sconn, realaddr, "", srv.NetRestrict) //srv.NodeDatabase) + } else { + ntab, err = discv5.ListenUDP(srv.PrivateKey, conn, realaddr, "", srv.NetRestrict) //srv.NodeDatabase) + } if err != nil { return err } |