aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--core/agreement-state.go45
-rw-r--r--core/agreement-state_test.go206
-rw-r--r--core/agreement.go32
3 files changed, 111 insertions, 172 deletions
diff --git a/core/agreement-state.go b/core/agreement-state.go
index 6022bf3..892d7c3 100644
--- a/core/agreement-state.go
+++ b/core/agreement-state.go
@@ -81,7 +81,7 @@ func (s *prepareState) nextState() (agreementState, error) {
delete(s.a.blocks, s.a.ID)
}
}
- s.a.blockChan <- hash
+ s.a.recv.proposeBlock(hash)
return newAckState(s.a), nil
}
func (s *prepareState) receiveVote() error { return nil }
@@ -112,11 +112,11 @@ func (s *ackState) nextState() (agreementState, error) {
if hash == nullBlockHash {
hash = s.a.leader.leaderBlockHash()
}
- s.a.voteChan <- &types.Vote{
+ s.a.recv.proposeVote(&types.Vote{
Type: types.VoteAck,
BlockHash: hash,
Period: s.a.period,
- }
+ })
return newConfirmState(s.a), nil
}
func (s *ackState) receiveVote() error { return nil }
@@ -151,11 +151,11 @@ func (s *confirmState) receiveVote() error {
return nil
}
if hash != nullBlockHash {
- s.a.voteChan <- &types.Vote{
+ s.a.recv.proposeVote(&types.Vote{
Type: types.VoteConfirm,
BlockHash: hash,
Period: s.a.period,
- }
+ })
}
s.voted.Store(true)
return nil
@@ -179,22 +179,22 @@ func (s *pass1State) nextState() (agreementState, error) {
defer s.a.votesLock.RUnlock()
if vote, exist :=
s.a.votes[s.a.period][types.VoteConfirm][s.a.ID]; exist {
- s.a.voteChan <- &types.Vote{
+ s.a.recv.proposeVote(&types.Vote{
Type: types.VotePass,
BlockHash: vote.BlockHash,
Period: s.a.period,
- }
+ })
} else if s.a.period == 1 {
voteDefault = true
} else {
hash, ok := s.a.countVote(s.a.period-1, types.VotePass)
if ok {
if hash == nullBlockHash {
- s.a.voteChan <- &types.Vote{
+ s.a.recv.proposeVote(&types.Vote{
Type: types.VotePass,
BlockHash: hash,
Period: s.a.period,
- }
+ })
} else {
voteDefault = true
}
@@ -203,11 +203,11 @@ func (s *pass1State) nextState() (agreementState, error) {
}
}
if voteDefault {
- s.a.voteChan <- &types.Vote{
+ s.a.recv.proposeVote(&types.Vote{
Type: types.VotePass,
BlockHash: s.a.defaultBlock,
Period: s.a.period,
- }
+ })
}
return newPass2State(s.a), nil
}
@@ -227,7 +227,7 @@ func newPass2State(a *agreementData) *pass2State {
return &pass2State{
a: a,
voted: voted,
- enoughPassVote: make(chan common.Hash),
+ enoughPassVote: make(chan common.Hash, 1),
terminateChan: make(chan struct{}),
}
}
@@ -259,30 +259,29 @@ func (s *pass2State) receiveVote() error {
}
ackHash, ok := s.a.countVote(s.a.period, types.VoteAck)
if ok && ackHash != nullBlockHash {
- s.a.voteChan <- &types.Vote{
+ s.a.recv.proposeVote(&types.Vote{
Type: types.VotePass,
BlockHash: ackHash,
Period: s.a.period,
- }
+ })
+ s.voted.Store(true)
} else if s.a.period > 1 {
if _, exist :=
s.a.votes[s.a.period][types.VoteConfirm][s.a.ID]; !exist {
hash, ok := s.a.countVote(s.a.period-1, types.VotePass)
if ok && hash == nullBlockHash {
- s.a.voteChan <- &types.Vote{
+ s.a.recv.proposeVote(&types.Vote{
Type: types.VotePass,
BlockHash: hash,
Period: s.a.period,
- }
+ })
+ s.voted.Store(true)
}
}
}
- go func() {
- hash, ok := s.a.countVote(s.a.period, types.VotePass)
- if ok {
- s.enoughPassVote <- hash
- }
- }()
- s.voted.Store(true)
+ hash, ok := s.a.countVote(s.a.period, types.VotePass)
+ if ok {
+ s.enoughPassVote <- hash
+ }
return nil
}
diff --git a/core/agreement-state_test.go b/core/agreement-state_test.go
index 14b7c6a..5fd6214 100644
--- a/core/agreement-state_test.go
+++ b/core/agreement-state_test.go
@@ -39,6 +39,22 @@ type AgreementTestSuite struct {
block map[common.Hash]*types.Block
}
+type agreementTestReceiver struct {
+ s *AgreementTestSuite
+}
+
+func (r *agreementTestReceiver) proposeVote(vote *types.Vote) {
+ r.s.voteChan <- vote
+}
+
+func (r *agreementTestReceiver) proposeBlock(block common.Hash) {
+ r.s.blockChan <- block
+}
+
+func (r *agreementTestReceiver) confirmBlock(block common.Hash) {
+ r.s.confirmChan <- block
+}
+
func (s *AgreementTestSuite) blockProposer() *types.Block {
block := &types.Block{
ProposerID: s.ID,
@@ -88,27 +104,13 @@ func (s *AgreementTestSuite) newAgreement(numValidator int) *agreement {
s.prvKey[validators[i]] = prvKey
}
validators = append(validators, s.ID)
- agreement, voteChan, blockChan, confirmChan := newAgreement(
+ agreement := newAgreement(
s.ID,
+ &agreementTestReceiver{s},
validators,
eth.SigToPub,
s.blockProposer,
)
- go func() {
- for {
- s.voteChan <- <-voteChan
- }
- }()
- go func() {
- for {
- s.blockChan <- <-blockChan
- }
- }()
- go func() {
- for {
- s.confirmChan <- <-confirmChan
- }
- }()
return agreement
}
@@ -122,15 +124,10 @@ func (s *AgreementTestSuite) TestPrepareState() {
a.data.period = 1
newState, err := state.nextState()
s.Require().Nil(err)
- var proposedBlock common.Hash
- select {
- case proposedBlock = <-s.blockChan:
- s.NotEqual(common.Hash{}, proposedBlock)
- err := a.processBlock(s.block[proposedBlock])
- s.Require().Nil(err)
- case <-time.After(50 * time.Millisecond):
- s.FailNow("Expecting a proposed block.\n")
- }
+ s.Require().True(len(s.blockChan) > 0)
+ proposedBlock := <-s.blockChan
+ s.NotEqual(common.Hash{}, proposedBlock)
+ s.Require().Nil(a.processBlock(s.block[proposedBlock]))
s.Equal(stateAck, newState.state())
// For period >= 2, if the pass-vote for block b equal to {}
@@ -146,12 +143,9 @@ func (s *AgreementTestSuite) TestPrepareState() {
newState, err = state.nextState()
s.Require().Nil(err)
- select {
- case hash := <-s.blockChan:
- s.Equal(proposedBlock, hash)
- case <-time.After(50 * time.Millisecond):
- s.FailNow("Expecting a proposed block.\n")
- }
+ s.Require().True(len(s.blockChan) > 0)
+ hash := <-s.blockChan
+ s.Equal(proposedBlock, hash)
s.Equal(stateAck, newState.state())
// For period >= 2, if the pass-vote for block v not equal to {}
@@ -168,12 +162,9 @@ func (s *AgreementTestSuite) TestPrepareState() {
newState, err = state.nextState()
s.Require().Nil(err)
- select {
- case hash := <-s.blockChan:
- s.Equal(block.Hash, hash)
- case <-time.After(50 * time.Millisecond):
- s.FailNow("Expecting a proposed block.\n")
- }
+ s.Require().True(len(s.blockChan) > 0)
+ hash = <-s.blockChan
+ s.Equal(block.Hash, hash)
s.Equal(stateAck, newState.state())
}
@@ -195,13 +186,10 @@ func (s *AgreementTestSuite) TestAckState() {
a.data.period = 1
newState, err := state.nextState()
s.Require().Nil(err)
- select {
- case vote := <-s.voteChan:
- s.Equal(types.VoteAck, vote.Type)
- s.NotEqual(common.Hash{}, vote.BlockHash)
- case <-time.After(50 * time.Millisecond):
- s.FailNow("Expecting a proposed vote.\n")
- }
+ s.Require().True(len(s.voteChan) > 0)
+ vote := <-s.voteChan
+ s.Equal(types.VoteAck, vote.Type)
+ s.NotEqual(common.Hash{}, vote.BlockHash)
s.Equal(stateConfirm, newState.state())
// For period >= 2, if block v equal to {} has more than 2f+1 pass-vote
@@ -213,13 +201,10 @@ func (s *AgreementTestSuite) TestAckState() {
}
newState, err = state.nextState()
s.Require().Nil(err)
- select {
- case vote := <-s.voteChan:
- s.Equal(types.VoteAck, vote.Type)
- s.NotEqual(common.Hash{}, vote.BlockHash)
- case <-time.After(50 * time.Millisecond):
- s.FailNow("Expecting a proposed vote.\n")
- }
+ s.Require().True(len(s.voteChan) > 0)
+ vote = <-s.voteChan
+ s.Equal(types.VoteAck, vote.Type)
+ s.NotEqual(common.Hash{}, vote.BlockHash)
s.Equal(stateConfirm, newState.state())
// For period >= 2, if block v not equal to {} has more than 2f+1 pass-vote
@@ -232,13 +217,10 @@ func (s *AgreementTestSuite) TestAckState() {
}
newState, err = state.nextState()
s.Require().Nil(err)
- select {
- case vote := <-s.voteChan:
- s.Equal(types.VoteAck, vote.Type)
- s.Equal(hash, vote.BlockHash)
- case <-time.After(50 * time.Millisecond):
- s.FailNow("Expecting a proposed vote.\n")
- }
+ s.Require().True(len(s.voteChan) > 0)
+ vote = <-s.voteChan
+ s.Equal(types.VoteAck, vote.Type)
+ s.Equal(hash, vote.BlockHash)
s.Equal(stateConfirm, newState.state())
}
@@ -260,13 +242,10 @@ func (s *AgreementTestSuite) TestConfirmState() {
s.Require().Nil(state.receiveVote())
newState, err := state.nextState()
s.Require().Nil(err)
- select {
- case vote := <-s.voteChan:
- s.Equal(types.VoteConfirm, vote.Type)
- s.Equal(block.Hash, vote.BlockHash)
- case <-time.After(50 * time.Millisecond):
- s.FailNow("Expecting a proposed vote.\n")
- }
+ s.Require().True(len(s.voteChan) > 0)
+ vote := <-s.voteChan
+ s.Equal(types.VoteConfirm, vote.Type)
+ s.Equal(block.Hash, vote.BlockHash)
s.Equal(statePass1, newState.state())
// Else, no vote is propose in this state.
@@ -274,11 +253,7 @@ func (s *AgreementTestSuite) TestConfirmState() {
s.Require().Nil(state.receiveVote())
newState, err = state.nextState()
s.Require().Nil(err)
- select {
- case <-s.voteChan:
- s.FailNow("Unexpected proposed vote.\n")
- case <-time.After(50 * time.Millisecond):
- }
+ s.Require().True(len(s.voteChan) == 0)
s.Equal(statePass1, newState.state())
// If there are 2f+1 ack-vote for block v equal to {},
@@ -291,11 +266,7 @@ func (s *AgreementTestSuite) TestConfirmState() {
s.Require().Nil(state.receiveVote())
newState, err = state.nextState()
s.Require().Nil(err)
- select {
- case <-s.voteChan:
- s.FailNow("Unexpected proposed vote.\n")
- case <-time.After(50 * time.Millisecond):
- }
+ s.Require().True(len(s.voteChan) == 0)
s.Equal(statePass1, newState.state())
}
@@ -313,13 +284,10 @@ func (s *AgreementTestSuite) TestPass1State() {
s.Require().Nil(a.processVote(vote))
newState, err := state.nextState()
s.Require().Nil(err)
- select {
- case vote := <-s.voteChan:
- s.Equal(types.VotePass, vote.Type)
- s.Equal(hash, vote.BlockHash)
- case <-time.After(50 * time.Millisecond):
- s.FailNow("Expecting a proposed vote.\n")
- }
+ s.Require().True(len(s.voteChan) > 0)
+ vote = <-s.voteChan
+ s.Equal(types.VotePass, vote.Type)
+ s.Equal(hash, vote.BlockHash)
s.Equal(statePass2, newState.state())
// Else if period >= 2 and has 2f+1 pass-vote in period-1 for block {},
@@ -333,13 +301,10 @@ func (s *AgreementTestSuite) TestPass1State() {
s.Require().Nil(a.processVote(vote))
newState, err = state.nextState()
s.Require().Nil(err)
- select {
- case vote := <-s.voteChan:
- s.Equal(types.VotePass, vote.Type)
- s.Equal(common.Hash{}, vote.BlockHash)
- case <-time.After(50 * time.Millisecond):
- s.FailNow("Expecting a proposed vote.\n")
- }
+ s.Require().True(len(s.voteChan) > 0)
+ vote = <-s.voteChan
+ s.Equal(types.VotePass, vote.Type)
+ s.Equal(common.Hash{}, vote.BlockHash)
s.Equal(statePass2, newState.state())
// Else, propose pass-vote for default block.
@@ -355,13 +320,10 @@ func (s *AgreementTestSuite) TestPass1State() {
s.Require().Nil(a.processVote(vote))
newState, err = state.nextState()
s.Require().Nil(err)
- select {
- case vote := <-s.voteChan:
- s.Equal(types.VotePass, vote.Type)
- s.Equal(block.Hash, vote.BlockHash)
- case <-time.After(50 * time.Millisecond):
- s.FailNow("Expecting a proposed vote.\n")
- }
+ s.Require().True(len(s.voteChan) > 0)
+ vote = <-s.voteChan
+ s.Equal(types.VotePass, vote.Type)
+ s.Equal(block.Hash, vote.BlockHash)
s.Equal(statePass2, newState.state())
// Period == 1 is also else condition.
@@ -375,13 +337,10 @@ func (s *AgreementTestSuite) TestPass1State() {
s.Require().Nil(a.processVote(vote))
newState, err = state.nextState()
s.Require().Nil(err)
- select {
- case vote := <-s.voteChan:
- s.Equal(types.VotePass, vote.Type)
- s.Equal(block.Hash, vote.BlockHash)
- case <-time.After(50 * time.Millisecond):
- s.FailNow("Expecting a proposed vote.\n")
- }
+ s.Require().True(len(s.voteChan) > 0)
+ vote = <-s.voteChan
+ s.Equal(types.VotePass, vote.Type)
+ s.Equal(block.Hash, vote.BlockHash)
s.Equal(statePass2, newState.state())
// No enought pass-vote for period-1.
@@ -390,13 +349,10 @@ func (s *AgreementTestSuite) TestPass1State() {
s.Require().Nil(a.processVote(vote))
newState, err = state.nextState()
s.Require().Nil(err)
- select {
- case vote := <-s.voteChan:
- s.Equal(types.VotePass, vote.Type)
- s.Equal(block.Hash, vote.BlockHash)
- case <-time.After(50 * time.Millisecond):
- s.FailNow("Expecting a proposed vote.\n")
- }
+ s.Require().True(len(s.voteChan) > 0)
+ vote = <-s.voteChan
+ s.Equal(types.VotePass, vote.Type)
+ s.Equal(block.Hash, vote.BlockHash)
s.Equal(statePass2, newState.state())
}
@@ -414,20 +370,13 @@ func (s *AgreementTestSuite) TestPass2State() {
s.Require().Nil(a.processVote(vote))
}
s.Require().Nil(state.receiveVote())
- select {
- case vote := <-s.voteChan:
- s.Equal(types.VotePass, vote.Type)
- s.Equal(block.Hash, vote.BlockHash)
- case <-time.After(50 * time.Millisecond):
- s.FailNow("Expecting a proposed vote.\n")
- }
+ s.Require().True(len(s.voteChan) > 0)
+ vote := <-s.voteChan
+ s.Equal(types.VotePass, vote.Type)
+ s.Equal(block.Hash, vote.BlockHash)
// Only propose one vote.
s.Require().Nil(state.receiveVote())
- select {
- case <-s.voteChan:
- s.FailNow("Unexpected proposed vote.\n")
- case <-time.After(50 * time.Millisecond):
- }
+ s.Require().True(len(s.voteChan) == 0)
// If period >= 2 and
// there are 2f+1 pass-vote in period-1 for block v equal to {} and
@@ -439,16 +388,13 @@ func (s *AgreementTestSuite) TestPass2State() {
vote := s.prepareVote(vID, types.VotePass, common.Hash{}, 1)
s.Require().Nil(a.processVote(vote))
}
- vote := s.prepareVote(s.ID, types.VoteAck, common.Hash{}, 2)
+ vote = s.prepareVote(s.ID, types.VoteAck, common.Hash{}, 2)
s.Require().Nil(a.processVote(vote))
s.Require().Nil(state.receiveVote())
- select {
- case vote := <-s.voteChan:
- s.Equal(types.VotePass, vote.Type)
- s.Equal(common.Hash{}, vote.BlockHash)
- case <-time.After(50 * time.Millisecond):
- s.FailNow("Expecting a proposed vote.\n")
- }
+ s.Require().True(len(s.voteChan) > 0)
+ vote = <-s.voteChan
+ s.Equal(types.VotePass, vote.Type)
+ s.Equal(common.Hash{}, vote.BlockHash)
// Test terminate.
ok := make(chan struct{})
diff --git a/core/agreement.go b/core/agreement.go
index 6aeae07..9d0440e 100644
--- a/core/agreement.go
+++ b/core/agreement.go
@@ -55,6 +55,13 @@ func newVoteListMap() []map[types.ValidatorID]*types.Vote {
return listMap
}
+// agreementReceiver is the interface receiving agreement event.
+type agreementReceiver interface {
+ proposeVote(vote *types.Vote)
+ proposeBlock(common.Hash)
+ confirmBlock(common.Hash)
+}
+
// position is the current round of the agreement.
type position struct {
ShardID uint64
@@ -64,8 +71,7 @@ type position struct {
// agreementData is the data for agreementState.
type agreementData struct {
- voteChan chan *types.Vote
- blockChan chan common.Hash
+ recv agreementReceiver
ID types.ValidatorID
leader *leaderSelector
@@ -80,8 +86,6 @@ type agreementData struct {
// agreement is the agreement protocal describe in the Crypto Shuffle Algorithm.
type agreement struct {
- confirmChan chan common.Hash
-
state agreementState
data *agreementData
aID *atomic.Value
@@ -93,32 +97,22 @@ type agreement struct {
// newAgreement creates a agreement instance.
func newAgreement(
ID types.ValidatorID,
+ recv agreementReceiver,
validators types.ValidatorIDs,
sigToPub SigToPubFn,
- blockProposer blockProposerFn) (
- *agreement,
- <-chan *types.Vote,
- <-chan common.Hash,
- <-chan common.Hash,
-) {
- // TODO(jimmy-dexon): use callback instead of channel.
- voteChan := make(chan *types.Vote, 3)
- blockChan := make(chan common.Hash)
- confirmChan := make(chan common.Hash)
+ blockProposer blockProposerFn) *agreement {
agreement := &agreement{
- confirmChan: confirmChan,
data: &agreementData{
+ recv: recv,
ID: ID,
leader: newLeaderSelector(),
- voteChan: voteChan,
- blockChan: blockChan,
blockProposer: blockProposer,
},
aID: &atomic.Value{},
sigToPub: sigToPub,
}
agreement.restart(validators)
- return agreement, voteChan, blockChan, confirmChan
+ return agreement
}
// terminate the current running state.
@@ -213,7 +207,7 @@ func (a *agreement) processVote(vote *types.Vote) error {
if len(a.data.votes[vote.Period][types.VoteConfirm]) >=
a.data.requiredVote {
a.hasOutput = true
- a.confirmChan <- vote.BlockHash
+ a.data.recv.confirmBlock(vote.BlockHash)
}
}
return true