diff options
Diffstat (limited to 'core/test/stopper.go')
-rw-r--r-- | core/test/stopper.go | 58 |
1 files changed, 56 insertions, 2 deletions
diff --git a/core/test/stopper.go b/core/test/stopper.go index 9fe5592..71b215d 100644 --- a/core/test/stopper.go +++ b/core/test/stopper.go @@ -40,7 +40,6 @@ func NewStopByConfirmedBlocks( blockCount int, apps map[types.NodeID]*App, dbs map[types.NodeID]blockdb.BlockDatabase) *StopByConfirmedBlocks { - confirmedBlocks := make(map[types.NodeID]int) for nID := range apps { confirmedBlocks[nID] = 0 @@ -58,7 +57,6 @@ func NewStopByConfirmedBlocks( func (s *StopByConfirmedBlocks) ShouldStop(nID types.NodeID) bool { s.lock.Lock() defer s.lock.Unlock() - // Accumulate confirmed blocks proposed by this node in this round. lastChecked := s.lastCheckDelivered[nID] currentConfirmedBlocks := s.confirmedBlocks[nID] @@ -84,3 +82,59 @@ func (s *StopByConfirmedBlocks) ShouldStop(nID types.NodeID) bool { } return true } + +// StopByRound would make sure at least one block at round R is delivered +// at each node. +type StopByRound struct { + untilRound uint64 + currentRounds map[types.NodeID]uint64 + lastCheckDelivered map[types.NodeID]int + apps map[types.NodeID]*App + dbs map[types.NodeID]blockdb.BlockDatabase + lock sync.Mutex +} + +// NewStopByRound constructs an StopByRound instance. +func NewStopByRound( + round uint64, + apps map[types.NodeID]*App, + dbs map[types.NodeID]blockdb.BlockDatabase) *StopByRound { + return &StopByRound{ + untilRound: round, + currentRounds: make(map[types.NodeID]uint64), + lastCheckDelivered: make(map[types.NodeID]int), + apps: apps, + dbs: dbs, + } +} + +// ShouldStop implements Stopper interface. +func (s *StopByRound) ShouldStop(nID types.NodeID) bool { + s.lock.Lock() + defer s.lock.Unlock() + // Cache latest round of this node. + if curRound := s.currentRounds[nID]; curRound < s.untilRound { + lastChecked := s.lastCheckDelivered[nID] + db := s.dbs[nID] + s.apps[nID].Check(func(app *App) { + for _, h := range app.DeliverSequence[lastChecked:] { + b, err := db.Get(h) + if err != nil { + panic(err) + } + if b.Position.Round > curRound { + curRound = b.Position.Round + } + } + s.lastCheckDelivered[nID] = len(app.DeliverSequence) + s.currentRounds[nID] = curRound + }) + } + // Check if latest round on each node is later than untilRound. + for _, round := range s.currentRounds { + if round < s.untilRound { + return false + } + } + return true +} |