aboutsummaryrefslogtreecommitdiffstats
path: root/core/test/state.go
diff options
context:
space:
mode:
Diffstat (limited to 'core/test/state.go')
-rw-r--r--core/test/state.go44
1 files changed, 44 insertions, 0 deletions
diff --git a/core/test/state.go b/core/test/state.go
index 02ee412..f1cf365 100644
--- a/core/test/state.go
+++ b/core/test/state.go
@@ -68,6 +68,9 @@ var (
ErrStateDKGFinalsNotEqual = errors.New("dkg finalizations not equal")
// ErrStateCRSsNotEqual means CRSs of two states are not equal.
ErrStateCRSsNotEqual = errors.New("crs not equal")
+ // ErrStateDKGResetCountNotEqual means dkgResetCount of two states are not
+ // equal.
+ ErrStateDKGResetCountNotEqual = errors.New("dkg reset count not equal")
// ErrStatePendingChangesNotEqual means pending change requests of two
// states are not equal.
ErrStatePendingChangesNotEqual = errors.New("pending changes not equal")
@@ -99,6 +102,7 @@ type State struct {
dkgReadys map[uint64]map[types.NodeID]*typesDKG.MPKReady
dkgFinals map[uint64]map[types.NodeID]*typesDKG.Finalize
crs []common.Hash
+ dkgResetCount map[uint64]uint64
// Other stuffs
local bool
logger common.Logger
@@ -143,6 +147,7 @@ func NewState(
map[uint64]map[types.NodeID][]*typesDKG.Complaint),
dkgMasterPublicKeys: make(
map[uint64]map[types.NodeID]*typesDKG.MasterPublicKey),
+ dkgResetCount: make(map[uint64]uint64),
appliedRequests: make(map[common.Hash]struct{}),
}
}
@@ -200,6 +205,10 @@ func (s *State) unpackPayload(
case StateAddDKGFinal:
v = &typesDKG.Finalize{}
err = rlp.DecodeBytes(raw.Payload, v)
+ case StateResetDKG:
+ var tmp common.Hash
+ err = rlp.DecodeBytes(raw.Payload, &tmp)
+ v = tmp
case StateChangeLambdaBA:
var tmp uint64
err = rlp.DecodeBytes(raw.Payload, &tmp)
@@ -392,6 +401,15 @@ func (s *State) Equal(other *State) error {
return ErrStateCRSsNotEqual
}
}
+ // Check dkgResetCount.
+ if len(s.dkgResetCount) != len(other.dkgResetCount) {
+ return ErrStateDKGResetCountNotEqual
+ }
+ for idx, count := range s.dkgResetCount {
+ if count != other.dkgResetCount[idx] {
+ return ErrStateDKGResetCountNotEqual
+ }
+ }
// Check pending changes.
checkPending := func(
src, target map[common.Hash]*StateChangeRequest) error {
@@ -478,6 +496,10 @@ func (s *State) Clone() (copied *State) {
for _, crs := range s.crs {
copied.crs = append(copied.crs, crs)
}
+ copied.dkgResetCount = make(map[uint64]uint64, len(s.dkgResetCount))
+ for round, count := range s.dkgResetCount {
+ copied.dkgResetCount[round] = count
+ }
for hash := range s.appliedRequests {
copied.appliedRequests[hash] = struct{}{}
}
@@ -654,6 +676,11 @@ func (s *State) isValidRequest(req *StateChangeRequest) (err error) {
} else {
return ErrMissingPreviousCRS
}
+ case StateResetDKG:
+ newCRS := req.Payload.(common.Hash)
+ if s.crs[len(s.crs)-1].Equal(newCRS) {
+ return ErrDuplicatedChange
+ }
}
return nil
}
@@ -702,6 +729,14 @@ func (s *State) applyRequest(req *StateChangeRequest) error {
s.dkgFinals[final.Round] = make(map[types.NodeID]*typesDKG.Finalize)
}
s.dkgFinals[final.Round][final.ProposerID] = final
+ case StateResetDKG:
+ round := uint64(len(s.crs) - 1)
+ s.crs[round] = req.Payload.(common.Hash)
+ s.dkgResetCount[round]++
+ delete(s.dkgMasterPublicKeys, round)
+ delete(s.dkgReadys, round)
+ delete(s.dkgComplaints, round)
+ delete(s.dkgFinals, round)
case StateChangeLambdaBA:
s.lambdaBA = time.Duration(req.Payload.(uint64))
case StateChangeLambdaDKG:
@@ -752,6 +787,8 @@ func (s *State) RequestChange(
payload = payload.(*typesDKG.MasterPublicKey)
case StateAddDKGComplaint:
payload = payload.(*typesDKG.Complaint)
+ case StateResetDKG:
+ payload = payload.(common.Hash)
}
req := NewStateChangeRequest(t, payload)
s.lock.Lock()
@@ -830,3 +867,10 @@ func (s *State) IsDKGFinal(round uint64, threshold int) bool {
defer s.lock.RUnlock()
return len(s.dkgFinals[round]) > threshold
}
+
+// DKGResetCount returns the reset count for DKG of given round.
+func (s *State) DKGResetCount(round uint64) uint64 {
+ s.lock.RLock()
+ defer s.lock.RUnlock()
+ return s.dkgResetCount[round]
+}