aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJeffrey Wilcke <jeffrey@ethereum.org>2016-10-06 22:14:22 +0800
committerGitHub <noreply@github.com>2016-10-06 22:14:22 +0800
commit7335a70a020517cc9cebe7ae82c0e49ba133abf1 (patch)
treecc63625fa07bf3fb28326a01d5c8255a83a83bd1
parent07caa3fccdfe11bbee084c043ac11e7cfae9a6b7 (diff)
parent3c836dd71b192de24774b1848173a4eb0ca9a63b (diff)
downloaddexon-7335a70a020517cc9cebe7ae82c0e49ba133abf1.tar
dexon-7335a70a020517cc9cebe7ae82c0e49ba133abf1.tar.gz
dexon-7335a70a020517cc9cebe7ae82c0e49ba133abf1.tar.bz2
dexon-7335a70a020517cc9cebe7ae82c0e49ba133abf1.tar.lz
dexon-7335a70a020517cc9cebe7ae82c0e49ba133abf1.tar.xz
dexon-7335a70a020517cc9cebe7ae82c0e49ba133abf1.tar.zst
dexon-7335a70a020517cc9cebe7ae82c0e49ba133abf1.zip
Merge pull request #3092 from fjl/state-journal
core/state: implement reverts by journaling all changes
-rw-r--r--accounts/abi/bind/backends/simulated.go6
-rw-r--r--cmd/evm/main.go32
-rw-r--r--core/chain_makers.go2
-rw-r--r--core/execution.go8
-rw-r--r--core/state/dump.go2
-rw-r--r--core/state/journal.go117
-rw-r--r--core/state/managed_state_test.go7
-rw-r--r--core/state/state_object.go76
-rw-r--r--core/state/state_test.go30
-rw-r--r--core/state/statedb.go222
-rw-r--r--core/state/statedb_test.go313
-rw-r--r--core/state/sync_test.go2
-rw-r--r--core/tx_pool.go2
-rw-r--r--core/vm/environment.go11
-rw-r--r--core/vm/instructions.go2
-rw-r--r--core/vm/jit.go2
-rw-r--r--core/vm/jit_test.go4
-rw-r--r--core/vm/runtime/env.go8
-rw-r--r--core/vm/vm.go2
-rw-r--r--core/vm_env.go8
-rw-r--r--eth/api_backend.go6
-rw-r--r--internal/ethapi/tracer_test.go16
-rw-r--r--light/state_test.go14
-rw-r--r--miner/worker.go6
-rw-r--r--tests/state_test_util.go22
-rw-r--r--tests/util.go34
-rw-r--r--tests/vm_test_util.go18
27 files changed, 697 insertions, 275 deletions
diff --git a/accounts/abi/bind/backends/simulated.go b/accounts/abi/bind/backends/simulated.go
index 7e09abb11..74203a468 100644
--- a/accounts/abi/bind/backends/simulated.go
+++ b/accounts/abi/bind/backends/simulated.go
@@ -172,8 +172,9 @@ func (b *SimulatedBackend) CallContract(ctx context.Context, call ethereum.CallM
func (b *SimulatedBackend) PendingCallContract(ctx context.Context, call ethereum.CallMsg) ([]byte, error) {
b.mu.Lock()
defer b.mu.Unlock()
+ defer b.pendingState.RevertToSnapshot(b.pendingState.Snapshot())
- rval, _, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState.Copy())
+ rval, _, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState)
return rval, err
}
@@ -197,8 +198,9 @@ func (b *SimulatedBackend) SuggestGasPrice(ctx context.Context) (*big.Int, error
func (b *SimulatedBackend) EstimateGas(ctx context.Context, call ethereum.CallMsg) (*big.Int, error) {
b.mu.Lock()
defer b.mu.Unlock()
+ defer b.pendingState.RevertToSnapshot(b.pendingState.Snapshot())
- _, gas, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState.Copy())
+ _, gas, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState)
return gas, err
}
diff --git a/cmd/evm/main.go b/cmd/evm/main.go
index 09ade1577..22707c1cc 100644
--- a/cmd/evm/main.go
+++ b/cmd/evm/main.go
@@ -227,22 +227,22 @@ type ruleSet struct{}
func (ruleSet) IsHomestead(*big.Int) bool { return true }
-func (self *VMEnv) RuleSet() vm.RuleSet { return ruleSet{} }
-func (self *VMEnv) Vm() vm.Vm { return self.evm }
-func (self *VMEnv) Db() vm.Database { return self.state }
-func (self *VMEnv) MakeSnapshot() vm.Database { return self.state.Copy() }
-func (self *VMEnv) SetSnapshot(db vm.Database) { self.state.Set(db.(*state.StateDB)) }
-func (self *VMEnv) Origin() common.Address { return *self.transactor }
-func (self *VMEnv) BlockNumber() *big.Int { return common.Big0 }
-func (self *VMEnv) Coinbase() common.Address { return *self.transactor }
-func (self *VMEnv) Time() *big.Int { return self.time }
-func (self *VMEnv) Difficulty() *big.Int { return common.Big1 }
-func (self *VMEnv) BlockHash() []byte { return make([]byte, 32) }
-func (self *VMEnv) Value() *big.Int { return self.value }
-func (self *VMEnv) GasLimit() *big.Int { return big.NewInt(1000000000) }
-func (self *VMEnv) VmType() vm.Type { return vm.StdVmTy }
-func (self *VMEnv) Depth() int { return 0 }
-func (self *VMEnv) SetDepth(i int) { self.depth = i }
+func (self *VMEnv) RuleSet() vm.RuleSet { return ruleSet{} }
+func (self *VMEnv) Vm() vm.Vm { return self.evm }
+func (self *VMEnv) Db() vm.Database { return self.state }
+func (self *VMEnv) SnapshotDatabase() int { return self.state.Snapshot() }
+func (self *VMEnv) RevertToSnapshot(snap int) { self.state.RevertToSnapshot(snap) }
+func (self *VMEnv) Origin() common.Address { return *self.transactor }
+func (self *VMEnv) BlockNumber() *big.Int { return common.Big0 }
+func (self *VMEnv) Coinbase() common.Address { return *self.transactor }
+func (self *VMEnv) Time() *big.Int { return self.time }
+func (self *VMEnv) Difficulty() *big.Int { return common.Big1 }
+func (self *VMEnv) BlockHash() []byte { return make([]byte, 32) }
+func (self *VMEnv) Value() *big.Int { return self.value }
+func (self *VMEnv) GasLimit() *big.Int { return big.NewInt(1000000000) }
+func (self *VMEnv) VmType() vm.Type { return vm.StdVmTy }
+func (self *VMEnv) Depth() int { return 0 }
+func (self *VMEnv) SetDepth(i int) { self.depth = i }
func (self *VMEnv) GetHash(n uint64) common.Hash {
if self.block.Number().Cmp(big.NewInt(int64(n))) == 0 {
return self.block.Hash()
diff --git a/core/chain_makers.go b/core/chain_makers.go
index 0b9a5f75d..e3ad9cda0 100644
--- a/core/chain_makers.go
+++ b/core/chain_makers.go
@@ -131,7 +131,7 @@ func (b *BlockGen) AddUncheckedReceipt(receipt *types.Receipt) {
// TxNonce returns the next valid transaction nonce for the
// account at addr. It panics if the account does not exist.
func (b *BlockGen) TxNonce(addr common.Address) uint64 {
- if !b.statedb.HasAccount(addr) {
+ if !b.statedb.Exist(addr) {
panic("account does not exist")
}
return b.statedb.GetNonce(addr)
diff --git a/core/execution.go b/core/execution.go
index 1bc02f7fb..1cb507ee7 100644
--- a/core/execution.go
+++ b/core/execution.go
@@ -85,7 +85,7 @@ func exec(env vm.Environment, caller vm.ContractRef, address, codeAddr *common.A
createAccount = true
}
- snapshotPreTransfer := env.MakeSnapshot()
+ snapshotPreTransfer := env.SnapshotDatabase()
var (
from = env.Db().GetAccount(caller.Address())
to vm.Account
@@ -129,7 +129,7 @@ func exec(env vm.Environment, caller vm.ContractRef, address, codeAddr *common.A
if err != nil && (env.RuleSet().IsHomestead(env.BlockNumber()) || err != vm.CodeStoreOutOfGasError) {
contract.UseGas(contract.Gas)
- env.SetSnapshot(snapshotPreTransfer)
+ env.RevertToSnapshot(snapshotPreTransfer)
}
return ret, addr, err
@@ -144,7 +144,7 @@ func execDelegateCall(env vm.Environment, caller vm.ContractRef, originAddr, toA
return nil, common.Address{}, vm.DepthError
}
- snapshot := env.MakeSnapshot()
+ snapshot := env.SnapshotDatabase()
var to vm.Account
if !env.Db().Exist(*toAddr) {
@@ -162,7 +162,7 @@ func execDelegateCall(env vm.Environment, caller vm.ContractRef, originAddr, toA
if err != nil {
contract.UseGas(contract.Gas)
- env.SetSnapshot(snapshot)
+ env.RevertToSnapshot(snapshot)
}
return ret, addr, err
diff --git a/core/state/dump.go b/core/state/dump.go
index 58ecd852b..8294d61b9 100644
--- a/core/state/dump.go
+++ b/core/state/dump.go
@@ -52,7 +52,7 @@ func (self *StateDB) RawDump() Dump {
panic(err)
}
- obj := NewObject(common.BytesToAddress(addr), data, nil)
+ obj := newObject(nil, common.BytesToAddress(addr), data, nil)
account := DumpAccount{
Balance: data.Balance.String(),
Nonce: data.Nonce,
diff --git a/core/state/journal.go b/core/state/journal.go
new file mode 100644
index 000000000..720c821b9
--- /dev/null
+++ b/core/state/journal.go
@@ -0,0 +1,117 @@
+// Copyright 2016 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+package state
+
+import (
+ "math/big"
+
+ "github.com/ethereum/go-ethereum/common"
+)
+
+type journalEntry interface {
+ undo(*StateDB)
+}
+
+type journal []journalEntry
+
+type (
+ // Changes to the account trie.
+ createObjectChange struct {
+ account *common.Address
+ }
+ resetObjectChange struct {
+ prev *StateObject
+ }
+ suicideChange struct {
+ account *common.Address
+ prev bool // whether account had already suicided
+ prevbalance *big.Int
+ }
+
+ // Changes to individual accounts.
+ balanceChange struct {
+ account *common.Address
+ prev *big.Int
+ }
+ nonceChange struct {
+ account *common.Address
+ prev uint64
+ }
+ storageChange struct {
+ account *common.Address
+ key, prevalue common.Hash
+ }
+ codeChange struct {
+ account *common.Address
+ prevcode, prevhash []byte
+ }
+
+ // Changes to other state values.
+ refundChange struct {
+ prev *big.Int
+ }
+ addLogChange struct {
+ txhash common.Hash
+ }
+)
+
+func (ch createObjectChange) undo(s *StateDB) {
+ s.GetStateObject(*ch.account).deleted = true
+ delete(s.stateObjects, *ch.account)
+ delete(s.stateObjectsDirty, *ch.account)
+}
+
+func (ch resetObjectChange) undo(s *StateDB) {
+ s.setStateObject(ch.prev)
+}
+
+func (ch suicideChange) undo(s *StateDB) {
+ obj := s.GetStateObject(*ch.account)
+ if obj != nil {
+ obj.suicided = ch.prev
+ obj.setBalance(ch.prevbalance)
+ }
+}
+
+func (ch balanceChange) undo(s *StateDB) {
+ s.GetStateObject(*ch.account).setBalance(ch.prev)
+}
+
+func (ch nonceChange) undo(s *StateDB) {
+ s.GetStateObject(*ch.account).setNonce(ch.prev)
+}
+
+func (ch codeChange) undo(s *StateDB) {
+ s.GetStateObject(*ch.account).setCode(common.BytesToHash(ch.prevhash), ch.prevcode)
+}
+
+func (ch storageChange) undo(s *StateDB) {
+ s.GetStateObject(*ch.account).setState(ch.key, ch.prevalue)
+}
+
+func (ch refundChange) undo(s *StateDB) {
+ s.refund = ch.prev
+}
+
+func (ch addLogChange) undo(s *StateDB) {
+ logs := s.logs[ch.txhash]
+ if len(logs) == 1 {
+ delete(s.logs, ch.txhash)
+ } else {
+ s.logs[ch.txhash] = logs[:len(logs)-1]
+ }
+}
diff --git a/core/state/managed_state_test.go b/core/state/managed_state_test.go
index baa53428f..3f7bc2aa8 100644
--- a/core/state/managed_state_test.go
+++ b/core/state/managed_state_test.go
@@ -29,11 +29,8 @@ func create() (*ManagedState, *account) {
db, _ := ethdb.NewMemDatabase()
statedb, _ := New(common.Hash{}, db)
ms := ManageState(statedb)
- so := &StateObject{address: addr}
- so.SetNonce(100)
- ms.StateDB.stateObjects[addr] = so
- ms.accounts[addr] = newAccount(so)
-
+ ms.StateDB.SetNonce(addr, 100)
+ ms.accounts[addr] = newAccount(ms.StateDB.GetStateObject(addr))
return ms, ms.accounts[addr]
}
diff --git a/core/state/state_object.go b/core/state/state_object.go
index cbd50e2a3..6eab27d9e 100644
--- a/core/state/state_object.go
+++ b/core/state/state_object.go
@@ -66,6 +66,7 @@ func (self Storage) Copy() Storage {
type StateObject struct {
address common.Address // Ethereum address of this account
data Account
+ db *StateDB
// DB error.
// State objects are used by the consensus core and VM which are
@@ -82,10 +83,10 @@ type StateObject struct {
dirtyStorage Storage // Storage entries that need to be flushed to disk
// Cache flags.
- // When an object is marked for deletion it will be delete from the trie
- // during the "update" phase of the state transition
+ // When an object is marked suicided it will be delete from the trie
+ // during the "update" phase of the state transition.
dirtyCode bool // true if the code was updated
- remove bool
+ suicided bool
deleted bool
onDirty func(addr common.Address) // Callback method to mark a state object newly dirty
}
@@ -99,15 +100,15 @@ type Account struct {
CodeHash []byte
}
-// NewObject creates a state object.
-func NewObject(address common.Address, data Account, onDirty func(addr common.Address)) *StateObject {
+// newObject creates a state object.
+func newObject(db *StateDB, address common.Address, data Account, onDirty func(addr common.Address)) *StateObject {
if data.Balance == nil {
data.Balance = new(big.Int)
}
if data.CodeHash == nil {
data.CodeHash = emptyCodeHash
}
- return &StateObject{address: address, data: data, cachedStorage: make(Storage), dirtyStorage: make(Storage), onDirty: onDirty}
+ return &StateObject{db: db, address: address, data: data, cachedStorage: make(Storage), dirtyStorage: make(Storage), onDirty: onDirty}
}
// EncodeRLP implements rlp.Encoder.
@@ -122,8 +123,8 @@ func (self *StateObject) setError(err error) {
}
}
-func (self *StateObject) MarkForDeletion() {
- self.remove = true
+func (self *StateObject) markSuicided() {
+ self.suicided = true
if self.onDirty != nil {
self.onDirty(self.Address())
self.onDirty = nil
@@ -152,10 +153,13 @@ func (self *StateObject) GetState(db trie.Database, key common.Hash) common.Hash
return value
}
// Load from DB in case it is missing.
- tr := self.getTrie(db)
- var ret []byte
- rlp.DecodeBytes(tr.Get(key[:]), &ret)
- value = common.BytesToHash(ret)
+ if enc := self.getTrie(db).Get(key[:]); len(enc) > 0 {
+ _, content, _, err := rlp.Split(enc)
+ if err != nil {
+ self.setError(err)
+ }
+ value.SetBytes(content)
+ }
if (value != common.Hash{}) {
self.cachedStorage[key] = value
}
@@ -163,7 +167,16 @@ func (self *StateObject) GetState(db trie.Database, key common.Hash) common.Hash
}
// SetState updates a value in account storage.
-func (self *StateObject) SetState(key, value common.Hash) {
+func (self *StateObject) SetState(db trie.Database, key, value common.Hash) {
+ self.db.journal = append(self.db.journal, storageChange{
+ account: &self.address,
+ key: key,
+ prevalue: self.GetState(db, key),
+ })
+ self.setState(key, value)
+}
+
+func (self *StateObject) setState(key, value common.Hash) {
self.cachedStorage[key] = value
self.dirtyStorage[key] = value
@@ -189,7 +202,7 @@ func (self *StateObject) updateTrie(db trie.Database) {
}
// UpdateRoot sets the trie root to the current root hash of
-func (self *StateObject) UpdateRoot(db trie.Database) {
+func (self *StateObject) updateRoot(db trie.Database) {
self.updateTrie(db)
self.data.Root = self.trie.Hash()
}
@@ -199,7 +212,6 @@ func (self *StateObject) UpdateRoot(db trie.Database) {
func (self *StateObject) CommitTrie(db trie.Database, dbw trie.DatabaseWriter) error {
self.updateTrie(db)
if self.dbErr != nil {
- fmt.Println("dbErr:", self.dbErr)
return self.dbErr
}
root, err := self.trie.CommitTo(dbw)
@@ -232,6 +244,14 @@ func (c *StateObject) SubBalance(amount *big.Int) {
}
func (self *StateObject) SetBalance(amount *big.Int) {
+ self.db.journal = append(self.db.journal, balanceChange{
+ account: &self.address,
+ prev: new(big.Int).Set(self.data.Balance),
+ })
+ self.setBalance(amount)
+}
+
+func (self *StateObject) setBalance(amount *big.Int) {
self.data.Balance = amount
if self.onDirty != nil {
self.onDirty(self.Address())
@@ -242,13 +262,13 @@ func (self *StateObject) SetBalance(amount *big.Int) {
// Return the gas back to the origin. Used by the Virtual machine or Closures
func (c *StateObject) ReturnGas(gas, price *big.Int) {}
-func (self *StateObject) Copy(db trie.Database, onDirty func(addr common.Address)) *StateObject {
- stateObject := NewObject(self.address, self.data, onDirty)
+func (self *StateObject) deepCopy(db *StateDB, onDirty func(addr common.Address)) *StateObject {
+ stateObject := newObject(db, self.address, self.data, onDirty)
stateObject.trie = self.trie
stateObject.code = self.code
stateObject.dirtyStorage = self.dirtyStorage.Copy()
stateObject.cachedStorage = self.dirtyStorage.Copy()
- stateObject.remove = self.remove
+ stateObject.suicided = self.suicided
stateObject.dirtyCode = self.dirtyCode
stateObject.deleted = self.deleted
return stateObject
@@ -280,6 +300,16 @@ func (self *StateObject) Code(db trie.Database) []byte {
}
func (self *StateObject) SetCode(codeHash common.Hash, code []byte) {
+ prevcode := self.Code(self.db.db)
+ self.db.journal = append(self.db.journal, codeChange{
+ account: &self.address,
+ prevhash: self.CodeHash(),
+ prevcode: prevcode,
+ })
+ self.setCode(codeHash, code)
+}
+
+func (self *StateObject) setCode(codeHash common.Hash, code []byte) {
self.code = code
self.data.CodeHash = codeHash[:]
self.dirtyCode = true
@@ -290,6 +320,14 @@ func (self *StateObject) SetCode(codeHash common.Hash, code []byte) {
}
func (self *StateObject) SetNonce(nonce uint64) {
+ self.db.journal = append(self.db.journal, nonceChange{
+ account: &self.address,
+ prev: self.data.Nonce,
+ })
+ self.setNonce(nonce)
+}
+
+func (self *StateObject) setNonce(nonce uint64) {
self.data.Nonce = nonce
if self.onDirty != nil {
self.onDirty(self.Address())
@@ -322,7 +360,7 @@ func (self *StateObject) ForEachStorage(cb func(key, value common.Hash) bool) {
cb(h, value)
}
- it := self.trie.Iterator()
+ it := self.getTrie(self.db.db).Iterator()
for it.Next() {
// ignore cached values
key := common.BytesToHash(self.trie.GetKey(it.Key))
diff --git a/core/state/state_test.go b/core/state/state_test.go
index 7b9b39e06..f188bc271 100644
--- a/core/state/state_test.go
+++ b/core/state/state_test.go
@@ -46,8 +46,8 @@ func (s *StateSuite) TestDump(c *checker.C) {
obj3.SetBalance(big.NewInt(44))
// write some of them to the trie
- s.state.UpdateStateObject(obj1)
- s.state.UpdateStateObject(obj2)
+ s.state.updateStateObject(obj1)
+ s.state.updateStateObject(obj2)
s.state.Commit()
// check that dump contains the state objects that are in trie
@@ -116,12 +116,12 @@ func (s *StateSuite) TestSnapshot(c *checker.C) {
// set initial state object value
s.state.SetState(stateobjaddr, storageaddr, data1)
// get snapshot of current state
- snapshot := s.state.Copy()
+ snapshot := s.state.Snapshot()
// set new state object value
s.state.SetState(stateobjaddr, storageaddr, data2)
// restore snapshot
- s.state.Set(snapshot)
+ s.state.RevertToSnapshot(snapshot)
// get state storage value
res := s.state.GetState(stateobjaddr, storageaddr)
@@ -129,6 +129,12 @@ func (s *StateSuite) TestSnapshot(c *checker.C) {
c.Assert(data1, checker.DeepEquals, res)
}
+func TestSnapshotEmpty(t *testing.T) {
+ db, _ := ethdb.NewMemDatabase()
+ state, _ := New(common.Hash{}, db)
+ state.RevertToSnapshot(state.Snapshot())
+}
+
// use testing instead of checker because checker does not support
// printing/logging in tests (-check.vv does not work)
func TestSnapshot2(t *testing.T) {
@@ -150,9 +156,9 @@ func TestSnapshot2(t *testing.T) {
so0.SetBalance(big.NewInt(42))
so0.SetNonce(43)
so0.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e'}), []byte{'c', 'a', 'f', 'e'})
- so0.remove = false
+ so0.suicided = false
so0.deleted = false
- state.SetStateObject(so0)
+ state.setStateObject(so0)
root, _ := state.Commit()
state.Reset(root)
@@ -162,17 +168,17 @@ func TestSnapshot2(t *testing.T) {
so1.SetBalance(big.NewInt(52))
so1.SetNonce(53)
so1.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e', '2'}), []byte{'c', 'a', 'f', 'e', '2'})
- so1.remove = true
+ so1.suicided = true
so1.deleted = true
- state.SetStateObject(so1)
+ state.setStateObject(so1)
so1 = state.GetStateObject(stateobjaddr1)
if so1 != nil {
t.Fatalf("deleted object not nil when getting")
}
- snapshot := state.Copy()
- state.Set(snapshot)
+ snapshot := state.Snapshot()
+ state.RevertToSnapshot(snapshot)
so0Restored := state.GetStateObject(stateobjaddr0)
// Update lazily-loaded values before comparing.
@@ -222,8 +228,8 @@ func compareStateObjects(so0, so1 *StateObject, t *testing.T) {
}
}
- if so0.remove != so1.remove {
- t.Fatalf("Remove mismatch: have %v, want %v", so0.remove, so1.remove)
+ if so0.suicided != so1.suicided {
+ t.Fatalf("suicided mismatch: have %v, want %v", so0.suicided, so1.suicided)
}
if so0.deleted != so1.deleted {
t.Fatalf("Deleted mismatch: have %v, want %v", so0.deleted, so1.deleted)
diff --git a/core/state/statedb.go b/core/state/statedb.go
index 4204c456e..ec9e9392f 100644
--- a/core/state/statedb.go
+++ b/core/state/statedb.go
@@ -20,6 +20,7 @@ package state
import (
"fmt"
"math/big"
+ "sort"
"sync"
"github.com/ethereum/go-ethereum/common"
@@ -40,12 +41,17 @@ var StartingNonce uint64
const (
// Number of past tries to keep. The arbitrarily chosen value here
// is max uncle depth + 1.
- maxJournalLength = 8
+ maxTrieCacheLength = 8
// Number of codehash->size associations to keep.
codeSizeCacheSize = 100000
)
+type revision struct {
+ id int
+ journalIndex int
+}
+
// StateDBs within the ethereum protocol are used to store anything
// within the merkle trie. StateDBs take care of caching and storing
// nested states. It's the general query interface to retrieve:
@@ -69,6 +75,12 @@ type StateDB struct {
logs map[common.Hash]vm.Logs
logSize uint
+ // Journal of state modifications. This is the backbone of
+ // Snapshot and RevertToSnapshot.
+ journal journal
+ validRevisions []revision
+ nextRevisionId int
+
lock sync.Mutex
}
@@ -124,12 +136,12 @@ func (self *StateDB) Reset(root common.Hash) error {
self.trie = tr
self.stateObjects = make(map[common.Address]*StateObject)
self.stateObjectsDirty = make(map[common.Address]struct{})
- self.refund = new(big.Int)
self.thash = common.Hash{}
self.bhash = common.Hash{}
self.txIndex = 0
self.logs = make(map[common.Hash]vm.Logs)
self.logSize = 0
+ self.clearJournalAndRefund()
return nil
}
@@ -150,7 +162,7 @@ func (self *StateDB) pushTrie(t *trie.SecureTrie) {
self.lock.Lock()
defer self.lock.Unlock()
- if len(self.pastTries) >= maxJournalLength {
+ if len(self.pastTries) >= maxTrieCacheLength {
copy(self.pastTries, self.pastTries[1:])
self.pastTries[len(self.pastTries)-1] = t
} else {
@@ -165,6 +177,8 @@ func (self *StateDB) StartRecord(thash, bhash common.Hash, ti int) {
}
func (self *StateDB) AddLog(log *vm.Log) {
+ self.journal = append(self.journal, addLogChange{txhash: self.thash})
+
log.TxHash = self.thash
log.BlockHash = self.bhash
log.TxIndex = uint(self.txIndex)
@@ -186,13 +200,12 @@ func (self *StateDB) Logs() vm.Logs {
}
func (self *StateDB) AddRefund(gas *big.Int) {
+ self.journal = append(self.journal, refundChange{prev: new(big.Int).Set(self.refund)})
self.refund.Add(self.refund, gas)
}
-func (self *StateDB) HasAccount(addr common.Address) bool {
- return self.GetStateObject(addr) != nil
-}
-
+// Exist reports whether the given account address exists in the state.
+// Notably this also returns true for suicided accounts.
func (self *StateDB) Exist(addr common.Address) bool {
return self.GetStateObject(addr) != nil
}
@@ -207,7 +220,6 @@ func (self *StateDB) GetBalance(addr common.Address) *big.Int {
if stateObject != nil {
return stateObject.Balance()
}
-
return common.Big0
}
@@ -263,10 +275,10 @@ func (self *StateDB) GetState(a common.Address, b common.Hash) common.Hash {
return common.Hash{}
}
-func (self *StateDB) IsDeleted(addr common.Address) bool {
+func (self *StateDB) HasSuicided(addr common.Address) bool {
stateObject := self.GetStateObject(addr)
if stateObject != nil {
- return stateObject.remove
+ return stateObject.suicided
}
return false
}
@@ -282,6 +294,13 @@ func (self *StateDB) AddBalance(addr common.Address, amount *big.Int) {
}
}
+func (self *StateDB) SetBalance(addr common.Address, amount *big.Int) {
+ stateObject := self.GetOrNewStateObject(addr)
+ if stateObject != nil {
+ stateObject.SetBalance(amount)
+ }
+}
+
func (self *StateDB) SetNonce(addr common.Address, nonce uint64) {
stateObject := self.GetOrNewStateObject(addr)
if stateObject != nil {
@@ -299,27 +318,36 @@ func (self *StateDB) SetCode(addr common.Address, code []byte) {
func (self *StateDB) SetState(addr common.Address, key common.Hash, value common.Hash) {
stateObject := self.GetOrNewStateObject(addr)
if stateObject != nil {
- stateObject.SetState(key, value)
+ stateObject.SetState(self.db, key, value)
}
}
-func (self *StateDB) Delete(addr common.Address) bool {
+// Suicide marks the given account as suicided.
+// This clears the account balance.
+//
+// The account's state object is still available until the state is committed,
+// GetStateObject will return a non-nil account after Suicide.
+func (self *StateDB) Suicide(addr common.Address) bool {
stateObject := self.GetStateObject(addr)
- if stateObject != nil {
- stateObject.MarkForDeletion()
- stateObject.data.Balance = new(big.Int)
- return true
+ if stateObject == nil {
+ return false
}
-
- return false
+ self.journal = append(self.journal, suicideChange{
+ account: &addr,
+ prev: stateObject.suicided,
+ prevbalance: new(big.Int).Set(stateObject.Balance()),
+ })
+ stateObject.markSuicided()
+ stateObject.data.Balance = new(big.Int)
+ return true
}
//
// Setting, updating & deleting state object methods
//
-// Update the given state object and apply it to state trie
-func (self *StateDB) UpdateStateObject(stateObject *StateObject) {
+// updateStateObject writes the given object to the trie.
+func (self *StateDB) updateStateObject(stateObject *StateObject) {
addr := stateObject.Address()
data, err := rlp.EncodeToBytes(stateObject)
if err != nil {
@@ -328,10 +356,9 @@ func (self *StateDB) UpdateStateObject(stateObject *StateObject) {
self.trie.Update(addr[:], data)
}
-// Delete the given state object and delete it from the state trie
-func (self *StateDB) DeleteStateObject(stateObject *StateObject) {
+// deleteStateObject removes the given object from the state trie.
+func (self *StateDB) deleteStateObject(stateObject *StateObject) {
stateObject.deleted = true
-
addr := stateObject.Address()
self.trie.Delete(addr[:])
}
@@ -357,12 +384,12 @@ func (self *StateDB) GetStateObject(addr common.Address) (stateObject *StateObje
return nil
}
// Insert into the live set.
- obj := NewObject(addr, data, self.MarkStateObjectDirty)
- self.SetStateObject(obj)
+ obj := newObject(self, addr, data, self.MarkStateObjectDirty)
+ self.setStateObject(obj)
return obj
}
-func (self *StateDB) SetStateObject(object *StateObject) {
+func (self *StateDB) setStateObject(object *StateObject) {
self.stateObjects[object.Address()] = object
}
@@ -370,52 +397,55 @@ func (self *StateDB) SetStateObject(object *StateObject) {
func (self *StateDB) GetOrNewStateObject(addr common.Address) *StateObject {
stateObject := self.GetStateObject(addr)
if stateObject == nil || stateObject.deleted {
- stateObject = self.CreateStateObject(addr)
+ stateObject, _ = self.createObject(addr)
}
-
return stateObject
}
-// NewStateObject create a state object whether it exist in the trie or not
-func (self *StateDB) newStateObject(addr common.Address) *StateObject {
- if glog.V(logger.Core) {
- glog.Infof("(+) %x\n", addr)
- }
- obj := NewObject(addr, Account{}, self.MarkStateObjectDirty)
- obj.SetNonce(StartingNonce) // sets the object to dirty
- self.stateObjects[addr] = obj
- return obj
-}
-
// MarkStateObjectDirty adds the specified object to the dirty map to avoid costly
// state object cache iteration to find a handful of modified ones.
func (self *StateDB) MarkStateObjectDirty(addr common.Address) {
self.stateObjectsDirty[addr] = struct{}{}
}
-// Creates creates a new state object and takes ownership.
-func (self *StateDB) CreateStateObject(addr common.Address) *StateObject {
- // Get previous (if any)
- so := self.GetStateObject(addr)
- // Create a new one
- newSo := self.newStateObject(addr)
-
- // If it existed set the balance to the new account
- if so != nil {
- newSo.data.Balance = so.data.Balance
+// createObject creates a new state object. If there is an existing account with
+// the given address, it is overwritten and returned as the second return value.
+func (self *StateDB) createObject(addr common.Address) (newobj, prev *StateObject) {
+ prev = self.GetStateObject(addr)
+ newobj = newObject(self, addr, Account{}, self.MarkStateObjectDirty)
+ newobj.setNonce(StartingNonce) // sets the object to dirty
+ if prev == nil {
+ if glog.V(logger.Core) {
+ glog.Infof("(+) %x\n", addr)
+ }
+ self.journal = append(self.journal, createObjectChange{account: &addr})
+ } else {
+ self.journal = append(self.journal, resetObjectChange{prev: prev})
}
-
- return newSo
-}
-
-func (self *StateDB) CreateAccount(addr common.Address) vm.Account {
- return self.CreateStateObject(addr)
+ self.setStateObject(newobj)
+ return newobj, prev
}
+// CreateAccount explicitly creates a state object. If a state object with the address
+// already exists the balance is carried over to the new account.
+//
+// CreateAccount is called during the EVM CREATE operation. The situation might arise that
+// a contract does the following:
//
-// Setting, copying of the state methods
+// 1. sends funds to sha(account ++ (nonce + 1))
+// 2. tx_create(sha(account ++ nonce)) (note that this gets the address of 1)
//
+// Carrying over the balance ensures that Ether doesn't disappear.
+func (self *StateDB) CreateAccount(addr common.Address) vm.Account {
+ new, prev := self.createObject(addr)
+ if prev != nil {
+ new.setBalance(prev.data.Balance)
+ }
+ return new
+}
+// Copy creates a deep, independent copy of the state.
+// Snapshots of the copied state cannot be applied to the copy.
func (self *StateDB) Copy() *StateDB {
self.lock.Lock()
defer self.lock.Unlock()
@@ -434,7 +464,7 @@ func (self *StateDB) Copy() *StateDB {
}
// Copy the dirty states and logs
for addr, _ := range self.stateObjectsDirty {
- state.stateObjects[addr] = self.stateObjects[addr].Copy(self.db, state.MarkStateObjectDirty)
+ state.stateObjects[addr] = self.stateObjects[addr].deepCopy(state, state.MarkStateObjectDirty)
state.stateObjectsDirty[addr] = struct{}{}
}
for hash, logs := range self.logs {
@@ -444,21 +474,38 @@ func (self *StateDB) Copy() *StateDB {
return state
}
-func (self *StateDB) Set(state *StateDB) {
- self.lock.Lock()
- defer self.lock.Unlock()
+// Snapshot returns an identifier for the current revision of the state.
+func (self *StateDB) Snapshot() int {
+ id := self.nextRevisionId
+ self.nextRevisionId++
+ self.validRevisions = append(self.validRevisions, revision{id, len(self.journal)})
+ return id
+}
+
+// RevertToSnapshot reverts all state changes made since the given revision.
+func (self *StateDB) RevertToSnapshot(revid int) {
+ // Find the snapshot in the stack of valid snapshots.
+ idx := sort.Search(len(self.validRevisions), func(i int) bool {
+ return self.validRevisions[i].id >= revid
+ })
+ if idx == len(self.validRevisions) || self.validRevisions[idx].id != revid {
+ panic(fmt.Errorf("revision id %v cannot be reverted", revid))
+ }
+ snapshot := self.validRevisions[idx].journalIndex
+
+ // Replay the journal to undo changes.
+ for i := len(self.journal) - 1; i >= snapshot; i-- {
+ self.journal[i].undo(self)
+ }
+ self.journal = self.journal[:snapshot]
- self.db = state.db
- self.trie = state.trie
- self.pastTries = state.pastTries
- self.stateObjects = state.stateObjects
- self.stateObjectsDirty = state.stateObjectsDirty
- self.codeSizeCache = state.codeSizeCache
- self.refund = state.refund
- self.logs = state.logs
- self.logSize = state.logSize
+ // Remove invalidated snapshots from the stack.
+ self.validRevisions = self.validRevisions[:idx]
}
+// GetRefund returns the current value of the refund counter.
+// The return value must not be modified by the caller and will become
+// invalid at the next call to AddRefund.
func (self *StateDB) GetRefund() *big.Int {
return self.refund
}
@@ -467,16 +514,17 @@ func (self *StateDB) GetRefund() *big.Int {
// It is called in between transactions to get the root hash that
// goes into transaction receipts.
func (s *StateDB) IntermediateRoot() common.Hash {
- s.refund = new(big.Int)
for addr, _ := range s.stateObjectsDirty {
stateObject := s.stateObjects[addr]
- if stateObject.remove {
- s.DeleteStateObject(stateObject)
+ if stateObject.suicided {
+ s.deleteStateObject(stateObject)
} else {
- stateObject.UpdateRoot(s.db)
- s.UpdateStateObject(stateObject)
+ stateObject.updateRoot(s.db)
+ s.updateStateObject(stateObject)
}
}
+ // Invalidate journal because reverting across transactions is not allowed.
+ s.clearJournalAndRefund()
return s.trie.Hash()
}
@@ -486,15 +534,15 @@ func (s *StateDB) IntermediateRoot() common.Hash {
// DeleteSuicides should not be used for consensus related updates
// under any circumstances.
func (s *StateDB) DeleteSuicides() {
- // Reset refund so that any used-gas calculations can use
- // this method.
- s.refund = new(big.Int)
+ // Reset refund so that any used-gas calculations can use this method.
+ s.clearJournalAndRefund()
+
for addr, _ := range s.stateObjectsDirty {
stateObject := s.stateObjects[addr]
// If the object has been removed by a suicide
// flag the object as deleted.
- if stateObject.remove {
+ if stateObject.suicided {
stateObject.deleted = true
}
delete(s.stateObjectsDirty, addr)
@@ -516,15 +564,21 @@ func (s *StateDB) CommitBatch() (root common.Hash, batch ethdb.Batch) {
return root, batch
}
-func (s *StateDB) commit(dbw trie.DatabaseWriter) (root common.Hash, err error) {
+func (s *StateDB) clearJournalAndRefund() {
+ s.journal = nil
+ s.validRevisions = s.validRevisions[:0]
s.refund = new(big.Int)
+}
+
+func (s *StateDB) commit(dbw trie.DatabaseWriter) (root common.Hash, err error) {
+ defer s.clearJournalAndRefund()
// Commit objects to the trie.
for addr, stateObject := range s.stateObjects {
- if stateObject.remove {
+ if stateObject.suicided {
// If the object has been removed, don't bother syncing it
// and just mark it for deletion in the trie.
- s.DeleteStateObject(stateObject)
+ s.deleteStateObject(stateObject)
} else if _, ok := s.stateObjectsDirty[addr]; ok {
// Write any contract code associated with the state object
if stateObject.code != nil && stateObject.dirtyCode {
@@ -538,7 +592,7 @@ func (s *StateDB) commit(dbw trie.DatabaseWriter) (root common.Hash, err error)
return common.Hash{}, err
}
// Update the object in the main account trie.
- s.UpdateStateObject(stateObject)
+ s.updateStateObject(stateObject)
}
delete(s.stateObjectsDirty, addr)
}
@@ -549,7 +603,3 @@ func (s *StateDB) commit(dbw trie.DatabaseWriter) (root common.Hash, err error)
}
return root, err
}
-
-func (self *StateDB) Refunds() *big.Int {
- return self.refund
-}
diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go
index 7930b620d..5d041c740 100644
--- a/core/state/statedb_test.go
+++ b/core/state/statedb_test.go
@@ -17,11 +17,19 @@
package state
import (
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "math"
"math/big"
+ "math/rand"
+ "reflect"
+ "strings"
"testing"
+ "testing/quick"
"github.com/ethereum/go-ethereum/common"
- "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/ethdb"
)
@@ -34,16 +42,16 @@ func TestUpdateLeaks(t *testing.T) {
// Update it with some accounts
for i := byte(0); i < 255; i++ {
- obj := state.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
- obj.AddBalance(big.NewInt(int64(11 * i)))
- obj.SetNonce(uint64(42 * i))
+ addr := common.BytesToAddress([]byte{i})
+ state.AddBalance(addr, big.NewInt(int64(11*i)))
+ state.SetNonce(addr, uint64(42*i))
if i%2 == 0 {
- obj.SetState(common.BytesToHash([]byte{i, i, i}), common.BytesToHash([]byte{i, i, i, i}))
+ state.SetState(addr, common.BytesToHash([]byte{i, i, i}), common.BytesToHash([]byte{i, i, i, i}))
}
if i%3 == 0 {
- obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i}), []byte{i, i, i, i, i})
+ state.SetCode(addr, []byte{i, i, i, i, i})
}
- state.UpdateStateObject(obj)
+ state.IntermediateRoot()
}
// Ensure that no data was leaked into the database
for _, key := range db.Keys() {
@@ -61,51 +69,38 @@ func TestIntermediateLeaks(t *testing.T) {
transState, _ := New(common.Hash{}, transDb)
finalState, _ := New(common.Hash{}, finalDb)
- // Update the states with some objects
- for i := byte(0); i < 255; i++ {
- // Create a new state object with some data into the transition database
- obj := transState.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
- obj.SetBalance(big.NewInt(int64(11 * i)))
- obj.SetNonce(uint64(42 * i))
+ modify := func(state *StateDB, addr common.Address, i, tweak byte) {
+ state.SetBalance(addr, big.NewInt(int64(11*i)+int64(tweak)))
+ state.SetNonce(addr, uint64(42*i+tweak))
if i%2 == 0 {
- obj.SetState(common.BytesToHash([]byte{i, i, i, 0}), common.BytesToHash([]byte{i, i, i, i, 0}))
+ state.SetState(addr, common.Hash{i, i, i, 0}, common.Hash{})
+ state.SetState(addr, common.Hash{i, i, i, tweak}, common.Hash{i, i, i, i, tweak})
}
if i%3 == 0 {
- obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i, 0}), []byte{i, i, i, i, i, 0})
+ state.SetCode(addr, []byte{i, i, i, i, i, tweak})
}
- transState.UpdateStateObject(obj)
+ }
- // Overwrite all the data with new values in the transition database
- obj.SetBalance(big.NewInt(int64(11*i + 1)))
- obj.SetNonce(uint64(42*i + 1))
- if i%2 == 0 {
- obj.SetState(common.BytesToHash([]byte{i, i, i, 0}), common.Hash{})
- obj.SetState(common.BytesToHash([]byte{i, i, i, 1}), common.BytesToHash([]byte{i, i, i, i, 1}))
- }
- if i%3 == 0 {
- obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i, 1}), []byte{i, i, i, i, i, 1})
- }
- transState.UpdateStateObject(obj)
+ // Modify the transient state.
+ for i := byte(0); i < 255; i++ {
+ modify(transState, common.Address{byte(i)}, i, 0)
+ }
+ // Write modifications to trie.
+ transState.IntermediateRoot()
- // Create the final state object directly in the final database
- obj = finalState.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
- obj.SetBalance(big.NewInt(int64(11*i + 1)))
- obj.SetNonce(uint64(42*i + 1))
- if i%2 == 0 {
- obj.SetState(common.BytesToHash([]byte{i, i, i, 1}), common.BytesToHash([]byte{i, i, i, i, 1}))
- }
- if i%3 == 0 {
- obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i, 1}), []byte{i, i, i, i, i, 1})
- }
- finalState.UpdateStateObject(obj)
+ // Overwrite all the data with new values in the transient database.
+ for i := byte(0); i < 255; i++ {
+ modify(transState, common.Address{byte(i)}, i, 99)
+ modify(finalState, common.Address{byte(i)}, i, 99)
}
+
+ // Commit and cross check the databases.
if _, err := transState.Commit(); err != nil {
t.Fatalf("failed to commit transition state: %v", err)
}
if _, err := finalState.Commit(); err != nil {
t.Fatalf("failed to commit final state: %v", err)
}
- // Cross check the databases to ensure they are the same
for _, key := range finalDb.Keys() {
if _, err := transDb.Get(key); err != nil {
val, _ := finalDb.Get(key)
@@ -119,3 +114,243 @@ func TestIntermediateLeaks(t *testing.T) {
}
}
}
+
+func TestSnapshotRandom(t *testing.T) {
+ config := &quick.Config{MaxCount: 1000}
+ err := quick.Check((*snapshotTest).run, config)
+ if cerr, ok := err.(*quick.CheckError); ok {
+ test := cerr.In[0].(*snapshotTest)
+ t.Errorf("%v:\n%s", test.err, test)
+ } else if err != nil {
+ t.Error(err)
+ }
+}
+
+// A snapshotTest checks that reverting StateDB snapshots properly undoes all changes
+// captured by the snapshot. Instances of this test with pseudorandom content are created
+// by Generate.
+//
+// The test works as follows:
+//
+// A new state is created and all actions are applied to it. Several snapshots are taken
+// in between actions. The test then reverts each snapshot. For each snapshot the actions
+// leading up to it are replayed on a fresh, empty state. The behaviour of all public
+// accessor methods on the reverted state must match the return value of the equivalent
+// methods on the replayed state.
+type snapshotTest struct {
+ addrs []common.Address // all account addresses
+ actions []testAction // modifications to the state
+ snapshots []int // actions indexes at which snapshot is taken
+ err error // failure details are reported through this field
+}
+
+type testAction struct {
+ name string
+ fn func(testAction, *StateDB)
+ args []int64
+ noAddr bool
+}
+
+// newTestAction creates a random action that changes state.
+func newTestAction(addr common.Address, r *rand.Rand) testAction {
+ actions := []testAction{
+ {
+ name: "SetBalance",
+ fn: func(a testAction, s *StateDB) {
+ s.SetBalance(addr, big.NewInt(a.args[0]))
+ },
+ args: make([]int64, 1),
+ },
+ {
+ name: "AddBalance",
+ fn: func(a testAction, s *StateDB) {
+ s.AddBalance(addr, big.NewInt(a.args[0]))
+ },
+ args: make([]int64, 1),
+ },
+ {
+ name: "SetNonce",
+ fn: func(a testAction, s *StateDB) {
+ s.SetNonce(addr, uint64(a.args[0]))
+ },
+ args: make([]int64, 1),
+ },
+ {
+ name: "SetState",
+ fn: func(a testAction, s *StateDB) {
+ var key, val common.Hash
+ binary.BigEndian.PutUint16(key[:], uint16(a.args[0]))
+ binary.BigEndian.PutUint16(val[:], uint16(a.args[1]))
+ s.SetState(addr, key, val)
+ },
+ args: make([]int64, 2),
+ },
+ {
+ name: "SetCode",
+ fn: func(a testAction, s *StateDB) {
+ code := make([]byte, 16)
+ binary.BigEndian.PutUint64(code, uint64(a.args[0]))
+ binary.BigEndian.PutUint64(code[8:], uint64(a.args[1]))
+ s.SetCode(addr, code)
+ },
+ args: make([]int64, 2),
+ },
+ {
+ name: "CreateAccount",
+ fn: func(a testAction, s *StateDB) {
+ s.CreateAccount(addr)
+ },
+ },
+ {
+ name: "Suicide",
+ fn: func(a testAction, s *StateDB) {
+ s.Suicide(addr)
+ },
+ },
+ {
+ name: "AddRefund",
+ fn: func(a testAction, s *StateDB) {
+ s.AddRefund(big.NewInt(a.args[0]))
+ },
+ args: make([]int64, 1),
+ noAddr: true,
+ },
+ {
+ name: "AddLog",
+ fn: func(a testAction, s *StateDB) {
+ data := make([]byte, 2)
+ binary.BigEndian.PutUint16(data, uint16(a.args[0]))
+ s.AddLog(&vm.Log{Address: addr, Data: data})
+ },
+ args: make([]int64, 1),
+ },
+ }
+ action := actions[r.Intn(len(actions))]
+ var nameargs []string
+ if !action.noAddr {
+ nameargs = append(nameargs, addr.Hex())
+ }
+ for _, i := range action.args {
+ action.args[i] = rand.Int63n(100)
+ nameargs = append(nameargs, fmt.Sprint(action.args[i]))
+ }
+ action.name += strings.Join(nameargs, ", ")
+ return action
+}
+
+// Generate returns a new snapshot test of the given size. All randomness is
+// derived from r.
+func (*snapshotTest) Generate(r *rand.Rand, size int) reflect.Value {
+ // Generate random actions.
+ addrs := make([]common.Address, 50)
+ for i := range addrs {
+ addrs[i][0] = byte(i)
+ }
+ actions := make([]testAction, size)
+ for i := range actions {
+ addr := addrs[r.Intn(len(addrs))]
+ actions[i] = newTestAction(addr, r)
+ }
+ // Generate snapshot indexes.
+ nsnapshots := int(math.Sqrt(float64(size)))
+ if size > 0 && nsnapshots == 0 {
+ nsnapshots = 1
+ }
+ snapshots := make([]int, nsnapshots)
+ snaplen := len(actions) / nsnapshots
+ for i := range snapshots {
+ // Try to place the snapshots some number of actions apart from each other.
+ snapshots[i] = (i * snaplen) + r.Intn(snaplen)
+ }
+ return reflect.ValueOf(&snapshotTest{addrs, actions, snapshots, nil})
+}
+
+func (test *snapshotTest) String() string {
+ out := new(bytes.Buffer)
+ sindex := 0
+ for i, action := range test.actions {
+ if len(test.snapshots) > sindex && i == test.snapshots[sindex] {
+ fmt.Fprintf(out, "---- snapshot %d ----\n", sindex)
+ sindex++
+ }
+ fmt.Fprintf(out, "%4d: %s\n", i, action.name)
+ }
+ return out.String()
+}
+
+func (test *snapshotTest) run() bool {
+ // Run all actions and create snapshots.
+ var (
+ db, _ = ethdb.NewMemDatabase()
+ state, _ = New(common.Hash{}, db)
+ snapshotRevs = make([]int, len(test.snapshots))
+ sindex = 0
+ )
+ for i, action := range test.actions {
+ if len(test.snapshots) > sindex && i == test.snapshots[sindex] {
+ snapshotRevs[sindex] = state.Snapshot()
+ sindex++
+ }
+ action.fn(action, state)
+ }
+
+ // Revert all snapshots in reverse order. Each revert must yield a state
+ // that is equivalent to fresh state with all actions up the snapshot applied.
+ for sindex--; sindex >= 0; sindex-- {
+ checkstate, _ := New(common.Hash{}, db)
+ for _, action := range test.actions[:test.snapshots[sindex]] {
+ action.fn(action, checkstate)
+ }
+ state.RevertToSnapshot(snapshotRevs[sindex])
+ if err := test.checkEqual(state, checkstate); err != nil {
+ test.err = fmt.Errorf("state mismatch after revert to snapshot %d\n%v", sindex, err)
+ return false
+ }
+ }
+ return true
+}
+
+// checkEqual checks that methods of state and checkstate return the same values.
+func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
+ for _, addr := range test.addrs {
+ var err error
+ checkeq := func(op string, a, b interface{}) bool {
+ if err == nil && !reflect.DeepEqual(a, b) {
+ err = fmt.Errorf("got %s(%s) == %v, want %v", op, addr.Hex(), a, b)
+ return false
+ }
+ return true
+ }
+ // Check basic accessor methods.
+ checkeq("Exist", state.Exist(addr), checkstate.Exist(addr))
+ checkeq("HasSuicided", state.HasSuicided(addr), checkstate.HasSuicided(addr))
+ checkeq("GetBalance", state.GetBalance(addr), checkstate.GetBalance(addr))
+ checkeq("GetNonce", state.GetNonce(addr), checkstate.GetNonce(addr))
+ checkeq("GetCode", state.GetCode(addr), checkstate.GetCode(addr))
+ checkeq("GetCodeHash", state.GetCodeHash(addr), checkstate.GetCodeHash(addr))
+ checkeq("GetCodeSize", state.GetCodeSize(addr), checkstate.GetCodeSize(addr))
+ // Check storage.
+ if obj := state.GetStateObject(addr); obj != nil {
+ obj.ForEachStorage(func(key, val common.Hash) bool {
+ return checkeq("GetState("+key.Hex()+")", val, checkstate.GetState(addr, key))
+ })
+ checkobj := checkstate.GetStateObject(addr)
+ checkobj.ForEachStorage(func(key, checkval common.Hash) bool {
+ return checkeq("GetState("+key.Hex()+")", state.GetState(addr, key), checkval)
+ })
+ }
+ if err != nil {
+ return err
+ }
+ }
+
+ if state.GetRefund().Cmp(checkstate.GetRefund()) != 0 {
+ return fmt.Errorf("got GetRefund() == %d, want GetRefund() == %d",
+ state.GetRefund(), checkstate.GetRefund())
+ }
+ if !reflect.DeepEqual(state.GetLogs(common.Hash{}), checkstate.GetLogs(common.Hash{})) {
+ return fmt.Errorf("got GetLogs(common.Hash{}) == %v, want GetLogs(common.Hash{}) == %v",
+ state.GetLogs(common.Hash{}), checkstate.GetLogs(common.Hash{}))
+ }
+ return nil
+}
diff --git a/core/state/sync_test.go b/core/state/sync_test.go
index 670e1fb1b..949df7301 100644
--- a/core/state/sync_test.go
+++ b/core/state/sync_test.go
@@ -57,7 +57,7 @@ func makeTestState() (ethdb.Database, common.Hash, []*testAccount) {
obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i}), []byte{i, i, i, i, i})
acc.code = []byte{i, i, i, i, i}
}
- state.UpdateStateObject(obj)
+ state.updateStateObject(obj)
accounts = append(accounts, acc)
}
root, _ := state.Commit()
diff --git a/core/tx_pool.go b/core/tx_pool.go
index f8b11a7ce..10a110e0b 100644
--- a/core/tx_pool.go
+++ b/core/tx_pool.go
@@ -257,7 +257,7 @@ func (pool *TxPool) validateTx(tx *types.Transaction) error {
// Make sure the account exist. Non existent accounts
// haven't got funds and well therefor never pass.
- if !currentState.HasAccount(from) {
+ if !currentState.Exist(from) {
return ErrNonExistentAccount
}
diff --git a/core/vm/environment.go b/core/vm/environment.go
index daf6fb90d..a4b2ac196 100644
--- a/core/vm/environment.go
+++ b/core/vm/environment.go
@@ -36,9 +36,9 @@ type Environment interface {
// The state database
Db() Database
// Creates a restorable snapshot
- MakeSnapshot() Database
+ SnapshotDatabase() int
// Set database to previous snapshot
- SetSnapshot(Database)
+ RevertToSnapshot(int)
// Address of the original invoker (first occurrence of the VM invoker)
Origin() common.Address
// The block number this VM is invoked on
@@ -105,9 +105,12 @@ type Database interface {
GetState(common.Address, common.Hash) common.Hash
SetState(common.Address, common.Hash, common.Hash)
- Delete(common.Address) bool
+ Suicide(common.Address) bool
+ HasSuicided(common.Address) bool
+
+ // Exist reports whether the given account exists in state.
+ // Notably this should also return true for suicided accounts.
Exist(common.Address) bool
- IsDeleted(common.Address) bool
}
// Account represents a contract or basic ethereum account.
diff --git a/core/vm/instructions.go b/core/vm/instructions.go
index 849a8463c..79aee60d2 100644
--- a/core/vm/instructions.go
+++ b/core/vm/instructions.go
@@ -614,7 +614,7 @@ func opSuicide(instr instruction, pc *uint64, env Environment, contract *Contrac
balance := env.Db().GetBalance(contract.Address())
env.Db().AddBalance(common.BigToAddress(stack.pop()), balance)
- env.Db().Delete(contract.Address())
+ env.Db().Suicide(contract.Address())
}
// following functions are used by the instruction jump table
diff --git a/core/vm/jit.go b/core/vm/jit.go
index 460a68ddd..55d2e0477 100644
--- a/core/vm/jit.go
+++ b/core/vm/jit.go
@@ -425,7 +425,7 @@ func jitCalculateGasAndSize(env Environment, contract *Contract, instr instructi
}
gas.Set(g)
case SUICIDE:
- if !statedb.IsDeleted(contract.Address()) {
+ if !statedb.HasSuicided(contract.Address()) {
statedb.AddRefund(params.SuicideRefundGas)
}
case MLOAD:
diff --git a/core/vm/jit_test.go b/core/vm/jit_test.go
index e6922aeb7..a6de710e1 100644
--- a/core/vm/jit_test.go
+++ b/core/vm/jit_test.go
@@ -179,8 +179,8 @@ func (self *Env) BlockNumber() *big.Int { return big.NewInt(0) }
//func (self *Env) PrevHash() []byte { return self.parent }
func (self *Env) Coinbase() common.Address { return common.Address{} }
-func (self *Env) MakeSnapshot() Database { return nil }
-func (self *Env) SetSnapshot(Database) {}
+func (self *Env) SnapshotDatabase() int { return 0 }
+func (self *Env) RevertToSnapshot(int) {}
func (self *Env) Time() *big.Int { return big.NewInt(time.Now().Unix()) }
func (self *Env) Difficulty() *big.Int { return big.NewInt(0) }
func (self *Env) Db() Database { return nil }
diff --git a/core/vm/runtime/env.go b/core/vm/runtime/env.go
index a4793c98f..59fbaa792 100644
--- a/core/vm/runtime/env.go
+++ b/core/vm/runtime/env.go
@@ -86,11 +86,11 @@ func (self *Env) SetDepth(i int) { self.depth = i }
func (self *Env) CanTransfer(from common.Address, balance *big.Int) bool {
return self.state.GetBalance(from).Cmp(balance) >= 0
}
-func (self *Env) MakeSnapshot() vm.Database {
- return self.state.Copy()
+func (self *Env) SnapshotDatabase() int {
+ return self.state.Snapshot()
}
-func (self *Env) SetSnapshot(copy vm.Database) {
- self.state.Set(copy.(*state.StateDB))
+func (self *Env) RevertToSnapshot(snapshot int) {
+ self.state.RevertToSnapshot(snapshot)
}
func (self *Env) Transfer(from, to vm.Account, amount *big.Int) {
diff --git a/core/vm/vm.go b/core/vm/vm.go
index 5d78b4a2a..033ada21c 100644
--- a/core/vm/vm.go
+++ b/core/vm/vm.go
@@ -303,7 +303,7 @@ func calculateGasAndSize(env Environment, contract *Contract, caller ContractRef
}
gas.Set(g)
case SUICIDE:
- if !statedb.IsDeleted(contract.Address()) {
+ if !statedb.HasSuicided(contract.Address()) {
statedb.AddRefund(params.SuicideRefundGas)
}
case MLOAD:
diff --git a/core/vm_env.go b/core/vm_env.go
index e541eaef4..d62eebbd9 100644
--- a/core/vm_env.go
+++ b/core/vm_env.go
@@ -89,12 +89,12 @@ func (self *VMEnv) CanTransfer(from common.Address, balance *big.Int) bool {
return self.state.GetBalance(from).Cmp(balance) >= 0
}
-func (self *VMEnv) MakeSnapshot() vm.Database {
- return self.state.Copy()
+func (self *VMEnv) SnapshotDatabase() int {
+ return self.state.Snapshot()
}
-func (self *VMEnv) SetSnapshot(copy vm.Database) {
- self.state.Set(copy.(*state.StateDB))
+func (self *VMEnv) RevertToSnapshot(snapshot int) {
+ self.state.RevertToSnapshot(snapshot)
}
func (self *VMEnv) Transfer(from, to vm.Account, amount *big.Int) {
diff --git a/eth/api_backend.go b/eth/api_backend.go
index 4adeb0aa0..42b84bf9b 100644
--- a/eth/api_backend.go
+++ b/eth/api_backend.go
@@ -98,12 +98,12 @@ func (b *EthApiBackend) GetTd(blockHash common.Hash) *big.Int {
}
func (b *EthApiBackend) GetVMEnv(ctx context.Context, msg core.Message, state ethapi.State, header *types.Header) (vm.Environment, func() error, error) {
- stateDb := state.(EthApiState).state.Copy()
+ statedb := state.(EthApiState).state
addr, _ := msg.From()
- from := stateDb.GetOrNewStateObject(addr)
+ from := statedb.GetOrNewStateObject(addr)
from.SetBalance(common.MaxBig)
vmError := func() error { return nil }
- return core.NewEnv(stateDb, b.eth.chainConfig, b.eth.blockchain, msg, header, b.eth.chainConfig.VmConfig), vmError, nil
+ return core.NewEnv(statedb, b.eth.chainConfig, b.eth.blockchain, msg, header, b.eth.chainConfig.VmConfig), vmError, nil
}
func (b *EthApiBackend) SendTx(ctx context.Context, signedTx *types.Transaction) error {
diff --git a/internal/ethapi/tracer_test.go b/internal/ethapi/tracer_test.go
index 7c831d299..127af32a8 100644
--- a/internal/ethapi/tracer_test.go
+++ b/internal/ethapi/tracer_test.go
@@ -50,14 +50,14 @@ func (self *Env) Origin() common.Address { return common.Address{} }
func (self *Env) BlockNumber() *big.Int { return big.NewInt(0) }
//func (self *Env) PrevHash() []byte { return self.parent }
-func (self *Env) Coinbase() common.Address { return common.Address{} }
-func (self *Env) MakeSnapshot() vm.Database { return nil }
-func (self *Env) SetSnapshot(vm.Database) {}
-func (self *Env) Time() *big.Int { return big.NewInt(time.Now().Unix()) }
-func (self *Env) Difficulty() *big.Int { return big.NewInt(0) }
-func (self *Env) Db() vm.Database { return nil }
-func (self *Env) GasLimit() *big.Int { return self.gasLimit }
-func (self *Env) VmType() vm.Type { return vm.StdVmTy }
+func (self *Env) Coinbase() common.Address { return common.Address{} }
+func (self *Env) SnapshotDatabase() int { return 0 }
+func (self *Env) RevertToSnapshot(int) {}
+func (self *Env) Time() *big.Int { return big.NewInt(time.Now().Unix()) }
+func (self *Env) Difficulty() *big.Int { return big.NewInt(0) }
+func (self *Env) Db() vm.Database { return nil }
+func (self *Env) GasLimit() *big.Int { return self.gasLimit }
+func (self *Env) VmType() vm.Type { return vm.StdVmTy }
func (self *Env) GetHash(n uint64) common.Hash {
return common.BytesToHash(crypto.Keccak256([]byte(big.NewInt(int64(n)).String())))
}
diff --git a/light/state_test.go b/light/state_test.go
index d4fe95022..a6b115786 100644
--- a/light/state_test.go
+++ b/light/state_test.go
@@ -23,7 +23,6 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/state"
- "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/trie"
"golang.org/x/net/context"
@@ -54,16 +53,13 @@ func makeTestState() (common.Hash, ethdb.Database) {
sdb, _ := ethdb.NewMemDatabase()
st, _ := state.New(common.Hash{}, sdb)
for i := byte(0); i < 100; i++ {
- so := st.GetOrNewStateObject(common.Address{i})
+ addr := common.Address{i}
for j := byte(0); j < 100; j++ {
- val := common.Hash{i, j}
- so.SetState(common.Hash{j}, val)
- so.SetNonce(100)
+ st.SetState(addr, common.Hash{j}, common.Hash{i, j})
}
- so.AddBalance(big.NewInt(int64(i)))
- so.SetCode(crypto.Keccak256Hash([]byte{i, i, i}), []byte{i, i, i})
- so.UpdateRoot(sdb)
- st.UpdateStateObject(so)
+ st.SetNonce(addr, 100)
+ st.AddBalance(addr, big.NewInt(int64(i)))
+ st.SetCode(addr, []byte{i, i, i})
}
root, _ := st.Commit()
return root, sdb
diff --git a/miner/worker.go b/miner/worker.go
index ac1ef5ba3..e5348cef4 100644
--- a/miner/worker.go
+++ b/miner/worker.go
@@ -171,7 +171,7 @@ func (self *worker) pending() (*types.Block, *state.StateDB) {
self.current.receipts,
), self.current.state
}
- return self.current.Block, self.current.state
+ return self.current.Block, self.current.state.Copy()
}
func (self *worker) start() {
@@ -618,7 +618,7 @@ func (env *Work) commitTransactions(mux *event.TypeMux, txs *types.TransactionsB
}
func (env *Work) commitTransaction(tx *types.Transaction, bc *core.BlockChain, gp *core.GasPool) (error, vm.Logs) {
- snap := env.state.Copy()
+ snap := env.state.Snapshot()
// this is a bit of a hack to force jit for the miners
config := env.config.VmConfig
@@ -629,7 +629,7 @@ func (env *Work) commitTransaction(tx *types.Transaction, bc *core.BlockChain, g
receipt, logs, _, err := core.ApplyTransaction(env.config, bc, gp, env.state, env.header, tx, env.header.GasUsed, config)
if err != nil {
- env.state.Set(snap)
+ env.state.RevertToSnapshot(snap)
return err, nil
}
env.txs = append(env.txs, tx)
diff --git a/tests/state_test_util.go b/tests/state_test_util.go
index 67e4bf832..3c4b42a18 100644
--- a/tests/state_test_util.go
+++ b/tests/state_test_util.go
@@ -95,14 +95,7 @@ func BenchStateTest(ruleSet RuleSet, p string, conf bconf, b *testing.B) error {
func benchStateTest(ruleSet RuleSet, test VmTest, env map[string]string, b *testing.B) {
b.StopTimer()
db, _ := ethdb.NewMemDatabase()
- statedb, _ := state.New(common.Hash{}, db)
- for addr, account := range test.Pre {
- obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty)
- statedb.SetStateObject(obj)
- for a, v := range account.Storage {
- obj.SetState(common.HexToHash(a), common.HexToHash(v))
- }
- }
+ statedb := makePreState(db, test.Pre)
b.StartTimer()
RunState(ruleSet, statedb, env, test.Exec)
@@ -134,14 +127,7 @@ func runStateTests(ruleSet RuleSet, tests map[string]VmTest, skipTests []string)
func runStateTest(ruleSet RuleSet, test VmTest) error {
db, _ := ethdb.NewMemDatabase()
- statedb, _ := state.New(common.Hash{}, db)
- for addr, account := range test.Pre {
- obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty)
- statedb.SetStateObject(obj)
- for a, v := range account.Storage {
- obj.SetState(common.HexToHash(a), common.HexToHash(v))
- }
- }
+ statedb := makePreState(db, test.Pre)
// XXX Yeah, yeah...
env := make(map[string]string)
@@ -227,7 +213,7 @@ func RunState(ruleSet RuleSet, statedb *state.StateDB, env, tx map[string]string
}
// Set pre compiled contracts
vm.Precompiled = vm.PrecompiledContracts()
- snapshot := statedb.Copy()
+ snapshot := statedb.Snapshot()
gaspool := new(core.GasPool).AddGas(common.Big(env["currentGasLimit"]))
key, _ := hex.DecodeString(tx["secretKey"])
@@ -237,7 +223,7 @@ func RunState(ruleSet RuleSet, statedb *state.StateDB, env, tx map[string]string
vmenv.origin = addr
ret, _, err := core.ApplyMessage(vmenv, message, gaspool)
if core.IsNonceErr(err) || core.IsInvalidTxErr(err) || core.IsGasLimitErr(err) {
- statedb.Set(snapshot)
+ statedb.RevertToSnapshot(snapshot)
}
statedb.Commit()
diff --git a/tests/util.go b/tests/util.go
index ffbcb9d56..8a9d09213 100644
--- a/tests/util.go
+++ b/tests/util.go
@@ -103,19 +103,25 @@ func (self Log) Topics() [][]byte {
return t
}
-func StateObjectFromAccount(db ethdb.Database, addr string, account Account, onDirty func(common.Address)) *state.StateObject {
+func makePreState(db ethdb.Database, accounts map[string]Account) *state.StateDB {
+ statedb, _ := state.New(common.Hash{}, db)
+ for addr, account := range accounts {
+ insertAccount(statedb, addr, account)
+ }
+ return statedb
+}
+
+func insertAccount(state *state.StateDB, saddr string, account Account) {
if common.IsHex(account.Code) {
account.Code = account.Code[2:]
}
- code := common.Hex2Bytes(account.Code)
- codeHash := crypto.Keccak256Hash(code)
- obj := state.NewObject(common.HexToAddress(addr), state.Account{
- Balance: common.Big(account.Balance),
- CodeHash: codeHash[:],
- Nonce: common.Big(account.Nonce).Uint64(),
- }, onDirty)
- obj.SetCode(codeHash, code)
- return obj
+ addr := common.HexToAddress(saddr)
+ state.SetCode(addr, common.Hex2Bytes(account.Code))
+ state.SetNonce(addr, common.Big(account.Nonce).Uint64())
+ state.SetBalance(addr, common.Big(account.Balance))
+ for a, v := range account.Storage {
+ state.SetState(addr, common.HexToHash(a), common.HexToHash(v))
+ }
}
type VmEnv struct {
@@ -229,11 +235,11 @@ func (self *Env) CanTransfer(from common.Address, balance *big.Int) bool {
return self.state.GetBalance(from).Cmp(balance) >= 0
}
-func (self *Env) MakeSnapshot() vm.Database {
- return self.state.Copy()
+func (self *Env) SnapshotDatabase() int {
+ return self.state.Snapshot()
}
-func (self *Env) SetSnapshot(copy vm.Database) {
- self.state.Set(copy.(*state.StateDB))
+func (self *Env) RevertToSnapshot(snapshot int) {
+ self.state.RevertToSnapshot(snapshot)
}
func (self *Env) Transfer(from, to vm.Account, amount *big.Int) {
diff --git a/tests/vm_test_util.go b/tests/vm_test_util.go
index 4ad72d91c..c269f21e0 100644
--- a/tests/vm_test_util.go
+++ b/tests/vm_test_util.go
@@ -101,14 +101,7 @@ func BenchVmTest(p string, conf bconf, b *testing.B) error {
func benchVmTest(test VmTest, env map[string]string, b *testing.B) {
b.StopTimer()
db, _ := ethdb.NewMemDatabase()
- statedb, _ := state.New(common.Hash{}, db)
- for addr, account := range test.Pre {
- obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty)
- statedb.SetStateObject(obj)
- for a, v := range account.Storage {
- obj.SetState(common.HexToHash(a), common.HexToHash(v))
- }
- }
+ statedb := makePreState(db, test.Pre)
b.StartTimer()
RunVm(statedb, env, test.Exec)
@@ -152,14 +145,7 @@ func runVmTests(tests map[string]VmTest, skipTests []string) error {
func runVmTest(test VmTest) error {
db, _ := ethdb.NewMemDatabase()
- statedb, _ := state.New(common.Hash{}, db)
- for addr, account := range test.Pre {
- obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty)
- statedb.SetStateObject(obj)
- for a, v := range account.Storage {
- obj.SetState(common.HexToHash(a), common.HexToHash(v))
- }
- }
+ statedb := makePreState(db, test.Pre)
// XXX Yeah, yeah...
env := make(map[string]string)