aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--accounts/abi/abi.go197
-rw-r--r--accounts/abi/abi_test.go802
-rw-r--r--accounts/abi/bind/backends/simulated.go4
-rw-r--r--accounts/abi/error.go5
-rw-r--r--accounts/abi/method.go2
-rw-r--r--accounts/abi/numbers.go13
-rw-r--r--accounts/abi/numbers_test.go38
-rw-r--r--accounts/abi/pack.go (renamed from accounts/abi/packing.go)15
-rw-r--r--accounts/abi/pack_test.go441
-rw-r--r--accounts/abi/reflect.go20
-rw-r--r--accounts/abi/type.go8
-rw-r--r--accounts/abi/type_test.go115
-rw-r--r--accounts/abi/unpack.go235
-rw-r--r--accounts/abi/unpack_test.go681
-rw-r--r--build/ci.go2
-rw-r--r--cmd/evm/runner.go6
-rw-r--r--cmd/geth/chaincmd.go2
-rw-r--r--cmd/puppeth/wizard_faucet.go3
-rw-r--r--common/hexutil/hexutil.go25
-rw-r--r--common/hexutil/json.go47
-rw-r--r--common/hexutil/json_test.go64
-rw-r--r--common/types.go15
-rw-r--r--common/types_test.go10
-rw-r--r--core/block_validator.go12
-rw-r--r--core/blockchain.go40
-rw-r--r--core/blockchain_test.go10
-rw-r--r--core/chain_makers.go4
-rw-r--r--core/genesis.go2
-rw-r--r--core/state/database.go154
-rw-r--r--core/state/dump.go2
-rw-r--r--core/state/iterator.go13
-rw-r--r--core/state/iterator_test.go8
-rw-r--r--core/state/managed_state_test.go2
-rw-r--r--core/state/state_object.go56
-rw-r--r--core/state/state_test.go39
-rw-r--r--core/state/statedb.go139
-rw-r--r--core/state/statedb_test.go40
-rw-r--r--core/state/sync_test.go75
-rw-r--r--core/tx_pool_test.go24
-rw-r--r--core/vm/gas_table.go16
-rw-r--r--core/vm/gas_table_test.go18
-rw-r--r--core/vm/instructions_test.go242
-rw-r--r--core/vm/runtime/runtime.go4
-rw-r--r--core/vm/runtime/runtime_test.go2
-rw-r--r--eth/api.go2
-rw-r--r--eth/api_backend.go35
-rw-r--r--eth/api_test.go2
-rw-r--r--eth/bind.go6
-rw-r--r--eth/downloader/downloader_test.go2
-rw-r--r--eth/handler_test.go2
-rw-r--r--internal/ethapi/api.go33
-rw-r--r--internal/ethapi/backend.go13
-rw-r--r--les/api_backend.go20
-rw-r--r--les/odr_test.go53
-rw-r--r--les/request_test.go2
-rw-r--r--light/lightchain.go5
-rw-r--r--light/odr.go8
-rw-r--r--light/odr_test.go164
-rw-r--r--light/odr_util.go19
-rw-r--r--light/state.go316
-rw-r--r--light/state_object.go275
-rw-r--r--light/state_test.go248
-rw-r--r--light/trie.go251
-rw-r--r--light/trie_test.go83
-rw-r--r--light/txpool.go32
-rw-r--r--light/vm_env.go194
-rw-r--r--miner/worker.go2
-rw-r--r--tests/block_test_util.go4
-rw-r--r--tests/state_test_util.go48
-rw-r--r--tests/util.go7
-rw-r--r--trie/proof.go5
-rw-r--r--trie/secure_trie.go5
72 files changed, 2660 insertions, 2828 deletions
diff --git a/accounts/abi/abi.go b/accounts/abi/abi.go
index 3d1010229..2a06d474b 100644
--- a/accounts/abi/abi.go
+++ b/accounts/abi/abi.go
@@ -17,11 +17,9 @@
package abi
import (
- "encoding/binary"
"encoding/json"
"fmt"
"io"
- "math/big"
"reflect"
"strings"
@@ -67,7 +65,7 @@ func (abi ABI) Pack(name string, args ...interface{}) ([]byte, error) {
}
method = m
}
- arguments, err := method.pack(method, args...)
+ arguments, err := method.pack(args...)
if err != nil {
return nil, err
}
@@ -78,199 +76,6 @@ func (abi ABI) Pack(name string, args ...interface{}) ([]byte, error) {
return append(method.Id(), arguments...), nil
}
-// toGoSliceType parses the input and casts it to the proper slice defined by the ABI
-// argument in T.
-func toGoSlice(i int, t Argument, output []byte) (interface{}, error) {
- index := i * 32
- // The slice must, at very least be large enough for the index+32 which is exactly the size required
- // for the [offset in output, size of offset].
- if index+32 > len(output) {
- return nil, fmt.Errorf("abi: cannot marshal in to go slice: insufficient size output %d require %d", len(output), index+32)
- }
- elem := t.Type.Elem
-
- // first we need to create a slice of the type
- var refSlice reflect.Value
- switch elem.T {
- case IntTy, UintTy, BoolTy:
- // create a new reference slice matching the element type
- switch t.Type.Kind {
- case reflect.Bool:
- refSlice = reflect.ValueOf([]bool(nil))
- case reflect.Uint8:
- refSlice = reflect.ValueOf([]uint8(nil))
- case reflect.Uint16:
- refSlice = reflect.ValueOf([]uint16(nil))
- case reflect.Uint32:
- refSlice = reflect.ValueOf([]uint32(nil))
- case reflect.Uint64:
- refSlice = reflect.ValueOf([]uint64(nil))
- case reflect.Int8:
- refSlice = reflect.ValueOf([]int8(nil))
- case reflect.Int16:
- refSlice = reflect.ValueOf([]int16(nil))
- case reflect.Int32:
- refSlice = reflect.ValueOf([]int32(nil))
- case reflect.Int64:
- refSlice = reflect.ValueOf([]int64(nil))
- default:
- refSlice = reflect.ValueOf([]*big.Int(nil))
- }
- case AddressTy: // address must be of slice Address
- refSlice = reflect.ValueOf([]common.Address(nil))
- case HashTy: // hash must be of slice hash
- refSlice = reflect.ValueOf([]common.Hash(nil))
- case FixedBytesTy:
- refSlice = reflect.ValueOf([][]byte(nil))
- default: // no other types are supported
- return nil, fmt.Errorf("abi: unsupported slice type %v", elem.T)
- }
-
- var slice []byte
- var size int
- var offset int
- if t.Type.IsSlice {
- // get the offset which determines the start of this array ...
- offset = int(binary.BigEndian.Uint64(output[index+24 : index+32]))
- if offset+32 > len(output) {
- return nil, fmt.Errorf("abi: cannot marshal in to go slice: offset %d would go over slice boundary (len=%d)", len(output), offset+32)
- }
-
- slice = output[offset:]
- // ... starting with the size of the array in elements ...
- size = int(binary.BigEndian.Uint64(slice[24:32]))
- slice = slice[32:]
- // ... and make sure that we've at the very least the amount of bytes
- // available in the buffer.
- if size*32 > len(slice) {
- return nil, fmt.Errorf("abi: cannot marshal in to go slice: insufficient size output %d require %d", len(output), offset+32+size*32)
- }
-
- // reslice to match the required size
- slice = slice[:size*32]
- } else if t.Type.IsArray {
- //get the number of elements in the array
- size = t.Type.SliceSize
-
- //check to make sure array size matches up
- if index+32*size > len(output) {
- return nil, fmt.Errorf("abi: cannot marshal in to go array: offset %d would go over slice boundary (len=%d)", len(output), index+32*size)
- }
- //slice is there for a fixed amount of times
- slice = output[index : index+size*32]
- }
-
- for i := 0; i < size; i++ {
- var (
- inter interface{} // interface type
- returnOutput = slice[i*32 : i*32+32] // the return output
- )
- // set inter to the correct type (cast)
- switch elem.T {
- case IntTy, UintTy:
- inter = readInteger(t.Type.Kind, returnOutput)
- case BoolTy:
- inter = !allZero(returnOutput)
- case AddressTy:
- inter = common.BytesToAddress(returnOutput)
- case HashTy:
- inter = common.BytesToHash(returnOutput)
- case FixedBytesTy:
- inter = returnOutput
- }
- // append the item to our reflect slice
- refSlice = reflect.Append(refSlice, reflect.ValueOf(inter))
- }
-
- // return the interface
- return refSlice.Interface(), nil
-}
-
-func readInteger(kind reflect.Kind, b []byte) interface{} {
- switch kind {
- case reflect.Uint8:
- return uint8(b[len(b)-1])
- case reflect.Uint16:
- return binary.BigEndian.Uint16(b[len(b)-2:])
- case reflect.Uint32:
- return binary.BigEndian.Uint32(b[len(b)-4:])
- case reflect.Uint64:
- return binary.BigEndian.Uint64(b[len(b)-8:])
- case reflect.Int8:
- return int8(b[len(b)-1])
- case reflect.Int16:
- return int16(binary.BigEndian.Uint16(b[len(b)-2:]))
- case reflect.Int32:
- return int32(binary.BigEndian.Uint32(b[len(b)-4:]))
- case reflect.Int64:
- return int64(binary.BigEndian.Uint64(b[len(b)-8:]))
- default:
- return new(big.Int).SetBytes(b)
- }
-}
-
-func allZero(b []byte) bool {
- for _, byte := range b {
- if byte != 0 {
- return false
- }
- }
- return true
-}
-
-// toGoType parses the input and casts it to the proper type defined by the ABI
-// argument in T.
-func toGoType(i int, t Argument, output []byte) (interface{}, error) {
- // we need to treat slices differently
- if (t.Type.IsSlice || t.Type.IsArray) && t.Type.T != BytesTy && t.Type.T != StringTy && t.Type.T != FixedBytesTy && t.Type.T != FunctionTy {
- return toGoSlice(i, t, output)
- }
-
- index := i * 32
- if index+32 > len(output) {
- return nil, fmt.Errorf("abi: cannot marshal in to go type: length insufficient %d require %d", len(output), index+32)
- }
-
- // Parse the given index output and check whether we need to read
- // a different offset and length based on the type (i.e. string, bytes)
- var returnOutput []byte
- switch t.Type.T {
- case StringTy, BytesTy: // variable arrays are written at the end of the return bytes
- // parse offset from which we should start reading
- offset := int(binary.BigEndian.Uint64(output[index+24 : index+32]))
- if offset+32 > len(output) {
- return nil, fmt.Errorf("abi: cannot marshal in to go type: length insufficient %d require %d", len(output), offset+32)
- }
- // parse the size up until we should be reading
- size := int(binary.BigEndian.Uint64(output[offset+24 : offset+32]))
- if offset+32+size > len(output) {
- return nil, fmt.Errorf("abi: cannot marshal in to go type: length insufficient %d require %d", len(output), offset+32+size)
- }
-
- // get the bytes for this return value
- returnOutput = output[offset+32 : offset+32+size]
- default:
- returnOutput = output[index : index+32]
- }
-
- // convert the bytes to whatever is specified by the ABI.
- switch t.Type.T {
- case IntTy, UintTy:
- return readInteger(t.Type.Kind, returnOutput), nil
- case BoolTy:
- return !allZero(returnOutput), nil
- case AddressTy:
- return common.BytesToAddress(returnOutput), nil
- case HashTy:
- return common.BytesToHash(returnOutput), nil
- case BytesTy, FixedBytesTy, FunctionTy:
- return returnOutput, nil
- case StringTy:
- return string(returnOutput), nil
- }
- return nil, fmt.Errorf("abi: unknown type %v", t.Type.T)
-}
-
// these variable are used to determine certain types during type assertion for
// assignment.
var (
diff --git a/accounts/abi/abi_test.go b/accounts/abi/abi_test.go
index a45bd6cc0..a3aa9446e 100644
--- a/accounts/abi/abi_test.go
+++ b/accounts/abi/abi_test.go
@@ -48,412 +48,6 @@ func pad(input []byte, size int, left bool) []byte {
return common.RightPadBytes(input, size)
}
-func TestTypeCheck(t *testing.T) {
- for i, test := range []struct {
- typ string
- input interface{}
- err string
- }{
- {"uint", big.NewInt(1), ""},
- {"int", big.NewInt(1), ""},
- {"uint30", big.NewInt(1), ""},
- {"uint30", uint8(1), "abi: cannot use uint8 as type ptr as argument"},
- {"uint16", uint16(1), ""},
- {"uint16", uint8(1), "abi: cannot use uint8 as type uint16 as argument"},
- {"uint16[]", []uint16{1, 2, 3}, ""},
- {"uint16[]", [3]uint16{1, 2, 3}, ""},
- {"uint16[]", []uint32{1, 2, 3}, "abi: cannot use []uint32 as type []uint16 as argument"},
- {"uint16[3]", [3]uint32{1, 2, 3}, "abi: cannot use [3]uint32 as type [3]uint16 as argument"},
- {"uint16[3]", [4]uint16{1, 2, 3}, "abi: cannot use [4]uint16 as type [3]uint16 as argument"},
- {"uint16[3]", []uint16{1, 2, 3}, ""},
- {"uint16[3]", []uint16{1, 2, 3, 4}, "abi: cannot use [4]uint16 as type [3]uint16 as argument"},
- {"address[]", []common.Address{{1}}, ""},
- {"address[1]", []common.Address{{1}}, ""},
- {"address[1]", [1]common.Address{{1}}, ""},
- {"address[2]", [1]common.Address{{1}}, "abi: cannot use [1]array as type [2]array as argument"},
- {"bytes32", [32]byte{}, ""},
- {"bytes32", [33]byte{}, "abi: cannot use [33]uint8 as type [32]uint8 as argument"},
- {"bytes32", common.Hash{1}, ""},
- {"bytes31", [31]byte{}, ""},
- {"bytes31", [32]byte{}, "abi: cannot use [32]uint8 as type [31]uint8 as argument"},
- {"bytes", []byte{0, 1}, ""},
- {"bytes", [2]byte{0, 1}, ""},
- {"bytes", common.Hash{1}, ""},
- {"string", "hello world", ""},
- {"bytes32[]", [][32]byte{{}}, ""},
- {"function", [24]byte{}, ""},
- } {
- typ, err := NewType(test.typ)
- if err != nil {
- t.Fatal("unexpected parse error:", err)
- }
-
- err = typeCheck(typ, reflect.ValueOf(test.input))
- if err != nil && len(test.err) == 0 {
- t.Errorf("%d failed. Expected no err but got: %v", i, err)
- continue
- }
- if err == nil && len(test.err) != 0 {
- t.Errorf("%d failed. Expected err: %v but got none", i, test.err)
- continue
- }
-
- if err != nil && len(test.err) != 0 && err.Error() != test.err {
- t.Errorf("%d failed. Expected err: '%v' got err: '%v'", i, test.err, err)
- }
- }
-}
-
-func TestSimpleMethodUnpack(t *testing.T) {
- for i, test := range []struct {
- def string // definition of the **output** ABI params
- marshalledOutput []byte // evm return data
- expectedOut interface{} // the expected output
- outVar string // the output variable (e.g. uint32, *big.Int, etc)
- err string // empty or error if expected
- }{
- {
- `[ { "type": "uint32" } ]`,
- pad([]byte{1}, 32, true),
- uint32(1),
- "uint32",
- "",
- },
- {
- `[ { "type": "uint32" } ]`,
- pad([]byte{1}, 32, true),
- nil,
- "uint16",
- "abi: cannot unmarshal uint32 in to uint16",
- },
- {
- `[ { "type": "uint17" } ]`,
- pad([]byte{1}, 32, true),
- nil,
- "uint16",
- "abi: cannot unmarshal *big.Int in to uint16",
- },
- {
- `[ { "type": "uint17" } ]`,
- pad([]byte{1}, 32, true),
- big.NewInt(1),
- "*big.Int",
- "",
- },
-
- {
- `[ { "type": "int32" } ]`,
- pad([]byte{1}, 32, true),
- int32(1),
- "int32",
- "",
- },
- {
- `[ { "type": "int32" } ]`,
- pad([]byte{1}, 32, true),
- nil,
- "int16",
- "abi: cannot unmarshal int32 in to int16",
- },
- {
- `[ { "type": "int17" } ]`,
- pad([]byte{1}, 32, true),
- nil,
- "int16",
- "abi: cannot unmarshal *big.Int in to int16",
- },
- {
- `[ { "type": "int17" } ]`,
- pad([]byte{1}, 32, true),
- big.NewInt(1),
- "*big.Int",
- "",
- },
-
- {
- `[ { "type": "address" } ]`,
- pad(pad([]byte{1}, 20, false), 32, true),
- common.Address{1},
- "address",
- "",
- },
- {
- `[ { "type": "bytes32" } ]`,
- pad([]byte{1}, 32, false),
- pad([]byte{1}, 32, false),
- "bytes",
- "",
- },
- {
- `[ { "type": "bytes32" } ]`,
- pad([]byte{1}, 32, false),
- pad([]byte{1}, 32, false),
- "hash",
- "",
- },
- {
- `[ { "type": "bytes32" } ]`,
- pad([]byte{1}, 32, false),
- pad([]byte{1}, 32, false),
- "interface",
- "",
- },
- {
- `[ { "type": "function" } ]`,
- pad([]byte{1}, 32, false),
- [24]byte{1},
- "function",
- "",
- },
- } {
- abiDefinition := fmt.Sprintf(`[{ "name" : "method", "outputs": %s}]`, test.def)
- abi, err := JSON(strings.NewReader(abiDefinition))
- if err != nil {
- t.Errorf("%d failed. %v", i, err)
- continue
- }
-
- var outvar interface{}
- switch test.outVar {
- case "uint8":
- var v uint8
- err = abi.Unpack(&v, "method", test.marshalledOutput)
- outvar = v
- case "uint16":
- var v uint16
- err = abi.Unpack(&v, "method", test.marshalledOutput)
- outvar = v
- case "uint32":
- var v uint32
- err = abi.Unpack(&v, "method", test.marshalledOutput)
- outvar = v
- case "uint64":
- var v uint64
- err = abi.Unpack(&v, "method", test.marshalledOutput)
- outvar = v
- case "int8":
- var v int8
- err = abi.Unpack(&v, "method", test.marshalledOutput)
- outvar = v
- case "int16":
- var v int16
- err = abi.Unpack(&v, "method", test.marshalledOutput)
- outvar = v
- case "int32":
- var v int32
- err = abi.Unpack(&v, "method", test.marshalledOutput)
- outvar = v
- case "int64":
- var v int64
- err = abi.Unpack(&v, "method", test.marshalledOutput)
- outvar = v
- case "*big.Int":
- var v *big.Int
- err = abi.Unpack(&v, "method", test.marshalledOutput)
- outvar = v
- case "address":
- var v common.Address
- err = abi.Unpack(&v, "method", test.marshalledOutput)
- outvar = v
- case "bytes":
- var v []byte
- err = abi.Unpack(&v, "method", test.marshalledOutput)
- outvar = v
- case "hash":
- var v common.Hash
- err = abi.Unpack(&v, "method", test.marshalledOutput)
- outvar = v
- case "function":
- var v [24]byte
- err = abi.Unpack(&v, "method", test.marshalledOutput)
- outvar = v
- case "interface":
- err = abi.Unpack(&outvar, "method", test.marshalledOutput)
- default:
- t.Errorf("unsupported type '%v' please add it to the switch statement in this test", test.outVar)
- continue
- }
-
- if err != nil && len(test.err) == 0 {
- t.Errorf("%d failed. Expected no err but got: %v", i, err)
- continue
- }
- if err == nil && len(test.err) != 0 {
- t.Errorf("%d failed. Expected err: %v but got none", i, test.err)
- continue
- }
- if err != nil && len(test.err) != 0 && err.Error() != test.err {
- t.Errorf("%d failed. Expected err: '%v' got err: '%v'", i, test.err, err)
- continue
- }
-
- if err == nil {
- // bit of an ugly hack for hash type but I don't feel like finding a proper solution
- if test.outVar == "hash" {
- tmp := outvar.(common.Hash) // without assignment it's unaddressable
- outvar = tmp[:]
- }
-
- if !reflect.DeepEqual(test.expectedOut, outvar) {
- t.Errorf("%d failed. Output error: expected %v, got %v", i, test.expectedOut, outvar)
- }
- }
- }
-}
-
-func TestUnpackSetInterfaceSlice(t *testing.T) {
- var (
- var1 = new(uint8)
- var2 = new(uint8)
- )
- out := []interface{}{var1, var2}
- abi, err := JSON(strings.NewReader(`[{"type":"function", "name":"ints", "outputs":[{"type":"uint8"}, {"type":"uint8"}]}]`))
- if err != nil {
- t.Fatal(err)
- }
- marshalledReturn := append(pad([]byte{1}, 32, true), pad([]byte{2}, 32, true)...)
- err = abi.Unpack(&out, "ints", marshalledReturn)
- if err != nil {
- t.Fatal(err)
- }
- if *var1 != 1 {
- t.Error("expected var1 to be 1, got", *var1)
- }
- if *var2 != 2 {
- t.Error("expected var2 to be 2, got", *var2)
- }
-
- out = []interface{}{var1}
- err = abi.Unpack(&out, "ints", marshalledReturn)
-
- expErr := "abi: cannot marshal in to slices of unequal size (require: 2, got: 1)"
- if err == nil || err.Error() != expErr {
- t.Error("expected err:", expErr, "Got:", err)
- }
-}
-
-func TestUnpackSetInterfaceArrayOutput(t *testing.T) {
- var (
- var1 = new([1]uint32)
- var2 = new([1]uint32)
- )
- out := []interface{}{var1, var2}
- abi, err := JSON(strings.NewReader(`[{"type":"function", "name":"ints", "outputs":[{"type":"uint32[1]"}, {"type":"uint32[1]"}]}]`))
- if err != nil {
- t.Fatal(err)
- }
- marshalledReturn := append(pad([]byte{1}, 32, true), pad([]byte{2}, 32, true)...)
- err = abi.Unpack(&out, "ints", marshalledReturn)
- if err != nil {
- t.Fatal(err)
- }
-
- if *var1 != [1]uint32{1} {
- t.Error("expected var1 to be [1], got", *var1)
- }
- if *var2 != [1]uint32{2} {
- t.Error("expected var2 to be [2], got", *var2)
- }
-}
-
-func TestPack(t *testing.T) {
- for i, test := range []struct {
- typ string
-
- input interface{}
- output []byte
- }{
- {"uint16", uint16(2), pad([]byte{2}, 32, true)},
- {"uint16[]", []uint16{1, 2}, formatSliceOutput([]byte{1}, []byte{2})},
- {"bytes20", [20]byte{1}, pad([]byte{1}, 32, false)},
- {"uint256[]", []*big.Int{big.NewInt(1), big.NewInt(2)}, formatSliceOutput([]byte{1}, []byte{2})},
- {"address[]", []common.Address{{1}, {2}}, formatSliceOutput(pad([]byte{1}, 20, false), pad([]byte{2}, 20, false))},
- {"bytes32[]", []common.Hash{{1}, {2}}, formatSliceOutput(pad([]byte{1}, 32, false), pad([]byte{2}, 32, false))},
- {"function", [24]byte{1}, pad([]byte{1}, 32, false)},
- } {
- typ, err := NewType(test.typ)
- if err != nil {
- t.Fatal("unexpected parse error:", err)
- }
-
- output, err := typ.pack(reflect.ValueOf(test.input))
- if err != nil {
- t.Fatal("unexpected pack error:", err)
- }
-
- if !bytes.Equal(output, test.output) {
- t.Errorf("%d failed. Expected bytes: '%x' Got: '%x'", i, test.output, output)
- }
- }
-}
-
-func TestMethodPack(t *testing.T) {
- abi, err := JSON(strings.NewReader(jsondata2))
- if err != nil {
- t.Fatal(err)
- }
-
- sig := abi.Methods["slice"].Id()
- sig = append(sig, common.LeftPadBytes([]byte{1}, 32)...)
- sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...)
-
- packed, err := abi.Pack("slice", []uint32{1, 2})
- if err != nil {
- t.Error(err)
- }
-
- if !bytes.Equal(packed, sig) {
- t.Errorf("expected %x got %x", sig, packed)
- }
-
- var addrA, addrB = common.Address{1}, common.Address{2}
- sig = abi.Methods["sliceAddress"].Id()
- sig = append(sig, common.LeftPadBytes([]byte{32}, 32)...)
- sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...)
- sig = append(sig, common.LeftPadBytes(addrA[:], 32)...)
- sig = append(sig, common.LeftPadBytes(addrB[:], 32)...)
-
- packed, err = abi.Pack("sliceAddress", []common.Address{addrA, addrB})
- if err != nil {
- t.Fatal(err)
- }
- if !bytes.Equal(packed, sig) {
- t.Errorf("expected %x got %x", sig, packed)
- }
-
- var addrC, addrD = common.Address{3}, common.Address{4}
- sig = abi.Methods["sliceMultiAddress"].Id()
- sig = append(sig, common.LeftPadBytes([]byte{64}, 32)...)
- sig = append(sig, common.LeftPadBytes([]byte{160}, 32)...)
- sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...)
- sig = append(sig, common.LeftPadBytes(addrA[:], 32)...)
- sig = append(sig, common.LeftPadBytes(addrB[:], 32)...)
- sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...)
- sig = append(sig, common.LeftPadBytes(addrC[:], 32)...)
- sig = append(sig, common.LeftPadBytes(addrD[:], 32)...)
-
- packed, err = abi.Pack("sliceMultiAddress", []common.Address{addrA, addrB}, []common.Address{addrC, addrD})
- if err != nil {
- t.Fatal(err)
- }
- if !bytes.Equal(packed, sig) {
- t.Errorf("expected %x got %x", sig, packed)
- }
-
- sig = abi.Methods["slice256"].Id()
- sig = append(sig, common.LeftPadBytes([]byte{1}, 32)...)
- sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...)
-
- packed, err = abi.Pack("slice256", []*big.Int{big.NewInt(1), big.NewInt(2)})
- if err != nil {
- t.Error(err)
- }
-
- if !bytes.Equal(packed, sig) {
- t.Errorf("expected %x got %x", sig, packed)
- }
-}
-
const jsondata = `
[
{ "type" : "function", "name" : "balance", "constant" : true },
@@ -843,399 +437,3 @@ func TestBareEvents(t *testing.T) {
}
}
}
-
-func TestMultiReturnWithStruct(t *testing.T) {
- const definition = `[
- { "name" : "multi", "constant" : false, "outputs": [ { "name": "Int", "type": "uint256" }, { "name": "String", "type": "string" } ] }]`
-
- abi, err := JSON(strings.NewReader(definition))
- if err != nil {
- t.Fatal(err)
- }
-
- // using buff to make the code readable
- buff := new(bytes.Buffer)
- buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001"))
- buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000040"))
- buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000005"))
- stringOut := "hello"
- buff.Write(common.RightPadBytes([]byte(stringOut), 32))
-
- var inter struct {
- Int *big.Int
- String string
- }
- err = abi.Unpack(&inter, "multi", buff.Bytes())
- if err != nil {
- t.Error(err)
- }
-
- if inter.Int == nil || inter.Int.Cmp(big.NewInt(1)) != 0 {
- t.Error("expected Int to be 1 got", inter.Int)
- }
-
- if inter.String != stringOut {
- t.Error("expected String to be", stringOut, "got", inter.String)
- }
-
- var reversed struct {
- String string
- Int *big.Int
- }
-
- err = abi.Unpack(&reversed, "multi", buff.Bytes())
- if err != nil {
- t.Error(err)
- }
-
- if reversed.Int == nil || reversed.Int.Cmp(big.NewInt(1)) != 0 {
- t.Error("expected Int to be 1 got", reversed.Int)
- }
-
- if reversed.String != stringOut {
- t.Error("expected String to be", stringOut, "got", reversed.String)
- }
-}
-
-func TestMultiReturnWithSlice(t *testing.T) {
- const definition = `[
- { "name" : "multi", "constant" : false, "outputs": [ { "name": "Int", "type": "uint256" }, { "name": "String", "type": "string" } ] }]`
-
- abi, err := JSON(strings.NewReader(definition))
- if err != nil {
- t.Fatal(err)
- }
-
- // using buff to make the code readable
- buff := new(bytes.Buffer)
- buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001"))
- buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000040"))
- buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000005"))
- stringOut := "hello"
- buff.Write(common.RightPadBytes([]byte(stringOut), 32))
-
- var inter []interface{}
- err = abi.Unpack(&inter, "multi", buff.Bytes())
- if err != nil {
- t.Error(err)
- }
-
- if len(inter) != 2 {
- t.Fatal("expected 2 results got", len(inter))
- }
-
- if num, ok := inter[0].(*big.Int); !ok || num.Cmp(big.NewInt(1)) != 0 {
- t.Error("expected index 0 to be 1 got", num)
- }
-
- if str, ok := inter[1].(string); !ok || str != stringOut {
- t.Error("expected index 1 to be", stringOut, "got", str)
- }
-}
-
-func TestMarshalArrays(t *testing.T) {
- const definition = `[
- { "name" : "bytes32", "constant" : false, "outputs": [ { "type": "bytes32" } ] },
- { "name" : "bytes10", "constant" : false, "outputs": [ { "type": "bytes10" } ] }
- ]`
-
- abi, err := JSON(strings.NewReader(definition))
- if err != nil {
- t.Fatal(err)
- }
-
- output := common.LeftPadBytes([]byte{1}, 32)
-
- var bytes10 [10]byte
- err = abi.Unpack(&bytes10, "bytes32", output)
- if err == nil || err.Error() != "abi: cannot unmarshal src (len=32) in to dst (len=10)" {
- t.Error("expected error or bytes32 not be assignable to bytes10:", err)
- }
-
- var bytes32 [32]byte
- err = abi.Unpack(&bytes32, "bytes32", output)
- if err != nil {
- t.Error("didn't expect error:", err)
- }
- if !bytes.Equal(bytes32[:], output) {
- t.Error("expected bytes32[31] to be 1 got", bytes32[31])
- }
-
- type (
- B10 [10]byte
- B32 [32]byte
- )
-
- var b10 B10
- err = abi.Unpack(&b10, "bytes32", output)
- if err == nil || err.Error() != "abi: cannot unmarshal src (len=32) in to dst (len=10)" {
- t.Error("expected error or bytes32 not be assignable to bytes10:", err)
- }
-
- var b32 B32
- err = abi.Unpack(&b32, "bytes32", output)
- if err != nil {
- t.Error("didn't expect error:", err)
- }
- if !bytes.Equal(b32[:], output) {
- t.Error("expected bytes32[31] to be 1 got", bytes32[31])
- }
-
- output[10] = 1
- var shortAssignLong [32]byte
- err = abi.Unpack(&shortAssignLong, "bytes10", output)
- if err != nil {
- t.Error("didn't expect error:", err)
- }
- if !bytes.Equal(output, shortAssignLong[:]) {
- t.Errorf("expected %x to be %x", shortAssignLong, output)
- }
-}
-
-func TestUnmarshal(t *testing.T) {
- const definition = `[
- { "name" : "int", "constant" : false, "outputs": [ { "type": "uint256" } ] },
- { "name" : "bool", "constant" : false, "outputs": [ { "type": "bool" } ] },
- { "name" : "bytes", "constant" : false, "outputs": [ { "type": "bytes" } ] },
- { "name" : "fixed", "constant" : false, "outputs": [ { "type": "bytes32" } ] },
- { "name" : "multi", "constant" : false, "outputs": [ { "type": "bytes" }, { "type": "bytes" } ] },
- { "name" : "intArraySingle", "constant" : false, "outputs": [ { "type": "uint256[3]" } ] },
- { "name" : "addressSliceSingle", "constant" : false, "outputs": [ { "type": "address[]" } ] },
- { "name" : "addressSliceDouble", "constant" : false, "outputs": [ { "name": "a", "type": "address[]" }, { "name": "b", "type": "address[]" } ] },
- { "name" : "mixedBytes", "constant" : true, "outputs": [ { "name": "a", "type": "bytes" }, { "name": "b", "type": "bytes32" } ] }]`
-
- abi, err := JSON(strings.NewReader(definition))
- if err != nil {
- t.Fatal(err)
- }
- buff := new(bytes.Buffer)
-
- // marshal int
- var Int *big.Int
- err = abi.Unpack(&Int, "int", common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001"))
- if err != nil {
- t.Error(err)
- }
-
- if Int == nil || Int.Cmp(big.NewInt(1)) != 0 {
- t.Error("expected Int to be 1 got", Int)
- }
-
- // marshal bool
- var Bool bool
- err = abi.Unpack(&Bool, "bool", common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001"))
- if err != nil {
- t.Error(err)
- }
-
- if !Bool {
- t.Error("expected Bool to be true")
- }
-
- // marshal dynamic bytes max length 32
- buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020"))
- buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020"))
- bytesOut := common.RightPadBytes([]byte("hello"), 32)
- buff.Write(bytesOut)
-
- var Bytes []byte
- err = abi.Unpack(&Bytes, "bytes", buff.Bytes())
- if err != nil {
- t.Error(err)
- }
-
- if !bytes.Equal(Bytes, bytesOut) {
- t.Errorf("expected %x got %x", bytesOut, Bytes)
- }
-
- // marshall dynamic bytes max length 64
- buff.Reset()
- buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020"))
- buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000040"))
- bytesOut = common.RightPadBytes([]byte("hello"), 64)
- buff.Write(bytesOut)
-
- err = abi.Unpack(&Bytes, "bytes", buff.Bytes())
- if err != nil {
- t.Error(err)
- }
-
- if !bytes.Equal(Bytes, bytesOut) {
- t.Errorf("expected %x got %x", bytesOut, Bytes)
- }
-
- // marshall dynamic bytes max length 63
- buff.Reset()
- buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020"))
- buff.Write(common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000003f"))
- bytesOut = common.RightPadBytes([]byte("hello"), 63)
- buff.Write(bytesOut)
-
- err = abi.Unpack(&Bytes, "bytes", buff.Bytes())
- if err != nil {
- t.Error(err)
- }
-
- if !bytes.Equal(Bytes, bytesOut) {
- t.Errorf("expected %x got %x", bytesOut, Bytes)
- }
-
- // marshal dynamic bytes output empty
- err = abi.Unpack(&Bytes, "bytes", nil)
- if err == nil {
- t.Error("expected error")
- }
-
- // marshal dynamic bytes length 5
- buff.Reset()
- buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020"))
- buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000005"))
- buff.Write(common.RightPadBytes([]byte("hello"), 32))
-
- err = abi.Unpack(&Bytes, "bytes", buff.Bytes())
- if err != nil {
- t.Error(err)
- }
-
- if !bytes.Equal(Bytes, []byte("hello")) {
- t.Errorf("expected %x got %x", bytesOut, Bytes)
- }
-
- // marshal dynamic bytes length 5
- buff.Reset()
- buff.Write(common.RightPadBytes([]byte("hello"), 32))
-
- var hash common.Hash
- err = abi.Unpack(&hash, "fixed", buff.Bytes())
- if err != nil {
- t.Error(err)
- }
-
- helloHash := common.BytesToHash(common.RightPadBytes([]byte("hello"), 32))
- if hash != helloHash {
- t.Errorf("Expected %x to equal %x", hash, helloHash)
- }
-
- // marshal error
- buff.Reset()
- buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020"))
- err = abi.Unpack(&Bytes, "bytes", buff.Bytes())
- if err == nil {
- t.Error("expected error")
- }
-
- err = abi.Unpack(&Bytes, "multi", make([]byte, 64))
- if err == nil {
- t.Error("expected error")
- }
-
- // marshal mixed bytes
- buff.Reset()
- buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000040"))
- fixed := common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")
- buff.Write(fixed)
- buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020"))
- bytesOut = common.RightPadBytes([]byte("hello"), 32)
- buff.Write(bytesOut)
-
- var out []interface{}
- err = abi.Unpack(&out, "mixedBytes", buff.Bytes())
- if err != nil {
- t.Fatal("didn't expect error:", err)
- }
-
- if !bytes.Equal(bytesOut, out[0].([]byte)) {
- t.Errorf("expected %x, got %x", bytesOut, out[0])
- }
-
- if !bytes.Equal(fixed, out[1].([]byte)) {
- t.Errorf("expected %x, got %x", fixed, out[1])
- }
-
- buff.Reset()
- buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001"))
- buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"))
- buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000003"))
- // marshal int array
- var intArray [3]*big.Int
- err = abi.Unpack(&intArray, "intArraySingle", buff.Bytes())
- if err != nil {
- t.Error(err)
- }
- var testAgainstIntArray [3]*big.Int
- testAgainstIntArray[0] = big.NewInt(1)
- testAgainstIntArray[1] = big.NewInt(2)
- testAgainstIntArray[2] = big.NewInt(3)
-
- for i, Int := range intArray {
- if Int.Cmp(testAgainstIntArray[i]) != 0 {
- t.Errorf("expected %v, got %v", testAgainstIntArray[i], Int)
- }
- }
- // marshal address slice
- buff.Reset()
- buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020")) // offset
- buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) // size
- buff.Write(common.Hex2Bytes("0000000000000000000000000100000000000000000000000000000000000000"))
-
- var outAddr []common.Address
- err = abi.Unpack(&outAddr, "addressSliceSingle", buff.Bytes())
- if err != nil {
- t.Fatal("didn't expect error:", err)
- }
-
- if len(outAddr) != 1 {
- t.Fatal("expected 1 item, got", len(outAddr))
- }
-
- if outAddr[0] != (common.Address{1}) {
- t.Errorf("expected %x, got %x", common.Address{1}, outAddr[0])
- }
-
- // marshal multiple address slice
- buff.Reset()
- buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000040")) // offset
- buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000080")) // offset
- buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) // size
- buff.Write(common.Hex2Bytes("0000000000000000000000000100000000000000000000000000000000000000"))
- buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002")) // size
- buff.Write(common.Hex2Bytes("0000000000000000000000000200000000000000000000000000000000000000"))
- buff.Write(common.Hex2Bytes("0000000000000000000000000300000000000000000000000000000000000000"))
-
- var outAddrStruct struct {
- A []common.Address
- B []common.Address
- }
- err = abi.Unpack(&outAddrStruct, "addressSliceDouble", buff.Bytes())
- if err != nil {
- t.Fatal("didn't expect error:", err)
- }
-
- if len(outAddrStruct.A) != 1 {
- t.Fatal("expected 1 item, got", len(outAddrStruct.A))
- }
-
- if outAddrStruct.A[0] != (common.Address{1}) {
- t.Errorf("expected %x, got %x", common.Address{1}, outAddrStruct.A[0])
- }
-
- if len(outAddrStruct.B) != 2 {
- t.Fatal("expected 1 item, got", len(outAddrStruct.B))
- }
-
- if outAddrStruct.B[0] != (common.Address{2}) {
- t.Errorf("expected %x, got %x", common.Address{2}, outAddrStruct.B[0])
- }
- if outAddrStruct.B[1] != (common.Address{3}) {
- t.Errorf("expected %x, got %x", common.Address{3}, outAddrStruct.B[1])
- }
-
- // marshal invalid address slice
- buff.Reset()
- buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000100"))
-
- err = abi.Unpack(&outAddr, "addressSliceSingle", buff.Bytes())
- if err == nil {
- t.Fatal("expected error:", err)
- }
-}
diff --git a/accounts/abi/bind/backends/simulated.go b/accounts/abi/bind/backends/simulated.go
index 159fca136..7ac8b5820 100644
--- a/accounts/abi/bind/backends/simulated.go
+++ b/accounts/abi/bind/backends/simulated.go
@@ -90,7 +90,7 @@ func (b *SimulatedBackend) Rollback() {
func (b *SimulatedBackend) rollback() {
blocks, _ := core.GenerateChain(b.config, b.blockchain.CurrentBlock(), b.database, 1, func(int, *core.BlockGen) {})
b.pendingBlock = blocks[0]
- b.pendingState, _ = state.New(b.pendingBlock.Root(), b.database)
+ b.pendingState, _ = state.New(b.pendingBlock.Root(), state.NewDatabase(b.database))
}
// CodeAt returns the code associated with a certain account in the blockchain.
@@ -279,7 +279,7 @@ func (b *SimulatedBackend) SendTransaction(ctx context.Context, tx *types.Transa
block.AddTx(tx)
})
b.pendingBlock = blocks[0]
- b.pendingState, _ = state.New(b.pendingBlock.Root(), b.database)
+ b.pendingState, _ = state.New(b.pendingBlock.Root(), state.NewDatabase(b.database))
return nil
}
diff --git a/accounts/abi/error.go b/accounts/abi/error.go
index 67739c21d..420acf418 100644
--- a/accounts/abi/error.go
+++ b/accounts/abi/error.go
@@ -17,10 +17,15 @@
package abi
import (
+ "errors"
"fmt"
"reflect"
)
+var (
+ errBadBool = errors.New("abi: improperly encoded boolean value")
+)
+
// formatSliceString formats the reflection kind with the given slice size
// and returns a formatted string representation.
func formatSliceString(kind reflect.Kind, sliceSize int) string {
diff --git a/accounts/abi/method.go b/accounts/abi/method.go
index d56f3bc3d..62b3d2957 100644
--- a/accounts/abi/method.go
+++ b/accounts/abi/method.go
@@ -39,7 +39,7 @@ type Method struct {
Outputs []Argument
}
-func (m Method) pack(method Method, args ...interface{}) ([]byte, error) {
+func (method Method) pack(args ...interface{}) ([]byte, error) {
// Make sure arguments match up and pack them
if len(args) != len(method.Inputs) {
return nil, fmt.Errorf("argument count mismatch: %d for %d", len(args), len(method.Inputs))
diff --git a/accounts/abi/numbers.go b/accounts/abi/numbers.go
index 10afa6511..5d3efff52 100644
--- a/accounts/abi/numbers.go
+++ b/accounts/abi/numbers.go
@@ -62,19 +62,6 @@ func U256(n *big.Int) []byte {
return math.PaddedBigBytes(math.U256(n), 32)
}
-// packNum packs the given number (using the reflect value) and will cast it to appropriate number representation
-func packNum(value reflect.Value) []byte {
- switch kind := value.Kind(); kind {
- case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
- return U256(new(big.Int).SetUint64(value.Uint()))
- case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
- return U256(big.NewInt(value.Int()))
- case reflect.Ptr:
- return U256(value.Interface().(*big.Int))
- }
- return nil
-}
-
// checks whether the given reflect value is signed. This also works for slices with a number type
func isSigned(v reflect.Value) bool {
switch v.Type() {
diff --git a/accounts/abi/numbers_test.go b/accounts/abi/numbers_test.go
index 44afe8647..b9ff5aef1 100644
--- a/accounts/abi/numbers_test.go
+++ b/accounts/abi/numbers_test.go
@@ -18,7 +18,6 @@ package abi
import (
"bytes"
- "math"
"math/big"
"reflect"
"testing"
@@ -34,43 +33,6 @@ func TestNumberTypes(t *testing.T) {
}
}
-func TestPackNumber(t *testing.T) {
- tests := []struct {
- value reflect.Value
- packed []byte
- }{
- // Protocol limits
- {reflect.ValueOf(0), []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}},
- {reflect.ValueOf(1), []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}},
- {reflect.ValueOf(-1), []byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}},
-
- // Type corner cases
- {reflect.ValueOf(uint8(math.MaxUint8)), []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255}},
- {reflect.ValueOf(uint16(math.MaxUint16)), []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255}},
- {reflect.ValueOf(uint32(math.MaxUint32)), []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 255}},
- {reflect.ValueOf(uint64(math.MaxUint64)), []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 255, 255, 255, 255, 255}},
-
- {reflect.ValueOf(int8(math.MaxInt8)), []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 127}},
- {reflect.ValueOf(int16(math.MaxInt16)), []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 127, 255}},
- {reflect.ValueOf(int32(math.MaxInt32)), []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 127, 255, 255, 255}},
- {reflect.ValueOf(int64(math.MaxInt64)), []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 127, 255, 255, 255, 255, 255, 255, 255}},
-
- {reflect.ValueOf(int8(math.MinInt8)), []byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 128}},
- {reflect.ValueOf(int16(math.MinInt16)), []byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 128, 0}},
- {reflect.ValueOf(int32(math.MinInt32)), []byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 128, 0, 0, 0}},
- {reflect.ValueOf(int64(math.MinInt64)), []byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 128, 0, 0, 0, 0, 0, 0, 0}},
- }
- for i, tt := range tests {
- packed := packNum(tt.value)
- if !bytes.Equal(packed, tt.packed) {
- t.Errorf("test %d: pack mismatch: have %x, want %x", i, packed, tt.packed)
- }
- }
- if packed := packNum(reflect.ValueOf("string")); packed != nil {
- t.Errorf("expected 'string' to pack to nil. got %x instead", packed)
- }
-}
-
func TestSigned(t *testing.T) {
if isSigned(reflect.ValueOf(uint(10))) {
t.Error("signed")
diff --git a/accounts/abi/packing.go b/accounts/abi/pack.go
index 1d7f85e2b..4d8a3f031 100644
--- a/accounts/abi/packing.go
+++ b/accounts/abi/pack.go
@@ -17,6 +17,7 @@
package abi
import (
+ "math/big"
"reflect"
"github.com/ethereum/go-ethereum/common"
@@ -59,8 +60,20 @@ func packElement(t Type, reflectValue reflect.Value) []byte {
if reflectValue.Kind() == reflect.Array {
reflectValue = mustArrayToByteSlice(reflectValue)
}
-
return common.RightPadBytes(reflectValue.Bytes(), 32)
}
panic("abi: fatal error")
}
+
+// packNum packs the given number (using the reflect value) and will cast it to appropriate number representation
+func packNum(value reflect.Value) []byte {
+ switch kind := value.Kind(); kind {
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ return U256(new(big.Int).SetUint64(value.Uint()))
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ return U256(big.NewInt(value.Int()))
+ case reflect.Ptr:
+ return U256(value.Interface().(*big.Int))
+ }
+ return nil
+}
diff --git a/accounts/abi/pack_test.go b/accounts/abi/pack_test.go
new file mode 100644
index 000000000..c6cfb56ea
--- /dev/null
+++ b/accounts/abi/pack_test.go
@@ -0,0 +1,441 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+package abi
+
+import (
+ "bytes"
+ "math"
+ "math/big"
+ "reflect"
+ "strings"
+ "testing"
+
+ "github.com/ethereum/go-ethereum/common"
+)
+
+func TestPack(t *testing.T) {
+ for i, test := range []struct {
+ typ string
+
+ input interface{}
+ output []byte
+ }{
+ {
+ "uint8",
+ uint8(2),
+ common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"),
+ },
+ {
+ "uint8[]",
+ []uint8{1, 2},
+ common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"),
+ },
+ {
+ "uint16",
+ uint16(2),
+ common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"),
+ },
+ {
+ "uint16[]",
+ []uint16{1, 2},
+ common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"),
+ },
+ {
+ "uint32",
+ uint32(2),
+ common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"),
+ },
+ {
+ "uint32[]",
+ []uint32{1, 2},
+ common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"),
+ },
+ {
+ "uint64",
+ uint64(2),
+ common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"),
+ },
+ {
+ "uint64[]",
+ []uint64{1, 2},
+ common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"),
+ },
+ {
+ "uint256",
+ big.NewInt(2),
+ common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"),
+ },
+ {
+ "uint256[]",
+ []*big.Int{big.NewInt(1), big.NewInt(2)},
+ common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"),
+ },
+ {
+ "int8",
+ int8(2),
+ common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"),
+ },
+ {
+ "int8[]",
+ []int8{1, 2},
+ common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"),
+ },
+ {
+ "int16",
+ int16(2),
+ common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"),
+ },
+ {
+ "int16[]",
+ []int16{1, 2},
+ common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"),
+ },
+ {
+ "int32",
+ int32(2),
+ common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"),
+ },
+ {
+ "int32[]",
+ []int32{1, 2},
+ common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"),
+ },
+ {
+ "int64",
+ int64(2),
+ common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"),
+ },
+ {
+ "int64[]",
+ []int64{1, 2},
+ common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"),
+ },
+ {
+ "int256",
+ big.NewInt(2),
+ common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"),
+ },
+ {
+ "int256[]",
+ []*big.Int{big.NewInt(1), big.NewInt(2)},
+ common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"),
+ },
+ {
+ "bytes1",
+ [1]byte{1},
+ common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
+ },
+ {
+ "bytes2",
+ [2]byte{1},
+ common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
+ },
+ {
+ "bytes3",
+ [3]byte{1},
+ common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
+ },
+ {
+ "bytes4",
+ [4]byte{1},
+ common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
+ },
+ {
+ "bytes5",
+ [5]byte{1},
+ common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
+ },
+ {
+ "bytes6",
+ [6]byte{1},
+ common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
+ },
+ {
+ "bytes7",
+ [7]byte{1},
+ common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
+ },
+ {
+ "bytes8",
+ [8]byte{1},
+ common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
+ },
+ {
+ "bytes9",
+ [9]byte{1},
+ common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
+ },
+ {
+ "bytes10",
+ [10]byte{1},
+ common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
+ },
+ {
+ "bytes11",
+ [11]byte{1},
+ common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
+ },
+ {
+ "bytes12",
+ [12]byte{1},
+ common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
+ },
+ {
+ "bytes13",
+ [13]byte{1},
+ common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
+ },
+ {
+ "bytes14",
+ [14]byte{1},
+ common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
+ },
+ {
+ "bytes15",
+ [15]byte{1},
+ common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
+ },
+ {
+ "bytes16",
+ [16]byte{1},
+ common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
+ },
+ {
+ "bytes17",
+ [17]byte{1},
+ common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
+ },
+ {
+ "bytes18",
+ [18]byte{1},
+ common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
+ },
+ {
+ "bytes19",
+ [19]byte{1},
+ common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
+ },
+ {
+ "bytes20",
+ [20]byte{1},
+ common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
+ },
+ {
+ "bytes21",
+ [21]byte{1},
+ common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
+ },
+ {
+ "bytes22",
+ [22]byte{1},
+ common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
+ },
+ {
+ "bytes23",
+ [23]byte{1},
+ common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
+ },
+ {
+ "bytes24",
+ [24]byte{1},
+ common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
+ },
+ {
+ "bytes24",
+ [24]byte{1},
+ common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
+ },
+ {
+ "bytes25",
+ [25]byte{1},
+ common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
+ },
+ {
+ "bytes26",
+ [26]byte{1},
+ common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
+ },
+ {
+ "bytes27",
+ [27]byte{1},
+ common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
+ },
+ {
+ "bytes28",
+ [28]byte{1},
+ common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
+ },
+ {
+ "bytes29",
+ [29]byte{1},
+ common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
+ },
+ {
+ "bytes30",
+ [30]byte{1},
+ common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
+ },
+ {
+ "bytes31",
+ [31]byte{1},
+ common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
+ },
+ {
+ "bytes32",
+ [32]byte{1},
+ common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
+ },
+ {
+ "address[]",
+ []common.Address{{1}, {2}},
+ common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000001000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000"),
+ },
+ {
+ "bytes32[]",
+ []common.Hash{{1}, {2}},
+ common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000201000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000"),
+ },
+ {
+ "function",
+ [24]byte{1},
+ common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
+ },
+ {
+ "string",
+ "foobar",
+ common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000006666f6f6261720000000000000000000000000000000000000000000000000000"),
+ },
+ } {
+ typ, err := NewType(test.typ)
+ if err != nil {
+ t.Fatal("unexpected parse error:", err)
+ }
+
+ output, err := typ.pack(reflect.ValueOf(test.input))
+ if err != nil {
+ t.Fatal("unexpected pack error:", err)
+ }
+
+ if !bytes.Equal(output, test.output) {
+ t.Errorf("%d failed. Expected bytes: '%x' Got: '%x'", i, test.output, output)
+ }
+ }
+}
+
+func TestMethodPack(t *testing.T) {
+ abi, err := JSON(strings.NewReader(jsondata2))
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ sig := abi.Methods["slice"].Id()
+ sig = append(sig, common.LeftPadBytes([]byte{1}, 32)...)
+ sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...)
+
+ packed, err := abi.Pack("slice", []uint32{1, 2})
+ if err != nil {
+ t.Error(err)
+ }
+
+ if !bytes.Equal(packed, sig) {
+ t.Errorf("expected %x got %x", sig, packed)
+ }
+
+ var addrA, addrB = common.Address{1}, common.Address{2}
+ sig = abi.Methods["sliceAddress"].Id()
+ sig = append(sig, common.LeftPadBytes([]byte{32}, 32)...)
+ sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...)
+ sig = append(sig, common.LeftPadBytes(addrA[:], 32)...)
+ sig = append(sig, common.LeftPadBytes(addrB[:], 32)...)
+
+ packed, err = abi.Pack("sliceAddress", []common.Address{addrA, addrB})
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !bytes.Equal(packed, sig) {
+ t.Errorf("expected %x got %x", sig, packed)
+ }
+
+ var addrC, addrD = common.Address{3}, common.Address{4}
+ sig = abi.Methods["sliceMultiAddress"].Id()
+ sig = append(sig, common.LeftPadBytes([]byte{64}, 32)...)
+ sig = append(sig, common.LeftPadBytes([]byte{160}, 32)...)
+ sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...)
+ sig = append(sig, common.LeftPadBytes(addrA[:], 32)...)
+ sig = append(sig, common.LeftPadBytes(addrB[:], 32)...)
+ sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...)
+ sig = append(sig, common.LeftPadBytes(addrC[:], 32)...)
+ sig = append(sig, common.LeftPadBytes(addrD[:], 32)...)
+
+ packed, err = abi.Pack("sliceMultiAddress", []common.Address{addrA, addrB}, []common.Address{addrC, addrD})
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !bytes.Equal(packed, sig) {
+ t.Errorf("expected %x got %x", sig, packed)
+ }
+
+ sig = abi.Methods["slice256"].Id()
+ sig = append(sig, common.LeftPadBytes([]byte{1}, 32)...)
+ sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...)
+
+ packed, err = abi.Pack("slice256", []*big.Int{big.NewInt(1), big.NewInt(2)})
+ if err != nil {
+ t.Error(err)
+ }
+
+ if !bytes.Equal(packed, sig) {
+ t.Errorf("expected %x got %x", sig, packed)
+ }
+}
+
+func TestPackNumber(t *testing.T) {
+ tests := []struct {
+ value reflect.Value
+ packed []byte
+ }{
+ // Protocol limits
+ {reflect.ValueOf(0), common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000000")},
+ {reflect.ValueOf(1), common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")},
+ {reflect.ValueOf(-1), common.Hex2Bytes("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")},
+
+ // Type corner cases
+ {reflect.ValueOf(uint8(math.MaxUint8)), common.Hex2Bytes("00000000000000000000000000000000000000000000000000000000000000ff")},
+ {reflect.ValueOf(uint16(math.MaxUint16)), common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000ffff")},
+ {reflect.ValueOf(uint32(math.MaxUint32)), common.Hex2Bytes("00000000000000000000000000000000000000000000000000000000ffffffff")},
+ {reflect.ValueOf(uint64(math.MaxUint64)), common.Hex2Bytes("000000000000000000000000000000000000000000000000ffffffffffffffff")},
+
+ {reflect.ValueOf(int8(math.MaxInt8)), common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000007f")},
+ {reflect.ValueOf(int16(math.MaxInt16)), common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000007fff")},
+ {reflect.ValueOf(int32(math.MaxInt32)), common.Hex2Bytes("000000000000000000000000000000000000000000000000000000007fffffff")},
+ {reflect.ValueOf(int64(math.MaxInt64)), common.Hex2Bytes("0000000000000000000000000000000000000000000000007fffffffffffffff")},
+
+ {reflect.ValueOf(int8(math.MinInt8)), common.Hex2Bytes("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff80")},
+ {reflect.ValueOf(int16(math.MinInt16)), common.Hex2Bytes("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff8000")},
+ {reflect.ValueOf(int32(math.MinInt32)), common.Hex2Bytes("ffffffffffffffffffffffffffffffffffffffffffffffffffffffff80000000")},
+ {reflect.ValueOf(int64(math.MinInt64)), common.Hex2Bytes("ffffffffffffffffffffffffffffffffffffffffffffffff8000000000000000")},
+ }
+ for i, tt := range tests {
+ packed := packNum(tt.value)
+ if !bytes.Equal(packed, tt.packed) {
+ t.Errorf("test %d: pack mismatch: have %x, want %x", i, packed, tt.packed)
+ }
+ }
+ if packed := packNum(reflect.ValueOf("string")); packed != nil {
+ t.Errorf("expected 'string' to pack to nil. got %x instead", packed)
+ }
+}
diff --git a/accounts/abi/reflect.go b/accounts/abi/reflect.go
index 7970ba8ac..8fa75df07 100644
--- a/accounts/abi/reflect.go
+++ b/accounts/abi/reflect.go
@@ -32,30 +32,30 @@ func indirect(v reflect.Value) reflect.Value {
// reflectIntKind returns the reflect using the given size and
// unsignedness.
-func reflectIntKind(unsigned bool, size int) reflect.Kind {
+func reflectIntKindAndType(unsigned bool, size int) (reflect.Kind, reflect.Type) {
switch size {
case 8:
if unsigned {
- return reflect.Uint8
+ return reflect.Uint8, uint8_t
}
- return reflect.Int8
+ return reflect.Int8, int8_t
case 16:
if unsigned {
- return reflect.Uint16
+ return reflect.Uint16, uint16_t
}
- return reflect.Int16
+ return reflect.Int16, int16_t
case 32:
if unsigned {
- return reflect.Uint32
+ return reflect.Uint32, uint32_t
}
- return reflect.Int32
+ return reflect.Int32, int32_t
case 64:
if unsigned {
- return reflect.Uint64
+ return reflect.Uint64, uint64_t
}
- return reflect.Int64
+ return reflect.Int64, int64_t
}
- return reflect.Ptr
+ return reflect.Ptr, big_t
}
// mustArrayToBytesSlice creates a new byte slice with the exact same size as value
diff --git a/accounts/abi/type.go b/accounts/abi/type.go
index f2832aef5..5f20babb3 100644
--- a/accounts/abi/type.go
+++ b/accounts/abi/type.go
@@ -33,7 +33,7 @@ const (
FixedBytesTy
BytesTy
HashTy
- FixedpointTy
+ FixedPointTy
FunctionTy
)
@@ -126,13 +126,11 @@ func NewType(t string) (typ Type, err error) {
switch varType {
case "int":
- typ.Kind = reflectIntKind(false, varSize)
- typ.Type = big_t
+ typ.Kind, typ.Type = reflectIntKindAndType(false, varSize)
typ.Size = varSize
typ.T = IntTy
case "uint":
- typ.Kind = reflectIntKind(true, varSize)
- typ.Type = ubig_t
+ typ.Kind, typ.Type = reflectIntKindAndType(true, varSize)
typ.Size = varSize
typ.T = UintTy
case "bool":
diff --git a/accounts/abi/type_test.go b/accounts/abi/type_test.go
index 155806459..984a5bb4c 100644
--- a/accounts/abi/type_test.go
+++ b/accounts/abi/type_test.go
@@ -17,8 +17,11 @@
package abi
import (
+ "math/big"
"reflect"
"testing"
+
+ "github.com/ethereum/go-ethereum/common"
)
// typeWithoutStringer is a alias for the Type type which simply doesn't implement
@@ -31,26 +34,44 @@ func TestTypeRegexp(t *testing.T) {
blob string
kind Type
}{
- {"int", Type{Kind: reflect.Ptr, Type: big_t, Size: 256, T: IntTy, stringKind: "int256"}},
- {"int8", Type{Kind: reflect.Int8, Type: big_t, Size: 8, T: IntTy, stringKind: "int8"}},
+ {"bool", Type{Kind: reflect.Bool, T: BoolTy, stringKind: "bool"}},
+ {"bool[]", Type{IsSlice: true, SliceSize: -1, Kind: reflect.Bool, T: BoolTy, Elem: &Type{Kind: reflect.Bool, T: BoolTy, stringKind: "bool"}, stringKind: "bool[]"}},
+ {"bool[2]", Type{IsArray: true, SliceSize: 2, Kind: reflect.Bool, T: BoolTy, Elem: &Type{Kind: reflect.Bool, T: BoolTy, stringKind: "bool"}, stringKind: "bool[2]"}},
+ {"int8", Type{Kind: reflect.Int8, Type: int8_t, Size: 8, T: IntTy, stringKind: "int8"}},
+ {"int16", Type{Kind: reflect.Int16, Type: int16_t, Size: 16, T: IntTy, stringKind: "int16"}},
+ {"int32", Type{Kind: reflect.Int32, Type: int32_t, Size: 32, T: IntTy, stringKind: "int32"}},
+ {"int64", Type{Kind: reflect.Int64, Type: int64_t, Size: 64, T: IntTy, stringKind: "int64"}},
{"int256", Type{Kind: reflect.Ptr, Type: big_t, Size: 256, T: IntTy, stringKind: "int256"}},
- {"int[]", Type{IsSlice: true, SliceSize: -1, Kind: reflect.Ptr, Type: big_t, Size: 256, T: IntTy, Elem: &Type{Kind: reflect.Ptr, Type: big_t, Size: 256, T: IntTy, stringKind: "int256"}, stringKind: "int256[]"}},
- {"int[2]", Type{IsArray: true, SliceSize: 2, Kind: reflect.Ptr, Type: big_t, Size: 256, T: IntTy, Elem: &Type{Kind: reflect.Ptr, Type: big_t, Size: 256, T: IntTy, stringKind: "int256"}, stringKind: "int256[2]"}},
- {"int32[]", Type{IsSlice: true, SliceSize: -1, Kind: reflect.Int32, Type: big_t, Size: 32, T: IntTy, Elem: &Type{Kind: reflect.Int32, Type: big_t, Size: 32, T: IntTy, stringKind: "int32"}, stringKind: "int32[]"}},
- {"int32[2]", Type{IsArray: true, SliceSize: 2, Kind: reflect.Int32, Type: big_t, Size: 32, T: IntTy, Elem: &Type{Kind: reflect.Int32, Type: big_t, Size: 32, T: IntTy, stringKind: "int32"}, stringKind: "int32[2]"}},
- {"uint", Type{Kind: reflect.Ptr, Type: ubig_t, Size: 256, T: UintTy, stringKind: "uint256"}},
- {"uint8", Type{Kind: reflect.Uint8, Type: ubig_t, Size: 8, T: UintTy, stringKind: "uint8"}},
- {"uint256", Type{Kind: reflect.Ptr, Type: ubig_t, Size: 256, T: UintTy, stringKind: "uint256"}},
- {"uint[]", Type{IsSlice: true, SliceSize: -1, Kind: reflect.Ptr, Type: ubig_t, Size: 256, T: UintTy, Elem: &Type{Kind: reflect.Ptr, Type: ubig_t, Size: 256, T: UintTy, stringKind: "uint256"}, stringKind: "uint256[]"}},
- {"uint[2]", Type{IsArray: true, SliceSize: 2, Kind: reflect.Ptr, Type: ubig_t, Size: 256, T: UintTy, Elem: &Type{Kind: reflect.Ptr, Type: ubig_t, Size: 256, T: UintTy, stringKind: "uint256"}, stringKind: "uint256[2]"}},
- {"uint32[]", Type{IsSlice: true, SliceSize: -1, Kind: reflect.Uint32, Type: ubig_t, Size: 32, T: UintTy, Elem: &Type{Kind: reflect.Uint32, Type: big_t, Size: 32, T: UintTy, stringKind: "uint32"}, stringKind: "uint32[]"}},
- {"uint32[2]", Type{IsArray: true, SliceSize: 2, Kind: reflect.Uint32, Type: ubig_t, Size: 32, T: UintTy, Elem: &Type{Kind: reflect.Uint32, Type: big_t, Size: 32, T: UintTy, stringKind: "uint32"}, stringKind: "uint32[2]"}},
- {"bytes", Type{IsSlice: true, SliceSize: -1, Elem: &Type{Kind: reflect.Uint8, Type: ubig_t, Size: 8, T: UintTy, stringKind: "uint8"}, T: BytesTy, stringKind: "bytes"}},
- {"bytes32", Type{IsArray: true, SliceSize: 32, Elem: &Type{Kind: reflect.Uint8, Type: ubig_t, Size: 8, T: UintTy, stringKind: "uint8"}, T: FixedBytesTy, stringKind: "bytes32"}},
- {"bytes[]", Type{IsSlice: true, SliceSize: -1, Elem: &Type{IsSlice: true, SliceSize: -1, Elem: &Type{Kind: reflect.Uint8, Type: ubig_t, Size: 8, T: UintTy, stringKind: "uint8"}, T: BytesTy, stringKind: "bytes"}, stringKind: "bytes[]"}},
- {"bytes[2]", Type{IsArray: true, SliceSize: 2, Elem: &Type{IsSlice: true, SliceSize: -1, Elem: &Type{Kind: reflect.Uint8, Type: ubig_t, Size: 8, T: UintTy, stringKind: "uint8"}, T: BytesTy, stringKind: "bytes"}, stringKind: "bytes[2]"}},
- {"bytes32[]", Type{IsSlice: true, SliceSize: -1, Elem: &Type{IsArray: true, SliceSize: 32, Elem: &Type{Kind: reflect.Uint8, Type: ubig_t, Size: 8, T: UintTy, stringKind: "uint8"}, T: FixedBytesTy, stringKind: "bytes32"}, stringKind: "bytes32[]"}},
- {"bytes32[2]", Type{IsArray: true, SliceSize: 2, Elem: &Type{IsArray: true, SliceSize: 32, Elem: &Type{Kind: reflect.Uint8, Type: ubig_t, Size: 8, T: UintTy, stringKind: "uint8"}, T: FixedBytesTy, stringKind: "bytes32"}, stringKind: "bytes32[2]"}},
+ {"int8[]", Type{IsSlice: true, SliceSize: -1, Kind: reflect.Int8, Type: int8_t, Size: 8, T: IntTy, Elem: &Type{Kind: reflect.Int8, Type: int8_t, Size: 8, T: IntTy, stringKind: "int8"}, stringKind: "int8[]"}},
+ {"int8[2]", Type{IsArray: true, SliceSize: 2, Kind: reflect.Int8, Type: int8_t, Size: 8, T: IntTy, Elem: &Type{Kind: reflect.Int8, Type: int8_t, Size: 8, T: IntTy, stringKind: "int8"}, stringKind: "int8[2]"}},
+ {"int16[]", Type{IsSlice: true, SliceSize: -1, Kind: reflect.Int16, Type: int16_t, Size: 16, T: IntTy, Elem: &Type{Kind: reflect.Int16, Type: int16_t, Size: 16, T: IntTy, stringKind: "int16"}, stringKind: "int16[]"}},
+ {"int16[2]", Type{IsArray: true, SliceSize: 2, Kind: reflect.Int16, Type: int16_t, Size: 16, T: IntTy, Elem: &Type{Kind: reflect.Int16, Type: int16_t, Size: 16, T: IntTy, stringKind: "int16"}, stringKind: "int16[2]"}},
+ {"int32[]", Type{IsSlice: true, SliceSize: -1, Kind: reflect.Int32, Type: int32_t, Size: 32, T: IntTy, Elem: &Type{Kind: reflect.Int32, Type: int32_t, Size: 32, T: IntTy, stringKind: "int32"}, stringKind: "int32[]"}},
+ {"int32[2]", Type{IsArray: true, SliceSize: 2, Kind: reflect.Int32, Type: int32_t, Size: 32, T: IntTy, Elem: &Type{Kind: reflect.Int32, Type: int32_t, Size: 32, T: IntTy, stringKind: "int32"}, stringKind: "int32[2]"}},
+ {"int64[]", Type{IsSlice: true, SliceSize: -1, Kind: reflect.Int64, Type: int64_t, Size: 64, T: IntTy, Elem: &Type{Kind: reflect.Int64, Type: int64_t, Size: 64, T: IntTy, stringKind: "int64"}, stringKind: "int64[]"}},
+ {"int64[2]", Type{IsArray: true, SliceSize: 2, Kind: reflect.Int64, Type: int64_t, Size: 64, T: IntTy, Elem: &Type{Kind: reflect.Int64, Type: int64_t, Size: 64, T: IntTy, stringKind: "int64"}, stringKind: "int64[2]"}},
+ {"int256[]", Type{IsSlice: true, SliceSize: -1, Kind: reflect.Ptr, Type: big_t, Size: 256, T: IntTy, Elem: &Type{Kind: reflect.Ptr, Type: big_t, Size: 256, T: IntTy, stringKind: "int256"}, stringKind: "int256[]"}},
+ {"int256[2]", Type{IsArray: true, SliceSize: 2, Kind: reflect.Ptr, Type: big_t, Size: 256, T: IntTy, Elem: &Type{Kind: reflect.Ptr, Type: big_t, Size: 256, T: IntTy, stringKind: "int256"}, stringKind: "int256[2]"}},
+ {"uint8", Type{Kind: reflect.Uint8, Type: uint8_t, Size: 8, T: UintTy, stringKind: "uint8"}},
+ {"uint16", Type{Kind: reflect.Uint16, Type: uint16_t, Size: 16, T: UintTy, stringKind: "uint16"}},
+ {"uint32", Type{Kind: reflect.Uint32, Type: uint32_t, Size: 32, T: UintTy, stringKind: "uint32"}},
+ {"uint64", Type{Kind: reflect.Uint64, Type: uint64_t, Size: 64, T: UintTy, stringKind: "uint64"}},
+ {"uint256", Type{Kind: reflect.Ptr, Type: big_t, Size: 256, T: UintTy, stringKind: "uint256"}},
+ {"uint8[]", Type{IsSlice: true, SliceSize: -1, Kind: reflect.Uint8, Type: uint8_t, Size: 8, T: UintTy, Elem: &Type{Kind: reflect.Uint8, Type: uint8_t, Size: 8, T: UintTy, stringKind: "uint8"}, stringKind: "uint8[]"}},
+ {"uint8[2]", Type{IsArray: true, SliceSize: 2, Kind: reflect.Uint8, Type: uint8_t, Size: 8, T: UintTy, Elem: &Type{Kind: reflect.Uint8, Type: uint8_t, Size: 8, T: UintTy, stringKind: "uint8"}, stringKind: "uint8[2]"}},
+ {"uint16[]", Type{IsSlice: true, SliceSize: -1, Kind: reflect.Uint16, Type: uint16_t, Size: 16, T: UintTy, Elem: &Type{Kind: reflect.Uint16, Type: uint16_t, Size: 16, T: UintTy, stringKind: "uint16"}, stringKind: "uint16[]"}},
+ {"uint16[2]", Type{IsArray: true, SliceSize: 2, Kind: reflect.Uint16, Type: uint16_t, Size: 16, T: UintTy, Elem: &Type{Kind: reflect.Uint16, Type: uint16_t, Size: 16, T: UintTy, stringKind: "uint16"}, stringKind: "uint16[2]"}},
+ {"uint32[]", Type{IsSlice: true, SliceSize: -1, Kind: reflect.Uint32, Type: uint32_t, Size: 32, T: UintTy, Elem: &Type{Kind: reflect.Uint32, Type: uint32_t, Size: 32, T: UintTy, stringKind: "uint32"}, stringKind: "uint32[]"}},
+ {"uint32[2]", Type{IsArray: true, SliceSize: 2, Kind: reflect.Uint32, Type: uint32_t, Size: 32, T: UintTy, Elem: &Type{Kind: reflect.Uint32, Type: uint32_t, Size: 32, T: UintTy, stringKind: "uint32"}, stringKind: "uint32[2]"}},
+ {"uint64[]", Type{IsSlice: true, SliceSize: -1, Kind: reflect.Uint64, Type: uint64_t, Size: 64, T: UintTy, Elem: &Type{Kind: reflect.Uint64, Type: uint64_t, Size: 64, T: UintTy, stringKind: "uint64"}, stringKind: "uint64[]"}},
+ {"uint64[2]", Type{IsArray: true, SliceSize: 2, Kind: reflect.Uint64, Type: uint64_t, Size: 64, T: UintTy, Elem: &Type{Kind: reflect.Uint64, Type: uint64_t, Size: 64, T: UintTy, stringKind: "uint64"}, stringKind: "uint64[2]"}},
+ {"uint256[]", Type{IsSlice: true, SliceSize: -1, Kind: reflect.Ptr, Type: big_t, Size: 256, T: UintTy, Elem: &Type{Kind: reflect.Ptr, Type: big_t, Size: 256, T: UintTy, stringKind: "uint256"}, stringKind: "uint256[]"}},
+ {"uint256[2]", Type{IsArray: true, SliceSize: 2, Kind: reflect.Ptr, Type: big_t, Size: 256, T: UintTy, Elem: &Type{Kind: reflect.Ptr, Type: big_t, Size: 256, T: UintTy, stringKind: "uint256"}, stringKind: "uint256[2]"}},
+ {"bytes32", Type{IsArray: true, SliceSize: 32, Elem: &Type{Kind: reflect.Uint8, Type: uint8_t, Size: 8, T: UintTy, stringKind: "uint8"}, T: FixedBytesTy, stringKind: "bytes32"}},
+ {"bytes[]", Type{IsSlice: true, SliceSize: -1, Elem: &Type{IsSlice: true, SliceSize: -1, Elem: &Type{Kind: reflect.Uint8, Type: uint8_t, Size: 8, T: UintTy, stringKind: "uint8"}, T: BytesTy, stringKind: "bytes"}, stringKind: "bytes[]"}},
+ {"bytes[2]", Type{IsArray: true, SliceSize: 2, Elem: &Type{IsSlice: true, SliceSize: -1, Elem: &Type{Kind: reflect.Uint8, Type: uint8_t, Size: 8, T: UintTy, stringKind: "uint8"}, T: BytesTy, stringKind: "bytes"}, stringKind: "bytes[2]"}},
+ {"bytes32[]", Type{IsSlice: true, SliceSize: -1, Elem: &Type{IsArray: true, SliceSize: 32, Elem: &Type{Kind: reflect.Uint8, Type: uint8_t, Size: 8, T: UintTy, stringKind: "uint8"}, T: FixedBytesTy, stringKind: "bytes32"}, stringKind: "bytes32[]"}},
+ {"bytes32[2]", Type{IsArray: true, SliceSize: 2, Elem: &Type{IsArray: true, SliceSize: 32, Elem: &Type{Kind: reflect.Uint8, Type: uint8_t, Size: 8, T: UintTy, stringKind: "uint8"}, T: FixedBytesTy, stringKind: "bytes32"}, stringKind: "bytes32[2]"}},
{"string", Type{Kind: reflect.String, Size: -1, T: StringTy, stringKind: "string"}},
{"string[]", Type{IsSlice: true, SliceSize: -1, Kind: reflect.String, T: StringTy, Size: -1, Elem: &Type{Kind: reflect.String, T: StringTy, Size: -1, stringKind: "string"}, stringKind: "string[]"}},
{"string[2]", Type{IsArray: true, SliceSize: 2, Kind: reflect.String, T: StringTy, Size: -1, Elem: &Type{Kind: reflect.String, T: StringTy, Size: -1, stringKind: "string"}, stringKind: "string[2]"}},
@@ -76,3 +97,59 @@ func TestTypeRegexp(t *testing.T) {
}
}
}
+
+func TestTypeCheck(t *testing.T) {
+ for i, test := range []struct {
+ typ string
+ input interface{}
+ err string
+ }{
+ {"uint", big.NewInt(1), ""},
+ {"int", big.NewInt(1), ""},
+ {"uint30", big.NewInt(1), ""},
+ {"uint30", uint8(1), "abi: cannot use uint8 as type ptr as argument"},
+ {"uint16", uint16(1), ""},
+ {"uint16", uint8(1), "abi: cannot use uint8 as type uint16 as argument"},
+ {"uint16[]", []uint16{1, 2, 3}, ""},
+ {"uint16[]", [3]uint16{1, 2, 3}, ""},
+ {"uint16[]", []uint32{1, 2, 3}, "abi: cannot use []uint32 as type []uint16 as argument"},
+ {"uint16[3]", [3]uint32{1, 2, 3}, "abi: cannot use [3]uint32 as type [3]uint16 as argument"},
+ {"uint16[3]", [4]uint16{1, 2, 3}, "abi: cannot use [4]uint16 as type [3]uint16 as argument"},
+ {"uint16[3]", []uint16{1, 2, 3}, ""},
+ {"uint16[3]", []uint16{1, 2, 3, 4}, "abi: cannot use [4]uint16 as type [3]uint16 as argument"},
+ {"address[]", []common.Address{{1}}, ""},
+ {"address[1]", []common.Address{{1}}, ""},
+ {"address[1]", [1]common.Address{{1}}, ""},
+ {"address[2]", [1]common.Address{{1}}, "abi: cannot use [1]array as type [2]array as argument"},
+ {"bytes32", [32]byte{}, ""},
+ {"bytes32", [33]byte{}, "abi: cannot use [33]uint8 as type [32]uint8 as argument"},
+ {"bytes32", common.Hash{1}, ""},
+ {"bytes31", [31]byte{}, ""},
+ {"bytes31", [32]byte{}, "abi: cannot use [32]uint8 as type [31]uint8 as argument"},
+ {"bytes", []byte{0, 1}, ""},
+ {"bytes", [2]byte{0, 1}, ""},
+ {"bytes", common.Hash{1}, ""},
+ {"string", "hello world", ""},
+ {"bytes32[]", [][32]byte{{}}, ""},
+ {"function", [24]byte{}, ""},
+ } {
+ typ, err := NewType(test.typ)
+ if err != nil {
+ t.Fatal("unexpected parse error:", err)
+ }
+
+ err = typeCheck(typ, reflect.ValueOf(test.input))
+ if err != nil && len(test.err) == 0 {
+ t.Errorf("%d failed. Expected no err but got: %v", i, err)
+ continue
+ }
+ if err == nil && len(test.err) != 0 {
+ t.Errorf("%d failed. Expected err: %v but got none", i, test.err)
+ continue
+ }
+
+ if err != nil && len(test.err) != 0 && err.Error() != test.err {
+ t.Errorf("%d failed. Expected err: '%v' got err: '%v'", i, test.err, err)
+ }
+ }
+}
diff --git a/accounts/abi/unpack.go b/accounts/abi/unpack.go
new file mode 100644
index 000000000..fc41c88ac
--- /dev/null
+++ b/accounts/abi/unpack.go
@@ -0,0 +1,235 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+package abi
+
+import (
+ "encoding/binary"
+ "fmt"
+ "math/big"
+ "reflect"
+
+ "github.com/ethereum/go-ethereum/common"
+)
+
+// toGoSliceType parses the input and casts it to the proper slice defined by the ABI
+// argument in T.
+func toGoSlice(i int, t Argument, output []byte) (interface{}, error) {
+ index := i * 32
+ // The slice must, at very least be large enough for the index+32 which is exactly the size required
+ // for the [offset in output, size of offset].
+ if index+32 > len(output) {
+ return nil, fmt.Errorf("abi: cannot marshal in to go slice: insufficient size output %d require %d", len(output), index+32)
+ }
+ elem := t.Type.Elem
+
+ // first we need to create a slice of the type
+ var refSlice reflect.Value
+ switch elem.T {
+ case IntTy, UintTy, BoolTy:
+ // create a new reference slice matching the element type
+ switch t.Type.Kind {
+ case reflect.Bool:
+ refSlice = reflect.ValueOf([]bool(nil))
+ case reflect.Uint8:
+ refSlice = reflect.ValueOf([]uint8(nil))
+ case reflect.Uint16:
+ refSlice = reflect.ValueOf([]uint16(nil))
+ case reflect.Uint32:
+ refSlice = reflect.ValueOf([]uint32(nil))
+ case reflect.Uint64:
+ refSlice = reflect.ValueOf([]uint64(nil))
+ case reflect.Int8:
+ refSlice = reflect.ValueOf([]int8(nil))
+ case reflect.Int16:
+ refSlice = reflect.ValueOf([]int16(nil))
+ case reflect.Int32:
+ refSlice = reflect.ValueOf([]int32(nil))
+ case reflect.Int64:
+ refSlice = reflect.ValueOf([]int64(nil))
+ default:
+ refSlice = reflect.ValueOf([]*big.Int(nil))
+ }
+ case AddressTy: // address must be of slice Address
+ refSlice = reflect.ValueOf([]common.Address(nil))
+ case HashTy: // hash must be of slice hash
+ refSlice = reflect.ValueOf([]common.Hash(nil))
+ case FixedBytesTy:
+ refSlice = reflect.ValueOf([][]byte(nil))
+ default: // no other types are supported
+ return nil, fmt.Errorf("abi: unsupported slice type %v", elem.T)
+ }
+
+ var slice []byte
+ var size int
+ var offset int
+ if t.Type.IsSlice {
+ // get the offset which determines the start of this array ...
+ offset = int(binary.BigEndian.Uint64(output[index+24 : index+32]))
+ if offset+32 > len(output) {
+ return nil, fmt.Errorf("abi: cannot marshal in to go slice: offset %d would go over slice boundary (len=%d)", len(output), offset+32)
+ }
+
+ slice = output[offset:]
+ // ... starting with the size of the array in elements ...
+ size = int(binary.BigEndian.Uint64(slice[24:32]))
+ slice = slice[32:]
+ // ... and make sure that we've at the very least the amount of bytes
+ // available in the buffer.
+ if size*32 > len(slice) {
+ return nil, fmt.Errorf("abi: cannot marshal in to go slice: insufficient size output %d require %d", len(output), offset+32+size*32)
+ }
+
+ // reslice to match the required size
+ slice = slice[:size*32]
+ } else if t.Type.IsArray {
+ //get the number of elements in the array
+ size = t.Type.SliceSize
+
+ //check to make sure array size matches up
+ if index+32*size > len(output) {
+ return nil, fmt.Errorf("abi: cannot marshal in to go array: offset %d would go over slice boundary (len=%d)", len(output), index+32*size)
+ }
+ //slice is there for a fixed amount of times
+ slice = output[index : index+size*32]
+ }
+
+ for i := 0; i < size; i++ {
+ var (
+ inter interface{} // interface type
+ returnOutput = slice[i*32 : i*32+32] // the return output
+ err error
+ )
+ // set inter to the correct type (cast)
+ switch elem.T {
+ case IntTy, UintTy:
+ inter = readInteger(t.Type.Kind, returnOutput)
+ case BoolTy:
+ inter, err = readBool(returnOutput)
+ if err != nil {
+ return nil, err
+ }
+ case AddressTy:
+ inter = common.BytesToAddress(returnOutput)
+ case HashTy:
+ inter = common.BytesToHash(returnOutput)
+ case FixedBytesTy:
+ inter = returnOutput
+ }
+ // append the item to our reflect slice
+ refSlice = reflect.Append(refSlice, reflect.ValueOf(inter))
+ }
+
+ // return the interface
+ return refSlice.Interface(), nil
+}
+
+func readInteger(kind reflect.Kind, b []byte) interface{} {
+ switch kind {
+ case reflect.Uint8:
+ return uint8(b[len(b)-1])
+ case reflect.Uint16:
+ return binary.BigEndian.Uint16(b[len(b)-2:])
+ case reflect.Uint32:
+ return binary.BigEndian.Uint32(b[len(b)-4:])
+ case reflect.Uint64:
+ return binary.BigEndian.Uint64(b[len(b)-8:])
+ case reflect.Int8:
+ return int8(b[len(b)-1])
+ case reflect.Int16:
+ return int16(binary.BigEndian.Uint16(b[len(b)-2:]))
+ case reflect.Int32:
+ return int32(binary.BigEndian.Uint32(b[len(b)-4:]))
+ case reflect.Int64:
+ return int64(binary.BigEndian.Uint64(b[len(b)-8:]))
+ default:
+ return new(big.Int).SetBytes(b)
+ }
+}
+
+func readBool(word []byte) (bool, error) {
+ if len(word) != 32 {
+ return false, fmt.Errorf("abi: fatal error: incorrect word length")
+ }
+
+ for i, b := range word {
+ if b != 0 && i != 31 {
+ return false, errBadBool
+ }
+ }
+ switch word[31] {
+ case 0:
+ return false, nil
+ case 1:
+ return true, nil
+ default:
+ return false, errBadBool
+ }
+
+}
+
+// toGoType parses the input and casts it to the proper type defined by the ABI
+// argument in T.
+func toGoType(i int, t Argument, output []byte) (interface{}, error) {
+ // we need to treat slices differently
+ if (t.Type.IsSlice || t.Type.IsArray) && t.Type.T != BytesTy && t.Type.T != StringTy && t.Type.T != FixedBytesTy && t.Type.T != FunctionTy {
+ return toGoSlice(i, t, output)
+ }
+
+ index := i * 32
+ if index+32 > len(output) {
+ return nil, fmt.Errorf("abi: cannot marshal in to go type: length insufficient %d require %d", len(output), index+32)
+ }
+
+ // Parse the given index output and check whether we need to read
+ // a different offset and length based on the type (i.e. string, bytes)
+ var returnOutput []byte
+ switch t.Type.T {
+ case StringTy, BytesTy: // variable arrays are written at the end of the return bytes
+ // parse offset from which we should start reading
+ offset := int(binary.BigEndian.Uint64(output[index+24 : index+32]))
+ if offset+32 > len(output) {
+ return nil, fmt.Errorf("abi: cannot marshal in to go type: length insufficient %d require %d", len(output), offset+32)
+ }
+ // parse the size up until we should be reading
+ size := int(binary.BigEndian.Uint64(output[offset+24 : offset+32]))
+ if offset+32+size > len(output) {
+ return nil, fmt.Errorf("abi: cannot marshal in to go type: length insufficient %d require %d", len(output), offset+32+size)
+ }
+
+ // get the bytes for this return value
+ returnOutput = output[offset+32 : offset+32+size]
+ default:
+ returnOutput = output[index : index+32]
+ }
+
+ // convert the bytes to whatever is specified by the ABI.
+ switch t.Type.T {
+ case IntTy, UintTy:
+ return readInteger(t.Type.Kind, returnOutput), nil
+ case BoolTy:
+ return readBool(returnOutput)
+ case AddressTy:
+ return common.BytesToAddress(returnOutput), nil
+ case HashTy:
+ return common.BytesToHash(returnOutput), nil
+ case BytesTy, FixedBytesTy, FunctionTy:
+ return returnOutput, nil
+ case StringTy:
+ return string(returnOutput), nil
+ }
+ return nil, fmt.Errorf("abi: unknown type %v", t.Type.T)
+}
diff --git a/accounts/abi/unpack_test.go b/accounts/abi/unpack_test.go
new file mode 100644
index 000000000..8e3afee4e
--- /dev/null
+++ b/accounts/abi/unpack_test.go
@@ -0,0 +1,681 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+package abi
+
+import (
+ "bytes"
+ "fmt"
+ "math/big"
+ "reflect"
+ "strings"
+ "testing"
+
+ "github.com/ethereum/go-ethereum/common"
+)
+
+func TestSimpleMethodUnpack(t *testing.T) {
+ for i, test := range []struct {
+ def string // definition of the **output** ABI params
+ marshalledOutput []byte // evm return data
+ expectedOut interface{} // the expected output
+ outVar string // the output variable (e.g. uint32, *big.Int, etc)
+ err string // empty or error if expected
+ }{
+ {
+ `[ { "type": "bool" } ]`,
+ common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001"),
+ bool(true),
+ "bool",
+ "",
+ },
+ {
+ `[ { "type": "uint32" } ]`,
+ common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001"),
+ uint32(1),
+ "uint32",
+ "",
+ },
+ {
+ `[ { "type": "uint32" } ]`,
+ common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001"),
+ nil,
+ "uint16",
+ "abi: cannot unmarshal uint32 in to uint16",
+ },
+ {
+ `[ { "type": "uint17" } ]`,
+ common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001"),
+ nil,
+ "uint16",
+ "abi: cannot unmarshal *big.Int in to uint16",
+ },
+ {
+ `[ { "type": "uint17" } ]`,
+ common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001"),
+ big.NewInt(1),
+ "*big.Int",
+ "",
+ },
+
+ {
+ `[ { "type": "int32" } ]`,
+ common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001"),
+ int32(1),
+ "int32",
+ "",
+ },
+ {
+ `[ { "type": "int32" } ]`,
+ common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001"),
+ nil,
+ "int16",
+ "abi: cannot unmarshal int32 in to int16",
+ },
+ {
+ `[ { "type": "int17" } ]`,
+ common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001"),
+ nil,
+ "int16",
+ "abi: cannot unmarshal *big.Int in to int16",
+ },
+ {
+ `[ { "type": "int17" } ]`,
+ common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001"),
+ big.NewInt(1),
+ "*big.Int",
+ "",
+ },
+
+ {
+ `[ { "type": "address" } ]`,
+ common.Hex2Bytes("0000000000000000000000000100000000000000000000000000000000000000"),
+ common.Address{1},
+ "address",
+ "",
+ },
+ {
+ `[ { "type": "bytes32" } ]`,
+ common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
+ common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
+ "bytes",
+ "",
+ },
+ {
+ `[ { "type": "bytes32" } ]`,
+ common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
+ common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
+ "hash",
+ "",
+ },
+ {
+ `[ { "type": "bytes32" } ]`,
+ common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
+ common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
+ "interface",
+ "",
+ },
+ {
+ `[ { "type": "function" } ]`,
+ common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
+ [24]byte{1},
+ "function",
+ "",
+ },
+ } {
+ abiDefinition := fmt.Sprintf(`[{ "name" : "method", "outputs": %s}]`, test.def)
+ abi, err := JSON(strings.NewReader(abiDefinition))
+ if err != nil {
+ t.Errorf("%d failed. %v", i, err)
+ continue
+ }
+
+ var outvar interface{}
+ switch test.outVar {
+ case "bool":
+ var v bool
+ err = abi.Unpack(&v, "method", test.marshalledOutput)
+ outvar = v
+ case "uint8":
+ var v uint8
+ err = abi.Unpack(&v, "method", test.marshalledOutput)
+ outvar = v
+ case "uint16":
+ var v uint16
+ err = abi.Unpack(&v, "method", test.marshalledOutput)
+ outvar = v
+ case "uint32":
+ var v uint32
+ err = abi.Unpack(&v, "method", test.marshalledOutput)
+ outvar = v
+ case "uint64":
+ var v uint64
+ err = abi.Unpack(&v, "method", test.marshalledOutput)
+ outvar = v
+ case "int8":
+ var v int8
+ err = abi.Unpack(&v, "method", test.marshalledOutput)
+ outvar = v
+ case "int16":
+ var v int16
+ err = abi.Unpack(&v, "method", test.marshalledOutput)
+ outvar = v
+ case "int32":
+ var v int32
+ err = abi.Unpack(&v, "method", test.marshalledOutput)
+ outvar = v
+ case "int64":
+ var v int64
+ err = abi.Unpack(&v, "method", test.marshalledOutput)
+ outvar = v
+ case "*big.Int":
+ var v *big.Int
+ err = abi.Unpack(&v, "method", test.marshalledOutput)
+ outvar = v
+ case "address":
+ var v common.Address
+ err = abi.Unpack(&v, "method", test.marshalledOutput)
+ outvar = v
+ case "bytes":
+ var v []byte
+ err = abi.Unpack(&v, "method", test.marshalledOutput)
+ outvar = v
+ case "hash":
+ var v common.Hash
+ err = abi.Unpack(&v, "method", test.marshalledOutput)
+ outvar = v.Bytes()[:]
+ case "function":
+ var v [24]byte
+ err = abi.Unpack(&v, "method", test.marshalledOutput)
+ outvar = v
+ case "interface":
+ err = abi.Unpack(&outvar, "method", test.marshalledOutput)
+ default:
+ t.Errorf("unsupported type '%v' please add it to the switch statement in this test", test.outVar)
+ continue
+ }
+
+ if err != nil && len(test.err) == 0 {
+ t.Errorf("%d failed. Expected no err but got: %v", i, err)
+ continue
+ }
+ if err == nil && len(test.err) != 0 {
+ t.Errorf("%d failed. Expected err: %v but got none", i, test.err)
+ continue
+ }
+ if err != nil && len(test.err) != 0 && err.Error() != test.err {
+ t.Errorf("%d failed. Expected err: '%v' got err: '%v'", i, test.err, err)
+ continue
+ }
+
+ if err == nil {
+ if !reflect.DeepEqual(test.expectedOut, outvar) {
+ t.Errorf("%d failed. Output error: expected %v, got %v", i, test.expectedOut, outvar)
+ }
+ }
+ }
+}
+
+func TestUnpackSetInterfaceSlice(t *testing.T) {
+ var (
+ var1 = new(uint8)
+ var2 = new(uint8)
+ )
+ out := []interface{}{var1, var2}
+ abi, err := JSON(strings.NewReader(`[{"type":"function", "name":"ints", "outputs":[{"type":"uint8"}, {"type":"uint8"}]}]`))
+ if err != nil {
+ t.Fatal(err)
+ }
+ marshalledReturn := append(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001"), common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002")...)
+ err = abi.Unpack(&out, "ints", marshalledReturn)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if *var1 != 1 {
+ t.Error("expected var1 to be 1, got", *var1)
+ }
+ if *var2 != 2 {
+ t.Error("expected var2 to be 2, got", *var2)
+ }
+
+ out = []interface{}{var1}
+ err = abi.Unpack(&out, "ints", marshalledReturn)
+
+ expErr := "abi: cannot marshal in to slices of unequal size (require: 2, got: 1)"
+ if err == nil || err.Error() != expErr {
+ t.Error("expected err:", expErr, "Got:", err)
+ }
+}
+
+func TestUnpackSetInterfaceArrayOutput(t *testing.T) {
+ var (
+ var1 = new([1]uint32)
+ var2 = new([1]uint32)
+ )
+ out := []interface{}{var1, var2}
+ abi, err := JSON(strings.NewReader(`[{"type":"function", "name":"ints", "outputs":[{"type":"uint32[1]"}, {"type":"uint32[1]"}]}]`))
+ if err != nil {
+ t.Fatal(err)
+ }
+ marshalledReturn := append(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001"), common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002")...)
+ err = abi.Unpack(&out, "ints", marshalledReturn)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if *var1 != [1]uint32{1} {
+ t.Error("expected var1 to be [1], got", *var1)
+ }
+ if *var2 != [1]uint32{2} {
+ t.Error("expected var2 to be [2], got", *var2)
+ }
+}
+
+func TestMultiReturnWithStruct(t *testing.T) {
+ const definition = `[
+ { "name" : "multi", "constant" : false, "outputs": [ { "name": "Int", "type": "uint256" }, { "name": "String", "type": "string" } ] }]`
+
+ abi, err := JSON(strings.NewReader(definition))
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // using buff to make the code readable
+ buff := new(bytes.Buffer)
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001"))
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000040"))
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000005"))
+ stringOut := "hello"
+ buff.Write(common.RightPadBytes([]byte(stringOut), 32))
+
+ var inter struct {
+ Int *big.Int
+ String string
+ }
+ err = abi.Unpack(&inter, "multi", buff.Bytes())
+ if err != nil {
+ t.Error(err)
+ }
+
+ if inter.Int == nil || inter.Int.Cmp(big.NewInt(1)) != 0 {
+ t.Error("expected Int to be 1 got", inter.Int)
+ }
+
+ if inter.String != stringOut {
+ t.Error("expected String to be", stringOut, "got", inter.String)
+ }
+
+ var reversed struct {
+ String string
+ Int *big.Int
+ }
+
+ err = abi.Unpack(&reversed, "multi", buff.Bytes())
+ if err != nil {
+ t.Error(err)
+ }
+
+ if reversed.Int == nil || reversed.Int.Cmp(big.NewInt(1)) != 0 {
+ t.Error("expected Int to be 1 got", reversed.Int)
+ }
+
+ if reversed.String != stringOut {
+ t.Error("expected String to be", stringOut, "got", reversed.String)
+ }
+}
+
+func TestMultiReturnWithSlice(t *testing.T) {
+ const definition = `[
+ { "name" : "multi", "constant" : false, "outputs": [ { "name": "Int", "type": "uint256" }, { "name": "String", "type": "string" } ] }]`
+
+ abi, err := JSON(strings.NewReader(definition))
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // using buff to make the code readable
+ buff := new(bytes.Buffer)
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001"))
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000040"))
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000005"))
+ stringOut := "hello"
+ buff.Write(common.RightPadBytes([]byte(stringOut), 32))
+
+ var inter []interface{}
+ err = abi.Unpack(&inter, "multi", buff.Bytes())
+ if err != nil {
+ t.Error(err)
+ }
+
+ if len(inter) != 2 {
+ t.Fatal("expected 2 results got", len(inter))
+ }
+
+ if num, ok := inter[0].(*big.Int); !ok || num.Cmp(big.NewInt(1)) != 0 {
+ t.Error("expected index 0 to be 1 got", num)
+ }
+
+ if str, ok := inter[1].(string); !ok || str != stringOut {
+ t.Error("expected index 1 to be", stringOut, "got", str)
+ }
+}
+
+func TestMarshalArrays(t *testing.T) {
+ const definition = `[
+ { "name" : "bytes32", "constant" : false, "outputs": [ { "type": "bytes32" } ] },
+ { "name" : "bytes10", "constant" : false, "outputs": [ { "type": "bytes10" } ] }
+ ]`
+
+ abi, err := JSON(strings.NewReader(definition))
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ output := common.LeftPadBytes([]byte{1}, 32)
+
+ var bytes10 [10]byte
+ err = abi.Unpack(&bytes10, "bytes32", output)
+ if err == nil || err.Error() != "abi: cannot unmarshal src (len=32) in to dst (len=10)" {
+ t.Error("expected error or bytes32 not be assignable to bytes10:", err)
+ }
+
+ var bytes32 [32]byte
+ err = abi.Unpack(&bytes32, "bytes32", output)
+ if err != nil {
+ t.Error("didn't expect error:", err)
+ }
+ if !bytes.Equal(bytes32[:], output) {
+ t.Error("expected bytes32[31] to be 1 got", bytes32[31])
+ }
+
+ type (
+ B10 [10]byte
+ B32 [32]byte
+ )
+
+ var b10 B10
+ err = abi.Unpack(&b10, "bytes32", output)
+ if err == nil || err.Error() != "abi: cannot unmarshal src (len=32) in to dst (len=10)" {
+ t.Error("expected error or bytes32 not be assignable to bytes10:", err)
+ }
+
+ var b32 B32
+ err = abi.Unpack(&b32, "bytes32", output)
+ if err != nil {
+ t.Error("didn't expect error:", err)
+ }
+ if !bytes.Equal(b32[:], output) {
+ t.Error("expected bytes32[31] to be 1 got", bytes32[31])
+ }
+
+ output[10] = 1
+ var shortAssignLong [32]byte
+ err = abi.Unpack(&shortAssignLong, "bytes10", output)
+ if err != nil {
+ t.Error("didn't expect error:", err)
+ }
+ if !bytes.Equal(output, shortAssignLong[:]) {
+ t.Errorf("expected %x to be %x", shortAssignLong, output)
+ }
+}
+
+func TestUnmarshal(t *testing.T) {
+ const definition = `[
+ { "name" : "int", "constant" : false, "outputs": [ { "type": "uint256" } ] },
+ { "name" : "bool", "constant" : false, "outputs": [ { "type": "bool" } ] },
+ { "name" : "bytes", "constant" : false, "outputs": [ { "type": "bytes" } ] },
+ { "name" : "fixed", "constant" : false, "outputs": [ { "type": "bytes32" } ] },
+ { "name" : "multi", "constant" : false, "outputs": [ { "type": "bytes" }, { "type": "bytes" } ] },
+ { "name" : "intArraySingle", "constant" : false, "outputs": [ { "type": "uint256[3]" } ] },
+ { "name" : "addressSliceSingle", "constant" : false, "outputs": [ { "type": "address[]" } ] },
+ { "name" : "addressSliceDouble", "constant" : false, "outputs": [ { "name": "a", "type": "address[]" }, { "name": "b", "type": "address[]" } ] },
+ { "name" : "mixedBytes", "constant" : true, "outputs": [ { "name": "a", "type": "bytes" }, { "name": "b", "type": "bytes32" } ] }]`
+
+ abi, err := JSON(strings.NewReader(definition))
+ if err != nil {
+ t.Fatal(err)
+ }
+ buff := new(bytes.Buffer)
+
+ // marshal int
+ var Int *big.Int
+ err = abi.Unpack(&Int, "int", common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001"))
+ if err != nil {
+ t.Error(err)
+ }
+
+ if Int == nil || Int.Cmp(big.NewInt(1)) != 0 {
+ t.Error("expected Int to be 1 got", Int)
+ }
+
+ // marshal bool
+ var Bool bool
+ err = abi.Unpack(&Bool, "bool", common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001"))
+ if err != nil {
+ t.Error(err)
+ }
+
+ if !Bool {
+ t.Error("expected Bool to be true")
+ }
+
+ // marshal dynamic bytes max length 32
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020"))
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020"))
+ bytesOut := common.RightPadBytes([]byte("hello"), 32)
+ buff.Write(bytesOut)
+
+ var Bytes []byte
+ err = abi.Unpack(&Bytes, "bytes", buff.Bytes())
+ if err != nil {
+ t.Error(err)
+ }
+
+ if !bytes.Equal(Bytes, bytesOut) {
+ t.Errorf("expected %x got %x", bytesOut, Bytes)
+ }
+
+ // marshall dynamic bytes max length 64
+ buff.Reset()
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020"))
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000040"))
+ bytesOut = common.RightPadBytes([]byte("hello"), 64)
+ buff.Write(bytesOut)
+
+ err = abi.Unpack(&Bytes, "bytes", buff.Bytes())
+ if err != nil {
+ t.Error(err)
+ }
+
+ if !bytes.Equal(Bytes, bytesOut) {
+ t.Errorf("expected %x got %x", bytesOut, Bytes)
+ }
+
+ // marshall dynamic bytes max length 63
+ buff.Reset()
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020"))
+ buff.Write(common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000003f"))
+ bytesOut = common.RightPadBytes([]byte("hello"), 63)
+ buff.Write(bytesOut)
+
+ err = abi.Unpack(&Bytes, "bytes", buff.Bytes())
+ if err != nil {
+ t.Error(err)
+ }
+
+ if !bytes.Equal(Bytes, bytesOut) {
+ t.Errorf("expected %x got %x", bytesOut, Bytes)
+ }
+
+ // marshal dynamic bytes output empty
+ err = abi.Unpack(&Bytes, "bytes", nil)
+ if err == nil {
+ t.Error("expected error")
+ }
+
+ // marshal dynamic bytes length 5
+ buff.Reset()
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020"))
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000005"))
+ buff.Write(common.RightPadBytes([]byte("hello"), 32))
+
+ err = abi.Unpack(&Bytes, "bytes", buff.Bytes())
+ if err != nil {
+ t.Error(err)
+ }
+
+ if !bytes.Equal(Bytes, []byte("hello")) {
+ t.Errorf("expected %x got %x", bytesOut, Bytes)
+ }
+
+ // marshal dynamic bytes length 5
+ buff.Reset()
+ buff.Write(common.RightPadBytes([]byte("hello"), 32))
+
+ var hash common.Hash
+ err = abi.Unpack(&hash, "fixed", buff.Bytes())
+ if err != nil {
+ t.Error(err)
+ }
+
+ helloHash := common.BytesToHash(common.RightPadBytes([]byte("hello"), 32))
+ if hash != helloHash {
+ t.Errorf("Expected %x to equal %x", hash, helloHash)
+ }
+
+ // marshal error
+ buff.Reset()
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020"))
+ err = abi.Unpack(&Bytes, "bytes", buff.Bytes())
+ if err == nil {
+ t.Error("expected error")
+ }
+
+ err = abi.Unpack(&Bytes, "multi", make([]byte, 64))
+ if err == nil {
+ t.Error("expected error")
+ }
+
+ // marshal mixed bytes
+ buff.Reset()
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000040"))
+ fixed := common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")
+ buff.Write(fixed)
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020"))
+ bytesOut = common.RightPadBytes([]byte("hello"), 32)
+ buff.Write(bytesOut)
+
+ var out []interface{}
+ err = abi.Unpack(&out, "mixedBytes", buff.Bytes())
+ if err != nil {
+ t.Fatal("didn't expect error:", err)
+ }
+
+ if !bytes.Equal(bytesOut, out[0].([]byte)) {
+ t.Errorf("expected %x, got %x", bytesOut, out[0])
+ }
+
+ if !bytes.Equal(fixed, out[1].([]byte)) {
+ t.Errorf("expected %x, got %x", fixed, out[1])
+ }
+
+ buff.Reset()
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001"))
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"))
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000003"))
+ // marshal int array
+ var intArray [3]*big.Int
+ err = abi.Unpack(&intArray, "intArraySingle", buff.Bytes())
+ if err != nil {
+ t.Error(err)
+ }
+ var testAgainstIntArray [3]*big.Int
+ testAgainstIntArray[0] = big.NewInt(1)
+ testAgainstIntArray[1] = big.NewInt(2)
+ testAgainstIntArray[2] = big.NewInt(3)
+
+ for i, Int := range intArray {
+ if Int.Cmp(testAgainstIntArray[i]) != 0 {
+ t.Errorf("expected %v, got %v", testAgainstIntArray[i], Int)
+ }
+ }
+ // marshal address slice
+ buff.Reset()
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020")) // offset
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) // size
+ buff.Write(common.Hex2Bytes("0000000000000000000000000100000000000000000000000000000000000000"))
+
+ var outAddr []common.Address
+ err = abi.Unpack(&outAddr, "addressSliceSingle", buff.Bytes())
+ if err != nil {
+ t.Fatal("didn't expect error:", err)
+ }
+
+ if len(outAddr) != 1 {
+ t.Fatal("expected 1 item, got", len(outAddr))
+ }
+
+ if outAddr[0] != (common.Address{1}) {
+ t.Errorf("expected %x, got %x", common.Address{1}, outAddr[0])
+ }
+
+ // marshal multiple address slice
+ buff.Reset()
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000040")) // offset
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000080")) // offset
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) // size
+ buff.Write(common.Hex2Bytes("0000000000000000000000000100000000000000000000000000000000000000"))
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002")) // size
+ buff.Write(common.Hex2Bytes("0000000000000000000000000200000000000000000000000000000000000000"))
+ buff.Write(common.Hex2Bytes("0000000000000000000000000300000000000000000000000000000000000000"))
+
+ var outAddrStruct struct {
+ A []common.Address
+ B []common.Address
+ }
+ err = abi.Unpack(&outAddrStruct, "addressSliceDouble", buff.Bytes())
+ if err != nil {
+ t.Fatal("didn't expect error:", err)
+ }
+
+ if len(outAddrStruct.A) != 1 {
+ t.Fatal("expected 1 item, got", len(outAddrStruct.A))
+ }
+
+ if outAddrStruct.A[0] != (common.Address{1}) {
+ t.Errorf("expected %x, got %x", common.Address{1}, outAddrStruct.A[0])
+ }
+
+ if len(outAddrStruct.B) != 2 {
+ t.Fatal("expected 1 item, got", len(outAddrStruct.B))
+ }
+
+ if outAddrStruct.B[0] != (common.Address{2}) {
+ t.Errorf("expected %x, got %x", common.Address{2}, outAddrStruct.B[0])
+ }
+ if outAddrStruct.B[1] != (common.Address{3}) {
+ t.Errorf("expected %x, got %x", common.Address{3}, outAddrStruct.B[1])
+ }
+
+ // marshal invalid address slice
+ buff.Reset()
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000100"))
+
+ err = abi.Unpack(&outAddr, "addressSliceSingle", buff.Bytes())
+ if err == nil {
+ t.Fatal("expected error:", err)
+ }
+}
diff --git a/build/ci.go b/build/ci.go
index 47b1dc780..6a52077d4 100644
--- a/build/ci.go
+++ b/build/ci.go
@@ -175,7 +175,7 @@ func doInstall(cmdline []string) {
// Check Go version. People regularly open issues about compilation
// failure with outdated Go. This should save them the trouble.
- if runtime.Version() < "go1.7" && !strings.HasPrefix(runtime.Version(), "devel") {
+ if runtime.Version() < "go1.7" && !strings.Contains(runtime.Version(), "devel") {
log.Println("You have Go version", runtime.Version())
log.Println("go-ethereum requires at least Go version 1.7 and cannot")
log.Println("be compiled with an earlier version. Please upgrade your Go installation.")
diff --git a/cmd/evm/runner.go b/cmd/evm/runner.go
index 2ce0920f6..3f95a0c93 100644
--- a/cmd/evm/runner.go
+++ b/cmd/evm/runner.go
@@ -98,8 +98,8 @@ func runCmd(ctx *cli.Context) error {
_, statedb = gen.ToBlock()
chainConfig = gen.Config
} else {
- var db, _ = ethdb.NewMemDatabase()
- statedb, _ = state.New(common.Hash{}, db)
+ db, _ := ethdb.NewMemDatabase()
+ statedb, _ = state.New(common.Hash{}, state.NewDatabase(db))
}
if ctx.GlobalString(SenderFlag.Name) != "" {
sender = common.HexToAddress(ctx.GlobalString(SenderFlag.Name))
@@ -188,7 +188,7 @@ func runCmd(ctx *cli.Context) error {
execTime := time.Since(tstart)
if ctx.GlobalBool(DumpFlag.Name) {
- statedb.Commit(true)
+ statedb.IntermediateRoot(true)
fmt.Println(string(statedb.Dump()))
}
diff --git a/cmd/geth/chaincmd.go b/cmd/geth/chaincmd.go
index ab0e92f21..12bc1d7c6 100644
--- a/cmd/geth/chaincmd.go
+++ b/cmd/geth/chaincmd.go
@@ -312,7 +312,7 @@ func dump(ctx *cli.Context) error {
fmt.Println("{}")
utils.Fatalf("block not found")
} else {
- state, err := state.New(block.Root(), chainDb)
+ state, err := state.New(block.Root(), state.NewDatabase(chainDb))
if err != nil {
utils.Fatalf("could not create new state: %v", err)
}
diff --git a/cmd/puppeth/wizard_faucet.go b/cmd/puppeth/wizard_faucet.go
index 66ec98c73..51c4e2f7f 100644
--- a/cmd/puppeth/wizard_faucet.go
+++ b/cmd/puppeth/wizard_faucet.go
@@ -165,8 +165,7 @@ func (w *wizard) deployFaucet() {
}
// Load up the credential needed to release funds
if infos.node.keyJSON != "" {
- var key keystore.Key
- if err := json.Unmarshal([]byte(infos.node.keyJSON), &key); err != nil {
+ if key, err := keystore.DecryptKey([]byte(infos.node.keyJSON), infos.node.keyPass); err != nil {
infos.node.keyJSON, infos.node.keyPass = "", ""
} else {
fmt.Println()
diff --git a/common/hexutil/hexutil.go b/common/hexutil/hexutil.go
index 6b128ae36..582a67c22 100644
--- a/common/hexutil/hexutil.go
+++ b/common/hexutil/hexutil.go
@@ -32,7 +32,6 @@ package hexutil
import (
"encoding/hex"
- "errors"
"fmt"
"math/big"
"strconv"
@@ -41,17 +40,23 @@ import (
const uintBits = 32 << (uint64(^uint(0)) >> 63)
var (
- ErrEmptyString = errors.New("empty hex string")
- ErrMissingPrefix = errors.New("missing 0x prefix for hex data")
- ErrSyntax = errors.New("invalid hex")
- ErrEmptyNumber = errors.New("hex number has no digits after 0x")
- ErrLeadingZero = errors.New("hex number has leading zero digits after 0x")
- ErrOddLength = errors.New("hex string has odd length")
- ErrUint64Range = errors.New("hex number does not fit into 64 bits")
- ErrUintRange = fmt.Errorf("hex number does not fit into %d bits", uintBits)
- ErrBig256Range = errors.New("hex number does not fit into 256 bits")
+ ErrEmptyString = &decError{"empty hex string"}
+ ErrSyntax = &decError{"invalid hex string"}
+ ErrMissingPrefix = &decError{"hex string without 0x prefix"}
+ ErrOddLength = &decError{"hex string of odd length"}
+ ErrEmptyNumber = &decError{"hex string \"0x\""}
+ ErrLeadingZero = &decError{"hex number with leading zero digits"}
+ ErrUint64Range = &decError{"hex number > 64 bits"}
+ ErrUintRange = &decError{fmt.Sprintf("hex number > %d bits", uintBits)}
+ ErrBig256Range = &decError{"hex number > 256 bits"}
)
+type decError struct{ msg string }
+
+func (err decError) Error() string {
+ return string(err.msg)
+}
+
// Decode decodes a hex string with 0x prefix.
func Decode(input string) ([]byte, error) {
if len(input) == 0 {
diff --git a/common/hexutil/json.go b/common/hexutil/json.go
index 1bc1d014c..943288fad 100644
--- a/common/hexutil/json.go
+++ b/common/hexutil/json.go
@@ -18,15 +18,19 @@ package hexutil
import (
"encoding/hex"
- "errors"
+ "encoding/json"
"fmt"
"math/big"
+ "reflect"
"strconv"
)
var (
- textZero = []byte(`0x0`)
- errNonString = errors.New("cannot unmarshal non-string as hex data")
+ textZero = []byte(`0x0`)
+ bytesT = reflect.TypeOf(Bytes(nil))
+ bigT = reflect.TypeOf((*Big)(nil))
+ uintT = reflect.TypeOf(Uint(0))
+ uint64T = reflect.TypeOf(Uint64(0))
)
// Bytes marshals/unmarshals as a JSON string with 0x prefix.
@@ -44,9 +48,9 @@ func (b Bytes) MarshalText() ([]byte, error) {
// UnmarshalJSON implements json.Unmarshaler.
func (b *Bytes) UnmarshalJSON(input []byte) error {
if !isString(input) {
- return errNonString
+ return errNonString(bytesT)
}
- return b.UnmarshalText(input[1 : len(input)-1])
+ return wrapTypeError(b.UnmarshalText(input[1:len(input)-1]), bytesT)
}
// UnmarshalText implements encoding.TextUnmarshaler.
@@ -69,6 +73,16 @@ func (b Bytes) String() string {
return Encode(b)
}
+// UnmarshalFixedJSON decodes the input as a string with 0x prefix. The length of out
+// determines the required input length. This function is commonly used to implement the
+// UnmarshalJSON method for fixed-size types.
+func UnmarshalFixedJSON(typ reflect.Type, input, out []byte) error {
+ if !isString(input) {
+ return errNonString(typ)
+ }
+ return wrapTypeError(UnmarshalFixedText(typ.String(), input[1:len(input)-1], out), typ)
+}
+
// UnmarshalFixedText decodes the input as a string with 0x prefix. The length of out
// determines the required input length. This function is commonly used to implement the
// UnmarshalText method for fixed-size types.
@@ -127,9 +141,9 @@ func (b Big) MarshalText() ([]byte, error) {
// UnmarshalJSON implements json.Unmarshaler.
func (b *Big) UnmarshalJSON(input []byte) error {
if !isString(input) {
- return errNonString
+ return errNonString(bigT)
}
- return b.UnmarshalText(input[1 : len(input)-1])
+ return wrapTypeError(b.UnmarshalText(input[1:len(input)-1]), bigT)
}
// UnmarshalText implements encoding.TextUnmarshaler
@@ -189,9 +203,9 @@ func (b Uint64) MarshalText() ([]byte, error) {
// UnmarshalJSON implements json.Unmarshaler.
func (b *Uint64) UnmarshalJSON(input []byte) error {
if !isString(input) {
- return errNonString
+ return errNonString(uint64T)
}
- return b.UnmarshalText(input[1 : len(input)-1])
+ return wrapTypeError(b.UnmarshalText(input[1:len(input)-1]), uint64T)
}
// UnmarshalText implements encoding.TextUnmarshaler
@@ -233,9 +247,9 @@ func (b Uint) MarshalText() ([]byte, error) {
// UnmarshalJSON implements json.Unmarshaler.
func (b *Uint) UnmarshalJSON(input []byte) error {
if !isString(input) {
- return errNonString
+ return errNonString(uintT)
}
- return b.UnmarshalText(input[1 : len(input)-1])
+ return wrapTypeError(b.UnmarshalText(input[1:len(input)-1]), uintT)
}
// UnmarshalText implements encoding.TextUnmarshaler.
@@ -295,3 +309,14 @@ func checkNumberText(input []byte) (raw []byte, err error) {
}
return input, nil
}
+
+func wrapTypeError(err error, typ reflect.Type) error {
+ if _, ok := err.(*decError); ok {
+ return &json.UnmarshalTypeError{Value: err.Error(), Type: typ}
+ }
+ return err
+}
+
+func errNonString(typ reflect.Type) error {
+ return &json.UnmarshalTypeError{Value: "non-string", Type: typ}
+}
diff --git a/common/hexutil/json_test.go b/common/hexutil/json_test.go
index e4e827491..8a6b8643a 100644
--- a/common/hexutil/json_test.go
+++ b/common/hexutil/json_test.go
@@ -62,12 +62,12 @@ var errJSONEOF = errors.New("unexpected end of JSON input")
var unmarshalBytesTests = []unmarshalTest{
// invalid encoding
{input: "", wantErr: errJSONEOF},
- {input: "null", wantErr: errNonString},
- {input: "10", wantErr: errNonString},
- {input: `"0"`, wantErr: ErrMissingPrefix},
- {input: `"0x0"`, wantErr: ErrOddLength},
- {input: `"0xxx"`, wantErr: ErrSyntax},
- {input: `"0x01zz01"`, wantErr: ErrSyntax},
+ {input: "null", wantErr: errNonString(bytesT)},
+ {input: "10", wantErr: errNonString(bytesT)},
+ {input: `"0"`, wantErr: wrapTypeError(ErrMissingPrefix, bytesT)},
+ {input: `"0x0"`, wantErr: wrapTypeError(ErrOddLength, bytesT)},
+ {input: `"0xxx"`, wantErr: wrapTypeError(ErrSyntax, bytesT)},
+ {input: `"0x01zz01"`, wantErr: wrapTypeError(ErrSyntax, bytesT)},
// valid encoding
{input: `""`, want: referenceBytes("")},
@@ -127,16 +127,16 @@ func TestMarshalBytes(t *testing.T) {
var unmarshalBigTests = []unmarshalTest{
// invalid encoding
{input: "", wantErr: errJSONEOF},
- {input: "null", wantErr: errNonString},
- {input: "10", wantErr: errNonString},
- {input: `"0"`, wantErr: ErrMissingPrefix},
- {input: `"0x"`, wantErr: ErrEmptyNumber},
- {input: `"0x01"`, wantErr: ErrLeadingZero},
- {input: `"0xx"`, wantErr: ErrSyntax},
- {input: `"0x1zz01"`, wantErr: ErrSyntax},
+ {input: "null", wantErr: errNonString(bigT)},
+ {input: "10", wantErr: errNonString(bigT)},
+ {input: `"0"`, wantErr: wrapTypeError(ErrMissingPrefix, bigT)},
+ {input: `"0x"`, wantErr: wrapTypeError(ErrEmptyNumber, bigT)},
+ {input: `"0x01"`, wantErr: wrapTypeError(ErrLeadingZero, bigT)},
+ {input: `"0xx"`, wantErr: wrapTypeError(ErrSyntax, bigT)},
+ {input: `"0x1zz01"`, wantErr: wrapTypeError(ErrSyntax, bigT)},
{
input: `"0x10000000000000000000000000000000000000000000000000000000000000000"`,
- wantErr: ErrBig256Range,
+ wantErr: wrapTypeError(ErrBig256Range, bigT),
},
// valid encoding
@@ -208,14 +208,14 @@ func TestMarshalBig(t *testing.T) {
var unmarshalUint64Tests = []unmarshalTest{
// invalid encoding
{input: "", wantErr: errJSONEOF},
- {input: "null", wantErr: errNonString},
- {input: "10", wantErr: errNonString},
- {input: `"0"`, wantErr: ErrMissingPrefix},
- {input: `"0x"`, wantErr: ErrEmptyNumber},
- {input: `"0x01"`, wantErr: ErrLeadingZero},
- {input: `"0xfffffffffffffffff"`, wantErr: ErrUint64Range},
- {input: `"0xx"`, wantErr: ErrSyntax},
- {input: `"0x1zz01"`, wantErr: ErrSyntax},
+ {input: "null", wantErr: errNonString(uint64T)},
+ {input: "10", wantErr: errNonString(uint64T)},
+ {input: `"0"`, wantErr: wrapTypeError(ErrMissingPrefix, uint64T)},
+ {input: `"0x"`, wantErr: wrapTypeError(ErrEmptyNumber, uint64T)},
+ {input: `"0x01"`, wantErr: wrapTypeError(ErrLeadingZero, uint64T)},
+ {input: `"0xfffffffffffffffff"`, wantErr: wrapTypeError(ErrUint64Range, uint64T)},
+ {input: `"0xx"`, wantErr: wrapTypeError(ErrSyntax, uint64T)},
+ {input: `"0x1zz01"`, wantErr: wrapTypeError(ErrSyntax, uint64T)},
// valid encoding
{input: `""`, want: uint64(0)},
@@ -298,15 +298,15 @@ var (
var unmarshalUintTests = []unmarshalTest{
// invalid encoding
{input: "", wantErr: errJSONEOF},
- {input: "null", wantErr: errNonString},
- {input: "10", wantErr: errNonString},
- {input: `"0"`, wantErr: ErrMissingPrefix},
- {input: `"0x"`, wantErr: ErrEmptyNumber},
- {input: `"0x01"`, wantErr: ErrLeadingZero},
- {input: `"0x100000000"`, want: uint(maxUint33bits), wantErr32bit: ErrUintRange},
- {input: `"0xfffffffffffffffff"`, wantErr: ErrUintRange},
- {input: `"0xx"`, wantErr: ErrSyntax},
- {input: `"0x1zz01"`, wantErr: ErrSyntax},
+ {input: "null", wantErr: errNonString(uintT)},
+ {input: "10", wantErr: errNonString(uintT)},
+ {input: `"0"`, wantErr: wrapTypeError(ErrMissingPrefix, uintT)},
+ {input: `"0x"`, wantErr: wrapTypeError(ErrEmptyNumber, uintT)},
+ {input: `"0x01"`, wantErr: wrapTypeError(ErrLeadingZero, uintT)},
+ {input: `"0x100000000"`, want: uint(maxUint33bits), wantErr32bit: wrapTypeError(ErrUintRange, uintT)},
+ {input: `"0xfffffffffffffffff"`, wantErr: wrapTypeError(ErrUintRange, uintT)},
+ {input: `"0xx"`, wantErr: wrapTypeError(ErrSyntax, uintT)},
+ {input: `"0x1zz01"`, wantErr: wrapTypeError(ErrSyntax, uintT)},
// valid encoding
{input: `""`, want: uint(0)},
@@ -317,7 +317,7 @@ var unmarshalUintTests = []unmarshalTest{
{input: `"0x1122aaff"`, want: uint(0x1122aaff)},
{input: `"0xbbb"`, want: uint(0xbbb)},
{input: `"0xffffffff"`, want: uint(0xffffffff)},
- {input: `"0xffffffffffffffff"`, want: uint(maxUint64bits), wantErr32bit: ErrUintRange},
+ {input: `"0xffffffffffffffff"`, want: uint(maxUint64bits), wantErr32bit: wrapTypeError(ErrUintRange, uintT)},
}
func TestUnmarshalUint(t *testing.T) {
diff --git a/common/types.go b/common/types.go
index 05288bf46..803726634 100644
--- a/common/types.go
+++ b/common/types.go
@@ -31,6 +31,11 @@ const (
AddressLength = 20
)
+var (
+ hashT = reflect.TypeOf(Hash{})
+ addressT = reflect.TypeOf(Address{})
+)
+
// Hash represents the 32 byte Keccak256 hash of arbitrary data.
type Hash [HashLength]byte
@@ -72,6 +77,11 @@ func (h *Hash) UnmarshalText(input []byte) error {
return hexutil.UnmarshalFixedText("Hash", input, h[:])
}
+// UnmarshalJSON parses a hash in hex syntax.
+func (h *Hash) UnmarshalJSON(input []byte) error {
+ return hexutil.UnmarshalFixedJSON(hashT, input, h[:])
+}
+
// MarshalText returns the hex representation of h.
func (h Hash) MarshalText() ([]byte, error) {
return hexutil.Bytes(h[:]).MarshalText()
@@ -194,6 +204,11 @@ func (a *Address) UnmarshalText(input []byte) error {
return hexutil.UnmarshalFixedText("Address", input, a[:])
}
+// UnmarshalJSON parses a hash in hex syntax.
+func (a *Address) UnmarshalJSON(input []byte) error {
+ return hexutil.UnmarshalFixedJSON(addressT, input, a[:])
+}
+
// UnprefixedHash allows marshaling an Address without 0x prefix.
type UnprefixedAddress Address
diff --git a/common/types_test.go b/common/types_test.go
index 9f9d8b767..154c33063 100644
--- a/common/types_test.go
+++ b/common/types_test.go
@@ -21,8 +21,6 @@ import (
"math/big"
"strings"
"testing"
-
- "github.com/ethereum/go-ethereum/common/hexutil"
)
func TestBytesConversion(t *testing.T) {
@@ -43,10 +41,10 @@ func TestHashJsonValidation(t *testing.T) {
Size int
Error string
}{
- {"", 62, hexutil.ErrMissingPrefix.Error()},
- {"0x", 66, "hex string has length 66, want 64 for Hash"},
- {"0x", 63, hexutil.ErrOddLength.Error()},
- {"0x", 0, "hex string has length 0, want 64 for Hash"},
+ {"", 62, "json: cannot unmarshal hex string without 0x prefix into Go value of type common.Hash"},
+ {"0x", 66, "hex string has length 66, want 64 for common.Hash"},
+ {"0x", 63, "json: cannot unmarshal hex string of odd length into Go value of type common.Hash"},
+ {"0x", 0, "hex string has length 0, want 64 for common.Hash"},
{"0x", 64, ""},
{"0X", 64, ""},
}
diff --git a/core/block_validator.go b/core/block_validator.go
index 4f85df007..e9cfd0482 100644
--- a/core/block_validator.go
+++ b/core/block_validator.go
@@ -52,16 +52,10 @@ func NewBlockValidator(config *params.ChainConfig, blockchain *BlockChain, engin
// validated at this point.
func (v *BlockValidator) ValidateBody(block *types.Block) error {
// Check whether the block's known, and if not, that it's linkable
- if v.bc.HasBlock(block.Hash()) {
- if _, err := state.New(block.Root(), v.bc.chainDb); err == nil {
- return ErrKnownBlock
- }
+ if v.bc.HasBlockAndState(block.Hash()) {
+ return ErrKnownBlock
}
- parent := v.bc.GetBlock(block.ParentHash(), block.NumberU64()-1)
- if parent == nil {
- return consensus.ErrUnknownAncestor
- }
- if _, err := state.New(parent.Root(), v.bc.chainDb); err != nil {
+ if !v.bc.HasBlockAndState(block.ParentHash()) {
return consensus.ErrUnknownAncestor
}
// Header validity is known at this point, check the uncles and transactions
diff --git a/core/blockchain.go b/core/blockchain.go
index 073b91bab..6772ea284 100644
--- a/core/blockchain.go
+++ b/core/blockchain.go
@@ -92,7 +92,7 @@ type BlockChain struct {
currentBlock *types.Block // Current head of the block chain
currentFastBlock *types.Block // Current head of the fast-sync chain (may be above the block chain!)
- stateCache *state.StateDB // State database to reuse between imports (contains state cache)
+ stateCache state.Database // State database to reuse between imports (contains state cache)
bodyCache *lru.Cache // Cache for the most recent block bodies
bodyRLPCache *lru.Cache // Cache for the most recent block bodies in RLP encoded format
blockCache *lru.Cache // Cache for the most recent entire blocks
@@ -125,6 +125,7 @@ func NewBlockChain(chainDb ethdb.Database, config *params.ChainConfig, engine co
bc := &BlockChain{
config: config,
chainDb: chainDb,
+ stateCache: state.NewDatabase(chainDb),
eventMux: mux,
quit: make(chan struct{}),
bodyCache: bodyCache,
@@ -190,7 +191,7 @@ func (bc *BlockChain) loadLastState() error {
return bc.Reset()
}
// Make sure the state associated with the block is available
- if _, err := state.New(currentBlock.Root(), bc.chainDb); err != nil {
+ if _, err := state.New(currentBlock.Root(), bc.stateCache); err != nil {
// Dangling block without a state associated, init from scratch
log.Warn("Head state missing, resetting chain", "number", currentBlock.Number(), "hash", currentBlock.Hash())
return bc.Reset()
@@ -214,12 +215,6 @@ func (bc *BlockChain) loadLastState() error {
bc.currentFastBlock = block
}
}
- // Initialize a statedb cache to ensure singleton account bloom filter generation
- statedb, err := state.New(bc.currentBlock.Root(), bc.chainDb)
- if err != nil {
- return err
- }
- bc.stateCache = statedb
// Issue a status log for the user
headerTd := bc.GetTd(currentHeader.Hash(), currentHeader.Number.Uint64())
@@ -261,7 +256,7 @@ func (bc *BlockChain) SetHead(head uint64) error {
bc.currentBlock = bc.GetBlock(currentHeader.Hash(), currentHeader.Number.Uint64())
}
if bc.currentBlock != nil {
- if _, err := state.New(bc.currentBlock.Root(), bc.chainDb); err != nil {
+ if _, err := state.New(bc.currentBlock.Root(), bc.stateCache); err != nil {
// Rewound state missing, rolled back to before pivot, reset to genesis
bc.currentBlock = nil
}
@@ -384,7 +379,7 @@ func (bc *BlockChain) State() (*state.StateDB, error) {
// StateAt returns a new mutable state based on a particular point in time.
func (bc *BlockChain) StateAt(root common.Hash) (*state.StateDB, error) {
- return bc.stateCache.New(root)
+ return state.New(root, bc.stateCache)
}
// Reset purges the entire blockchain, restoring it to its genesis state.
@@ -531,7 +526,7 @@ func (bc *BlockChain) HasBlockAndState(hash common.Hash) bool {
return false
}
// Ensure the associated state is also present
- _, err := state.New(block.Root(), bc.chainDb)
+ _, err := bc.stateCache.OpenTrie(block.Root())
return err == nil
}
@@ -959,31 +954,30 @@ func (bc *BlockChain) InsertChain(chain types.Blocks) (int, error) {
}
// Create a new statedb using the parent block and report an
// error if it fails.
- switch {
- case i == 0:
- err = bc.stateCache.Reset(bc.GetBlock(block.ParentHash(), block.NumberU64()-1).Root())
- default:
- err = bc.stateCache.Reset(chain[i-1].Root())
+ var parent *types.Block
+ if i == 0 {
+ parent = bc.GetBlock(block.ParentHash(), block.NumberU64()-1)
+ } else {
+ parent = chain[i-1]
}
+ state, err := state.New(parent.Root(), bc.stateCache)
if err != nil {
- bc.reportBlock(block, nil, err)
return i, err
}
// Process block using the parent state as reference point.
- receipts, logs, usedGas, err := bc.processor.Process(block, bc.stateCache, bc.vmConfig)
+ receipts, logs, usedGas, err := bc.processor.Process(block, state, bc.vmConfig)
if err != nil {
bc.reportBlock(block, receipts, err)
return i, err
}
// Validate the state using the default validator
- err = bc.Validator().ValidateState(block, bc.GetBlock(block.ParentHash(), block.NumberU64()-1), bc.stateCache, receipts, usedGas)
+ err = bc.Validator().ValidateState(block, parent, state, receipts, usedGas)
if err != nil {
bc.reportBlock(block, receipts, err)
return i, err
}
// Write state changes to database
- _, err = bc.stateCache.Commit(bc.config.IsEIP158(block.Number()))
- if err != nil {
+ if _, err = state.CommitTo(bc.chainDb, bc.config.IsEIP158(block.Number())); err != nil {
return i, err
}
@@ -1021,7 +1015,7 @@ func (bc *BlockChain) InsertChain(chain types.Blocks) (int, error) {
return i, err
}
// Write hash preimages
- if err := WritePreimages(bc.chainDb, block.NumberU64(), bc.stateCache.Preimages()); err != nil {
+ if err := WritePreimages(bc.chainDb, block.NumberU64(), state.Preimages()); err != nil {
return i, err
}
case SideStatTy:
@@ -1079,7 +1073,7 @@ func (st *insertStats) report(chain []*types.Block, index int) {
}
log.Info("Imported new chain segment", context...)
- *st = insertStats{startTime: now, lastIndex: index}
+ *st = insertStats{startTime: now, lastIndex: index + 1}
}
}
diff --git a/core/blockchain_test.go b/core/blockchain_test.go
index 7505208e1..371522ab7 100644
--- a/core/blockchain_test.go
+++ b/core/blockchain_test.go
@@ -131,7 +131,7 @@ func testBlockChainImport(chain types.Blocks, blockchain *BlockChain) error {
}
return err
}
- statedb, err := state.New(blockchain.GetBlockByHash(block.ParentHash()).Root(), blockchain.chainDb)
+ statedb, err := state.New(blockchain.GetBlockByHash(block.ParentHash()).Root(), blockchain.stateCache)
if err != nil {
return err
}
@@ -148,7 +148,7 @@ func testBlockChainImport(chain types.Blocks, blockchain *BlockChain) error {
blockchain.mu.Lock()
WriteTd(blockchain.chainDb, block.Hash(), block.NumberU64(), new(big.Int).Add(block.Difficulty(), blockchain.GetTdByHash(block.ParentHash())))
WriteBlock(blockchain.chainDb, block)
- statedb.Commit(false)
+ statedb.CommitTo(blockchain.chainDb, false)
blockchain.mu.Unlock()
}
return nil
@@ -1131,7 +1131,7 @@ func TestEIP161AccountRemoval(t *testing.T) {
if _, err := blockchain.InsertChain(types.Blocks{blocks[0]}); err != nil {
t.Fatal(err)
}
- if !blockchain.stateCache.Exist(theAddr) {
+ if st, _ := blockchain.State(); !st.Exist(theAddr) {
t.Error("expected account to exist")
}
@@ -1139,7 +1139,7 @@ func TestEIP161AccountRemoval(t *testing.T) {
if _, err := blockchain.InsertChain(types.Blocks{blocks[1]}); err != nil {
t.Fatal(err)
}
- if blockchain.stateCache.Exist(theAddr) {
+ if st, _ := blockchain.State(); st.Exist(theAddr) {
t.Error("account should not exist")
}
@@ -1147,7 +1147,7 @@ func TestEIP161AccountRemoval(t *testing.T) {
if _, err := blockchain.InsertChain(types.Blocks{blocks[2]}); err != nil {
t.Fatal(err)
}
- if blockchain.stateCache.Exist(theAddr) {
+ if st, _ := blockchain.State(); st.Exist(theAddr) {
t.Error("account should not exist")
}
}
diff --git a/core/chain_makers.go b/core/chain_makers.go
index cc14f8fb8..38a69d42a 100644
--- a/core/chain_makers.go
+++ b/core/chain_makers.go
@@ -181,7 +181,7 @@ func GenerateChain(config *params.ChainConfig, parent *types.Block, db ethdb.Dat
gen(i, b)
}
ethash.AccumulateRewards(statedb, h, b.uncles)
- root, err := statedb.Commit(config.IsEIP158(h.Number))
+ root, err := statedb.CommitTo(db, config.IsEIP158(h.Number))
if err != nil {
panic(fmt.Sprintf("state write error: %v", err))
}
@@ -189,7 +189,7 @@ func GenerateChain(config *params.ChainConfig, parent *types.Block, db ethdb.Dat
return types.NewBlock(h, b.txs, b.uncles, b.receipts), b.receipts
}
for i := 0; i < n; i++ {
- statedb, err := state.New(parent.Root(), db)
+ statedb, err := state.New(parent.Root(), state.NewDatabase(db))
if err != nil {
panic(err)
}
diff --git a/core/genesis.go b/core/genesis.go
index 947a53c70..5815d5901 100644
--- a/core/genesis.go
+++ b/core/genesis.go
@@ -176,7 +176,7 @@ func (g *Genesis) configOrDefault(ghash common.Hash) *params.ChainConfig {
// ToBlock creates the block and state of a genesis specification.
func (g *Genesis) ToBlock() (*types.Block, *state.StateDB) {
db, _ := ethdb.NewMemDatabase()
- statedb, _ := state.New(common.Hash{}, db)
+ statedb, _ := state.New(common.Hash{}, state.NewDatabase(db))
for addr, account := range g.Alloc {
statedb.AddBalance(addr, account.Balance)
statedb.SetCode(addr, account.Code)
diff --git a/core/state/database.go b/core/state/database.go
new file mode 100644
index 000000000..946625e76
--- /dev/null
+++ b/core/state/database.go
@@ -0,0 +1,154 @@
+// Copyright 2017 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+package state
+
+import (
+ "fmt"
+ "sync"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/ethdb"
+ "github.com/ethereum/go-ethereum/trie"
+ lru "github.com/hashicorp/golang-lru"
+)
+
+// Trie cache generation limit after which to evic trie nodes from memory.
+var MaxTrieCacheGen = uint16(120)
+
+const (
+ // Number of past tries to keep. This value is chosen such that
+ // reasonable chain reorg depths will hit an existing trie.
+ maxPastTries = 12
+
+ // Number of codehash->size associations to keep.
+ codeSizeCacheSize = 100000
+)
+
+// Database wraps access to tries and contract code.
+type Database interface {
+ // Accessing tries:
+ // OpenTrie opens the main account trie.
+ // OpenStorageTrie opens the storage trie of an account.
+ OpenTrie(root common.Hash) (Trie, error)
+ OpenStorageTrie(addrHash, root common.Hash) (Trie, error)
+ // Accessing contract code:
+ ContractCode(addrHash, codeHash common.Hash) ([]byte, error)
+ ContractCodeSize(addrHash, codeHash common.Hash) (int, error)
+ // CopyTrie returns an independent copy of the given trie.
+ CopyTrie(Trie) Trie
+}
+
+// Trie is a Ethereum Merkle Trie.
+type Trie interface {
+ TryGet(key []byte) ([]byte, error)
+ TryUpdate(key, value []byte) error
+ TryDelete(key []byte) error
+ CommitTo(trie.DatabaseWriter) (common.Hash, error)
+ Hash() common.Hash
+ NodeIterator(startKey []byte) trie.NodeIterator
+ GetKey([]byte) []byte // TODO(fjl): remove this when SecureTrie is removed
+}
+
+// NewDatabase creates a backing store for state. The returned database is safe for
+// concurrent use and retains cached trie nodes in memory.
+func NewDatabase(db ethdb.Database) Database {
+ csc, _ := lru.New(codeSizeCacheSize)
+ return &cachingDB{db: db, codeSizeCache: csc}
+}
+
+type cachingDB struct {
+ db ethdb.Database
+ mu sync.Mutex
+ pastTries []*trie.SecureTrie
+ codeSizeCache *lru.Cache
+}
+
+func (db *cachingDB) OpenTrie(root common.Hash) (Trie, error) {
+ db.mu.Lock()
+ defer db.mu.Unlock()
+
+ for i := len(db.pastTries) - 1; i >= 0; i-- {
+ if db.pastTries[i].Hash() == root {
+ return cachedTrie{db.pastTries[i].Copy(), db}, nil
+ }
+ }
+ tr, err := trie.NewSecure(root, db.db, MaxTrieCacheGen)
+ if err != nil {
+ return nil, err
+ }
+ return cachedTrie{tr, db}, nil
+}
+
+func (db *cachingDB) pushTrie(t *trie.SecureTrie) {
+ db.mu.Lock()
+ defer db.mu.Unlock()
+
+ if len(db.pastTries) >= maxPastTries {
+ copy(db.pastTries, db.pastTries[1:])
+ db.pastTries[len(db.pastTries)-1] = t
+ } else {
+ db.pastTries = append(db.pastTries, t)
+ }
+}
+
+func (db *cachingDB) OpenStorageTrie(addrHash, root common.Hash) (Trie, error) {
+ return trie.NewSecure(root, db.db, 0)
+}
+
+func (db *cachingDB) CopyTrie(t Trie) Trie {
+ switch t := t.(type) {
+ case cachedTrie:
+ return cachedTrie{t.SecureTrie.Copy(), db}
+ case *trie.SecureTrie:
+ return t.Copy()
+ default:
+ panic(fmt.Errorf("unknown trie type %T", t))
+ }
+}
+
+func (db *cachingDB) ContractCode(addrHash, codeHash common.Hash) ([]byte, error) {
+ code, err := db.db.Get(codeHash[:])
+ if err == nil {
+ db.codeSizeCache.Add(codeHash, len(code))
+ }
+ return code, err
+}
+
+func (db *cachingDB) ContractCodeSize(addrHash, codeHash common.Hash) (int, error) {
+ if cached, ok := db.codeSizeCache.Get(codeHash); ok {
+ return cached.(int), nil
+ }
+ code, err := db.ContractCode(addrHash, codeHash)
+ if err == nil {
+ db.codeSizeCache.Add(codeHash, len(code))
+ }
+ return len(code), err
+}
+
+// cachedTrie inserts its trie into a cachingDB on commit.
+type cachedTrie struct {
+ *trie.SecureTrie
+ db *cachingDB
+}
+
+func (m cachedTrie) CommitTo(dbw trie.DatabaseWriter) (common.Hash, error) {
+ root, err := m.SecureTrie.CommitTo(dbw)
+ if err == nil {
+ m.db.pushTrie(m.SecureTrie)
+ }
+ return root, err
+}
diff --git a/core/state/dump.go b/core/state/dump.go
index ffa1a7283..46e612850 100644
--- a/core/state/dump.go
+++ b/core/state/dump.go
@@ -41,7 +41,7 @@ type Dump struct {
func (self *StateDB) RawDump() Dump {
dump := Dump{
- Root: common.Bytes2Hex(self.trie.Root()),
+ Root: fmt.Sprintf("%x", self.trie.Hash()),
Accounts: make(map[string]DumpAccount),
}
diff --git a/core/state/iterator.go b/core/state/iterator.go
index a8a2722ae..6a5c73d3d 100644
--- a/core/state/iterator.go
+++ b/core/state/iterator.go
@@ -19,7 +19,6 @@ package state
import (
"bytes"
"fmt"
- "math/big"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/rlp"
@@ -105,16 +104,11 @@ func (it *NodeIterator) step() error {
return nil
}
// Otherwise we've reached an account node, initiate data iteration
- var account struct {
- Nonce uint64
- Balance *big.Int
- Root common.Hash
- CodeHash []byte
- }
+ var account Account
if err := rlp.Decode(bytes.NewReader(it.stateIt.LeafBlob()), &account); err != nil {
return err
}
- dataTrie, err := trie.New(account.Root, it.state.db)
+ dataTrie, err := it.state.db.OpenStorageTrie(common.BytesToHash(it.stateIt.LeafKey()), account.Root)
if err != nil {
return err
}
@@ -124,7 +118,8 @@ func (it *NodeIterator) step() error {
}
if !bytes.Equal(account.CodeHash, emptyCodeHash) {
it.codeHash = common.BytesToHash(account.CodeHash)
- it.code, err = it.state.db.Get(account.CodeHash)
+ addrHash := common.BytesToHash(it.stateIt.LeafKey())
+ it.code, err = it.state.db.ContractCode(addrHash, common.BytesToHash(account.CodeHash))
if err != nil {
return fmt.Errorf("code %x: %v", account.CodeHash, err)
}
diff --git a/core/state/iterator_test.go b/core/state/iterator_test.go
index aa9c5b728..ff66ba7a9 100644
--- a/core/state/iterator_test.go
+++ b/core/state/iterator_test.go
@@ -21,13 +21,12 @@ import (
"testing"
"github.com/ethereum/go-ethereum/common"
- "github.com/ethereum/go-ethereum/ethdb"
)
// Tests that the node iterator indeed walks over the entire database contents.
func TestNodeIteratorCoverage(t *testing.T) {
// Create some arbitrary test state to iterate
- db, root, _ := makeTestState()
+ db, mem, root, _ := makeTestState()
state, err := New(root, db)
if err != nil {
@@ -40,13 +39,14 @@ func TestNodeIteratorCoverage(t *testing.T) {
hashes[it.Hash] = struct{}{}
}
}
+
// Cross check the hashes and the database itself
for hash := range hashes {
- if _, err := db.Get(hash.Bytes()); err != nil {
+ if _, err := mem.Get(hash.Bytes()); err != nil {
t.Errorf("failed to retrieve reported node %x: %v", hash, err)
}
}
- for _, key := range db.(*ethdb.MemDatabase).Keys() {
+ for _, key := range mem.Keys() {
if bytes.HasPrefix(key, []byte("secure-key-")) {
continue
}
diff --git a/core/state/managed_state_test.go b/core/state/managed_state_test.go
index ea5737a08..1cfdd3a89 100644
--- a/core/state/managed_state_test.go
+++ b/core/state/managed_state_test.go
@@ -27,7 +27,7 @@ var addr = common.BytesToAddress([]byte("test"))
func create() (*ManagedState, *account) {
db, _ := ethdb.NewMemDatabase()
- statedb, _ := New(common.Hash{}, db)
+ statedb, _ := New(common.Hash{}, NewDatabase(db))
ms := ManageState(statedb)
ms.StateDB.SetNonce(addr, 100)
ms.accounts[addr] = newAccount(ms.StateDB.getStateObject(addr))
diff --git a/core/state/state_object.go b/core/state/state_object.go
index dcad9d068..b2378c69c 100644
--- a/core/state/state_object.go
+++ b/core/state/state_object.go
@@ -62,9 +62,10 @@ func (self Storage) Copy() Storage {
// Account values can be accessed and modified through the object.
// Finally, call CommitTrie to write the modified storage trie into a database.
type stateObject struct {
- address common.Address // Ethereum address of this account
- data Account
- db *StateDB
+ address common.Address
+ addrHash common.Hash // hash of ethereum address of the account
+ data Account
+ db *StateDB
// DB error.
// State objects are used by the consensus core and VM which are
@@ -74,8 +75,8 @@ type stateObject struct {
dbErr error
// Write caches.
- trie *trie.SecureTrie // storage trie, which becomes non-nil on first access
- code Code // contract bytecode, which gets set when code is loaded
+ trie Trie // storage trie, which becomes non-nil on first access
+ code Code // contract bytecode, which gets set when code is loaded
cachedStorage Storage // Storage entry cache to avoid duplicate reads
dirtyStorage Storage // Storage entries that need to be flushed to disk
@@ -112,7 +113,15 @@ func newObject(db *StateDB, address common.Address, data Account, onDirty func(a
if data.CodeHash == nil {
data.CodeHash = emptyCodeHash
}
- return &stateObject{db: db, address: address, data: data, cachedStorage: make(Storage), dirtyStorage: make(Storage), onDirty: onDirty}
+ return &stateObject{
+ db: db,
+ address: address,
+ addrHash: crypto.Keccak256Hash(address[:]),
+ data: data,
+ cachedStorage: make(Storage),
+ dirtyStorage: make(Storage),
+ onDirty: onDirty,
+ }
}
// EncodeRLP implements rlp.Encoder.
@@ -148,12 +157,12 @@ func (c *stateObject) touch() {
c.touched = true
}
-func (c *stateObject) getTrie(db trie.Database) *trie.SecureTrie {
+func (c *stateObject) getTrie(db Database) Trie {
if c.trie == nil {
var err error
- c.trie, err = trie.NewSecure(c.data.Root, db, 0)
+ c.trie, err = db.OpenStorageTrie(c.addrHash, c.data.Root)
if err != nil {
- c.trie, _ = trie.NewSecure(common.Hash{}, db, 0)
+ c.trie, _ = db.OpenStorageTrie(c.addrHash, common.Hash{})
c.setError(fmt.Errorf("can't create storage trie: %v", err))
}
}
@@ -161,13 +170,18 @@ func (c *stateObject) getTrie(db trie.Database) *trie.SecureTrie {
}
// GetState returns a value in account storage.
-func (self *stateObject) GetState(db trie.Database, key common.Hash) common.Hash {
+func (self *stateObject) GetState(db Database, key common.Hash) common.Hash {
value, exists := self.cachedStorage[key]
if exists {
return value
}
// Load from DB in case it is missing.
- if enc := self.getTrie(db).Get(key[:]); len(enc) > 0 {
+ enc, err := self.getTrie(db).TryGet(key[:])
+ if err != nil {
+ self.setError(err)
+ return common.Hash{}
+ }
+ if len(enc) > 0 {
_, content, _, err := rlp.Split(enc)
if err != nil {
self.setError(err)
@@ -181,7 +195,7 @@ func (self *stateObject) GetState(db trie.Database, key common.Hash) common.Hash
}
// SetState updates a value in account storage.
-func (self *stateObject) SetState(db trie.Database, key, value common.Hash) {
+func (self *stateObject) SetState(db Database, key, value common.Hash) {
self.db.journal = append(self.db.journal, storageChange{
account: &self.address,
key: key,
@@ -201,30 +215,30 @@ func (self *stateObject) setState(key, value common.Hash) {
}
// updateTrie writes cached storage modifications into the object's storage trie.
-func (self *stateObject) updateTrie(db trie.Database) *trie.SecureTrie {
+func (self *stateObject) updateTrie(db Database) Trie {
tr := self.getTrie(db)
for key, value := range self.dirtyStorage {
delete(self.dirtyStorage, key)
if (value == common.Hash{}) {
- tr.Delete(key[:])
+ self.setError(tr.TryDelete(key[:]))
continue
}
// Encoding []byte cannot fail, ok to ignore the error.
v, _ := rlp.EncodeToBytes(bytes.TrimLeft(value[:], "\x00"))
- tr.Update(key[:], v)
+ self.setError(tr.TryUpdate(key[:], v))
}
return tr
}
// UpdateRoot sets the trie root to the current root hash of
-func (self *stateObject) updateRoot(db trie.Database) {
+func (self *stateObject) updateRoot(db Database) {
self.updateTrie(db)
self.data.Root = self.trie.Hash()
}
// CommitTrie the storage trie of the object to dwb.
// This updates the trie root.
-func (self *stateObject) CommitTrie(db trie.Database, dbw trie.DatabaseWriter) error {
+func (self *stateObject) CommitTrie(db Database, dbw trie.DatabaseWriter) error {
self.updateTrie(db)
if self.dbErr != nil {
return self.dbErr
@@ -282,9 +296,7 @@ func (c *stateObject) ReturnGas(gas *big.Int) {}
func (self *stateObject) deepCopy(db *StateDB, onDirty func(addr common.Address)) *stateObject {
stateObject := newObject(db, self.address, self.data, onDirty)
if self.trie != nil {
- // A shallow copy makes the two tries independent.
- cpy := *self.trie
- stateObject.trie = &cpy
+ stateObject.trie = db.db.CopyTrie(self.trie)
}
stateObject.code = self.code
stateObject.dirtyStorage = self.dirtyStorage.Copy()
@@ -305,14 +317,14 @@ func (c *stateObject) Address() common.Address {
}
// Code returns the contract code associated with this object, if any.
-func (self *stateObject) Code(db trie.Database) []byte {
+func (self *stateObject) Code(db Database) []byte {
if self.code != nil {
return self.code
}
if bytes.Equal(self.CodeHash(), emptyCodeHash) {
return nil
}
- code, err := db.Get(self.CodeHash())
+ code, err := db.ContractCode(self.addrHash, common.BytesToHash(self.CodeHash()))
if err != nil {
self.setError(fmt.Errorf("can't load code hash %x: %v", self.CodeHash(), err))
}
diff --git a/core/state/state_test.go b/core/state/state_test.go
index 3bc63c148..bbae3685b 100644
--- a/core/state/state_test.go
+++ b/core/state/state_test.go
@@ -21,14 +21,14 @@ import (
"math/big"
"testing"
- checker "gopkg.in/check.v1"
-
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb"
+ checker "gopkg.in/check.v1"
)
type StateSuite struct {
+ db *ethdb.MemDatabase
state *StateDB
}
@@ -48,7 +48,7 @@ func (s *StateSuite) TestDump(c *checker.C) {
// write some of them to the trie
s.state.updateStateObject(obj1)
s.state.updateStateObject(obj2)
- s.state.Commit(false)
+ s.state.CommitTo(s.db, false)
// check that dump contains the state objects that are in trie
got := string(s.state.Dump())
@@ -87,23 +87,20 @@ func (s *StateSuite) TestDump(c *checker.C) {
}
func (s *StateSuite) SetUpTest(c *checker.C) {
- db, _ := ethdb.NewMemDatabase()
- s.state, _ = New(common.Hash{}, db)
+ s.db, _ = ethdb.NewMemDatabase()
+ s.state, _ = New(common.Hash{}, NewDatabase(s.db))
}
-func TestNull(t *testing.T) {
- db, _ := ethdb.NewMemDatabase()
- state, _ := New(common.Hash{}, db)
-
+func (s *StateSuite) TestNull(c *checker.C) {
address := common.HexToAddress("0x823140710bf13990e4500136726d8b55")
- state.CreateAccount(address)
+ s.state.CreateAccount(address)
//value := common.FromHex("0x823140710bf13990e4500136726d8b55")
var value common.Hash
- state.SetState(address, common.Hash{}, value)
- state.Commit(false)
- value = state.GetState(address, common.Hash{})
+ s.state.SetState(address, common.Hash{}, value)
+ s.state.CommitTo(s.db, false)
+ value = s.state.GetState(address, common.Hash{})
if !common.EmptyHash(value) {
- t.Errorf("expected empty hash. got %x", value)
+ c.Errorf("expected empty hash. got %x", value)
}
}
@@ -129,17 +126,15 @@ func (s *StateSuite) TestSnapshot(c *checker.C) {
c.Assert(data1, checker.DeepEquals, res)
}
-func TestSnapshotEmpty(t *testing.T) {
- db, _ := ethdb.NewMemDatabase()
- state, _ := New(common.Hash{}, db)
- state.RevertToSnapshot(state.Snapshot())
+func (s *StateSuite) TestSnapshotEmpty(c *checker.C) {
+ s.state.RevertToSnapshot(s.state.Snapshot())
}
// use testing instead of checker because checker does not support
// printing/logging in tests (-check.vv does not work)
func TestSnapshot2(t *testing.T) {
db, _ := ethdb.NewMemDatabase()
- state, _ := New(common.Hash{}, db)
+ state, _ := New(common.Hash{}, NewDatabase(db))
stateobjaddr0 := toAddr([]byte("so0"))
stateobjaddr1 := toAddr([]byte("so1"))
@@ -160,7 +155,7 @@ func TestSnapshot2(t *testing.T) {
so0.deleted = false
state.setStateObject(so0)
- root, _ := state.Commit(false)
+ root, _ := state.CommitTo(db, false)
state.Reset(root)
// and one with deleted == true
@@ -182,8 +177,8 @@ func TestSnapshot2(t *testing.T) {
so0Restored := state.getStateObject(stateobjaddr0)
// Update lazily-loaded values before comparing.
- so0Restored.GetState(db, storageaddr)
- so0Restored.Code(db)
+ so0Restored.GetState(state.db, storageaddr)
+ so0Restored.Code(state.db)
// non-deleted is equal (restored)
compareStateObjects(so0Restored, so0, t)
diff --git a/core/state/statedb.go b/core/state/statedb.go
index 05869a0c8..694374f82 100644
--- a/core/state/statedb.go
+++ b/core/state/statedb.go
@@ -26,23 +26,9 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto"
- "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/trie"
- lru "github.com/hashicorp/golang-lru"
-)
-
-// Trie cache generation limit after which to evic trie nodes from memory.
-var MaxTrieCacheGen = uint16(120)
-
-const (
- // Number of past tries to keep. This value is chosen such that
- // reasonable chain reorg depths will hit an existing trie.
- maxPastTries = 12
-
- // Number of codehash->size associations to keep.
- codeSizeCacheSize = 100000
)
type revision struct {
@@ -56,16 +42,21 @@ type revision struct {
// * Contracts
// * Accounts
type StateDB struct {
- db ethdb.Database
- trie *trie.SecureTrie
- pastTries []*trie.SecureTrie
- codeSizeCache *lru.Cache
+ db Database
+ trie Trie
// This map holds 'live' objects, which will get modified while processing a state transition.
stateObjects map[common.Address]*stateObject
stateObjectsDirty map[common.Address]struct{}
stateObjectsDestructed map[common.Address]struct{}
+ // DB error.
+ // State objects are used by the consensus core and VM which are
+ // unable to deal with database-level errors. Any error that occurs
+ // during a database read is memoized here and will eventually be returned
+ // by StateDB.Commit.
+ dbErr error
+
// The refund counter, also used by state transitioning.
refund *big.Int
@@ -86,16 +77,14 @@ type StateDB struct {
}
// Create a new state from a given trie
-func New(root common.Hash, db ethdb.Database) (*StateDB, error) {
- tr, err := trie.NewSecure(root, db, MaxTrieCacheGen)
+func New(root common.Hash, db Database) (*StateDB, error) {
+ tr, err := db.OpenTrie(root)
if err != nil {
return nil, err
}
- csc, _ := lru.New(codeSizeCacheSize)
return &StateDB{
db: db,
trie: tr,
- codeSizeCache: csc,
stateObjects: make(map[common.Address]*stateObject),
stateObjectsDirty: make(map[common.Address]struct{}),
stateObjectsDestructed: make(map[common.Address]struct{}),
@@ -105,36 +94,21 @@ func New(root common.Hash, db ethdb.Database) (*StateDB, error) {
}, nil
}
-// New creates a new statedb by reusing any journalled tries to avoid costly
-// disk io.
-func (self *StateDB) New(root common.Hash) (*StateDB, error) {
- self.lock.Lock()
- defer self.lock.Unlock()
-
- tr, err := self.openTrie(root)
- if err != nil {
- return nil, err
+// setError remembers the first non-nil error it is called with.
+func (self *StateDB) setError(err error) {
+ if self.dbErr == nil {
+ self.dbErr = err
}
- return &StateDB{
- db: self.db,
- trie: tr,
- codeSizeCache: self.codeSizeCache,
- stateObjects: make(map[common.Address]*stateObject),
- stateObjectsDirty: make(map[common.Address]struct{}),
- stateObjectsDestructed: make(map[common.Address]struct{}),
- refund: new(big.Int),
- logs: make(map[common.Hash][]*types.Log),
- preimages: make(map[common.Hash][]byte),
- }, nil
+}
+
+func (self *StateDB) Error() error {
+ return self.dbErr
}
// Reset clears out all emphemeral state objects from the state db, but keeps
// the underlying state trie to avoid reloading data for the next operations.
func (self *StateDB) Reset(root common.Hash) error {
- self.lock.Lock()
- defer self.lock.Unlock()
-
- tr, err := self.openTrie(root)
+ tr, err := self.db.OpenTrie(root)
if err != nil {
return err
}
@@ -149,34 +123,9 @@ func (self *StateDB) Reset(root common.Hash) error {
self.logSize = 0
self.preimages = make(map[common.Hash][]byte)
self.clearJournalAndRefund()
-
return nil
}
-// openTrie creates a trie. It uses an existing trie if one is available
-// from the journal if available.
-func (self *StateDB) openTrie(root common.Hash) (*trie.SecureTrie, error) {
- for i := len(self.pastTries) - 1; i >= 0; i-- {
- if self.pastTries[i].Hash() == root {
- tr := *self.pastTries[i]
- return &tr, nil
- }
- }
- return trie.NewSecure(root, self.db, MaxTrieCacheGen)
-}
-
-func (self *StateDB) pushTrie(t *trie.SecureTrie) {
- self.lock.Lock()
- defer self.lock.Unlock()
-
- if len(self.pastTries) >= maxPastTries {
- copy(self.pastTries, self.pastTries[1:])
- self.pastTries[len(self.pastTries)-1] = t
- } else {
- self.pastTries = append(self.pastTries, t)
- }
-}
-
func (self *StateDB) AddLog(log *types.Log) {
self.journal = append(self.journal, addLogChange{txhash: self.thash})
@@ -254,10 +203,7 @@ func (self *StateDB) GetNonce(addr common.Address) uint64 {
func (self *StateDB) GetCode(addr common.Address) []byte {
stateObject := self.getStateObject(addr)
if stateObject != nil {
- code := stateObject.Code(self.db)
- key := common.BytesToHash(stateObject.CodeHash())
- self.codeSizeCache.Add(key, len(code))
- return code
+ return stateObject.Code(self.db)
}
return nil
}
@@ -267,13 +213,12 @@ func (self *StateDB) GetCodeSize(addr common.Address) int {
if stateObject == nil {
return 0
}
- key := common.BytesToHash(stateObject.CodeHash())
- if cached, ok := self.codeSizeCache.Get(key); ok {
- return cached.(int)
+ if stateObject.code != nil {
+ return len(stateObject.code)
}
- size := len(stateObject.Code(self.db))
- if stateObject.dbErr == nil {
- self.codeSizeCache.Add(key, size)
+ size, err := self.db.ContractCodeSize(stateObject.addrHash, common.BytesToHash(stateObject.CodeHash()))
+ if err != nil {
+ self.setError(err)
}
return size
}
@@ -296,7 +241,7 @@ func (self *StateDB) GetState(a common.Address, b common.Hash) common.Hash {
// StorageTrie returns the storage trie of an account.
// The return value is a copy and is nil for non-existent accounts.
-func (self *StateDB) StorageTrie(a common.Address) *trie.SecureTrie {
+func (self *StateDB) StorageTrie(a common.Address) Trie {
stateObject := self.getStateObject(a)
if stateObject == nil {
return nil
@@ -394,14 +339,14 @@ func (self *StateDB) updateStateObject(stateObject *stateObject) {
if err != nil {
panic(fmt.Errorf("can't encode object at %x: %v", addr[:], err))
}
- self.trie.Update(addr[:], data)
+ self.setError(self.trie.TryUpdate(addr[:], data))
}
// deleteStateObject removes the given object from the state trie.
func (self *StateDB) deleteStateObject(stateObject *stateObject) {
stateObject.deleted = true
addr := stateObject.Address()
- self.trie.Delete(addr[:])
+ self.setError(self.trie.TryDelete(addr[:]))
}
// Retrieve a state object given my the address. Returns nil if not found.
@@ -415,8 +360,9 @@ func (self *StateDB) getStateObject(addr common.Address) (stateObject *stateObje
}
// Load the object from the database.
- enc := self.trie.Get(addr[:])
+ enc, err := self.trie.TryGet(addr[:])
if len(enc) == 0 {
+ self.setError(err)
return nil
}
var data Account
@@ -512,8 +458,6 @@ func (self *StateDB) Copy() *StateDB {
state := &StateDB{
db: self.db,
trie: self.trie,
- pastTries: self.pastTries,
- codeSizeCache: self.codeSizeCache,
stateObjects: make(map[common.Address]*stateObject, len(self.stateObjectsDirty)),
stateObjectsDirty: make(map[common.Address]struct{}, len(self.stateObjectsDirty)),
stateObjectsDestructed: make(map[common.Address]struct{}, len(self.stateObjectsDestructed)),
@@ -636,23 +580,6 @@ func (s *StateDB) DeleteSuicides() {
}
}
-// Commit commits all state changes to the database.
-func (s *StateDB) Commit(deleteEmptyObjects bool) (root common.Hash, err error) {
- root, batch := s.CommitBatch(deleteEmptyObjects)
- return root, batch.Write()
-}
-
-// CommitBatch commits all state changes to a write batch but does not
-// execute the batch. It is used to validate state changes against
-// the root hash stored in a block.
-func (s *StateDB) CommitBatch(deleteEmptyObjects bool) (root common.Hash, batch ethdb.Batch) {
- batch = s.db.NewBatch()
- root, _ = s.CommitTo(batch, deleteEmptyObjects)
-
- log.Debug("Trie cache stats after commit", "misses", trie.CacheMisses(), "unloads", trie.CacheUnloads())
- return root, batch
-}
-
func (s *StateDB) clearJournalAndRefund() {
s.journal = nil
s.validRevisions = s.validRevisions[:0]
@@ -690,8 +617,6 @@ func (s *StateDB) CommitTo(dbw trie.DatabaseWriter, deleteEmptyObjects bool) (ro
}
// Write trie changes.
root, err = s.trie.CommitTo(dbw)
- if err == nil {
- s.pushTrie(s.trie)
- }
+ log.Debug("Trie cache stats after commit", "misses", trie.CacheMisses(), "unloads", trie.CacheUnloads())
return root, err
}
diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go
index 72b638f97..b2bd18e65 100644
--- a/core/state/statedb_test.go
+++ b/core/state/statedb_test.go
@@ -28,6 +28,8 @@ import (
"testing"
"testing/quick"
+ check "gopkg.in/check.v1"
+
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/ethdb"
@@ -38,7 +40,7 @@ import (
func TestUpdateLeaks(t *testing.T) {
// Create an empty state database
db, _ := ethdb.NewMemDatabase()
- state, _ := New(common.Hash{}, db)
+ state, _ := New(common.Hash{}, NewDatabase(db))
// Update it with some accounts
for i := byte(0); i < 255; i++ {
@@ -66,8 +68,8 @@ func TestIntermediateLeaks(t *testing.T) {
// Create two state databases, one transitioning to the final state, the other final from the beginning
transDb, _ := ethdb.NewMemDatabase()
finalDb, _ := ethdb.NewMemDatabase()
- transState, _ := New(common.Hash{}, transDb)
- finalState, _ := New(common.Hash{}, finalDb)
+ transState, _ := New(common.Hash{}, NewDatabase(transDb))
+ finalState, _ := New(common.Hash{}, NewDatabase(finalDb))
modify := func(state *StateDB, addr common.Address, i, tweak byte) {
state.SetBalance(addr, big.NewInt(int64(11*i)+int64(tweak)))
@@ -95,10 +97,10 @@ func TestIntermediateLeaks(t *testing.T) {
}
// Commit and cross check the databases.
- if _, err := transState.Commit(false); err != nil {
+ if _, err := transState.CommitTo(transDb, false); err != nil {
t.Fatalf("failed to commit transition state: %v", err)
}
- if _, err := finalState.Commit(false); err != nil {
+ if _, err := finalState.CommitTo(finalDb, false); err != nil {
t.Fatalf("failed to commit final state: %v", err)
}
for _, key := range finalDb.Keys() {
@@ -282,7 +284,7 @@ func (test *snapshotTest) run() bool {
// Run all actions and create snapshots.
var (
db, _ = ethdb.NewMemDatabase()
- state, _ = New(common.Hash{}, db)
+ state, _ = New(common.Hash{}, NewDatabase(db))
snapshotRevs = make([]int, len(test.snapshots))
sindex = 0
)
@@ -297,7 +299,7 @@ func (test *snapshotTest) run() bool {
// Revert all snapshots in reverse order. Each revert must yield a state
// that is equivalent to fresh state with all actions up the snapshot applied.
for sindex--; sindex >= 0; sindex-- {
- checkstate, _ := New(common.Hash{}, db)
+ checkstate, _ := New(common.Hash{}, NewDatabase(db))
for _, action := range test.actions[:test.snapshots[sindex]] {
action.fn(action, checkstate)
}
@@ -354,21 +356,19 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
return nil
}
-func TestTouchDelete(t *testing.T) {
- db, _ := ethdb.NewMemDatabase()
- state, _ := New(common.Hash{}, db)
- state.GetOrNewStateObject(common.Address{})
- root, _ := state.Commit(false)
- state.Reset(root)
+func (s *StateSuite) TestTouchDelete(c *check.C) {
+ s.state.GetOrNewStateObject(common.Address{})
+ root, _ := s.state.CommitTo(s.db, false)
+ s.state.Reset(root)
- snapshot := state.Snapshot()
- state.AddBalance(common.Address{}, new(big.Int))
- if len(state.stateObjectsDirty) != 1 {
- t.Fatal("expected one dirty state object")
+ snapshot := s.state.Snapshot()
+ s.state.AddBalance(common.Address{}, new(big.Int))
+ if len(s.state.stateObjectsDirty) != 1 {
+ c.Fatal("expected one dirty state object")
}
- state.RevertToSnapshot(snapshot)
- if len(state.stateObjectsDirty) != 0 {
- t.Fatal("expected no dirty state object")
+ s.state.RevertToSnapshot(snapshot)
+ if len(s.state.stateObjectsDirty) != 0 {
+ c.Fatal("expected no dirty state object")
}
}
diff --git a/core/state/sync_test.go b/core/state/sync_test.go
index 108ebb320..06c572ea6 100644
--- a/core/state/sync_test.go
+++ b/core/state/sync_test.go
@@ -36,9 +36,10 @@ type testAccount struct {
}
// makeTestState create a sample test state to test node-wise reconstruction.
-func makeTestState() (ethdb.Database, common.Hash, []*testAccount) {
+func makeTestState() (Database, *ethdb.MemDatabase, common.Hash, []*testAccount) {
// Create an empty state
- db, _ := ethdb.NewMemDatabase()
+ mem, _ := ethdb.NewMemDatabase()
+ db := NewDatabase(mem)
state, _ := New(common.Hash{}, db)
// Fill it with some arbitrary data
@@ -60,17 +61,17 @@ func makeTestState() (ethdb.Database, common.Hash, []*testAccount) {
state.updateStateObject(obj)
accounts = append(accounts, acc)
}
- root, _ := state.Commit(false)
+ root, _ := state.CommitTo(mem, false)
// Return the generated state
- return db, root, accounts
+ return db, mem, root, accounts
}
// checkStateAccounts cross references a reconstructed state with an expected
// account array.
func checkStateAccounts(t *testing.T, db ethdb.Database, root common.Hash, accounts []*testAccount) {
// Check root availability and state contents
- state, err := New(root, db)
+ state, err := New(root, NewDatabase(db))
if err != nil {
t.Fatalf("failed to create state trie at %x: %v", root, err)
}
@@ -90,13 +91,28 @@ func checkStateAccounts(t *testing.T, db ethdb.Database, root common.Hash, accou
}
}
-// checkStateConsistency checks that all nodes in a state trie are indeed present.
+// checkTrieConsistency checks that all nodes in a (sub-)trie are indeed present.
+func checkTrieConsistency(db ethdb.Database, root common.Hash) error {
+ if v, _ := db.Get(root[:]); v == nil {
+ return nil // Consider a non existent state consistent.
+ }
+ trie, err := trie.New(root, db)
+ if err != nil {
+ return err
+ }
+ it := trie.NodeIterator(nil)
+ for it.Next(true) {
+ }
+ return it.Error()
+}
+
+// checkStateConsistency checks that all data of a state root is present.
func checkStateConsistency(db ethdb.Database, root common.Hash) error {
// Create and iterate a state trie rooted in a sub-node
if _, err := db.Get(root.Bytes()); err != nil {
- return nil // Consider a non existent state consistent
+ return nil // Consider a non existent state consistent.
}
- state, err := New(root, db)
+ state, err := New(root, NewDatabase(db))
if err != nil {
return err
}
@@ -122,7 +138,7 @@ func TestIterativeStateSyncBatched(t *testing.T) { testIterativeStateSync(t,
func testIterativeStateSync(t *testing.T, batch int) {
// Create a random state to copy
- srcDb, srcRoot, srcAccounts := makeTestState()
+ _, srcMem, srcRoot, srcAccounts := makeTestState()
// Create a destination state and sync with the scheduler
dstDb, _ := ethdb.NewMemDatabase()
@@ -132,7 +148,7 @@ func testIterativeStateSync(t *testing.T, batch int) {
for len(queue) > 0 {
results := make([]trie.SyncResult, len(queue))
for i, hash := range queue {
- data, err := srcDb.Get(hash.Bytes())
+ data, err := srcMem.Get(hash.Bytes())
if err != nil {
t.Fatalf("failed to retrieve node data for %x: %v", hash, err)
}
@@ -154,7 +170,7 @@ func testIterativeStateSync(t *testing.T, batch int) {
// partial results are returned, and the others sent only later.
func TestIterativeDelayedStateSync(t *testing.T) {
// Create a random state to copy
- srcDb, srcRoot, srcAccounts := makeTestState()
+ _, srcMem, srcRoot, srcAccounts := makeTestState()
// Create a destination state and sync with the scheduler
dstDb, _ := ethdb.NewMemDatabase()
@@ -165,7 +181,7 @@ func TestIterativeDelayedStateSync(t *testing.T) {
// Sync only half of the scheduled nodes
results := make([]trie.SyncResult, len(queue)/2+1)
for i, hash := range queue[:len(results)] {
- data, err := srcDb.Get(hash.Bytes())
+ data, err := srcMem.Get(hash.Bytes())
if err != nil {
t.Fatalf("failed to retrieve node data for %x: %v", hash, err)
}
@@ -191,7 +207,7 @@ func TestIterativeRandomStateSyncBatched(t *testing.T) { testIterativeRandomS
func testIterativeRandomStateSync(t *testing.T, batch int) {
// Create a random state to copy
- srcDb, srcRoot, srcAccounts := makeTestState()
+ _, srcMem, srcRoot, srcAccounts := makeTestState()
// Create a destination state and sync with the scheduler
dstDb, _ := ethdb.NewMemDatabase()
@@ -205,7 +221,7 @@ func testIterativeRandomStateSync(t *testing.T, batch int) {
// Fetch all the queued nodes in a random order
results := make([]trie.SyncResult, 0, len(queue))
for hash := range queue {
- data, err := srcDb.Get(hash.Bytes())
+ data, err := srcMem.Get(hash.Bytes())
if err != nil {
t.Fatalf("failed to retrieve node data for %x: %v", hash, err)
}
@@ -231,7 +247,7 @@ func testIterativeRandomStateSync(t *testing.T, batch int) {
// partial results are returned (Even those randomly), others sent only later.
func TestIterativeRandomDelayedStateSync(t *testing.T) {
// Create a random state to copy
- srcDb, srcRoot, srcAccounts := makeTestState()
+ _, srcMem, srcRoot, srcAccounts := makeTestState()
// Create a destination state and sync with the scheduler
dstDb, _ := ethdb.NewMemDatabase()
@@ -247,7 +263,7 @@ func TestIterativeRandomDelayedStateSync(t *testing.T) {
for hash := range queue {
delete(queue, hash)
- data, err := srcDb.Get(hash.Bytes())
+ data, err := srcMem.Get(hash.Bytes())
if err != nil {
t.Fatalf("failed to retrieve node data for %x: %v", hash, err)
}
@@ -276,7 +292,9 @@ func TestIterativeRandomDelayedStateSync(t *testing.T) {
// the database.
func TestIncompleteStateSync(t *testing.T) {
// Create a random state to copy
- srcDb, srcRoot, srcAccounts := makeTestState()
+ _, srcMem, srcRoot, srcAccounts := makeTestState()
+
+ checkTrieConsistency(srcMem, srcRoot)
// Create a destination state and sync with the scheduler
dstDb, _ := ethdb.NewMemDatabase()
@@ -288,7 +306,7 @@ func TestIncompleteStateSync(t *testing.T) {
// Fetch a batch of state nodes
results := make([]trie.SyncResult, len(queue))
for i, hash := range queue {
- data, err := srcDb.Get(hash.Bytes())
+ data, err := srcMem.Get(hash.Bytes())
if err != nil {
t.Fatalf("failed to retrieve node data for %x: %v", hash, err)
}
@@ -304,21 +322,18 @@ func TestIncompleteStateSync(t *testing.T) {
for _, result := range results {
added = append(added, result.Hash)
}
- // Check that all known sub-tries in the synced state is complete
- for _, root := range added {
- // Skim through the accounts and make sure the root hash is not a code node
- codeHash := false
+ // Check that all known sub-tries added so far are complete or missing entirely.
+ checkSubtries:
+ for _, hash := range added {
for _, acc := range srcAccounts {
- if root == crypto.Keccak256Hash(acc.code) {
- codeHash = true
- break
+ if hash == crypto.Keccak256Hash(acc.code) {
+ continue checkSubtries // skip trie check of code nodes.
}
}
- // If the root is a real trie node, check consistency
- if !codeHash {
- if err := checkStateConsistency(dstDb, root); err != nil {
- t.Fatalf("state inconsistent: %v", err)
- }
+ // Can't use checkStateConsistency here because subtrie keys may have odd
+ // length and crash in LeafKey.
+ if err := checkTrieConsistency(dstDb, hash); err != nil {
+ t.Fatalf("state inconsistent: %v", err)
}
}
// Fetch the next batch to retrieve
diff --git a/core/tx_pool_test.go b/core/tx_pool_test.go
index 4e28522e9..4903bc3ca 100644
--- a/core/tx_pool_test.go
+++ b/core/tx_pool_test.go
@@ -44,7 +44,7 @@ func pricedTransaction(nonce uint64, gaslimit, gasprice *big.Int, key *ecdsa.Pri
func setupTxPool() (*TxPool, *ecdsa.PrivateKey) {
db, _ := ethdb.NewMemDatabase()
- statedb, _ := state.New(common.Hash{}, db)
+ statedb, _ := state.New(common.Hash{}, state.NewDatabase(db))
key, _ := crypto.GenerateKey()
newPool := NewTxPool(DefaultTxPoolConfig, params.TestChainConfig, new(event.TypeMux), func() (*state.StateDB, error) { return statedb, nil }, func() *big.Int { return big.NewInt(1000000) })
@@ -95,7 +95,7 @@ func TestStateChangeDuringPoolReset(t *testing.T) {
key, _ = crypto.GenerateKey()
address = crypto.PubkeyToAddress(key.PublicKey)
mux = new(event.TypeMux)
- statedb, _ = state.New(common.Hash{}, db)
+ statedb, _ = state.New(common.Hash{}, state.NewDatabase(db))
trigger = false
)
@@ -114,7 +114,7 @@ func TestStateChangeDuringPoolReset(t *testing.T) {
// a state change between those fetches.
stdb := statedb
if trigger {
- statedb, _ = state.New(common.Hash{}, db)
+ statedb, _ = state.New(common.Hash{}, state.NewDatabase(db))
// simulate that the new head block included tx0 and tx1
statedb.SetNonce(address, 2)
statedb.SetBalance(address, new(big.Int).SetUint64(params.Ether))
@@ -292,7 +292,7 @@ func TestTransactionChainFork(t *testing.T) {
addr := crypto.PubkeyToAddress(key.PublicKey)
resetState := func() {
db, _ := ethdb.NewMemDatabase()
- statedb, _ := state.New(common.Hash{}, db)
+ statedb, _ := state.New(common.Hash{}, state.NewDatabase(db))
pool.currentState = func() (*state.StateDB, error) { return statedb, nil }
currentState, _ := pool.currentState()
currentState.AddBalance(addr, big.NewInt(100000000000000))
@@ -318,7 +318,7 @@ func TestTransactionDoubleNonce(t *testing.T) {
addr := crypto.PubkeyToAddress(key.PublicKey)
resetState := func() {
db, _ := ethdb.NewMemDatabase()
- statedb, _ := state.New(common.Hash{}, db)
+ statedb, _ := state.New(common.Hash{}, state.NewDatabase(db))
pool.currentState = func() (*state.StateDB, error) { return statedb, nil }
currentState, _ := pool.currentState()
currentState.AddBalance(addr, big.NewInt(100000000000000))
@@ -628,7 +628,7 @@ func TestTransactionQueueGlobalLimiting(t *testing.T) {
// Create the pool to test the limit enforcement with
db, _ := ethdb.NewMemDatabase()
- statedb, _ := state.New(common.Hash{}, db)
+ statedb, _ := state.New(common.Hash{}, state.NewDatabase(db))
pool := NewTxPool(DefaultTxPoolConfig, params.TestChainConfig, new(event.TypeMux), func() (*state.StateDB, error) { return statedb, nil }, func() *big.Int { return big.NewInt(1000000) })
pool.resetState()
@@ -783,7 +783,7 @@ func TestTransactionPendingGlobalLimiting(t *testing.T) {
// Create the pool to test the limit enforcement with
db, _ := ethdb.NewMemDatabase()
- statedb, _ := state.New(common.Hash{}, db)
+ statedb, _ := state.New(common.Hash{}, state.NewDatabase(db))
pool := NewTxPool(DefaultTxPoolConfig, params.TestChainConfig, new(event.TypeMux), func() (*state.StateDB, error) { return statedb, nil }, func() *big.Int { return big.NewInt(1000000) })
pool.resetState()
@@ -835,7 +835,7 @@ func TestTransactionCapClearsFromAll(t *testing.T) {
// Create the pool to test the limit enforcement with
db, _ := ethdb.NewMemDatabase()
- statedb, _ := state.New(common.Hash{}, db)
+ statedb, _ := state.New(common.Hash{}, state.NewDatabase(db))
pool := NewTxPool(DefaultTxPoolConfig, params.TestChainConfig, new(event.TypeMux), func() (*state.StateDB, error) { return statedb, nil }, func() *big.Int { return big.NewInt(1000000) })
pool.resetState()
@@ -868,7 +868,7 @@ func TestTransactionPendingMinimumAllowance(t *testing.T) {
// Create the pool to test the limit enforcement with
db, _ := ethdb.NewMemDatabase()
- statedb, _ := state.New(common.Hash{}, db)
+ statedb, _ := state.New(common.Hash{}, state.NewDatabase(db))
pool := NewTxPool(DefaultTxPoolConfig, params.TestChainConfig, new(event.TypeMux), func() (*state.StateDB, error) { return statedb, nil }, func() *big.Int { return big.NewInt(1000000) })
pool.resetState()
@@ -913,7 +913,7 @@ func TestTransactionPendingMinimumAllowance(t *testing.T) {
func TestTransactionPoolRepricing(t *testing.T) {
// Create the pool to test the pricing enforcement with
db, _ := ethdb.NewMemDatabase()
- statedb, _ := state.New(common.Hash{}, db)
+ statedb, _ := state.New(common.Hash{}, state.NewDatabase(db))
pool := NewTxPool(DefaultTxPoolConfig, params.TestChainConfig, new(event.TypeMux), func() (*state.StateDB, error) { return statedb, nil }, func() *big.Int { return big.NewInt(1000000) })
pool.resetState()
@@ -1006,7 +1006,7 @@ func TestTransactionPoolUnderpricing(t *testing.T) {
// Create the pool to test the pricing enforcement with
db, _ := ethdb.NewMemDatabase()
- statedb, _ := state.New(common.Hash{}, db)
+ statedb, _ := state.New(common.Hash{}, state.NewDatabase(db))
pool := NewTxPool(DefaultTxPoolConfig, params.TestChainConfig, new(event.TypeMux), func() (*state.StateDB, error) { return statedb, nil }, func() *big.Int { return big.NewInt(1000000) })
pool.resetState()
@@ -1091,7 +1091,7 @@ func TestTransactionPoolUnderpricing(t *testing.T) {
func TestTransactionReplacement(t *testing.T) {
// Create the pool to test the pricing enforcement with
db, _ := ethdb.NewMemDatabase()
- statedb, _ := state.New(common.Hash{}, db)
+ statedb, _ := state.New(common.Hash{}, state.NewDatabase(db))
pool := NewTxPool(DefaultTxPoolConfig, params.TestChainConfig, new(event.TypeMux), func() (*state.StateDB, error) { return statedb, nil }, func() *big.Int { return big.NewInt(1000000) })
pool.resetState()
diff --git a/core/vm/gas_table.go b/core/vm/gas_table.go
index 24ad6caa5..761ca4450 100644
--- a/core/vm/gas_table.go
+++ b/core/vm/gas_table.go
@@ -17,7 +17,6 @@
package vm
import (
- gmath "math"
"math/big"
"github.com/ethereum/go-ethereum/common"
@@ -28,15 +27,20 @@ import (
// memoryGasCosts calculates the quadratic gas for memory expansion. It does so
// only for the memory region that is expanded, not the total memory.
func memoryGasCost(mem *Memory, newMemSize uint64) (uint64, error) {
- // The maximum that will fit in a uint64 is max_word_count - 1
- // anything above that will result in an overflow.
- if newMemSize > gmath.MaxUint64-32 {
- return 0, errGasUintOverflow
- }
if newMemSize == 0 {
return 0, nil
}
+ // The maximum that will fit in a uint64 is max_word_count - 1
+ // anything above that will result in an overflow.
+ // Additionally, a newMemSize which results in a
+ // newMemSizeWords larger than 0x7ffffffff will cause the square operation
+ // to overflow.
+ // The constant 0xffffffffe0 is the highest number that can be used without
+ // overflowing the gas calculation
+ if newMemSize > 0xffffffffe0 {
+ return 0, errGasUintOverflow
+ }
newMemSizeWords := toWordSize(newMemSize)
newMemSize = newMemSizeWords * 32
diff --git a/core/vm/gas_table_test.go b/core/vm/gas_table_test.go
index 1ee909e92..1b91aee56 100644
--- a/core/vm/gas_table_test.go
+++ b/core/vm/gas_table_test.go
@@ -16,24 +16,20 @@
package vm
-import (
- "math"
- "testing"
-)
+import "testing"
func TestMemoryGasCost(t *testing.T) {
- size := uint64(math.MaxUint64 - 64)
- _, err := memoryGasCost(&Memory{}, size)
+ //size := uint64(math.MaxUint64 - 64)
+ size := uint64(0xffffffffe0)
+ v, err := memoryGasCost(&Memory{}, size)
if err != nil {
t.Error("didn't expect error:", err)
}
-
- _, err = memoryGasCost(&Memory{}, size+32)
- if err != nil {
- t.Error("didn't expect error:", err)
+ if v != 36028899963961341 {
+ t.Errorf("Expected: 36028899963961341, got %d", v)
}
- _, err = memoryGasCost(&Memory{}, size+33)
+ _, err = memoryGasCost(&Memory{}, size+1)
if err == nil {
t.Error("expected error")
}
diff --git a/core/vm/instructions_test.go b/core/vm/instructions_test.go
index ae428aeab..03c42c561 100644
--- a/core/vm/instructions_test.go
+++ b/core/vm/instructions_test.go
@@ -1,6 +1,7 @@
package vm
import (
+ "fmt"
"math/big"
"testing"
@@ -40,3 +41,244 @@ func TestByteOp(t *testing.T) {
}
}
}
+
+func opBenchmark(bench *testing.B, op func(pc *uint64, evm *EVM, contract *Contract, memory *Memory, stack *Stack) ([]byte, error), args ...string) {
+ var (
+ env = NewEVM(Context{}, nil, params.TestChainConfig, Config{EnableJit: false, ForceJit: false})
+ stack = newstack()
+ )
+ // convert args
+ byteArgs := make([][]byte, len(args))
+ for i, arg := range args {
+ byteArgs[i] = common.Hex2Bytes(arg)
+ }
+ pc := uint64(0)
+ bench.ResetTimer()
+ for i := 0; i < bench.N; i++ {
+ for _, arg := range byteArgs {
+ a := new(big.Int).SetBytes(arg)
+ stack.push(a)
+ }
+ op(&pc, env, nil, nil, stack)
+ stack.pop()
+ }
+}
+
+func precompiledBenchmark(addr, input, expected string, gas uint64, bench *testing.B) {
+
+ contract := NewContract(AccountRef(common.HexToAddress("1337")),
+ nil, new(big.Int), gas)
+
+ p := PrecompiledContracts[common.HexToAddress(addr)]
+ in := common.Hex2Bytes(input)
+ var (
+ res []byte
+ err error
+ )
+ data := make([]byte, len(in))
+ bench.ResetTimer()
+ for i := 0; i < bench.N; i++ {
+ contract.Gas = gas
+ copy(data, in)
+ res, err = RunPrecompiledContract(p, data, contract)
+ }
+ bench.StopTimer()
+ //Check if it is correct
+ if err != nil {
+ bench.Error(err)
+ return
+ }
+ if common.Bytes2Hex(res) != expected {
+ bench.Error(fmt.Sprintf("Expected %v, got %v", expected, common.Bytes2Hex(res)))
+ return
+ }
+}
+
+func BenchmarkPrecompiledEcdsa(bench *testing.B) {
+ var (
+ addr = "01"
+ inp = "38d18acb67d25c8bb9942764b62f18e17054f66a817bd4295423adf9ed98873e000000000000000000000000000000000000000000000000000000000000001b38d18acb67d25c8bb9942764b62f18e17054f66a817bd4295423adf9ed98873e789d1dd423d25f0772d2748d60f7e4b81bb14d086eba8e8e8efb6dcff8a4ae02"
+ exp = "000000000000000000000000ceaccac640adf55b2028469bd36ba501f28b699d"
+ gas = uint64(4000000)
+ )
+ precompiledBenchmark(addr, inp, exp, gas, bench)
+}
+func BenchmarkPrecompiledSha256(bench *testing.B) {
+ var (
+ addr = "02"
+ inp = "38d18acb67d25c8bb9942764b62f18e17054f66a817bd4295423adf9ed98873e000000000000000000000000000000000000000000000000000000000000001b38d18acb67d25c8bb9942764b62f18e17054f66a817bd4295423adf9ed98873e789d1dd423d25f0772d2748d60f7e4b81bb14d086eba8e8e8efb6dcff8a4ae02"
+ exp = "811c7003375852fabd0d362e40e68607a12bdabae61a7d068fe5fdd1dbbf2a5d"
+ gas = uint64(4000000)
+ )
+ precompiledBenchmark(addr, inp, exp, gas, bench)
+}
+func BenchmarkPrecompiledRipeMD(bench *testing.B) {
+ var (
+ addr = "03"
+ inp = "38d18acb67d25c8bb9942764b62f18e17054f66a817bd4295423adf9ed98873e000000000000000000000000000000000000000000000000000000000000001b38d18acb67d25c8bb9942764b62f18e17054f66a817bd4295423adf9ed98873e789d1dd423d25f0772d2748d60f7e4b81bb14d086eba8e8e8efb6dcff8a4ae02"
+ exp = "0000000000000000000000009215b8d9882ff46f0dfde6684d78e831467f65e6"
+ gas = uint64(4000000)
+ )
+ precompiledBenchmark(addr, inp, exp, gas, bench)
+}
+func BenchmarkPrecompiledIdentity(bench *testing.B) {
+ var (
+ addr = "04"
+ inp = "38d18acb67d25c8bb9942764b62f18e17054f66a817bd4295423adf9ed98873e000000000000000000000000000000000000000000000000000000000000001b38d18acb67d25c8bb9942764b62f18e17054f66a817bd4295423adf9ed98873e789d1dd423d25f0772d2748d60f7e4b81bb14d086eba8e8e8efb6dcff8a4ae02"
+ exp = "38d18acb67d25c8bb9942764b62f18e17054f66a817bd4295423adf9ed98873e000000000000000000000000000000000000000000000000000000000000001b38d18acb67d25c8bb9942764b62f18e17054f66a817bd4295423adf9ed98873e789d1dd423d25f0772d2748d60f7e4b81bb14d086eba8e8e8efb6dcff8a4ae02"
+ gas = uint64(4000000)
+ )
+ precompiledBenchmark(addr, inp, exp, gas, bench)
+}
+func BenchmarkOpAdd(b *testing.B) {
+ x := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff"
+ y := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff"
+
+ opBenchmark(b, opAdd, x, y)
+
+}
+func BenchmarkOpSub(b *testing.B) {
+ x := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff"
+ y := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff"
+
+ opBenchmark(b, opSub, x, y)
+
+}
+func BenchmarkOpMul(b *testing.B) {
+ x := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff"
+ y := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff"
+
+ opBenchmark(b, opMul, x, y)
+
+}
+func BenchmarkOpDiv(b *testing.B) {
+ x := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff"
+ y := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff"
+
+ opBenchmark(b, opDiv, x, y)
+
+}
+func BenchmarkOpSdiv(b *testing.B) {
+ x := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff"
+ y := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff"
+
+ opBenchmark(b, opSdiv, x, y)
+
+}
+func BenchmarkOpMod(b *testing.B) {
+ x := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff"
+ y := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff"
+
+ opBenchmark(b, opMod, x, y)
+
+}
+func BenchmarkOpSmod(b *testing.B) {
+ x := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff"
+ y := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff"
+
+ opBenchmark(b, opSmod, x, y)
+
+}
+func BenchmarkOpExp(b *testing.B) {
+ x := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff"
+ y := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff"
+
+ opBenchmark(b, opExp, x, y)
+
+}
+func BenchmarkOpSignExtend(b *testing.B) {
+ x := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff"
+ y := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff"
+
+ opBenchmark(b, opSignExtend, x, y)
+
+}
+func BenchmarkOpLt(b *testing.B) {
+ x := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff"
+ y := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff"
+
+ opBenchmark(b, opLt, x, y)
+
+}
+func BenchmarkOpGt(b *testing.B) {
+ x := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff"
+ y := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff"
+
+ opBenchmark(b, opGt, x, y)
+
+}
+func BenchmarkOpSlt(b *testing.B) {
+ x := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff"
+ y := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff"
+
+ opBenchmark(b, opSlt, x, y)
+
+}
+func BenchmarkOpSgt(b *testing.B) {
+ x := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff"
+ y := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff"
+
+ opBenchmark(b, opSgt, x, y)
+
+}
+func BenchmarkOpEq(b *testing.B) {
+ x := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff"
+ y := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff"
+
+ opBenchmark(b, opEq, x, y)
+
+}
+func BenchmarkOpAnd(b *testing.B) {
+ x := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff"
+ y := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff"
+
+ opBenchmark(b, opAnd, x, y)
+
+}
+func BenchmarkOpOr(b *testing.B) {
+ x := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff"
+ y := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff"
+
+ opBenchmark(b, opOr, x, y)
+
+}
+func BenchmarkOpXor(b *testing.B) {
+ x := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff"
+ y := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff"
+
+ opBenchmark(b, opXor, x, y)
+
+}
+func BenchmarkOpByte(b *testing.B) {
+ x := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff"
+ y := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff"
+
+ opBenchmark(b, opByte, x, y)
+
+}
+
+func BenchmarkOpAddmod(b *testing.B) {
+ x := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff"
+ y := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff"
+ z := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff"
+
+ opBenchmark(b, opAddmod, x, y, z)
+
+}
+func BenchmarkOpMulmod(b *testing.B) {
+ x := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff"
+ y := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff"
+ z := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff"
+
+ opBenchmark(b, opMulmod, x, y, z)
+
+}
+
+//func BenchmarkOpSha3(b *testing.B) {
+// x := "0"
+// y := "32"
+//
+// opBenchmark(b,opSha3, x, y)
+//
+//
+//}
diff --git a/core/vm/runtime/runtime.go b/core/vm/runtime/runtime.go
index aa386a995..44cde4f70 100644
--- a/core/vm/runtime/runtime.go
+++ b/core/vm/runtime/runtime.go
@@ -102,7 +102,7 @@ func Execute(code, input []byte, cfg *Config) ([]byte, *state.StateDB, error) {
if cfg.State == nil {
db, _ := ethdb.NewMemDatabase()
- cfg.State, _ = state.New(common.Hash{}, db)
+ cfg.State, _ = state.New(common.Hash{}, state.NewDatabase(db))
}
var (
address = common.StringToAddress("contract")
@@ -133,7 +133,7 @@ func Create(input []byte, cfg *Config) ([]byte, common.Address, uint64, error) {
if cfg.State == nil {
db, _ := ethdb.NewMemDatabase()
- cfg.State, _ = state.New(common.Hash{}, db)
+ cfg.State, _ = state.New(common.Hash{}, state.NewDatabase(db))
}
var (
vmenv = NewEnv(cfg, cfg.State)
diff --git a/core/vm/runtime/runtime_test.go b/core/vm/runtime/runtime_test.go
index 7f40770d2..2c4dc5026 100644
--- a/core/vm/runtime/runtime_test.go
+++ b/core/vm/runtime/runtime_test.go
@@ -95,7 +95,7 @@ func TestExecute(t *testing.T) {
func TestCall(t *testing.T) {
db, _ := ethdb.NewMemDatabase()
- state, _ := state.New(common.Hash{}, db)
+ state, _ := state.New(common.Hash{}, state.NewDatabase(db))
address := common.HexToAddress("0x0a")
state.SetCode(address, []byte{
byte(vm.PUSH1), 10,
diff --git a/eth/api.go b/eth/api.go
index 81570988c..0d90759b6 100644
--- a/eth/api.go
+++ b/eth/api.go
@@ -637,7 +637,7 @@ func (api *PrivateDebugAPI) StorageRangeAt(ctx context.Context, blockHash common
return storageRangeAt(st, keyStart, maxResult), nil
}
-func storageRangeAt(st *trie.SecureTrie, start []byte, maxResult int) StorageRangeResult {
+func storageRangeAt(st state.Trie, start []byte, maxResult int) StorageRangeResult {
it := trie.NewIterator(st.NodeIterator(start))
result := StorageRangeResult{Storage: storageMap{}}
for i := 0; i < maxResult && it.Next(); i++ {
diff --git a/eth/api_backend.go b/eth/api_backend.go
index fe108d272..166b5084d 100644
--- a/eth/api_backend.go
+++ b/eth/api_backend.go
@@ -31,7 +31,6 @@ import (
"github.com/ethereum/go-ethereum/eth/gasprice"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/event"
- "github.com/ethereum/go-ethereum/internal/ethapi"
"github.com/ethereum/go-ethereum/params"
"github.com/ethereum/go-ethereum/rpc"
)
@@ -81,11 +80,11 @@ func (b *EthApiBackend) BlockByNumber(ctx context.Context, blockNr rpc.BlockNumb
return b.eth.blockchain.GetBlockByNumber(uint64(blockNr)), nil
}
-func (b *EthApiBackend) StateAndHeaderByNumber(ctx context.Context, blockNr rpc.BlockNumber) (ethapi.State, *types.Header, error) {
+func (b *EthApiBackend) StateAndHeaderByNumber(ctx context.Context, blockNr rpc.BlockNumber) (*state.StateDB, *types.Header, error) {
// Pending state is only known by the miner
if blockNr == rpc.PendingBlockNumber {
block, state := b.eth.miner.Pending()
- return EthApiState{state}, block.Header(), nil
+ return state, block.Header(), nil
}
// Otherwise resolve the block number and return its state
header, err := b.HeaderByNumber(ctx, blockNr)
@@ -93,7 +92,7 @@ func (b *EthApiBackend) StateAndHeaderByNumber(ctx context.Context, blockNr rpc.
return nil, nil, err
}
stateDb, err := b.eth.BlockChain().StateAt(header.Root)
- return EthApiState{stateDb}, header, err
+ return stateDb, header, err
}
func (b *EthApiBackend) GetBlock(ctx context.Context, blockHash common.Hash) (*types.Block, error) {
@@ -108,14 +107,12 @@ func (b *EthApiBackend) GetTd(blockHash common.Hash) *big.Int {
return b.eth.blockchain.GetTdByHash(blockHash)
}
-func (b *EthApiBackend) GetEVM(ctx context.Context, msg core.Message, state ethapi.State, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error) {
- statedb := state.(EthApiState).state
- from := statedb.GetOrNewStateObject(msg.From())
- from.SetBalance(math.MaxBig256)
+func (b *EthApiBackend) GetEVM(ctx context.Context, msg core.Message, state *state.StateDB, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error) {
+ state.SetBalance(msg.From(), math.MaxBig256)
vmError := func() error { return nil }
context := core.NewEVMContext(msg, header, b.eth.BlockChain(), nil)
- return vm.NewEVM(context, statedb, b.eth.chainConfig, vmCfg), vmError, nil
+ return vm.NewEVM(context, state, b.eth.chainConfig, vmCfg), vmError, nil
}
func (b *EthApiBackend) SendTx(ctx context.Context, signedTx *types.Transaction) error {
@@ -200,23 +197,3 @@ func (b *EthApiBackend) EventMux() *event.TypeMux {
func (b *EthApiBackend) AccountManager() *accounts.Manager {
return b.eth.AccountManager()
}
-
-type EthApiState struct {
- state *state.StateDB
-}
-
-func (s EthApiState) GetBalance(ctx context.Context, addr common.Address) (*big.Int, error) {
- return s.state.GetBalance(addr), nil
-}
-
-func (s EthApiState) GetCode(ctx context.Context, addr common.Address) ([]byte, error) {
- return s.state.GetCode(addr), nil
-}
-
-func (s EthApiState) GetState(ctx context.Context, a common.Address, b common.Hash) (common.Hash, error) {
- return s.state.GetState(a, b), nil
-}
-
-func (s EthApiState) GetNonce(ctx context.Context, addr common.Address) (uint64, error) {
- return s.state.GetNonce(addr), nil
-}
diff --git a/eth/api_test.go b/eth/api_test.go
index f8d2e9c76..49ce38688 100644
--- a/eth/api_test.go
+++ b/eth/api_test.go
@@ -32,7 +32,7 @@ func TestStorageRangeAt(t *testing.T) {
// Create a state where account 0x010000... has a few storage entries.
var (
db, _ = ethdb.NewMemDatabase()
- state, _ = state.New(common.Hash{}, db)
+ state, _ = state.New(common.Hash{}, state.NewDatabase(db))
addr = common.Address{0x01}
keys = []common.Hash{ // hashes of Keys of storage
common.HexToHash("340dd630ad21bf010b4e676dbfa9ba9a02175262d1fa356232cfde6cb5b47ef2"),
diff --git a/eth/bind.go b/eth/bind.go
index e5abd8617..0385db1f9 100644
--- a/eth/bind.go
+++ b/eth/bind.go
@@ -54,14 +54,12 @@ func NewContractBackend(apiBackend ethapi.Backend) *ContractBackend {
// CodeAt retrieves any code associated with the contract from the local API.
func (b *ContractBackend) CodeAt(ctx context.Context, contract common.Address, blockNum *big.Int) ([]byte, error) {
- out, err := b.bcapi.GetCode(ctx, contract, toBlockNumber(blockNum))
- return common.FromHex(out), err
+ return b.bcapi.GetCode(ctx, contract, toBlockNumber(blockNum))
}
// CodeAt retrieves any code associated with the contract from the local API.
func (b *ContractBackend) PendingCodeAt(ctx context.Context, contract common.Address) ([]byte, error) {
- out, err := b.bcapi.GetCode(ctx, contract, rpc.PendingBlockNumber)
- return common.FromHex(out), err
+ return b.bcapi.GetCode(ctx, contract, rpc.PendingBlockNumber)
}
// ContractCall implements bind.ContractCaller executing an Ethereum contract
diff --git a/eth/downloader/downloader_test.go b/eth/downloader/downloader_test.go
index 267a0def9..1fb5a0910 100644
--- a/eth/downloader/downloader_test.go
+++ b/eth/downloader/downloader_test.go
@@ -657,7 +657,7 @@ func assertOwnForkedChain(t *testing.T, tester *downloadTester, common int, leng
index = len(tester.ownHashes) - lengths[len(lengths)-1] + int(tester.downloader.queue.fastSyncPivot)
}
if index > 0 {
- if statedb, err := state.New(tester.ownHeaders[tester.ownHashes[index]].Root, tester.stateDb); statedb == nil || err != nil {
+ if statedb, err := state.New(tester.ownHeaders[tester.ownHashes[index]].Root, state.NewDatabase(tester.stateDb)); statedb == nil || err != nil {
t.Fatalf("state reconstruction failed: %v", err)
}
}
diff --git a/eth/handler_test.go b/eth/handler_test.go
index 413ed2bff..ca9c9e1b4 100644
--- a/eth/handler_test.go
+++ b/eth/handler_test.go
@@ -374,7 +374,7 @@ func testGetNodeData(t *testing.T, protocol int) {
}
accounts := []common.Address{testBank, acc1Addr, acc2Addr}
for i := uint64(0); i <= pm.blockchain.CurrentBlock().NumberU64(); i++ {
- trie, _ := state.New(pm.blockchain.GetBlockByNumber(i).Root(), statedb)
+ trie, _ := state.New(pm.blockchain.GetBlockByNumber(i).Root(), state.NewDatabase(statedb))
for j, acc := range accounts {
state, _ := pm.blockchain.State()
diff --git a/internal/ethapi/api.go b/internal/ethapi/api.go
index da5dc5d58..c22c56dfb 100644
--- a/internal/ethapi/api.go
+++ b/internal/ethapi/api.go
@@ -447,8 +447,8 @@ func (s *PublicBlockChainAPI) GetBalance(ctx context.Context, address common.Add
if state == nil || err != nil {
return nil, err
}
-
- return state.GetBalance(ctx, address)
+ b := state.GetBalance(address)
+ return b, state.Error()
}
// GetBlockByNumber returns the requested block. When blockNr is -1 the chain head is returned. When fullTx is true all
@@ -529,31 +529,25 @@ func (s *PublicBlockChainAPI) GetUncleCountByBlockHash(ctx context.Context, bloc
}
// GetCode returns the code stored at the given address in the state for the given block number.
-func (s *PublicBlockChainAPI) GetCode(ctx context.Context, address common.Address, blockNr rpc.BlockNumber) (string, error) {
+func (s *PublicBlockChainAPI) GetCode(ctx context.Context, address common.Address, blockNr rpc.BlockNumber) (hexutil.Bytes, error) {
state, _, err := s.b.StateAndHeaderByNumber(ctx, blockNr)
if state == nil || err != nil {
- return "", err
- }
- res, err := state.GetCode(ctx, address)
- if len(res) == 0 || err != nil { // backwards compatibility
- return "0x", err
+ return nil, err
}
- return common.ToHex(res), nil
+ code := state.GetCode(address)
+ return code, state.Error()
}
// GetStorageAt returns the storage from the state at the given address, key and
// block number. The rpc.LatestBlockNumber and rpc.PendingBlockNumber meta block
// numbers are also allowed.
-func (s *PublicBlockChainAPI) GetStorageAt(ctx context.Context, address common.Address, key string, blockNr rpc.BlockNumber) (string, error) {
+func (s *PublicBlockChainAPI) GetStorageAt(ctx context.Context, address common.Address, key string, blockNr rpc.BlockNumber) (hexutil.Bytes, error) {
state, _, err := s.b.StateAndHeaderByNumber(ctx, blockNr)
if state == nil || err != nil {
- return "0x", err
- }
- res, err := state.GetState(ctx, address, common.HexToHash(key))
- if err != nil {
- return "0x", err
+ return nil, err
}
- return res.Hex(), nil
+ res := state.GetState(address, common.HexToHash(key))
+ return res[:], state.Error()
}
// callmsg is the message type used for call transitions.
@@ -978,11 +972,8 @@ func (s *PublicTransactionPoolAPI) GetTransactionCount(ctx context.Context, addr
if state == nil || err != nil {
return nil, err
}
- nonce, err := state.GetNonce(ctx, address)
- if err != nil {
- return nil, err
- }
- return (*hexutil.Uint64)(&nonce), nil
+ nonce := state.GetNonce(address)
+ return (*hexutil.Uint64)(&nonce), state.Error()
}
// getTransactionBlockData fetches the meta data for the given transaction from the chain database. This is useful to
diff --git a/internal/ethapi/backend.go b/internal/ethapi/backend.go
index 68b5069d0..d122b7915 100644
--- a/internal/ethapi/backend.go
+++ b/internal/ethapi/backend.go
@@ -24,6 +24,7 @@ import (
"github.com/ethereum/go-ethereum/accounts"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core"
+ "github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/eth/downloader"
@@ -47,11 +48,12 @@ type Backend interface {
SetHead(number uint64)
HeaderByNumber(ctx context.Context, blockNr rpc.BlockNumber) (*types.Header, error)
BlockByNumber(ctx context.Context, blockNr rpc.BlockNumber) (*types.Block, error)
- StateAndHeaderByNumber(ctx context.Context, blockNr rpc.BlockNumber) (State, *types.Header, error)
+ StateAndHeaderByNumber(ctx context.Context, blockNr rpc.BlockNumber) (*state.StateDB, *types.Header, error)
GetBlock(ctx context.Context, blockHash common.Hash) (*types.Block, error)
GetReceipts(ctx context.Context, blockHash common.Hash) (types.Receipts, error)
GetTd(blockHash common.Hash) *big.Int
- GetEVM(ctx context.Context, msg core.Message, state State, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error)
+ GetEVM(ctx context.Context, msg core.Message, state *state.StateDB, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error)
+
// TxPool API
SendTx(ctx context.Context, signedTx *types.Transaction) error
RemoveTx(txHash common.Hash)
@@ -65,13 +67,6 @@ type Backend interface {
CurrentBlock() *types.Block
}
-type State interface {
- GetBalance(ctx context.Context, addr common.Address) (*big.Int, error)
- GetCode(ctx context.Context, addr common.Address) ([]byte, error)
- GetState(ctx context.Context, a common.Address, b common.Hash) (common.Hash, error)
- GetNonce(ctx context.Context, addr common.Address) (uint64, error)
-}
-
func GetAPIs(apiBackend Backend) []rpc.API {
nonceLock := new(AddrLocker)
return []rpc.API{
diff --git a/les/api_backend.go b/les/api_backend.go
index 7d69046de..7a3c2447c 100644
--- a/les/api_backend.go
+++ b/les/api_backend.go
@@ -24,13 +24,13 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/math"
"github.com/ethereum/go-ethereum/core"
+ "github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/eth/downloader"
"github.com/ethereum/go-ethereum/eth/gasprice"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/event"
- "github.com/ethereum/go-ethereum/internal/ethapi"
"github.com/ethereum/go-ethereum/light"
"github.com/ethereum/go-ethereum/params"
"github.com/ethereum/go-ethereum/rpc"
@@ -70,12 +70,12 @@ func (b *LesApiBackend) BlockByNumber(ctx context.Context, blockNr rpc.BlockNumb
return b.GetBlock(ctx, header.Hash())
}
-func (b *LesApiBackend) StateAndHeaderByNumber(ctx context.Context, blockNr rpc.BlockNumber) (ethapi.State, *types.Header, error) {
+func (b *LesApiBackend) StateAndHeaderByNumber(ctx context.Context, blockNr rpc.BlockNumber) (*state.StateDB, *types.Header, error) {
header, err := b.HeaderByNumber(ctx, blockNr)
if header == nil || err != nil {
return nil, nil, err
}
- return light.NewLightState(light.StateTrieID(header), b.eth.odr), header, nil
+ return light.NewState(ctx, header, b.eth.odr), header, nil
}
func (b *LesApiBackend) GetBlock(ctx context.Context, blockHash common.Hash) (*types.Block, error) {
@@ -90,18 +90,10 @@ func (b *LesApiBackend) GetTd(blockHash common.Hash) *big.Int {
return b.eth.blockchain.GetTdByHash(blockHash)
}
-func (b *LesApiBackend) GetEVM(ctx context.Context, msg core.Message, state ethapi.State, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error) {
- stateDb := state.(*light.LightState).Copy()
- addr := msg.From()
- from, err := stateDb.GetOrNewStateObject(ctx, addr)
- if err != nil {
- return nil, nil, err
- }
- from.SetBalance(math.MaxBig256)
-
- vmstate := light.NewVMState(ctx, stateDb)
+func (b *LesApiBackend) GetEVM(ctx context.Context, msg core.Message, state *state.StateDB, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error) {
+ state.SetBalance(msg.From(), math.MaxBig256)
context := core.NewEVMContext(msg, header, b.eth.blockchain, nil)
- return vm.NewEVM(context, vmstate, b.eth.chainConfig, vmCfg), vmstate.Error, nil
+ return vm.NewEVM(context, state, b.eth.chainConfig, vmCfg), state.Error, nil
}
func (b *LesApiBackend) SendTx(ctx context.Context, signedTx *types.Transaction) error {
diff --git a/les/odr_test.go b/les/odr_test.go
index 7b34996ce..3a0fd6738 100644
--- a/les/odr_test.go
+++ b/les/odr_test.go
@@ -75,24 +75,23 @@ func odrAccounts(ctx context.Context, db ethdb.Database, config *params.ChainCon
dummyAddr := common.HexToAddress("1234567812345678123456781234567812345678")
acc := []common.Address{testBankAddress, acc1Addr, acc2Addr, dummyAddr}
- var res []byte
+ var (
+ res []byte
+ st *state.StateDB
+ err error
+ )
for _, addr := range acc {
if bc != nil {
header := bc.GetHeaderByHash(bhash)
- st, err := state.New(header.Root, db)
- if err == nil {
- bal := st.GetBalance(addr)
- rlp, _ := rlp.EncodeToBytes(bal)
- res = append(res, rlp...)
- }
+ st, err = state.New(header.Root, state.NewDatabase(db))
} else {
header := lc.GetHeaderByHash(bhash)
- st := light.NewLightState(light.StateTrieID(header), lc.Odr())
- bal, err := st.GetBalance(ctx, addr)
- if err == nil {
- rlp, _ := rlp.EncodeToBytes(bal)
- res = append(res, rlp...)
- }
+ st = light.NewState(ctx, header, lc.Odr())
+ }
+ if err == nil {
+ bal := st.GetBalance(addr)
+ rlp, _ := rlp.EncodeToBytes(bal)
+ res = append(res, rlp...)
}
}
@@ -115,7 +114,7 @@ func odrContractCall(ctx context.Context, db ethdb.Database, config *params.Chai
data[35] = byte(i)
if bc != nil {
header := bc.GetHeaderByHash(bhash)
- statedb, err := state.New(header.Root, db)
+ statedb, err := state.New(header.Root, state.NewDatabase(db))
if err == nil {
from := statedb.GetOrNewStateObject(testBankAddress)
@@ -133,23 +132,15 @@ func odrContractCall(ctx context.Context, db ethdb.Database, config *params.Chai
}
} else {
header := lc.GetHeaderByHash(bhash)
- state := light.NewLightState(light.StateTrieID(header), lc.Odr())
- vmstate := light.NewVMState(ctx, state)
- from, err := state.GetOrNewStateObject(ctx, testBankAddress)
- if err == nil {
- from.SetBalance(math.MaxBig256)
-
- msg := callmsg{types.NewMessage(from.Address(), &testContractAddr, 0, new(big.Int), big.NewInt(100000), new(big.Int), data, false)}
-
- context := core.NewEVMContext(msg, header, lc, nil)
- vmenv := vm.NewEVM(context, vmstate, config, vm.Config{})
-
- //vmenv := light.NewEnv(ctx, state, config, lc, msg, header, vm.Config{})
- gp := new(core.GasPool).AddGas(math.MaxBig256)
- ret, _, _ := core.ApplyMessage(vmenv, msg, gp)
- if vmstate.Error() == nil {
- res = append(res, ret...)
- }
+ state := light.NewState(ctx, header, lc.Odr())
+ state.SetBalance(testBankAddress, math.MaxBig256)
+ msg := callmsg{types.NewMessage(testBankAddress, &testContractAddr, 0, new(big.Int), big.NewInt(100000), new(big.Int), data, false)}
+ context := core.NewEVMContext(msg, header, lc, nil)
+ vmenv := vm.NewEVM(context, state, config, vm.Config{})
+ gp := new(core.GasPool).AddGas(math.MaxBig256)
+ ret, _, _ := core.ApplyMessage(vmenv, msg, gp)
+ if state.Error() == nil {
+ res = append(res, ret...)
}
}
}
diff --git a/les/request_test.go b/les/request_test.go
index 3add5f20d..6b594462d 100644
--- a/les/request_test.go
+++ b/les/request_test.go
@@ -62,7 +62,7 @@ func tfCodeAccess(db ethdb.Database, bhash common.Hash, number uint64) light.Odr
return nil
}
sti := light.StateTrieID(header)
- ci := light.StorageTrieID(sti, testContractAddr, common.Hash{})
+ ci := light.StorageTrieID(sti, crypto.Keccak256Hash(testContractAddr[:]), common.Hash{})
return &light.CodeRequest{Id: ci, Hash: crypto.Keccak256Hash(testContractCodeDeployed)}
}
diff --git a/light/lightchain.go b/light/lightchain.go
index 5b7e57041..87436f4a5 100644
--- a/light/lightchain.go
+++ b/light/lightchain.go
@@ -180,11 +180,6 @@ func (self *LightChain) Status() (td *big.Int, currentBlock common.Hash, genesis
return self.GetTd(hash, header.Number.Uint64()), hash, self.genesisBlock.Hash()
}
-// State returns a new mutable state based on the current HEAD block.
-func (self *LightChain) State() *LightState {
- return NewLightState(StateTrieID(self.hc.CurrentHeader()), self.odr)
-}
-
// Reset purges the entire blockchain, restoring it to its genesis state.
func (bc *LightChain) Reset() {
bc.ResetWithGenesisBlock(bc.genesisBlock)
diff --git a/light/odr.go b/light/odr.go
index ca6364f28..d19a488f6 100644
--- a/light/odr.go
+++ b/light/odr.go
@@ -34,7 +34,7 @@ import (
// service is not required.
var NoOdr = context.Background()
-// OdrBackend is an interface to a backend service that handles ODR retrievals
+// OdrBackend is an interface to a backend service that handles ODR retrievals type
type OdrBackend interface {
Database() ethdb.Database
Retrieve(ctx context.Context, req OdrRequest) error
@@ -66,11 +66,11 @@ func StateTrieID(header *types.Header) *TrieID {
// StorageTrieID returns a TrieID for a contract storage trie at a given account
// of a given state trie. It also requires the root hash of the trie for
// checking Merkle proofs.
-func StorageTrieID(state *TrieID, addr common.Address, root common.Hash) *TrieID {
+func StorageTrieID(state *TrieID, addrHash, root common.Hash) *TrieID {
return &TrieID{
BlockHash: state.BlockHash,
BlockNumber: state.BlockNumber,
- AccKey: crypto.Keccak256(addr[:]),
+ AccKey: addrHash[:],
Root: root,
}
}
@@ -102,7 +102,7 @@ func storeProof(db ethdb.Database, proof []rlp.RawValue) {
// CodeRequest is the ODR request type for retrieving contract code
type CodeRequest struct {
OdrRequest
- Id *TrieID
+ Id *TrieID // references storage trie of the account
Hash common.Hash
Data []byte
}
diff --git a/light/odr_test.go b/light/odr_test.go
index 576e3abc9..544b64eff 100644
--- a/light/odr_test.go
+++ b/light/odr_test.go
@@ -86,11 +86,11 @@ func (odr *testOdr) Retrieve(ctx context.Context, req OdrRequest) error {
return nil
}
-type odrTestFn func(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) []byte
+type odrTestFn func(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) ([]byte, error)
-func TestOdrGetBlockLes1(t *testing.T) { testChainOdr(t, 1, 1, odrGetBlock) }
+func TestOdrGetBlockLes1(t *testing.T) { testChainOdr(t, 1, odrGetBlock) }
-func odrGetBlock(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) []byte {
+func odrGetBlock(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) ([]byte, error) {
var block *types.Block
if bc != nil {
block = bc.GetBlockByHash(bhash)
@@ -98,15 +98,15 @@ func odrGetBlock(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc
block, _ = lc.GetBlockByHash(ctx, bhash)
}
if block == nil {
- return nil
+ return nil, nil
}
rlp, _ := rlp.EncodeToBytes(block)
- return rlp
+ return rlp, nil
}
-func TestOdrGetReceiptsLes1(t *testing.T) { testChainOdr(t, 1, 1, odrGetReceipts) }
+func TestOdrGetReceiptsLes1(t *testing.T) { testChainOdr(t, 1, odrGetReceipts) }
-func odrGetReceipts(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) []byte {
+func odrGetReceipts(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) ([]byte, error) {
var receipts types.Receipts
if bc != nil {
receipts = core.GetBlockReceipts(db, bhash, core.GetBlockNumber(db, bhash))
@@ -114,43 +114,37 @@ func odrGetReceipts(ctx context.Context, db ethdb.Database, bc *core.BlockChain,
receipts, _ = GetBlockReceipts(ctx, lc.Odr(), bhash, core.GetBlockNumber(db, bhash))
}
if receipts == nil {
- return nil
+ return nil, nil
}
rlp, _ := rlp.EncodeToBytes(receipts)
- return rlp
+ return rlp, nil
}
-func TestOdrAccountsLes1(t *testing.T) { testChainOdr(t, 1, 1, odrAccounts) }
+func TestOdrAccountsLes1(t *testing.T) { testChainOdr(t, 1, odrAccounts) }
-func odrAccounts(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) []byte {
+func odrAccounts(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) ([]byte, error) {
dummyAddr := common.HexToAddress("1234567812345678123456781234567812345678")
acc := []common.Address{testBankAddress, acc1Addr, acc2Addr, dummyAddr}
+ var st *state.StateDB
+ if bc == nil {
+ header := lc.GetHeaderByHash(bhash)
+ st = NewState(ctx, header, lc.Odr())
+ } else {
+ header := bc.GetHeaderByHash(bhash)
+ st, _ = state.New(header.Root, state.NewDatabase(db))
+ }
+
var res []byte
for _, addr := range acc {
- if bc != nil {
- header := bc.GetHeaderByHash(bhash)
- st, err := state.New(header.Root, db)
- if err == nil {
- bal := st.GetBalance(addr)
- rlp, _ := rlp.EncodeToBytes(bal)
- res = append(res, rlp...)
- }
- } else {
- header := lc.GetHeaderByHash(bhash)
- st := NewLightState(StateTrieID(header), lc.Odr())
- bal, err := st.GetBalance(ctx, addr)
- if err == nil {
- rlp, _ := rlp.EncodeToBytes(bal)
- res = append(res, rlp...)
- }
- }
+ bal := st.GetBalance(addr)
+ rlp, _ := rlp.EncodeToBytes(bal)
+ res = append(res, rlp...)
}
-
- return res
+ return res, st.Error()
}
-func TestOdrContractCallLes1(t *testing.T) { testChainOdr(t, 1, 2, odrContractCall) }
+func TestOdrContractCallLes1(t *testing.T) { testChainOdr(t, 1, odrContractCall) }
type callmsg struct {
types.Message
@@ -158,50 +152,42 @@ type callmsg struct {
func (callmsg) CheckNonce() bool { return false }
-func odrContractCall(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) []byte {
+func odrContractCall(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) ([]byte, error) {
data := common.Hex2Bytes("60CD26850000000000000000000000000000000000000000000000000000000000000000")
-
config := params.TestChainConfig
var res []byte
for i := 0; i < 3; i++ {
data[35] = byte(i)
- if bc != nil {
- header := bc.GetHeaderByHash(bhash)
- statedb, err := state.New(header.Root, db)
- if err == nil {
- from := statedb.GetOrNewStateObject(testBankAddress)
- from.SetBalance(math.MaxBig256)
-
- msg := callmsg{types.NewMessage(from.Address(), &testContractAddr, 0, new(big.Int), big.NewInt(1000000), new(big.Int), data, false)}
- context := core.NewEVMContext(msg, header, bc, nil)
- vmenv := vm.NewEVM(context, statedb, config, vm.Config{})
-
- gp := new(core.GasPool).AddGas(math.MaxBig256)
- ret, _, _ := core.ApplyMessage(vmenv, msg, gp)
- res = append(res, ret...)
- }
+ var (
+ st *state.StateDB
+ header *types.Header
+ chain core.ChainContext
+ )
+ if bc == nil {
+ chain = lc
+ header = lc.GetHeaderByHash(bhash)
+ st = NewState(ctx, header, lc.Odr())
} else {
- header := lc.GetHeaderByHash(bhash)
- state := NewLightState(StateTrieID(header), lc.Odr())
- vmstate := NewVMState(ctx, state)
- from, err := state.GetOrNewStateObject(ctx, testBankAddress)
- if err == nil {
- from.SetBalance(math.MaxBig256)
-
- msg := callmsg{types.NewMessage(from.Address(), &testContractAddr, 0, new(big.Int), big.NewInt(1000000), new(big.Int), data, false)}
- context := core.NewEVMContext(msg, header, lc, nil)
- vmenv := vm.NewEVM(context, vmstate, config, vm.Config{})
- gp := new(core.GasPool).AddGas(math.MaxBig256)
- ret, _, _ := core.ApplyMessage(vmenv, msg, gp)
- if vmstate.Error() == nil {
- res = append(res, ret...)
- }
- }
+ chain = bc
+ header = bc.GetHeaderByHash(bhash)
+ st, _ = state.New(header.Root, state.NewDatabase(db))
+ }
+
+ // Perform read-only call.
+ st.SetBalance(testBankAddress, math.MaxBig256)
+ msg := callmsg{types.NewMessage(testBankAddress, &testContractAddr, 0, new(big.Int), big.NewInt(1000000), new(big.Int), data, false)}
+ context := core.NewEVMContext(msg, header, chain, nil)
+ vmenv := vm.NewEVM(context, st, config, vm.Config{})
+ gp := new(core.GasPool).AddGas(math.MaxBig256)
+ ret, _, _ := core.ApplyMessage(vmenv, msg, gp)
+ res = append(res, ret...)
+ if st.Error() != nil {
+ return res, st.Error()
}
}
- return res
+ return res, nil
}
func testChainGen(i int, block *core.BlockGen) {
@@ -245,7 +231,7 @@ func testChainGen(i int, block *core.BlockGen) {
}
}
-func testChainOdr(t *testing.T, protocol int, expFail uint64, fn odrTestFn) {
+func testChainOdr(t *testing.T, protocol int, fn odrTestFn) {
var (
evmux = new(event.TypeMux)
sdb, _ = ethdb.NewMemDatabase()
@@ -258,46 +244,58 @@ func testChainOdr(t *testing.T, protocol int, expFail uint64, fn odrTestFn) {
blockchain, _ := core.NewBlockChain(sdb, params.TestChainConfig, ethash.NewFullFaker(), evmux, vm.Config{})
gchain, _ := core.GenerateChain(params.TestChainConfig, genesis, sdb, 4, testChainGen)
if _, err := blockchain.InsertChain(gchain); err != nil {
- panic(err)
+ t.Fatal(err)
}
odr := &testOdr{sdb: sdb, ldb: ldb}
- lightchain, _ := NewLightChain(odr, params.TestChainConfig, ethash.NewFullFaker(), evmux)
+ lightchain, err := NewLightChain(odr, params.TestChainConfig, ethash.NewFullFaker(), evmux)
+ if err != nil {
+ t.Fatal(err)
+ }
headers := make([]*types.Header, len(gchain))
for i, block := range gchain {
headers[i] = block.Header()
}
if _, err := lightchain.InsertHeaderChain(headers, 1); err != nil {
- panic(err)
+ t.Fatal(err)
}
- test := func(expFail uint64) {
+ test := func(expFail int) {
for i := uint64(0); i <= blockchain.CurrentHeader().Number.Uint64(); i++ {
bhash := core.GetCanonicalHash(sdb, i)
- b1 := fn(NoOdr, sdb, blockchain, nil, bhash)
+ b1, err := fn(NoOdr, sdb, blockchain, nil, bhash)
+ if err != nil {
+ t.Fatalf("error in full-node test for block %d: %v", i, err)
+ }
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
- b2 := fn(ctx, ldb, nil, lightchain, bhash)
+
+ exp := i < uint64(expFail)
+ b2, err := fn(ctx, ldb, nil, lightchain, bhash)
+ if err != nil && exp {
+ t.Errorf("error in ODR test for block %d: %v", i, err)
+ }
eq := bytes.Equal(b1, b2)
- exp := i < expFail
if exp && !eq {
- t.Errorf("odr mismatch")
- }
- if !exp && eq {
- t.Errorf("unexpected odr match")
+ t.Errorf("ODR test output for block %d doesn't match full node", i)
}
}
}
- odr.disable = true
// expect retrievals to fail (except genesis block) without a les peer
- test(expFail)
- odr.disable = false
- // expect all retrievals to pass
- test(5)
+ t.Log("checking without ODR")
odr.disable = true
+ test(1)
+
+ // expect all retrievals to pass with ODR enabled
+ t.Log("checking with ODR")
+ odr.disable = false
+ test(len(gchain))
+
// still expect all retrievals to pass, now data should be cached locally
- test(5)
+ t.Log("checking without ODR, should be cached")
+ odr.disable = true
+ test(len(gchain))
}
diff --git a/light/odr_util.go b/light/odr_util.go
index d7f8458f1..fcdfdb82c 100644
--- a/light/odr_util.go
+++ b/light/odr_util.go
@@ -106,25 +106,6 @@ func GetCanonicalHash(ctx context.Context, odr OdrBackend, number uint64) (commo
return common.Hash{}, err
}
-// retrieveContractCode tries to retrieve the contract code of the given account
-// with the given hash from the network (id points to the storage trie belonging
-// to the same account)
-func retrieveContractCode(ctx context.Context, odr OdrBackend, id *TrieID, hash common.Hash) ([]byte, error) {
- if hash == sha3_nil {
- return nil, nil
- }
- res, _ := odr.Database().Get(hash[:])
- if res != nil {
- return res, nil
- }
- r := &CodeRequest{Id: id, Hash: hash}
- if err := odr.Retrieve(ctx, r); err != nil {
- return nil, err
- } else {
- return r.Data, nil
- }
-}
-
// GetBodyRLP retrieves the block body (transactions and uncles) in RLP encoding.
func GetBodyRLP(ctx context.Context, odr OdrBackend, hash common.Hash, number uint64) (rlp.RawValue, error) {
if data := core.GetBodyRLP(odr.Database(), hash, number); data != nil {
diff --git a/light/state.go b/light/state.go
deleted file mode 100644
index b184dc3a5..000000000
--- a/light/state.go
+++ /dev/null
@@ -1,316 +0,0 @@
-// Copyright 2015 The go-ethereum Authors
-// This file is part of the go-ethereum library.
-//
-// The go-ethereum library is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Lesser General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// The go-ethereum library is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Lesser General Public License for more details.
-//
-// You should have received a copy of the GNU Lesser General Public License
-// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
-
-package light
-
-import (
- "context"
- "math/big"
-
- "github.com/ethereum/go-ethereum/common"
- "github.com/ethereum/go-ethereum/crypto"
-)
-
-// LightState is a memory representation of a state.
-// This version is ODR capable, caching only the already accessed part of the
-// state, retrieving unknown parts on-demand from the ODR backend. Changes are
-// never stored in the local database, only in the memory objects.
-type LightState struct {
- odr OdrBackend
- trie *LightTrie
- id *TrieID
- stateObjects map[string]*StateObject
- refund *big.Int
-}
-
-// NewLightState creates a new LightState with the specified root.
-// Note that the creation of a light state is always successful, even if the
-// root is non-existent. In that case, ODR retrieval will always be unsuccessful
-// and every operation will return with an error or wait for the context to be
-// cancelled.
-func NewLightState(id *TrieID, odr OdrBackend) *LightState {
- var tr *LightTrie
- if id != nil {
- tr = NewLightTrie(id, odr, true)
- }
- return &LightState{
- odr: odr,
- trie: tr,
- id: id,
- stateObjects: make(map[string]*StateObject),
- refund: new(big.Int),
- }
-}
-
-// AddRefund adds an amount to the refund value collected during a vm execution
-func (self *LightState) AddRefund(gas *big.Int) {
- self.refund.Add(self.refund, gas)
-}
-
-// HasAccount returns true if an account exists at the given address
-func (self *LightState) HasAccount(ctx context.Context, addr common.Address) (bool, error) {
- so, err := self.GetStateObject(ctx, addr)
- return so != nil, err
-}
-
-// GetBalance retrieves the balance from the given address or 0 if the account does
-// not exist
-func (self *LightState) GetBalance(ctx context.Context, addr common.Address) (*big.Int, error) {
- stateObject, err := self.GetStateObject(ctx, addr)
- if err != nil {
- return common.Big0, err
- }
- if stateObject != nil {
- return stateObject.balance, nil
- }
-
- return common.Big0, nil
-}
-
-// GetNonce returns the nonce at the given address or 0 if the account does
-// not exist
-func (self *LightState) GetNonce(ctx context.Context, addr common.Address) (uint64, error) {
- stateObject, err := self.GetStateObject(ctx, addr)
- if err != nil {
- return 0, err
- }
- if stateObject != nil {
- return stateObject.nonce, nil
- }
- return 0, nil
-}
-
-// GetCode returns the contract code at the given address or nil if the account
-// does not exist
-func (self *LightState) GetCode(ctx context.Context, addr common.Address) ([]byte, error) {
- stateObject, err := self.GetStateObject(ctx, addr)
- if err != nil {
- return nil, err
- }
- if stateObject != nil {
- return stateObject.code, nil
- }
- return nil, nil
-}
-
-// GetState returns the contract storage value at storage address b from the
-// contract address a or common.Hash{} if the account does not exist
-func (self *LightState) GetState(ctx context.Context, a common.Address, b common.Hash) (common.Hash, error) {
- stateObject, err := self.GetStateObject(ctx, a)
- if err == nil && stateObject != nil {
- return stateObject.GetState(ctx, b)
- }
- return common.Hash{}, err
-}
-
-// HasSuicided returns true if the given account has been marked for deletion
-// or false if the account does not exist
-func (self *LightState) HasSuicided(ctx context.Context, addr common.Address) (bool, error) {
- stateObject, err := self.GetStateObject(ctx, addr)
- if err == nil && stateObject != nil {
- return stateObject.remove, nil
- }
- return false, err
-}
-
-/*
- * SETTERS
- */
-
-// AddBalance adds the given amount to the balance of the specified account
-func (self *LightState) AddBalance(ctx context.Context, addr common.Address, amount *big.Int) error {
- stateObject, err := self.GetOrNewStateObject(ctx, addr)
- if err == nil && stateObject != nil {
- stateObject.AddBalance(amount)
- }
- return err
-}
-
-// SubBalance adds the given amount to the balance of the specified account
-func (self *LightState) SubBalance(ctx context.Context, addr common.Address, amount *big.Int) error {
- stateObject, err := self.GetOrNewStateObject(ctx, addr)
- if err == nil && stateObject != nil {
- stateObject.SubBalance(amount)
- }
- return err
-}
-
-// SetNonce sets the nonce of the specified account
-func (self *LightState) SetNonce(ctx context.Context, addr common.Address, nonce uint64) error {
- stateObject, err := self.GetOrNewStateObject(ctx, addr)
- if err == nil && stateObject != nil {
- stateObject.SetNonce(nonce)
- }
- return err
-}
-
-// SetCode sets the contract code at the specified account
-func (self *LightState) SetCode(ctx context.Context, addr common.Address, code []byte) error {
- stateObject, err := self.GetOrNewStateObject(ctx, addr)
- if err == nil && stateObject != nil {
- stateObject.SetCode(crypto.Keccak256Hash(code), code)
- }
- return err
-}
-
-// SetState sets the storage value at storage address key of the account addr
-func (self *LightState) SetState(ctx context.Context, addr common.Address, key common.Hash, value common.Hash) error {
- stateObject, err := self.GetOrNewStateObject(ctx, addr)
- if err == nil && stateObject != nil {
- stateObject.SetState(key, value)
- }
- return err
-}
-
-// Delete marks an account to be removed and clears its balance
-func (self *LightState) Suicide(ctx context.Context, addr common.Address) (bool, error) {
- stateObject, err := self.GetOrNewStateObject(ctx, addr)
- if err == nil && stateObject != nil {
- stateObject.MarkForDeletion()
- stateObject.balance = new(big.Int)
-
- return true, nil
- }
-
- return false, err
-}
-
-//
-// Get, set, new state object methods
-//
-
-// GetStateObject returns the state object of the given account or nil if the
-// account does not exist
-func (self *LightState) GetStateObject(ctx context.Context, addr common.Address) (stateObject *StateObject, err error) {
- stateObject = self.stateObjects[addr.Str()]
- if stateObject != nil {
- if stateObject.deleted {
- stateObject = nil
- }
- return stateObject, nil
- }
- data, err := self.trie.Get(ctx, addr[:])
- if err != nil {
- return nil, err
- }
- if len(data) == 0 {
- return nil, nil
- }
-
- stateObject, err = DecodeObject(ctx, self.id, addr, self.odr, []byte(data))
- if err != nil {
- return nil, err
- }
-
- self.SetStateObject(stateObject)
-
- return stateObject, nil
-}
-
-// SetStateObject sets the state object of the given account
-func (self *LightState) SetStateObject(object *StateObject) {
- self.stateObjects[object.Address().Str()] = object
-}
-
-// GetOrNewStateObject returns the state object of the given account or creates a
-// new one if the account does not exist
-func (self *LightState) GetOrNewStateObject(ctx context.Context, addr common.Address) (*StateObject, error) {
- stateObject, err := self.GetStateObject(ctx, addr)
- if err == nil && (stateObject == nil || stateObject.deleted) {
- stateObject, err = self.CreateStateObject(ctx, addr)
- }
- return stateObject, err
-}
-
-// newStateObject creates a state object whether it exists in the state or not
-func (self *LightState) newStateObject(addr common.Address) *StateObject {
- stateObject := NewStateObject(addr, self.odr)
- self.stateObjects[addr.Str()] = stateObject
-
- return stateObject
-}
-
-// CreateStateObject creates creates a new state object and takes ownership.
-// This is different from "NewStateObject"
-func (self *LightState) CreateStateObject(ctx context.Context, addr common.Address) (*StateObject, error) {
- // Get previous (if any)
- so, err := self.GetStateObject(ctx, addr)
- if err != nil {
- return nil, err
- }
- // Create a new one
- newSo := self.newStateObject(addr)
-
- // If it existed set the balance to the new account
- if so != nil {
- newSo.balance = so.balance
- }
-
- return newSo, nil
-}
-
-// ForEachStorage calls a callback function for every key/value pair found
-// in the local storage cache. Note that unlike core/state.StateObject,
-// light.StateObject only returns cached values and doesn't download the
-// entire storage tree.
-func (self *LightState) ForEachStorage(ctx context.Context, addr common.Address, cb func(key, value common.Hash) bool) error {
- so, err := self.GetStateObject(ctx, addr)
- if err != nil {
- return err
- }
-
- if so == nil {
- return nil
- }
-
- for h, v := range so.storage {
- cb(h, v)
- }
- return nil
-}
-
-//
-// Setting, copying of the state methods
-//
-
-// Copy creates a copy of the state
-func (self *LightState) Copy() *LightState {
- // ignore error - we assume state-to-be-copied always exists
- state := NewLightState(nil, self.odr)
- state.trie = self.trie
- state.id = self.id
- for k, stateObject := range self.stateObjects {
- if stateObject.dirty {
- state.stateObjects[k] = stateObject.Copy()
- }
- }
-
- state.refund.Set(self.refund)
- return state
-}
-
-// Set copies the contents of the given state onto this state, overwriting
-// its contents
-func (self *LightState) Set(state *LightState) {
- self.trie = state.trie
- self.stateObjects = state.stateObjects
- self.refund = state.refund
-}
-
-// GetRefund returns the refund value collected during a vm execution
-func (self *LightState) GetRefund() *big.Int {
- return self.refund
-}
diff --git a/light/state_object.go b/light/state_object.go
deleted file mode 100644
index a54ea1d9f..000000000
--- a/light/state_object.go
+++ /dev/null
@@ -1,275 +0,0 @@
-// Copyright 2015 The go-ethereum Authors
-// This file is part of the go-ethereum library.
-//
-// The go-ethereum library is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Lesser General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// The go-ethereum library is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Lesser General Public License for more details.
-//
-// You should have received a copy of the GNU Lesser General Public License
-// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
-
-package light
-
-import (
- "bytes"
- "context"
- "fmt"
- "math/big"
-
- "github.com/ethereum/go-ethereum/common"
- "github.com/ethereum/go-ethereum/crypto"
- "github.com/ethereum/go-ethereum/rlp"
-)
-
-var emptyCodeHash = crypto.Keccak256(nil)
-
-// Code represents a contract code in binary form
-type Code []byte
-
-// String returns a string representation of the code
-func (self Code) String() string {
- return string(self) //strings.Join(Disassemble(self), " ")
-}
-
-// Storage is a memory map cache of a contract storage
-type Storage map[common.Hash]common.Hash
-
-// String returns a string representation of the storage cache
-func (self Storage) String() (str string) {
- for key, value := range self {
- str += fmt.Sprintf("%X : %X\n", key, value)
- }
-
- return
-}
-
-// Copy copies the contents of a storage cache
-func (self Storage) Copy() Storage {
- cpy := make(Storage)
- for key, value := range self {
- cpy[key] = value
- }
-
- return cpy
-}
-
-// StateObject is a memory representation of an account or contract and its storage.
-// This version is ODR capable, caching only the already accessed part of the
-// storage, retrieving unknown parts on-demand from the ODR backend. Changes are
-// never stored in the local database, only in the memory objects.
-type StateObject struct {
- odr OdrBackend
- trie *LightTrie
-
- // Address belonging to this account
- address common.Address
- // The balance of the account
- balance *big.Int
- // The nonce of the account
- nonce uint64
- // The code hash if code is present (i.e. a contract)
- codeHash []byte
- // The code for this account
- code Code
- // Cached storage (flushed when updated)
- storage Storage
-
- // Mark for deletion
- // When an object is marked for deletion it will be delete from the trie
- // during the "update" phase of the state transition
- remove bool
- deleted bool
- dirty bool
-}
-
-// NewStateObject creates a new StateObject of the specified account address
-func NewStateObject(address common.Address, odr OdrBackend) *StateObject {
- object := &StateObject{
- odr: odr,
- address: address,
- balance: new(big.Int),
- dirty: true,
- codeHash: emptyCodeHash,
- storage: make(Storage),
- }
- object.trie = NewLightTrie(&TrieID{}, odr, true)
- return object
-}
-
-// MarkForDeletion marks an account to be removed
-func (self *StateObject) MarkForDeletion() {
- self.remove = true
- self.dirty = true
-}
-
-// getAddr gets the storage value at the given address from the trie
-func (c *StateObject) getAddr(ctx context.Context, addr common.Hash) (common.Hash, error) {
- var ret []byte
- val, err := c.trie.Get(ctx, addr[:])
- if err != nil {
- return common.Hash{}, err
- }
- rlp.DecodeBytes(val, &ret)
- return common.BytesToHash(ret), nil
-}
-
-// Storage returns the storage cache object of the account
-func (self *StateObject) Storage() Storage {
- return self.storage
-}
-
-// GetState returns the storage value at the given address from either the cache
-// or the trie
-func (self *StateObject) GetState(ctx context.Context, key common.Hash) (common.Hash, error) {
- value, exists := self.storage[key]
- if !exists {
- var err error
- value, err = self.getAddr(ctx, key)
- if err != nil {
- return common.Hash{}, err
- }
- if (value != common.Hash{}) {
- self.storage[key] = value
- }
- }
-
- return value, nil
-}
-
-// SetState sets the storage value at the given address
-func (self *StateObject) SetState(k, value common.Hash) {
- self.storage[k] = value
- self.dirty = true
-}
-
-// AddBalance adds the given amount to the account balance
-func (c *StateObject) AddBalance(amount *big.Int) {
- c.SetBalance(new(big.Int).Add(c.balance, amount))
-}
-
-// SubBalance subtracts the given amount from the account balance
-func (c *StateObject) SubBalance(amount *big.Int) {
- c.SetBalance(new(big.Int).Sub(c.balance, amount))
-}
-
-// SetBalance sets the account balance to the given amount
-func (c *StateObject) SetBalance(amount *big.Int) {
- c.balance = amount
- c.dirty = true
-}
-
-// ReturnGas returns the gas back to the origin. Used by the Virtual machine or Closures
-func (c *StateObject) ReturnGas(gas *big.Int) {}
-
-// Copy creates a copy of the state object
-func (self *StateObject) Copy() *StateObject {
- stateObject := NewStateObject(self.Address(), self.odr)
- stateObject.balance.Set(self.balance)
- stateObject.codeHash = common.CopyBytes(self.codeHash)
- stateObject.nonce = self.nonce
- stateObject.trie = self.trie
- stateObject.code = self.code
- stateObject.storage = self.storage.Copy()
- stateObject.remove = self.remove
- stateObject.dirty = self.dirty
- stateObject.deleted = self.deleted
-
- return stateObject
-}
-
-//
-// Attribute accessors
-//
-
-// empty returns whether the account is considered empty.
-func (self *StateObject) empty() bool {
- return self.nonce == 0 && self.balance.Sign() == 0 && bytes.Equal(self.codeHash, emptyCodeHash)
-}
-
-// Balance returns the account balance
-func (self *StateObject) Balance() *big.Int {
- return self.balance
-}
-
-// Address returns the address of the contract/account
-func (self *StateObject) Address() common.Address {
- return self.address
-}
-
-// Code returns the contract code
-func (self *StateObject) Code() []byte {
- return self.code
-}
-
-// SetCode sets the contract code
-func (self *StateObject) SetCode(hash common.Hash, code []byte) {
- self.code = code
- self.codeHash = hash[:]
- self.dirty = true
-}
-
-// SetNonce sets the account nonce
-func (self *StateObject) SetNonce(nonce uint64) {
- self.nonce = nonce
- self.dirty = true
-}
-
-// Nonce returns the account nonce
-func (self *StateObject) Nonce() uint64 {
- return self.nonce
-}
-
-// ForEachStorage calls a callback function for every key/value pair found
-// in the local storage cache. Note that unlike core/state.StateObject,
-// light.StateObject only returns cached values and doesn't download the
-// entire storage tree.
-func (self *StateObject) ForEachStorage(cb func(key, value common.Hash) bool) {
- for h, v := range self.storage {
- cb(h, v)
- }
-}
-
-// Never called, but must be present to allow StateObject to be used
-// as a vm.Account interface that also satisfies the vm.ContractRef
-// interface. Interfaces are awesome.
-func (self *StateObject) Value() *big.Int {
- panic("Value on StateObject should never be called")
-}
-
-// Encoding
-
-type extStateObject struct {
- Nonce uint64
- Balance *big.Int
- Root common.Hash
- CodeHash []byte
-}
-
-// DecodeObject decodes an RLP-encoded state object.
-func DecodeObject(ctx context.Context, stateID *TrieID, address common.Address, odr OdrBackend, data []byte) (*StateObject, error) {
- var (
- obj = &StateObject{address: address, odr: odr, storage: make(Storage)}
- ext extStateObject
- err error
- )
- if err = rlp.DecodeBytes(data, &ext); err != nil {
- return nil, err
- }
- trieID := StorageTrieID(stateID, address, ext.Root)
- obj.trie = NewLightTrie(trieID, odr, true)
- if !bytes.Equal(ext.CodeHash, emptyCodeHash) {
- if obj.code, err = retrieveContractCode(ctx, obj.odr, trieID, common.BytesToHash(ext.CodeHash)); err != nil {
- return nil, fmt.Errorf("can't find code for hash %x: %v", ext.CodeHash, err)
- }
- }
- obj.nonce = ext.Nonce
- obj.balance = ext.Balance
- obj.codeHash = ext.CodeHash
- return obj, nil
-}
diff --git a/light/state_test.go b/light/state_test.go
deleted file mode 100644
index e776efec8..000000000
--- a/light/state_test.go
+++ /dev/null
@@ -1,248 +0,0 @@
-// Copyright 2015 The go-ethereum Authors
-// This file is part of the go-ethereum library.
-//
-// The go-ethereum library is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Lesser General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// The go-ethereum library is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Lesser General Public License for more details.
-//
-// You should have received a copy of the GNU Lesser General Public License
-// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
-
-package light
-
-import (
- "bytes"
- "context"
- "math/big"
- "testing"
-
- "github.com/ethereum/go-ethereum/common"
- "github.com/ethereum/go-ethereum/core"
- "github.com/ethereum/go-ethereum/core/state"
- "github.com/ethereum/go-ethereum/core/types"
- "github.com/ethereum/go-ethereum/ethdb"
-)
-
-func makeTestState() (common.Hash, ethdb.Database) {
- sdb, _ := ethdb.NewMemDatabase()
- st, _ := state.New(common.Hash{}, sdb)
- for i := byte(0); i < 100; i++ {
- addr := common.Address{i}
- for j := byte(0); j < 100; j++ {
- st.SetState(addr, common.Hash{j}, common.Hash{i, j})
- }
- st.SetNonce(addr, 100)
- st.AddBalance(addr, big.NewInt(int64(i)))
- st.SetCode(addr, []byte{i, i, i})
- }
- root, _ := st.Commit(false)
- return root, sdb
-}
-
-func TestLightStateOdr(t *testing.T) {
- root, sdb := makeTestState()
- header := &types.Header{Root: root, Number: big.NewInt(0)}
- core.WriteHeader(sdb, header)
- ldb, _ := ethdb.NewMemDatabase()
- odr := &testOdr{sdb: sdb, ldb: ldb}
- ls := NewLightState(StateTrieID(header), odr)
- ctx := context.Background()
-
- for i := byte(0); i < 100; i++ {
- addr := common.Address{i}
- err := ls.AddBalance(ctx, addr, big.NewInt(1000))
- if err != nil {
- t.Fatalf("Error adding balance to acc[%d]: %v", i, err)
- }
- err = ls.SetState(ctx, addr, common.Hash{100}, common.Hash{i, 100})
- if err != nil {
- t.Fatalf("Error setting storage of acc[%d]: %v", i, err)
- }
- }
-
- addr := common.Address{100}
- _, err := ls.CreateStateObject(ctx, addr)
- if err != nil {
- t.Fatalf("Error creating state object: %v", err)
- }
- err = ls.SetCode(ctx, addr, []byte{100, 100, 100})
- if err != nil {
- t.Fatalf("Error setting code: %v", err)
- }
- err = ls.AddBalance(ctx, addr, big.NewInt(1100))
- if err != nil {
- t.Fatalf("Error adding balance to acc[100]: %v", err)
- }
- for j := byte(0); j < 101; j++ {
- err = ls.SetState(ctx, addr, common.Hash{j}, common.Hash{100, j})
- if err != nil {
- t.Fatalf("Error setting storage of acc[100]: %v", err)
- }
- }
- err = ls.SetNonce(ctx, addr, 100)
- if err != nil {
- t.Fatalf("Error setting nonce for acc[100]: %v", err)
- }
-
- for i := byte(0); i < 101; i++ {
- addr := common.Address{i}
-
- bal, err := ls.GetBalance(ctx, addr)
- if err != nil {
- t.Fatalf("Error getting balance of acc[%d]: %v", i, err)
- }
- if bal.Int64() != int64(i)+1000 {
- t.Fatalf("Incorrect balance at acc[%d]: expected %v, got %v", i, int64(i)+1000, bal.Int64())
- }
-
- nonce, err := ls.GetNonce(ctx, addr)
- if err != nil {
- t.Fatalf("Error getting nonce of acc[%d]: %v", i, err)
- }
- if nonce != 100 {
- t.Fatalf("Incorrect nonce at acc[%d]: expected %v, got %v", i, 100, nonce)
- }
-
- code, err := ls.GetCode(ctx, addr)
- exp := []byte{i, i, i}
- if err != nil {
- t.Fatalf("Error getting code of acc[%d]: %v", i, err)
- }
- if !bytes.Equal(code, exp) {
- t.Fatalf("Incorrect code at acc[%d]: expected %v, got %v", i, exp, code)
- }
-
- for j := byte(0); j < 101; j++ {
- exp := common.Hash{i, j}
- val, err := ls.GetState(ctx, addr, common.Hash{j})
- if err != nil {
- t.Fatalf("Error retrieving acc[%d].storage[%d]: %v", i, j, err)
- }
- if val != exp {
- t.Fatalf("Retrieved wrong value from acc[%d].storage[%d]: expected %04x, got %04x", i, j, exp, val)
- }
- }
- }
-}
-
-func TestLightStateSetCopy(t *testing.T) {
- root, sdb := makeTestState()
- header := &types.Header{Root: root, Number: big.NewInt(0)}
- core.WriteHeader(sdb, header)
- ldb, _ := ethdb.NewMemDatabase()
- odr := &testOdr{sdb: sdb, ldb: ldb}
- ls := NewLightState(StateTrieID(header), odr)
- ctx := context.Background()
-
- for i := byte(0); i < 100; i++ {
- addr := common.Address{i}
- err := ls.AddBalance(ctx, addr, big.NewInt(1000))
- if err != nil {
- t.Fatalf("Error adding balance to acc[%d]: %v", i, err)
- }
- err = ls.SetState(ctx, addr, common.Hash{100}, common.Hash{i, 100})
- if err != nil {
- t.Fatalf("Error setting storage of acc[%d]: %v", i, err)
- }
- }
-
- ls2 := ls.Copy()
-
- for i := byte(0); i < 100; i++ {
- addr := common.Address{i}
- err := ls2.AddBalance(ctx, addr, big.NewInt(1000))
- if err != nil {
- t.Fatalf("Error adding balance to acc[%d]: %v", i, err)
- }
- err = ls2.SetState(ctx, addr, common.Hash{100}, common.Hash{i, 200})
- if err != nil {
- t.Fatalf("Error setting storage of acc[%d]: %v", i, err)
- }
- }
-
- lsx := ls.Copy()
- ls.Set(ls2)
- ls2.Set(lsx)
-
- for i := byte(0); i < 100; i++ {
- addr := common.Address{i}
- // check balance in ls
- bal, err := ls.GetBalance(ctx, addr)
- if err != nil {
- t.Fatalf("Error getting balance to acc[%d]: %v", i, err)
- }
- if bal.Int64() != int64(i)+2000 {
- t.Fatalf("Incorrect balance at ls.acc[%d]: expected %v, got %v", i, int64(i)+1000, bal.Int64())
- }
- // check balance in ls2
- bal, err = ls2.GetBalance(ctx, addr)
- if err != nil {
- t.Fatalf("Error getting balance to acc[%d]: %v", i, err)
- }
- if bal.Int64() != int64(i)+1000 {
- t.Fatalf("Incorrect balance at ls.acc[%d]: expected %v, got %v", i, int64(i)+1000, bal.Int64())
- }
- // check storage in ls
- exp := common.Hash{i, 200}
- val, err := ls.GetState(ctx, addr, common.Hash{100})
- if err != nil {
- t.Fatalf("Error retrieving acc[%d].storage[100]: %v", i, err)
- }
- if val != exp {
- t.Fatalf("Retrieved wrong value from acc[%d].storage[100]: expected %04x, got %04x", i, exp, val)
- }
- // check storage in ls2
- exp = common.Hash{i, 100}
- val, err = ls2.GetState(ctx, addr, common.Hash{100})
- if err != nil {
- t.Fatalf("Error retrieving acc[%d].storage[100]: %v", i, err)
- }
- if val != exp {
- t.Fatalf("Retrieved wrong value from acc[%d].storage[100]: expected %04x, got %04x", i, exp, val)
- }
- }
-}
-
-func TestLightStateDelete(t *testing.T) {
- root, sdb := makeTestState()
- header := &types.Header{Root: root, Number: big.NewInt(0)}
- core.WriteHeader(sdb, header)
- ldb, _ := ethdb.NewMemDatabase()
- odr := &testOdr{sdb: sdb, ldb: ldb}
- ls := NewLightState(StateTrieID(header), odr)
- ctx := context.Background()
-
- addr := common.Address{42}
-
- b, err := ls.HasAccount(ctx, addr)
- if err != nil {
- t.Fatalf("HasAccount error: %v", err)
- }
- if !b {
- t.Fatalf("HasAccount returned false, expected true")
- }
-
- b, err = ls.HasSuicided(ctx, addr)
- if err != nil {
- t.Fatalf("HasSuicided error: %v", err)
- }
- if b {
- t.Fatalf("HasSuicided returned true, expected false")
- }
-
- ls.Suicide(ctx, addr)
-
- b, err = ls.HasSuicided(ctx, addr)
- if err != nil {
- t.Fatalf("HasSuicided error: %v", err)
- }
- if !b {
- t.Fatalf("HasSuicided returned false, expected true")
- }
-}
diff --git a/light/trie.go b/light/trie.go
index 2988a16cf..7502b6e5d 100644
--- a/light/trie.go
+++ b/light/trie.go
@@ -18,99 +18,216 @@ package light
import (
"context"
+ "fmt"
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/state"
+ "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto"
- "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/trie"
)
-// LightTrie is an ODR-capable wrapper around trie.SecureTrie
-type LightTrie struct {
- trie *trie.SecureTrie
+func NewState(ctx context.Context, head *types.Header, odr OdrBackend) *state.StateDB {
+ state, _ := state.New(head.Root, NewStateDatabase(ctx, head, odr))
+ return state
+}
+
+func NewStateDatabase(ctx context.Context, head *types.Header, odr OdrBackend) state.Database {
+ return &odrDatabase{ctx, StateTrieID(head), odr}
+}
+
+type odrDatabase struct {
+ ctx context.Context
+ id *TrieID
+ backend OdrBackend
+}
+
+func (db *odrDatabase) OpenTrie(root common.Hash) (state.Trie, error) {
+ return &odrTrie{db: db, id: db.id}, nil
+}
+
+func (db *odrDatabase) OpenStorageTrie(addrHash, root common.Hash) (state.Trie, error) {
+ return &odrTrie{db: db, id: StorageTrieID(db.id, addrHash, root)}, nil
+}
+
+func (db *odrDatabase) CopyTrie(t state.Trie) state.Trie {
+ switch t := t.(type) {
+ case *odrTrie:
+ cpy := &odrTrie{db: t.db, id: t.id}
+ if t.trie != nil {
+ cpytrie := *t.trie
+ cpy.trie = &cpytrie
+ }
+ return cpy
+ default:
+ panic(fmt.Errorf("unknown trie type %T", t))
+ }
+}
+
+func (db *odrDatabase) ContractCode(addrHash, codeHash common.Hash) ([]byte, error) {
+ if codeHash == sha3_nil {
+ return nil, nil
+ }
+ if code, err := db.backend.Database().Get(codeHash[:]); err == nil {
+ return code, nil
+ }
+ id := *db.id
+ id.AccKey = addrHash[:]
+ req := &CodeRequest{Id: &id, Hash: codeHash}
+ err := db.backend.Retrieve(db.ctx, req)
+ return req.Data, err
+}
+
+func (db *odrDatabase) ContractCodeSize(addrHash, codeHash common.Hash) (int, error) {
+ code, err := db.ContractCode(addrHash, codeHash)
+ return len(code), err
+}
+
+type odrTrie struct {
+ db *odrDatabase
id *TrieID
- odr OdrBackend
- db ethdb.Database
-}
-
-// NewLightTrie creates a new LightTrie instance. It doesn't instantly try to
-// access the db or network and retrieve the root node, it only initializes its
-// encapsulated SecureTrie at the first actual operation.
-func NewLightTrie(id *TrieID, odr OdrBackend, useFakeMap bool) *LightTrie {
- return &LightTrie{
- // SecureTrie is initialized before first request
- id: id,
- odr: odr,
- db: odr.Database(),
+ trie *trie.Trie
+}
+
+func (t *odrTrie) TryGet(key []byte) ([]byte, error) {
+ key = crypto.Keccak256(key)
+ var res []byte
+ err := t.do(key, func() (err error) {
+ res, err = t.trie.TryGet(key)
+ return err
+ })
+ return res, err
+}
+
+func (t *odrTrie) TryUpdate(key, value []byte) error {
+ key = crypto.Keccak256(key)
+ return t.do(key, func() error {
+ return t.trie.TryDelete(key)
+ })
+}
+
+func (t *odrTrie) TryDelete(key []byte) error {
+ key = crypto.Keccak256(key)
+ return t.do(key, func() error {
+ return t.trie.TryDelete(key)
+ })
+}
+
+func (t *odrTrie) CommitTo(db trie.DatabaseWriter) (common.Hash, error) {
+ if t.trie == nil {
+ return t.id.Root, nil
+ }
+ return t.trie.CommitTo(db)
+}
+
+func (t *odrTrie) Hash() common.Hash {
+ if t.trie == nil {
+ return t.id.Root
}
+ return t.trie.Hash()
+}
+
+func (t *odrTrie) NodeIterator(startkey []byte) trie.NodeIterator {
+ return newNodeIterator(t, startkey)
}
-// retrieveKey retrieves a single key, returns true and stores nodes in local
-// database if successful
-func (t *LightTrie) retrieveKey(ctx context.Context, key []byte) bool {
- r := &TrieRequest{Id: t.id, Key: crypto.Keccak256(key)}
- return t.odr.Retrieve(ctx, r) == nil
+func (t *odrTrie) GetKey(sha []byte) []byte {
+ return nil
}
// do tries and retries to execute a function until it returns with no error or
// an error type other than MissingNodeError
-func (t *LightTrie) do(ctx context.Context, key []byte, fn func() error) error {
- err := fn()
- for err != nil {
+func (t *odrTrie) do(key []byte, fn func() error) error {
+ for {
+ var err error
+ if t.trie == nil {
+ t.trie, err = trie.New(t.id.Root, t.db.backend.Database())
+ }
+ if err == nil {
+ err = fn()
+ }
if _, ok := err.(*trie.MissingNodeError); !ok {
return err
}
- if !t.retrieveKey(ctx, key) {
- break
+ r := &TrieRequest{Id: t.id, Key: key}
+ if err := t.db.backend.Retrieve(t.db.ctx, r); err != nil {
+ return fmt.Errorf("can't fetch trie key %x: %v", key, err)
}
- err = fn()
}
- return err
}
-// Get returns the value for key stored in the trie.
-// The value bytes must not be modified by the caller.
-func (t *LightTrie) Get(ctx context.Context, key []byte) (res []byte, err error) {
- err = t.do(ctx, key, func() (err error) {
- if t.trie == nil {
- t.trie, err = trie.NewSecure(t.id.Root, t.db, 0)
- }
- if err == nil {
- res, err = t.trie.TryGet(key)
- }
- return
+type nodeIterator struct {
+ trie.NodeIterator
+ t *odrTrie
+ err error
+}
+
+func newNodeIterator(t *odrTrie, startkey []byte) trie.NodeIterator {
+ it := &nodeIterator{t: t}
+ // Open the actual non-ODR trie if that hasn't happened yet.
+ if t.trie == nil {
+ it.do(func() error {
+ t, err := trie.New(t.id.Root, t.db.backend.Database())
+ if err == nil {
+ it.t.trie = t
+ }
+ return err
+ })
+ }
+ it.do(func() error {
+ it.NodeIterator = it.t.trie.NodeIterator(startkey)
+ return it.NodeIterator.Error()
})
- return
+ return it
}
-// Update associates key with value in the trie. Subsequent calls to
-// Get will return value. If value has length zero, any existing value
-// is deleted from the trie and calls to Get will return nil.
-//
-// The value bytes must not be modified by the caller while they are
-// stored in the trie.
-func (t *LightTrie) Update(ctx context.Context, key, value []byte) (err error) {
- err = t.do(ctx, key, func() (err error) {
- if t.trie == nil {
- t.trie, err = trie.NewSecure(t.id.Root, t.db, 0)
- }
- if err == nil {
- err = t.trie.TryUpdate(key, value)
- }
- return
+func (it *nodeIterator) Next(descend bool) bool {
+ var ok bool
+ it.do(func() error {
+ ok = it.NodeIterator.Next(descend)
+ return it.NodeIterator.Error()
})
- return
+ return ok
}
-// Delete removes any existing value for key from the trie.
-func (t *LightTrie) Delete(ctx context.Context, key []byte) (err error) {
- err = t.do(ctx, key, func() (err error) {
- if t.trie == nil {
- t.trie, err = trie.NewSecure(t.id.Root, t.db, 0)
+// do runs fn and attempts to fill in missing nodes by retrieving.
+func (it *nodeIterator) do(fn func() error) {
+ var lasthash common.Hash
+ for {
+ it.err = fn()
+ missing, ok := it.err.(*trie.MissingNodeError)
+ if !ok {
+ return
}
- if err == nil {
- err = t.trie.TryDelete(key)
+ if missing.NodeHash == lasthash {
+ it.err = fmt.Errorf("retrieve loop for trie node %x", missing.NodeHash)
+ return
}
- return
- })
- return
+ lasthash = missing.NodeHash
+ r := &TrieRequest{Id: it.t.id, Key: nibblesToKey(missing.Path)}
+ if it.err = it.t.db.backend.Retrieve(it.t.db.ctx, r); it.err != nil {
+ return
+ }
+ }
+}
+
+func (it *nodeIterator) Error() error {
+ if it.err != nil {
+ return it.err
+ }
+ return it.NodeIterator.Error()
+}
+
+func nibblesToKey(nib []byte) []byte {
+ if len(nib) > 0 && nib[len(nib)-1] == 0x10 {
+ nib = nib[:len(nib)-1] // drop terminator
+ }
+ if len(nib)&1 == 1 {
+ nib = append(nib, 0) // make even
+ }
+ key := make([]byte, len(nib)/2)
+ for bi, ni := 0, 0; ni < len(nib); bi, ni = bi+1, ni+2 {
+ key[bi] = nib[ni]<<4 | nib[ni+1]
+ }
+ return key
}
diff --git a/light/trie_test.go b/light/trie_test.go
new file mode 100644
index 000000000..9b2cf7c2b
--- /dev/null
+++ b/light/trie_test.go
@@ -0,0 +1,83 @@
+// Copyright 2017 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+package light
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "testing"
+
+ "github.com/davecgh/go-spew/spew"
+ "github.com/ethereum/go-ethereum/consensus/ethash"
+ "github.com/ethereum/go-ethereum/core"
+ "github.com/ethereum/go-ethereum/core/state"
+ "github.com/ethereum/go-ethereum/core/vm"
+ "github.com/ethereum/go-ethereum/ethdb"
+ "github.com/ethereum/go-ethereum/event"
+ "github.com/ethereum/go-ethereum/params"
+ "github.com/ethereum/go-ethereum/trie"
+)
+
+func TestNodeIterator(t *testing.T) {
+ var (
+ fulldb, _ = ethdb.NewMemDatabase()
+ lightdb, _ = ethdb.NewMemDatabase()
+ gspec = core.Genesis{Alloc: core.GenesisAlloc{testBankAddress: {Balance: testBankFunds}}}
+ genesis = gspec.MustCommit(fulldb)
+ )
+ gspec.MustCommit(lightdb)
+ blockchain, _ := core.NewBlockChain(fulldb, params.TestChainConfig, ethash.NewFullFaker(), new(event.TypeMux), vm.Config{})
+ gchain, _ := core.GenerateChain(params.TestChainConfig, genesis, fulldb, 4, testChainGen)
+ if _, err := blockchain.InsertChain(gchain); err != nil {
+ panic(err)
+ }
+
+ ctx := context.Background()
+ odr := &testOdr{sdb: fulldb, ldb: lightdb}
+ head := blockchain.CurrentHeader()
+ lightTrie, _ := NewStateDatabase(ctx, head, odr).OpenTrie(head.Root)
+ fullTrie, _ := state.NewDatabase(fulldb).OpenTrie(head.Root)
+ if err := diffTries(fullTrie, lightTrie); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func diffTries(t1, t2 state.Trie) error {
+ i1 := trie.NewIterator(t1.NodeIterator(nil))
+ i2 := trie.NewIterator(t2.NodeIterator(nil))
+ for i1.Next() && i2.Next() {
+ if !bytes.Equal(i1.Key, i2.Key) {
+ spew.Dump(i2)
+ return fmt.Errorf("tries have different keys %x, %x", i1.Key, i2.Key)
+ }
+ if !bytes.Equal(i2.Value, i2.Value) {
+ return fmt.Errorf("tries differ at key %x", i1.Key)
+ }
+ }
+ switch {
+ case i1.Err != nil:
+ return fmt.Errorf("full trie iterator error: %v", i1.Err)
+ case i2.Err != nil:
+ return fmt.Errorf("light trie iterator error: %v", i1.Err)
+ case i1.Next():
+ return fmt.Errorf("full trie iterator has more k/v pairs")
+ case i2.Next():
+ return fmt.Errorf("light trie iterator has more k/v pairs")
+ }
+ return nil
+}
diff --git a/light/txpool.go b/light/txpool.go
index 7276874b8..0430b280f 100644
--- a/light/txpool.go
+++ b/light/txpool.go
@@ -24,6 +24,7 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core"
+ "github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/event"
@@ -100,17 +101,18 @@ func NewTxPool(config *params.ChainConfig, eventMux *event.TypeMux, chain *Light
}
// currentState returns the light state of the current head header
-func (pool *TxPool) currentState() *LightState {
- return NewLightState(StateTrieID(pool.chain.CurrentHeader()), pool.odr)
+func (pool *TxPool) currentState(ctx context.Context) *state.StateDB {
+ return NewState(ctx, pool.chain.CurrentHeader(), pool.odr)
}
// GetNonce returns the "pending" nonce of a given address. It always queries
// the nonce belonging to the latest header too in order to detect if another
// client using the same key sent a transaction.
func (pool *TxPool) GetNonce(ctx context.Context, addr common.Address) (uint64, error) {
- nonce, err := pool.currentState().GetNonce(ctx, addr)
- if err != nil {
- return 0, err
+ state := pool.currentState(ctx)
+ nonce := state.GetNonce(addr)
+ if state.Error() != nil {
+ return 0, state.Error()
}
sn, ok := pool.nonce[addr]
if ok && sn > nonce {
@@ -357,13 +359,9 @@ func (pool *TxPool) validateTx(ctx context.Context, tx *types.Transaction) error
return core.ErrInvalidSender
}
// Last but not least check for nonce errors
- currentState := pool.currentState()
- if n, err := currentState.GetNonce(ctx, from); err == nil {
- if n > tx.Nonce() {
- return core.ErrNonceTooLow
- }
- } else {
- return err
+ currentState := pool.currentState(ctx)
+ if n := currentState.GetNonce(from); n > tx.Nonce() {
+ return core.ErrNonceTooLow
}
// Check the transaction doesn't exceed the current
@@ -382,12 +380,8 @@ func (pool *TxPool) validateTx(ctx context.Context, tx *types.Transaction) error
// Transactor should have enough funds to cover the costs
// cost == V + GP * GL
- if b, err := currentState.GetBalance(ctx, from); err == nil {
- if b.Cmp(tx.Cost()) < 0 {
- return core.ErrInsufficientFunds
- }
- } else {
- return err
+ if b := currentState.GetBalance(from); b.Cmp(tx.Cost()) < 0 {
+ return core.ErrInsufficientFunds
}
// Should supply enough intrinsic gas
@@ -395,7 +389,7 @@ func (pool *TxPool) validateTx(ctx context.Context, tx *types.Transaction) error
return core.ErrIntrinsicGas
}
- return nil
+ return currentState.Error()
}
// add validates a new transaction and sets its state pending if processable.
diff --git a/light/vm_env.go b/light/vm_env.go
deleted file mode 100644
index 54aa12875..000000000
--- a/light/vm_env.go
+++ /dev/null
@@ -1,194 +0,0 @@
-// Copyright 2016 The go-ethereum Authors
-// This file is part of the go-ethereum library.
-//
-// The go-ethereum library is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Lesser General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// The go-ethereum library is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Lesser General Public License for more details.
-//
-// You should have received a copy of the GNU Lesser General Public License
-// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
-
-package light
-
-import (
- "context"
- "math/big"
-
- "github.com/ethereum/go-ethereum/common"
- "github.com/ethereum/go-ethereum/core/types"
- "github.com/ethereum/go-ethereum/crypto"
-)
-
-// VMState is a wrapper for the light state that holds the actual context and
-// passes it to any state operation that requires it.
-type VMState struct {
- ctx context.Context
- state *LightState
- snapshots []*LightState
- err error
-}
-
-func NewVMState(ctx context.Context, state *LightState) *VMState {
- return &VMState{ctx: ctx, state: state}
-}
-
-func (s *VMState) Error() error {
- return s.err
-}
-
-func (s *VMState) AddLog(log *types.Log) {}
-
-func (s *VMState) AddPreimage(hash common.Hash, preimage []byte) {}
-
-// errHandler handles and stores any state error that happens during execution.
-func (s *VMState) errHandler(err error) {
- if err != nil && s.err == nil {
- s.err = err
- }
-}
-
-func (self *VMState) Snapshot() int {
- self.snapshots = append(self.snapshots, self.state.Copy())
- return len(self.snapshots) - 1
-}
-
-func (self *VMState) RevertToSnapshot(idx int) {
- self.state.Set(self.snapshots[idx])
- self.snapshots = self.snapshots[:idx]
-}
-
-// CreateAccount creates creates a new account object and takes ownership.
-func (s *VMState) CreateAccount(addr common.Address) {
- _, err := s.state.CreateStateObject(s.ctx, addr)
- s.errHandler(err)
-}
-
-// AddBalance adds the given amount to the balance of the specified account
-func (s *VMState) AddBalance(addr common.Address, amount *big.Int) {
- err := s.state.AddBalance(s.ctx, addr, amount)
- s.errHandler(err)
-}
-
-// SubBalance adds the given amount to the balance of the specified account
-func (s *VMState) SubBalance(addr common.Address, amount *big.Int) {
- err := s.state.SubBalance(s.ctx, addr, amount)
- s.errHandler(err)
-}
-
-// ForEachStorage calls a callback function for every key/value pair found
-// in the local storage cache. Note that unlike core/state.StateObject,
-// light.StateObject only returns cached values and doesn't download the
-// entire storage tree.
-func (s *VMState) ForEachStorage(addr common.Address, cb func(key, value common.Hash) bool) {
- err := s.state.ForEachStorage(s.ctx, addr, cb)
- s.errHandler(err)
-}
-
-// GetBalance retrieves the balance from the given address or 0 if the account does
-// not exist
-func (s *VMState) GetBalance(addr common.Address) *big.Int {
- res, err := s.state.GetBalance(s.ctx, addr)
- s.errHandler(err)
- return res
-}
-
-// GetNonce returns the nonce at the given address or 0 if the account does
-// not exist
-func (s *VMState) GetNonce(addr common.Address) uint64 {
- res, err := s.state.GetNonce(s.ctx, addr)
- s.errHandler(err)
- return res
-}
-
-// SetNonce sets the nonce of the specified account
-func (s *VMState) SetNonce(addr common.Address, nonce uint64) {
- err := s.state.SetNonce(s.ctx, addr, nonce)
- s.errHandler(err)
-}
-
-// GetCode returns the contract code at the given address or nil if the account
-// does not exist
-func (s *VMState) GetCode(addr common.Address) []byte {
- res, err := s.state.GetCode(s.ctx, addr)
- s.errHandler(err)
- return res
-}
-
-// GetCodeHash returns the contract code hash at the given address
-func (s *VMState) GetCodeHash(addr common.Address) common.Hash {
- res, err := s.state.GetCode(s.ctx, addr)
- s.errHandler(err)
- return crypto.Keccak256Hash(res)
-}
-
-// GetCodeSize returns the contract code size at the given address
-func (s *VMState) GetCodeSize(addr common.Address) int {
- res, err := s.state.GetCode(s.ctx, addr)
- s.errHandler(err)
- return len(res)
-}
-
-// SetCode sets the contract code at the specified account
-func (s *VMState) SetCode(addr common.Address, code []byte) {
- err := s.state.SetCode(s.ctx, addr, code)
- s.errHandler(err)
-}
-
-// AddRefund adds an amount to the refund value collected during a vm execution
-func (s *VMState) AddRefund(gas *big.Int) {
- s.state.AddRefund(gas)
-}
-
-// GetRefund returns the refund value collected during a vm execution
-func (s *VMState) GetRefund() *big.Int {
- return s.state.GetRefund()
-}
-
-// GetState returns the contract storage value at storage address b from the
-// contract address a or common.Hash{} if the account does not exist
-func (s *VMState) GetState(a common.Address, b common.Hash) common.Hash {
- res, err := s.state.GetState(s.ctx, a, b)
- s.errHandler(err)
- return res
-}
-
-// SetState sets the storage value at storage address key of the account addr
-func (s *VMState) SetState(addr common.Address, key common.Hash, value common.Hash) {
- err := s.state.SetState(s.ctx, addr, key, value)
- s.errHandler(err)
-}
-
-// Suicide marks an account to be removed and clears its balance
-func (s *VMState) Suicide(addr common.Address) bool {
- res, err := s.state.Suicide(s.ctx, addr)
- s.errHandler(err)
- return res
-}
-
-// Exist returns true if an account exists at the given address
-func (s *VMState) Exist(addr common.Address) bool {
- res, err := s.state.HasAccount(s.ctx, addr)
- s.errHandler(err)
- return res
-}
-
-// Empty returns true if the account at the given address is considered empty
-func (s *VMState) Empty(addr common.Address) bool {
- so, err := s.state.GetStateObject(s.ctx, addr)
- s.errHandler(err)
- return so == nil || so.empty()
-}
-
-// HasSuicided returns true if the given account has been marked for deletion
-// or false if the account does not exist
-func (s *VMState) HasSuicided(addr common.Address) bool {
- res, err := s.state.HasSuicided(s.ctx, addr)
- s.errHandler(err)
- return res
-}
diff --git a/miner/worker.go b/miner/worker.go
index 803015390..e44514755 100644
--- a/miner/worker.go
+++ b/miner/worker.go
@@ -274,7 +274,7 @@ func (self *worker) wait() {
}
go self.mux.Post(core.NewMinedBlockEvent{Block: block})
} else {
- work.state.Commit(self.config.IsEIP158(block.Number()))
+ work.state.CommitTo(self.chainDb, self.config.IsEIP158(block.Number()))
stat, err := self.chain.WriteBlock(block)
if err != nil {
log.Error("Failed writing block to chain", "err", err)
diff --git a/tests/block_test_util.go b/tests/block_test_util.go
index b9678a77b..24d4672b6 100644
--- a/tests/block_test_util.go
+++ b/tests/block_test_util.go
@@ -204,7 +204,7 @@ func runBlockTest(homesteadBlock, daoForkBlock, gasPriceFork *big.Int, test *Blo
// InsertPreState populates the given database with the genesis
// accounts defined by the test.
func (t *BlockTest) InsertPreState(db ethdb.Database) (*state.StateDB, error) {
- statedb, err := state.New(common.Hash{}, db)
+ statedb, err := state.New(common.Hash{}, state.NewDatabase(db))
if err != nil {
return nil, err
}
@@ -232,7 +232,7 @@ func (t *BlockTest) InsertPreState(db ethdb.Database) (*state.StateDB, error) {
}
}
- root, err := statedb.Commit(false)
+ root, err := statedb.CommitTo(db, false)
if err != nil {
return nil, fmt.Errorf("error writing state: %v", err)
}
diff --git a/tests/state_test_util.go b/tests/state_test_util.go
index c1892cdcc..58acdd488 100644
--- a/tests/state_test_util.go
+++ b/tests/state_test_util.go
@@ -20,7 +20,6 @@ import (
"bytes"
"fmt"
"io"
- "math/big"
"strconv"
"strings"
"testing"
@@ -99,7 +98,7 @@ func benchStateTest(chainConfig *params.ChainConfig, test VmTest, env map[string
statedb := makePreState(db, test.Pre)
b.StartTimer()
- RunState(chainConfig, statedb, env, test.Exec)
+ RunState(chainConfig, statedb, db, env, test.Exec)
}
func runStateTests(chainConfig *params.ChainConfig, tests map[string]VmTest, skipTests []string) error {
@@ -143,16 +142,9 @@ func runStateTest(chainConfig *params.ChainConfig, test VmTest) error {
env["currentTimestamp"] = test.Env.CurrentTimestamp.(string)
}
- var (
- ret []byte
- // gas *big.Int
- // err error
- logs []*types.Log
- )
+ ret, logs, root, _ := RunState(chainConfig, statedb, db, env, test.Transaction)
- ret, logs, _, _ = RunState(chainConfig, statedb, env, test.Transaction)
-
- // Compare expected and actual return
+ // Return value:
var rexp []byte
if strings.HasPrefix(test.Out, "#") {
n, _ := strconv.Atoi(test.Out[1:])
@@ -163,61 +155,43 @@ func runStateTest(chainConfig *params.ChainConfig, test VmTest) error {
if !bytes.Equal(rexp, ret) {
return fmt.Errorf("return failed. Expected %x, got %x\n", rexp, ret)
}
-
- // check post state
+ // Post state content:
for addr, account := range test.Post {
address := common.HexToAddress(addr)
if !statedb.Exist(address) {
return fmt.Errorf("did not find expected post-state account: %s", addr)
}
-
if balance := statedb.GetBalance(address); balance.Cmp(math.MustParseBig256(account.Balance)) != 0 {
return fmt.Errorf("(%x) balance failed. Expected: %v have: %v\n", address[:4], math.MustParseBig256(account.Balance), balance)
}
-
if nonce := statedb.GetNonce(address); nonce != math.MustParseUint64(account.Nonce) {
return fmt.Errorf("(%x) nonce failed. Expected: %v have: %v\n", address[:4], account.Nonce, nonce)
}
-
for addr, value := range account.Storage {
v := statedb.GetState(address, common.HexToHash(addr))
vexp := common.HexToHash(value)
-
if v != vexp {
return fmt.Errorf("storage failed:\n%x: %s:\nexpected: %x\nhave: %x\n(%v %v)\n", address[:4], addr, vexp, v, vexp.Big(), v.Big())
}
}
}
-
- root, _ := statedb.Commit(false)
+ // Root:
if common.HexToHash(test.PostStateRoot) != root {
return fmt.Errorf("Post state root error. Expected: %s have: %x", test.PostStateRoot, root)
}
-
- // check logs
- if len(test.Logs) > 0 {
- if err := checkLogs(test.Logs, logs); err != nil {
- return err
- }
- }
-
- return nil
+ // Logs:
+ return checkLogs(test.Logs, logs)
}
-func RunState(chainConfig *params.ChainConfig, statedb *state.StateDB, env, tx map[string]string) ([]byte, []*types.Log, *big.Int, error) {
+func RunState(chainConfig *params.ChainConfig, statedb *state.StateDB, db ethdb.Database, env, tx map[string]string) ([]byte, []*types.Log, common.Hash, error) {
environment, msg := NewEVMEnvironment(false, chainConfig, statedb, env, tx)
gaspool := new(core.GasPool).AddGas(math.MustParseBig256(env["currentGasLimit"]))
- root, _ := statedb.Commit(false)
- statedb.Reset(root)
-
snapshot := statedb.Snapshot()
-
- ret, gasUsed, err := core.ApplyMessage(environment, msg, gaspool)
+ ret, _, err := core.ApplyMessage(environment, msg, gaspool)
if err != nil {
statedb.RevertToSnapshot(snapshot)
}
- statedb.Commit(chainConfig.IsEIP158(environment.Context.BlockNumber))
-
- return ret, statedb.Logs(), gasUsed, err
+ root, _ := statedb.CommitTo(db, chainConfig.IsEIP158(environment.Context.BlockNumber))
+ return ret, statedb.Logs(), root, err
}
diff --git a/tests/util.go b/tests/util.go
index a3a9a1f64..ff02679ec 100644
--- a/tests/util.go
+++ b/tests/util.go
@@ -48,7 +48,6 @@ func init() {
}
func checkLogs(tlog []Log, logs []*types.Log) error {
-
if len(tlog) != len(logs) {
return fmt.Errorf("log length mismatch. Expected %d, got %d", len(tlog), len(logs))
} else {
@@ -106,10 +105,14 @@ func (self Log) Topics() [][]byte {
}
func makePreState(db ethdb.Database, accounts map[string]Account) *state.StateDB {
- statedb, _ := state.New(common.Hash{}, db)
+ sdb := state.NewDatabase(db)
+ statedb, _ := state.New(common.Hash{}, sdb)
for addr, account := range accounts {
insertAccount(statedb, addr, account)
}
+ // Commit and re-open to start with a clean state.
+ root, _ := statedb.CommitTo(db, false)
+ statedb, _ = state.New(root, sdb)
return statedb
}
diff --git a/trie/proof.go b/trie/proof.go
index 1f8f76b1b..298f648c4 100644
--- a/trie/proof.go
+++ b/trie/proof.go
@@ -125,7 +125,7 @@ func VerifyProof(rootHash common.Hash, key []byte, proof []rlp.RawValue) (value
}
func get(tn node, key []byte) ([]byte, node) {
- for len(key) > 0 {
+ for {
switch n := tn.(type) {
case *shortNode:
if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) {
@@ -140,9 +140,10 @@ func get(tn node, key []byte) ([]byte, node) {
return key, n
case nil:
return key, nil
+ case valueNode:
+ return nil, n
default:
panic(fmt.Sprintf("%T: invalid node: %v", tn, tn))
}
}
- return nil, tn.(valueNode)
}
diff --git a/trie/secure_trie.go b/trie/secure_trie.go
index 37d1d4b09..20c303f31 100644
--- a/trie/secure_trie.go
+++ b/trie/secure_trie.go
@@ -156,6 +156,11 @@ func (t *SecureTrie) Root() []byte {
return t.trie.Root()
}
+func (t *SecureTrie) Copy() *SecureTrie {
+ cpy := *t
+ return &cpy
+}
+
// NodeIterator returns an iterator that returns nodes of the underlying trie. Iteration
// starts at the key after the given start key.
func (t *SecureTrie) NodeIterator(start []byte) NodeIterator {