aboutsummaryrefslogtreecommitdiffstats
path: root/p2p
diff options
context:
space:
mode:
authorJanos Guljas <janos@resenje.org>2018-02-09 19:23:30 +0800
committerJanos Guljas <janos@resenje.org>2018-02-22 21:23:17 +0800
commita3a07350dcef0ba39829a20d8ddba4bd3463d293 (patch)
tree100f2515cadd92105537a12e6981fab2193435ee /p2p
parent820cf09c98706f71d4b02b6c25e62db15830f3fb (diff)
parent1a4e68721a901e86322631fed1191025a6d14c52 (diff)
downloaddexon-a3a07350dcef0ba39829a20d8ddba4bd3463d293.tar
dexon-a3a07350dcef0ba39829a20d8ddba4bd3463d293.tar.gz
dexon-a3a07350dcef0ba39829a20d8ddba4bd3463d293.tar.bz2
dexon-a3a07350dcef0ba39829a20d8ddba4bd3463d293.tar.lz
dexon-a3a07350dcef0ba39829a20d8ddba4bd3463d293.tar.xz
dexon-a3a07350dcef0ba39829a20d8ddba4bd3463d293.tar.zst
dexon-a3a07350dcef0ba39829a20d8ddba4bd3463d293.zip
swarm, cmd/swarm: Merge branch 'master' into multiple-ens-endpoints
Diffstat (limited to 'p2p')
-rw-r--r--p2p/discover/database.go6
-rw-r--r--p2p/discover/udp.go44
-rw-r--r--p2p/discover/udp_test.go3
-rw-r--r--p2p/discv5/database.go6
-rw-r--r--p2p/discv5/net.go59
-rw-r--r--p2p/discv5/net_test.go2
-rw-r--r--p2p/discv5/node.go5
-rw-r--r--p2p/discv5/nodeevent_string.go6
-rw-r--r--p2p/discv5/sim_test.go2
-rw-r--r--p2p/discv5/ticket.go227
-rw-r--r--p2p/discv5/topic.go3
-rw-r--r--p2p/discv5/udp.go47
-rw-r--r--p2p/enr/enr.go290
-rw-r--r--p2p/enr/enr_test.go318
-rw-r--r--p2p/enr/entries.go160
-rw-r--r--p2p/server.go75
16 files changed, 1035 insertions, 218 deletions
diff --git a/p2p/discover/database.go b/p2p/discover/database.go
index 7206a63c6..b136609f2 100644
--- a/p2p/discover/database.go
+++ b/p2p/discover/database.go
@@ -226,14 +226,14 @@ func (db *nodeDB) ensureExpirer() {
// expirer should be started in a go routine, and is responsible for looping ad
// infinitum and dropping stale data from the database.
func (db *nodeDB) expirer() {
- tick := time.Tick(nodeDBCleanupCycle)
+ tick := time.NewTicker(nodeDBCleanupCycle)
+ defer tick.Stop()
for {
select {
- case <-tick:
+ case <-tick.C:
if err := db.expireNodes(); err != nil {
log.Error("Failed to expire nodedb items", "err", err)
}
-
case <-db.quit:
return
}
diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go
index f9eb99ee3..60436952d 100644
--- a/p2p/discover/udp.go
+++ b/p2p/discover/udp.go
@@ -210,17 +210,15 @@ type reply struct {
matched chan<- bool
}
+// ReadPacket is sent to the unhandled channel when it could not be processed
+type ReadPacket struct {
+ Data []byte
+ Addr *net.UDPAddr
+}
+
// ListenUDP returns a new table that listens for UDP packets on laddr.
-func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, error) {
- addr, err := net.ResolveUDPAddr("udp", laddr)
- if err != nil {
- return nil, err
- }
- conn, err := net.ListenUDP("udp", addr)
- if err != nil {
- return nil, err
- }
- tab, _, err := newUDP(priv, conn, natm, nodeDBPath, netrestrict)
+func ListenUDP(priv *ecdsa.PrivateKey, conn conn, realaddr *net.UDPAddr, unhandled chan ReadPacket, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, error) {
+ tab, _, err := newUDP(priv, conn, realaddr, unhandled, nodeDBPath, netrestrict)
if err != nil {
return nil, err
}
@@ -228,7 +226,7 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBP
return tab, nil
}
-func newUDP(priv *ecdsa.PrivateKey, c conn, natm nat.Interface, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, *udp, error) {
+func newUDP(priv *ecdsa.PrivateKey, c conn, realaddr *net.UDPAddr, unhandled chan ReadPacket, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, *udp, error) {
udp := &udp{
conn: c,
priv: priv,
@@ -237,16 +235,6 @@ func newUDP(priv *ecdsa.PrivateKey, c conn, natm nat.Interface, nodeDBPath strin
gotreply: make(chan reply),
addpending: make(chan *pending),
}
- realaddr := c.LocalAddr().(*net.UDPAddr)
- if natm != nil {
- if !realaddr.IP.IsLoopback() {
- go nat.Map(natm, udp.closing, "udp", realaddr.Port, realaddr.Port, "ethereum discovery")
- }
- // TODO: react to external IP changes over time.
- if ext, err := natm.ExternalIP(); err == nil {
- realaddr = &net.UDPAddr{IP: ext, Port: realaddr.Port}
- }
- }
// TODO: separate TCP port
udp.ourEndpoint = makeEndpoint(realaddr, uint16(realaddr.Port))
tab, err := newTable(udp, PubkeyID(&priv.PublicKey), realaddr, nodeDBPath)
@@ -256,7 +244,7 @@ func newUDP(priv *ecdsa.PrivateKey, c conn, natm nat.Interface, nodeDBPath strin
udp.Table = tab
go udp.loop()
- go udp.readLoop()
+ go udp.readLoop(unhandled)
return udp.Table, udp, nil
}
@@ -492,8 +480,11 @@ func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) ([]byte,
}
// readLoop runs in its own goroutine. it handles incoming UDP packets.
-func (t *udp) readLoop() {
+func (t *udp) readLoop(unhandled chan ReadPacket) {
defer t.conn.Close()
+ if unhandled != nil {
+ defer close(unhandled)
+ }
// Discovery packets are defined to be no larger than 1280 bytes.
// Packets larger than this size will be cut at the end and treated
// as invalid because their hash won't match.
@@ -509,7 +500,12 @@ func (t *udp) readLoop() {
log.Debug("UDP read error", "err", err)
return
}
- t.handlePacket(from, buf[:nbytes])
+ if t.handlePacket(from, buf[:nbytes]) != nil && unhandled != nil {
+ select {
+ case unhandled <- ReadPacket{buf[:nbytes], from}:
+ default:
+ }
+ }
}
}
diff --git a/p2p/discover/udp_test.go b/p2p/discover/udp_test.go
index 21e8b561d..b81caf839 100644
--- a/p2p/discover/udp_test.go
+++ b/p2p/discover/udp_test.go
@@ -70,7 +70,8 @@ func newUDPTest(t *testing.T) *udpTest {
remotekey: newkey(),
remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303},
}
- test.table, test.udp, _ = newUDP(test.localkey, test.pipe, nil, "", nil)
+ realaddr := test.pipe.LocalAddr().(*net.UDPAddr)
+ test.table, test.udp, _ = newUDP(test.localkey, test.pipe, realaddr, nil, "", nil)
return test
}
diff --git a/p2p/discv5/database.go b/p2p/discv5/database.go
index a3b044ec1..3c2d5744c 100644
--- a/p2p/discv5/database.go
+++ b/p2p/discv5/database.go
@@ -239,14 +239,14 @@ func (db *nodeDB) ensureExpirer() {
// expirer should be started in a go routine, and is responsible for looping ad
// infinitum and dropping stale data from the database.
func (db *nodeDB) expirer() {
- tick := time.Tick(nodeDBCleanupCycle)
+ tick := time.NewTicker(nodeDBCleanupCycle)
+ defer tick.Stop()
for {
select {
- case <-tick:
+ case <-tick.C:
if err := db.expireNodes(); err != nil {
log.Error(fmt.Sprintf("Failed to expire nodedb items: %v", err))
}
-
case <-db.quit:
return
}
diff --git a/p2p/discv5/net.go b/p2p/discv5/net.go
index 2fbb60824..f9baf126f 100644
--- a/p2p/discv5/net.go
+++ b/p2p/discv5/net.go
@@ -29,7 +29,6 @@ import (
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/crypto/sha3"
"github.com/ethereum/go-ethereum/log"
- "github.com/ethereum/go-ethereum/p2p/nat"
"github.com/ethereum/go-ethereum/p2p/netutil"
"github.com/ethereum/go-ethereum/rlp"
)
@@ -51,16 +50,9 @@ const (
const testTopic = "foo"
const (
- printDebugLogs = false
printTestImgLogs = false
)
-func debugLog(s string) {
- if printDebugLogs {
- fmt.Println(s)
- }
-}
-
// Network manages the table and all protocol interaction.
type Network struct {
db *nodeDB // database of known nodes
@@ -141,7 +133,7 @@ type timeoutEvent struct {
node *Node
}
-func newNetwork(conn transport, ourPubkey ecdsa.PublicKey, natm nat.Interface, dbPath string, netrestrict *netutil.Netlist) (*Network, error) {
+func newNetwork(conn transport, ourPubkey ecdsa.PublicKey, dbPath string, netrestrict *netutil.Netlist) (*Network, error) {
ourID := PubkeyID(&ourPubkey)
var db *nodeDB
@@ -388,14 +380,14 @@ func (net *Network) loop() {
}
}()
resetNextTicket := func() {
- t, timeout := net.ticketStore.nextFilteredTicket()
- if t != nextTicket {
- nextTicket = t
+ ticket, timeout := net.ticketStore.nextFilteredTicket()
+ if nextTicket != ticket {
+ nextTicket = ticket
if nextRegisterTimer != nil {
nextRegisterTimer.Stop()
nextRegisterTime = nil
}
- if t != nil {
+ if ticket != nil {
nextRegisterTimer = time.NewTimer(timeout)
nextRegisterTime = nextRegisterTimer.C
}
@@ -423,13 +415,13 @@ loop:
select {
case <-net.closeReq:
- debugLog("<-net.closeReq")
+ log.Trace("<-net.closeReq")
break loop
// Ingress packet handling.
case pkt := <-net.read:
//fmt.Println("read", pkt.ev)
- debugLog("<-net.read")
+ log.Trace("<-net.read")
n := net.internNode(&pkt)
prestate := n.state
status := "ok"
@@ -444,7 +436,7 @@ loop:
// State transition timeouts.
case timeout := <-net.timeout:
- debugLog("<-net.timeout")
+ log.Trace("<-net.timeout")
if net.timeoutTimers[timeout] == nil {
// Stale timer (was aborted).
continue
@@ -462,20 +454,20 @@ loop:
// Querying.
case q := <-net.queryReq:
- debugLog("<-net.queryReq")
+ log.Trace("<-net.queryReq")
if !q.start(net) {
q.remote.deferQuery(q)
}
// Interacting with the table.
case f := <-net.tableOpReq:
- debugLog("<-net.tableOpReq")
+ log.Trace("<-net.tableOpReq")
f()
net.tableOpResp <- struct{}{}
// Topic registration stuff.
case req := <-net.topicRegisterReq:
- debugLog("<-net.topicRegisterReq")
+ log.Trace("<-net.topicRegisterReq")
if !req.add {
net.ticketStore.removeRegisterTopic(req.topic)
continue
@@ -486,7 +478,7 @@ loop:
// determination for new topics.
// if topicRegisterLookupDone == nil {
if topicRegisterLookupTarget.target == (common.Hash{}) {
- debugLog("topicRegisterLookupTarget == null")
+ log.Trace("topicRegisterLookupTarget == null")
if topicRegisterLookupTick.Stop() {
<-topicRegisterLookupTick.C
}
@@ -496,7 +488,7 @@ loop:
}
case nodes := <-topicRegisterLookupDone:
- debugLog("<-topicRegisterLookupDone")
+ log.Trace("<-topicRegisterLookupDone")
net.ticketStore.registerLookupDone(topicRegisterLookupTarget, nodes, func(n *Node) []byte {
net.ping(n, n.addr())
return n.pingEcho
@@ -507,7 +499,7 @@ loop:
topicRegisterLookupDone = nil
case <-topicRegisterLookupTick.C:
- debugLog("<-topicRegisterLookupTick")
+ log.Trace("<-topicRegisterLookupTick")
if (topicRegisterLookupTarget.target == common.Hash{}) {
target, delay := net.ticketStore.nextRegisterLookup()
topicRegisterLookupTarget = target
@@ -520,14 +512,14 @@ loop:
}
case <-nextRegisterTime:
- debugLog("<-nextRegisterTime")
+ log.Trace("<-nextRegisterTime")
net.ticketStore.ticketRegistered(*nextTicket)
//fmt.Println("sendTopicRegister", nextTicket.t.node.addr().String(), nextTicket.t.topics, nextTicket.idx, nextTicket.t.pong)
net.conn.sendTopicRegister(nextTicket.t.node, nextTicket.t.topics, nextTicket.idx, nextTicket.t.pong)
case req := <-net.topicSearchReq:
if refreshDone == nil {
- debugLog("<-net.topicSearchReq")
+ log.Trace("<-net.topicSearchReq")
info, ok := searchInfo[req.topic]
if ok {
if req.delay == time.Duration(0) {
@@ -588,7 +580,7 @@ loop:
})
case <-statsDump.C:
- debugLog("<-statsDump.C")
+ log.Trace("<-statsDump.C")
/*r, ok := net.ticketStore.radius[testTopic]
if !ok {
fmt.Printf("(%x) no radius @ %v\n", net.tab.self.ID[:8], time.Now())
@@ -617,7 +609,7 @@ loop:
// Periodic / lookup-initiated bucket refresh.
case <-refreshTimer.C:
- debugLog("<-refreshTimer.C")
+ log.Trace("<-refreshTimer.C")
// TODO: ideally we would start the refresh timer after
// fallback nodes have been set for the first time.
if refreshDone == nil {
@@ -631,7 +623,7 @@ loop:
bucketRefreshTimer.Reset(bucketRefreshInterval)
}()
case newNursery := <-net.refreshReq:
- debugLog("<-net.refreshReq")
+ log.Trace("<-net.refreshReq")
if newNursery != nil {
net.nursery = newNursery
}
@@ -641,7 +633,7 @@ loop:
}
net.refreshResp <- refreshDone
case <-refreshDone:
- debugLog("<-net.refreshDone")
+ log.Trace("<-net.refreshDone")
refreshDone = nil
list := searchReqWhenRefreshDone
searchReqWhenRefreshDone = nil
@@ -652,7 +644,7 @@ loop:
}()
}
}
- debugLog("loop stopped")
+ log.Trace("loop stopped")
log.Debug(fmt.Sprintf("shutting down"))
if net.conn != nil {
@@ -1109,14 +1101,14 @@ func (net *Network) ping(n *Node, addr *net.UDPAddr) {
//fmt.Println(" not sent")
return
}
- debugLog(fmt.Sprintf("ping(node = %x)", n.ID[:8]))
+ log.Trace("Pinging remote node", "node", n.ID)
n.pingTopics = net.ticketStore.regTopicSet()
n.pingEcho = net.conn.sendPing(n, addr, n.pingTopics)
net.timedEvent(respTimeout, n, pongTimeout)
}
func (net *Network) handlePing(n *Node, pkt *ingressPacket) {
- debugLog(fmt.Sprintf("handlePing(node = %x)", n.ID[:8]))
+ log.Trace("Handling remote ping", "node", n.ID)
ping := pkt.data.(*ping)
n.TCP = ping.From.TCP
t := net.topictab.getTicket(n, ping.Topics)
@@ -1131,7 +1123,7 @@ func (net *Network) handlePing(n *Node, pkt *ingressPacket) {
}
func (net *Network) handleKnownPong(n *Node, pkt *ingressPacket) error {
- debugLog(fmt.Sprintf("handleKnownPong(node = %x)", n.ID[:8]))
+ log.Trace("Handling known pong", "node", n.ID)
net.abortTimedEvent(n, pongTimeout)
now := mclock.Now()
ticket, err := pongToTicket(now, n.pingTopics, n, pkt)
@@ -1139,9 +1131,8 @@ func (net *Network) handleKnownPong(n *Node, pkt *ingressPacket) error {
// fmt.Printf("(%x) ticket: %+v\n", net.tab.self.ID[:8], pkt.data)
net.ticketStore.addTicket(now, pkt.data.(*pong).ReplyTok, ticket)
} else {
- debugLog(fmt.Sprintf(" error: %v", err))
+ log.Trace("Failed to convert pong to ticket", "err", err)
}
-
n.pingEcho = nil
n.pingTopics = nil
return err
diff --git a/p2p/discv5/net_test.go b/p2p/discv5/net_test.go
index bd234f5ba..369282ca9 100644
--- a/p2p/discv5/net_test.go
+++ b/p2p/discv5/net_test.go
@@ -28,7 +28,7 @@ import (
func TestNetwork_Lookup(t *testing.T) {
key, _ := crypto.GenerateKey()
- network, err := newNetwork(lookupTestnet, key.PublicKey, nil, "", nil)
+ network, err := newNetwork(lookupTestnet, key.PublicKey, "", nil)
if err != nil {
t.Fatal(err)
}
diff --git a/p2p/discv5/node.go b/p2p/discv5/node.go
index 2db7a508f..fd88a55b1 100644
--- a/p2p/discv5/node.go
+++ b/p2p/discv5/node.go
@@ -273,6 +273,11 @@ func (n NodeID) GoString() string {
return fmt.Sprintf("discover.HexID(\"%x\")", n[:])
}
+// TerminalString returns a shortened hex string for terminal logging.
+func (n NodeID) TerminalString() string {
+ return hex.EncodeToString(n[:8])
+}
+
// HexID converts a hex string to a NodeID.
// The string may be prefixed with 0x.
func HexID(in string) (NodeID, error) {
diff --git a/p2p/discv5/nodeevent_string.go b/p2p/discv5/nodeevent_string.go
index fde9045c5..eb696fb8b 100644
--- a/p2p/discv5/nodeevent_string.go
+++ b/p2p/discv5/nodeevent_string.go
@@ -1,8 +1,8 @@
-// Code generated by "stringer -type nodeEvent"; DO NOT EDIT
+// Code generated by "stringer -type=nodeEvent"; DO NOT EDIT.
package discv5
-import "fmt"
+import "strconv"
const (
_nodeEvent_name_0 = "invalidEventpingPacketpongPacketfindnodePacketneighborsPacketfindnodeHashPackettopicRegisterPackettopicQueryPackettopicNodesPacket"
@@ -22,6 +22,6 @@ func (i nodeEvent) String() string {
i -= 265
return _nodeEvent_name_1[_nodeEvent_index_1[i]:_nodeEvent_index_1[i+1]]
default:
- return fmt.Sprintf("nodeEvent(%d)", i)
+ return "nodeEvent(" + strconv.FormatInt(int64(i), 10) + ")"
}
}
diff --git a/p2p/discv5/sim_test.go b/p2p/discv5/sim_test.go
index bf57872e2..543faecd4 100644
--- a/p2p/discv5/sim_test.go
+++ b/p2p/discv5/sim_test.go
@@ -282,7 +282,7 @@ func (s *simulation) launchNode(log bool) *Network {
addr := &net.UDPAddr{IP: ip, Port: 30303}
transport := &simTransport{joinTime: time.Now(), sender: id, senderAddr: addr, sim: s, priv: key}
- net, err := newNetwork(transport, key.PublicKey, nil, "<no database>", nil)
+ net, err := newNetwork(transport, key.PublicKey, "<no database>", nil)
if err != nil {
panic("cannot launch new node: " + err.Error())
}
diff --git a/p2p/discv5/ticket.go b/p2p/discv5/ticket.go
index 193cef4be..37ce8d23c 100644
--- a/p2p/discv5/ticket.go
+++ b/p2p/discv5/ticket.go
@@ -28,6 +28,7 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/log"
)
const (
@@ -128,8 +129,11 @@ type ticketStore struct {
// Contains buckets (for each absolute minute) of tickets
// that can be used in that minute.
// This is only set if the topic is being registered.
- tickets map[Topic]topicTickets
- regtopics []Topic
+ tickets map[Topic]*topicTickets
+
+ regQueue []Topic // Topic registration queue for round robin attempts
+ regSet map[Topic]struct{} // Topic registration queue contents for fast filling
+
nodes map[*Node]*ticket
nodeLastReq map[*Node]reqInfo
@@ -152,14 +156,16 @@ type sentQuery struct {
}
type topicTickets struct {
- buckets map[timeBucket][]ticketRef
- nextLookup, nextReg mclock.AbsTime
+ buckets map[timeBucket][]ticketRef
+ nextLookup mclock.AbsTime
+ nextReg mclock.AbsTime
}
func newTicketStore() *ticketStore {
return &ticketStore{
radius: make(map[Topic]*topicRadius),
- tickets: make(map[Topic]topicTickets),
+ tickets: make(map[Topic]*topicTickets),
+ regSet: make(map[Topic]struct{}),
nodes: make(map[*Node]*ticket),
nodeLastReq: make(map[*Node]reqInfo),
searchTopicMap: make(map[Topic]searchTopic),
@@ -169,13 +175,13 @@ func newTicketStore() *ticketStore {
// addTopic starts tracking a topic. If register is true,
// the local node will register the topic and tickets will be collected.
-func (s *ticketStore) addTopic(t Topic, register bool) {
- debugLog(fmt.Sprintf(" addTopic(%v, %v)", t, register))
- if s.radius[t] == nil {
- s.radius[t] = newTopicRadius(t)
+func (s *ticketStore) addTopic(topic Topic, register bool) {
+ log.Trace("Adding discovery topic", "topic", topic, "register", register)
+ if s.radius[topic] == nil {
+ s.radius[topic] = newTopicRadius(topic)
}
- if register && s.tickets[t].buckets == nil {
- s.tickets[t] = topicTickets{buckets: make(map[timeBucket][]ticketRef)}
+ if register && s.tickets[topic] == nil {
+ s.tickets[topic] = &topicTickets{buckets: make(map[timeBucket][]ticketRef)}
}
}
@@ -194,7 +200,11 @@ func (s *ticketStore) removeSearchTopic(t Topic) {
// removeRegisterTopic deletes all tickets for the given topic.
func (s *ticketStore) removeRegisterTopic(topic Topic) {
- debugLog(fmt.Sprintf(" removeRegisterTopic(%v)", topic))
+ log.Trace("Removing discovery topic", "topic", topic)
+ if s.tickets[topic] == nil {
+ log.Warn("Removing non-existent discovery topic", "topic", topic)
+ return
+ }
for _, list := range s.tickets[topic].buckets {
for _, ref := range list {
ref.t.refCnt--
@@ -216,23 +226,35 @@ func (s *ticketStore) regTopicSet() []Topic {
}
// nextRegisterLookup returns the target of the next lookup for ticket collection.
-func (s *ticketStore) nextRegisterLookup() (lookup lookupInfo, delay time.Duration) {
- debugLog("nextRegisterLookup()")
- firstTopic, ok := s.iterRegTopics()
- for topic := firstTopic; ok; {
- debugLog(fmt.Sprintf(" checking topic %v, len(s.tickets[topic]) = %d", topic, len(s.tickets[topic].buckets)))
- if s.tickets[topic].buckets != nil && s.needMoreTickets(topic) {
- next := s.radius[topic].nextTarget(false)
- debugLog(fmt.Sprintf(" %x 1s", next.target[:8]))
- return next, 100 * time.Millisecond
+func (s *ticketStore) nextRegisterLookup() (lookupInfo, time.Duration) {
+ // Queue up any new topics (or discarded ones), preserving iteration order
+ for topic := range s.tickets {
+ if _, ok := s.regSet[topic]; !ok {
+ s.regQueue = append(s.regQueue, topic)
+ s.regSet[topic] = struct{}{}
+ }
+ }
+ // Iterate over the set of all topics and look up the next suitable one
+ for len(s.regQueue) > 0 {
+ // Fetch the next topic from the queue, and ensure it still exists
+ topic := s.regQueue[0]
+ s.regQueue = s.regQueue[1:]
+ delete(s.regSet, topic)
+
+ if s.tickets[topic] == nil {
+ continue
}
- topic, ok = s.iterRegTopics()
- if topic == firstTopic {
- break // We have checked all topics.
+ // If the topic needs more tickets, return it
+ if s.tickets[topic].nextLookup < mclock.Now() {
+ next, delay := s.radius[topic].nextTarget(false), 100*time.Millisecond
+ log.Trace("Found discovery topic to register", "topic", topic, "target", next.target, "delay", delay)
+ return next, delay
}
}
- debugLog(" null, 40s")
- return lookupInfo{}, 40 * time.Second
+ // No registration topics found or all exhausted, sleep
+ delay := 40 * time.Second
+ log.Trace("No topic found to register", "delay", delay)
+ return lookupInfo{}, delay
}
func (s *ticketStore) nextSearchLookup(topic Topic) lookupInfo {
@@ -246,40 +268,22 @@ func (s *ticketStore) nextSearchLookup(topic Topic) lookupInfo {
return target
}
-// iterRegTopics returns topics to register in arbitrary order.
-// The second return value is false if there are no topics.
-func (s *ticketStore) iterRegTopics() (Topic, bool) {
- debugLog("iterRegTopics()")
- if len(s.regtopics) == 0 {
- if len(s.tickets) == 0 {
- debugLog(" false")
- return "", false
- }
- // Refill register list.
- for t := range s.tickets {
- s.regtopics = append(s.regtopics, t)
- }
+// ticketsInWindow returns the tickets of a given topic in the registration window.
+func (s *ticketStore) ticketsInWindow(topic Topic) []ticketRef {
+ // Sanity check that the topic still exists before operating on it
+ if s.tickets[topic] == nil {
+ log.Warn("Listing non-existing discovery tickets", "topic", topic)
+ return nil
}
- topic := s.regtopics[len(s.regtopics)-1]
- s.regtopics = s.regtopics[:len(s.regtopics)-1]
- debugLog(" " + string(topic) + " true")
- return topic, true
-}
-
-func (s *ticketStore) needMoreTickets(t Topic) bool {
- return s.tickets[t].nextLookup < mclock.Now()
-}
+ // Gather all the tickers in the next time window
+ var tickets []ticketRef
-// ticketsInWindow returns the tickets of a given topic in the registration window.
-func (s *ticketStore) ticketsInWindow(t Topic) []ticketRef {
- ltBucket := s.lastBucketFetched
- var res []ticketRef
- tickets := s.tickets[t].buckets
- for g := ltBucket; g < ltBucket+timeWindow; g++ {
- res = append(res, tickets[g]...)
+ buckets := s.tickets[topic].buckets
+ for idx := timeBucket(0); idx < timeWindow; idx++ {
+ tickets = append(tickets, buckets[s.lastBucketFetched+idx]...)
}
- debugLog(fmt.Sprintf("ticketsInWindow(%v) = %v", t, len(res)))
- return res
+ log.Trace("Retrieved discovery registration tickets", "topic", topic, "from", s.lastBucketFetched, "tickets", len(tickets))
+ return tickets
}
func (s *ticketStore) removeExcessTickets(t Topic) {
@@ -317,53 +321,55 @@ func (s ticketRefByWaitTime) Swap(i, j int) {
func (s *ticketStore) addTicketRef(r ticketRef) {
topic := r.t.topics[r.idx]
- t := s.tickets[topic]
- if t.buckets == nil {
+ tickets := s.tickets[topic]
+ if tickets == nil {
+ log.Warn("Adding ticket to non-existent topic", "topic", topic)
return
}
bucket := timeBucket(r.t.regTime[r.idx] / mclock.AbsTime(ticketTimeBucketLen))
- t.buckets[bucket] = append(t.buckets[bucket], r)
+ tickets.buckets[bucket] = append(tickets.buckets[bucket], r)
r.t.refCnt++
min := mclock.Now() - mclock.AbsTime(collectFrequency)*maxCollectDebt
- if t.nextLookup < min {
- t.nextLookup = min
+ if tickets.nextLookup < min {
+ tickets.nextLookup = min
}
- t.nextLookup += mclock.AbsTime(collectFrequency)
- s.tickets[topic] = t
+ tickets.nextLookup += mclock.AbsTime(collectFrequency)
//s.removeExcessTickets(topic)
}
-func (s *ticketStore) nextFilteredTicket() (t *ticketRef, wait time.Duration) {
+func (s *ticketStore) nextFilteredTicket() (*ticketRef, time.Duration) {
now := mclock.Now()
for {
- t, wait = s.nextRegisterableTicket()
- if t == nil {
- return
+ ticket, wait := s.nextRegisterableTicket()
+ if ticket == nil {
+ return ticket, wait
}
+ log.Trace("Found discovery ticket to register", "node", ticket.t.node, "serial", ticket.t.serial, "wait", wait)
+
regTime := now + mclock.AbsTime(wait)
- topic := t.t.topics[t.idx]
- if regTime >= s.tickets[topic].nextReg {
- return
+ topic := ticket.t.topics[ticket.idx]
+ if s.tickets[topic] != nil && regTime >= s.tickets[topic].nextReg {
+ return ticket, wait
}
- s.removeTicketRef(*t)
+ s.removeTicketRef(*ticket)
}
}
-func (s *ticketStore) ticketRegistered(t ticketRef) {
+func (s *ticketStore) ticketRegistered(ref ticketRef) {
now := mclock.Now()
- topic := t.t.topics[t.idx]
- tt := s.tickets[topic]
+ topic := ref.t.topics[ref.idx]
+ tickets := s.tickets[topic]
min := now - mclock.AbsTime(registerFrequency)*maxRegisterDebt
- if min > tt.nextReg {
- tt.nextReg = min
+ if min > tickets.nextReg {
+ tickets.nextReg = min
}
- tt.nextReg += mclock.AbsTime(registerFrequency)
- s.tickets[topic] = tt
+ tickets.nextReg += mclock.AbsTime(registerFrequency)
+ s.tickets[topic] = tickets
- s.removeTicketRef(t)
+ s.removeTicketRef(ref)
}
// nextRegisterableTicket returns the next ticket that can be used
@@ -374,16 +380,7 @@ func (s *ticketStore) ticketRegistered(t ticketRef) {
//
// A ticket can be returned more than once with <= zero wait time in case
// the ticket contains multiple topics.
-func (s *ticketStore) nextRegisterableTicket() (t *ticketRef, wait time.Duration) {
- defer func() {
- if t == nil {
- debugLog(" nil")
- } else {
- debugLog(fmt.Sprintf(" node = %x sn = %v wait = %v", t.t.node.ID[:8], t.t.serial, wait))
- }
- }()
-
- debugLog("nextRegisterableTicket()")
+func (s *ticketStore) nextRegisterableTicket() (*ticketRef, time.Duration) {
now := mclock.Now()
if s.nextTicketCached != nil {
return s.nextTicketCached, time.Duration(s.nextTicketCached.topicRegTime() - now)
@@ -412,9 +409,8 @@ func (s *ticketStore) nextRegisterableTicket() (t *ticketRef, wait time.Duration
return nil, 0
}
if nextTicket.t != nil {
- wait = time.Duration(nextTicket.topicRegTime() - now)
s.nextTicketCached = &nextTicket
- return &nextTicket, wait
+ return &nextTicket, time.Duration(nextTicket.topicRegTime() - now)
}
s.lastBucketFetched = bucket
}
@@ -422,14 +418,20 @@ func (s *ticketStore) nextRegisterableTicket() (t *ticketRef, wait time.Duration
// removeTicket removes a ticket from the ticket store
func (s *ticketStore) removeTicketRef(ref ticketRef) {
- debugLog(fmt.Sprintf("removeTicketRef(node = %x sn = %v)", ref.t.node.ID[:8], ref.t.serial))
+ log.Trace("Removing discovery ticket reference", "node", ref.t.node.ID, "serial", ref.t.serial)
+
+ // Make nextRegisterableTicket return the next available ticket.
+ s.nextTicketCached = nil
+
topic := ref.topic()
- tickets := s.tickets[topic].buckets
+ tickets := s.tickets[topic]
+
if tickets == nil {
+ log.Trace("Removing tickets from unknown topic", "topic", topic)
return
}
bucket := timeBucket(ref.t.regTime[ref.idx] / mclock.AbsTime(ticketTimeBucketLen))
- list := tickets[bucket]
+ list := tickets.buckets[bucket]
idx := -1
for i, bt := range list {
if bt.t == ref.t {
@@ -442,18 +444,15 @@ func (s *ticketStore) removeTicketRef(ref ticketRef) {
}
list = append(list[:idx], list[idx+1:]...)
if len(list) != 0 {
- tickets[bucket] = list
+ tickets.buckets[bucket] = list
} else {
- delete(tickets, bucket)
+ delete(tickets.buckets, bucket)
}
ref.t.refCnt--
if ref.t.refCnt == 0 {
delete(s.nodes, ref.t.node)
delete(s.nodeLastReq, ref.t.node)
}
-
- // Make nextRegisterableTicket return the next available ticket.
- s.nextTicketCached = nil
}
type lookupInfo struct {
@@ -523,21 +522,21 @@ func (s *ticketStore) adjustWithTicket(now mclock.AbsTime, targetHash common.Has
}
}
-func (s *ticketStore) addTicket(localTime mclock.AbsTime, pingHash []byte, t *ticket) {
- debugLog(fmt.Sprintf("add(node = %x sn = %v)", t.node.ID[:8], t.serial))
+func (s *ticketStore) addTicket(localTime mclock.AbsTime, pingHash []byte, ticket *ticket) {
+ log.Trace("Adding discovery ticket", "node", ticket.node.ID, "serial", ticket.serial)
- lastReq, ok := s.nodeLastReq[t.node]
+ lastReq, ok := s.nodeLastReq[ticket.node]
if !(ok && bytes.Equal(pingHash, lastReq.pingHash)) {
return
}
- s.adjustWithTicket(localTime, lastReq.lookup.target, t)
+ s.adjustWithTicket(localTime, lastReq.lookup.target, ticket)
- if lastReq.lookup.radiusLookup || s.nodes[t.node] != nil {
+ if lastReq.lookup.radiusLookup || s.nodes[ticket.node] != nil {
return
}
topic := lastReq.lookup.topic
- topicIdx := t.findIdx(topic)
+ topicIdx := ticket.findIdx(topic)
if topicIdx == -1 {
return
}
@@ -548,29 +547,29 @@ func (s *ticketStore) addTicket(localTime mclock.AbsTime, pingHash []byte, t *ti
}
if _, ok := s.tickets[topic]; ok {
- wait := t.regTime[topicIdx] - localTime
+ wait := ticket.regTime[topicIdx] - localTime
rnd := rand.ExpFloat64()
if rnd > 10 {
rnd = 10
}
if float64(wait) < float64(keepTicketConst)+float64(keepTicketExp)*rnd {
// use the ticket to register this topic
- //fmt.Println("addTicket", t.node.ID[:8], t.node.addr().String(), t.serial, t.pong)
- s.addTicketRef(ticketRef{t, topicIdx})
+ //fmt.Println("addTicket", ticket.node.ID[:8], ticket.node.addr().String(), ticket.serial, ticket.pong)
+ s.addTicketRef(ticketRef{ticket, topicIdx})
}
}
- if t.refCnt > 0 {
+ if ticket.refCnt > 0 {
s.nextTicketCached = nil
- s.nodes[t.node] = t
+ s.nodes[ticket.node] = ticket
}
}
func (s *ticketStore) getNodeTicket(node *Node) *ticket {
if s.nodes[node] == nil {
- debugLog(fmt.Sprintf("getNodeTicket(%x) sn = nil", node.ID[:8]))
+ log.Trace("Retrieving node ticket", "node", node.ID, "serial", nil)
} else {
- debugLog(fmt.Sprintf("getNodeTicket(%x) sn = %v", node.ID[:8], s.nodes[node].serial))
+ log.Trace("Retrieving node ticket", "node", node.ID, "serial", s.nodes[node].serial)
}
return s.nodes[node]
}
@@ -643,7 +642,7 @@ func (s *ticketStore) gotTopicNodes(from *Node, hash common.Hash, nodes []rpcNod
if ip.IsUnspecified() || ip.IsLoopback() {
ip = from.IP
}
- n := NewNode(node.ID, ip, node.UDP-1, node.TCP-1) // subtract one from port while discv5 is running in test mode on UDPport+1
+ n := NewNode(node.ID, ip, node.UDP, node.TCP)
select {
case chn <- n:
default:
diff --git a/p2p/discv5/topic.go b/p2p/discv5/topic.go
index b6bea013c..e7a7f8e02 100644
--- a/p2p/discv5/topic.go
+++ b/p2p/discv5/topic.go
@@ -24,6 +24,7 @@ import (
"time"
"github.com/ethereum/go-ethereum/common/mclock"
+ "github.com/ethereum/go-ethereum/log"
)
const (
@@ -235,7 +236,7 @@ func (t *topicTable) deleteEntry(e *topicEntry) {
// It is assumed that topics and waitPeriods have the same length.
func (t *topicTable) useTicket(node *Node, serialNo uint32, topics []Topic, idx int, issueTime uint64, waitPeriods []uint32) (registered bool) {
- debugLog(fmt.Sprintf("useTicket %v %v %v", serialNo, topics, waitPeriods))
+ log.Trace("Using discovery ticket", "serial", serialNo, "topics", topics, "waits", waitPeriods)
//fmt.Println("useTicket", serialNo, topics, waitPeriods)
t.collectGarbage()
diff --git a/p2p/discv5/udp.go b/p2p/discv5/udp.go
index 26087cd8e..543771817 100644
--- a/p2p/discv5/udp.go
+++ b/p2p/discv5/udp.go
@@ -37,7 +37,7 @@ const Version = 4
// Errors
var (
errPacketTooSmall = errors.New("too small")
- errBadHash = errors.New("bad hash")
+ errBadPrefix = errors.New("bad prefix")
errExpired = errors.New("expired")
errUnsolicitedReply = errors.New("unsolicited reply")
errUnknownNode = errors.New("unknown node")
@@ -145,10 +145,11 @@ type (
}
)
-const (
- macSize = 256 / 8
- sigSize = 520 / 8
- headSize = macSize + sigSize // space of packet frame data
+var (
+ versionPrefix = []byte("temporary discovery v5")
+ versionPrefixSize = len(versionPrefix)
+ sigSize = 520 / 8
+ headSize = versionPrefixSize + sigSize // space of packet frame data
)
// Neighbors replies are sent across multiple packets to
@@ -237,30 +238,23 @@ type udp struct {
}
// ListenUDP returns a new table that listens for UDP packets on laddr.
-func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string, netrestrict *netutil.Netlist) (*Network, error) {
- transport, err := listenUDP(priv, laddr)
+func ListenUDP(priv *ecdsa.PrivateKey, conn conn, realaddr *net.UDPAddr, nodeDBPath string, netrestrict *netutil.Netlist) (*Network, error) {
+ transport, err := listenUDP(priv, conn, realaddr)
if err != nil {
return nil, err
}
- net, err := newNetwork(transport, priv.PublicKey, natm, nodeDBPath, netrestrict)
+ net, err := newNetwork(transport, priv.PublicKey, nodeDBPath, netrestrict)
if err != nil {
return nil, err
}
+ log.Info("UDP listener up", "net", net.tab.self)
transport.net = net
go transport.readLoop()
return net, nil
}
-func listenUDP(priv *ecdsa.PrivateKey, laddr string) (*udp, error) {
- addr, err := net.ResolveUDPAddr("udp", laddr)
- if err != nil {
- return nil, err
- }
- conn, err := net.ListenUDP("udp", addr)
- if err != nil {
- return nil, err
- }
- return &udp{conn: conn, priv: priv, ourEndpoint: makeEndpoint(addr, uint16(addr.Port))}, nil
+func listenUDP(priv *ecdsa.PrivateKey, conn conn, realaddr *net.UDPAddr) (*udp, error) {
+ return &udp{conn: conn, priv: priv, ourEndpoint: makeEndpoint(realaddr, uint16(realaddr.Port))}, nil
}
func (t *udp) localAddr() *net.UDPAddr {
@@ -372,11 +366,9 @@ func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) (p, hash
log.Error(fmt.Sprint("could not sign packet:", err))
return nil, nil, err
}
- copy(packet[macSize:], sig)
- // add the hash to the front. Note: this doesn't protect the
- // packet in any way.
- hash = crypto.Keccak256(packet[macSize:])
- copy(packet, hash)
+ copy(packet, versionPrefix)
+ copy(packet[versionPrefixSize:], sig)
+ hash = crypto.Keccak256(packet[versionPrefixSize:])
return packet, hash, nil
}
@@ -420,17 +412,16 @@ func decodePacket(buffer []byte, pkt *ingressPacket) error {
}
buf := make([]byte, len(buffer))
copy(buf, buffer)
- hash, sig, sigdata := buf[:macSize], buf[macSize:headSize], buf[headSize:]
- shouldhash := crypto.Keccak256(buf[macSize:])
- if !bytes.Equal(hash, shouldhash) {
- return errBadHash
+ prefix, sig, sigdata := buf[:versionPrefixSize], buf[versionPrefixSize:headSize], buf[headSize:]
+ if !bytes.Equal(prefix, versionPrefix) {
+ return errBadPrefix
}
fromID, err := recoverNodeID(crypto.Keccak256(buf[headSize:]), sig)
if err != nil {
return err
}
pkt.rawData = buf
- pkt.hash = hash
+ pkt.hash = crypto.Keccak256(buf[versionPrefixSize:])
pkt.remoteID = fromID
switch pkt.ev = nodeEvent(sigdata[0]); pkt.ev {
case pingPacket:
diff --git a/p2p/enr/enr.go b/p2p/enr/enr.go
new file mode 100644
index 000000000..2c3afb43e
--- /dev/null
+++ b/p2p/enr/enr.go
@@ -0,0 +1,290 @@
+// Copyright 2017 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+// Package enr implements Ethereum Node Records as defined in EIP-778. A node record holds
+// arbitrary information about a node on the peer-to-peer network.
+//
+// Records contain named keys. To store and retrieve key/values in a record, use the Entry
+// interface.
+//
+// Records must be signed before transmitting them to another node. Decoding a record verifies
+// its signature. When creating a record, set the entries you want, then call Sign to add the
+// signature. Modifying a record invalidates the signature.
+//
+// Package enr supports the "secp256k1-keccak" identity scheme.
+package enr
+
+import (
+ "bytes"
+ "crypto/ecdsa"
+ "errors"
+ "fmt"
+ "io"
+ "sort"
+
+ "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/crypto/sha3"
+ "github.com/ethereum/go-ethereum/rlp"
+)
+
+const SizeLimit = 300 // maximum encoded size of a node record in bytes
+
+const ID_SECP256k1_KECCAK = ID("secp256k1-keccak") // the default identity scheme
+
+var (
+ errNoID = errors.New("unknown or unspecified identity scheme")
+ errInvalidSigsize = errors.New("invalid signature size")
+ errInvalidSig = errors.New("invalid signature")
+ errNotSorted = errors.New("record key/value pairs are not sorted by key")
+ errDuplicateKey = errors.New("record contains duplicate key")
+ errIncompletePair = errors.New("record contains incomplete k/v pair")
+ errTooBig = fmt.Errorf("record bigger than %d bytes", SizeLimit)
+ errEncodeUnsigned = errors.New("can't encode unsigned record")
+ errNotFound = errors.New("no such key in record")
+)
+
+// Record represents a node record. The zero value is an empty record.
+type Record struct {
+ seq uint64 // sequence number
+ signature []byte // the signature
+ raw []byte // RLP encoded record
+ pairs []pair // sorted list of all key/value pairs
+}
+
+// pair is a key/value pair in a record.
+type pair struct {
+ k string
+ v rlp.RawValue
+}
+
+// Signed reports whether the record has a valid signature.
+func (r *Record) Signed() bool {
+ return r.signature != nil
+}
+
+// Seq returns the sequence number.
+func (r *Record) Seq() uint64 {
+ return r.seq
+}
+
+// SetSeq updates the record sequence number. This invalidates any signature on the record.
+// Calling SetSeq is usually not required because signing the redord increments the
+// sequence number.
+func (r *Record) SetSeq(s uint64) {
+ r.signature = nil
+ r.raw = nil
+ r.seq = s
+}
+
+// Load retrieves the value of a key/value pair. The given Entry must be a pointer and will
+// be set to the value of the entry in the record.
+//
+// Errors returned by Load are wrapped in KeyError. You can distinguish decoding errors
+// from missing keys using the IsNotFound function.
+func (r *Record) Load(e Entry) error {
+ i := sort.Search(len(r.pairs), func(i int) bool { return r.pairs[i].k >= e.ENRKey() })
+ if i < len(r.pairs) && r.pairs[i].k == e.ENRKey() {
+ if err := rlp.DecodeBytes(r.pairs[i].v, e); err != nil {
+ return &KeyError{Key: e.ENRKey(), Err: err}
+ }
+ return nil
+ }
+ return &KeyError{Key: e.ENRKey(), Err: errNotFound}
+}
+
+// Set adds or updates the given entry in the record.
+// It panics if the value can't be encoded.
+func (r *Record) Set(e Entry) {
+ r.signature = nil
+ r.raw = nil
+ blob, err := rlp.EncodeToBytes(e)
+ if err != nil {
+ panic(fmt.Errorf("enr: can't encode %s: %v", e.ENRKey(), err))
+ }
+
+ i := sort.Search(len(r.pairs), func(i int) bool { return r.pairs[i].k >= e.ENRKey() })
+
+ if i < len(r.pairs) && r.pairs[i].k == e.ENRKey() {
+ // element is present at r.pairs[i]
+ r.pairs[i].v = blob
+ return
+ } else if i < len(r.pairs) {
+ // insert pair before i-th elem
+ el := pair{e.ENRKey(), blob}
+ r.pairs = append(r.pairs, pair{})
+ copy(r.pairs[i+1:], r.pairs[i:])
+ r.pairs[i] = el
+ return
+ }
+
+ // element should be placed at the end of r.pairs
+ r.pairs = append(r.pairs, pair{e.ENRKey(), blob})
+}
+
+// EncodeRLP implements rlp.Encoder. Encoding fails if
+// the record is unsigned.
+func (r Record) EncodeRLP(w io.Writer) error {
+ if !r.Signed() {
+ return errEncodeUnsigned
+ }
+ _, err := w.Write(r.raw)
+ return err
+}
+
+// DecodeRLP implements rlp.Decoder. Decoding verifies the signature.
+func (r *Record) DecodeRLP(s *rlp.Stream) error {
+ raw, err := s.Raw()
+ if err != nil {
+ return err
+ }
+ if len(raw) > SizeLimit {
+ return errTooBig
+ }
+
+ // Decode the RLP container.
+ dec := Record{raw: raw}
+ s = rlp.NewStream(bytes.NewReader(raw), 0)
+ if _, err := s.List(); err != nil {
+ return err
+ }
+ if err = s.Decode(&dec.signature); err != nil {
+ return err
+ }
+ if err = s.Decode(&dec.seq); err != nil {
+ return err
+ }
+ // The rest of the record contains sorted k/v pairs.
+ var prevkey string
+ for i := 0; ; i++ {
+ var kv pair
+ if err := s.Decode(&kv.k); err != nil {
+ if err == rlp.EOL {
+ break
+ }
+ return err
+ }
+ if err := s.Decode(&kv.v); err != nil {
+ if err == rlp.EOL {
+ return errIncompletePair
+ }
+ return err
+ }
+ if i > 0 {
+ if kv.k == prevkey {
+ return errDuplicateKey
+ }
+ if kv.k < prevkey {
+ return errNotSorted
+ }
+ }
+ dec.pairs = append(dec.pairs, kv)
+ prevkey = kv.k
+ }
+ if err := s.ListEnd(); err != nil {
+ return err
+ }
+
+ // Verify signature.
+ if err = dec.verifySignature(); err != nil {
+ return err
+ }
+ *r = dec
+ return nil
+}
+
+type s256raw []byte
+
+func (s256raw) ENRKey() string { return "secp256k1" }
+
+// NodeAddr returns the node address. The return value will be nil if the record is
+// unsigned.
+func (r *Record) NodeAddr() []byte {
+ var entry s256raw
+ if r.Load(&entry) != nil {
+ return nil
+ }
+ return crypto.Keccak256(entry)
+}
+
+// Sign signs the record with the given private key. It updates the record's identity
+// scheme, public key and increments the sequence number. Sign returns an error if the
+// encoded record is larger than the size limit.
+func (r *Record) Sign(privkey *ecdsa.PrivateKey) error {
+ r.seq = r.seq + 1
+ r.Set(ID_SECP256k1_KECCAK)
+ r.Set(Secp256k1(privkey.PublicKey))
+ return r.signAndEncode(privkey)
+}
+
+func (r *Record) appendPairs(list []interface{}) []interface{} {
+ list = append(list, r.seq)
+ for _, p := range r.pairs {
+ list = append(list, p.k, p.v)
+ }
+ return list
+}
+
+func (r *Record) signAndEncode(privkey *ecdsa.PrivateKey) error {
+ // Put record elements into a flat list. Leave room for the signature.
+ list := make([]interface{}, 1, len(r.pairs)*2+2)
+ list = r.appendPairs(list)
+
+ // Sign the tail of the list.
+ h := sha3.NewKeccak256()
+ rlp.Encode(h, list[1:])
+ sig, err := crypto.Sign(h.Sum(nil), privkey)
+ if err != nil {
+ return err
+ }
+ sig = sig[:len(sig)-1] // remove v
+
+ // Put signature in front.
+ r.signature, list[0] = sig, sig
+ r.raw, err = rlp.EncodeToBytes(list)
+ if err != nil {
+ return err
+ }
+ if len(r.raw) > SizeLimit {
+ return errTooBig
+ }
+ return nil
+}
+
+func (r *Record) verifySignature() error {
+ // Get identity scheme, public key, signature.
+ var id ID
+ var entry s256raw
+ if err := r.Load(&id); err != nil {
+ return err
+ } else if id != ID_SECP256k1_KECCAK {
+ return errNoID
+ }
+ if err := r.Load(&entry); err != nil {
+ return err
+ } else if len(entry) != 33 {
+ return fmt.Errorf("invalid public key")
+ }
+
+ // Verify the signature.
+ list := make([]interface{}, 0, len(r.pairs)*2+1)
+ list = r.appendPairs(list)
+ h := sha3.NewKeccak256()
+ rlp.Encode(h, list)
+ if !crypto.VerifySignature(entry, h.Sum(nil), r.signature) {
+ return errInvalidSig
+ }
+ return nil
+}
diff --git a/p2p/enr/enr_test.go b/p2p/enr/enr_test.go
new file mode 100644
index 000000000..ce7767d10
--- /dev/null
+++ b/p2p/enr/enr_test.go
@@ -0,0 +1,318 @@
+// Copyright 2017 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+package enr
+
+import (
+ "bytes"
+ "encoding/hex"
+ "fmt"
+ "math/rand"
+ "testing"
+ "time"
+
+ "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/rlp"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+var (
+ privkey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291")
+ pubkey = &privkey.PublicKey
+)
+
+var rnd = rand.New(rand.NewSource(time.Now().UnixNano()))
+
+func randomString(strlen int) string {
+ b := make([]byte, strlen)
+ rnd.Read(b)
+ return string(b)
+}
+
+// TestGetSetID tests encoding/decoding and setting/getting of the ID key.
+func TestGetSetID(t *testing.T) {
+ id := ID("someid")
+ var r Record
+ r.Set(id)
+
+ var id2 ID
+ require.NoError(t, r.Load(&id2))
+ assert.Equal(t, id, id2)
+}
+
+// TestGetSetIP4 tests encoding/decoding and setting/getting of the IP4 key.
+func TestGetSetIP4(t *testing.T) {
+ ip := IP4{192, 168, 0, 3}
+ var r Record
+ r.Set(ip)
+
+ var ip2 IP4
+ require.NoError(t, r.Load(&ip2))
+ assert.Equal(t, ip, ip2)
+}
+
+// TestGetSetIP6 tests encoding/decoding and setting/getting of the IP6 key.
+func TestGetSetIP6(t *testing.T) {
+ ip := IP6{0x20, 0x01, 0x48, 0x60, 0, 0, 0x20, 0x01, 0, 0, 0, 0, 0, 0, 0x00, 0x68}
+ var r Record
+ r.Set(ip)
+
+ var ip2 IP6
+ require.NoError(t, r.Load(&ip2))
+ assert.Equal(t, ip, ip2)
+}
+
+// TestGetSetDiscPort tests encoding/decoding and setting/getting of the DiscPort key.
+func TestGetSetDiscPort(t *testing.T) {
+ port := DiscPort(30309)
+ var r Record
+ r.Set(port)
+
+ var port2 DiscPort
+ require.NoError(t, r.Load(&port2))
+ assert.Equal(t, port, port2)
+}
+
+// TestGetSetSecp256k1 tests encoding/decoding and setting/getting of the Secp256k1 key.
+func TestGetSetSecp256k1(t *testing.T) {
+ var r Record
+ if err := r.Sign(privkey); err != nil {
+ t.Fatal(err)
+ }
+
+ var pk Secp256k1
+ require.NoError(t, r.Load(&pk))
+ assert.EqualValues(t, pubkey, &pk)
+}
+
+func TestLoadErrors(t *testing.T) {
+ var r Record
+ ip4 := IP4{127, 0, 0, 1}
+ r.Set(ip4)
+
+ // Check error for missing keys.
+ var ip6 IP6
+ err := r.Load(&ip6)
+ if !IsNotFound(err) {
+ t.Error("IsNotFound should return true for missing key")
+ }
+ assert.Equal(t, &KeyError{Key: ip6.ENRKey(), Err: errNotFound}, err)
+
+ // Check error for invalid keys.
+ var list []uint
+ err = r.Load(WithEntry(ip4.ENRKey(), &list))
+ kerr, ok := err.(*KeyError)
+ if !ok {
+ t.Fatalf("expected KeyError, got %T", err)
+ }
+ assert.Equal(t, kerr.Key, ip4.ENRKey())
+ assert.Error(t, kerr.Err)
+ if IsNotFound(err) {
+ t.Error("IsNotFound should return false for decoding errors")
+ }
+}
+
+// TestSortedGetAndSet tests that Set produced a sorted pairs slice.
+func TestSortedGetAndSet(t *testing.T) {
+ type pair struct {
+ k string
+ v uint32
+ }
+
+ for _, tt := range []struct {
+ input []pair
+ want []pair
+ }{
+ {
+ input: []pair{{"a", 1}, {"c", 2}, {"b", 3}},
+ want: []pair{{"a", 1}, {"b", 3}, {"c", 2}},
+ },
+ {
+ input: []pair{{"a", 1}, {"c", 2}, {"b", 3}, {"d", 4}, {"a", 5}, {"bb", 6}},
+ want: []pair{{"a", 5}, {"b", 3}, {"bb", 6}, {"c", 2}, {"d", 4}},
+ },
+ {
+ input: []pair{{"c", 2}, {"b", 3}, {"d", 4}, {"a", 5}, {"bb", 6}},
+ want: []pair{{"a", 5}, {"b", 3}, {"bb", 6}, {"c", 2}, {"d", 4}},
+ },
+ } {
+ var r Record
+ for _, i := range tt.input {
+ r.Set(WithEntry(i.k, &i.v))
+ }
+ for i, w := range tt.want {
+ // set got's key from r.pair[i], so that we preserve order of pairs
+ got := pair{k: r.pairs[i].k}
+ assert.NoError(t, r.Load(WithEntry(w.k, &got.v)))
+ assert.Equal(t, w, got)
+ }
+ }
+}
+
+// TestDirty tests record signature removal on setting of new key/value pair in record.
+func TestDirty(t *testing.T) {
+ var r Record
+
+ if r.Signed() {
+ t.Error("Signed returned true for zero record")
+ }
+ if _, err := rlp.EncodeToBytes(r); err != errEncodeUnsigned {
+ t.Errorf("expected errEncodeUnsigned, got %#v", err)
+ }
+
+ require.NoError(t, r.Sign(privkey))
+ if !r.Signed() {
+ t.Error("Signed return false for signed record")
+ }
+ _, err := rlp.EncodeToBytes(r)
+ assert.NoError(t, err)
+
+ r.SetSeq(3)
+ if r.Signed() {
+ t.Error("Signed returned true for modified record")
+ }
+ if _, err := rlp.EncodeToBytes(r); err != errEncodeUnsigned {
+ t.Errorf("expected errEncodeUnsigned, got %#v", err)
+ }
+}
+
+// TestGetSetOverwrite tests value overwrite when setting a new value with an existing key in record.
+func TestGetSetOverwrite(t *testing.T) {
+ var r Record
+
+ ip := IP4{192, 168, 0, 3}
+ r.Set(ip)
+
+ ip2 := IP4{192, 168, 0, 4}
+ r.Set(ip2)
+
+ var ip3 IP4
+ require.NoError(t, r.Load(&ip3))
+ assert.Equal(t, ip2, ip3)
+}
+
+// TestSignEncodeAndDecode tests signing, RLP encoding and RLP decoding of a record.
+func TestSignEncodeAndDecode(t *testing.T) {
+ var r Record
+ r.Set(DiscPort(30303))
+ r.Set(IP4{127, 0, 0, 1})
+ require.NoError(t, r.Sign(privkey))
+
+ blob, err := rlp.EncodeToBytes(r)
+ require.NoError(t, err)
+
+ var r2 Record
+ require.NoError(t, rlp.DecodeBytes(blob, &r2))
+ assert.Equal(t, r, r2)
+
+ blob2, err := rlp.EncodeToBytes(r2)
+ require.NoError(t, err)
+ assert.Equal(t, blob, blob2)
+}
+
+func TestNodeAddr(t *testing.T) {
+ var r Record
+ if addr := r.NodeAddr(); addr != nil {
+ t.Errorf("wrong address on empty record: got %v, want %v", addr, nil)
+ }
+
+ require.NoError(t, r.Sign(privkey))
+ expected := "caaa1485d83b18b32ed9ad666026151bf0cae8a0a88c857ae2d4c5be2daa6726"
+ assert.Equal(t, expected, hex.EncodeToString(r.NodeAddr()))
+}
+
+var pyRecord, _ = hex.DecodeString("f896b840954dc36583c1f4b69ab59b1375f362f06ee99f3723cd77e64b6de6d211c27d7870642a79d4516997f94091325d2a7ca6215376971455fb221d34f35b277149a1018664697363763582765f82696490736563703235366b312d6b656363616b83697034847f00000189736563703235366b31a103ca634cae0d49acb401d8a4c6b6fe8c55b70d115bf400769cc1400f3258cd3138")
+
+// TestPythonInterop checks that we can decode and verify a record produced by the Python
+// implementation.
+func TestPythonInterop(t *testing.T) {
+ var r Record
+ if err := rlp.DecodeBytes(pyRecord, &r); err != nil {
+ t.Fatalf("can't decode: %v", err)
+ }
+
+ var (
+ wantAddr, _ = hex.DecodeString("caaa1485d83b18b32ed9ad666026151bf0cae8a0a88c857ae2d4c5be2daa6726")
+ wantSeq = uint64(1)
+ wantIP = IP4{127, 0, 0, 1}
+ wantDiscport = DiscPort(30303)
+ )
+ if r.Seq() != wantSeq {
+ t.Errorf("wrong seq: got %d, want %d", r.Seq(), wantSeq)
+ }
+ if addr := r.NodeAddr(); !bytes.Equal(addr, wantAddr) {
+ t.Errorf("wrong addr: got %x, want %x", addr, wantAddr)
+ }
+ want := map[Entry]interface{}{new(IP4): &wantIP, new(DiscPort): &wantDiscport}
+ for k, v := range want {
+ desc := fmt.Sprintf("loading key %q", k.ENRKey())
+ if assert.NoError(t, r.Load(k), desc) {
+ assert.Equal(t, k, v, desc)
+ }
+ }
+}
+
+// TestRecordTooBig tests that records bigger than SizeLimit bytes cannot be signed.
+func TestRecordTooBig(t *testing.T) {
+ var r Record
+ key := randomString(10)
+
+ // set a big value for random key, expect error
+ r.Set(WithEntry(key, randomString(300)))
+ if err := r.Sign(privkey); err != errTooBig {
+ t.Fatalf("expected to get errTooBig, got %#v", err)
+ }
+
+ // set an acceptable value for random key, expect no error
+ r.Set(WithEntry(key, randomString(100)))
+ require.NoError(t, r.Sign(privkey))
+}
+
+// TestSignEncodeAndDecodeRandom tests encoding/decoding of records containing random key/value pairs.
+func TestSignEncodeAndDecodeRandom(t *testing.T) {
+ var r Record
+
+ // random key/value pairs for testing
+ pairs := map[string]uint32{}
+ for i := 0; i < 10; i++ {
+ key := randomString(7)
+ value := rnd.Uint32()
+ pairs[key] = value
+ r.Set(WithEntry(key, &value))
+ }
+
+ require.NoError(t, r.Sign(privkey))
+ _, err := rlp.EncodeToBytes(r)
+ require.NoError(t, err)
+
+ for k, v := range pairs {
+ desc := fmt.Sprintf("key %q", k)
+ var got uint32
+ buf := WithEntry(k, &got)
+ require.NoError(t, r.Load(buf), desc)
+ require.Equal(t, v, got, desc)
+ }
+}
+
+func BenchmarkDecode(b *testing.B) {
+ var r Record
+ for i := 0; i < b.N; i++ {
+ rlp.DecodeBytes(pyRecord, &r)
+ }
+ b.StopTimer()
+ r.NodeAddr()
+}
diff --git a/p2p/enr/entries.go b/p2p/enr/entries.go
new file mode 100644
index 000000000..7591e6eff
--- /dev/null
+++ b/p2p/enr/entries.go
@@ -0,0 +1,160 @@
+// Copyright 2017 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+package enr
+
+import (
+ "crypto/ecdsa"
+ "fmt"
+ "io"
+ "net"
+
+ "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/rlp"
+)
+
+// Entry is implemented by known node record entry types.
+//
+// To define a new entry that is to be included in a node record,
+// create a Go type that satisfies this interface. The type should
+// also implement rlp.Decoder if additional checks are needed on the value.
+type Entry interface {
+ ENRKey() string
+}
+
+type generic struct {
+ key string
+ value interface{}
+}
+
+func (g generic) ENRKey() string { return g.key }
+
+func (g generic) EncodeRLP(w io.Writer) error {
+ return rlp.Encode(w, g.value)
+}
+
+func (g *generic) DecodeRLP(s *rlp.Stream) error {
+ return s.Decode(g.value)
+}
+
+// WithEntry wraps any value with a key name. It can be used to set and load arbitrary values
+// in a record. The value v must be supported by rlp. To use WithEntry with Load, the value
+// must be a pointer.
+func WithEntry(k string, v interface{}) Entry {
+ return &generic{key: k, value: v}
+}
+
+// DiscPort is the "discv5" key, which holds the UDP port for discovery v5.
+type DiscPort uint16
+
+func (v DiscPort) ENRKey() string { return "discv5" }
+
+// ID is the "id" key, which holds the name of the identity scheme.
+type ID string
+
+func (v ID) ENRKey() string { return "id" }
+
+// IP4 is the "ip4" key, which holds a 4-byte IPv4 address.
+type IP4 net.IP
+
+func (v IP4) ENRKey() string { return "ip4" }
+
+// EncodeRLP implements rlp.Encoder.
+func (v IP4) EncodeRLP(w io.Writer) error {
+ ip4 := net.IP(v).To4()
+ if ip4 == nil {
+ return fmt.Errorf("invalid IPv4 address: %v", v)
+ }
+ return rlp.Encode(w, ip4)
+}
+
+// DecodeRLP implements rlp.Decoder.
+func (v *IP4) DecodeRLP(s *rlp.Stream) error {
+ if err := s.Decode((*net.IP)(v)); err != nil {
+ return err
+ }
+ if len(*v) != 4 {
+ return fmt.Errorf("invalid IPv4 address, want 4 bytes: %v", *v)
+ }
+ return nil
+}
+
+// IP6 is the "ip6" key, which holds a 16-byte IPv6 address.
+type IP6 net.IP
+
+func (v IP6) ENRKey() string { return "ip6" }
+
+// EncodeRLP implements rlp.Encoder.
+func (v IP6) EncodeRLP(w io.Writer) error {
+ ip6 := net.IP(v)
+ return rlp.Encode(w, ip6)
+}
+
+// DecodeRLP implements rlp.Decoder.
+func (v *IP6) DecodeRLP(s *rlp.Stream) error {
+ if err := s.Decode((*net.IP)(v)); err != nil {
+ return err
+ }
+ if len(*v) != 16 {
+ return fmt.Errorf("invalid IPv6 address, want 16 bytes: %v", *v)
+ }
+ return nil
+}
+
+// Secp256k1 is the "secp256k1" key, which holds a public key.
+type Secp256k1 ecdsa.PublicKey
+
+func (v Secp256k1) ENRKey() string { return "secp256k1" }
+
+// EncodeRLP implements rlp.Encoder.
+func (v Secp256k1) EncodeRLP(w io.Writer) error {
+ return rlp.Encode(w, crypto.CompressPubkey((*ecdsa.PublicKey)(&v)))
+}
+
+// DecodeRLP implements rlp.Decoder.
+func (v *Secp256k1) DecodeRLP(s *rlp.Stream) error {
+ buf, err := s.Bytes()
+ if err != nil {
+ return err
+ }
+ pk, err := crypto.DecompressPubkey(buf)
+ if err != nil {
+ return err
+ }
+ *v = (Secp256k1)(*pk)
+ return nil
+}
+
+// KeyError is an error related to a key.
+type KeyError struct {
+ Key string
+ Err error
+}
+
+// Error implements error.
+func (err *KeyError) Error() string {
+ if err.Err == errNotFound {
+ return fmt.Sprintf("missing ENR key %q", err.Key)
+ }
+ return fmt.Sprintf("ENR key %q: %v", err.Key, err.Err)
+}
+
+// IsNotFound reports whether the given error means that a key/value pair is
+// missing from a record.
+func IsNotFound(err error) bool {
+ kerr, ok := err.(*KeyError)
+ return ok && kerr.Err == errNotFound
+}
diff --git a/p2p/server.go b/p2p/server.go
index 922df55ba..2cff94ea5 100644
--- a/p2p/server.go
+++ b/p2p/server.go
@@ -78,9 +78,6 @@ type Config struct {
// protocol should be started or not.
DiscoveryV5 bool `toml:",omitempty"`
- // Listener address for the V5 discovery protocol UDP traffic.
- DiscoveryV5Addr string `toml:",omitempty"`
-
// Name sets the node name of this server.
// Use common.MakeName to create a name that follows existing conventions.
Name string `toml:"-"`
@@ -354,6 +351,32 @@ func (srv *Server) Stop() {
srv.loopWG.Wait()
}
+// sharedUDPConn implements a shared connection. Write sends messages to the underlying connection while read returns
+// messages that were found unprocessable and sent to the unhandled channel by the primary listener.
+type sharedUDPConn struct {
+ *net.UDPConn
+ unhandled chan discover.ReadPacket
+}
+
+// ReadFromUDP implements discv5.conn
+func (s *sharedUDPConn) ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) {
+ packet, ok := <-s.unhandled
+ if !ok {
+ return 0, nil, fmt.Errorf("Connection was closed")
+ }
+ l := len(packet.Data)
+ if l > len(b) {
+ l = len(b)
+ }
+ copy(b[:l], packet.Data[:l])
+ return l, packet.Addr, nil
+}
+
+// Close implements discv5.conn
+func (s *sharedUDPConn) Close() error {
+ return nil
+}
+
// Start starts running the server.
// Servers can not be re-used after stopping.
func (srv *Server) Start() (err error) {
@@ -388,9 +411,43 @@ func (srv *Server) Start() (err error) {
srv.peerOp = make(chan peerOpFunc)
srv.peerOpDone = make(chan struct{})
+ var (
+ conn *net.UDPConn
+ sconn *sharedUDPConn
+ realaddr *net.UDPAddr
+ unhandled chan discover.ReadPacket
+ )
+
+ if !srv.NoDiscovery || srv.DiscoveryV5 {
+ addr, err := net.ResolveUDPAddr("udp", srv.ListenAddr)
+ if err != nil {
+ return err
+ }
+ conn, err = net.ListenUDP("udp", addr)
+ if err != nil {
+ return err
+ }
+
+ realaddr = conn.LocalAddr().(*net.UDPAddr)
+ if srv.NAT != nil {
+ if !realaddr.IP.IsLoopback() {
+ go nat.Map(srv.NAT, srv.quit, "udp", realaddr.Port, realaddr.Port, "ethereum discovery")
+ }
+ // TODO: react to external IP changes over time.
+ if ext, err := srv.NAT.ExternalIP(); err == nil {
+ realaddr = &net.UDPAddr{IP: ext, Port: realaddr.Port}
+ }
+ }
+ }
+
+ if !srv.NoDiscovery && srv.DiscoveryV5 {
+ unhandled = make(chan discover.ReadPacket, 100)
+ sconn = &sharedUDPConn{conn, unhandled}
+ }
+
// node table
if !srv.NoDiscovery {
- ntab, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT, srv.NodeDatabase, srv.NetRestrict)
+ ntab, err := discover.ListenUDP(srv.PrivateKey, conn, realaddr, unhandled, srv.NodeDatabase, srv.NetRestrict)
if err != nil {
return err
}
@@ -401,7 +458,15 @@ func (srv *Server) Start() (err error) {
}
if srv.DiscoveryV5 {
- ntab, err := discv5.ListenUDP(srv.PrivateKey, srv.DiscoveryV5Addr, srv.NAT, "", srv.NetRestrict) //srv.NodeDatabase)
+ var (
+ ntab *discv5.Network
+ err error
+ )
+ if sconn != nil {
+ ntab, err = discv5.ListenUDP(srv.PrivateKey, sconn, realaddr, "", srv.NetRestrict) //srv.NodeDatabase)
+ } else {
+ ntab, err = discv5.ListenUDP(srv.PrivateKey, conn, realaddr, "", srv.NetRestrict) //srv.NodeDatabase)
+ }
if err != nil {
return err
}