aboutsummaryrefslogtreecommitdiffstats
path: root/core/test/stopper.go
diff options
context:
space:
mode:
Diffstat (limited to 'core/test/stopper.go')
-rw-r--r--core/test/stopper.go58
1 files changed, 56 insertions, 2 deletions
diff --git a/core/test/stopper.go b/core/test/stopper.go
index 9fe5592..71b215d 100644
--- a/core/test/stopper.go
+++ b/core/test/stopper.go
@@ -40,7 +40,6 @@ func NewStopByConfirmedBlocks(
blockCount int,
apps map[types.NodeID]*App,
dbs map[types.NodeID]blockdb.BlockDatabase) *StopByConfirmedBlocks {
-
confirmedBlocks := make(map[types.NodeID]int)
for nID := range apps {
confirmedBlocks[nID] = 0
@@ -58,7 +57,6 @@ func NewStopByConfirmedBlocks(
func (s *StopByConfirmedBlocks) ShouldStop(nID types.NodeID) bool {
s.lock.Lock()
defer s.lock.Unlock()
-
// Accumulate confirmed blocks proposed by this node in this round.
lastChecked := s.lastCheckDelivered[nID]
currentConfirmedBlocks := s.confirmedBlocks[nID]
@@ -84,3 +82,59 @@ func (s *StopByConfirmedBlocks) ShouldStop(nID types.NodeID) bool {
}
return true
}
+
+// StopByRound would make sure at least one block at round R is delivered
+// at each node.
+type StopByRound struct {
+ untilRound uint64
+ currentRounds map[types.NodeID]uint64
+ lastCheckDelivered map[types.NodeID]int
+ apps map[types.NodeID]*App
+ dbs map[types.NodeID]blockdb.BlockDatabase
+ lock sync.Mutex
+}
+
+// NewStopByRound constructs an StopByRound instance.
+func NewStopByRound(
+ round uint64,
+ apps map[types.NodeID]*App,
+ dbs map[types.NodeID]blockdb.BlockDatabase) *StopByRound {
+ return &StopByRound{
+ untilRound: round,
+ currentRounds: make(map[types.NodeID]uint64),
+ lastCheckDelivered: make(map[types.NodeID]int),
+ apps: apps,
+ dbs: dbs,
+ }
+}
+
+// ShouldStop implements Stopper interface.
+func (s *StopByRound) ShouldStop(nID types.NodeID) bool {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+ // Cache latest round of this node.
+ if curRound := s.currentRounds[nID]; curRound < s.untilRound {
+ lastChecked := s.lastCheckDelivered[nID]
+ db := s.dbs[nID]
+ s.apps[nID].Check(func(app *App) {
+ for _, h := range app.DeliverSequence[lastChecked:] {
+ b, err := db.Get(h)
+ if err != nil {
+ panic(err)
+ }
+ if b.Position.Round > curRound {
+ curRound = b.Position.Round
+ }
+ }
+ s.lastCheckDelivered[nID] = len(app.DeliverSequence)
+ s.currentRounds[nID] = curRound
+ })
+ }
+ // Check if latest round on each node is later than untilRound.
+ for _, round := range s.currentRounds {
+ if round < s.untilRound {
+ return false
+ }
+ }
+ return true
+}