aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorgluk256 <gluk256@gmail.com>2019-04-19 17:15:17 +0800
committerAnton Evangelatov <anton.evangelatov@gmail.com>2019-04-19 17:15:17 +0800
commitd9403690ecaf1beecc4369c56c4600caf275dae8 (patch)
tree002c7b10c8fe260e3d8eb8bcc2bb9cd78036b6e6
parentd8dc37c85bb68d83c05b819e03a7b5c05c815e42 (diff)
downloadgo-tangerine-d9403690ecaf1beecc4369c56c4600caf275dae8.tar
go-tangerine-d9403690ecaf1beecc4369c56c4600caf275dae8.tar.gz
go-tangerine-d9403690ecaf1beecc4369c56c4600caf275dae8.tar.bz2
go-tangerine-d9403690ecaf1beecc4369c56c4600caf275dae8.tar.lz
go-tangerine-d9403690ecaf1beecc4369c56c4600caf275dae8.tar.xz
go-tangerine-d9403690ecaf1beecc4369c56c4600caf275dae8.tar.zst
go-tangerine-d9403690ecaf1beecc4369c56c4600caf275dae8.zip
swarm/pss: Fix flaky TestProxNetwork (#19471)
-rw-r--r--swarm/network/simulation/node.go2
-rw-r--r--swarm/pss/prox_test.go364
-rw-r--r--swarm/pss/pss_test.go2
3 files changed, 181 insertions, 187 deletions
diff --git a/swarm/network/simulation/node.go b/swarm/network/simulation/node.go
index 46c2bb866..f66b0afd0 100644
--- a/swarm/network/simulation/node.go
+++ b/swarm/network/simulation/node.go
@@ -234,9 +234,9 @@ func (s *Simulation) UploadSnapshot(ctx context.Context, snapshotFile string, op
if err != nil {
return err
}
- defer f.Close()
jsonbyte, err := ioutil.ReadAll(f)
+ f.Close()
if err != nil {
return err
}
diff --git a/swarm/pss/prox_test.go b/swarm/pss/prox_test.go
index bc32e612d..908a0d330 100644
--- a/swarm/pss/prox_test.go
+++ b/swarm/pss/prox_test.go
@@ -4,10 +4,7 @@ import (
"context"
"crypto/ecdsa"
"encoding/binary"
- "errors"
"fmt"
- "strconv"
- "strings"
"sync"
"testing"
"time"
@@ -39,24 +36,20 @@ type handlerNotification struct {
}
type testData struct {
- mu sync.Mutex
- sim *simulation.Simulation
- handlerDone bool // set to true on termination of the simulation run
- requiredMessages int
- allowedMessages int
- messageCount int
- kademlias map[enode.ID]*network.Kademlia
- nodeAddrs map[enode.ID][]byte // make predictable overlay addresses from the generated random enode ids
- recipients map[int][]enode.ID // for logging output only
- allowed map[int][]enode.ID // allowed recipients
- expectedMsgs map[enode.ID][]uint64 // message serials we expect respective nodes to receive
- allowedMsgs map[enode.ID][]uint64 // message serials we expect respective nodes to receive
- senders map[int]enode.ID // originating nodes of the messages (intention is to choose as far as possible from the receiving neighborhood)
- handlerC chan handlerNotification // passes message from pss message handler to simulation driver
- doneC chan struct{} // terminates the handler channel listener
- errC chan error // error to pass to main sim thread
- msgC chan handlerNotification // message receipt notification to main sim thread
- msgs [][]byte // recipient addresses of messages
+ sim *simulation.Simulation
+ kademlias map[enode.ID]*network.Kademlia
+ nodeAddresses map[enode.ID][]byte // make predictable overlay addresses from the generated random enode ids
+ senders map[int]enode.ID // originating nodes of the messages (intention is to choose as far as possible from the receiving neighborhood)
+ recipientAddresses [][]byte
+
+ requiredMsgCount int
+ requiredMsgs map[enode.ID][]uint64 // message serials we expect respective nodes to receive
+ allowedMsgs map[enode.ID][]uint64 // message serials we expect respective nodes to receive
+
+ notifications []handlerNotification // notification queue
+ totalMsgCount int
+ handlerDone bool // set to true on termination of the simulation run
+ mu sync.Mutex
}
var (
@@ -64,67 +57,60 @@ var (
topic = BytesToTopic([]byte{0xf3, 0x9e, 0x06, 0x82})
)
-func (d *testData) getMsgCount() int {
- d.mu.Lock()
- defer d.mu.Unlock()
- return d.messageCount
+func (td *testData) pushNotification(val handlerNotification) {
+ td.mu.Lock()
+ td.notifications = append(td.notifications, val)
+ td.mu.Unlock()
}
-func (d *testData) incrementMsgCount() int {
- d.mu.Lock()
- defer d.mu.Unlock()
- d.messageCount++
- return d.messageCount
+func (td *testData) popNotification() (first handlerNotification, exist bool) {
+ td.mu.Lock()
+ if len(td.notifications) > 0 {
+ exist = true
+ first = td.notifications[0]
+ td.notifications = td.notifications[1:]
+ }
+ td.mu.Unlock()
+ return first, exist
}
-func (d *testData) isDone() bool {
- d.mu.Lock()
- defer d.mu.Unlock()
- return d.handlerDone
+func (td *testData) getMsgCount() int {
+ td.mu.Lock()
+ defer td.mu.Unlock()
+ return td.totalMsgCount
}
-func (d *testData) setDone() {
- d.mu.Lock()
- defer d.mu.Unlock()
- d.handlerDone = true
+func (td *testData) incrementMsgCount() int {
+ td.mu.Lock()
+ defer td.mu.Unlock()
+ td.totalMsgCount++
+ return td.totalMsgCount
}
-func getCmdParams(t *testing.T) (int, int, time.Duration) {
- args := strings.Split(t.Name(), "/")
- msgCount, err := strconv.ParseInt(args[2], 10, 16)
- if err != nil {
- t.Fatal(err)
- }
- nodeCount, err := strconv.ParseInt(args[1], 10, 16)
- if err != nil {
- t.Fatal(err)
- }
- timeoutStr := fmt.Sprintf("%ss", args[3])
- timeoutDur, err := time.ParseDuration(timeoutStr)
- if err != nil {
- t.Fatal(err)
- }
- return int(msgCount), int(nodeCount), timeoutDur
+func (td *testData) isDone() bool {
+ td.mu.Lock()
+ defer td.mu.Unlock()
+ return td.handlerDone
+}
+
+func (td *testData) setDone() {
+ td.mu.Lock()
+ defer td.mu.Unlock()
+ td.handlerDone = true
}
func newTestData() *testData {
return &testData{
- kademlias: make(map[enode.ID]*network.Kademlia),
- nodeAddrs: make(map[enode.ID][]byte),
- recipients: make(map[int][]enode.ID),
- allowed: make(map[int][]enode.ID),
- expectedMsgs: make(map[enode.ID][]uint64),
- allowedMsgs: make(map[enode.ID][]uint64),
- senders: make(map[int]enode.ID),
- handlerC: make(chan handlerNotification),
- doneC: make(chan struct{}),
- errC: make(chan error),
- msgC: make(chan handlerNotification),
+ kademlias: make(map[enode.ID]*network.Kademlia),
+ nodeAddresses: make(map[enode.ID][]byte),
+ requiredMsgs: make(map[enode.ID][]uint64),
+ allowedMsgs: make(map[enode.ID][]uint64),
+ senders: make(map[int]enode.ID),
}
}
-func (d *testData) getKademlia(nodeId *enode.ID) (*network.Kademlia, error) {
- kadif, ok := d.sim.NodeItem(*nodeId, simulation.BucketKeyKademlia)
+func (td *testData) getKademlia(nodeId *enode.ID) (*network.Kademlia, error) {
+ kadif, ok := td.sim.NodeItem(*nodeId, simulation.BucketKeyKademlia)
if !ok {
return nil, fmt.Errorf("no kademlia entry for %v", nodeId)
}
@@ -135,29 +121,29 @@ func (d *testData) getKademlia(nodeId *enode.ID) (*network.Kademlia, error) {
return kad, nil
}
-func (d *testData) init(msgCount int) error {
+func (td *testData) init(msgCount int) error {
log.Debug("TestProxNetwork start")
- for _, nodeId := range d.sim.NodeIDs() {
- kad, err := d.getKademlia(&nodeId)
+ for _, nodeId := range td.sim.NodeIDs() {
+ kad, err := td.getKademlia(&nodeId)
if err != nil {
return err
}
- d.nodeAddrs[nodeId] = kad.BaseAddr()
+ td.nodeAddresses[nodeId] = kad.BaseAddr()
}
for i := 0; i < int(msgCount); i++ {
msgAddr := pot.RandomAddress() // we choose message addresses randomly
- d.msgs = append(d.msgs, msgAddr.Bytes())
+ td.recipientAddresses = append(td.recipientAddresses, msgAddr.Bytes())
smallestPo := 256
var targets []enode.ID
var closestPO int
// loop through all nodes and find the required and allowed recipients of each message
// (for more information, please see the comment to the main test function)
- for _, nod := range d.sim.Net.GetNodes() {
- po, _ := pof(d.msgs[i], d.nodeAddrs[nod.ID()], 0)
- depth := d.kademlias[nod.ID()].NeighbourhoodDepth()
+ for _, nod := range td.sim.Net.GetNodes() {
+ po, _ := pof(td.recipientAddresses[i], td.nodeAddresses[nod.ID()], 0)
+ depth := td.kademlias[nod.ID()].NeighbourhoodDepth()
// only nodes with closest IDs (wrt the msg address) will be required recipients
if po > closestPO {
@@ -169,28 +155,25 @@ func (d *testData) init(msgCount int) error {
}
if po >= depth {
- d.allowedMessages++
- d.allowed[i] = append(d.allowed[i], nod.ID())
- d.allowedMsgs[nod.ID()] = append(d.allowedMsgs[nod.ID()], uint64(i))
+ td.allowedMsgs[nod.ID()] = append(td.allowedMsgs[nod.ID()], uint64(i))
}
// a node with the smallest PO (wrt msg) will be the sender,
// in order to increase the distance the msg must travel
if po < smallestPo {
smallestPo = po
- d.senders[i] = nod.ID()
+ td.senders[i] = nod.ID()
}
}
- d.requiredMessages += len(targets)
+ td.requiredMsgCount += len(targets)
for _, id := range targets {
- d.recipients[i] = append(d.recipients[i], id)
- d.expectedMsgs[id] = append(d.expectedMsgs[id], uint64(i))
+ td.requiredMsgs[id] = append(td.requiredMsgs[id], uint64(i))
}
- log.Debug("nn for msg", "targets", len(d.recipients[i]), "msgidx", i, "msg", common.Bytes2Hex(msgAddr[:8]), "sender", d.senders[i], "senderpo", smallestPo)
+ log.Debug("nn for msg", "targets", len(targets), "msgidx", i, "msg", common.Bytes2Hex(msgAddr[:8]), "sender", td.senders[i], "senderpo", smallestPo)
}
- log.Debug("msgs to receive", "count", d.requiredMessages)
+ log.Debug("recipientAddresses to receive", "count", td.requiredMsgCount)
return nil
}
@@ -213,144 +196,161 @@ func (d *testData) init(msgCount int) error {
// nodes Y and Z will be considered required recipients of the msg,
// whereas nodes X, Y and Z will be allowed recipients.
func TestProxNetwork(t *testing.T) {
- t.Run("16/16/15", testProxNetwork)
+ t.Run("16_nodes,_16_messages,_16_seconds", func(t *testing.T) {
+ testProxNetwork(t, 16, 16, 16*time.Second)
+ })
}
-// params in run name: nodes/msgs
func TestProxNetworkLong(t *testing.T) {
if !*longrunning {
t.Skip("run with --longrunning flag to run extensive network tests")
}
- t.Run("8/100/30", testProxNetwork)
- t.Run("16/100/30", testProxNetwork)
- t.Run("32/100/60", testProxNetwork)
- t.Run("64/100/60", testProxNetwork)
- t.Run("128/100/120", testProxNetwork)
+ t.Run("8_nodes,_100_messages,_30_seconds", func(t *testing.T) {
+ testProxNetwork(t, 8, 100, 30*time.Second)
+ })
+ t.Run("16_nodes,_100_messages,_30_seconds", func(t *testing.T) {
+ testProxNetwork(t, 16, 100, 30*time.Second)
+ })
+ t.Run("32_nodes,_100_messages,_60_seconds", func(t *testing.T) {
+ testProxNetwork(t, 32, 100, 1*time.Minute)
+ })
+ t.Run("64_nodes,_100_messages,_60_seconds", func(t *testing.T) {
+ testProxNetwork(t, 64, 100, 1*time.Minute)
+ })
+ t.Run("128_nodes,_100_messages,_120_seconds", func(t *testing.T) {
+ testProxNetwork(t, 128, 100, 2*time.Minute)
+ })
}
-func testProxNetwork(t *testing.T) {
- tstdata := newTestData()
- msgCount, nodeCount, timeout := getCmdParams(t)
+func testProxNetwork(t *testing.T, nodeCount int, msgCount int, timeout time.Duration) {
+ td := newTestData()
handlerContextFuncs := make(map[Topic]handlerContextFunc)
handlerContextFuncs[topic] = nodeMsgHandler
- services := newProxServices(tstdata, true, handlerContextFuncs, tstdata.kademlias)
- tstdata.sim = simulation.New(services)
- defer tstdata.sim.Close()
+ services := newProxServices(td, true, handlerContextFuncs, td.kademlias)
+ td.sim = simulation.New(services)
+ defer td.sim.Close()
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
filename := fmt.Sprintf("testdata/snapshot_%d.json", nodeCount)
- err := tstdata.sim.UploadSnapshot(ctx, filename)
+ err := td.sim.UploadSnapshot(ctx, filename)
if err != nil {
t.Fatal(err)
}
- err = tstdata.init(msgCount) // initialize the test data
+ err = td.init(msgCount) // initialize the test data
if err != nil {
t.Fatal(err)
}
wrapper := func(c context.Context, _ *simulation.Simulation) error {
- return testRoutine(tstdata, c)
+ return testRoutine(td, c)
}
- result := tstdata.sim.Run(ctx, wrapper) // call the main test function
+ result := td.sim.Run(ctx, wrapper) // call the main test function
if result.Error != nil {
- // context deadline exceeded
- // however, it might just mean that not all possible messages are received
- // now we must check if all required messages are received
- cnt := tstdata.getMsgCount()
- log.Debug("TestProxNetwork finished", "rcv", cnt)
- if cnt < tstdata.requiredMessages {
+ timedOut := result.Error == context.DeadlineExceeded
+ if !timedOut || td.getMsgCount() < td.requiredMsgCount {
t.Fatal(result.Error)
}
}
- t.Logf("completed %d", result.Duration)
}
-func (tstdata *testData) sendAllMsgs() {
- for i, msg := range tstdata.msgs {
- log.Debug("sending msg", "idx", i, "from", tstdata.senders[i])
- nodeClient, err := tstdata.sim.Net.GetNode(tstdata.senders[i]).Client()
+func (td *testData) sendAllMsgs() error {
+ nodes := make(map[int]*rpc.Client)
+ for i := range td.recipientAddresses {
+ nodeClient, err := td.sim.Net.GetNode(td.senders[i]).Client()
if err != nil {
- tstdata.errC <- err
+ return err
}
+ nodes[i] = nodeClient
+ }
+
+ for i, msg := range td.recipientAddresses {
+ log.Debug("sending msg", "idx", i, "from", td.senders[i])
+ nodeClient := nodes[i]
var uvarByte [8]byte
binary.PutUvarint(uvarByte[:], uint64(i))
nodeClient.Call(nil, "pss_sendRaw", hexutil.Encode(msg), hexutil.Encode(topic[:]), hexutil.Encode(uvarByte[:]))
}
- log.Debug("all messages sent")
+ return nil
+}
+
+func isMoreTimeLeft(ctx context.Context) bool {
+ select {
+ case <-ctx.Done():
+ return false
+ default:
+ return true
+ }
}
// testRoutine is the main test function, called by Simulation.Run()
-func testRoutine(tstdata *testData, ctx context.Context) error {
- go handlerChannelListener(tstdata, ctx)
- go tstdata.sendAllMsgs()
+func testRoutine(td *testData, ctx context.Context) error {
+
+ hasMoreRound := func(err error, hadMessage bool) bool {
+ return err == nil && (hadMessage || isMoreTimeLeft(ctx))
+ }
+
+ if err := td.sendAllMsgs(); err != nil {
+ return err
+ }
+
+ var err error
received := 0
+ hadMessage := false
- // collect incoming messages and terminate with corresponding status when message handler listener ends
- for {
- select {
- case err := <-tstdata.errC:
- return err
- case hn := <-tstdata.msgC:
- received++
- log.Debug("msg received", "msgs_received", received, "total_expected", tstdata.requiredMessages, "id", hn.id, "serial", hn.serial)
- if received == tstdata.allowedMessages {
- close(tstdata.doneC)
- return nil
+ for oneMoreRound := true; oneMoreRound; oneMoreRound = hasMoreRound(err, hadMessage) {
+ message, hadMessage := td.popNotification()
+
+ if !isMoreTimeLeft(ctx) {
+ // Stop handlers from sending more messages.
+ // Note: only best effort, race is possible.
+ td.setDone()
+ }
+
+ if hadMessage {
+ if td.isAllowedMessage(message) {
+ received++
+ log.Debug("msg received", "msgs_received", received, "total_expected", td.requiredMsgCount, "id", message.id, "serial", message.serial)
+ } else {
+ err = fmt.Errorf("message %d received by wrong recipient %v", message.serial, message.id)
}
+ } else {
+ time.Sleep(32 * time.Millisecond)
}
}
- return nil
-}
-func handlerChannelListener(tstdata *testData, ctx context.Context) {
- for {
- select {
- case <-tstdata.doneC: // graceful exit
- tstdata.setDone()
- tstdata.errC <- nil
- return
-
- case <-ctx.Done(): // timeout or cancel
- tstdata.setDone()
- tstdata.errC <- ctx.Err()
- return
-
- // incoming message from pss message handler
- case handlerNotification := <-tstdata.handlerC:
- // check if recipient has already received all its messages and notify to fail the test if so
- aMsgs := tstdata.allowedMsgs[handlerNotification.id]
- if len(aMsgs) == 0 {
- tstdata.setDone()
- tstdata.errC <- fmt.Errorf("too many messages received by recipient %x", handlerNotification.id)
- return
- }
+ if err != nil {
+ return err
+ }
- // check if message serial is in expected messages for this recipient and notify to fail the test if not
- idx := -1
- for i, msg := range aMsgs {
- if handlerNotification.serial == msg {
- idx = i
- break
- }
- }
- if idx == -1 {
- tstdata.setDone()
- tstdata.errC <- fmt.Errorf("message %d received by wrong recipient %v", handlerNotification.serial, handlerNotification.id)
- return
- }
+ if td.getMsgCount() < td.requiredMsgCount {
+ return ctx.Err()
+ }
+ return nil
+}
- // message is ok, so remove that message serial from the recipient expectation array and notify the main sim thread
- aMsgs[idx] = aMsgs[len(aMsgs)-1]
- aMsgs = aMsgs[:len(aMsgs)-1]
- tstdata.msgC <- handlerNotification
+func (td *testData) isAllowedMessage(n handlerNotification) bool {
+ // check if message serial is in expected messages for this recipient
+ for _, s := range td.allowedMsgs[n.id] {
+ if n.serial == s {
+ return true
}
}
+ return false
}
-func nodeMsgHandler(tstdata *testData, config *adapters.NodeConfig) *handler {
+func (td *testData) removeAllowedMessage(id enode.ID, index int) {
+ last := len(td.allowedMsgs[id]) - 1
+ td.allowedMsgs[id][index] = td.allowedMsgs[id][last]
+ td.allowedMsgs[id] = td.allowedMsgs[id][:last]
+}
+
+func nodeMsgHandler(td *testData, config *adapters.NodeConfig) *handler {
return &handler{
f: func(msg []byte, p *p2p.Peer, asymmetric bool, keyid string) error {
- cnt := tstdata.incrementMsgCount()
- log.Debug("nodeMsgHandler rcv", "cnt", cnt)
+ if td.isDone() {
+ return nil // terminate if simulation is over
+ }
+
+ td.incrementMsgCount()
// using simple serial in message body, makes it easy to keep track of who's getting what
serial, c := binary.Uvarint(msg)
@@ -358,15 +358,7 @@ func nodeMsgHandler(tstdata *testData, config *adapters.NodeConfig) *handler {
log.Crit(fmt.Sprintf("corrupt message received by %x (uvarint parse returned %d)", config.ID, c))
}
- if tstdata.isDone() {
- return errors.New("handlers aborted") // terminate if simulation is over
- }
-
- // pass message context to the listener in the simulation
- tstdata.handlerC <- handlerNotification{
- id: config.ID,
- serial: serial,
- }
+ td.pushNotification(handlerNotification{id: config.ID, serial: serial})
return nil
},
caps: &handlerCaps{
@@ -378,7 +370,7 @@ func nodeMsgHandler(tstdata *testData, config *adapters.NodeConfig) *handler {
// an adaptation of the same services setup as in pss_test.go
// replaces pss_test.go when those tests are rewritten to the new swarm/network/simulation package
-func newProxServices(tstdata *testData, allowRaw bool, handlerContextFuncs map[Topic]handlerContextFunc, kademlias map[enode.ID]*network.Kademlia) map[string]simulation.ServiceFunc {
+func newProxServices(td *testData, allowRaw bool, handlerContextFuncs map[Topic]handlerContextFunc, kademlias map[enode.ID]*network.Kademlia) map[string]simulation.ServiceFunc {
stateStore := state.NewInmemoryStore()
kademlia := func(id enode.ID, bzzkey []byte) *network.Kademlia {
if k, ok := kademlias[id]; ok {
@@ -415,6 +407,9 @@ func newProxServices(tstdata *testData, allowRaw bool, handlerContextFuncs map[T
UnderlayAddr: addr.Under(),
HiveParams: hp,
}
+ bzzKey := network.PrivateKeyToBzzKey(bzzPrivateKey)
+ pskad := kademlia(ctx.Config.ID, bzzKey)
+ b.Store(simulation.BucketKeyKademlia, pskad)
return network.NewBzz(config, kademlia(ctx.Config.ID, addr.OAddr), stateStore, nil, nil), nil, nil
},
"pss": func(ctx *adapters.ServiceContext, b *sync.Map) (node.Service, func(), error) {
@@ -434,6 +429,7 @@ func newProxServices(tstdata *testData, allowRaw bool, handlerContextFuncs map[T
}
bzzKey := network.PrivateKeyToBzzKey(bzzPrivateKey)
pskad := kademlia(ctx.Config.ID, bzzKey)
+ b.Store(simulation.BucketKeyKademlia, pskad)
ps, err := NewPss(pskad, pssp)
if err != nil {
return nil, nil, err
@@ -442,7 +438,7 @@ func newProxServices(tstdata *testData, allowRaw bool, handlerContextFuncs map[T
// register the handlers we've been passed
var deregisters []func()
for tpc, hndlrFunc := range handlerContextFuncs {
- deregisters = append(deregisters, ps.Register(&tpc, hndlrFunc(tstdata, ctx.Config)))
+ deregisters = append(deregisters, ps.Register(&tpc, hndlrFunc(td, ctx.Config)))
}
// if handshake mode is set, add the controller
@@ -459,8 +455,6 @@ func newProxServices(tstdata *testData, allowRaw bool, handlerContextFuncs map[T
Public: false,
})
- b.Store(simulation.BucketKeyKademlia, pskad)
-
// return Pss and cleanups
return ps, func() {
// run the handler deregister functions in reverse order
diff --git a/swarm/pss/pss_test.go b/swarm/pss/pss_test.go
index ea7a591b1..9884ffbe9 100644
--- a/swarm/pss/pss_test.go
+++ b/swarm/pss/pss_test.go
@@ -1364,7 +1364,7 @@ func TestNetwork(t *testing.T) {
}
// params in run name:
-// nodes/msgs/addrbytes/adaptertype
+// nodes/recipientAddresses/addrbytes/adaptertype
// if adaptertype is exec uses execadapter, simadapter otherwise
func TestNetwork2000(t *testing.T) {
if !*longrunning {