diff options
-rw-r--r-- | dex/backend.go | 1 | ||||
-rw-r--r-- | dex/handler.go | 22 | ||||
-rw-r--r-- | dex/nodetable.go | 4 | ||||
-rw-r--r-- | dex/peer.go | 463 | ||||
-rw-r--r-- | dex/peer_test.go | 578 | ||||
-rw-r--r-- | dex/protocol.go | 4 | ||||
-rw-r--r-- | dex/protocol_test.go | 10 | ||||
-rw-r--r-- | p2p/dial.go | 59 | ||||
-rw-r--r-- | p2p/dial_test.go | 131 | ||||
-rw-r--r-- | p2p/server.go | 146 | ||||
-rw-r--r-- | p2p/server_test.go | 90 |
11 files changed, 424 insertions, 1084 deletions
diff --git a/dex/backend.go b/dex/backend.go index b16bdc493..6c5c03861 100644 --- a/dex/backend.go +++ b/dex/backend.go @@ -245,7 +245,6 @@ func (s *Dexon) Start(srvr *p2p.Server) error { } // Start the networking layer and the light server if requested s.protocolManager.Start(srvr, maxPeers) - s.protocolManager.addSelfRecord() return nil } diff --git a/dex/handler.go b/dex/handler.go index df5516f62..490e1ec33 100644 --- a/dex/handler.go +++ b/dex/handler.go @@ -38,6 +38,7 @@ import ( "errors" "fmt" "math" + "math/rand" "net" "sync" "sync/atomic" @@ -304,10 +305,6 @@ func (pm *ProtocolManager) Start(srvr p2pServer, maxPeers int) { } -func (pm *ProtocolManager) addSelfRecord() { - pm.nodeTable.AddRecords([]*enr.Record{pm.srvr.Self().Record()}) -} - func (pm *ProtocolManager) Stop() { log.Info("Stopping Ethereum protocol") @@ -1177,10 +1174,27 @@ func (pm *ProtocolManager) finalizedBlockBroadcastLoop() { } func (pm *ProtocolManager) recordBroadcastLoop() { + r := rand.New(rand.NewSource(time.Now().Unix())) + t := time.NewTimer(0) + defer t.Stop() + for { select { case event := <-pm.recordsCh: pm.BroadcastRecords(event.Records) + pm.peers.Refresh() + + case <-t.C: + record := pm.srvr.Self().Record() + log.Debug("refresh our node record", "seq", record.Seq()) + pm.nodeTable.AddRecords([]*enr.Record{record}) + + // Log current peers connection status. + pm.peers.Status() + + // Reset timer. + d := 1*time.Minute + time.Duration(r.Int63n(60))*time.Second + t.Reset(d) // Err() channel will be closed when unsubscribing. case <-pm.recordsSub.Err(): diff --git a/dex/nodetable.go b/dex/nodetable.go index 12cc9ba46..ba1c28994 100644 --- a/dex/nodetable.go +++ b/dex/nodetable.go @@ -52,7 +52,9 @@ func (t *nodeTable) AddRecords(records []*enr.Record) { log.Debug("Add new record to node table", "id", node.ID().String(), "ip", node.IP().String(), "udp", node.UDP(), "tcp", node.TCP()) } - t.feed.Send(newRecordsEvent{newRecords}) + if len(newRecords) > 0 { + go t.feed.Send(newRecordsEvent{newRecords}) + } } func (t *nodeTable) Records() []*enr.Record { diff --git a/dex/peer.go b/dex/peer.go index 0a67db688..67a59348d 100644 --- a/dex/peer.go +++ b/dex/peer.go @@ -37,7 +37,6 @@ import ( "encoding/hex" "errors" "fmt" - "net" "sync" "time" @@ -105,7 +104,8 @@ const ( handshakeTimeout = 5 * time.Second - groupNodeNum = 3 + groupConnNum = 3 + groupConnTimeout = 3 * time.Minute ) // PeerInfo represents a short summary of the Ethereum sub-protocol metadata known @@ -129,6 +129,17 @@ type peerLabel struct { round uint64 } +func (p peerLabel) String() string { + var t string + switch p.set { + case dkgset: + t = fmt.Sprintf("DKGSet round: %d", p.round) + case notaryset: + t = fmt.Sprintf("NotarySet round: %d chain: %d", p.round, p.chainID) + } + return t +} + type peer struct { id string @@ -711,28 +722,27 @@ type peerSet struct { tab *nodeTable selfPK string - srvr p2pServer - gov governance - peer2Labels map[string]map[peerLabel]struct{} - label2Peers map[peerLabel]map[string]struct{} - history map[uint64]struct{} - notaryHistory map[uint64]struct{} - dkgHistory map[uint64]struct{} + srvr p2pServer + gov governance + + label2Nodes map[peerLabel]map[string]*enode.Node + directConn map[peerLabel]struct{} + groupConnPeers map[peerLabel]map[string]time.Time + allDirectPeers map[string]map[peerLabel]struct{} } // newPeerSet creates a new peer set to track the active participants. func newPeerSet(gov governance, srvr p2pServer, tab *nodeTable) *peerSet { return &peerSet{ - peers: make(map[string]*peer), - gov: gov, - srvr: srvr, - tab: tab, - selfPK: hex.EncodeToString(crypto.FromECDSAPub(&srvr.GetPrivateKey().PublicKey)), - peer2Labels: make(map[string]map[peerLabel]struct{}), - label2Peers: make(map[peerLabel]map[string]struct{}), - history: make(map[uint64]struct{}), - notaryHistory: make(map[uint64]struct{}), - dkgHistory: make(map[uint64]struct{}), + peers: make(map[string]*peer), + gov: gov, + srvr: srvr, + tab: tab, + selfPK: hex.EncodeToString(crypto.FromECDSAPub(&srvr.GetPrivateKey().PublicKey)), + label2Nodes: make(map[peerLabel]map[string]*enode.Node), + directConn: make(map[peerLabel]struct{}), + groupConnPeers: make(map[peerLabel]map[string]time.Time), + allDirectPeers: make(map[string]map[peerLabel]struct{}), } } @@ -832,8 +842,8 @@ func (ps *peerSet) PeersWithoutTx(hash common.Hash) []*peer { func (ps *peerSet) PeersWithLabel(label peerLabel) []*peer { ps.lock.RLock() defer ps.lock.RUnlock() - list := make([]*peer, 0, len(ps.label2Peers[label])) - for id := range ps.label2Peers[label] { + list := make([]*peer, 0, len(ps.label2Nodes[label])) + for id := range ps.label2Nodes[label] { if p, ok := ps.peers[id]; ok { list = append(list, p) } @@ -845,8 +855,8 @@ func (ps *peerSet) PeersWithoutVote(hash common.Hash, label peerLabel) []*peer { ps.lock.RLock() defer ps.lock.RUnlock() - list := make([]*peer, 0, len(ps.label2Peers[label])) - for id := range ps.label2Peers[label] { + list := make([]*peer, 0, len(ps.label2Nodes[label])) + for id := range ps.label2Nodes[label] { if p, ok := ps.peers[id]; ok { if !p.knownVotes.Contains(hash) { list = append(list, p) @@ -948,345 +958,178 @@ func (ps *peerSet) Close() { } func (ps *peerSet) BuildConnection(round uint64) { - ps.lock.Lock() - defer ps.lock.Unlock() - defer ps.dumpPeerLabel(fmt.Sprintf("BuildConnection: %d", round)) - - ps.history[round] = struct{}{} + dkgLabel := peerLabel{set: dkgset, round: round} + if _, ok := ps.label2Nodes[dkgLabel]; !ok { + dkgPKs, err := ps.gov.DKGSet(round) + if err != nil { + log.Error("get DKG set fail", "round", round, "err", err) + } - dkgPKs, err := ps.gov.DKGSet(round) - if err != nil { - log.Error("get dkg set fail", "round", round, "err", err) - } + nodes := ps.pksToNodes(dkgPKs) + ps.label2Nodes[dkgLabel] = nodes - // build dkg connection - _, inDKGSet := dkgPKs[ps.selfPK] - if inDKGSet { - delete(dkgPKs, ps.selfPK) - dkgLabel := peerLabel{set: dkgset, round: round} - for pk := range dkgPKs { - ps.addDirectPeer(pk, dkgLabel) + if _, exists := nodes[ps.srvr.Self().ID().String()]; exists { + ps.buildDirectConn(dkgLabel) + } else { + ps.buildGroupConn(dkgLabel) } } - var inOneNotarySet bool - for cid := uint32(0); cid < ps.gov.GetNumChains(round); cid++ { - notaryPKs, err := ps.gov.NotarySet(round, cid) - if err != nil { - log.Error("get notary set fail", - "round", round, "chain id", cid, "err", err) - continue - } - label := peerLabel{set: notaryset, chainID: cid, round: round} - // not in notary set, add group - if _, ok := notaryPKs[ps.selfPK]; !ok { - var nodes []*enode.Node - for pk := range notaryPKs { - node := ps.newNode(pk) - nodes = append(nodes, node) - ps.addLabel(node, label) + for chainID := uint32(0); chainID < ps.gov.GetNumChains(round); chainID++ { + notaryLabel := peerLabel{set: notaryset, chainID: chainID, round: round} + if _, ok := ps.label2Nodes[notaryLabel]; !ok { + notaryPKs, err := ps.gov.NotarySet(round, chainID) + if err != nil { + log.Error("get notary set fail", + "round", round, "chainID", chainID, "err", err) + continue } - ps.srvr.AddGroup(notarySetName(cid, round), nodes, groupNodeNum) - continue - } - delete(notaryPKs, ps.selfPK) - for pk := range notaryPKs { - ps.addDirectPeer(pk, label) - } - inOneNotarySet = true - } + nodes := ps.pksToNodes(notaryPKs) + ps.label2Nodes[notaryLabel] = nodes - // build some connections to DKG nodes - if !inDKGSet && inOneNotarySet { - var nodes []*enode.Node - label := peerLabel{set: dkgset, round: round} - for pk := range dkgPKs { - node := ps.newNode(pk) - nodes = append(nodes, node) - ps.addLabel(node, label) + if _, exists := nodes[ps.srvr.Self().ID().String()]; exists { + ps.buildDirectConn(notaryLabel) + } else { + ps.buildGroupConn(notaryLabel) + } } - ps.srvr.AddGroup(dkgSetName(round), nodes, groupNodeNum) } } func (ps *peerSet) ForgetConnection(round uint64) { ps.lock.Lock() defer ps.lock.Unlock() - defer ps.dumpPeerLabel(fmt.Sprintf("ForgetConnection: %d", round)) - for r := range ps.history { - if r <= round { - ps.forgetConnection(round) - delete(ps.history, r) + for label := range ps.directConn { + if label.round <= round { + ps.forgetDirectConn(label) } } -} - -func (ps *peerSet) forgetConnection(round uint64) { - dkgPKs, err := ps.gov.DKGSet(round) - if err != nil { - log.Error("get dkg set fail", "round", round, "err", err) - } - _, inDKGSet := dkgPKs[ps.selfPK] - if inDKGSet { - delete(dkgPKs, ps.selfPK) - label := peerLabel{set: dkgset, round: round} - for id := range dkgPKs { - ps.removeDirectPeer(id, label) + for label := range ps.groupConnPeers { + if label.round <= round { + ps.forgetGroupConn(label) } } - var inOneNotarySet bool - for cid := uint32(0); cid < ps.gov.GetNumChains(round); cid++ { - notaryPKs, err := ps.gov.NotarySet(round, cid) - if err != nil { - log.Error("get notary set fail", - "round", round, "chain id", cid, "err", err) - continue + for label := range ps.label2Nodes { + if label.round <= round { + delete(ps.label2Nodes, label) } - - label := peerLabel{set: notaryset, chainID: cid, round: round} - - // not in notary set, add group - if _, ok := notaryPKs[ps.selfPK]; !ok { - var nodes []*enode.Node - for id := range notaryPKs { - node := ps.newNode(id) - nodes = append(nodes, node) - ps.removeLabel(node, label) - } - ps.srvr.RemoveGroup(notarySetName(cid, round)) - continue - } - - delete(notaryPKs, ps.selfPK) - for pk := range notaryPKs { - ps.removeDirectPeer(pk, label) - } - inOneNotarySet = true - } - - // build some connections to DKG nodes - if !inDKGSet && inOneNotarySet { - var nodes []*enode.Node - label := peerLabel{set: dkgset, round: round} - for id := range dkgPKs { - node := ps.newNode(id) - nodes = append(nodes, node) - ps.removeLabel(node, label) - } - ps.srvr.RemoveGroup(dkgSetName(round)) } } -func (ps *peerSet) BuildNotaryConn(round uint64) { +func (ps *peerSet) EnsureGroupConn() { ps.lock.Lock() defer ps.lock.Unlock() - defer ps.dumpPeerLabel(fmt.Sprintf("BuildNotaryConn: %d", round)) - - if _, ok := ps.notaryHistory[round]; ok { - return - } - - ps.notaryHistory[round] = struct{}{} - for chainID := uint32(0); chainID < ps.gov.GetNumChains(round); chainID++ { - s, err := ps.gov.NotarySet(round, chainID) - if err != nil { - log.Error("get notary set fail", - "round", round, "chain id", chainID, "err", err) - continue - } - - // not in notary set, add group - if _, ok := s[ps.selfPK]; !ok { - var nodes []*enode.Node - for id := range s { - nodes = append(nodes, ps.newNode(id)) + now := time.Now() + for label, peers := range ps.groupConnPeers { + // Remove timeout group conn peer. + for id, addtime := range peers { + if ps.peers[id] == nil && time.Since(addtime) > groupConnTimeout { + ps.removeDirectPeer(id, label) + delete(ps.groupConnPeers[label], id) } - ps.srvr.AddGroup(notarySetName(chainID, round), nodes, groupNodeNum) - continue - } - - label := peerLabel{ - set: notaryset, - chainID: chainID, - round: round, } - delete(s, ps.selfPK) - for pk := range s { - ps.addDirectPeer(pk, label) - } - } -} -func (ps *peerSet) dumpPeerLabel(s string) { - log.Debug(s, "peer num", len(ps.peers)) - for id, labels := range ps.peer2Labels { - _, ok := ps.peers[id] - for label := range labels { - log.Debug(s, "connected", ok, "id", id[:16], - "round", label.round, "cid", label.chainID, "set", label.set) + // Add new group conn peer. + for id := range ps.label2Nodes[label] { + if len(ps.groupConnPeers[label]) >= groupConnNum { + break + } + ps.groupConnPeers[label][id] = now + ps.addDirectPeer(id, label) } } } -func (ps *peerSet) ForgetNotaryConn(round uint64) { +func (ps *peerSet) Refresh() { ps.lock.Lock() defer ps.lock.Unlock() - defer ps.dumpPeerLabel(fmt.Sprintf("ForgetNotaryConn: %d", round)) - - // forget all the rounds before the given round - for r := range ps.notaryHistory { - if r <= round { - ps.forgetNotaryConn(r) - delete(ps.notaryHistory, r) + for id := range ps.allDirectPeers { + if ps.peers[id] == nil { + if node := ps.tab.GetNode(enode.HexID(id)); node != nil { + ps.srvr.AddDirectPeer(node) + } } } } -func (ps *peerSet) forgetNotaryConn(round uint64) { - for chainID := uint32(0); chainID < ps.gov.GetNumChains(round); chainID++ { - s, err := ps.gov.NotarySet(round, chainID) - if err != nil { - log.Error("get notary set fail", - "round", round, "chain id", chainID, "err", err) - continue - } - if _, ok := s[ps.selfPK]; !ok { - ps.srvr.RemoveGroup(notarySetName(chainID, round)) - continue - } - - label := peerLabel{ - set: notaryset, - chainID: chainID, - round: round, - } - delete(s, ps.selfPK) - for pk := range s { - ps.removeDirectPeer(pk, label) - } +func (ps *peerSet) buildDirectConn(label peerLabel) { + ps.directConn[label] = struct{}{} + for id := range ps.label2Nodes[label] { + ps.addDirectPeer(id, label) } } -func notarySetName(chainID uint32, round uint64) string { - return fmt.Sprintf("%d-%d-notaryset", chainID, round) -} - -func dkgSetName(round uint64) string { - return fmt.Sprintf("%d-dkgset", round) -} - -func (ps *peerSet) BuildDKGConn(round uint64) { - ps.lock.Lock() - defer ps.lock.Unlock() - defer ps.dumpPeerLabel(fmt.Sprintf("BuildDKGConn: %d", round)) - s, err := ps.gov.DKGSet(round) - if err != nil { - log.Error("get dkg set fail", "round", round) - return - } - - if _, ok := s[ps.selfPK]; !ok { - return - } - ps.dkgHistory[round] = struct{}{} - - delete(s, ps.selfPK) - for pk := range s { - ps.addDirectPeer(pk, peerLabel{ - set: dkgset, - round: round, - }) +func (ps *peerSet) forgetDirectConn(label peerLabel) { + for id := range ps.label2Nodes[label] { + ps.removeDirectPeer(id, label) } + delete(ps.directConn, label) } -func (ps *peerSet) ForgetDKGConn(round uint64) { - ps.lock.Lock() - defer ps.lock.Unlock() - defer ps.dumpPeerLabel(fmt.Sprintf("ForgetDKGConn: %d", round)) - - // forget all the rounds before the given round - for r := range ps.dkgHistory { - if r <= round { - ps.forgetDKGConn(r) - delete(ps.dkgHistory, r) +func (ps *peerSet) buildGroupConn(label peerLabel) { + peers := make(map[string]time.Time) + now := time.Now() + for id := range ps.label2Nodes[label] { + peers[id] = now + ps.addDirectPeer(id, label) + if len(peers) >= groupConnNum { + break } } + ps.groupConnPeers[label] = peers } -func (ps *peerSet) forgetDKGConn(round uint64) { - s, err := ps.gov.DKGSet(round) - if err != nil { - log.Error("get dkg set fail", "round", round) - return +func (ps *peerSet) forgetGroupConn(label peerLabel) { + for id := range ps.groupConnPeers[label] { + ps.removeDirectPeer(id, label) } - if _, ok := s[ps.selfPK]; !ok { + delete(ps.groupConnPeers, label) +} + +func (ps *peerSet) addDirectPeer(id string, label peerLabel) { + if len(ps.allDirectPeers[id]) > 0 { + ps.allDirectPeers[id][label] = struct{}{} return } + ps.allDirectPeers[id] = map[peerLabel]struct{}{label: {}} - delete(s, ps.selfPK) - label := peerLabel{ - set: dkgset, - round: round, + node := ps.tab.GetNode(enode.HexID(id)) + if node == nil { + node = ps.label2Nodes[label][id] } - for pk := range s { - ps.removeDirectPeer(pk, label) - } -} - -// make sure the ps.lock is held -func (ps *peerSet) addDirectPeer(pk string, label peerLabel) { - node := ps.newNode(pk) - ps.addLabel(node, label) ps.srvr.AddDirectPeer(node) } -// make sure the ps.lock is held -func (ps *peerSet) removeDirectPeer(pk string, label peerLabel) { - node := ps.newNode(pk) - ps.removeLabel(node, label) - if len(ps.peer2Labels[node.ID().String()]) == 0 { - ps.srvr.RemoveDirectPeer(node) +func (ps *peerSet) removeDirectPeer(id string, label peerLabel) { + if len(ps.allDirectPeers[id]) == 0 { + return } -} - -// make sure the ps.lock is held -func (ps *peerSet) addLabel(node *enode.Node, label peerLabel) { - id := node.ID().String() - if _, ok := ps.peer2Labels[id]; !ok { - ps.peer2Labels[id] = make(map[peerLabel]struct{}) + delete(ps.allDirectPeers[id], label) + if len(ps.allDirectPeers[id]) == 0 { + ps.srvr.RemoveDirectPeer(ps.label2Nodes[label][id]) + delete(ps.allDirectPeers, id) } - if _, ok := ps.label2Peers[label]; !ok { - ps.label2Peers[label] = make(map[string]struct{}) - } - ps.peer2Labels[id][label] = struct{}{} - ps.label2Peers[label][id] = struct{}{} } -// make sure the ps.lock is held -func (ps *peerSet) removeLabel(node *enode.Node, label peerLabel) { - id := node.ID().String() - - delete(ps.peer2Labels[id], label) - delete(ps.label2Peers[label], id) - if len(ps.peer2Labels[id]) == 0 { - delete(ps.peer2Labels, id) - } - if len(ps.label2Peers[label]) == 0 { - delete(ps.label2Peers, label) +func (ps *peerSet) pksToNodes(pks map[string]struct{}) map[string]*enode.Node { + nodes := map[string]*enode.Node{} + for pk := range pks { + n := ps.newEmptyNode(pk) + if n.ID() == ps.srvr.Self().ID() { + n = ps.srvr.Self() + } + nodes[n.ID().String()] = n } + return nodes } -// TODO: improve this by not using pk. -func (ps *peerSet) newNode(pk string) *enode.Node { - var ip net.IP - var tcp, udp int - +func (ps *peerSet) newEmptyNode(pk string) *enode.Node { b, err := hex.DecodeString(pk) if err != nil { panic(err) @@ -1296,10 +1139,34 @@ func (ps *peerSet) newNode(pk string) *enode.Node { if err != nil { panic(err) } + return enode.NewV4(pubkey, nil, 0, 0) +} + +func (ps *peerSet) Status() { + ps.lock.Lock() + defer ps.lock.Unlock() + for label := range ps.directConn { + l := label.String() + for id := range ps.label2Nodes[label] { + _, ok := ps.peers[id] + log.Debug("direct conn", "label", l, "id", id, "connected", ok) + } + } + + for label, peers := range ps.groupConnPeers { + l := label.String() + for id := range peers { + _, ok := ps.peers[id] + log.Debug("group conn", "label", l, "id", id, "connected", ok) + } + } - node := ps.tab.GetNode(enode.PubkeyToIDV4(pubkey)) - if node != nil { - return node + connected := 0 + for id := range ps.allDirectPeers { + if _, ok := ps.peers[id]; ok { + connected++ + } } - return enode.NewV4(pubkey, ip, tcp, udp) + log.Debug("all direct peers", + "connected", connected, "all", len(ps.allDirectPeers)) } diff --git a/dex/peer_test.go b/dex/peer_test.go index 9caa62d1e..29b4971c5 100644 --- a/dex/peer_test.go +++ b/dex/peer_test.go @@ -1,15 +1,16 @@ package dex import ( + "crypto/ecdsa" "encoding/hex" - "fmt" + "reflect" "testing" "github.com/dexon-foundation/dexon/crypto" "github.com/dexon-foundation/dexon/p2p/enode" ) -func TestPeerSetBuildAndForgetNotaryConn(t *testing.T) { +func TestPeerSetBuildAndForgetConn(t *testing.T) { key, err := crypto.GenerateKey() if err != nil { t.Fatal(err) @@ -26,7 +27,7 @@ func TestPeerSetBuildAndForgetNotaryConn(t *testing.T) { var nodes []*enode.Node for i := 0; i < 9; i++ { - nodes = append(nodes, randomNode()) + nodes = append(nodes, randomV4CompactNode()) } round10 := [][]*enode.Node{ @@ -55,445 +56,268 @@ func TestPeerSetBuildAndForgetNotaryConn(t *testing.T) { return newTestNodeSet(m[round][cid]), nil } - ps := newPeerSet(gov, server, table) - peer1 := newDummyPeer(nodes[1]) - peer2 := newDummyPeer(nodes[2]) - err = ps.Register(peer1) - if err != nil { - t.Error(err) - } - err = ps.Register(peer2) - if err != nil { - t.Error(err) - } - - // build round 10 - ps.BuildNotaryConn(10) - - err = checkPeer2Labels(ps, map[string][]peerLabel{ - nodes[1].ID().String(): { - {notaryset, 0, 10}, - }, - nodes[2].ID().String(): { - {notaryset, 0, 10}, - }, - }) - if err != nil { - t.Error(err) - } - err = checkPeerSetHistory(ps, []uint64{10}, notaryset) - if err != nil { - t.Error(err) - } - err = checkDirectPeer(server, []enode.ID{ - nodes[1].ID(), nodes[2].ID(), - }) - if err != nil { - t.Error(err) - } - err = checkGroup(server, []string{ - notarySetName(1, 10), - notarySetName(2, 10), - }) - if err != nil { - t.Error(err) + gov.dkgSetFunc = func(round uint64) (map[string]struct{}, error) { + m := map[uint64][]*enode.Node{ + 10: {self, nodes[1], nodes[3]}, + 11: {nodes[1], nodes[2], nodes[5]}, + 12: {self, nodes[3], nodes[5]}, + } + return newTestNodeSet(m[round]), nil } - // build round 11 - ps.BuildNotaryConn(11) + ps := newPeerSet(gov, server, table) - err = checkPeer2Labels(ps, map[string][]peerLabel{ - nodes[1].ID().String(): { - {notaryset, 0, 10}, - {notaryset, 0, 11}, + // build round 10 + ps.BuildConnection(10) + ps.BuildConnection(11) + ps.BuildConnection(12) + + expectedlabel2Nodes := map[peerLabel]map[string]*enode.Node{ + {set: notaryset, round: 10, chainID: 0}: { + self.ID().String(): self, + nodes[1].ID().String(): nodes[1], + nodes[2].ID().String(): nodes[2], }, - nodes[2].ID().String(): { - {notaryset, 0, 10}, - {notaryset, 2, 11}, + {set: notaryset, round: 10, chainID: 1}: { + nodes[1].ID().String(): nodes[1], + nodes[3].ID().String(): nodes[3], }, - nodes[4].ID().String(): { - {notaryset, 2, 11}, + {set: notaryset, round: 10, chainID: 2}: { + nodes[2].ID().String(): nodes[2], + nodes[4].ID().String(): nodes[4], }, - nodes[5].ID().String(): { - {notaryset, 0, 11}, + {set: dkgset, round: 10}: { + self.ID().String(): self, + nodes[1].ID().String(): nodes[1], + nodes[3].ID().String(): nodes[3], }, - }) - if err != nil { - t.Error(err) - } - err = checkPeerSetHistory(ps, []uint64{10, 11}, notaryset) - if err != nil { - t.Error(err) - } - err = checkDirectPeer(server, []enode.ID{ - nodes[1].ID(), nodes[2].ID(), nodes[4].ID(), nodes[5].ID(), - }) - if err != nil { - t.Error(err) - } - err = checkGroup(server, []string{ - notarySetName(1, 10), - notarySetName(2, 10), - notarySetName(1, 11), - }) - if err != nil { - t.Error(err) - } - - // build round 12 - ps.BuildNotaryConn(12) - - err = checkPeer2Labels(ps, map[string][]peerLabel{ - nodes[1].ID().String(): { - {notaryset, 0, 10}, - {notaryset, 0, 11}, + {set: notaryset, round: 11, chainID: 0}: { + self.ID().String(): self, + nodes[1].ID().String(): nodes[1], + nodes[5].ID().String(): nodes[5], }, - nodes[2].ID().String(): { - {notaryset, 0, 10}, - {notaryset, 2, 11}, - {notaryset, 2, 12}, + {set: notaryset, round: 11, chainID: 1}: { + nodes[5].ID().String(): nodes[5], + nodes[6].ID().String(): nodes[6], }, - nodes[3].ID().String(): { - {notaryset, 0, 12}, + {set: notaryset, round: 11, chainID: 2}: { + self.ID().String(): self, + nodes[2].ID().String(): nodes[2], + nodes[4].ID().String(): nodes[4], }, - nodes[4].ID().String(): { - {notaryset, 2, 11}, + {set: dkgset, round: 11}: { + nodes[1].ID().String(): nodes[1], + nodes[2].ID().String(): nodes[2], + nodes[5].ID().String(): nodes[5], }, - nodes[5].ID().String(): { - {notaryset, 0, 11}, - {notaryset, 0, 12}, + {set: notaryset, round: 12, chainID: 0}: { + self.ID().String(): self, + nodes[3].ID().String(): nodes[3], + nodes[5].ID().String(): nodes[5], }, - nodes[6].ID().String(): { - {notaryset, 2, 12}, + {set: notaryset, round: 12, chainID: 1}: { + self.ID().String(): self, + nodes[7].ID().String(): nodes[7], + nodes[8].ID().String(): nodes[8], }, - nodes[7].ID().String(): { - {notaryset, 1, 12}, + {set: notaryset, round: 12, chainID: 2}: { + self.ID().String(): self, + nodes[2].ID().String(): nodes[2], + nodes[6].ID().String(): nodes[6], }, - nodes[8].ID().String(): { - {notaryset, 1, 12}, + {set: dkgset, round: 12}: { + self.ID().String(): self, + nodes[3].ID().String(): nodes[3], + nodes[5].ID().String(): nodes[5], }, - }) - if err != nil { - t.Error(err) - } - err = checkPeerSetHistory(ps, []uint64{10, 11, 12}, notaryset) - if err != nil { - t.Error(err) } - err = checkDirectPeer(server, []enode.ID{ - nodes[1].ID(), nodes[2].ID(), nodes[3].ID(), nodes[4].ID(), - nodes[5].ID(), nodes[6].ID(), nodes[7].ID(), nodes[8].ID(), - }) - if err != nil { - t.Error(err) - } - err = checkGroup(server, []string{ - notarySetName(1, 10), - notarySetName(2, 10), - notarySetName(1, 11), - }) - if err != nil { - t.Error(err) - } - - // forget round 11 - ps.ForgetNotaryConn(11) - err = checkPeer2Labels(ps, map[string][]peerLabel{ - nodes[2].ID().String(): { - {notaryset, 2, 12}, - }, - nodes[3].ID().String(): { - {notaryset, 0, 12}, - }, - nodes[5].ID().String(): { - {notaryset, 0, 12}, - }, - nodes[6].ID().String(): { - {notaryset, 2, 12}, - }, - nodes[7].ID().String(): { - {notaryset, 1, 12}, - }, - nodes[8].ID().String(): { - {notaryset, 1, 12}, - }, - }) - if err != nil { - t.Error(err) - } - err = checkPeerSetHistory(ps, []uint64{12}, notaryset) - if err != nil { - t.Error(err) - } - err = checkDirectPeer(server, []enode.ID{ - nodes[2].ID(), nodes[3].ID(), - nodes[5].ID(), nodes[6].ID(), nodes[7].ID(), nodes[8].ID(), - }) - if err != nil { - t.Error(err) - } - err = checkGroup(server, []string{}) - if err != nil { - t.Error(err) + if !reflect.DeepEqual(ps.label2Nodes, expectedlabel2Nodes) { + t.Errorf("label2Nodes not match") } - // forget round 12 - ps.ForgetNotaryConn(12) - err = checkPeer2Labels(ps, map[string][]peerLabel{}) - if err != nil { - t.Error(err) + expectedDirectConn := map[peerLabel]struct{}{ + {set: notaryset, round: 10, chainID: 0}: {}, + {set: notaryset, round: 11, chainID: 0}: {}, + {set: notaryset, round: 11, chainID: 2}: {}, + {set: notaryset, round: 12, chainID: 0}: {}, + {set: notaryset, round: 12, chainID: 1}: {}, + {set: notaryset, round: 12, chainID: 2}: {}, + {set: dkgset, round: 10}: {}, + {set: dkgset, round: 12}: {}, } - err = checkPeerSetHistory(ps, []uint64{}, notaryset) - if err != nil { - t.Error(err) - } - err = checkDirectPeer(server, []enode.ID{}) - if err != nil { - t.Error(err) - } - err = checkGroup(server, []string{}) - if err != nil { - t.Error(err) - } - -} -func TestPeerSetBuildDKGConn(t *testing.T) { - key, err := crypto.GenerateKey() - if err != nil { - t.Fatal(err) + if !reflect.DeepEqual(ps.directConn, expectedDirectConn) { + t.Errorf("direct conn not match") } - server := newTestP2PServer(key) - self := server.Self() - table := newNodeTable() - var nodes []*enode.Node - for i := 0; i < 6; i++ { - nodes = append(nodes, randomNode()) + expectedGroupConn := []peerLabel{ + {set: notaryset, round: 10, chainID: 1}, + {set: notaryset, round: 10, chainID: 2}, + {set: notaryset, round: 11, chainID: 1}, + {set: dkgset, round: 11}, } - gov := &testGovernance{} - - gov.dkgSetFunc = func(round uint64) (map[string]struct{}, error) { - m := map[uint64][]*enode.Node{ - 10: {self, nodes[1], nodes[2]}, - 11: {nodes[1], nodes[2], nodes[5]}, - 12: {self, nodes[3], nodes[5]}, - } - return newTestNodeSet(m[round]), nil + if len(ps.groupConnPeers) != len(expectedGroupConn) { + t.Errorf("group conn peers not match") } - ps := newPeerSet(gov, server, table) - peer1 := newDummyPeer(nodes[1]) - peer2 := newDummyPeer(nodes[2]) - err = ps.Register(peer1) - if err != nil { - t.Error(err) - } - err = ps.Register(peer2) - if err != nil { - t.Error(err) + for _, l := range expectedGroupConn { + if len(ps.groupConnPeers[l]) == 0 { + t.Errorf("group conn peers is 0") + } } - // build round 10 - ps.BuildDKGConn(10) + expectedAllDirect := make(map[string]map[peerLabel]struct{}) - err = checkPeer2Labels(ps, map[string][]peerLabel{ - nodes[1].ID().String(): { - {dkgset, 0, 10}, - }, - nodes[2].ID().String(): { - {dkgset, 0, 10}, - }, - }) - if err != nil { - t.Error(err) - } - err = checkPeerSetHistory(ps, []uint64{10}, dkgset) - if err != nil { - t.Error(err) - } - err = checkDirectPeer(server, []enode.ID{ - nodes[1].ID(), nodes[2].ID(), - }) - if err != nil { - t.Error(err) + for l := range ps.directConn { + for id := range ps.label2Nodes[l] { + if expectedAllDirect[id] == nil { + expectedAllDirect[id] = make(map[peerLabel]struct{}) + } + expectedAllDirect[id][l] = struct{}{} + } } - // build round 11 - ps.BuildDKGConn(11) - - err = checkPeer2Labels(ps, map[string][]peerLabel{ - nodes[1].ID().String(): { - {dkgset, 0, 10}, - }, - nodes[2].ID().String(): { - {dkgset, 0, 10}, - }, - }) - if err != nil { - t.Error(err) - } - err = checkPeerSetHistory(ps, []uint64{10}, dkgset) - if err != nil { - t.Error(err) + for l, peers := range ps.groupConnPeers { + for id := range peers { + if expectedAllDirect[id] == nil { + expectedAllDirect[id] = make(map[peerLabel]struct{}) + } + expectedAllDirect[id][l] = struct{}{} + } } - err = checkDirectPeer(server, []enode.ID{ - nodes[1].ID(), nodes[2].ID(), - }) - if err != nil { - t.Error(err) + + if !reflect.DeepEqual(ps.allDirectPeers, expectedAllDirect) { + t.Errorf("all direct peers not match") } - // build round 12 - ps.BuildDKGConn(12) + // forget round 11 + ps.ForgetConnection(11) - err = checkPeer2Labels(ps, map[string][]peerLabel{ - nodes[1].ID().String(): { - {dkgset, 0, 10}, + expectedlabel2Nodes = map[peerLabel]map[string]*enode.Node{ + {set: notaryset, round: 12, chainID: 0}: { + self.ID().String(): self, + nodes[3].ID().String(): nodes[3], + nodes[5].ID().String(): nodes[5], }, - nodes[2].ID().String(): { - {dkgset, 0, 10}, + {set: notaryset, round: 12, chainID: 1}: { + self.ID().String(): self, + nodes[7].ID().String(): nodes[7], + nodes[8].ID().String(): nodes[8], }, - nodes[3].ID().String(): { - {dkgset, 0, 12}, + {set: notaryset, round: 12, chainID: 2}: { + self.ID().String(): self, + nodes[2].ID().String(): nodes[2], + nodes[6].ID().String(): nodes[6], }, - nodes[5].ID().String(): { - {dkgset, 0, 12}, + {set: dkgset, round: 12}: { + self.ID().String(): self, + nodes[3].ID().String(): nodes[3], + nodes[5].ID().String(): nodes[5], }, - }) - if err != nil { - t.Error(err) - } - err = checkPeerSetHistory(ps, []uint64{10, 12}, dkgset) - if err != nil { - t.Error(err) - } - err = checkDirectPeer(server, []enode.ID{ - nodes[1].ID(), nodes[2].ID(), nodes[3].ID(), nodes[5].ID(), - }) - if err != nil { - t.Error(err) } - // forget round 11 - ps.ForgetDKGConn(11) - - err = checkPeer2Labels(ps, map[string][]peerLabel{ - nodes[3].ID().String(): { - {dkgset, 0, 12}, - }, - nodes[5].ID().String(): { - {dkgset, 0, 12}, - }, - }) - if err != nil { - t.Error(err) - } - err = checkPeerSetHistory(ps, []uint64{12}, dkgset) - if err != nil { - t.Error(err) - } - err = checkDirectPeer(server, []enode.ID{ - nodes[3].ID(), nodes[5].ID(), - }) - if err != nil { - t.Error(err) + if !reflect.DeepEqual(ps.label2Nodes, expectedlabel2Nodes) { + t.Errorf("label2Nodes not match") } - // forget round 12 - ps.ForgetDKGConn(12) - err = checkPeer2Labels(ps, map[string][]peerLabel{}) - if err != nil { - t.Error(err) + expectedDirectConn = map[peerLabel]struct{}{ + {set: notaryset, round: 12, chainID: 0}: {}, + {set: notaryset, round: 12, chainID: 1}: {}, + {set: notaryset, round: 12, chainID: 2}: {}, + {set: dkgset, round: 12}: {}, } - err = checkPeerSetHistory(ps, []uint64{}, dkgset) - if err != nil { - t.Error(err) - } - err = checkDirectPeer(server, []enode.ID{}) - if err != nil { - t.Error(err) + + if !reflect.DeepEqual(ps.directConn, expectedDirectConn) { + t.Error("direct conn not match") } -} -func checkPeer2Labels(ps *peerSet, want map[string][]peerLabel) error { - if len(ps.peer2Labels) != len(want) { - return fmt.Errorf("peer num mismatch: got %d, want %d", - len(ps.peer2Labels), len(want)) + expectedGroupConn = []peerLabel{} + + if len(ps.groupConnPeers) != len(expectedGroupConn) { + t.Errorf("group conn peers not match") } - for peerID, gotLabels := range ps.peer2Labels { - wantLabels, ok := want[peerID] - if !ok { - return fmt.Errorf("peer id %s not exists", peerID) + for _, l := range expectedGroupConn { + if len(ps.groupConnPeers[l]) == 0 { + t.Errorf("group conn peers is 0") } + } - if len(gotLabels) != len(wantLabels) { - return fmt.Errorf( - "num of labels of peer id %s mismatch: got %d, want %d", - peerID, len(gotLabels), len(wantLabels)) + expectedAllDirect = make(map[string]map[peerLabel]struct{}) + + for l := range ps.directConn { + for id := range ps.label2Nodes[l] { + if expectedAllDirect[id] == nil { + expectedAllDirect[id] = make(map[peerLabel]struct{}) + } + expectedAllDirect[id][l] = struct{}{} } + } - for _, label := range wantLabels { - if _, ok := gotLabels[label]; !ok { - return fmt.Errorf("label: %+v not exists", label) + for l, peers := range ps.groupConnPeers { + for id := range peers { + if expectedAllDirect[id] == nil { + expectedAllDirect[id] = make(map[peerLabel]struct{}) } + expectedAllDirect[id][l] = struct{}{} } } - return nil -} -func checkPeerSetHistory(ps *peerSet, want []uint64, set setType) error { - var history map[uint64]struct{} - switch set { - case notaryset: - history = ps.notaryHistory - case dkgset: - history = ps.dkgHistory - default: - return fmt.Errorf("invalid set: %d", set) + if !reflect.DeepEqual(ps.allDirectPeers, expectedAllDirect) { + t.Errorf("all direct peers not match") } - if len(history) != len(want) { - return fmt.Errorf("num of history mismatch: got %d, want %d", - len(history), len(want)) + // forget round 12 + ps.ForgetConnection(12) + + expectedlabel2Nodes = map[peerLabel]map[string]*enode.Node{} + if !reflect.DeepEqual(ps.label2Nodes, expectedlabel2Nodes) { + t.Errorf("label2Nodes not match") } - for _, r := range want { - if _, ok := history[r]; !ok { - return fmt.Errorf("round %d not exists", r) - } + expectedDirectConn = map[peerLabel]struct{}{} + + if !reflect.DeepEqual(ps.directConn, expectedDirectConn) { + t.Error("direct conn not match") } - return nil -} -func checkDirectPeer(srvr *testP2PServer, want []enode.ID) error { - if len(srvr.direct) != len(want) { - return fmt.Errorf("num of direct peer mismatch: got %d, want %d", - len(srvr.direct), len(want)) + expectedGroupConn = []peerLabel{} + + if len(ps.groupConnPeers) != len(expectedGroupConn) { + t.Errorf("group conn peers not match") } - for _, id := range want { - if _, ok := srvr.direct[id]; !ok { - return fmt.Errorf("direct peer %s not exists", id.String()) + for _, l := range expectedGroupConn { + if len(ps.groupConnPeers[l]) == 0 { + t.Errorf("group conn peers is 0") } } - return nil -} -func checkGroup(srvr *testP2PServer, want []string) error { - if len(srvr.group) != len(want) { - return fmt.Errorf("num of group mismatch: got %d, want %d", - len(srvr.group), len(want)) + + expectedAllDirect = make(map[string]map[peerLabel]struct{}) + + for l := range ps.directConn { + for id := range ps.label2Nodes[l] { + if expectedAllDirect[id] == nil { + expectedAllDirect[id] = make(map[peerLabel]struct{}) + } + expectedAllDirect[id][l] = struct{}{} + } } - for _, name := range want { - if _, ok := srvr.group[name]; !ok { - return fmt.Errorf("group %s not exists", name) + for l, peers := range ps.groupConnPeers { + for id := range peers { + if expectedAllDirect[id] == nil { + expectedAllDirect[id] = make(map[peerLabel]struct{}) + } + expectedAllDirect[id][l] = struct{}{} } } - return nil + + if !reflect.DeepEqual(ps.allDirectPeers, expectedAllDirect) { + t.Errorf("all direct peers not match") + } } func newTestNodeSet(nodes []*enode.Node) map[string]struct{} { @@ -505,6 +329,14 @@ func newTestNodeSet(nodes []*enode.Node) map[string]struct{} { return m } -func newDummyPeer(node *enode.Node) *peer { - return &peer{id: node.ID().String()} +func randomV4CompactNode() *enode.Node { + var err error + var privkey *ecdsa.PrivateKey + for { + privkey, err = crypto.GenerateKey() + if err == nil { + break + } + } + return enode.NewV4(&privkey.PublicKey, nil, 0, 0) } diff --git a/dex/protocol.go b/dex/protocol.go index 9244b8bb7..0cb00ada6 100644 --- a/dex/protocol.go +++ b/dex/protocol.go @@ -169,10 +169,6 @@ type p2pServer interface { AddDirectPeer(*enode.Node) RemoveDirectPeer(*enode.Node) - - AddGroup(string, []*enode.Node, uint64) - - RemoveGroup(string) } // statusData is the network packet for the status message. diff --git a/dex/protocol_test.go b/dex/protocol_test.go index 5c68785e8..74778cdda 100644 --- a/dex/protocol_test.go +++ b/dex/protocol_test.go @@ -18,7 +18,6 @@ package dex import ( "crypto/ecdsa" - "encoding/hex" "fmt" "reflect" "sync" @@ -36,6 +35,7 @@ import ( "github.com/dexon-foundation/dexon/crypto" "github.com/dexon-foundation/dexon/dex/downloader" "github.com/dexon-foundation/dexon/p2p" + "github.com/dexon-foundation/dexon/p2p/enode" "github.com/dexon-foundation/dexon/p2p/enr" "github.com/dexon-foundation/dexon/rlp" ) @@ -553,11 +553,15 @@ func TestSendVote(t *testing.T) { }, } + pm.peers.label2Nodes = make(map[peerLabel]map[string]*enode.Node) for i, tt := range testPeers { p, _ := newTestPeer(fmt.Sprintf("peer #%d", i), dex64, pm, true) if tt.label != nil { - b := crypto.FromECDSAPub(p.Node().Pubkey()) - pm.peers.addDirectPeer(hex.EncodeToString(b), *tt.label) + if pm.peers.label2Nodes[*tt.label] == nil { + pm.peers.label2Nodes[*tt.label] = make(map[string]*enode.Node) + } + pm.peers.label2Nodes[*tt.label][p.ID().String()] = p.Node() + pm.peers.addDirectPeer(p.ID().String(), *tt.label) } wg.Add(1) go checkvote(p, tt.isReceiver) diff --git a/p2p/dial.go b/p2p/dial.go index 909bed863..99acade36 100644 --- a/p2p/dial.go +++ b/p2p/dial.go @@ -64,12 +64,6 @@ func (t TCPDialer) Dial(dest *enode.Node) (net.Conn, error) { return t.Dialer.Dial("tcp", addr.String()) } -type dialGroup struct { - name string - nodes map[enode.ID]*enode.Node - num uint64 -} - // dialstate schedules dials and discovery lookups. // it get's a chance to compute new tasks on every iteration // of the main loop in Server.run. @@ -85,7 +79,6 @@ type dialstate struct { randomNodes []*enode.Node // filled from Table static map[enode.ID]*dialTask direct map[enode.ID]*dialTask - group map[string]*dialGroup hist *dialHistory start time.Time // time when the dialer was first used @@ -143,7 +136,6 @@ func newDialState(self enode.ID, static []*enode.Node, bootnodes []*enode.Node, netrestrict: netrestrict, static: make(map[enode.ID]*dialTask), direct: make(map[enode.ID]*dialTask), - group: make(map[string]*dialGroup), dialing: make(map[enode.ID]connFlag), bootnodes: make([]*enode.Node, len(bootnodes)), randomNodes: make([]*enode.Node, maxdyn/2), @@ -179,14 +171,6 @@ func (s *dialstate) removeDirect(n *enode.Node) { s.hist.remove(n.ID()) } -func (s *dialstate) addGroup(g *dialGroup) { - s.group[g.name] = g -} - -func (s *dialstate) removeGroup(g *dialGroup) { - delete(s.group, g.name) -} - func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Time) []task { if s.start.IsZero() { s.start = now @@ -244,49 +228,6 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti } } - // compute connected - connected := map[string]map[enode.ID]struct{}{} - for _, g := range s.group { - connected[g.name] = map[enode.ID]struct{}{} - } - - for id := range peers { - for _, g := range s.group { - if _, ok := g.nodes[id]; ok { - connected[g.name][id] = struct{}{} - } - } - } - - for id := range s.dialing { - for _, g := range s.group { - if _, ok := g.nodes[id]; ok { - connected[g.name][id] = struct{}{} - } - } - } - - groupNodes := map[enode.ID]*enode.Node{} - for _, g := range s.group { - for _, n := range g.nodes { - if uint64(len(connected[g.name])) >= g.num { - break - } - err := s.checkDial(n, peers) - switch err { - case errNotWhitelisted, errSelf: - log.Warn("Removing group dial candidate", "id", n.ID(), "addr", &net.TCPAddr{IP: n.IP(), Port: n.TCP()}, "err", err) - delete(g.nodes, n.ID()) - case nil: - groupNodes[n.ID()] = n - connected[g.name][n.ID()] = struct{}{} - } - } - } - for _, n := range groupNodes { - addDial(groupDialedConn, n) - } - // If we don't have any peers whatsoever, try to dial a random bootnode. This // scenario is useful for the testnet (and private networks) where the discovery // table might be full of mostly bad peers, making it hard to find good ones. diff --git a/p2p/dial_test.go b/p2p/dial_test.go index 35e439798..ab687c2ea 100644 --- a/p2p/dial_test.go +++ b/p2p/dial_test.go @@ -611,137 +611,6 @@ func TestDialStateDirectDial(t *testing.T) { }) } -func TestDialStateGroupDial(t *testing.T) { - groups := []*dialGroup{ - { - name: "g1", - nodes: map[enode.ID]*enode.Node{ - uintID(1): newNode(uintID(1), nil), - uintID(2): newNode(uintID(2), nil), - }, - num: 2, - }, - { - name: "g2", - nodes: map[enode.ID]*enode.Node{ - uintID(2): newNode(uintID(2), nil), - uintID(3): newNode(uintID(3), nil), - uintID(4): newNode(uintID(4), nil), - uintID(5): newNode(uintID(5), nil), - uintID(6): newNode(uintID(6), nil), - }, - num: 2, - }, - } - - type groupTest struct { - peers []*Peer - dialing map[enode.ID]connFlag - ceiling map[string]uint64 - } - - tests := []groupTest{ - { - peers: nil, - dialing: map[enode.ID]connFlag{}, - ceiling: map[string]uint64{"g1": 2, "g2": 4}, - }, - { - peers: []*Peer{ - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}}, - }, - dialing: map[enode.ID]connFlag{ - uintID(1): staticDialedConn, - }, - ceiling: map[string]uint64{"g1": 2, "g2": 2}, - }, - { - peers: []*Peer{ - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(3), nil)}}, - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(4), nil)}}, - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(5), nil)}}, - }, - dialing: map[enode.ID]connFlag{ - uintID(2): staticDialedConn, - }, - ceiling: map[string]uint64{"g1": 2, "g2": 4}, - }, - { - peers: nil, - dialing: map[enode.ID]connFlag{ - uintID(1): staticDialedConn, - uintID(2): staticDialedConn, - uintID(3): staticDialedConn, - }, - ceiling: map[string]uint64{"g1": 2, "g2": 4}, - }, - } - - pm := func(ps []*Peer) map[enode.ID]*Peer { - m := make(map[enode.ID]*Peer) - for _, p := range ps { - m[p.rw.node.ID()] = p - } - return m - } - - run := func(i int, tt groupTest) { - d := newDialState(enode.ID{}, nil, nil, fakeTable{}, 0, nil) - d.dialing = make(map[enode.ID]connFlag) - for k, v := range tt.dialing { - d.dialing[k] = v - } - - for _, g := range groups { - d.addGroup(g) - } - peermap := pm(tt.peers) - new := d.newTasks(len(tt.dialing), peermap, time.Now()) - - cnt := map[string]uint64{} - for id := range peermap { - for _, g := range groups { - if _, ok := g.nodes[id]; ok { - cnt[g.name]++ - } - } - } - - for id := range tt.dialing { - for _, g := range groups { - if _, ok := g.nodes[id]; ok { - cnt[g.name]++ - } - } - } - - for _, task := range new { - id := task.(*dialTask).dest.ID() - for _, g := range groups { - if _, ok := g.nodes[id]; ok { - cnt[g.name]++ - } - } - } - - for _, g := range groups { - if cnt[g.name] < g.num { - t.Errorf("test %d) group %s peers + dialing + new < num (%d < %d)", - i, g.name, cnt[g.name], g.num) - } - if cnt[g.name] > tt.ceiling[g.name] { - t.Errorf("test %d) group %s peers + dialing + new > ceiling (%d > %d)", - i, g.name, cnt[g.name], tt.ceiling[g.name]) - } - } - } - - for i, tt := range tests { - run(i, tt) - } -} - // This test checks that static peers will be redialed immediately if they were re-added to a static list. func TestDialStaticAfterReset(t *testing.T) { wantStatic := []*enode.Node{ diff --git a/p2p/server.go b/p2p/server.go index 36b1721a5..58b76a708 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -182,8 +182,6 @@ type Server struct { removedirect chan *enode.Node addtrusted chan *enode.Node removetrusted chan *enode.Node - addgroup chan *dialGroup - removegroup chan *dialGroup posthandshake chan *conn addpeer chan *conn delpeer chan peerDrop @@ -206,7 +204,6 @@ const ( dynDialedConn connFlag = 1 << iota staticDialedConn directDialedConn - groupDialedConn inboundConn trustedConn ) @@ -260,9 +257,6 @@ func (f connFlag) String() string { if f&directDialedConn != 0 { s += "-directdial" } - if f&groupDialedConn != 0 { - s += "-groupdial" - } if f&inboundConn != 0 { s += "-inbound" } @@ -357,26 +351,6 @@ func (srv *Server) RemoveDirectPeer(node *enode.Node) { } } -func (srv *Server) AddGroup(name string, nodes []*enode.Node, num uint64) { - m := map[enode.ID]*enode.Node{} - for _, node := range nodes { - m[node.ID()] = node - } - g := &dialGroup{name: name, nodes: m, num: num} - select { - case srv.addgroup <- g: - case <-srv.quit: - } -} - -func (srv *Server) RemoveGroup(name string) { - g := &dialGroup{name: name} - select { - case srv.removegroup <- g: - case <-srv.quit: - } -} - // AddTrustedPeer adds the given node to a reserved whitelist which allows the // node to always connect, even if the slot are full. func (srv *Server) AddTrustedPeer(node *enode.Node) { @@ -524,8 +498,6 @@ func (srv *Server) Start() (err error) { srv.removestatic = make(chan *enode.Node) srv.adddirect = make(chan *enode.Node) srv.removedirect = make(chan *enode.Node) - srv.addgroup = make(chan *dialGroup) - srv.removegroup = make(chan *dialGroup) srv.addtrusted = make(chan *enode.Node) srv.removetrusted = make(chan *enode.Node) srv.peerOp = make(chan peerOpFunc) @@ -689,8 +661,6 @@ type dialer interface { removeStatic(*enode.Node) addDirect(*enode.Node) removeDirect(*enode.Node) - addGroup(*dialGroup) - removeGroup(*dialGroup) } func (srv *Server) run(dialstate dialer) { @@ -699,15 +669,12 @@ func (srv *Server) run(dialstate dialer) { defer srv.nodedb.Close() var ( - peers = make(map[enode.ID]*Peer) - inboundCount = 0 - trusted = make(map[enode.ID]bool, len(srv.TrustedNodes)) - peerflags = make(map[enode.ID]connFlag) - groupRefCount = make(map[enode.ID]int32) - groups = make(map[string]*dialGroup) - taskdone = make(chan task, maxActiveDialTasks) - runningTasks []task - queuedTasks []task // tasks that can't run yet + peers = make(map[enode.ID]*Peer) + inboundCount = 0 + trusted = make(map[enode.ID]bool, len(srv.TrustedNodes)) + taskdone = make(chan task, maxActiveDialTasks) + runningTasks []task + queuedTasks []task // tasks that can't run yet ) // Put trusted nodes into a map to speed up checks. // Trusted peers are loaded on startup or added via AddTrustedPeer RPC. @@ -745,60 +712,6 @@ func (srv *Server) run(dialstate dialer) { } } - // remember and maintain the connection flags locally - setConnFlags := func(id enode.ID, f connFlag, val bool) { - if p, ok := peers[id]; ok { - p.rw.set(f, val) - } - if val { - peerflags[id] |= f - } else { - peerflags[id] &= ^f - } - if peerflags[id] == 0 { - delete(peerflags, id) - } - } - - // Put trusted nodes into a map to speed up checks. - // Trusted peers are loaded on startup or added via AddTrustedPeer RPC. - for _, n := range srv.TrustedNodes { - setConnFlags(n.ID(), trustedConn, true) - } - - canDisconnect := func(p *Peer) bool { - f, ok := peerflags[p.ID()] - if ok && f != 0 { - return false - } - return !p.rw.is(dynDialedConn | inboundConn) - } - - removeGroup := func(g *dialGroup) { - if gg, ok := groups[g.name]; ok { - for id := range gg.nodes { - groupRefCount[id]-- - if groupRefCount[id] == 0 { - setConnFlags(id, groupDialedConn, false) - delete(groupRefCount, id) - } - } - } - } - - addGroup := func(g *dialGroup) { - if _, ok := groups[g.name]; ok { - removeGroup(groups[g.name]) - } - for id := range g.nodes { - groupRefCount[id]++ - if groupRefCount[id] > 0 { - setConnFlags(id, groupDialedConn, true) - } - } - groups[g.name] = g - } - running: for { scheduleTasks() @@ -812,16 +725,14 @@ running: // ephemeral static peer list. Add it to the dialer, // it will keep the node connected. srv.log.Trace("Adding static node", "node", n) - setConnFlags(n.ID(), staticDialedConn, true) dialstate.addStatic(n) case n := <-srv.removestatic: // This channel is used by RemovePeer to send a // disconnect request to a peer and begin the // stop keeping the node connected. srv.log.Trace("Removing static node", "node", n) - setConnFlags(n.ID(), staticDialedConn, false) dialstate.removeStatic(n) - if p, ok := peers[n.ID()]; ok && canDisconnect(p) { + if p, ok := peers[n.ID()]; ok { p.Disconnect(DiscRequested) } case n := <-srv.adddirect: @@ -829,42 +740,36 @@ running: // ephemeral direct peer list. Add it to the dialer, // it will keep the node connected. srv.log.Trace("Adding direct node", "node", n) - setConnFlags(n.ID(), directDialedConn, true) - if p, ok := peers[n.ID()]; ok { - p.rw.set(directDialedConn, true) - } dialstate.addDirect(n) case n := <-srv.removedirect: // This channel is used by RemoveDirectPeer to send a // disconnect request to a peer and begin the // stop keeping the node connected. srv.log.Trace("Removing direct node", "node", n) - setConnFlags(n.ID(), directDialedConn, false) + dialstate.removeDirect(n) if p, ok := peers[n.ID()]; ok { - p.rw.set(directDialedConn, false) - if !p.rw.is(trustedConn | groupDialedConn) { - p.Disconnect(DiscRequested) - } + p.Disconnect(DiscRequested) } - dialstate.removeDirect(n) - case g := <-srv.addgroup: - srv.log.Trace("Adding group", "group", g) - addGroup(g) - dialstate.addGroup(g) - case g := <-srv.removegroup: - srv.log.Trace("Removing group", "group", g) - removeGroup(g) - dialstate.removeGroup(g) case n := <-srv.addtrusted: // This channel is used by AddTrustedPeer to add an enode // to the trusted node set. srv.log.Trace("Adding trusted node", "node", n) - setConnFlags(n.ID(), trustedConn, true) + trusted[n.ID()] = true + // Mark any already-connected peer as trusted + if p, ok := peers[n.ID()]; ok { + p.rw.set(trustedConn, true) + } case n := <-srv.removetrusted: // This channel is used by RemoveTrustedPeer to remove an enode // from the trusted node set. srv.log.Trace("Removing trusted node", "node", n) - setConnFlags(n.ID(), trustedConn, false) + if _, ok := trusted[n.ID()]; ok { + delete(trusted, n.ID()) + } + // Unmark any already-connected peer as trusted + if p, ok := peers[n.ID()]; ok { + p.rw.set(trustedConn, false) + } case op := <-srv.peerOp: // This channel is used by Peers and PeerCount. op(peers) @@ -879,8 +784,9 @@ running: case c := <-srv.posthandshake: // A connection has passed the encryption handshake so // the remote identity is known (but hasn't been verified yet). - if f, ok := peerflags[c.node.ID()]; ok { - c.flags |= f + if trusted[c.node.ID()] { + // Ensure that the trusted flag is set before checking against MaxPeers. + c.flags |= trustedConn } // TODO: track in-progress inbound node IDs (pre-Peer) to avoid dialing them. select { @@ -962,9 +868,9 @@ func (srv *Server) protoHandshakeChecks(peers map[enode.ID]*Peer, inboundCount i func (srv *Server) encHandshakeChecks(peers map[enode.ID]*Peer, inboundCount int, c *conn) error { switch { - case !c.is(trustedConn|staticDialedConn|directDialedConn|groupDialedConn) && len(peers) >= srv.MaxPeers: + case !c.is(trustedConn|staticDialedConn|directDialedConn) && len(peers) >= srv.MaxPeers: return DiscTooManyPeers - case !c.is(trustedConn|directDialedConn|groupDialedConn) && c.is(inboundConn) && inboundCount >= srv.maxInboundConns(): + case !c.is(trustedConn|directDialedConn) && c.is(inboundConn) && inboundCount >= srv.maxInboundConns(): return DiscTooManyPeers case peers[c.node.ID()] != nil: return DiscAlreadyConnected diff --git a/p2p/server_test.go b/p2p/server_test.go index 8bd113791..734b2a8c1 100644 --- a/p2p/server_test.go +++ b/p2p/server_test.go @@ -203,92 +203,6 @@ func TestServerDial(t *testing.T) { } } -func TestServerPeerConnFlag(t *testing.T) { - srv := &Server{ - Config: Config{ - PrivateKey: newkey(), - MaxPeers: 10, - NoDial: true, - }, - } - if err := srv.Start(); err != nil { - t.Fatalf("could not start: %v", err) - } - defer srv.Stop() - - // inject a peer - key := newkey() - id := enode.PubkeyToIDV4(&key.PublicKey) - node := newNode(id, nil) - fd, _ := net.Pipe() - c := &conn{ - node: node, - fd: fd, - transport: newTestTransport(&key.PublicKey, fd), - flags: inboundConn, - cont: make(chan error), - } - if err := srv.checkpoint(c, srv.addpeer); err != nil { - t.Fatalf("could not add conn: %v", err) - } - - srv.AddTrustedPeer(node) - srv.Peers() // leverage this function to ensure trusted peer is added - if c.flags != (inboundConn | trustedConn) { - t.Errorf("flags mismatch: got %d, want %d", - c.flags, (inboundConn | trustedConn)) - } - - srv.AddDirectPeer(node) - srv.Peers() // leverage this function to ensure trusted peer is added - if c.flags != (inboundConn | trustedConn | directDialedConn) { - t.Errorf("flags mismatch: got %d, want %d", - c.flags, (inboundConn | trustedConn | directDialedConn)) - } - - srv.AddGroup("g1", []*enode.Node{node}, 1) - srv.Peers() // leverage this function to ensure trusted peer is added - if c.flags != (inboundConn | trustedConn | directDialedConn | groupDialedConn) { - t.Errorf("flags mismatch: got %d, want %d", - c.flags, (inboundConn | trustedConn | directDialedConn | groupDialedConn)) - } - - srv.AddGroup("g2", []*enode.Node{node}, 1) - srv.Peers() // leverage this function to ensure trusted peer is added - if c.flags != (inboundConn | trustedConn | directDialedConn | groupDialedConn) { - t.Errorf("flags mismatch: got %d, want %d", - c.flags, (inboundConn | trustedConn | directDialedConn | groupDialedConn)) - } - - srv.RemoveTrustedPeer(node) - srv.Peers() // leverage this function to ensure trusted peer is added - if c.flags != (inboundConn | directDialedConn | groupDialedConn) { - t.Errorf("flags mismatch: got %d, want %d", - c.flags, (inboundConn | directDialedConn | directDialedConn)) - } - - srv.RemoveDirectPeer(node) - srv.Peers() // leverage this function to ensure trusted peer is added - if c.flags != (inboundConn | groupDialedConn) { - t.Errorf("flags mismatch: got %d, want %d", - c.flags, (inboundConn | directDialedConn)) - } - - srv.RemoveGroup("g1") - srv.Peers() // leverage this function to ensure trusted peer is added - if c.flags != (inboundConn | groupDialedConn) { - t.Errorf("flags mismatch: got %d, want %d", - c.flags, (inboundConn | directDialedConn)) - } - - srv.RemoveGroup("g2") - srv.Peers() // leverage this function to ensure trusted peer is added - if c.flags != inboundConn { - t.Errorf("flags mismatch: got %d, want %d", - c.flags, inboundConn) - } -} - // This test checks that tasks generated by dialstate are // actually executed and taskdone is called for them. func TestServerTaskScheduling(t *testing.T) { @@ -429,10 +343,6 @@ func (tg taskgen) addDirect(*enode.Node) { } func (tg taskgen) removeDirect(*enode.Node) { } -func (tg taskgen) addGroup(*dialGroup) { -} -func (tg taskgen) removeGroup(*dialGroup) { -} type testTask struct { index int |