From 702f086745d19e51657502de7a94d39690be55f7 Mon Sep 17 00:00:00 2001
From: Bojie Wu <bojie@dexon.org>
Date: Tue, 9 Oct 2018 13:28:45 +0800
Subject: app: using lock correctly to use map safely

---
 dex/app.go | 140 +++++++++++++++++++++++++++++++++----------------------------
 1 file changed, 77 insertions(+), 63 deletions(-)

(limited to 'dex')

diff --git a/dex/app.go b/dex/app.go
index b924ab620..6c6d8da53 100644
--- a/dex/app.go
+++ b/dex/app.go
@@ -29,7 +29,6 @@ import (
 
 	"github.com/dexon-foundation/dexon/common"
 	"github.com/dexon-foundation/dexon/core"
-	"github.com/dexon-foundation/dexon/core/state"
 	"github.com/dexon-foundation/dexon/core/types"
 	"github.com/dexon-foundation/dexon/core/vm"
 	"github.com/dexon-foundation/dexon/ethdb"
@@ -47,13 +46,15 @@ type DexconApp struct {
 	vmConfig   vm.Config
 
 	notifyChan map[uint64]*notify
-	notifyMu   *sync.Mutex
+	notifyMu   sync.Mutex
 
-	lastPendingHeight uint64
-	insertMu          sync.Mutex
+	chainLatestRootMu sync.RWMutex
+	chainLatestRoot   map[uint32]*common.Hash
 
-	chainLocksInitMu *sync.Mutex
-	chainLocks       map[uint32]*sync.Mutex
+	insertMu sync.Mutex
+
+	chainLocksInitMu sync.Mutex
+	chainLocks       map[uint32]*sync.RWMutex
 }
 
 type notify struct {
@@ -68,16 +69,15 @@ type witnessData struct {
 
 func NewDexconApp(txPool *core.TxPool, blockchain *core.BlockChain, gov *DexconGovernance, chainDB ethdb.Database, config *Config, vmConfig vm.Config) *DexconApp {
 	return &DexconApp{
-		txPool:           txPool,
-		blockchain:       blockchain,
-		gov:              gov,
-		chainDB:          chainDB,
-		config:           config,
-		vmConfig:         vmConfig,
-		notifyChan:       make(map[uint64]*notify),
-		notifyMu:         &sync.Mutex{},
-		chainLocksInitMu: &sync.Mutex{},
-		chainLocks:       make(map[uint32]*sync.Mutex),
+		txPool:          txPool,
+		blockchain:      blockchain,
+		gov:             gov,
+		chainDB:         chainDB,
+		config:          config,
+		vmConfig:        vmConfig,
+		notifyChan:      make(map[uint64]*notify),
+		chainLocks:      make(map[uint32]*sync.RWMutex),
+		chainLatestRoot: make(map[uint32]*common.Hash),
 	}
 }
 
@@ -105,7 +105,6 @@ func (d *DexconApp) notify(height uint64) {
 			delete(d.notifyChan, h)
 		}
 	}
-	d.lastPendingHeight = height
 }
 
 func (d *DexconApp) checkChain(address common.Address, chainSize, chainID *big.Int) bool {
@@ -115,8 +114,8 @@ func (d *DexconApp) checkChain(address common.Address, chainSize, chainID *big.I
 
 // PreparePayload is called when consensus core is preparing payload for block.
 func (d *DexconApp) PreparePayload(position coreTypes.Position) (payload []byte, err error) {
-	d.chainLock(position.ChainID)
-	defer d.chainUnlock(position.ChainID)
+	d.chainRLock(position.ChainID)
+	defer d.chainRUnlock(position.ChainID)
 
 	if position.Height != 0 {
 		// check if chain block height is sequential
@@ -127,21 +126,16 @@ func (d *DexconApp) PreparePayload(position coreTypes.Position) (payload []byte,
 		}
 	}
 
-	// set state to the pending height
-	var latestState *state.StateDB
-	lastPendingBlock := d.blockchain.GetPendingBlockByHeight(d.lastPendingHeight)
-	if d.lastPendingHeight == 0 || lastPendingBlock == nil {
-		latestState, err = d.blockchain.State()
-		if err != nil {
-			log.Error("Get current state", "error", err)
-			return nil, fmt.Errorf("get current state error %v", err)
-		}
-	} else {
-		latestState, err = d.blockchain.StateAt(lastPendingBlock.Root())
-		if err != nil {
-			log.Error("Get pending state", "error", err)
-			return nil, fmt.Errorf("get pending state error: %v", err)
-		}
+	root := d.getChainLatestRoot(position.ChainID)
+	if root == nil {
+		currentRoot := d.blockchain.CurrentBlock().Root()
+		root = &currentRoot
+	}
+	// set state to the chain latest height
+	latestState, err := d.blockchain.StateAt(*root)
+	if err != nil {
+		log.Error("Get pending state", "error", err)
+		return nil, fmt.Errorf("get pending state error: %v", err)
 	}
 
 	txsMap, err := d.txPool.Pending()
@@ -220,16 +214,13 @@ addressMap:
 // PrepareWitness will return the witness data no lower than consensusHeight.
 func (d *DexconApp) PrepareWitness(consensusHeight uint64) (witness coreTypes.Witness, err error) {
 	var witnessBlock *types.Block
-	if d.lastPendingHeight == 0 && consensusHeight == 0 {
+	lastPendingHeight := d.blockchain.GetLastPendingHeight()
+	if lastPendingHeight == 0 && consensusHeight == 0 {
 		witnessBlock = d.blockchain.CurrentBlock()
-	} else if d.lastPendingHeight >= consensusHeight {
-		d.insertMu.Lock()
-		witnessBlock = d.blockchain.GetPendingBlockByHeight(d.lastPendingHeight)
-		d.insertMu.Unlock()
+	} else if lastPendingHeight >= consensusHeight {
+		witnessBlock = d.blockchain.GetLastPendingBlock()
 	} else if h := <-d.addNotify(consensusHeight); h >= consensusHeight {
-		d.insertMu.Lock()
-		witnessBlock = d.blockchain.GetPendingBlockByHeight(h)
-		d.insertMu.Unlock()
+		witnessBlock = d.blockchain.GetLastPendingBlock()
 	} else {
 		log.Error("need pending block")
 		return witness, fmt.Errorf("need pending block")
@@ -273,8 +264,8 @@ func (d *DexconApp) VerifyBlock(block *coreTypes.Block) coreTypes.BlockVerifySta
 		return coreTypes.VerifyRetryLater
 	}
 
-	d.chainLock(block.Position.ChainID)
-	defer d.chainUnlock(block.Position.ChainID)
+	d.chainRLock(block.Position.ChainID)
+	defer d.chainRUnlock(block.Position.ChainID)
 
 	if block.Position.Height != 0 {
 		// check if chain block height is sequential
@@ -285,21 +276,16 @@ func (d *DexconApp) VerifyBlock(block *coreTypes.Block) coreTypes.BlockVerifySta
 		}
 	}
 
-	// set state to the pending height
-	var latestState *state.StateDB
-	lastPendingBlock := d.blockchain.GetPendingBlockByHeight(d.lastPendingHeight)
-	if d.lastPendingHeight == 0 || lastPendingBlock == nil {
-		latestState, err = d.blockchain.State()
-		if err != nil {
-			log.Error("Get current state", "error", err)
-			return coreTypes.VerifyInvalidBlock
-		}
-	} else {
-		latestState, err = d.blockchain.StateAt(lastPendingBlock.Root())
-		if err != nil {
-			log.Error("Get pending state", "error", err)
-			return coreTypes.VerifyInvalidBlock
-		}
+	root := d.getChainLatestRoot(block.Position.ChainID)
+	if root == nil {
+		currentRoot := d.blockchain.CurrentBlock().Root()
+		root = &currentRoot
+	}
+	// set state to the chain latest height
+	latestState, err := d.blockchain.StateAt(*root)
+	if err != nil {
+		log.Error("Get pending state", "error", err)
+		return coreTypes.VerifyInvalidBlock
 	}
 
 	var transactions types.Transactions
@@ -430,11 +416,12 @@ func (d *DexconApp) BlockDelivered(blockHash coreCommon.Hash, result coreTypes.F
 		Randomness: result.Randomness,
 	}, transactions, nil, nil)
 
-	_, err = d.blockchain.ProcessPendingBlock(newBlock, &block.Witness)
+	root, err := d.blockchain.ProcessPendingBlock(newBlock, &block.Witness)
 	if err != nil {
 		log.Error("Insert chain", "error", err)
 		panic(err)
 	}
+	d.setChainLatestRoot(block.Position.ChainID, root)
 
 	log.Info("Insert pending block success", "height", result.Height)
 	d.blockchain.RemoveConfirmedBlock(blockHash)
@@ -473,17 +460,44 @@ func (d *DexconApp) validateNonce(txs types.Transactions) (map[common.Address]ui
 	return addressFirstNonce, nil
 }
 
-func (d *DexconApp) chainLock(chainID uint32) {
+func (d *DexconApp) getChainLatestRoot(chainID uint32) *common.Hash {
+	d.chainLatestRootMu.RLock()
+	defer d.chainLatestRootMu.RUnlock()
+
+	return d.chainLatestRoot[chainID]
+}
+
+func (d *DexconApp) setChainLatestRoot(chainID uint32, root *common.Hash) {
+	d.chainLatestRootMu.Lock()
+	defer d.chainLatestRootMu.Unlock()
+
+	d.chainLatestRoot[chainID] = root
+}
+
+func (d *DexconApp) chainLockInit(chainID uint32) {
 	d.chainLocksInitMu.Lock()
+	defer d.chainLocksInitMu.Unlock()
+
 	_, exist := d.chainLocks[chainID]
 	if !exist {
-		d.chainLocks[chainID] = &sync.Mutex{}
+		d.chainLocks[chainID] = &sync.RWMutex{}
 	}
-	d.chainLocksInitMu.Unlock()
+}
 
+func (d *DexconApp) chainLock(chainID uint32) {
+	d.chainLockInit(chainID)
 	d.chainLocks[chainID].Lock()
 }
 
 func (d *DexconApp) chainUnlock(chainID uint32) {
 	d.chainLocks[chainID].Unlock()
 }
+
+func (d *DexconApp) chainRLock(chainID uint32) {
+	d.chainLockInit(chainID)
+	d.chainLocks[chainID].RLock()
+}
+
+func (d *DexconApp) chainRUnlock(chainID uint32) {
+	d.chainLocks[chainID].RUnlock()
+}
-- 
cgit v1.2.3