diff options
Diffstat (limited to 'core')
-rw-r--r-- | core/agreement-state.go | 45 | ||||
-rw-r--r-- | core/agreement-state_test.go | 206 | ||||
-rw-r--r-- | core/agreement.go | 32 |
3 files changed, 111 insertions, 172 deletions
diff --git a/core/agreement-state.go b/core/agreement-state.go index 6022bf3..892d7c3 100644 --- a/core/agreement-state.go +++ b/core/agreement-state.go @@ -81,7 +81,7 @@ func (s *prepareState) nextState() (agreementState, error) { delete(s.a.blocks, s.a.ID) } } - s.a.blockChan <- hash + s.a.recv.proposeBlock(hash) return newAckState(s.a), nil } func (s *prepareState) receiveVote() error { return nil } @@ -112,11 +112,11 @@ func (s *ackState) nextState() (agreementState, error) { if hash == nullBlockHash { hash = s.a.leader.leaderBlockHash() } - s.a.voteChan <- &types.Vote{ + s.a.recv.proposeVote(&types.Vote{ Type: types.VoteAck, BlockHash: hash, Period: s.a.period, - } + }) return newConfirmState(s.a), nil } func (s *ackState) receiveVote() error { return nil } @@ -151,11 +151,11 @@ func (s *confirmState) receiveVote() error { return nil } if hash != nullBlockHash { - s.a.voteChan <- &types.Vote{ + s.a.recv.proposeVote(&types.Vote{ Type: types.VoteConfirm, BlockHash: hash, Period: s.a.period, - } + }) } s.voted.Store(true) return nil @@ -179,22 +179,22 @@ func (s *pass1State) nextState() (agreementState, error) { defer s.a.votesLock.RUnlock() if vote, exist := s.a.votes[s.a.period][types.VoteConfirm][s.a.ID]; exist { - s.a.voteChan <- &types.Vote{ + s.a.recv.proposeVote(&types.Vote{ Type: types.VotePass, BlockHash: vote.BlockHash, Period: s.a.period, - } + }) } else if s.a.period == 1 { voteDefault = true } else { hash, ok := s.a.countVote(s.a.period-1, types.VotePass) if ok { if hash == nullBlockHash { - s.a.voteChan <- &types.Vote{ + s.a.recv.proposeVote(&types.Vote{ Type: types.VotePass, BlockHash: hash, Period: s.a.period, - } + }) } else { voteDefault = true } @@ -203,11 +203,11 @@ func (s *pass1State) nextState() (agreementState, error) { } } if voteDefault { - s.a.voteChan <- &types.Vote{ + s.a.recv.proposeVote(&types.Vote{ Type: types.VotePass, BlockHash: s.a.defaultBlock, Period: s.a.period, - } + }) } return newPass2State(s.a), nil } @@ -227,7 +227,7 @@ func newPass2State(a *agreementData) *pass2State { return &pass2State{ a: a, voted: voted, - enoughPassVote: make(chan common.Hash), + enoughPassVote: make(chan common.Hash, 1), terminateChan: make(chan struct{}), } } @@ -259,30 +259,29 @@ func (s *pass2State) receiveVote() error { } ackHash, ok := s.a.countVote(s.a.period, types.VoteAck) if ok && ackHash != nullBlockHash { - s.a.voteChan <- &types.Vote{ + s.a.recv.proposeVote(&types.Vote{ Type: types.VotePass, BlockHash: ackHash, Period: s.a.period, - } + }) + s.voted.Store(true) } else if s.a.period > 1 { if _, exist := s.a.votes[s.a.period][types.VoteConfirm][s.a.ID]; !exist { hash, ok := s.a.countVote(s.a.period-1, types.VotePass) if ok && hash == nullBlockHash { - s.a.voteChan <- &types.Vote{ + s.a.recv.proposeVote(&types.Vote{ Type: types.VotePass, BlockHash: hash, Period: s.a.period, - } + }) + s.voted.Store(true) } } } - go func() { - hash, ok := s.a.countVote(s.a.period, types.VotePass) - if ok { - s.enoughPassVote <- hash - } - }() - s.voted.Store(true) + hash, ok := s.a.countVote(s.a.period, types.VotePass) + if ok { + s.enoughPassVote <- hash + } return nil } diff --git a/core/agreement-state_test.go b/core/agreement-state_test.go index 14b7c6a..5fd6214 100644 --- a/core/agreement-state_test.go +++ b/core/agreement-state_test.go @@ -39,6 +39,22 @@ type AgreementTestSuite struct { block map[common.Hash]*types.Block } +type agreementTestReceiver struct { + s *AgreementTestSuite +} + +func (r *agreementTestReceiver) proposeVote(vote *types.Vote) { + r.s.voteChan <- vote +} + +func (r *agreementTestReceiver) proposeBlock(block common.Hash) { + r.s.blockChan <- block +} + +func (r *agreementTestReceiver) confirmBlock(block common.Hash) { + r.s.confirmChan <- block +} + func (s *AgreementTestSuite) blockProposer() *types.Block { block := &types.Block{ ProposerID: s.ID, @@ -88,27 +104,13 @@ func (s *AgreementTestSuite) newAgreement(numValidator int) *agreement { s.prvKey[validators[i]] = prvKey } validators = append(validators, s.ID) - agreement, voteChan, blockChan, confirmChan := newAgreement( + agreement := newAgreement( s.ID, + &agreementTestReceiver{s}, validators, eth.SigToPub, s.blockProposer, ) - go func() { - for { - s.voteChan <- <-voteChan - } - }() - go func() { - for { - s.blockChan <- <-blockChan - } - }() - go func() { - for { - s.confirmChan <- <-confirmChan - } - }() return agreement } @@ -122,15 +124,10 @@ func (s *AgreementTestSuite) TestPrepareState() { a.data.period = 1 newState, err := state.nextState() s.Require().Nil(err) - var proposedBlock common.Hash - select { - case proposedBlock = <-s.blockChan: - s.NotEqual(common.Hash{}, proposedBlock) - err := a.processBlock(s.block[proposedBlock]) - s.Require().Nil(err) - case <-time.After(50 * time.Millisecond): - s.FailNow("Expecting a proposed block.\n") - } + s.Require().True(len(s.blockChan) > 0) + proposedBlock := <-s.blockChan + s.NotEqual(common.Hash{}, proposedBlock) + s.Require().Nil(a.processBlock(s.block[proposedBlock])) s.Equal(stateAck, newState.state()) // For period >= 2, if the pass-vote for block b equal to {} @@ -146,12 +143,9 @@ func (s *AgreementTestSuite) TestPrepareState() { newState, err = state.nextState() s.Require().Nil(err) - select { - case hash := <-s.blockChan: - s.Equal(proposedBlock, hash) - case <-time.After(50 * time.Millisecond): - s.FailNow("Expecting a proposed block.\n") - } + s.Require().True(len(s.blockChan) > 0) + hash := <-s.blockChan + s.Equal(proposedBlock, hash) s.Equal(stateAck, newState.state()) // For period >= 2, if the pass-vote for block v not equal to {} @@ -168,12 +162,9 @@ func (s *AgreementTestSuite) TestPrepareState() { newState, err = state.nextState() s.Require().Nil(err) - select { - case hash := <-s.blockChan: - s.Equal(block.Hash, hash) - case <-time.After(50 * time.Millisecond): - s.FailNow("Expecting a proposed block.\n") - } + s.Require().True(len(s.blockChan) > 0) + hash = <-s.blockChan + s.Equal(block.Hash, hash) s.Equal(stateAck, newState.state()) } @@ -195,13 +186,10 @@ func (s *AgreementTestSuite) TestAckState() { a.data.period = 1 newState, err := state.nextState() s.Require().Nil(err) - select { - case vote := <-s.voteChan: - s.Equal(types.VoteAck, vote.Type) - s.NotEqual(common.Hash{}, vote.BlockHash) - case <-time.After(50 * time.Millisecond): - s.FailNow("Expecting a proposed vote.\n") - } + s.Require().True(len(s.voteChan) > 0) + vote := <-s.voteChan + s.Equal(types.VoteAck, vote.Type) + s.NotEqual(common.Hash{}, vote.BlockHash) s.Equal(stateConfirm, newState.state()) // For period >= 2, if block v equal to {} has more than 2f+1 pass-vote @@ -213,13 +201,10 @@ func (s *AgreementTestSuite) TestAckState() { } newState, err = state.nextState() s.Require().Nil(err) - select { - case vote := <-s.voteChan: - s.Equal(types.VoteAck, vote.Type) - s.NotEqual(common.Hash{}, vote.BlockHash) - case <-time.After(50 * time.Millisecond): - s.FailNow("Expecting a proposed vote.\n") - } + s.Require().True(len(s.voteChan) > 0) + vote = <-s.voteChan + s.Equal(types.VoteAck, vote.Type) + s.NotEqual(common.Hash{}, vote.BlockHash) s.Equal(stateConfirm, newState.state()) // For period >= 2, if block v not equal to {} has more than 2f+1 pass-vote @@ -232,13 +217,10 @@ func (s *AgreementTestSuite) TestAckState() { } newState, err = state.nextState() s.Require().Nil(err) - select { - case vote := <-s.voteChan: - s.Equal(types.VoteAck, vote.Type) - s.Equal(hash, vote.BlockHash) - case <-time.After(50 * time.Millisecond): - s.FailNow("Expecting a proposed vote.\n") - } + s.Require().True(len(s.voteChan) > 0) + vote = <-s.voteChan + s.Equal(types.VoteAck, vote.Type) + s.Equal(hash, vote.BlockHash) s.Equal(stateConfirm, newState.state()) } @@ -260,13 +242,10 @@ func (s *AgreementTestSuite) TestConfirmState() { s.Require().Nil(state.receiveVote()) newState, err := state.nextState() s.Require().Nil(err) - select { - case vote := <-s.voteChan: - s.Equal(types.VoteConfirm, vote.Type) - s.Equal(block.Hash, vote.BlockHash) - case <-time.After(50 * time.Millisecond): - s.FailNow("Expecting a proposed vote.\n") - } + s.Require().True(len(s.voteChan) > 0) + vote := <-s.voteChan + s.Equal(types.VoteConfirm, vote.Type) + s.Equal(block.Hash, vote.BlockHash) s.Equal(statePass1, newState.state()) // Else, no vote is propose in this state. @@ -274,11 +253,7 @@ func (s *AgreementTestSuite) TestConfirmState() { s.Require().Nil(state.receiveVote()) newState, err = state.nextState() s.Require().Nil(err) - select { - case <-s.voteChan: - s.FailNow("Unexpected proposed vote.\n") - case <-time.After(50 * time.Millisecond): - } + s.Require().True(len(s.voteChan) == 0) s.Equal(statePass1, newState.state()) // If there are 2f+1 ack-vote for block v equal to {}, @@ -291,11 +266,7 @@ func (s *AgreementTestSuite) TestConfirmState() { s.Require().Nil(state.receiveVote()) newState, err = state.nextState() s.Require().Nil(err) - select { - case <-s.voteChan: - s.FailNow("Unexpected proposed vote.\n") - case <-time.After(50 * time.Millisecond): - } + s.Require().True(len(s.voteChan) == 0) s.Equal(statePass1, newState.state()) } @@ -313,13 +284,10 @@ func (s *AgreementTestSuite) TestPass1State() { s.Require().Nil(a.processVote(vote)) newState, err := state.nextState() s.Require().Nil(err) - select { - case vote := <-s.voteChan: - s.Equal(types.VotePass, vote.Type) - s.Equal(hash, vote.BlockHash) - case <-time.After(50 * time.Millisecond): - s.FailNow("Expecting a proposed vote.\n") - } + s.Require().True(len(s.voteChan) > 0) + vote = <-s.voteChan + s.Equal(types.VotePass, vote.Type) + s.Equal(hash, vote.BlockHash) s.Equal(statePass2, newState.state()) // Else if period >= 2 and has 2f+1 pass-vote in period-1 for block {}, @@ -333,13 +301,10 @@ func (s *AgreementTestSuite) TestPass1State() { s.Require().Nil(a.processVote(vote)) newState, err = state.nextState() s.Require().Nil(err) - select { - case vote := <-s.voteChan: - s.Equal(types.VotePass, vote.Type) - s.Equal(common.Hash{}, vote.BlockHash) - case <-time.After(50 * time.Millisecond): - s.FailNow("Expecting a proposed vote.\n") - } + s.Require().True(len(s.voteChan) > 0) + vote = <-s.voteChan + s.Equal(types.VotePass, vote.Type) + s.Equal(common.Hash{}, vote.BlockHash) s.Equal(statePass2, newState.state()) // Else, propose pass-vote for default block. @@ -355,13 +320,10 @@ func (s *AgreementTestSuite) TestPass1State() { s.Require().Nil(a.processVote(vote)) newState, err = state.nextState() s.Require().Nil(err) - select { - case vote := <-s.voteChan: - s.Equal(types.VotePass, vote.Type) - s.Equal(block.Hash, vote.BlockHash) - case <-time.After(50 * time.Millisecond): - s.FailNow("Expecting a proposed vote.\n") - } + s.Require().True(len(s.voteChan) > 0) + vote = <-s.voteChan + s.Equal(types.VotePass, vote.Type) + s.Equal(block.Hash, vote.BlockHash) s.Equal(statePass2, newState.state()) // Period == 1 is also else condition. @@ -375,13 +337,10 @@ func (s *AgreementTestSuite) TestPass1State() { s.Require().Nil(a.processVote(vote)) newState, err = state.nextState() s.Require().Nil(err) - select { - case vote := <-s.voteChan: - s.Equal(types.VotePass, vote.Type) - s.Equal(block.Hash, vote.BlockHash) - case <-time.After(50 * time.Millisecond): - s.FailNow("Expecting a proposed vote.\n") - } + s.Require().True(len(s.voteChan) > 0) + vote = <-s.voteChan + s.Equal(types.VotePass, vote.Type) + s.Equal(block.Hash, vote.BlockHash) s.Equal(statePass2, newState.state()) // No enought pass-vote for period-1. @@ -390,13 +349,10 @@ func (s *AgreementTestSuite) TestPass1State() { s.Require().Nil(a.processVote(vote)) newState, err = state.nextState() s.Require().Nil(err) - select { - case vote := <-s.voteChan: - s.Equal(types.VotePass, vote.Type) - s.Equal(block.Hash, vote.BlockHash) - case <-time.After(50 * time.Millisecond): - s.FailNow("Expecting a proposed vote.\n") - } + s.Require().True(len(s.voteChan) > 0) + vote = <-s.voteChan + s.Equal(types.VotePass, vote.Type) + s.Equal(block.Hash, vote.BlockHash) s.Equal(statePass2, newState.state()) } @@ -414,20 +370,13 @@ func (s *AgreementTestSuite) TestPass2State() { s.Require().Nil(a.processVote(vote)) } s.Require().Nil(state.receiveVote()) - select { - case vote := <-s.voteChan: - s.Equal(types.VotePass, vote.Type) - s.Equal(block.Hash, vote.BlockHash) - case <-time.After(50 * time.Millisecond): - s.FailNow("Expecting a proposed vote.\n") - } + s.Require().True(len(s.voteChan) > 0) + vote := <-s.voteChan + s.Equal(types.VotePass, vote.Type) + s.Equal(block.Hash, vote.BlockHash) // Only propose one vote. s.Require().Nil(state.receiveVote()) - select { - case <-s.voteChan: - s.FailNow("Unexpected proposed vote.\n") - case <-time.After(50 * time.Millisecond): - } + s.Require().True(len(s.voteChan) == 0) // If period >= 2 and // there are 2f+1 pass-vote in period-1 for block v equal to {} and @@ -439,16 +388,13 @@ func (s *AgreementTestSuite) TestPass2State() { vote := s.prepareVote(vID, types.VotePass, common.Hash{}, 1) s.Require().Nil(a.processVote(vote)) } - vote := s.prepareVote(s.ID, types.VoteAck, common.Hash{}, 2) + vote = s.prepareVote(s.ID, types.VoteAck, common.Hash{}, 2) s.Require().Nil(a.processVote(vote)) s.Require().Nil(state.receiveVote()) - select { - case vote := <-s.voteChan: - s.Equal(types.VotePass, vote.Type) - s.Equal(common.Hash{}, vote.BlockHash) - case <-time.After(50 * time.Millisecond): - s.FailNow("Expecting a proposed vote.\n") - } + s.Require().True(len(s.voteChan) > 0) + vote = <-s.voteChan + s.Equal(types.VotePass, vote.Type) + s.Equal(common.Hash{}, vote.BlockHash) // Test terminate. ok := make(chan struct{}) diff --git a/core/agreement.go b/core/agreement.go index 6aeae07..9d0440e 100644 --- a/core/agreement.go +++ b/core/agreement.go @@ -55,6 +55,13 @@ func newVoteListMap() []map[types.ValidatorID]*types.Vote { return listMap } +// agreementReceiver is the interface receiving agreement event. +type agreementReceiver interface { + proposeVote(vote *types.Vote) + proposeBlock(common.Hash) + confirmBlock(common.Hash) +} + // position is the current round of the agreement. type position struct { ShardID uint64 @@ -64,8 +71,7 @@ type position struct { // agreementData is the data for agreementState. type agreementData struct { - voteChan chan *types.Vote - blockChan chan common.Hash + recv agreementReceiver ID types.ValidatorID leader *leaderSelector @@ -80,8 +86,6 @@ type agreementData struct { // agreement is the agreement protocal describe in the Crypto Shuffle Algorithm. type agreement struct { - confirmChan chan common.Hash - state agreementState data *agreementData aID *atomic.Value @@ -93,32 +97,22 @@ type agreement struct { // newAgreement creates a agreement instance. func newAgreement( ID types.ValidatorID, + recv agreementReceiver, validators types.ValidatorIDs, sigToPub SigToPubFn, - blockProposer blockProposerFn) ( - *agreement, - <-chan *types.Vote, - <-chan common.Hash, - <-chan common.Hash, -) { - // TODO(jimmy-dexon): use callback instead of channel. - voteChan := make(chan *types.Vote, 3) - blockChan := make(chan common.Hash) - confirmChan := make(chan common.Hash) + blockProposer blockProposerFn) *agreement { agreement := &agreement{ - confirmChan: confirmChan, data: &agreementData{ + recv: recv, ID: ID, leader: newLeaderSelector(), - voteChan: voteChan, - blockChan: blockChan, blockProposer: blockProposer, }, aID: &atomic.Value{}, sigToPub: sigToPub, } agreement.restart(validators) - return agreement, voteChan, blockChan, confirmChan + return agreement } // terminate the current running state. @@ -213,7 +207,7 @@ func (a *agreement) processVote(vote *types.Vote) error { if len(a.data.votes[vote.Period][types.VoteConfirm]) >= a.data.requiredVote { a.hasOutput = true - a.confirmChan <- vote.BlockHash + a.data.recv.confirmBlock(vote.BlockHash) } } return true |