aboutsummaryrefslogtreecommitdiffstats
path: root/core/state/statedb_test.go
diff options
context:
space:
mode:
authorFelix Lange <fjl@twurst.com>2016-10-04 18:36:02 +0800
committerFelix Lange <fjl@twurst.com>2016-10-06 21:32:16 +0800
commit1f1ea18b5414bea22332bb4fce53cc95b5c6a07d (patch)
treed1aa3051f9c4d9f33a24519c18b70f0dd2f00644 /core/state/statedb_test.go
parentab7adb0027dbcf09cf75a533be356c1e24c46c90 (diff)
downloadgo-tangerine-1f1ea18b5414bea22332bb4fce53cc95b5c6a07d.tar
go-tangerine-1f1ea18b5414bea22332bb4fce53cc95b5c6a07d.tar.gz
go-tangerine-1f1ea18b5414bea22332bb4fce53cc95b5c6a07d.tar.bz2
go-tangerine-1f1ea18b5414bea22332bb4fce53cc95b5c6a07d.tar.lz
go-tangerine-1f1ea18b5414bea22332bb4fce53cc95b5c6a07d.tar.xz
go-tangerine-1f1ea18b5414bea22332bb4fce53cc95b5c6a07d.tar.zst
go-tangerine-1f1ea18b5414bea22332bb4fce53cc95b5c6a07d.zip
core/state: implement reverts by journaling all changes
This commit replaces the deep-copy based state revert mechanism with a linear complexity journal. This commit also hides several internal StateDB methods to limit the number of ways in which calling code can use the journal incorrectly. As usual consultation and bug fixes to the initial implementation were provided by @karalabe, @obscuren and @Arachnid. Thank you!
Diffstat (limited to 'core/state/statedb_test.go')
-rw-r--r--core/state/statedb_test.go313
1 files changed, 274 insertions, 39 deletions
diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go
index 7930b620d..e236cb8f3 100644
--- a/core/state/statedb_test.go
+++ b/core/state/statedb_test.go
@@ -17,11 +17,19 @@
package state
import (
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "math"
"math/big"
+ "math/rand"
+ "reflect"
+ "strings"
"testing"
+ "testing/quick"
"github.com/ethereum/go-ethereum/common"
- "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/ethdb"
)
@@ -34,16 +42,16 @@ func TestUpdateLeaks(t *testing.T) {
// Update it with some accounts
for i := byte(0); i < 255; i++ {
- obj := state.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
- obj.AddBalance(big.NewInt(int64(11 * i)))
- obj.SetNonce(uint64(42 * i))
+ addr := common.BytesToAddress([]byte{i})
+ state.AddBalance(addr, big.NewInt(int64(11*i)))
+ state.SetNonce(addr, uint64(42*i))
if i%2 == 0 {
- obj.SetState(common.BytesToHash([]byte{i, i, i}), common.BytesToHash([]byte{i, i, i, i}))
+ state.SetState(addr, common.BytesToHash([]byte{i, i, i}), common.BytesToHash([]byte{i, i, i, i}))
}
if i%3 == 0 {
- obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i}), []byte{i, i, i, i, i})
+ state.SetCode(addr, []byte{i, i, i, i, i})
}
- state.UpdateStateObject(obj)
+ state.IntermediateRoot()
}
// Ensure that no data was leaked into the database
for _, key := range db.Keys() {
@@ -61,51 +69,38 @@ func TestIntermediateLeaks(t *testing.T) {
transState, _ := New(common.Hash{}, transDb)
finalState, _ := New(common.Hash{}, finalDb)
- // Update the states with some objects
- for i := byte(0); i < 255; i++ {
- // Create a new state object with some data into the transition database
- obj := transState.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
- obj.SetBalance(big.NewInt(int64(11 * i)))
- obj.SetNonce(uint64(42 * i))
+ modify := func(state *StateDB, addr common.Address, i, tweak byte) {
+ state.SetBalance(addr, big.NewInt(int64(11*i)+int64(tweak)))
+ state.SetNonce(addr, uint64(42*i+tweak))
if i%2 == 0 {
- obj.SetState(common.BytesToHash([]byte{i, i, i, 0}), common.BytesToHash([]byte{i, i, i, i, 0}))
+ state.SetState(addr, common.Hash{i, i, i, 0}, common.Hash{})
+ state.SetState(addr, common.Hash{i, i, i, tweak}, common.Hash{i, i, i, i, tweak})
}
if i%3 == 0 {
- obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i, 0}), []byte{i, i, i, i, i, 0})
+ state.SetCode(addr, []byte{i, i, i, i, i, tweak})
}
- transState.UpdateStateObject(obj)
+ }
- // Overwrite all the data with new values in the transition database
- obj.SetBalance(big.NewInt(int64(11*i + 1)))
- obj.SetNonce(uint64(42*i + 1))
- if i%2 == 0 {
- obj.SetState(common.BytesToHash([]byte{i, i, i, 0}), common.Hash{})
- obj.SetState(common.BytesToHash([]byte{i, i, i, 1}), common.BytesToHash([]byte{i, i, i, i, 1}))
- }
- if i%3 == 0 {
- obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i, 1}), []byte{i, i, i, i, i, 1})
- }
- transState.UpdateStateObject(obj)
+ // Modify the transient state.
+ for i := byte(0); i < 255; i++ {
+ modify(transState, common.Address{byte(i)}, i, 0)
+ }
+ // Write modifications to trie.
+ transState.IntermediateRoot()
- // Create the final state object directly in the final database
- obj = finalState.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
- obj.SetBalance(big.NewInt(int64(11*i + 1)))
- obj.SetNonce(uint64(42*i + 1))
- if i%2 == 0 {
- obj.SetState(common.BytesToHash([]byte{i, i, i, 1}), common.BytesToHash([]byte{i, i, i, i, 1}))
- }
- if i%3 == 0 {
- obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i, 1}), []byte{i, i, i, i, i, 1})
- }
- finalState.UpdateStateObject(obj)
+ // Overwrite all the data with new values in the transient database.
+ for i := byte(0); i < 255; i++ {
+ modify(transState, common.Address{byte(i)}, i, 99)
+ modify(finalState, common.Address{byte(i)}, i, 99)
}
+
+ // Commit and cross check the databases.
if _, err := transState.Commit(); err != nil {
t.Fatalf("failed to commit transition state: %v", err)
}
if _, err := finalState.Commit(); err != nil {
t.Fatalf("failed to commit final state: %v", err)
}
- // Cross check the databases to ensure they are the same
for _, key := range finalDb.Keys() {
if _, err := transDb.Get(key); err != nil {
val, _ := finalDb.Get(key)
@@ -119,3 +114,243 @@ func TestIntermediateLeaks(t *testing.T) {
}
}
}
+
+func TestSnapshotRandom(t *testing.T) {
+ config := &quick.Config{MaxCount: 1000}
+ err := quick.Check((*snapshotTest).run, config)
+ if cerr, ok := err.(*quick.CheckError); ok {
+ test := cerr.In[0].(*snapshotTest)
+ t.Errorf("%v:\n%s", test.err, test)
+ } else if err != nil {
+ t.Error(err)
+ }
+}
+
+// A snapshotTest checks that reverting StateDB snapshots properly undoes all changes
+// captured by the snapshot. Instances of this test with pseudorandom content are created
+// by Generate.
+//
+// The test works as follows:
+//
+// A new state is created and all actions are applied to it. Several snapshots are taken
+// in between actions. The test then reverts each snapshot. For each snapshot the actions
+// leading up to it are replayed on a fresh, empty state. The behaviour of all public
+// accessor methods on the reverted state must match the return value of the equivalent
+// methods on the replayed state.
+type snapshotTest struct {
+ addrs []common.Address // all account addresses
+ actions []testAction // modifications to the state
+ snapshots []int // actions indexes at which snapshot is taken
+ err error // failure details are reported through this field
+}
+
+type testAction struct {
+ name string
+ fn func(testAction, *StateDB)
+ args []int64
+ noAddr bool
+}
+
+// newTestAction creates a random action that changes state.
+func newTestAction(addr common.Address, r *rand.Rand) testAction {
+ actions := []testAction{
+ {
+ name: "SetBalance",
+ fn: func(a testAction, s *StateDB) {
+ s.SetBalance(addr, big.NewInt(a.args[0]))
+ },
+ args: make([]int64, 1),
+ },
+ {
+ name: "AddBalance",
+ fn: func(a testAction, s *StateDB) {
+ s.AddBalance(addr, big.NewInt(a.args[0]))
+ },
+ args: make([]int64, 1),
+ },
+ {
+ name: "SetNonce",
+ fn: func(a testAction, s *StateDB) {
+ s.SetNonce(addr, uint64(a.args[0]))
+ },
+ args: make([]int64, 1),
+ },
+ {
+ name: "SetState",
+ fn: func(a testAction, s *StateDB) {
+ var key, val common.Hash
+ binary.BigEndian.PutUint16(key[:], uint16(a.args[0]))
+ binary.BigEndian.PutUint16(val[:], uint16(a.args[1]))
+ s.SetState(addr, key, val)
+ },
+ args: make([]int64, 2),
+ },
+ {
+ name: "SetCode",
+ fn: func(a testAction, s *StateDB) {
+ code := make([]byte, 16)
+ binary.BigEndian.PutUint64(code, uint64(a.args[0]))
+ binary.BigEndian.PutUint64(code[8:], uint64(a.args[1]))
+ s.SetCode(addr, code)
+ },
+ args: make([]int64, 2),
+ },
+ {
+ name: "CreateAccount",
+ fn: func(a testAction, s *StateDB) {
+ s.CreateAccount(addr)
+ },
+ },
+ {
+ name: "Delete",
+ fn: func(a testAction, s *StateDB) {
+ s.Delete(addr)
+ },
+ },
+ {
+ name: "AddRefund",
+ fn: func(a testAction, s *StateDB) {
+ s.AddRefund(big.NewInt(a.args[0]))
+ },
+ args: make([]int64, 1),
+ noAddr: true,
+ },
+ {
+ name: "AddLog",
+ fn: func(a testAction, s *StateDB) {
+ data := make([]byte, 2)
+ binary.BigEndian.PutUint16(data, uint16(a.args[0]))
+ s.AddLog(&vm.Log{Address: addr, Data: data})
+ },
+ args: make([]int64, 1),
+ },
+ }
+ action := actions[r.Intn(len(actions))]
+ var nameargs []string
+ if !action.noAddr {
+ nameargs = append(nameargs, addr.Hex())
+ }
+ for _, i := range action.args {
+ action.args[i] = rand.Int63n(100)
+ nameargs = append(nameargs, fmt.Sprint(action.args[i]))
+ }
+ action.name += strings.Join(nameargs, ", ")
+ return action
+}
+
+// Generate returns a new snapshot test of the given size. All randomness is
+// derived from r.
+func (*snapshotTest) Generate(r *rand.Rand, size int) reflect.Value {
+ // Generate random actions.
+ addrs := make([]common.Address, 50)
+ for i := range addrs {
+ addrs[i][0] = byte(i)
+ }
+ actions := make([]testAction, size)
+ for i := range actions {
+ addr := addrs[r.Intn(len(addrs))]
+ actions[i] = newTestAction(addr, r)
+ }
+ // Generate snapshot indexes.
+ nsnapshots := int(math.Sqrt(float64(size)))
+ if size > 0 && nsnapshots == 0 {
+ nsnapshots = 1
+ }
+ snapshots := make([]int, nsnapshots)
+ snaplen := len(actions) / nsnapshots
+ for i := range snapshots {
+ // Try to place the snapshots some number of actions apart from each other.
+ snapshots[i] = (i * snaplen) + r.Intn(snaplen)
+ }
+ return reflect.ValueOf(&snapshotTest{addrs, actions, snapshots, nil})
+}
+
+func (test *snapshotTest) String() string {
+ out := new(bytes.Buffer)
+ sindex := 0
+ for i, action := range test.actions {
+ if len(test.snapshots) > sindex && i == test.snapshots[sindex] {
+ fmt.Fprintf(out, "---- snapshot %d ----\n", sindex)
+ sindex++
+ }
+ fmt.Fprintf(out, "%4d: %s\n", i, action.name)
+ }
+ return out.String()
+}
+
+func (test *snapshotTest) run() bool {
+ // Run all actions and create snapshots.
+ var (
+ db, _ = ethdb.NewMemDatabase()
+ state, _ = New(common.Hash{}, db)
+ snapshotRevs = make([]int, len(test.snapshots))
+ sindex = 0
+ )
+ for i, action := range test.actions {
+ if len(test.snapshots) > sindex && i == test.snapshots[sindex] {
+ snapshotRevs[sindex] = state.Snapshot()
+ sindex++
+ }
+ action.fn(action, state)
+ }
+
+ // Revert all snapshots in reverse order. Each revert must yield a state
+ // that is equivalent to fresh state with all actions up the snapshot applied.
+ for sindex--; sindex >= 0; sindex-- {
+ checkstate, _ := New(common.Hash{}, db)
+ for _, action := range test.actions[:test.snapshots[sindex]] {
+ action.fn(action, checkstate)
+ }
+ state.RevertToSnapshot(snapshotRevs[sindex])
+ if err := test.checkEqual(state, checkstate); err != nil {
+ test.err = fmt.Errorf("state mismatch after revert to snapshot %d\n%v", sindex, err)
+ return false
+ }
+ }
+ return true
+}
+
+// checkEqual checks that methods of state and checkstate return the same values.
+func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
+ for _, addr := range test.addrs {
+ var err error
+ checkeq := func(op string, a, b interface{}) bool {
+ if err == nil && !reflect.DeepEqual(a, b) {
+ err = fmt.Errorf("got %s(%s) == %v, want %v", op, addr.Hex(), a, b)
+ return false
+ }
+ return true
+ }
+ // Check basic accessor methods.
+ checkeq("Exist", state.Exist(addr), checkstate.Exist(addr))
+ checkeq("IsDeleted", state.IsDeleted(addr), checkstate.IsDeleted(addr))
+ checkeq("GetBalance", state.GetBalance(addr), checkstate.GetBalance(addr))
+ checkeq("GetNonce", state.GetNonce(addr), checkstate.GetNonce(addr))
+ checkeq("GetCode", state.GetCode(addr), checkstate.GetCode(addr))
+ checkeq("GetCodeHash", state.GetCodeHash(addr), checkstate.GetCodeHash(addr))
+ checkeq("GetCodeSize", state.GetCodeSize(addr), checkstate.GetCodeSize(addr))
+ // Check storage.
+ if obj := state.GetStateObject(addr); obj != nil {
+ obj.ForEachStorage(func(key, val common.Hash) bool {
+ return checkeq("GetState("+key.Hex()+")", val, checkstate.GetState(addr, key))
+ })
+ checkobj := checkstate.GetStateObject(addr)
+ checkobj.ForEachStorage(func(key, checkval common.Hash) bool {
+ return checkeq("GetState("+key.Hex()+")", state.GetState(addr, key), checkval)
+ })
+ }
+ if err != nil {
+ return err
+ }
+ }
+
+ if state.GetRefund().Cmp(checkstate.GetRefund()) != 0 {
+ return fmt.Errorf("got GetRefund() == %d, want GetRefund() == %d",
+ state.GetRefund(), checkstate.GetRefund())
+ }
+ if !reflect.DeepEqual(state.GetLogs(common.Hash{}), checkstate.GetLogs(common.Hash{})) {
+ return fmt.Errorf("got GetLogs(common.Hash{}) == %v, want GetLogs(common.Hash{}) == %v",
+ state.GetLogs(common.Hash{}), checkstate.GetLogs(common.Hash{}))
+ }
+ return nil
+}