diff options
Diffstat (limited to 'tests/state_test_util.go')
-rw-r--r-- | tests/state_test_util.go | 181 |
1 files changed, 181 insertions, 0 deletions
diff --git a/tests/state_test_util.go b/tests/state_test_util.go new file mode 100644 index 000000000..e9abad788 --- /dev/null +++ b/tests/state_test_util.go @@ -0,0 +1,181 @@ +package tests + +import ( + "bytes" + "fmt" + "io" + "math/big" + "strconv" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/core/vm" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/logger/glog" +) + +func RunStateTestWithReader(r io.Reader, skipTests []string) error { + tests := make(map[string]VmTest) + if err := readJson(r, &tests); err != nil { + return err + } + + if err := runStateTests(tests, skipTests); err != nil { + return err + } + + return nil +} + +func RunStateTest(p string, skipTests []string) error { + tests := make(map[string]VmTest) + if err := readJsonFile(p, &tests); err != nil { + return err + } + + if err := runStateTests(tests, skipTests); err != nil { + return err + } + + return nil + +} + +func runStateTests(tests map[string]VmTest, skipTests []string) error { + skipTest := make(map[string]bool, len(skipTests)) + for _, name := range skipTests { + skipTest[name] = true + } + + for name, test := range tests { + if skipTest[name] { + glog.Infoln("Skipping state test", name) + return nil + } + + if err := runStateTest(test); err != nil { + return fmt.Errorf("%s: %s\n", name, err.Error()) + } + + glog.Infoln("State test passed: ", name) + //fmt.Println(string(statedb.Dump())) + } + return nil + +} + +func runStateTest(test VmTest) error { + db, _ := ethdb.NewMemDatabase() + statedb := state.New(common.Hash{}, db) + for addr, account := range test.Pre { + obj := StateObjectFromAccount(db, addr, account) + statedb.SetStateObject(obj) + for a, v := range account.Storage { + obj.SetState(common.HexToHash(a), common.HexToHash(v)) + } + } + + // XXX Yeah, yeah... + env := make(map[string]string) + env["currentCoinbase"] = test.Env.CurrentCoinbase + env["currentDifficulty"] = test.Env.CurrentDifficulty + env["currentGasLimit"] = test.Env.CurrentGasLimit + env["currentNumber"] = test.Env.CurrentNumber + env["previousHash"] = test.Env.PreviousHash + if n, ok := test.Env.CurrentTimestamp.(float64); ok { + env["currentTimestamp"] = strconv.Itoa(int(n)) + } else { + env["currentTimestamp"] = test.Env.CurrentTimestamp.(string) + } + + var ( + ret []byte + // gas *big.Int + // err error + logs state.Logs + ) + + ret, logs, _, _ = RunState(statedb, env, test.Transaction) + + // // Compare expected and actual return + rexp := common.FromHex(test.Out) + if bytes.Compare(rexp, ret) != 0 { + return fmt.Errorf("return failed. Expected %x, got %x\n", rexp, ret) + } + + // check post state + for addr, account := range test.Post { + obj := statedb.GetStateObject(common.HexToAddress(addr)) + if obj == nil { + continue + } + + if obj.Balance().Cmp(common.Big(account.Balance)) != 0 { + return fmt.Errorf("(%x) balance failed. Expected %v, got %v => %v\n", obj.Address().Bytes()[:4], account.Balance, obj.Balance(), new(big.Int).Sub(common.Big(account.Balance), obj.Balance())) + } + + if obj.Nonce() != common.String2Big(account.Nonce).Uint64() { + return fmt.Errorf("(%x) nonce failed. Expected %v, got %v\n", obj.Address().Bytes()[:4], account.Nonce, obj.Nonce()) + } + + for addr, value := range account.Storage { + v := obj.GetState(common.HexToHash(addr)) + vexp := common.HexToHash(value) + + if v != vexp { + return fmt.Errorf("(%x: %s) storage failed. Expected %x, got %x (%v %v)\n", obj.Address().Bytes()[0:4], addr, vexp, v, vexp.Big(), v.Big()) + } + } + } + + statedb.Sync() + if common.HexToHash(test.PostStateRoot) != statedb.Root() { + return fmt.Errorf("Post state root error. Expected %s, got %x", test.PostStateRoot, statedb.Root()) + } + + // check logs + if len(test.Logs) > 0 { + if err := checkLogs(test.Logs, logs); err != nil { + return err + } + } + + return nil +} + +func RunState(statedb *state.StateDB, env, tx map[string]string) ([]byte, state.Logs, *big.Int, error) { + var ( + keyPair, _ = crypto.NewKeyPairFromSec([]byte(common.Hex2Bytes(tx["secretKey"]))) + data = common.FromHex(tx["data"]) + gas = common.Big(tx["gasLimit"]) + price = common.Big(tx["gasPrice"]) + value = common.Big(tx["value"]) + nonce = common.Big(tx["nonce"]).Uint64() + caddr = common.HexToAddress(env["currentCoinbase"]) + ) + + var to *common.Address + if len(tx["to"]) > 2 { + t := common.HexToAddress(tx["to"]) + to = &t + } + // Set pre compiled contracts + vm.Precompiled = vm.PrecompiledContracts() + + snapshot := statedb.Copy() + coinbase := statedb.GetOrNewStateObject(caddr) + coinbase.SetGasPool(common.Big(env["currentGasLimit"])) + + message := NewMessage(common.BytesToAddress(keyPair.Address()), to, data, value, gas, price, nonce) + vmenv := NewEnvFromMap(statedb, env, tx) + vmenv.origin = common.BytesToAddress(keyPair.Address()) + ret, _, err := core.ApplyMessage(vmenv, message, coinbase) + if core.IsNonceErr(err) || core.IsInvalidTxErr(err) || state.IsGasLimitErr(err) { + statedb.Set(snapshot) + } + statedb.Update() + + return ret, vmenv.state.Logs(), vmenv.Gas, err +} |