aboutsummaryrefslogtreecommitdiffstats
path: root/core
diff options
context:
space:
mode:
Diffstat (limited to 'core')
-rw-r--r--core/chain_makers.go2
-rw-r--r--core/execution.go8
-rw-r--r--core/state/dump.go2
-rw-r--r--core/state/journal.go117
-rw-r--r--core/state/managed_state_test.go7
-rw-r--r--core/state/state_object.go54
-rw-r--r--core/state/state_test.go22
-rw-r--r--core/state/statedb.go210
-rw-r--r--core/state/statedb_test.go313
-rw-r--r--core/state/sync_test.go2
-rw-r--r--core/tx_pool.go2
-rw-r--r--core/vm/environment.go4
-rw-r--r--core/vm/jit_test.go4
-rw-r--r--core/vm/runtime/env.go8
-rw-r--r--core/vm_env.go8
15 files changed, 602 insertions, 161 deletions
diff --git a/core/chain_makers.go b/core/chain_makers.go
index 0b9a5f75d..e3ad9cda0 100644
--- a/core/chain_makers.go
+++ b/core/chain_makers.go
@@ -131,7 +131,7 @@ func (b *BlockGen) AddUncheckedReceipt(receipt *types.Receipt) {
// TxNonce returns the next valid transaction nonce for the
// account at addr. It panics if the account does not exist.
func (b *BlockGen) TxNonce(addr common.Address) uint64 {
- if !b.statedb.HasAccount(addr) {
+ if !b.statedb.Exist(addr) {
panic("account does not exist")
}
return b.statedb.GetNonce(addr)
diff --git a/core/execution.go b/core/execution.go
index 1bc02f7fb..1cb507ee7 100644
--- a/core/execution.go
+++ b/core/execution.go
@@ -85,7 +85,7 @@ func exec(env vm.Environment, caller vm.ContractRef, address, codeAddr *common.A
createAccount = true
}
- snapshotPreTransfer := env.MakeSnapshot()
+ snapshotPreTransfer := env.SnapshotDatabase()
var (
from = env.Db().GetAccount(caller.Address())
to vm.Account
@@ -129,7 +129,7 @@ func exec(env vm.Environment, caller vm.ContractRef, address, codeAddr *common.A
if err != nil && (env.RuleSet().IsHomestead(env.BlockNumber()) || err != vm.CodeStoreOutOfGasError) {
contract.UseGas(contract.Gas)
- env.SetSnapshot(snapshotPreTransfer)
+ env.RevertToSnapshot(snapshotPreTransfer)
}
return ret, addr, err
@@ -144,7 +144,7 @@ func execDelegateCall(env vm.Environment, caller vm.ContractRef, originAddr, toA
return nil, common.Address{}, vm.DepthError
}
- snapshot := env.MakeSnapshot()
+ snapshot := env.SnapshotDatabase()
var to vm.Account
if !env.Db().Exist(*toAddr) {
@@ -162,7 +162,7 @@ func execDelegateCall(env vm.Environment, caller vm.ContractRef, originAddr, toA
if err != nil {
contract.UseGas(contract.Gas)
- env.SetSnapshot(snapshot)
+ env.RevertToSnapshot(snapshot)
}
return ret, addr, err
diff --git a/core/state/dump.go b/core/state/dump.go
index 58ecd852b..8294d61b9 100644
--- a/core/state/dump.go
+++ b/core/state/dump.go
@@ -52,7 +52,7 @@ func (self *StateDB) RawDump() Dump {
panic(err)
}
- obj := NewObject(common.BytesToAddress(addr), data, nil)
+ obj := newObject(nil, common.BytesToAddress(addr), data, nil)
account := DumpAccount{
Balance: data.Balance.String(),
Nonce: data.Nonce,
diff --git a/core/state/journal.go b/core/state/journal.go
new file mode 100644
index 000000000..540ade6fb
--- /dev/null
+++ b/core/state/journal.go
@@ -0,0 +1,117 @@
+// Copyright 2016 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+package state
+
+import (
+ "math/big"
+
+ "github.com/ethereum/go-ethereum/common"
+)
+
+type journalEntry interface {
+ undo(*StateDB)
+}
+
+type journal []journalEntry
+
+type (
+ // Changes to the account trie.
+ createObjectChange struct {
+ account *common.Address
+ }
+ resetObjectChange struct {
+ prev *StateObject
+ }
+ deleteAccountChange struct {
+ account *common.Address
+ prev bool // whether account had already suicided
+ prevbalance *big.Int
+ }
+
+ // Changes to individual accounts.
+ balanceChange struct {
+ account *common.Address
+ prev *big.Int
+ }
+ nonceChange struct {
+ account *common.Address
+ prev uint64
+ }
+ storageChange struct {
+ account *common.Address
+ key, prevalue common.Hash
+ }
+ codeChange struct {
+ account *common.Address
+ prevcode, prevhash []byte
+ }
+
+ // Changes to other state values.
+ refundChange struct {
+ prev *big.Int
+ }
+ addLogChange struct {
+ txhash common.Hash
+ }
+)
+
+func (ch createObjectChange) undo(s *StateDB) {
+ s.GetStateObject(*ch.account).deleted = true
+ delete(s.stateObjects, *ch.account)
+ delete(s.stateObjectsDirty, *ch.account)
+}
+
+func (ch resetObjectChange) undo(s *StateDB) {
+ s.setStateObject(ch.prev)
+}
+
+func (ch deleteAccountChange) undo(s *StateDB) {
+ obj := s.GetStateObject(*ch.account)
+ if obj != nil {
+ obj.remove = ch.prev
+ obj.setBalance(ch.prevbalance)
+ }
+}
+
+func (ch balanceChange) undo(s *StateDB) {
+ s.GetStateObject(*ch.account).setBalance(ch.prev)
+}
+
+func (ch nonceChange) undo(s *StateDB) {
+ s.GetStateObject(*ch.account).setNonce(ch.prev)
+}
+
+func (ch codeChange) undo(s *StateDB) {
+ s.GetStateObject(*ch.account).setCode(common.BytesToHash(ch.prevhash), ch.prevcode)
+}
+
+func (ch storageChange) undo(s *StateDB) {
+ s.GetStateObject(*ch.account).setState(ch.key, ch.prevalue)
+}
+
+func (ch refundChange) undo(s *StateDB) {
+ s.refund = ch.prev
+}
+
+func (ch addLogChange) undo(s *StateDB) {
+ logs := s.logs[ch.txhash]
+ if len(logs) == 1 {
+ delete(s.logs, ch.txhash)
+ } else {
+ s.logs[ch.txhash] = logs[:len(logs)-1]
+ }
+}
diff --git a/core/state/managed_state_test.go b/core/state/managed_state_test.go
index baa53428f..3f7bc2aa8 100644
--- a/core/state/managed_state_test.go
+++ b/core/state/managed_state_test.go
@@ -29,11 +29,8 @@ func create() (*ManagedState, *account) {
db, _ := ethdb.NewMemDatabase()
statedb, _ := New(common.Hash{}, db)
ms := ManageState(statedb)
- so := &StateObject{address: addr}
- so.SetNonce(100)
- ms.StateDB.stateObjects[addr] = so
- ms.accounts[addr] = newAccount(so)
-
+ ms.StateDB.SetNonce(addr, 100)
+ ms.accounts[addr] = newAccount(ms.StateDB.GetStateObject(addr))
return ms, ms.accounts[addr]
}
diff --git a/core/state/state_object.go b/core/state/state_object.go
index cbd50e2a3..31ff9bcd8 100644
--- a/core/state/state_object.go
+++ b/core/state/state_object.go
@@ -66,6 +66,7 @@ func (self Storage) Copy() Storage {
type StateObject struct {
address common.Address // Ethereum address of this account
data Account
+ db *StateDB
// DB error.
// State objects are used by the consensus core and VM which are
@@ -99,15 +100,15 @@ type Account struct {
CodeHash []byte
}
-// NewObject creates a state object.
-func NewObject(address common.Address, data Account, onDirty func(addr common.Address)) *StateObject {
+// newObject creates a state object.
+func newObject(db *StateDB, address common.Address, data Account, onDirty func(addr common.Address)) *StateObject {
if data.Balance == nil {
data.Balance = new(big.Int)
}
if data.CodeHash == nil {
data.CodeHash = emptyCodeHash
}
- return &StateObject{address: address, data: data, cachedStorage: make(Storage), dirtyStorage: make(Storage), onDirty: onDirty}
+ return &StateObject{db: db, address: address, data: data, cachedStorage: make(Storage), dirtyStorage: make(Storage), onDirty: onDirty}
}
// EncodeRLP implements rlp.Encoder.
@@ -122,7 +123,7 @@ func (self *StateObject) setError(err error) {
}
}
-func (self *StateObject) MarkForDeletion() {
+func (self *StateObject) markForDeletion() {
self.remove = true
if self.onDirty != nil {
self.onDirty(self.Address())
@@ -163,7 +164,16 @@ func (self *StateObject) GetState(db trie.Database, key common.Hash) common.Hash
}
// SetState updates a value in account storage.
-func (self *StateObject) SetState(key, value common.Hash) {
+func (self *StateObject) SetState(db trie.Database, key, value common.Hash) {
+ self.db.journal = append(self.db.journal, storageChange{
+ account: &self.address,
+ key: key,
+ prevalue: self.GetState(db, key),
+ })
+ self.setState(key, value)
+}
+
+func (self *StateObject) setState(key, value common.Hash) {
self.cachedStorage[key] = value
self.dirtyStorage[key] = value
@@ -189,7 +199,7 @@ func (self *StateObject) updateTrie(db trie.Database) {
}
// UpdateRoot sets the trie root to the current root hash of
-func (self *StateObject) UpdateRoot(db trie.Database) {
+func (self *StateObject) updateRoot(db trie.Database) {
self.updateTrie(db)
self.data.Root = self.trie.Hash()
}
@@ -232,6 +242,14 @@ func (c *StateObject) SubBalance(amount *big.Int) {
}
func (self *StateObject) SetBalance(amount *big.Int) {
+ self.db.journal = append(self.db.journal, balanceChange{
+ account: &self.address,
+ prev: new(big.Int).Set(self.data.Balance),
+ })
+ self.setBalance(amount)
+}
+
+func (self *StateObject) setBalance(amount *big.Int) {
self.data.Balance = amount
if self.onDirty != nil {
self.onDirty(self.Address())
@@ -242,8 +260,8 @@ func (self *StateObject) SetBalance(amount *big.Int) {
// Return the gas back to the origin. Used by the Virtual machine or Closures
func (c *StateObject) ReturnGas(gas, price *big.Int) {}
-func (self *StateObject) Copy(db trie.Database, onDirty func(addr common.Address)) *StateObject {
- stateObject := NewObject(self.address, self.data, onDirty)
+func (self *StateObject) deepCopy(db *StateDB, onDirty func(addr common.Address)) *StateObject {
+ stateObject := newObject(db, self.address, self.data, onDirty)
stateObject.trie = self.trie
stateObject.code = self.code
stateObject.dirtyStorage = self.dirtyStorage.Copy()
@@ -280,6 +298,16 @@ func (self *StateObject) Code(db trie.Database) []byte {
}
func (self *StateObject) SetCode(codeHash common.Hash, code []byte) {
+ prevcode := self.Code(self.db.db)
+ self.db.journal = append(self.db.journal, codeChange{
+ account: &self.address,
+ prevhash: self.CodeHash(),
+ prevcode: prevcode,
+ })
+ self.setCode(codeHash, code)
+}
+
+func (self *StateObject) setCode(codeHash common.Hash, code []byte) {
self.code = code
self.data.CodeHash = codeHash[:]
self.dirtyCode = true
@@ -290,6 +318,14 @@ func (self *StateObject) SetCode(codeHash common.Hash, code []byte) {
}
func (self *StateObject) SetNonce(nonce uint64) {
+ self.db.journal = append(self.db.journal, nonceChange{
+ account: &self.address,
+ prev: self.data.Nonce,
+ })
+ self.setNonce(nonce)
+}
+
+func (self *StateObject) setNonce(nonce uint64) {
self.data.Nonce = nonce
if self.onDirty != nil {
self.onDirty(self.Address())
@@ -322,7 +358,7 @@ func (self *StateObject) ForEachStorage(cb func(key, value common.Hash) bool) {
cb(h, value)
}
- it := self.trie.Iterator()
+ it := self.getTrie(self.db.db).Iterator()
for it.Next() {
// ignore cached values
key := common.BytesToHash(self.trie.GetKey(it.Key))
diff --git a/core/state/state_test.go b/core/state/state_test.go
index 7b9b39e06..b86d8b140 100644
--- a/core/state/state_test.go
+++ b/core/state/state_test.go
@@ -46,8 +46,8 @@ func (s *StateSuite) TestDump(c *checker.C) {
obj3.SetBalance(big.NewInt(44))
// write some of them to the trie
- s.state.UpdateStateObject(obj1)
- s.state.UpdateStateObject(obj2)
+ s.state.updateStateObject(obj1)
+ s.state.updateStateObject(obj2)
s.state.Commit()
// check that dump contains the state objects that are in trie
@@ -116,12 +116,12 @@ func (s *StateSuite) TestSnapshot(c *checker.C) {
// set initial state object value
s.state.SetState(stateobjaddr, storageaddr, data1)
// get snapshot of current state
- snapshot := s.state.Copy()
+ snapshot := s.state.Snapshot()
// set new state object value
s.state.SetState(stateobjaddr, storageaddr, data2)
// restore snapshot
- s.state.Set(snapshot)
+ s.state.RevertToSnapshot(snapshot)
// get state storage value
res := s.state.GetState(stateobjaddr, storageaddr)
@@ -129,6 +129,12 @@ func (s *StateSuite) TestSnapshot(c *checker.C) {
c.Assert(data1, checker.DeepEquals, res)
}
+func TestSnapshotEmpty(t *testing.T) {
+ db, _ := ethdb.NewMemDatabase()
+ state, _ := New(common.Hash{}, db)
+ state.RevertToSnapshot(state.Snapshot())
+}
+
// use testing instead of checker because checker does not support
// printing/logging in tests (-check.vv does not work)
func TestSnapshot2(t *testing.T) {
@@ -152,7 +158,7 @@ func TestSnapshot2(t *testing.T) {
so0.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e'}), []byte{'c', 'a', 'f', 'e'})
so0.remove = false
so0.deleted = false
- state.SetStateObject(so0)
+ state.setStateObject(so0)
root, _ := state.Commit()
state.Reset(root)
@@ -164,15 +170,15 @@ func TestSnapshot2(t *testing.T) {
so1.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e', '2'}), []byte{'c', 'a', 'f', 'e', '2'})
so1.remove = true
so1.deleted = true
- state.SetStateObject(so1)
+ state.setStateObject(so1)
so1 = state.GetStateObject(stateobjaddr1)
if so1 != nil {
t.Fatalf("deleted object not nil when getting")
}
- snapshot := state.Copy()
- state.Set(snapshot)
+ snapshot := state.Snapshot()
+ state.RevertToSnapshot(snapshot)
so0Restored := state.GetStateObject(stateobjaddr0)
// Update lazily-loaded values before comparing.
diff --git a/core/state/statedb.go b/core/state/statedb.go
index 4204c456e..4f74302c3 100644
--- a/core/state/statedb.go
+++ b/core/state/statedb.go
@@ -20,6 +20,7 @@ package state
import (
"fmt"
"math/big"
+ "sort"
"sync"
"github.com/ethereum/go-ethereum/common"
@@ -40,12 +41,17 @@ var StartingNonce uint64
const (
// Number of past tries to keep. The arbitrarily chosen value here
// is max uncle depth + 1.
- maxJournalLength = 8
+ maxTrieCacheLength = 8
// Number of codehash->size associations to keep.
codeSizeCacheSize = 100000
)
+type revision struct {
+ id int
+ journalIndex int
+}
+
// StateDBs within the ethereum protocol are used to store anything
// within the merkle trie. StateDBs take care of caching and storing
// nested states. It's the general query interface to retrieve:
@@ -69,6 +75,12 @@ type StateDB struct {
logs map[common.Hash]vm.Logs
logSize uint
+ // Journal of state modifications. This is the backbone of
+ // Snapshot and RevertToSnapshot.
+ journal journal
+ validRevisions []revision
+ nextRevisionId int
+
lock sync.Mutex
}
@@ -124,12 +136,12 @@ func (self *StateDB) Reset(root common.Hash) error {
self.trie = tr
self.stateObjects = make(map[common.Address]*StateObject)
self.stateObjectsDirty = make(map[common.Address]struct{})
- self.refund = new(big.Int)
self.thash = common.Hash{}
self.bhash = common.Hash{}
self.txIndex = 0
self.logs = make(map[common.Hash]vm.Logs)
self.logSize = 0
+ self.clearJournalAndRefund()
return nil
}
@@ -150,7 +162,7 @@ func (self *StateDB) pushTrie(t *trie.SecureTrie) {
self.lock.Lock()
defer self.lock.Unlock()
- if len(self.pastTries) >= maxJournalLength {
+ if len(self.pastTries) >= maxTrieCacheLength {
copy(self.pastTries, self.pastTries[1:])
self.pastTries[len(self.pastTries)-1] = t
} else {
@@ -165,6 +177,8 @@ func (self *StateDB) StartRecord(thash, bhash common.Hash, ti int) {
}
func (self *StateDB) AddLog(log *vm.Log) {
+ self.journal = append(self.journal, addLogChange{txhash: self.thash})
+
log.TxHash = self.thash
log.BlockHash = self.bhash
log.TxIndex = uint(self.txIndex)
@@ -186,13 +200,12 @@ func (self *StateDB) Logs() vm.Logs {
}
func (self *StateDB) AddRefund(gas *big.Int) {
+ self.journal = append(self.journal, refundChange{prev: new(big.Int).Set(self.refund)})
self.refund.Add(self.refund, gas)
}
-func (self *StateDB) HasAccount(addr common.Address) bool {
- return self.GetStateObject(addr) != nil
-}
-
+// Exist reports whether the given account address exists in the state.
+// Notably this also returns true for suicided accounts.
func (self *StateDB) Exist(addr common.Address) bool {
return self.GetStateObject(addr) != nil
}
@@ -207,7 +220,6 @@ func (self *StateDB) GetBalance(addr common.Address) *big.Int {
if stateObject != nil {
return stateObject.Balance()
}
-
return common.Big0
}
@@ -282,6 +294,13 @@ func (self *StateDB) AddBalance(addr common.Address, amount *big.Int) {
}
}
+func (self *StateDB) SetBalance(addr common.Address, amount *big.Int) {
+ stateObject := self.GetOrNewStateObject(addr)
+ if stateObject != nil {
+ stateObject.SetBalance(amount)
+ }
+}
+
func (self *StateDB) SetNonce(addr common.Address, nonce uint64) {
stateObject := self.GetOrNewStateObject(addr)
if stateObject != nil {
@@ -299,27 +318,36 @@ func (self *StateDB) SetCode(addr common.Address, code []byte) {
func (self *StateDB) SetState(addr common.Address, key common.Hash, value common.Hash) {
stateObject := self.GetOrNewStateObject(addr)
if stateObject != nil {
- stateObject.SetState(key, value)
+ stateObject.SetState(self.db, key, value)
}
}
+// Delete marks the given account as suicided.
+// This clears the account balance.
+//
+// The account's state object is still available until the state is committed,
+// GetStateObject will return a non-nil account after Delete.
func (self *StateDB) Delete(addr common.Address) bool {
stateObject := self.GetStateObject(addr)
- if stateObject != nil {
- stateObject.MarkForDeletion()
- stateObject.data.Balance = new(big.Int)
- return true
+ if stateObject == nil {
+ return false
}
-
- return false
+ self.journal = append(self.journal, deleteAccountChange{
+ account: &addr,
+ prev: stateObject.remove,
+ prevbalance: new(big.Int).Set(stateObject.Balance()),
+ })
+ stateObject.markForDeletion()
+ stateObject.data.Balance = new(big.Int)
+ return true
}
//
// Setting, updating & deleting state object methods
//
-// Update the given state object and apply it to state trie
-func (self *StateDB) UpdateStateObject(stateObject *StateObject) {
+// updateStateObject writes the given object to the trie.
+func (self *StateDB) updateStateObject(stateObject *StateObject) {
addr := stateObject.Address()
data, err := rlp.EncodeToBytes(stateObject)
if err != nil {
@@ -328,10 +356,9 @@ func (self *StateDB) UpdateStateObject(stateObject *StateObject) {
self.trie.Update(addr[:], data)
}
-// Delete the given state object and delete it from the state trie
-func (self *StateDB) DeleteStateObject(stateObject *StateObject) {
+// deleteStateObject removes the given object from the state trie.
+func (self *StateDB) deleteStateObject(stateObject *StateObject) {
stateObject.deleted = true
-
addr := stateObject.Address()
self.trie.Delete(addr[:])
}
@@ -357,12 +384,12 @@ func (self *StateDB) GetStateObject(addr common.Address) (stateObject *StateObje
return nil
}
// Insert into the live set.
- obj := NewObject(addr, data, self.MarkStateObjectDirty)
- self.SetStateObject(obj)
+ obj := newObject(self, addr, data, self.MarkStateObjectDirty)
+ self.setStateObject(obj)
return obj
}
-func (self *StateDB) SetStateObject(object *StateObject) {
+func (self *StateDB) setStateObject(object *StateObject) {
self.stateObjects[object.Address()] = object
}
@@ -370,52 +397,55 @@ func (self *StateDB) SetStateObject(object *StateObject) {
func (self *StateDB) GetOrNewStateObject(addr common.Address) *StateObject {
stateObject := self.GetStateObject(addr)
if stateObject == nil || stateObject.deleted {
- stateObject = self.CreateStateObject(addr)
+ stateObject, _ = self.createObject(addr)
}
-
return stateObject
}
-// NewStateObject create a state object whether it exist in the trie or not
-func (self *StateDB) newStateObject(addr common.Address) *StateObject {
- if glog.V(logger.Core) {
- glog.Infof("(+) %x\n", addr)
- }
- obj := NewObject(addr, Account{}, self.MarkStateObjectDirty)
- obj.SetNonce(StartingNonce) // sets the object to dirty
- self.stateObjects[addr] = obj
- return obj
-}
-
// MarkStateObjectDirty adds the specified object to the dirty map to avoid costly
// state object cache iteration to find a handful of modified ones.
func (self *StateDB) MarkStateObjectDirty(addr common.Address) {
self.stateObjectsDirty[addr] = struct{}{}
}
-// Creates creates a new state object and takes ownership.
-func (self *StateDB) CreateStateObject(addr common.Address) *StateObject {
- // Get previous (if any)
- so := self.GetStateObject(addr)
- // Create a new one
- newSo := self.newStateObject(addr)
-
- // If it existed set the balance to the new account
- if so != nil {
- newSo.data.Balance = so.data.Balance
+// createObject creates a new state object. If there is an existing account with
+// the given address, it is overwritten and returned as the second return value.
+func (self *StateDB) createObject(addr common.Address) (newobj, prev *StateObject) {
+ prev = self.GetStateObject(addr)
+ newobj = newObject(self, addr, Account{}, self.MarkStateObjectDirty)
+ newobj.setNonce(StartingNonce) // sets the object to dirty
+ if prev == nil {
+ if glog.V(logger.Core) {
+ glog.Infof("(+) %x\n", addr)
+ }
+ self.journal = append(self.journal, createObjectChange{account: &addr})
+ } else {
+ self.journal = append(self.journal, resetObjectChange{prev: prev})
}
-
- return newSo
-}
-
-func (self *StateDB) CreateAccount(addr common.Address) vm.Account {
- return self.CreateStateObject(addr)
+ self.setStateObject(newobj)
+ return newobj, prev
}
+// CreateAccount explicitly creates a state object. If a state object with the address
+// already exists the balance is carried over to the new account.
+//
+// CreateAccount is called during the EVM CREATE operation. The situation might arise that
+// a contract does the following:
//
-// Setting, copying of the state methods
+// 1. sends funds to sha(account ++ (nonce + 1))
+// 2. tx_create(sha(account ++ nonce)) (note that this gets the address of 1)
//
+// Carrying over the balance ensures that Ether doesn't disappear.
+func (self *StateDB) CreateAccount(addr common.Address) vm.Account {
+ new, prev := self.createObject(addr)
+ if prev != nil {
+ new.setBalance(prev.data.Balance)
+ }
+ return new
+}
+// Copy creates a deep, independent copy of the state.
+// Snapshots of the copied state cannot be applied to the copy.
func (self *StateDB) Copy() *StateDB {
self.lock.Lock()
defer self.lock.Unlock()
@@ -434,7 +464,7 @@ func (self *StateDB) Copy() *StateDB {
}
// Copy the dirty states and logs
for addr, _ := range self.stateObjectsDirty {
- state.stateObjects[addr] = self.stateObjects[addr].Copy(self.db, state.MarkStateObjectDirty)
+ state.stateObjects[addr] = self.stateObjects[addr].deepCopy(state, state.MarkStateObjectDirty)
state.stateObjectsDirty[addr] = struct{}{}
}
for hash, logs := range self.logs {
@@ -444,21 +474,38 @@ func (self *StateDB) Copy() *StateDB {
return state
}
-func (self *StateDB) Set(state *StateDB) {
- self.lock.Lock()
- defer self.lock.Unlock()
+// Snapshot returns an identifier for the current revision of the state.
+func (self *StateDB) Snapshot() int {
+ id := self.nextRevisionId
+ self.nextRevisionId++
+ self.validRevisions = append(self.validRevisions, revision{id, len(self.journal)})
+ return id
+}
+
+// RevertToSnapshot reverts all state changes made since the given revision.
+func (self *StateDB) RevertToSnapshot(revid int) {
+ // Find the snapshot in the stack of valid snapshots.
+ idx := sort.Search(len(self.validRevisions), func(i int) bool {
+ return self.validRevisions[i].id >= revid
+ })
+ if idx == len(self.validRevisions) || self.validRevisions[idx].id != revid {
+ panic(fmt.Errorf("revision id %v cannot be reverted", revid))
+ }
+ snapshot := self.validRevisions[idx].journalIndex
+
+ // Replay the journal to undo changes.
+ for i := len(self.journal) - 1; i >= snapshot; i-- {
+ self.journal[i].undo(self)
+ }
+ self.journal = self.journal[:snapshot]
- self.db = state.db
- self.trie = state.trie
- self.pastTries = state.pastTries
- self.stateObjects = state.stateObjects
- self.stateObjectsDirty = state.stateObjectsDirty
- self.codeSizeCache = state.codeSizeCache
- self.refund = state.refund
- self.logs = state.logs
- self.logSize = state.logSize
+ // Remove invalidated snapshots from the stack.
+ self.validRevisions = self.validRevisions[:idx]
}
+// GetRefund returns the current value of the refund counter.
+// The return value must not be modified by the caller and will become
+// invalid at the next call to AddRefund.
func (self *StateDB) GetRefund() *big.Int {
return self.refund
}
@@ -467,16 +514,17 @@ func (self *StateDB) GetRefund() *big.Int {
// It is called in between transactions to get the root hash that
// goes into transaction receipts.
func (s *StateDB) IntermediateRoot() common.Hash {
- s.refund = new(big.Int)
for addr, _ := range s.stateObjectsDirty {
stateObject := s.stateObjects[addr]
if stateObject.remove {
- s.DeleteStateObject(stateObject)
+ s.deleteStateObject(stateObject)
} else {
- stateObject.UpdateRoot(s.db)
- s.UpdateStateObject(stateObject)
+ stateObject.updateRoot(s.db)
+ s.updateStateObject(stateObject)
}
}
+ // Invalidate journal because reverting across transactions is not allowed.
+ s.clearJournalAndRefund()
return s.trie.Hash()
}
@@ -486,9 +534,9 @@ func (s *StateDB) IntermediateRoot() common.Hash {
// DeleteSuicides should not be used for consensus related updates
// under any circumstances.
func (s *StateDB) DeleteSuicides() {
- // Reset refund so that any used-gas calculations can use
- // this method.
- s.refund = new(big.Int)
+ // Reset refund so that any used-gas calculations can use this method.
+ s.clearJournalAndRefund()
+
for addr, _ := range s.stateObjectsDirty {
stateObject := s.stateObjects[addr]
@@ -516,15 +564,21 @@ func (s *StateDB) CommitBatch() (root common.Hash, batch ethdb.Batch) {
return root, batch
}
-func (s *StateDB) commit(dbw trie.DatabaseWriter) (root common.Hash, err error) {
+func (s *StateDB) clearJournalAndRefund() {
+ s.journal = nil
+ s.validRevisions = s.validRevisions[:0]
s.refund = new(big.Int)
+}
+
+func (s *StateDB) commit(dbw trie.DatabaseWriter) (root common.Hash, err error) {
+ defer s.clearJournalAndRefund()
// Commit objects to the trie.
for addr, stateObject := range s.stateObjects {
if stateObject.remove {
// If the object has been removed, don't bother syncing it
// and just mark it for deletion in the trie.
- s.DeleteStateObject(stateObject)
+ s.deleteStateObject(stateObject)
} else if _, ok := s.stateObjectsDirty[addr]; ok {
// Write any contract code associated with the state object
if stateObject.code != nil && stateObject.dirtyCode {
@@ -538,7 +592,7 @@ func (s *StateDB) commit(dbw trie.DatabaseWriter) (root common.Hash, err error)
return common.Hash{}, err
}
// Update the object in the main account trie.
- s.UpdateStateObject(stateObject)
+ s.updateStateObject(stateObject)
}
delete(s.stateObjectsDirty, addr)
}
@@ -549,7 +603,3 @@ func (s *StateDB) commit(dbw trie.DatabaseWriter) (root common.Hash, err error)
}
return root, err
}
-
-func (self *StateDB) Refunds() *big.Int {
- return self.refund
-}
diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go
index 7930b620d..e236cb8f3 100644
--- a/core/state/statedb_test.go
+++ b/core/state/statedb_test.go
@@ -17,11 +17,19 @@
package state
import (
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "math"
"math/big"
+ "math/rand"
+ "reflect"
+ "strings"
"testing"
+ "testing/quick"
"github.com/ethereum/go-ethereum/common"
- "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/ethdb"
)
@@ -34,16 +42,16 @@ func TestUpdateLeaks(t *testing.T) {
// Update it with some accounts
for i := byte(0); i < 255; i++ {
- obj := state.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
- obj.AddBalance(big.NewInt(int64(11 * i)))
- obj.SetNonce(uint64(42 * i))
+ addr := common.BytesToAddress([]byte{i})
+ state.AddBalance(addr, big.NewInt(int64(11*i)))
+ state.SetNonce(addr, uint64(42*i))
if i%2 == 0 {
- obj.SetState(common.BytesToHash([]byte{i, i, i}), common.BytesToHash([]byte{i, i, i, i}))
+ state.SetState(addr, common.BytesToHash([]byte{i, i, i}), common.BytesToHash([]byte{i, i, i, i}))
}
if i%3 == 0 {
- obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i}), []byte{i, i, i, i, i})
+ state.SetCode(addr, []byte{i, i, i, i, i})
}
- state.UpdateStateObject(obj)
+ state.IntermediateRoot()
}
// Ensure that no data was leaked into the database
for _, key := range db.Keys() {
@@ -61,51 +69,38 @@ func TestIntermediateLeaks(t *testing.T) {
transState, _ := New(common.Hash{}, transDb)
finalState, _ := New(common.Hash{}, finalDb)
- // Update the states with some objects
- for i := byte(0); i < 255; i++ {
- // Create a new state object with some data into the transition database
- obj := transState.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
- obj.SetBalance(big.NewInt(int64(11 * i)))
- obj.SetNonce(uint64(42 * i))
+ modify := func(state *StateDB, addr common.Address, i, tweak byte) {
+ state.SetBalance(addr, big.NewInt(int64(11*i)+int64(tweak)))
+ state.SetNonce(addr, uint64(42*i+tweak))
if i%2 == 0 {
- obj.SetState(common.BytesToHash([]byte{i, i, i, 0}), common.BytesToHash([]byte{i, i, i, i, 0}))
+ state.SetState(addr, common.Hash{i, i, i, 0}, common.Hash{})
+ state.SetState(addr, common.Hash{i, i, i, tweak}, common.Hash{i, i, i, i, tweak})
}
if i%3 == 0 {
- obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i, 0}), []byte{i, i, i, i, i, 0})
+ state.SetCode(addr, []byte{i, i, i, i, i, tweak})
}
- transState.UpdateStateObject(obj)
+ }
- // Overwrite all the data with new values in the transition database
- obj.SetBalance(big.NewInt(int64(11*i + 1)))
- obj.SetNonce(uint64(42*i + 1))
- if i%2 == 0 {
- obj.SetState(common.BytesToHash([]byte{i, i, i, 0}), common.Hash{})
- obj.SetState(common.BytesToHash([]byte{i, i, i, 1}), common.BytesToHash([]byte{i, i, i, i, 1}))
- }
- if i%3 == 0 {
- obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i, 1}), []byte{i, i, i, i, i, 1})
- }
- transState.UpdateStateObject(obj)
+ // Modify the transient state.
+ for i := byte(0); i < 255; i++ {
+ modify(transState, common.Address{byte(i)}, i, 0)
+ }
+ // Write modifications to trie.
+ transState.IntermediateRoot()
- // Create the final state object directly in the final database
- obj = finalState.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
- obj.SetBalance(big.NewInt(int64(11*i + 1)))
- obj.SetNonce(uint64(42*i + 1))
- if i%2 == 0 {
- obj.SetState(common.BytesToHash([]byte{i, i, i, 1}), common.BytesToHash([]byte{i, i, i, i, 1}))
- }
- if i%3 == 0 {
- obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i, 1}), []byte{i, i, i, i, i, 1})
- }
- finalState.UpdateStateObject(obj)
+ // Overwrite all the data with new values in the transient database.
+ for i := byte(0); i < 255; i++ {
+ modify(transState, common.Address{byte(i)}, i, 99)
+ modify(finalState, common.Address{byte(i)}, i, 99)
}
+
+ // Commit and cross check the databases.
if _, err := transState.Commit(); err != nil {
t.Fatalf("failed to commit transition state: %v", err)
}
if _, err := finalState.Commit(); err != nil {
t.Fatalf("failed to commit final state: %v", err)
}
- // Cross check the databases to ensure they are the same
for _, key := range finalDb.Keys() {
if _, err := transDb.Get(key); err != nil {
val, _ := finalDb.Get(key)
@@ -119,3 +114,243 @@ func TestIntermediateLeaks(t *testing.T) {
}
}
}
+
+func TestSnapshotRandom(t *testing.T) {
+ config := &quick.Config{MaxCount: 1000}
+ err := quick.Check((*snapshotTest).run, config)
+ if cerr, ok := err.(*quick.CheckError); ok {
+ test := cerr.In[0].(*snapshotTest)
+ t.Errorf("%v:\n%s", test.err, test)
+ } else if err != nil {
+ t.Error(err)
+ }
+}
+
+// A snapshotTest checks that reverting StateDB snapshots properly undoes all changes
+// captured by the snapshot. Instances of this test with pseudorandom content are created
+// by Generate.
+//
+// The test works as follows:
+//
+// A new state is created and all actions are applied to it. Several snapshots are taken
+// in between actions. The test then reverts each snapshot. For each snapshot the actions
+// leading up to it are replayed on a fresh, empty state. The behaviour of all public
+// accessor methods on the reverted state must match the return value of the equivalent
+// methods on the replayed state.
+type snapshotTest struct {
+ addrs []common.Address // all account addresses
+ actions []testAction // modifications to the state
+ snapshots []int // actions indexes at which snapshot is taken
+ err error // failure details are reported through this field
+}
+
+type testAction struct {
+ name string
+ fn func(testAction, *StateDB)
+ args []int64
+ noAddr bool
+}
+
+// newTestAction creates a random action that changes state.
+func newTestAction(addr common.Address, r *rand.Rand) testAction {
+ actions := []testAction{
+ {
+ name: "SetBalance",
+ fn: func(a testAction, s *StateDB) {
+ s.SetBalance(addr, big.NewInt(a.args[0]))
+ },
+ args: make([]int64, 1),
+ },
+ {
+ name: "AddBalance",
+ fn: func(a testAction, s *StateDB) {
+ s.AddBalance(addr, big.NewInt(a.args[0]))
+ },
+ args: make([]int64, 1),
+ },
+ {
+ name: "SetNonce",
+ fn: func(a testAction, s *StateDB) {
+ s.SetNonce(addr, uint64(a.args[0]))
+ },
+ args: make([]int64, 1),
+ },
+ {
+ name: "SetState",
+ fn: func(a testAction, s *StateDB) {
+ var key, val common.Hash
+ binary.BigEndian.PutUint16(key[:], uint16(a.args[0]))
+ binary.BigEndian.PutUint16(val[:], uint16(a.args[1]))
+ s.SetState(addr, key, val)
+ },
+ args: make([]int64, 2),
+ },
+ {
+ name: "SetCode",
+ fn: func(a testAction, s *StateDB) {
+ code := make([]byte, 16)
+ binary.BigEndian.PutUint64(code, uint64(a.args[0]))
+ binary.BigEndian.PutUint64(code[8:], uint64(a.args[1]))
+ s.SetCode(addr, code)
+ },
+ args: make([]int64, 2),
+ },
+ {
+ name: "CreateAccount",
+ fn: func(a testAction, s *StateDB) {
+ s.CreateAccount(addr)
+ },
+ },
+ {
+ name: "Delete",
+ fn: func(a testAction, s *StateDB) {
+ s.Delete(addr)
+ },
+ },
+ {
+ name: "AddRefund",
+ fn: func(a testAction, s *StateDB) {
+ s.AddRefund(big.NewInt(a.args[0]))
+ },
+ args: make([]int64, 1),
+ noAddr: true,
+ },
+ {
+ name: "AddLog",
+ fn: func(a testAction, s *StateDB) {
+ data := make([]byte, 2)
+ binary.BigEndian.PutUint16(data, uint16(a.args[0]))
+ s.AddLog(&vm.Log{Address: addr, Data: data})
+ },
+ args: make([]int64, 1),
+ },
+ }
+ action := actions[r.Intn(len(actions))]
+ var nameargs []string
+ if !action.noAddr {
+ nameargs = append(nameargs, addr.Hex())
+ }
+ for _, i := range action.args {
+ action.args[i] = rand.Int63n(100)
+ nameargs = append(nameargs, fmt.Sprint(action.args[i]))
+ }
+ action.name += strings.Join(nameargs, ", ")
+ return action
+}
+
+// Generate returns a new snapshot test of the given size. All randomness is
+// derived from r.
+func (*snapshotTest) Generate(r *rand.Rand, size int) reflect.Value {
+ // Generate random actions.
+ addrs := make([]common.Address, 50)
+ for i := range addrs {
+ addrs[i][0] = byte(i)
+ }
+ actions := make([]testAction, size)
+ for i := range actions {
+ addr := addrs[r.Intn(len(addrs))]
+ actions[i] = newTestAction(addr, r)
+ }
+ // Generate snapshot indexes.
+ nsnapshots := int(math.Sqrt(float64(size)))
+ if size > 0 && nsnapshots == 0 {
+ nsnapshots = 1
+ }
+ snapshots := make([]int, nsnapshots)
+ snaplen := len(actions) / nsnapshots
+ for i := range snapshots {
+ // Try to place the snapshots some number of actions apart from each other.
+ snapshots[i] = (i * snaplen) + r.Intn(snaplen)
+ }
+ return reflect.ValueOf(&snapshotTest{addrs, actions, snapshots, nil})
+}
+
+func (test *snapshotTest) String() string {
+ out := new(bytes.Buffer)
+ sindex := 0
+ for i, action := range test.actions {
+ if len(test.snapshots) > sindex && i == test.snapshots[sindex] {
+ fmt.Fprintf(out, "---- snapshot %d ----\n", sindex)
+ sindex++
+ }
+ fmt.Fprintf(out, "%4d: %s\n", i, action.name)
+ }
+ return out.String()
+}
+
+func (test *snapshotTest) run() bool {
+ // Run all actions and create snapshots.
+ var (
+ db, _ = ethdb.NewMemDatabase()
+ state, _ = New(common.Hash{}, db)
+ snapshotRevs = make([]int, len(test.snapshots))
+ sindex = 0
+ )
+ for i, action := range test.actions {
+ if len(test.snapshots) > sindex && i == test.snapshots[sindex] {
+ snapshotRevs[sindex] = state.Snapshot()
+ sindex++
+ }
+ action.fn(action, state)
+ }
+
+ // Revert all snapshots in reverse order. Each revert must yield a state
+ // that is equivalent to fresh state with all actions up the snapshot applied.
+ for sindex--; sindex >= 0; sindex-- {
+ checkstate, _ := New(common.Hash{}, db)
+ for _, action := range test.actions[:test.snapshots[sindex]] {
+ action.fn(action, checkstate)
+ }
+ state.RevertToSnapshot(snapshotRevs[sindex])
+ if err := test.checkEqual(state, checkstate); err != nil {
+ test.err = fmt.Errorf("state mismatch after revert to snapshot %d\n%v", sindex, err)
+ return false
+ }
+ }
+ return true
+}
+
+// checkEqual checks that methods of state and checkstate return the same values.
+func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
+ for _, addr := range test.addrs {
+ var err error
+ checkeq := func(op string, a, b interface{}) bool {
+ if err == nil && !reflect.DeepEqual(a, b) {
+ err = fmt.Errorf("got %s(%s) == %v, want %v", op, addr.Hex(), a, b)
+ return false
+ }
+ return true
+ }
+ // Check basic accessor methods.
+ checkeq("Exist", state.Exist(addr), checkstate.Exist(addr))
+ checkeq("IsDeleted", state.IsDeleted(addr), checkstate.IsDeleted(addr))
+ checkeq("GetBalance", state.GetBalance(addr), checkstate.GetBalance(addr))
+ checkeq("GetNonce", state.GetNonce(addr), checkstate.GetNonce(addr))
+ checkeq("GetCode", state.GetCode(addr), checkstate.GetCode(addr))
+ checkeq("GetCodeHash", state.GetCodeHash(addr), checkstate.GetCodeHash(addr))
+ checkeq("GetCodeSize", state.GetCodeSize(addr), checkstate.GetCodeSize(addr))
+ // Check storage.
+ if obj := state.GetStateObject(addr); obj != nil {
+ obj.ForEachStorage(func(key, val common.Hash) bool {
+ return checkeq("GetState("+key.Hex()+")", val, checkstate.GetState(addr, key))
+ })
+ checkobj := checkstate.GetStateObject(addr)
+ checkobj.ForEachStorage(func(key, checkval common.Hash) bool {
+ return checkeq("GetState("+key.Hex()+")", state.GetState(addr, key), checkval)
+ })
+ }
+ if err != nil {
+ return err
+ }
+ }
+
+ if state.GetRefund().Cmp(checkstate.GetRefund()) != 0 {
+ return fmt.Errorf("got GetRefund() == %d, want GetRefund() == %d",
+ state.GetRefund(), checkstate.GetRefund())
+ }
+ if !reflect.DeepEqual(state.GetLogs(common.Hash{}), checkstate.GetLogs(common.Hash{})) {
+ return fmt.Errorf("got GetLogs(common.Hash{}) == %v, want GetLogs(common.Hash{}) == %v",
+ state.GetLogs(common.Hash{}), checkstate.GetLogs(common.Hash{}))
+ }
+ return nil
+}
diff --git a/core/state/sync_test.go b/core/state/sync_test.go
index 670e1fb1b..949df7301 100644
--- a/core/state/sync_test.go
+++ b/core/state/sync_test.go
@@ -57,7 +57,7 @@ func makeTestState() (ethdb.Database, common.Hash, []*testAccount) {
obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i}), []byte{i, i, i, i, i})
acc.code = []byte{i, i, i, i, i}
}
- state.UpdateStateObject(obj)
+ state.updateStateObject(obj)
accounts = append(accounts, acc)
}
root, _ := state.Commit()
diff --git a/core/tx_pool.go b/core/tx_pool.go
index f8b11a7ce..10a110e0b 100644
--- a/core/tx_pool.go
+++ b/core/tx_pool.go
@@ -257,7 +257,7 @@ func (pool *TxPool) validateTx(tx *types.Transaction) error {
// Make sure the account exist. Non existent accounts
// haven't got funds and well therefor never pass.
- if !currentState.HasAccount(from) {
+ if !currentState.Exist(from) {
return ErrNonExistentAccount
}
diff --git a/core/vm/environment.go b/core/vm/environment.go
index daf6fb90d..1038e69d5 100644
--- a/core/vm/environment.go
+++ b/core/vm/environment.go
@@ -36,9 +36,9 @@ type Environment interface {
// The state database
Db() Database
// Creates a restorable snapshot
- MakeSnapshot() Database
+ SnapshotDatabase() int
// Set database to previous snapshot
- SetSnapshot(Database)
+ RevertToSnapshot(int)
// Address of the original invoker (first occurrence of the VM invoker)
Origin() common.Address
// The block number this VM is invoked on
diff --git a/core/vm/jit_test.go b/core/vm/jit_test.go
index e6922aeb7..a6de710e1 100644
--- a/core/vm/jit_test.go
+++ b/core/vm/jit_test.go
@@ -179,8 +179,8 @@ func (self *Env) BlockNumber() *big.Int { return big.NewInt(0) }
//func (self *Env) PrevHash() []byte { return self.parent }
func (self *Env) Coinbase() common.Address { return common.Address{} }
-func (self *Env) MakeSnapshot() Database { return nil }
-func (self *Env) SetSnapshot(Database) {}
+func (self *Env) SnapshotDatabase() int { return 0 }
+func (self *Env) RevertToSnapshot(int) {}
func (self *Env) Time() *big.Int { return big.NewInt(time.Now().Unix()) }
func (self *Env) Difficulty() *big.Int { return big.NewInt(0) }
func (self *Env) Db() Database { return nil }
diff --git a/core/vm/runtime/env.go b/core/vm/runtime/env.go
index a4793c98f..59fbaa792 100644
--- a/core/vm/runtime/env.go
+++ b/core/vm/runtime/env.go
@@ -86,11 +86,11 @@ func (self *Env) SetDepth(i int) { self.depth = i }
func (self *Env) CanTransfer(from common.Address, balance *big.Int) bool {
return self.state.GetBalance(from).Cmp(balance) >= 0
}
-func (self *Env) MakeSnapshot() vm.Database {
- return self.state.Copy()
+func (self *Env) SnapshotDatabase() int {
+ return self.state.Snapshot()
}
-func (self *Env) SetSnapshot(copy vm.Database) {
- self.state.Set(copy.(*state.StateDB))
+func (self *Env) RevertToSnapshot(snapshot int) {
+ self.state.RevertToSnapshot(snapshot)
}
func (self *Env) Transfer(from, to vm.Account, amount *big.Int) {
diff --git a/core/vm_env.go b/core/vm_env.go
index e541eaef4..d62eebbd9 100644
--- a/core/vm_env.go
+++ b/core/vm_env.go
@@ -89,12 +89,12 @@ func (self *VMEnv) CanTransfer(from common.Address, balance *big.Int) bool {
return self.state.GetBalance(from).Cmp(balance) >= 0
}
-func (self *VMEnv) MakeSnapshot() vm.Database {
- return self.state.Copy()
+func (self *VMEnv) SnapshotDatabase() int {
+ return self.state.Snapshot()
}
-func (self *VMEnv) SetSnapshot(copy vm.Database) {
- self.state.Set(copy.(*state.StateDB))
+func (self *VMEnv) RevertToSnapshot(snapshot int) {
+ self.state.RevertToSnapshot(snapshot)
}
func (self *VMEnv) Transfer(from, to vm.Account, amount *big.Int) {