aboutsummaryrefslogtreecommitdiffstats
path: root/core/state/statedb_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'core/state/statedb_test.go')
-rw-r--r--core/state/statedb_test.go313
1 files changed, 274 insertions, 39 deletions
diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go
index 7930b620d..e236cb8f3 100644
--- a/core/state/statedb_test.go
+++ b/core/state/statedb_test.go
@@ -17,11 +17,19 @@
package state
import (
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "math"
"math/big"
+ "math/rand"
+ "reflect"
+ "strings"
"testing"
+ "testing/quick"
"github.com/ethereum/go-ethereum/common"
- "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/ethdb"
)
@@ -34,16 +42,16 @@ func TestUpdateLeaks(t *testing.T) {
// Update it with some accounts
for i := byte(0); i < 255; i++ {
- obj := state.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
- obj.AddBalance(big.NewInt(int64(11 * i)))
- obj.SetNonce(uint64(42 * i))
+ addr := common.BytesToAddress([]byte{i})
+ state.AddBalance(addr, big.NewInt(int64(11*i)))
+ state.SetNonce(addr, uint64(42*i))
if i%2 == 0 {
- obj.SetState(common.BytesToHash([]byte{i, i, i}), common.BytesToHash([]byte{i, i, i, i}))
+ state.SetState(addr, common.BytesToHash([]byte{i, i, i}), common.BytesToHash([]byte{i, i, i, i}))
}
if i%3 == 0 {
- obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i}), []byte{i, i, i, i, i})
+ state.SetCode(addr, []byte{i, i, i, i, i})
}
- state.UpdateStateObject(obj)
+ state.IntermediateRoot()
}
// Ensure that no data was leaked into the database
for _, key := range db.Keys() {
@@ -61,51 +69,38 @@ func TestIntermediateLeaks(t *testing.T) {
transState, _ := New(common.Hash{}, transDb)
finalState, _ := New(common.Hash{}, finalDb)
- // Update the states with some objects
- for i := byte(0); i < 255; i++ {
- // Create a new state object with some data into the transition database
- obj := transState.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
- obj.SetBalance(big.NewInt(int64(11 * i)))
- obj.SetNonce(uint64(42 * i))
+ modify := func(state *StateDB, addr common.Address, i, tweak byte) {
+ state.SetBalance(addr, big.NewInt(int64(11*i)+int64(tweak)))
+ state.SetNonce(addr, uint64(42*i+tweak))
if i%2 == 0 {
- obj.SetState(common.BytesToHash([]byte{i, i, i, 0}), common.BytesToHash([]byte{i, i, i, i, 0}))
+ state.SetState(addr, common.Hash{i, i, i, 0}, common.Hash{})
+ state.SetState(addr, common.Hash{i, i, i, tweak}, common.Hash{i, i, i, i, tweak})
}
if i%3 == 0 {
- obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i, 0}), []byte{i, i, i, i, i, 0})
+ state.SetCode(addr, []byte{i, i, i, i, i, tweak})
}
- transState.UpdateStateObject(obj)
+ }
- // Overwrite all the data with new values in the transition database
- obj.SetBalance(big.NewInt(int64(11*i + 1)))
- obj.SetNonce(uint64(42*i + 1))
- if i%2 == 0 {
- obj.SetState(common.BytesToHash([]byte{i, i, i, 0}), common.Hash{})
- obj.SetState(common.BytesToHash([]byte{i, i, i, 1}), common.BytesToHash([]byte{i, i, i, i, 1}))
- }
- if i%3 == 0 {
- obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i, 1}), []byte{i, i, i, i, i, 1})
- }
- transState.UpdateStateObject(obj)
+ // Modify the transient state.
+ for i := byte(0); i < 255; i++ {
+ modify(transState, common.Address{byte(i)}, i, 0)
+ }
+ // Write modifications to trie.
+ transState.IntermediateRoot()
- // Create the final state object directly in the final database
- obj = finalState.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
- obj.SetBalance(big.NewInt(int64(11*i + 1)))
- obj.SetNonce(uint64(42*i + 1))
- if i%2 == 0 {
- obj.SetState(common.BytesToHash([]byte{i, i, i, 1}), common.BytesToHash([]byte{i, i, i, i, 1}))
- }
- if i%3 == 0 {
- obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i, 1}), []byte{i, i, i, i, i, 1})
- }
- finalState.UpdateStateObject(obj)
+ // Overwrite all the data with new values in the transient database.
+ for i := byte(0); i < 255; i++ {
+ modify(transState, common.Address{byte(i)}, i, 99)
+ modify(finalState, common.Address{byte(i)}, i, 99)
}
+
+ // Commit and cross check the databases.
if _, err := transState.Commit(); err != nil {
t.Fatalf("failed to commit transition state: %v", err)
}
if _, err := finalState.Commit(); err != nil {
t.Fatalf("failed to commit final state: %v", err)
}
- // Cross check the databases to ensure they are the same
for _, key := range finalDb.Keys() {
if _, err := transDb.Get(key); err != nil {
val, _ := finalDb.Get(key)
@@ -119,3 +114,243 @@ func TestIntermediateLeaks(t *testing.T) {
}
}
}
+
+func TestSnapshotRandom(t *testing.T) {
+ config := &quick.Config{MaxCount: 1000}
+ err := quick.Check((*snapshotTest).run, config)
+ if cerr, ok := err.(*quick.CheckError); ok {
+ test := cerr.In[0].(*snapshotTest)
+ t.Errorf("%v:\n%s", test.err, test)
+ } else if err != nil {
+ t.Error(err)
+ }
+}
+
+// A snapshotTest checks that reverting StateDB snapshots properly undoes all changes
+// captured by the snapshot. Instances of this test with pseudorandom content are created
+// by Generate.
+//
+// The test works as follows:
+//
+// A new state is created and all actions are applied to it. Several snapshots are taken
+// in between actions. The test then reverts each snapshot. For each snapshot the actions
+// leading up to it are replayed on a fresh, empty state. The behaviour of all public
+// accessor methods on the reverted state must match the return value of the equivalent
+// methods on the replayed state.
+type snapshotTest struct {
+ addrs []common.Address // all account addresses
+ actions []testAction // modifications to the state
+ snapshots []int // actions indexes at which snapshot is taken
+ err error // failure details are reported through this field
+}
+
+type testAction struct {
+ name string
+ fn func(testAction, *StateDB)
+ args []int64
+ noAddr bool
+}
+
+// newTestAction creates a random action that changes state.
+func newTestAction(addr common.Address, r *rand.Rand) testAction {
+ actions := []testAction{
+ {
+ name: "SetBalance",
+ fn: func(a testAction, s *StateDB) {
+ s.SetBalance(addr, big.NewInt(a.args[0]))
+ },
+ args: make([]int64, 1),
+ },
+ {
+ name: "AddBalance",
+ fn: func(a testAction, s *StateDB) {
+ s.AddBalance(addr, big.NewInt(a.args[0]))
+ },
+ args: make([]int64, 1),
+ },
+ {
+ name: "SetNonce",
+ fn: func(a testAction, s *StateDB) {
+ s.SetNonce(addr, uint64(a.args[0]))
+ },
+ args: make([]int64, 1),
+ },
+ {
+ name: "SetState",
+ fn: func(a testAction, s *StateDB) {
+ var key, val common.Hash
+ binary.BigEndian.PutUint16(key[:], uint16(a.args[0]))
+ binary.BigEndian.PutUint16(val[:], uint16(a.args[1]))
+ s.SetState(addr, key, val)
+ },
+ args: make([]int64, 2),
+ },
+ {
+ name: "SetCode",
+ fn: func(a testAction, s *StateDB) {
+ code := make([]byte, 16)
+ binary.BigEndian.PutUint64(code, uint64(a.args[0]))
+ binary.BigEndian.PutUint64(code[8:], uint64(a.args[1]))
+ s.SetCode(addr, code)
+ },
+ args: make([]int64, 2),
+ },
+ {
+ name: "CreateAccount",
+ fn: func(a testAction, s *StateDB) {
+ s.CreateAccount(addr)
+ },
+ },
+ {
+ name: "Delete",
+ fn: func(a testAction, s *StateDB) {
+ s.Delete(addr)
+ },
+ },
+ {
+ name: "AddRefund",
+ fn: func(a testAction, s *StateDB) {
+ s.AddRefund(big.NewInt(a.args[0]))
+ },
+ args: make([]int64, 1),
+ noAddr: true,
+ },
+ {
+ name: "AddLog",
+ fn: func(a testAction, s *StateDB) {
+ data := make([]byte, 2)
+ binary.BigEndian.PutUint16(data, uint16(a.args[0]))
+ s.AddLog(&vm.Log{Address: addr, Data: data})
+ },
+ args: make([]int64, 1),
+ },
+ }
+ action := actions[r.Intn(len(actions))]
+ var nameargs []string
+ if !action.noAddr {
+ nameargs = append(nameargs, addr.Hex())
+ }
+ for _, i := range action.args {
+ action.args[i] = rand.Int63n(100)
+ nameargs = append(nameargs, fmt.Sprint(action.args[i]))
+ }
+ action.name += strings.Join(nameargs, ", ")
+ return action
+}
+
+// Generate returns a new snapshot test of the given size. All randomness is
+// derived from r.
+func (*snapshotTest) Generate(r *rand.Rand, size int) reflect.Value {
+ // Generate random actions.
+ addrs := make([]common.Address, 50)
+ for i := range addrs {
+ addrs[i][0] = byte(i)
+ }
+ actions := make([]testAction, size)
+ for i := range actions {
+ addr := addrs[r.Intn(len(addrs))]
+ actions[i] = newTestAction(addr, r)
+ }
+ // Generate snapshot indexes.
+ nsnapshots := int(math.Sqrt(float64(size)))
+ if size > 0 && nsnapshots == 0 {
+ nsnapshots = 1
+ }
+ snapshots := make([]int, nsnapshots)
+ snaplen := len(actions) / nsnapshots
+ for i := range snapshots {
+ // Try to place the snapshots some number of actions apart from each other.
+ snapshots[i] = (i * snaplen) + r.Intn(snaplen)
+ }
+ return reflect.ValueOf(&snapshotTest{addrs, actions, snapshots, nil})
+}
+
+func (test *snapshotTest) String() string {
+ out := new(bytes.Buffer)
+ sindex := 0
+ for i, action := range test.actions {
+ if len(test.snapshots) > sindex && i == test.snapshots[sindex] {
+ fmt.Fprintf(out, "---- snapshot %d ----\n", sindex)
+ sindex++
+ }
+ fmt.Fprintf(out, "%4d: %s\n", i, action.name)
+ }
+ return out.String()
+}
+
+func (test *snapshotTest) run() bool {
+ // Run all actions and create snapshots.
+ var (
+ db, _ = ethdb.NewMemDatabase()
+ state, _ = New(common.Hash{}, db)
+ snapshotRevs = make([]int, len(test.snapshots))
+ sindex = 0
+ )
+ for i, action := range test.actions {
+ if len(test.snapshots) > sindex && i == test.snapshots[sindex] {
+ snapshotRevs[sindex] = state.Snapshot()
+ sindex++
+ }
+ action.fn(action, state)
+ }
+
+ // Revert all snapshots in reverse order. Each revert must yield a state
+ // that is equivalent to fresh state with all actions up the snapshot applied.
+ for sindex--; sindex >= 0; sindex-- {
+ checkstate, _ := New(common.Hash{}, db)
+ for _, action := range test.actions[:test.snapshots[sindex]] {
+ action.fn(action, checkstate)
+ }
+ state.RevertToSnapshot(snapshotRevs[sindex])
+ if err := test.checkEqual(state, checkstate); err != nil {
+ test.err = fmt.Errorf("state mismatch after revert to snapshot %d\n%v", sindex, err)
+ return false
+ }
+ }
+ return true
+}
+
+// checkEqual checks that methods of state and checkstate return the same values.
+func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
+ for _, addr := range test.addrs {
+ var err error
+ checkeq := func(op string, a, b interface{}) bool {
+ if err == nil && !reflect.DeepEqual(a, b) {
+ err = fmt.Errorf("got %s(%s) == %v, want %v", op, addr.Hex(), a, b)
+ return false
+ }
+ return true
+ }
+ // Check basic accessor methods.
+ checkeq("Exist", state.Exist(addr), checkstate.Exist(addr))
+ checkeq("IsDeleted", state.IsDeleted(addr), checkstate.IsDeleted(addr))
+ checkeq("GetBalance", state.GetBalance(addr), checkstate.GetBalance(addr))
+ checkeq("GetNonce", state.GetNonce(addr), checkstate.GetNonce(addr))
+ checkeq("GetCode", state.GetCode(addr), checkstate.GetCode(addr))
+ checkeq("GetCodeHash", state.GetCodeHash(addr), checkstate.GetCodeHash(addr))
+ checkeq("GetCodeSize", state.GetCodeSize(addr), checkstate.GetCodeSize(addr))
+ // Check storage.
+ if obj := state.GetStateObject(addr); obj != nil {
+ obj.ForEachStorage(func(key, val common.Hash) bool {
+ return checkeq("GetState("+key.Hex()+")", val, checkstate.GetState(addr, key))
+ })
+ checkobj := checkstate.GetStateObject(addr)
+ checkobj.ForEachStorage(func(key, checkval common.Hash) bool {
+ return checkeq("GetState("+key.Hex()+")", state.GetState(addr, key), checkval)
+ })
+ }
+ if err != nil {
+ return err
+ }
+ }
+
+ if state.GetRefund().Cmp(checkstate.GetRefund()) != 0 {
+ return fmt.Errorf("got GetRefund() == %d, want GetRefund() == %d",
+ state.GetRefund(), checkstate.GetRefund())
+ }
+ if !reflect.DeepEqual(state.GetLogs(common.Hash{}), checkstate.GetLogs(common.Hash{})) {
+ return fmt.Errorf("got GetLogs(common.Hash{}) == %v, want GetLogs(common.Hash{}) == %v",
+ state.GetLogs(common.Hash{}), checkstate.GetLogs(common.Hash{}))
+ }
+ return nil
+}