diff options
72 files changed, 2660 insertions, 2828 deletions
diff --git a/accounts/abi/abi.go b/accounts/abi/abi.go index 3d1010229..2a06d474b 100644 --- a/accounts/abi/abi.go +++ b/accounts/abi/abi.go @@ -17,11 +17,9 @@ package abi import ( - "encoding/binary" "encoding/json" "fmt" "io" - "math/big" "reflect" "strings" @@ -67,7 +65,7 @@ func (abi ABI) Pack(name string, args ...interface{}) ([]byte, error) { } method = m } - arguments, err := method.pack(method, args...) + arguments, err := method.pack(args...) if err != nil { return nil, err } @@ -78,199 +76,6 @@ func (abi ABI) Pack(name string, args ...interface{}) ([]byte, error) { return append(method.Id(), arguments...), nil } -// toGoSliceType parses the input and casts it to the proper slice defined by the ABI -// argument in T. -func toGoSlice(i int, t Argument, output []byte) (interface{}, error) { - index := i * 32 - // The slice must, at very least be large enough for the index+32 which is exactly the size required - // for the [offset in output, size of offset]. - if index+32 > len(output) { - return nil, fmt.Errorf("abi: cannot marshal in to go slice: insufficient size output %d require %d", len(output), index+32) - } - elem := t.Type.Elem - - // first we need to create a slice of the type - var refSlice reflect.Value - switch elem.T { - case IntTy, UintTy, BoolTy: - // create a new reference slice matching the element type - switch t.Type.Kind { - case reflect.Bool: - refSlice = reflect.ValueOf([]bool(nil)) - case reflect.Uint8: - refSlice = reflect.ValueOf([]uint8(nil)) - case reflect.Uint16: - refSlice = reflect.ValueOf([]uint16(nil)) - case reflect.Uint32: - refSlice = reflect.ValueOf([]uint32(nil)) - case reflect.Uint64: - refSlice = reflect.ValueOf([]uint64(nil)) - case reflect.Int8: - refSlice = reflect.ValueOf([]int8(nil)) - case reflect.Int16: - refSlice = reflect.ValueOf([]int16(nil)) - case reflect.Int32: - refSlice = reflect.ValueOf([]int32(nil)) - case reflect.Int64: - refSlice = reflect.ValueOf([]int64(nil)) - default: - refSlice = reflect.ValueOf([]*big.Int(nil)) - } - case AddressTy: // address must be of slice Address - refSlice = reflect.ValueOf([]common.Address(nil)) - case HashTy: // hash must be of slice hash - refSlice = reflect.ValueOf([]common.Hash(nil)) - case FixedBytesTy: - refSlice = reflect.ValueOf([][]byte(nil)) - default: // no other types are supported - return nil, fmt.Errorf("abi: unsupported slice type %v", elem.T) - } - - var slice []byte - var size int - var offset int - if t.Type.IsSlice { - // get the offset which determines the start of this array ... - offset = int(binary.BigEndian.Uint64(output[index+24 : index+32])) - if offset+32 > len(output) { - return nil, fmt.Errorf("abi: cannot marshal in to go slice: offset %d would go over slice boundary (len=%d)", len(output), offset+32) - } - - slice = output[offset:] - // ... starting with the size of the array in elements ... - size = int(binary.BigEndian.Uint64(slice[24:32])) - slice = slice[32:] - // ... and make sure that we've at the very least the amount of bytes - // available in the buffer. - if size*32 > len(slice) { - return nil, fmt.Errorf("abi: cannot marshal in to go slice: insufficient size output %d require %d", len(output), offset+32+size*32) - } - - // reslice to match the required size - slice = slice[:size*32] - } else if t.Type.IsArray { - //get the number of elements in the array - size = t.Type.SliceSize - - //check to make sure array size matches up - if index+32*size > len(output) { - return nil, fmt.Errorf("abi: cannot marshal in to go array: offset %d would go over slice boundary (len=%d)", len(output), index+32*size) - } - //slice is there for a fixed amount of times - slice = output[index : index+size*32] - } - - for i := 0; i < size; i++ { - var ( - inter interface{} // interface type - returnOutput = slice[i*32 : i*32+32] // the return output - ) - // set inter to the correct type (cast) - switch elem.T { - case IntTy, UintTy: - inter = readInteger(t.Type.Kind, returnOutput) - case BoolTy: - inter = !allZero(returnOutput) - case AddressTy: - inter = common.BytesToAddress(returnOutput) - case HashTy: - inter = common.BytesToHash(returnOutput) - case FixedBytesTy: - inter = returnOutput - } - // append the item to our reflect slice - refSlice = reflect.Append(refSlice, reflect.ValueOf(inter)) - } - - // return the interface - return refSlice.Interface(), nil -} - -func readInteger(kind reflect.Kind, b []byte) interface{} { - switch kind { - case reflect.Uint8: - return uint8(b[len(b)-1]) - case reflect.Uint16: - return binary.BigEndian.Uint16(b[len(b)-2:]) - case reflect.Uint32: - return binary.BigEndian.Uint32(b[len(b)-4:]) - case reflect.Uint64: - return binary.BigEndian.Uint64(b[len(b)-8:]) - case reflect.Int8: - return int8(b[len(b)-1]) - case reflect.Int16: - return int16(binary.BigEndian.Uint16(b[len(b)-2:])) - case reflect.Int32: - return int32(binary.BigEndian.Uint32(b[len(b)-4:])) - case reflect.Int64: - return int64(binary.BigEndian.Uint64(b[len(b)-8:])) - default: - return new(big.Int).SetBytes(b) - } -} - -func allZero(b []byte) bool { - for _, byte := range b { - if byte != 0 { - return false - } - } - return true -} - -// toGoType parses the input and casts it to the proper type defined by the ABI -// argument in T. -func toGoType(i int, t Argument, output []byte) (interface{}, error) { - // we need to treat slices differently - if (t.Type.IsSlice || t.Type.IsArray) && t.Type.T != BytesTy && t.Type.T != StringTy && t.Type.T != FixedBytesTy && t.Type.T != FunctionTy { - return toGoSlice(i, t, output) - } - - index := i * 32 - if index+32 > len(output) { - return nil, fmt.Errorf("abi: cannot marshal in to go type: length insufficient %d require %d", len(output), index+32) - } - - // Parse the given index output and check whether we need to read - // a different offset and length based on the type (i.e. string, bytes) - var returnOutput []byte - switch t.Type.T { - case StringTy, BytesTy: // variable arrays are written at the end of the return bytes - // parse offset from which we should start reading - offset := int(binary.BigEndian.Uint64(output[index+24 : index+32])) - if offset+32 > len(output) { - return nil, fmt.Errorf("abi: cannot marshal in to go type: length insufficient %d require %d", len(output), offset+32) - } - // parse the size up until we should be reading - size := int(binary.BigEndian.Uint64(output[offset+24 : offset+32])) - if offset+32+size > len(output) { - return nil, fmt.Errorf("abi: cannot marshal in to go type: length insufficient %d require %d", len(output), offset+32+size) - } - - // get the bytes for this return value - returnOutput = output[offset+32 : offset+32+size] - default: - returnOutput = output[index : index+32] - } - - // convert the bytes to whatever is specified by the ABI. - switch t.Type.T { - case IntTy, UintTy: - return readInteger(t.Type.Kind, returnOutput), nil - case BoolTy: - return !allZero(returnOutput), nil - case AddressTy: - return common.BytesToAddress(returnOutput), nil - case HashTy: - return common.BytesToHash(returnOutput), nil - case BytesTy, FixedBytesTy, FunctionTy: - return returnOutput, nil - case StringTy: - return string(returnOutput), nil - } - return nil, fmt.Errorf("abi: unknown type %v", t.Type.T) -} - // these variable are used to determine certain types during type assertion for // assignment. var ( diff --git a/accounts/abi/abi_test.go b/accounts/abi/abi_test.go index a45bd6cc0..a3aa9446e 100644 --- a/accounts/abi/abi_test.go +++ b/accounts/abi/abi_test.go @@ -48,412 +48,6 @@ func pad(input []byte, size int, left bool) []byte { return common.RightPadBytes(input, size) } -func TestTypeCheck(t *testing.T) { - for i, test := range []struct { - typ string - input interface{} - err string - }{ - {"uint", big.NewInt(1), ""}, - {"int", big.NewInt(1), ""}, - {"uint30", big.NewInt(1), ""}, - {"uint30", uint8(1), "abi: cannot use uint8 as type ptr as argument"}, - {"uint16", uint16(1), ""}, - {"uint16", uint8(1), "abi: cannot use uint8 as type uint16 as argument"}, - {"uint16[]", []uint16{1, 2, 3}, ""}, - {"uint16[]", [3]uint16{1, 2, 3}, ""}, - {"uint16[]", []uint32{1, 2, 3}, "abi: cannot use []uint32 as type []uint16 as argument"}, - {"uint16[3]", [3]uint32{1, 2, 3}, "abi: cannot use [3]uint32 as type [3]uint16 as argument"}, - {"uint16[3]", [4]uint16{1, 2, 3}, "abi: cannot use [4]uint16 as type [3]uint16 as argument"}, - {"uint16[3]", []uint16{1, 2, 3}, ""}, - {"uint16[3]", []uint16{1, 2, 3, 4}, "abi: cannot use [4]uint16 as type [3]uint16 as argument"}, - {"address[]", []common.Address{{1}}, ""}, - {"address[1]", []common.Address{{1}}, ""}, - {"address[1]", [1]common.Address{{1}}, ""}, - {"address[2]", [1]common.Address{{1}}, "abi: cannot use [1]array as type [2]array as argument"}, - {"bytes32", [32]byte{}, ""}, - {"bytes32", [33]byte{}, "abi: cannot use [33]uint8 as type [32]uint8 as argument"}, - {"bytes32", common.Hash{1}, ""}, - {"bytes31", [31]byte{}, ""}, - {"bytes31", [32]byte{}, "abi: cannot use [32]uint8 as type [31]uint8 as argument"}, - {"bytes", []byte{0, 1}, ""}, - {"bytes", [2]byte{0, 1}, ""}, - {"bytes", common.Hash{1}, ""}, - {"string", "hello world", ""}, - {"bytes32[]", [][32]byte{{}}, ""}, - {"function", [24]byte{}, ""}, - } { - typ, err := NewType(test.typ) - if err != nil { - t.Fatal("unexpected parse error:", err) - } - - err = typeCheck(typ, reflect.ValueOf(test.input)) - if err != nil && len(test.err) == 0 { - t.Errorf("%d failed. Expected no err but got: %v", i, err) - continue - } - if err == nil && len(test.err) != 0 { - t.Errorf("%d failed. Expected err: %v but got none", i, test.err) - continue - } - - if err != nil && len(test.err) != 0 && err.Error() != test.err { - t.Errorf("%d failed. Expected err: '%v' got err: '%v'", i, test.err, err) - } - } -} - -func TestSimpleMethodUnpack(t *testing.T) { - for i, test := range []struct { - def string // definition of the **output** ABI params - marshalledOutput []byte // evm return data - expectedOut interface{} // the expected output - outVar string // the output variable (e.g. uint32, *big.Int, etc) - err string // empty or error if expected - }{ - { - `[ { "type": "uint32" } ]`, - pad([]byte{1}, 32, true), - uint32(1), - "uint32", - "", - }, - { - `[ { "type": "uint32" } ]`, - pad([]byte{1}, 32, true), - nil, - "uint16", - "abi: cannot unmarshal uint32 in to uint16", - }, - { - `[ { "type": "uint17" } ]`, - pad([]byte{1}, 32, true), - nil, - "uint16", - "abi: cannot unmarshal *big.Int in to uint16", - }, - { - `[ { "type": "uint17" } ]`, - pad([]byte{1}, 32, true), - big.NewInt(1), - "*big.Int", - "", - }, - - { - `[ { "type": "int32" } ]`, - pad([]byte{1}, 32, true), - int32(1), - "int32", - "", - }, - { - `[ { "type": "int32" } ]`, - pad([]byte{1}, 32, true), - nil, - "int16", - "abi: cannot unmarshal int32 in to int16", - }, - { - `[ { "type": "int17" } ]`, - pad([]byte{1}, 32, true), - nil, - "int16", - "abi: cannot unmarshal *big.Int in to int16", - }, - { - `[ { "type": "int17" } ]`, - pad([]byte{1}, 32, true), - big.NewInt(1), - "*big.Int", - "", - }, - - { - `[ { "type": "address" } ]`, - pad(pad([]byte{1}, 20, false), 32, true), - common.Address{1}, - "address", - "", - }, - { - `[ { "type": "bytes32" } ]`, - pad([]byte{1}, 32, false), - pad([]byte{1}, 32, false), - "bytes", - "", - }, - { - `[ { "type": "bytes32" } ]`, - pad([]byte{1}, 32, false), - pad([]byte{1}, 32, false), - "hash", - "", - }, - { - `[ { "type": "bytes32" } ]`, - pad([]byte{1}, 32, false), - pad([]byte{1}, 32, false), - "interface", - "", - }, - { - `[ { "type": "function" } ]`, - pad([]byte{1}, 32, false), - [24]byte{1}, - "function", - "", - }, - } { - abiDefinition := fmt.Sprintf(`[{ "name" : "method", "outputs": %s}]`, test.def) - abi, err := JSON(strings.NewReader(abiDefinition)) - if err != nil { - t.Errorf("%d failed. %v", i, err) - continue - } - - var outvar interface{} - switch test.outVar { - case "uint8": - var v uint8 - err = abi.Unpack(&v, "method", test.marshalledOutput) - outvar = v - case "uint16": - var v uint16 - err = abi.Unpack(&v, "method", test.marshalledOutput) - outvar = v - case "uint32": - var v uint32 - err = abi.Unpack(&v, "method", test.marshalledOutput) - outvar = v - case "uint64": - var v uint64 - err = abi.Unpack(&v, "method", test.marshalledOutput) - outvar = v - case "int8": - var v int8 - err = abi.Unpack(&v, "method", test.marshalledOutput) - outvar = v - case "int16": - var v int16 - err = abi.Unpack(&v, "method", test.marshalledOutput) - outvar = v - case "int32": - var v int32 - err = abi.Unpack(&v, "method", test.marshalledOutput) - outvar = v - case "int64": - var v int64 - err = abi.Unpack(&v, "method", test.marshalledOutput) - outvar = v - case "*big.Int": - var v *big.Int - err = abi.Unpack(&v, "method", test.marshalledOutput) - outvar = v - case "address": - var v common.Address - err = abi.Unpack(&v, "method", test.marshalledOutput) - outvar = v - case "bytes": - var v []byte - err = abi.Unpack(&v, "method", test.marshalledOutput) - outvar = v - case "hash": - var v common.Hash - err = abi.Unpack(&v, "method", test.marshalledOutput) - outvar = v - case "function": - var v [24]byte - err = abi.Unpack(&v, "method", test.marshalledOutput) - outvar = v - case "interface": - err = abi.Unpack(&outvar, "method", test.marshalledOutput) - default: - t.Errorf("unsupported type '%v' please add it to the switch statement in this test", test.outVar) - continue - } - - if err != nil && len(test.err) == 0 { - t.Errorf("%d failed. Expected no err but got: %v", i, err) - continue - } - if err == nil && len(test.err) != 0 { - t.Errorf("%d failed. Expected err: %v but got none", i, test.err) - continue - } - if err != nil && len(test.err) != 0 && err.Error() != test.err { - t.Errorf("%d failed. Expected err: '%v' got err: '%v'", i, test.err, err) - continue - } - - if err == nil { - // bit of an ugly hack for hash type but I don't feel like finding a proper solution - if test.outVar == "hash" { - tmp := outvar.(common.Hash) // without assignment it's unaddressable - outvar = tmp[:] - } - - if !reflect.DeepEqual(test.expectedOut, outvar) { - t.Errorf("%d failed. Output error: expected %v, got %v", i, test.expectedOut, outvar) - } - } - } -} - -func TestUnpackSetInterfaceSlice(t *testing.T) { - var ( - var1 = new(uint8) - var2 = new(uint8) - ) - out := []interface{}{var1, var2} - abi, err := JSON(strings.NewReader(`[{"type":"function", "name":"ints", "outputs":[{"type":"uint8"}, {"type":"uint8"}]}]`)) - if err != nil { - t.Fatal(err) - } - marshalledReturn := append(pad([]byte{1}, 32, true), pad([]byte{2}, 32, true)...) - err = abi.Unpack(&out, "ints", marshalledReturn) - if err != nil { - t.Fatal(err) - } - if *var1 != 1 { - t.Error("expected var1 to be 1, got", *var1) - } - if *var2 != 2 { - t.Error("expected var2 to be 2, got", *var2) - } - - out = []interface{}{var1} - err = abi.Unpack(&out, "ints", marshalledReturn) - - expErr := "abi: cannot marshal in to slices of unequal size (require: 2, got: 1)" - if err == nil || err.Error() != expErr { - t.Error("expected err:", expErr, "Got:", err) - } -} - -func TestUnpackSetInterfaceArrayOutput(t *testing.T) { - var ( - var1 = new([1]uint32) - var2 = new([1]uint32) - ) - out := []interface{}{var1, var2} - abi, err := JSON(strings.NewReader(`[{"type":"function", "name":"ints", "outputs":[{"type":"uint32[1]"}, {"type":"uint32[1]"}]}]`)) - if err != nil { - t.Fatal(err) - } - marshalledReturn := append(pad([]byte{1}, 32, true), pad([]byte{2}, 32, true)...) - err = abi.Unpack(&out, "ints", marshalledReturn) - if err != nil { - t.Fatal(err) - } - - if *var1 != [1]uint32{1} { - t.Error("expected var1 to be [1], got", *var1) - } - if *var2 != [1]uint32{2} { - t.Error("expected var2 to be [2], got", *var2) - } -} - -func TestPack(t *testing.T) { - for i, test := range []struct { - typ string - - input interface{} - output []byte - }{ - {"uint16", uint16(2), pad([]byte{2}, 32, true)}, - {"uint16[]", []uint16{1, 2}, formatSliceOutput([]byte{1}, []byte{2})}, - {"bytes20", [20]byte{1}, pad([]byte{1}, 32, false)}, - {"uint256[]", []*big.Int{big.NewInt(1), big.NewInt(2)}, formatSliceOutput([]byte{1}, []byte{2})}, - {"address[]", []common.Address{{1}, {2}}, formatSliceOutput(pad([]byte{1}, 20, false), pad([]byte{2}, 20, false))}, - {"bytes32[]", []common.Hash{{1}, {2}}, formatSliceOutput(pad([]byte{1}, 32, false), pad([]byte{2}, 32, false))}, - {"function", [24]byte{1}, pad([]byte{1}, 32, false)}, - } { - typ, err := NewType(test.typ) - if err != nil { - t.Fatal("unexpected parse error:", err) - } - - output, err := typ.pack(reflect.ValueOf(test.input)) - if err != nil { - t.Fatal("unexpected pack error:", err) - } - - if !bytes.Equal(output, test.output) { - t.Errorf("%d failed. Expected bytes: '%x' Got: '%x'", i, test.output, output) - } - } -} - -func TestMethodPack(t *testing.T) { - abi, err := JSON(strings.NewReader(jsondata2)) - if err != nil { - t.Fatal(err) - } - - sig := abi.Methods["slice"].Id() - sig = append(sig, common.LeftPadBytes([]byte{1}, 32)...) - sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...) - - packed, err := abi.Pack("slice", []uint32{1, 2}) - if err != nil { - t.Error(err) - } - - if !bytes.Equal(packed, sig) { - t.Errorf("expected %x got %x", sig, packed) - } - - var addrA, addrB = common.Address{1}, common.Address{2} - sig = abi.Methods["sliceAddress"].Id() - sig = append(sig, common.LeftPadBytes([]byte{32}, 32)...) - sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...) - sig = append(sig, common.LeftPadBytes(addrA[:], 32)...) - sig = append(sig, common.LeftPadBytes(addrB[:], 32)...) - - packed, err = abi.Pack("sliceAddress", []common.Address{addrA, addrB}) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(packed, sig) { - t.Errorf("expected %x got %x", sig, packed) - } - - var addrC, addrD = common.Address{3}, common.Address{4} - sig = abi.Methods["sliceMultiAddress"].Id() - sig = append(sig, common.LeftPadBytes([]byte{64}, 32)...) - sig = append(sig, common.LeftPadBytes([]byte{160}, 32)...) - sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...) - sig = append(sig, common.LeftPadBytes(addrA[:], 32)...) - sig = append(sig, common.LeftPadBytes(addrB[:], 32)...) - sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...) - sig = append(sig, common.LeftPadBytes(addrC[:], 32)...) - sig = append(sig, common.LeftPadBytes(addrD[:], 32)...) - - packed, err = abi.Pack("sliceMultiAddress", []common.Address{addrA, addrB}, []common.Address{addrC, addrD}) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(packed, sig) { - t.Errorf("expected %x got %x", sig, packed) - } - - sig = abi.Methods["slice256"].Id() - sig = append(sig, common.LeftPadBytes([]byte{1}, 32)...) - sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...) - - packed, err = abi.Pack("slice256", []*big.Int{big.NewInt(1), big.NewInt(2)}) - if err != nil { - t.Error(err) - } - - if !bytes.Equal(packed, sig) { - t.Errorf("expected %x got %x", sig, packed) - } -} - const jsondata = ` [ { "type" : "function", "name" : "balance", "constant" : true }, @@ -843,399 +437,3 @@ func TestBareEvents(t *testing.T) { } } } - -func TestMultiReturnWithStruct(t *testing.T) { - const definition = `[ - { "name" : "multi", "constant" : false, "outputs": [ { "name": "Int", "type": "uint256" }, { "name": "String", "type": "string" } ] }]` - - abi, err := JSON(strings.NewReader(definition)) - if err != nil { - t.Fatal(err) - } - - // using buff to make the code readable - buff := new(bytes.Buffer) - buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) - buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000040")) - buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000005")) - stringOut := "hello" - buff.Write(common.RightPadBytes([]byte(stringOut), 32)) - - var inter struct { - Int *big.Int - String string - } - err = abi.Unpack(&inter, "multi", buff.Bytes()) - if err != nil { - t.Error(err) - } - - if inter.Int == nil || inter.Int.Cmp(big.NewInt(1)) != 0 { - t.Error("expected Int to be 1 got", inter.Int) - } - - if inter.String != stringOut { - t.Error("expected String to be", stringOut, "got", inter.String) - } - - var reversed struct { - String string - Int *big.Int - } - - err = abi.Unpack(&reversed, "multi", buff.Bytes()) - if err != nil { - t.Error(err) - } - - if reversed.Int == nil || reversed.Int.Cmp(big.NewInt(1)) != 0 { - t.Error("expected Int to be 1 got", reversed.Int) - } - - if reversed.String != stringOut { - t.Error("expected String to be", stringOut, "got", reversed.String) - } -} - -func TestMultiReturnWithSlice(t *testing.T) { - const definition = `[ - { "name" : "multi", "constant" : false, "outputs": [ { "name": "Int", "type": "uint256" }, { "name": "String", "type": "string" } ] }]` - - abi, err := JSON(strings.NewReader(definition)) - if err != nil { - t.Fatal(err) - } - - // using buff to make the code readable - buff := new(bytes.Buffer) - buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) - buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000040")) - buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000005")) - stringOut := "hello" - buff.Write(common.RightPadBytes([]byte(stringOut), 32)) - - var inter []interface{} - err = abi.Unpack(&inter, "multi", buff.Bytes()) - if err != nil { - t.Error(err) - } - - if len(inter) != 2 { - t.Fatal("expected 2 results got", len(inter)) - } - - if num, ok := inter[0].(*big.Int); !ok || num.Cmp(big.NewInt(1)) != 0 { - t.Error("expected index 0 to be 1 got", num) - } - - if str, ok := inter[1].(string); !ok || str != stringOut { - t.Error("expected index 1 to be", stringOut, "got", str) - } -} - -func TestMarshalArrays(t *testing.T) { - const definition = `[ - { "name" : "bytes32", "constant" : false, "outputs": [ { "type": "bytes32" } ] }, - { "name" : "bytes10", "constant" : false, "outputs": [ { "type": "bytes10" } ] } - ]` - - abi, err := JSON(strings.NewReader(definition)) - if err != nil { - t.Fatal(err) - } - - output := common.LeftPadBytes([]byte{1}, 32) - - var bytes10 [10]byte - err = abi.Unpack(&bytes10, "bytes32", output) - if err == nil || err.Error() != "abi: cannot unmarshal src (len=32) in to dst (len=10)" { - t.Error("expected error or bytes32 not be assignable to bytes10:", err) - } - - var bytes32 [32]byte - err = abi.Unpack(&bytes32, "bytes32", output) - if err != nil { - t.Error("didn't expect error:", err) - } - if !bytes.Equal(bytes32[:], output) { - t.Error("expected bytes32[31] to be 1 got", bytes32[31]) - } - - type ( - B10 [10]byte - B32 [32]byte - ) - - var b10 B10 - err = abi.Unpack(&b10, "bytes32", output) - if err == nil || err.Error() != "abi: cannot unmarshal src (len=32) in to dst (len=10)" { - t.Error("expected error or bytes32 not be assignable to bytes10:", err) - } - - var b32 B32 - err = abi.Unpack(&b32, "bytes32", output) - if err != nil { - t.Error("didn't expect error:", err) - } - if !bytes.Equal(b32[:], output) { - t.Error("expected bytes32[31] to be 1 got", bytes32[31]) - } - - output[10] = 1 - var shortAssignLong [32]byte - err = abi.Unpack(&shortAssignLong, "bytes10", output) - if err != nil { - t.Error("didn't expect error:", err) - } - if !bytes.Equal(output, shortAssignLong[:]) { - t.Errorf("expected %x to be %x", shortAssignLong, output) - } -} - -func TestUnmarshal(t *testing.T) { - const definition = `[ - { "name" : "int", "constant" : false, "outputs": [ { "type": "uint256" } ] }, - { "name" : "bool", "constant" : false, "outputs": [ { "type": "bool" } ] }, - { "name" : "bytes", "constant" : false, "outputs": [ { "type": "bytes" } ] }, - { "name" : "fixed", "constant" : false, "outputs": [ { "type": "bytes32" } ] }, - { "name" : "multi", "constant" : false, "outputs": [ { "type": "bytes" }, { "type": "bytes" } ] }, - { "name" : "intArraySingle", "constant" : false, "outputs": [ { "type": "uint256[3]" } ] }, - { "name" : "addressSliceSingle", "constant" : false, "outputs": [ { "type": "address[]" } ] }, - { "name" : "addressSliceDouble", "constant" : false, "outputs": [ { "name": "a", "type": "address[]" }, { "name": "b", "type": "address[]" } ] }, - { "name" : "mixedBytes", "constant" : true, "outputs": [ { "name": "a", "type": "bytes" }, { "name": "b", "type": "bytes32" } ] }]` - - abi, err := JSON(strings.NewReader(definition)) - if err != nil { - t.Fatal(err) - } - buff := new(bytes.Buffer) - - // marshal int - var Int *big.Int - err = abi.Unpack(&Int, "int", common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) - if err != nil { - t.Error(err) - } - - if Int == nil || Int.Cmp(big.NewInt(1)) != 0 { - t.Error("expected Int to be 1 got", Int) - } - - // marshal bool - var Bool bool - err = abi.Unpack(&Bool, "bool", common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) - if err != nil { - t.Error(err) - } - - if !Bool { - t.Error("expected Bool to be true") - } - - // marshal dynamic bytes max length 32 - buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020")) - buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020")) - bytesOut := common.RightPadBytes([]byte("hello"), 32) - buff.Write(bytesOut) - - var Bytes []byte - err = abi.Unpack(&Bytes, "bytes", buff.Bytes()) - if err != nil { - t.Error(err) - } - - if !bytes.Equal(Bytes, bytesOut) { - t.Errorf("expected %x got %x", bytesOut, Bytes) - } - - // marshall dynamic bytes max length 64 - buff.Reset() - buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020")) - buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000040")) - bytesOut = common.RightPadBytes([]byte("hello"), 64) - buff.Write(bytesOut) - - err = abi.Unpack(&Bytes, "bytes", buff.Bytes()) - if err != nil { - t.Error(err) - } - - if !bytes.Equal(Bytes, bytesOut) { - t.Errorf("expected %x got %x", bytesOut, Bytes) - } - - // marshall dynamic bytes max length 63 - buff.Reset() - buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020")) - buff.Write(common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000003f")) - bytesOut = common.RightPadBytes([]byte("hello"), 63) - buff.Write(bytesOut) - - err = abi.Unpack(&Bytes, "bytes", buff.Bytes()) - if err != nil { - t.Error(err) - } - - if !bytes.Equal(Bytes, bytesOut) { - t.Errorf("expected %x got %x", bytesOut, Bytes) - } - - // marshal dynamic bytes output empty - err = abi.Unpack(&Bytes, "bytes", nil) - if err == nil { - t.Error("expected error") - } - - // marshal dynamic bytes length 5 - buff.Reset() - buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020")) - buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000005")) - buff.Write(common.RightPadBytes([]byte("hello"), 32)) - - err = abi.Unpack(&Bytes, "bytes", buff.Bytes()) - if err != nil { - t.Error(err) - } - - if !bytes.Equal(Bytes, []byte("hello")) { - t.Errorf("expected %x got %x", bytesOut, Bytes) - } - - // marshal dynamic bytes length 5 - buff.Reset() - buff.Write(common.RightPadBytes([]byte("hello"), 32)) - - var hash common.Hash - err = abi.Unpack(&hash, "fixed", buff.Bytes()) - if err != nil { - t.Error(err) - } - - helloHash := common.BytesToHash(common.RightPadBytes([]byte("hello"), 32)) - if hash != helloHash { - t.Errorf("Expected %x to equal %x", hash, helloHash) - } - - // marshal error - buff.Reset() - buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020")) - err = abi.Unpack(&Bytes, "bytes", buff.Bytes()) - if err == nil { - t.Error("expected error") - } - - err = abi.Unpack(&Bytes, "multi", make([]byte, 64)) - if err == nil { - t.Error("expected error") - } - - // marshal mixed bytes - buff.Reset() - buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000040")) - fixed := common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001") - buff.Write(fixed) - buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020")) - bytesOut = common.RightPadBytes([]byte("hello"), 32) - buff.Write(bytesOut) - - var out []interface{} - err = abi.Unpack(&out, "mixedBytes", buff.Bytes()) - if err != nil { - t.Fatal("didn't expect error:", err) - } - - if !bytes.Equal(bytesOut, out[0].([]byte)) { - t.Errorf("expected %x, got %x", bytesOut, out[0]) - } - - if !bytes.Equal(fixed, out[1].([]byte)) { - t.Errorf("expected %x, got %x", fixed, out[1]) - } - - buff.Reset() - buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) - buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002")) - buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000003")) - // marshal int array - var intArray [3]*big.Int - err = abi.Unpack(&intArray, "intArraySingle", buff.Bytes()) - if err != nil { - t.Error(err) - } - var testAgainstIntArray [3]*big.Int - testAgainstIntArray[0] = big.NewInt(1) - testAgainstIntArray[1] = big.NewInt(2) - testAgainstIntArray[2] = big.NewInt(3) - - for i, Int := range intArray { - if Int.Cmp(testAgainstIntArray[i]) != 0 { - t.Errorf("expected %v, got %v", testAgainstIntArray[i], Int) - } - } - // marshal address slice - buff.Reset() - buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020")) // offset - buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) // size - buff.Write(common.Hex2Bytes("0000000000000000000000000100000000000000000000000000000000000000")) - - var outAddr []common.Address - err = abi.Unpack(&outAddr, "addressSliceSingle", buff.Bytes()) - if err != nil { - t.Fatal("didn't expect error:", err) - } - - if len(outAddr) != 1 { - t.Fatal("expected 1 item, got", len(outAddr)) - } - - if outAddr[0] != (common.Address{1}) { - t.Errorf("expected %x, got %x", common.Address{1}, outAddr[0]) - } - - // marshal multiple address slice - buff.Reset() - buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000040")) // offset - buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000080")) // offset - buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) // size - buff.Write(common.Hex2Bytes("0000000000000000000000000100000000000000000000000000000000000000")) - buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002")) // size - buff.Write(common.Hex2Bytes("0000000000000000000000000200000000000000000000000000000000000000")) - buff.Write(common.Hex2Bytes("0000000000000000000000000300000000000000000000000000000000000000")) - - var outAddrStruct struct { - A []common.Address - B []common.Address - } - err = abi.Unpack(&outAddrStruct, "addressSliceDouble", buff.Bytes()) - if err != nil { - t.Fatal("didn't expect error:", err) - } - - if len(outAddrStruct.A) != 1 { - t.Fatal("expected 1 item, got", len(outAddrStruct.A)) - } - - if outAddrStruct.A[0] != (common.Address{1}) { - t.Errorf("expected %x, got %x", common.Address{1}, outAddrStruct.A[0]) - } - - if len(outAddrStruct.B) != 2 { - t.Fatal("expected 1 item, got", len(outAddrStruct.B)) - } - - if outAddrStruct.B[0] != (common.Address{2}) { - t.Errorf("expected %x, got %x", common.Address{2}, outAddrStruct.B[0]) - } - if outAddrStruct.B[1] != (common.Address{3}) { - t.Errorf("expected %x, got %x", common.Address{3}, outAddrStruct.B[1]) - } - - // marshal invalid address slice - buff.Reset() - buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000100")) - - err = abi.Unpack(&outAddr, "addressSliceSingle", buff.Bytes()) - if err == nil { - t.Fatal("expected error:", err) - } -} diff --git a/accounts/abi/bind/backends/simulated.go b/accounts/abi/bind/backends/simulated.go index 159fca136..7ac8b5820 100644 --- a/accounts/abi/bind/backends/simulated.go +++ b/accounts/abi/bind/backends/simulated.go @@ -90,7 +90,7 @@ func (b *SimulatedBackend) Rollback() { func (b *SimulatedBackend) rollback() { blocks, _ := core.GenerateChain(b.config, b.blockchain.CurrentBlock(), b.database, 1, func(int, *core.BlockGen) {}) b.pendingBlock = blocks[0] - b.pendingState, _ = state.New(b.pendingBlock.Root(), b.database) + b.pendingState, _ = state.New(b.pendingBlock.Root(), state.NewDatabase(b.database)) } // CodeAt returns the code associated with a certain account in the blockchain. @@ -279,7 +279,7 @@ func (b *SimulatedBackend) SendTransaction(ctx context.Context, tx *types.Transa block.AddTx(tx) }) b.pendingBlock = blocks[0] - b.pendingState, _ = state.New(b.pendingBlock.Root(), b.database) + b.pendingState, _ = state.New(b.pendingBlock.Root(), state.NewDatabase(b.database)) return nil } diff --git a/accounts/abi/error.go b/accounts/abi/error.go index 67739c21d..420acf418 100644 --- a/accounts/abi/error.go +++ b/accounts/abi/error.go @@ -17,10 +17,15 @@ package abi import ( + "errors" "fmt" "reflect" ) +var ( + errBadBool = errors.New("abi: improperly encoded boolean value") +) + // formatSliceString formats the reflection kind with the given slice size // and returns a formatted string representation. func formatSliceString(kind reflect.Kind, sliceSize int) string { diff --git a/accounts/abi/method.go b/accounts/abi/method.go index d56f3bc3d..62b3d2957 100644 --- a/accounts/abi/method.go +++ b/accounts/abi/method.go @@ -39,7 +39,7 @@ type Method struct { Outputs []Argument } -func (m Method) pack(method Method, args ...interface{}) ([]byte, error) { +func (method Method) pack(args ...interface{}) ([]byte, error) { // Make sure arguments match up and pack them if len(args) != len(method.Inputs) { return nil, fmt.Errorf("argument count mismatch: %d for %d", len(args), len(method.Inputs)) diff --git a/accounts/abi/numbers.go b/accounts/abi/numbers.go index 10afa6511..5d3efff52 100644 --- a/accounts/abi/numbers.go +++ b/accounts/abi/numbers.go @@ -62,19 +62,6 @@ func U256(n *big.Int) []byte { return math.PaddedBigBytes(math.U256(n), 32) } -// packNum packs the given number (using the reflect value) and will cast it to appropriate number representation -func packNum(value reflect.Value) []byte { - switch kind := value.Kind(); kind { - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return U256(new(big.Int).SetUint64(value.Uint())) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return U256(big.NewInt(value.Int())) - case reflect.Ptr: - return U256(value.Interface().(*big.Int)) - } - return nil -} - // checks whether the given reflect value is signed. This also works for slices with a number type func isSigned(v reflect.Value) bool { switch v.Type() { diff --git a/accounts/abi/numbers_test.go b/accounts/abi/numbers_test.go index 44afe8647..b9ff5aef1 100644 --- a/accounts/abi/numbers_test.go +++ b/accounts/abi/numbers_test.go @@ -18,7 +18,6 @@ package abi import ( "bytes" - "math" "math/big" "reflect" "testing" @@ -34,43 +33,6 @@ func TestNumberTypes(t *testing.T) { } } -func TestPackNumber(t *testing.T) { - tests := []struct { - value reflect.Value - packed []byte - }{ - // Protocol limits - {reflect.ValueOf(0), []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}, - {reflect.ValueOf(1), []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}}, - {reflect.ValueOf(-1), []byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}}, - - // Type corner cases - {reflect.ValueOf(uint8(math.MaxUint8)), []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255}}, - {reflect.ValueOf(uint16(math.MaxUint16)), []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255}}, - {reflect.ValueOf(uint32(math.MaxUint32)), []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 255}}, - {reflect.ValueOf(uint64(math.MaxUint64)), []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 255, 255, 255, 255, 255}}, - - {reflect.ValueOf(int8(math.MaxInt8)), []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 127}}, - {reflect.ValueOf(int16(math.MaxInt16)), []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 127, 255}}, - {reflect.ValueOf(int32(math.MaxInt32)), []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 127, 255, 255, 255}}, - {reflect.ValueOf(int64(math.MaxInt64)), []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 127, 255, 255, 255, 255, 255, 255, 255}}, - - {reflect.ValueOf(int8(math.MinInt8)), []byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 128}}, - {reflect.ValueOf(int16(math.MinInt16)), []byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 128, 0}}, - {reflect.ValueOf(int32(math.MinInt32)), []byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 128, 0, 0, 0}}, - {reflect.ValueOf(int64(math.MinInt64)), []byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 128, 0, 0, 0, 0, 0, 0, 0}}, - } - for i, tt := range tests { - packed := packNum(tt.value) - if !bytes.Equal(packed, tt.packed) { - t.Errorf("test %d: pack mismatch: have %x, want %x", i, packed, tt.packed) - } - } - if packed := packNum(reflect.ValueOf("string")); packed != nil { - t.Errorf("expected 'string' to pack to nil. got %x instead", packed) - } -} - func TestSigned(t *testing.T) { if isSigned(reflect.ValueOf(uint(10))) { t.Error("signed") diff --git a/accounts/abi/packing.go b/accounts/abi/pack.go index 1d7f85e2b..4d8a3f031 100644 --- a/accounts/abi/packing.go +++ b/accounts/abi/pack.go @@ -17,6 +17,7 @@ package abi import ( + "math/big" "reflect" "github.com/ethereum/go-ethereum/common" @@ -59,8 +60,20 @@ func packElement(t Type, reflectValue reflect.Value) []byte { if reflectValue.Kind() == reflect.Array { reflectValue = mustArrayToByteSlice(reflectValue) } - return common.RightPadBytes(reflectValue.Bytes(), 32) } panic("abi: fatal error") } + +// packNum packs the given number (using the reflect value) and will cast it to appropriate number representation +func packNum(value reflect.Value) []byte { + switch kind := value.Kind(); kind { + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return U256(new(big.Int).SetUint64(value.Uint())) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return U256(big.NewInt(value.Int())) + case reflect.Ptr: + return U256(value.Interface().(*big.Int)) + } + return nil +} diff --git a/accounts/abi/pack_test.go b/accounts/abi/pack_test.go new file mode 100644 index 000000000..c6cfb56ea --- /dev/null +++ b/accounts/abi/pack_test.go @@ -0,0 +1,441 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +package abi + +import ( + "bytes" + "math" + "math/big" + "reflect" + "strings" + "testing" + + "github.com/ethereum/go-ethereum/common" +) + +func TestPack(t *testing.T) { + for i, test := range []struct { + typ string + + input interface{} + output []byte + }{ + { + "uint8", + uint8(2), + common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"), + }, + { + "uint8[]", + []uint8{1, 2}, + common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"), + }, + { + "uint16", + uint16(2), + common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"), + }, + { + "uint16[]", + []uint16{1, 2}, + common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"), + }, + { + "uint32", + uint32(2), + common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"), + }, + { + "uint32[]", + []uint32{1, 2}, + common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"), + }, + { + "uint64", + uint64(2), + common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"), + }, + { + "uint64[]", + []uint64{1, 2}, + common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"), + }, + { + "uint256", + big.NewInt(2), + common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"), + }, + { + "uint256[]", + []*big.Int{big.NewInt(1), big.NewInt(2)}, + common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"), + }, + { + "int8", + int8(2), + common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"), + }, + { + "int8[]", + []int8{1, 2}, + common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"), + }, + { + "int16", + int16(2), + common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"), + }, + { + "int16[]", + []int16{1, 2}, + common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"), + }, + { + "int32", + int32(2), + common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"), + }, + { + "int32[]", + []int32{1, 2}, + common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"), + }, + { + "int64", + int64(2), + common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"), + }, + { + "int64[]", + []int64{1, 2}, + common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"), + }, + { + "int256", + big.NewInt(2), + common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"), + }, + { + "int256[]", + []*big.Int{big.NewInt(1), big.NewInt(2)}, + common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"), + }, + { + "bytes1", + [1]byte{1}, + common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), + }, + { + "bytes2", + [2]byte{1}, + common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), + }, + { + "bytes3", + [3]byte{1}, + common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), + }, + { + "bytes4", + [4]byte{1}, + common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), + }, + { + "bytes5", + [5]byte{1}, + common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), + }, + { + "bytes6", + [6]byte{1}, + common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), + }, + { + "bytes7", + [7]byte{1}, + common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), + }, + { + "bytes8", + [8]byte{1}, + common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), + }, + { + "bytes9", + [9]byte{1}, + common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), + }, + { + "bytes10", + [10]byte{1}, + common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), + }, + { + "bytes11", + [11]byte{1}, + common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), + }, + { + "bytes12", + [12]byte{1}, + common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), + }, + { + "bytes13", + [13]byte{1}, + common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), + }, + { + "bytes14", + [14]byte{1}, + common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), + }, + { + "bytes15", + [15]byte{1}, + common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), + }, + { + "bytes16", + [16]byte{1}, + common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), + }, + { + "bytes17", + [17]byte{1}, + common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), + }, + { + "bytes18", + [18]byte{1}, + common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), + }, + { + "bytes19", + [19]byte{1}, + common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), + }, + { + "bytes20", + [20]byte{1}, + common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), + }, + { + "bytes21", + [21]byte{1}, + common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), + }, + { + "bytes22", + [22]byte{1}, + common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), + }, + { + "bytes23", + [23]byte{1}, + common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), + }, + { + "bytes24", + [24]byte{1}, + common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), + }, + { + "bytes24", + [24]byte{1}, + common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), + }, + { + "bytes25", + [25]byte{1}, + common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), + }, + { + "bytes26", + [26]byte{1}, + common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), + }, + { + "bytes27", + [27]byte{1}, + common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), + }, + { + "bytes28", + [28]byte{1}, + common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), + }, + { + "bytes29", + [29]byte{1}, + common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), + }, + { + "bytes30", + [30]byte{1}, + common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), + }, + { + "bytes31", + [31]byte{1}, + common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), + }, + { + "bytes32", + [32]byte{1}, + common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), + }, + { + "address[]", + []common.Address{{1}, {2}}, + common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000001000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000"), + }, + { + "bytes32[]", + []common.Hash{{1}, {2}}, + common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000201000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000"), + }, + { + "function", + [24]byte{1}, + common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), + }, + { + "string", + "foobar", + common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000006666f6f6261720000000000000000000000000000000000000000000000000000"), + }, + } { + typ, err := NewType(test.typ) + if err != nil { + t.Fatal("unexpected parse error:", err) + } + + output, err := typ.pack(reflect.ValueOf(test.input)) + if err != nil { + t.Fatal("unexpected pack error:", err) + } + + if !bytes.Equal(output, test.output) { + t.Errorf("%d failed. Expected bytes: '%x' Got: '%x'", i, test.output, output) + } + } +} + +func TestMethodPack(t *testing.T) { + abi, err := JSON(strings.NewReader(jsondata2)) + if err != nil { + t.Fatal(err) + } + + sig := abi.Methods["slice"].Id() + sig = append(sig, common.LeftPadBytes([]byte{1}, 32)...) + sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...) + + packed, err := abi.Pack("slice", []uint32{1, 2}) + if err != nil { + t.Error(err) + } + + if !bytes.Equal(packed, sig) { + t.Errorf("expected %x got %x", sig, packed) + } + + var addrA, addrB = common.Address{1}, common.Address{2} + sig = abi.Methods["sliceAddress"].Id() + sig = append(sig, common.LeftPadBytes([]byte{32}, 32)...) + sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...) + sig = append(sig, common.LeftPadBytes(addrA[:], 32)...) + sig = append(sig, common.LeftPadBytes(addrB[:], 32)...) + + packed, err = abi.Pack("sliceAddress", []common.Address{addrA, addrB}) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(packed, sig) { + t.Errorf("expected %x got %x", sig, packed) + } + + var addrC, addrD = common.Address{3}, common.Address{4} + sig = abi.Methods["sliceMultiAddress"].Id() + sig = append(sig, common.LeftPadBytes([]byte{64}, 32)...) + sig = append(sig, common.LeftPadBytes([]byte{160}, 32)...) + sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...) + sig = append(sig, common.LeftPadBytes(addrA[:], 32)...) + sig = append(sig, common.LeftPadBytes(addrB[:], 32)...) + sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...) + sig = append(sig, common.LeftPadBytes(addrC[:], 32)...) + sig = append(sig, common.LeftPadBytes(addrD[:], 32)...) + + packed, err = abi.Pack("sliceMultiAddress", []common.Address{addrA, addrB}, []common.Address{addrC, addrD}) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(packed, sig) { + t.Errorf("expected %x got %x", sig, packed) + } + + sig = abi.Methods["slice256"].Id() + sig = append(sig, common.LeftPadBytes([]byte{1}, 32)...) + sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...) + + packed, err = abi.Pack("slice256", []*big.Int{big.NewInt(1), big.NewInt(2)}) + if err != nil { + t.Error(err) + } + + if !bytes.Equal(packed, sig) { + t.Errorf("expected %x got %x", sig, packed) + } +} + +func TestPackNumber(t *testing.T) { + tests := []struct { + value reflect.Value + packed []byte + }{ + // Protocol limits + {reflect.ValueOf(0), common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000000")}, + {reflect.ValueOf(1), common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")}, + {reflect.ValueOf(-1), common.Hex2Bytes("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")}, + + // Type corner cases + {reflect.ValueOf(uint8(math.MaxUint8)), common.Hex2Bytes("00000000000000000000000000000000000000000000000000000000000000ff")}, + {reflect.ValueOf(uint16(math.MaxUint16)), common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000ffff")}, + {reflect.ValueOf(uint32(math.MaxUint32)), common.Hex2Bytes("00000000000000000000000000000000000000000000000000000000ffffffff")}, + {reflect.ValueOf(uint64(math.MaxUint64)), common.Hex2Bytes("000000000000000000000000000000000000000000000000ffffffffffffffff")}, + + {reflect.ValueOf(int8(math.MaxInt8)), common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000007f")}, + {reflect.ValueOf(int16(math.MaxInt16)), common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000007fff")}, + {reflect.ValueOf(int32(math.MaxInt32)), common.Hex2Bytes("000000000000000000000000000000000000000000000000000000007fffffff")}, + {reflect.ValueOf(int64(math.MaxInt64)), common.Hex2Bytes("0000000000000000000000000000000000000000000000007fffffffffffffff")}, + + {reflect.ValueOf(int8(math.MinInt8)), common.Hex2Bytes("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff80")}, + {reflect.ValueOf(int16(math.MinInt16)), common.Hex2Bytes("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff8000")}, + {reflect.ValueOf(int32(math.MinInt32)), common.Hex2Bytes("ffffffffffffffffffffffffffffffffffffffffffffffffffffffff80000000")}, + {reflect.ValueOf(int64(math.MinInt64)), common.Hex2Bytes("ffffffffffffffffffffffffffffffffffffffffffffffff8000000000000000")}, + } + for i, tt := range tests { + packed := packNum(tt.value) + if !bytes.Equal(packed, tt.packed) { + t.Errorf("test %d: pack mismatch: have %x, want %x", i, packed, tt.packed) + } + } + if packed := packNum(reflect.ValueOf("string")); packed != nil { + t.Errorf("expected 'string' to pack to nil. got %x instead", packed) + } +} diff --git a/accounts/abi/reflect.go b/accounts/abi/reflect.go index 7970ba8ac..8fa75df07 100644 --- a/accounts/abi/reflect.go +++ b/accounts/abi/reflect.go @@ -32,30 +32,30 @@ func indirect(v reflect.Value) reflect.Value { // reflectIntKind returns the reflect using the given size and // unsignedness. -func reflectIntKind(unsigned bool, size int) reflect.Kind { +func reflectIntKindAndType(unsigned bool, size int) (reflect.Kind, reflect.Type) { switch size { case 8: if unsigned { - return reflect.Uint8 + return reflect.Uint8, uint8_t } - return reflect.Int8 + return reflect.Int8, int8_t case 16: if unsigned { - return reflect.Uint16 + return reflect.Uint16, uint16_t } - return reflect.Int16 + return reflect.Int16, int16_t case 32: if unsigned { - return reflect.Uint32 + return reflect.Uint32, uint32_t } - return reflect.Int32 + return reflect.Int32, int32_t case 64: if unsigned { - return reflect.Uint64 + return reflect.Uint64, uint64_t } - return reflect.Int64 + return reflect.Int64, int64_t } - return reflect.Ptr + return reflect.Ptr, big_t } // mustArrayToBytesSlice creates a new byte slice with the exact same size as value diff --git a/accounts/abi/type.go b/accounts/abi/type.go index f2832aef5..5f20babb3 100644 --- a/accounts/abi/type.go +++ b/accounts/abi/type.go @@ -33,7 +33,7 @@ const ( FixedBytesTy BytesTy HashTy - FixedpointTy + FixedPointTy FunctionTy ) @@ -126,13 +126,11 @@ func NewType(t string) (typ Type, err error) { switch varType { case "int": - typ.Kind = reflectIntKind(false, varSize) - typ.Type = big_t + typ.Kind, typ.Type = reflectIntKindAndType(false, varSize) typ.Size = varSize typ.T = IntTy case "uint": - typ.Kind = reflectIntKind(true, varSize) - typ.Type = ubig_t + typ.Kind, typ.Type = reflectIntKindAndType(true, varSize) typ.Size = varSize typ.T = UintTy case "bool": diff --git a/accounts/abi/type_test.go b/accounts/abi/type_test.go index 155806459..984a5bb4c 100644 --- a/accounts/abi/type_test.go +++ b/accounts/abi/type_test.go @@ -17,8 +17,11 @@ package abi import ( + "math/big" "reflect" "testing" + + "github.com/ethereum/go-ethereum/common" ) // typeWithoutStringer is a alias for the Type type which simply doesn't implement @@ -31,26 +34,44 @@ func TestTypeRegexp(t *testing.T) { blob string kind Type }{ - {"int", Type{Kind: reflect.Ptr, Type: big_t, Size: 256, T: IntTy, stringKind: "int256"}}, - {"int8", Type{Kind: reflect.Int8, Type: big_t, Size: 8, T: IntTy, stringKind: "int8"}}, + {"bool", Type{Kind: reflect.Bool, T: BoolTy, stringKind: "bool"}}, + {"bool[]", Type{IsSlice: true, SliceSize: -1, Kind: reflect.Bool, T: BoolTy, Elem: &Type{Kind: reflect.Bool, T: BoolTy, stringKind: "bool"}, stringKind: "bool[]"}}, + {"bool[2]", Type{IsArray: true, SliceSize: 2, Kind: reflect.Bool, T: BoolTy, Elem: &Type{Kind: reflect.Bool, T: BoolTy, stringKind: "bool"}, stringKind: "bool[2]"}}, + {"int8", Type{Kind: reflect.Int8, Type: int8_t, Size: 8, T: IntTy, stringKind: "int8"}}, + {"int16", Type{Kind: reflect.Int16, Type: int16_t, Size: 16, T: IntTy, stringKind: "int16"}}, + {"int32", Type{Kind: reflect.Int32, Type: int32_t, Size: 32, T: IntTy, stringKind: "int32"}}, + {"int64", Type{Kind: reflect.Int64, Type: int64_t, Size: 64, T: IntTy, stringKind: "int64"}}, {"int256", Type{Kind: reflect.Ptr, Type: big_t, Size: 256, T: IntTy, stringKind: "int256"}}, - {"int[]", Type{IsSlice: true, SliceSize: -1, Kind: reflect.Ptr, Type: big_t, Size: 256, T: IntTy, Elem: &Type{Kind: reflect.Ptr, Type: big_t, Size: 256, T: IntTy, stringKind: "int256"}, stringKind: "int256[]"}}, - {"int[2]", Type{IsArray: true, SliceSize: 2, Kind: reflect.Ptr, Type: big_t, Size: 256, T: IntTy, Elem: &Type{Kind: reflect.Ptr, Type: big_t, Size: 256, T: IntTy, stringKind: "int256"}, stringKind: "int256[2]"}}, - {"int32[]", Type{IsSlice: true, SliceSize: -1, Kind: reflect.Int32, Type: big_t, Size: 32, T: IntTy, Elem: &Type{Kind: reflect.Int32, Type: big_t, Size: 32, T: IntTy, stringKind: "int32"}, stringKind: "int32[]"}}, - {"int32[2]", Type{IsArray: true, SliceSize: 2, Kind: reflect.Int32, Type: big_t, Size: 32, T: IntTy, Elem: &Type{Kind: reflect.Int32, Type: big_t, Size: 32, T: IntTy, stringKind: "int32"}, stringKind: "int32[2]"}}, - {"uint", Type{Kind: reflect.Ptr, Type: ubig_t, Size: 256, T: UintTy, stringKind: "uint256"}}, - {"uint8", Type{Kind: reflect.Uint8, Type: ubig_t, Size: 8, T: UintTy, stringKind: "uint8"}}, - {"uint256", Type{Kind: reflect.Ptr, Type: ubig_t, Size: 256, T: UintTy, stringKind: "uint256"}}, - {"uint[]", Type{IsSlice: true, SliceSize: -1, Kind: reflect.Ptr, Type: ubig_t, Size: 256, T: UintTy, Elem: &Type{Kind: reflect.Ptr, Type: ubig_t, Size: 256, T: UintTy, stringKind: "uint256"}, stringKind: "uint256[]"}}, - {"uint[2]", Type{IsArray: true, SliceSize: 2, Kind: reflect.Ptr, Type: ubig_t, Size: 256, T: UintTy, Elem: &Type{Kind: reflect.Ptr, Type: ubig_t, Size: 256, T: UintTy, stringKind: "uint256"}, stringKind: "uint256[2]"}}, - {"uint32[]", Type{IsSlice: true, SliceSize: -1, Kind: reflect.Uint32, Type: ubig_t, Size: 32, T: UintTy, Elem: &Type{Kind: reflect.Uint32, Type: big_t, Size: 32, T: UintTy, stringKind: "uint32"}, stringKind: "uint32[]"}}, - {"uint32[2]", Type{IsArray: true, SliceSize: 2, Kind: reflect.Uint32, Type: ubig_t, Size: 32, T: UintTy, Elem: &Type{Kind: reflect.Uint32, Type: big_t, Size: 32, T: UintTy, stringKind: "uint32"}, stringKind: "uint32[2]"}}, - {"bytes", Type{IsSlice: true, SliceSize: -1, Elem: &Type{Kind: reflect.Uint8, Type: ubig_t, Size: 8, T: UintTy, stringKind: "uint8"}, T: BytesTy, stringKind: "bytes"}}, - {"bytes32", Type{IsArray: true, SliceSize: 32, Elem: &Type{Kind: reflect.Uint8, Type: ubig_t, Size: 8, T: UintTy, stringKind: "uint8"}, T: FixedBytesTy, stringKind: "bytes32"}}, - {"bytes[]", Type{IsSlice: true, SliceSize: -1, Elem: &Type{IsSlice: true, SliceSize: -1, Elem: &Type{Kind: reflect.Uint8, Type: ubig_t, Size: 8, T: UintTy, stringKind: "uint8"}, T: BytesTy, stringKind: "bytes"}, stringKind: "bytes[]"}}, - {"bytes[2]", Type{IsArray: true, SliceSize: 2, Elem: &Type{IsSlice: true, SliceSize: -1, Elem: &Type{Kind: reflect.Uint8, Type: ubig_t, Size: 8, T: UintTy, stringKind: "uint8"}, T: BytesTy, stringKind: "bytes"}, stringKind: "bytes[2]"}}, - {"bytes32[]", Type{IsSlice: true, SliceSize: -1, Elem: &Type{IsArray: true, SliceSize: 32, Elem: &Type{Kind: reflect.Uint8, Type: ubig_t, Size: 8, T: UintTy, stringKind: "uint8"}, T: FixedBytesTy, stringKind: "bytes32"}, stringKind: "bytes32[]"}}, - {"bytes32[2]", Type{IsArray: true, SliceSize: 2, Elem: &Type{IsArray: true, SliceSize: 32, Elem: &Type{Kind: reflect.Uint8, Type: ubig_t, Size: 8, T: UintTy, stringKind: "uint8"}, T: FixedBytesTy, stringKind: "bytes32"}, stringKind: "bytes32[2]"}}, + {"int8[]", Type{IsSlice: true, SliceSize: -1, Kind: reflect.Int8, Type: int8_t, Size: 8, T: IntTy, Elem: &Type{Kind: reflect.Int8, Type: int8_t, Size: 8, T: IntTy, stringKind: "int8"}, stringKind: "int8[]"}}, + {"int8[2]", Type{IsArray: true, SliceSize: 2, Kind: reflect.Int8, Type: int8_t, Size: 8, T: IntTy, Elem: &Type{Kind: reflect.Int8, Type: int8_t, Size: 8, T: IntTy, stringKind: "int8"}, stringKind: "int8[2]"}}, + {"int16[]", Type{IsSlice: true, SliceSize: -1, Kind: reflect.Int16, Type: int16_t, Size: 16, T: IntTy, Elem: &Type{Kind: reflect.Int16, Type: int16_t, Size: 16, T: IntTy, stringKind: "int16"}, stringKind: "int16[]"}}, + {"int16[2]", Type{IsArray: true, SliceSize: 2, Kind: reflect.Int16, Type: int16_t, Size: 16, T: IntTy, Elem: &Type{Kind: reflect.Int16, Type: int16_t, Size: 16, T: IntTy, stringKind: "int16"}, stringKind: "int16[2]"}}, + {"int32[]", Type{IsSlice: true, SliceSize: -1, Kind: reflect.Int32, Type: int32_t, Size: 32, T: IntTy, Elem: &Type{Kind: reflect.Int32, Type: int32_t, Size: 32, T: IntTy, stringKind: "int32"}, stringKind: "int32[]"}}, + {"int32[2]", Type{IsArray: true, SliceSize: 2, Kind: reflect.Int32, Type: int32_t, Size: 32, T: IntTy, Elem: &Type{Kind: reflect.Int32, Type: int32_t, Size: 32, T: IntTy, stringKind: "int32"}, stringKind: "int32[2]"}}, + {"int64[]", Type{IsSlice: true, SliceSize: -1, Kind: reflect.Int64, Type: int64_t, Size: 64, T: IntTy, Elem: &Type{Kind: reflect.Int64, Type: int64_t, Size: 64, T: IntTy, stringKind: "int64"}, stringKind: "int64[]"}}, + {"int64[2]", Type{IsArray: true, SliceSize: 2, Kind: reflect.Int64, Type: int64_t, Size: 64, T: IntTy, Elem: &Type{Kind: reflect.Int64, Type: int64_t, Size: 64, T: IntTy, stringKind: "int64"}, stringKind: "int64[2]"}}, + {"int256[]", Type{IsSlice: true, SliceSize: -1, Kind: reflect.Ptr, Type: big_t, Size: 256, T: IntTy, Elem: &Type{Kind: reflect.Ptr, Type: big_t, Size: 256, T: IntTy, stringKind: "int256"}, stringKind: "int256[]"}}, + {"int256[2]", Type{IsArray: true, SliceSize: 2, Kind: reflect.Ptr, Type: big_t, Size: 256, T: IntTy, Elem: &Type{Kind: reflect.Ptr, Type: big_t, Size: 256, T: IntTy, stringKind: "int256"}, stringKind: "int256[2]"}}, + {"uint8", Type{Kind: reflect.Uint8, Type: uint8_t, Size: 8, T: UintTy, stringKind: "uint8"}}, + {"uint16", Type{Kind: reflect.Uint16, Type: uint16_t, Size: 16, T: UintTy, stringKind: "uint16"}}, + {"uint32", Type{Kind: reflect.Uint32, Type: uint32_t, Size: 32, T: UintTy, stringKind: "uint32"}}, + {"uint64", Type{Kind: reflect.Uint64, Type: uint64_t, Size: 64, T: UintTy, stringKind: "uint64"}}, + {"uint256", Type{Kind: reflect.Ptr, Type: big_t, Size: 256, T: UintTy, stringKind: "uint256"}}, + {"uint8[]", Type{IsSlice: true, SliceSize: -1, Kind: reflect.Uint8, Type: uint8_t, Size: 8, T: UintTy, Elem: &Type{Kind: reflect.Uint8, Type: uint8_t, Size: 8, T: UintTy, stringKind: "uint8"}, stringKind: "uint8[]"}}, + {"uint8[2]", Type{IsArray: true, SliceSize: 2, Kind: reflect.Uint8, Type: uint8_t, Size: 8, T: UintTy, Elem: &Type{Kind: reflect.Uint8, Type: uint8_t, Size: 8, T: UintTy, stringKind: "uint8"}, stringKind: "uint8[2]"}}, + {"uint16[]", Type{IsSlice: true, SliceSize: -1, Kind: reflect.Uint16, Type: uint16_t, Size: 16, T: UintTy, Elem: &Type{Kind: reflect.Uint16, Type: uint16_t, Size: 16, T: UintTy, stringKind: "uint16"}, stringKind: "uint16[]"}}, + {"uint16[2]", Type{IsArray: true, SliceSize: 2, Kind: reflect.Uint16, Type: uint16_t, Size: 16, T: UintTy, Elem: &Type{Kind: reflect.Uint16, Type: uint16_t, Size: 16, T: UintTy, stringKind: "uint16"}, stringKind: "uint16[2]"}}, + {"uint32[]", Type{IsSlice: true, SliceSize: -1, Kind: reflect.Uint32, Type: uint32_t, Size: 32, T: UintTy, Elem: &Type{Kind: reflect.Uint32, Type: uint32_t, Size: 32, T: UintTy, stringKind: "uint32"}, stringKind: "uint32[]"}}, + {"uint32[2]", Type{IsArray: true, SliceSize: 2, Kind: reflect.Uint32, Type: uint32_t, Size: 32, T: UintTy, Elem: &Type{Kind: reflect.Uint32, Type: uint32_t, Size: 32, T: UintTy, stringKind: "uint32"}, stringKind: "uint32[2]"}}, + {"uint64[]", Type{IsSlice: true, SliceSize: -1, Kind: reflect.Uint64, Type: uint64_t, Size: 64, T: UintTy, Elem: &Type{Kind: reflect.Uint64, Type: uint64_t, Size: 64, T: UintTy, stringKind: "uint64"}, stringKind: "uint64[]"}}, + {"uint64[2]", Type{IsArray: true, SliceSize: 2, Kind: reflect.Uint64, Type: uint64_t, Size: 64, T: UintTy, Elem: &Type{Kind: reflect.Uint64, Type: uint64_t, Size: 64, T: UintTy, stringKind: "uint64"}, stringKind: "uint64[2]"}}, + {"uint256[]", Type{IsSlice: true, SliceSize: -1, Kind: reflect.Ptr, Type: big_t, Size: 256, T: UintTy, Elem: &Type{Kind: reflect.Ptr, Type: big_t, Size: 256, T: UintTy, stringKind: "uint256"}, stringKind: "uint256[]"}}, + {"uint256[2]", Type{IsArray: true, SliceSize: 2, Kind: reflect.Ptr, Type: big_t, Size: 256, T: UintTy, Elem: &Type{Kind: reflect.Ptr, Type: big_t, Size: 256, T: UintTy, stringKind: "uint256"}, stringKind: "uint256[2]"}}, + {"bytes32", Type{IsArray: true, SliceSize: 32, Elem: &Type{Kind: reflect.Uint8, Type: uint8_t, Size: 8, T: UintTy, stringKind: "uint8"}, T: FixedBytesTy, stringKind: "bytes32"}}, + {"bytes[]", Type{IsSlice: true, SliceSize: -1, Elem: &Type{IsSlice: true, SliceSize: -1, Elem: &Type{Kind: reflect.Uint8, Type: uint8_t, Size: 8, T: UintTy, stringKind: "uint8"}, T: BytesTy, stringKind: "bytes"}, stringKind: "bytes[]"}}, + {"bytes[2]", Type{IsArray: true, SliceSize: 2, Elem: &Type{IsSlice: true, SliceSize: -1, Elem: &Type{Kind: reflect.Uint8, Type: uint8_t, Size: 8, T: UintTy, stringKind: "uint8"}, T: BytesTy, stringKind: "bytes"}, stringKind: "bytes[2]"}}, + {"bytes32[]", Type{IsSlice: true, SliceSize: -1, Elem: &Type{IsArray: true, SliceSize: 32, Elem: &Type{Kind: reflect.Uint8, Type: uint8_t, Size: 8, T: UintTy, stringKind: "uint8"}, T: FixedBytesTy, stringKind: "bytes32"}, stringKind: "bytes32[]"}}, + {"bytes32[2]", Type{IsArray: true, SliceSize: 2, Elem: &Type{IsArray: true, SliceSize: 32, Elem: &Type{Kind: reflect.Uint8, Type: uint8_t, Size: 8, T: UintTy, stringKind: "uint8"}, T: FixedBytesTy, stringKind: "bytes32"}, stringKind: "bytes32[2]"}}, {"string", Type{Kind: reflect.String, Size: -1, T: StringTy, stringKind: "string"}}, {"string[]", Type{IsSlice: true, SliceSize: -1, Kind: reflect.String, T: StringTy, Size: -1, Elem: &Type{Kind: reflect.String, T: StringTy, Size: -1, stringKind: "string"}, stringKind: "string[]"}}, {"string[2]", Type{IsArray: true, SliceSize: 2, Kind: reflect.String, T: StringTy, Size: -1, Elem: &Type{Kind: reflect.String, T: StringTy, Size: -1, stringKind: "string"}, stringKind: "string[2]"}}, @@ -76,3 +97,59 @@ func TestTypeRegexp(t *testing.T) { } } } + +func TestTypeCheck(t *testing.T) { + for i, test := range []struct { + typ string + input interface{} + err string + }{ + {"uint", big.NewInt(1), ""}, + {"int", big.NewInt(1), ""}, + {"uint30", big.NewInt(1), ""}, + {"uint30", uint8(1), "abi: cannot use uint8 as type ptr as argument"}, + {"uint16", uint16(1), ""}, + {"uint16", uint8(1), "abi: cannot use uint8 as type uint16 as argument"}, + {"uint16[]", []uint16{1, 2, 3}, ""}, + {"uint16[]", [3]uint16{1, 2, 3}, ""}, + {"uint16[]", []uint32{1, 2, 3}, "abi: cannot use []uint32 as type []uint16 as argument"}, + {"uint16[3]", [3]uint32{1, 2, 3}, "abi: cannot use [3]uint32 as type [3]uint16 as argument"}, + {"uint16[3]", [4]uint16{1, 2, 3}, "abi: cannot use [4]uint16 as type [3]uint16 as argument"}, + {"uint16[3]", []uint16{1, 2, 3}, ""}, + {"uint16[3]", []uint16{1, 2, 3, 4}, "abi: cannot use [4]uint16 as type [3]uint16 as argument"}, + {"address[]", []common.Address{{1}}, ""}, + {"address[1]", []common.Address{{1}}, ""}, + {"address[1]", [1]common.Address{{1}}, ""}, + {"address[2]", [1]common.Address{{1}}, "abi: cannot use [1]array as type [2]array as argument"}, + {"bytes32", [32]byte{}, ""}, + {"bytes32", [33]byte{}, "abi: cannot use [33]uint8 as type [32]uint8 as argument"}, + {"bytes32", common.Hash{1}, ""}, + {"bytes31", [31]byte{}, ""}, + {"bytes31", [32]byte{}, "abi: cannot use [32]uint8 as type [31]uint8 as argument"}, + {"bytes", []byte{0, 1}, ""}, + {"bytes", [2]byte{0, 1}, ""}, + {"bytes", common.Hash{1}, ""}, + {"string", "hello world", ""}, + {"bytes32[]", [][32]byte{{}}, ""}, + {"function", [24]byte{}, ""}, + } { + typ, err := NewType(test.typ) + if err != nil { + t.Fatal("unexpected parse error:", err) + } + + err = typeCheck(typ, reflect.ValueOf(test.input)) + if err != nil && len(test.err) == 0 { + t.Errorf("%d failed. Expected no err but got: %v", i, err) + continue + } + if err == nil && len(test.err) != 0 { + t.Errorf("%d failed. Expected err: %v but got none", i, test.err) + continue + } + + if err != nil && len(test.err) != 0 && err.Error() != test.err { + t.Errorf("%d failed. Expected err: '%v' got err: '%v'", i, test.err, err) + } + } +} diff --git a/accounts/abi/unpack.go b/accounts/abi/unpack.go new file mode 100644 index 000000000..fc41c88ac --- /dev/null +++ b/accounts/abi/unpack.go @@ -0,0 +1,235 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +package abi + +import ( + "encoding/binary" + "fmt" + "math/big" + "reflect" + + "github.com/ethereum/go-ethereum/common" +) + +// toGoSliceType parses the input and casts it to the proper slice defined by the ABI +// argument in T. +func toGoSlice(i int, t Argument, output []byte) (interface{}, error) { + index := i * 32 + // The slice must, at very least be large enough for the index+32 which is exactly the size required + // for the [offset in output, size of offset]. + if index+32 > len(output) { + return nil, fmt.Errorf("abi: cannot marshal in to go slice: insufficient size output %d require %d", len(output), index+32) + } + elem := t.Type.Elem + + // first we need to create a slice of the type + var refSlice reflect.Value + switch elem.T { + case IntTy, UintTy, BoolTy: + // create a new reference slice matching the element type + switch t.Type.Kind { + case reflect.Bool: + refSlice = reflect.ValueOf([]bool(nil)) + case reflect.Uint8: + refSlice = reflect.ValueOf([]uint8(nil)) + case reflect.Uint16: + refSlice = reflect.ValueOf([]uint16(nil)) + case reflect.Uint32: + refSlice = reflect.ValueOf([]uint32(nil)) + case reflect.Uint64: + refSlice = reflect.ValueOf([]uint64(nil)) + case reflect.Int8: + refSlice = reflect.ValueOf([]int8(nil)) + case reflect.Int16: + refSlice = reflect.ValueOf([]int16(nil)) + case reflect.Int32: + refSlice = reflect.ValueOf([]int32(nil)) + case reflect.Int64: + refSlice = reflect.ValueOf([]int64(nil)) + default: + refSlice = reflect.ValueOf([]*big.Int(nil)) + } + case AddressTy: // address must be of slice Address + refSlice = reflect.ValueOf([]common.Address(nil)) + case HashTy: // hash must be of slice hash + refSlice = reflect.ValueOf([]common.Hash(nil)) + case FixedBytesTy: + refSlice = reflect.ValueOf([][]byte(nil)) + default: // no other types are supported + return nil, fmt.Errorf("abi: unsupported slice type %v", elem.T) + } + + var slice []byte + var size int + var offset int + if t.Type.IsSlice { + // get the offset which determines the start of this array ... + offset = int(binary.BigEndian.Uint64(output[index+24 : index+32])) + if offset+32 > len(output) { + return nil, fmt.Errorf("abi: cannot marshal in to go slice: offset %d would go over slice boundary (len=%d)", len(output), offset+32) + } + + slice = output[offset:] + // ... starting with the size of the array in elements ... + size = int(binary.BigEndian.Uint64(slice[24:32])) + slice = slice[32:] + // ... and make sure that we've at the very least the amount of bytes + // available in the buffer. + if size*32 > len(slice) { + return nil, fmt.Errorf("abi: cannot marshal in to go slice: insufficient size output %d require %d", len(output), offset+32+size*32) + } + + // reslice to match the required size + slice = slice[:size*32] + } else if t.Type.IsArray { + //get the number of elements in the array + size = t.Type.SliceSize + + //check to make sure array size matches up + if index+32*size > len(output) { + return nil, fmt.Errorf("abi: cannot marshal in to go array: offset %d would go over slice boundary (len=%d)", len(output), index+32*size) + } + //slice is there for a fixed amount of times + slice = output[index : index+size*32] + } + + for i := 0; i < size; i++ { + var ( + inter interface{} // interface type + returnOutput = slice[i*32 : i*32+32] // the return output + err error + ) + // set inter to the correct type (cast) + switch elem.T { + case IntTy, UintTy: + inter = readInteger(t.Type.Kind, returnOutput) + case BoolTy: + inter, err = readBool(returnOutput) + if err != nil { + return nil, err + } + case AddressTy: + inter = common.BytesToAddress(returnOutput) + case HashTy: + inter = common.BytesToHash(returnOutput) + case FixedBytesTy: + inter = returnOutput + } + // append the item to our reflect slice + refSlice = reflect.Append(refSlice, reflect.ValueOf(inter)) + } + + // return the interface + return refSlice.Interface(), nil +} + +func readInteger(kind reflect.Kind, b []byte) interface{} { + switch kind { + case reflect.Uint8: + return uint8(b[len(b)-1]) + case reflect.Uint16: + return binary.BigEndian.Uint16(b[len(b)-2:]) + case reflect.Uint32: + return binary.BigEndian.Uint32(b[len(b)-4:]) + case reflect.Uint64: + return binary.BigEndian.Uint64(b[len(b)-8:]) + case reflect.Int8: + return int8(b[len(b)-1]) + case reflect.Int16: + return int16(binary.BigEndian.Uint16(b[len(b)-2:])) + case reflect.Int32: + return int32(binary.BigEndian.Uint32(b[len(b)-4:])) + case reflect.Int64: + return int64(binary.BigEndian.Uint64(b[len(b)-8:])) + default: + return new(big.Int).SetBytes(b) + } +} + +func readBool(word []byte) (bool, error) { + if len(word) != 32 { + return false, fmt.Errorf("abi: fatal error: incorrect word length") + } + + for i, b := range word { + if b != 0 && i != 31 { + return false, errBadBool + } + } + switch word[31] { + case 0: + return false, nil + case 1: + return true, nil + default: + return false, errBadBool + } + +} + +// toGoType parses the input and casts it to the proper type defined by the ABI +// argument in T. +func toGoType(i int, t Argument, output []byte) (interface{}, error) { + // we need to treat slices differently + if (t.Type.IsSlice || t.Type.IsArray) && t.Type.T != BytesTy && t.Type.T != StringTy && t.Type.T != FixedBytesTy && t.Type.T != FunctionTy { + return toGoSlice(i, t, output) + } + + index := i * 32 + if index+32 > len(output) { + return nil, fmt.Errorf("abi: cannot marshal in to go type: length insufficient %d require %d", len(output), index+32) + } + + // Parse the given index output and check whether we need to read + // a different offset and length based on the type (i.e. string, bytes) + var returnOutput []byte + switch t.Type.T { + case StringTy, BytesTy: // variable arrays are written at the end of the return bytes + // parse offset from which we should start reading + offset := int(binary.BigEndian.Uint64(output[index+24 : index+32])) + if offset+32 > len(output) { + return nil, fmt.Errorf("abi: cannot marshal in to go type: length insufficient %d require %d", len(output), offset+32) + } + // parse the size up until we should be reading + size := int(binary.BigEndian.Uint64(output[offset+24 : offset+32])) + if offset+32+size > len(output) { + return nil, fmt.Errorf("abi: cannot marshal in to go type: length insufficient %d require %d", len(output), offset+32+size) + } + + // get the bytes for this return value + returnOutput = output[offset+32 : offset+32+size] + default: + returnOutput = output[index : index+32] + } + + // convert the bytes to whatever is specified by the ABI. + switch t.Type.T { + case IntTy, UintTy: + return readInteger(t.Type.Kind, returnOutput), nil + case BoolTy: + return readBool(returnOutput) + case AddressTy: + return common.BytesToAddress(returnOutput), nil + case HashTy: + return common.BytesToHash(returnOutput), nil + case BytesTy, FixedBytesTy, FunctionTy: + return returnOutput, nil + case StringTy: + return string(returnOutput), nil + } + return nil, fmt.Errorf("abi: unknown type %v", t.Type.T) +} diff --git a/accounts/abi/unpack_test.go b/accounts/abi/unpack_test.go new file mode 100644 index 000000000..8e3afee4e --- /dev/null +++ b/accounts/abi/unpack_test.go @@ -0,0 +1,681 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +package abi + +import ( + "bytes" + "fmt" + "math/big" + "reflect" + "strings" + "testing" + + "github.com/ethereum/go-ethereum/common" +) + +func TestSimpleMethodUnpack(t *testing.T) { + for i, test := range []struct { + def string // definition of the **output** ABI params + marshalledOutput []byte // evm return data + expectedOut interface{} // the expected output + outVar string // the output variable (e.g. uint32, *big.Int, etc) + err string // empty or error if expected + }{ + { + `[ { "type": "bool" } ]`, + common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001"), + bool(true), + "bool", + "", + }, + { + `[ { "type": "uint32" } ]`, + common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001"), + uint32(1), + "uint32", + "", + }, + { + `[ { "type": "uint32" } ]`, + common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001"), + nil, + "uint16", + "abi: cannot unmarshal uint32 in to uint16", + }, + { + `[ { "type": "uint17" } ]`, + common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001"), + nil, + "uint16", + "abi: cannot unmarshal *big.Int in to uint16", + }, + { + `[ { "type": "uint17" } ]`, + common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001"), + big.NewInt(1), + "*big.Int", + "", + }, + + { + `[ { "type": "int32" } ]`, + common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001"), + int32(1), + "int32", + "", + }, + { + `[ { "type": "int32" } ]`, + common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001"), + nil, + "int16", + "abi: cannot unmarshal int32 in to int16", + }, + { + `[ { "type": "int17" } ]`, + common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001"), + nil, + "int16", + "abi: cannot unmarshal *big.Int in to int16", + }, + { + `[ { "type": "int17" } ]`, + common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001"), + big.NewInt(1), + "*big.Int", + "", + }, + + { + `[ { "type": "address" } ]`, + common.Hex2Bytes("0000000000000000000000000100000000000000000000000000000000000000"), + common.Address{1}, + "address", + "", + }, + { + `[ { "type": "bytes32" } ]`, + common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), + common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), + "bytes", + "", + }, + { + `[ { "type": "bytes32" } ]`, + common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), + common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), + "hash", + "", + }, + { + `[ { "type": "bytes32" } ]`, + common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), + common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), + "interface", + "", + }, + { + `[ { "type": "function" } ]`, + common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), + [24]byte{1}, + "function", + "", + }, + } { + abiDefinition := fmt.Sprintf(`[{ "name" : "method", "outputs": %s}]`, test.def) + abi, err := JSON(strings.NewReader(abiDefinition)) + if err != nil { + t.Errorf("%d failed. %v", i, err) + continue + } + + var outvar interface{} + switch test.outVar { + case "bool": + var v bool + err = abi.Unpack(&v, "method", test.marshalledOutput) + outvar = v + case "uint8": + var v uint8 + err = abi.Unpack(&v, "method", test.marshalledOutput) + outvar = v + case "uint16": + var v uint16 + err = abi.Unpack(&v, "method", test.marshalledOutput) + outvar = v + case "uint32": + var v uint32 + err = abi.Unpack(&v, "method", test.marshalledOutput) + outvar = v + case "uint64": + var v uint64 + err = abi.Unpack(&v, "method", test.marshalledOutput) + outvar = v + case "int8": + var v int8 + err = abi.Unpack(&v, "method", test.marshalledOutput) + outvar = v + case "int16": + var v int16 + err = abi.Unpack(&v, "method", test.marshalledOutput) + outvar = v + case "int32": + var v int32 + err = abi.Unpack(&v, "method", test.marshalledOutput) + outvar = v + case "int64": + var v int64 + err = abi.Unpack(&v, "method", test.marshalledOutput) + outvar = v + case "*big.Int": + var v *big.Int + err = abi.Unpack(&v, "method", test.marshalledOutput) + outvar = v + case "address": + var v common.Address + err = abi.Unpack(&v, "method", test.marshalledOutput) + outvar = v + case "bytes": + var v []byte + err = abi.Unpack(&v, "method", test.marshalledOutput) + outvar = v + case "hash": + var v common.Hash + err = abi.Unpack(&v, "method", test.marshalledOutput) + outvar = v.Bytes()[:] + case "function": + var v [24]byte + err = abi.Unpack(&v, "method", test.marshalledOutput) + outvar = v + case "interface": + err = abi.Unpack(&outvar, "method", test.marshalledOutput) + default: + t.Errorf("unsupported type '%v' please add it to the switch statement in this test", test.outVar) + continue + } + + if err != nil && len(test.err) == 0 { + t.Errorf("%d failed. Expected no err but got: %v", i, err) + continue + } + if err == nil && len(test.err) != 0 { + t.Errorf("%d failed. Expected err: %v but got none", i, test.err) + continue + } + if err != nil && len(test.err) != 0 && err.Error() != test.err { + t.Errorf("%d failed. Expected err: '%v' got err: '%v'", i, test.err, err) + continue + } + + if err == nil { + if !reflect.DeepEqual(test.expectedOut, outvar) { + t.Errorf("%d failed. Output error: expected %v, got %v", i, test.expectedOut, outvar) + } + } + } +} + +func TestUnpackSetInterfaceSlice(t *testing.T) { + var ( + var1 = new(uint8) + var2 = new(uint8) + ) + out := []interface{}{var1, var2} + abi, err := JSON(strings.NewReader(`[{"type":"function", "name":"ints", "outputs":[{"type":"uint8"}, {"type":"uint8"}]}]`)) + if err != nil { + t.Fatal(err) + } + marshalledReturn := append(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001"), common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002")...) + err = abi.Unpack(&out, "ints", marshalledReturn) + if err != nil { + t.Fatal(err) + } + if *var1 != 1 { + t.Error("expected var1 to be 1, got", *var1) + } + if *var2 != 2 { + t.Error("expected var2 to be 2, got", *var2) + } + + out = []interface{}{var1} + err = abi.Unpack(&out, "ints", marshalledReturn) + + expErr := "abi: cannot marshal in to slices of unequal size (require: 2, got: 1)" + if err == nil || err.Error() != expErr { + t.Error("expected err:", expErr, "Got:", err) + } +} + +func TestUnpackSetInterfaceArrayOutput(t *testing.T) { + var ( + var1 = new([1]uint32) + var2 = new([1]uint32) + ) + out := []interface{}{var1, var2} + abi, err := JSON(strings.NewReader(`[{"type":"function", "name":"ints", "outputs":[{"type":"uint32[1]"}, {"type":"uint32[1]"}]}]`)) + if err != nil { + t.Fatal(err) + } + marshalledReturn := append(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001"), common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002")...) + err = abi.Unpack(&out, "ints", marshalledReturn) + if err != nil { + t.Fatal(err) + } + + if *var1 != [1]uint32{1} { + t.Error("expected var1 to be [1], got", *var1) + } + if *var2 != [1]uint32{2} { + t.Error("expected var2 to be [2], got", *var2) + } +} + +func TestMultiReturnWithStruct(t *testing.T) { + const definition = `[ + { "name" : "multi", "constant" : false, "outputs": [ { "name": "Int", "type": "uint256" }, { "name": "String", "type": "string" } ] }]` + + abi, err := JSON(strings.NewReader(definition)) + if err != nil { + t.Fatal(err) + } + + // using buff to make the code readable + buff := new(bytes.Buffer) + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000040")) + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000005")) + stringOut := "hello" + buff.Write(common.RightPadBytes([]byte(stringOut), 32)) + + var inter struct { + Int *big.Int + String string + } + err = abi.Unpack(&inter, "multi", buff.Bytes()) + if err != nil { + t.Error(err) + } + + if inter.Int == nil || inter.Int.Cmp(big.NewInt(1)) != 0 { + t.Error("expected Int to be 1 got", inter.Int) + } + + if inter.String != stringOut { + t.Error("expected String to be", stringOut, "got", inter.String) + } + + var reversed struct { + String string + Int *big.Int + } + + err = abi.Unpack(&reversed, "multi", buff.Bytes()) + if err != nil { + t.Error(err) + } + + if reversed.Int == nil || reversed.Int.Cmp(big.NewInt(1)) != 0 { + t.Error("expected Int to be 1 got", reversed.Int) + } + + if reversed.String != stringOut { + t.Error("expected String to be", stringOut, "got", reversed.String) + } +} + +func TestMultiReturnWithSlice(t *testing.T) { + const definition = `[ + { "name" : "multi", "constant" : false, "outputs": [ { "name": "Int", "type": "uint256" }, { "name": "String", "type": "string" } ] }]` + + abi, err := JSON(strings.NewReader(definition)) + if err != nil { + t.Fatal(err) + } + + // using buff to make the code readable + buff := new(bytes.Buffer) + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000040")) + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000005")) + stringOut := "hello" + buff.Write(common.RightPadBytes([]byte(stringOut), 32)) + + var inter []interface{} + err = abi.Unpack(&inter, "multi", buff.Bytes()) + if err != nil { + t.Error(err) + } + + if len(inter) != 2 { + t.Fatal("expected 2 results got", len(inter)) + } + + if num, ok := inter[0].(*big.Int); !ok || num.Cmp(big.NewInt(1)) != 0 { + t.Error("expected index 0 to be 1 got", num) + } + + if str, ok := inter[1].(string); !ok || str != stringOut { + t.Error("expected index 1 to be", stringOut, "got", str) + } +} + +func TestMarshalArrays(t *testing.T) { + const definition = `[ + { "name" : "bytes32", "constant" : false, "outputs": [ { "type": "bytes32" } ] }, + { "name" : "bytes10", "constant" : false, "outputs": [ { "type": "bytes10" } ] } + ]` + + abi, err := JSON(strings.NewReader(definition)) + if err != nil { + t.Fatal(err) + } + + output := common.LeftPadBytes([]byte{1}, 32) + + var bytes10 [10]byte + err = abi.Unpack(&bytes10, "bytes32", output) + if err == nil || err.Error() != "abi: cannot unmarshal src (len=32) in to dst (len=10)" { + t.Error("expected error or bytes32 not be assignable to bytes10:", err) + } + + var bytes32 [32]byte + err = abi.Unpack(&bytes32, "bytes32", output) + if err != nil { + t.Error("didn't expect error:", err) + } + if !bytes.Equal(bytes32[:], output) { + t.Error("expected bytes32[31] to be 1 got", bytes32[31]) + } + + type ( + B10 [10]byte + B32 [32]byte + ) + + var b10 B10 + err = abi.Unpack(&b10, "bytes32", output) + if err == nil || err.Error() != "abi: cannot unmarshal src (len=32) in to dst (len=10)" { + t.Error("expected error or bytes32 not be assignable to bytes10:", err) + } + + var b32 B32 + err = abi.Unpack(&b32, "bytes32", output) + if err != nil { + t.Error("didn't expect error:", err) + } + if !bytes.Equal(b32[:], output) { + t.Error("expected bytes32[31] to be 1 got", bytes32[31]) + } + + output[10] = 1 + var shortAssignLong [32]byte + err = abi.Unpack(&shortAssignLong, "bytes10", output) + if err != nil { + t.Error("didn't expect error:", err) + } + if !bytes.Equal(output, shortAssignLong[:]) { + t.Errorf("expected %x to be %x", shortAssignLong, output) + } +} + +func TestUnmarshal(t *testing.T) { + const definition = `[ + { "name" : "int", "constant" : false, "outputs": [ { "type": "uint256" } ] }, + { "name" : "bool", "constant" : false, "outputs": [ { "type": "bool" } ] }, + { "name" : "bytes", "constant" : false, "outputs": [ { "type": "bytes" } ] }, + { "name" : "fixed", "constant" : false, "outputs": [ { "type": "bytes32" } ] }, + { "name" : "multi", "constant" : false, "outputs": [ { "type": "bytes" }, { "type": "bytes" } ] }, + { "name" : "intArraySingle", "constant" : false, "outputs": [ { "type": "uint256[3]" } ] }, + { "name" : "addressSliceSingle", "constant" : false, "outputs": [ { "type": "address[]" } ] }, + { "name" : "addressSliceDouble", "constant" : false, "outputs": [ { "name": "a", "type": "address[]" }, { "name": "b", "type": "address[]" } ] }, + { "name" : "mixedBytes", "constant" : true, "outputs": [ { "name": "a", "type": "bytes" }, { "name": "b", "type": "bytes32" } ] }]` + + abi, err := JSON(strings.NewReader(definition)) + if err != nil { + t.Fatal(err) + } + buff := new(bytes.Buffer) + + // marshal int + var Int *big.Int + err = abi.Unpack(&Int, "int", common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) + if err != nil { + t.Error(err) + } + + if Int == nil || Int.Cmp(big.NewInt(1)) != 0 { + t.Error("expected Int to be 1 got", Int) + } + + // marshal bool + var Bool bool + err = abi.Unpack(&Bool, "bool", common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) + if err != nil { + t.Error(err) + } + + if !Bool { + t.Error("expected Bool to be true") + } + + // marshal dynamic bytes max length 32 + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020")) + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020")) + bytesOut := common.RightPadBytes([]byte("hello"), 32) + buff.Write(bytesOut) + + var Bytes []byte + err = abi.Unpack(&Bytes, "bytes", buff.Bytes()) + if err != nil { + t.Error(err) + } + + if !bytes.Equal(Bytes, bytesOut) { + t.Errorf("expected %x got %x", bytesOut, Bytes) + } + + // marshall dynamic bytes max length 64 + buff.Reset() + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020")) + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000040")) + bytesOut = common.RightPadBytes([]byte("hello"), 64) + buff.Write(bytesOut) + + err = abi.Unpack(&Bytes, "bytes", buff.Bytes()) + if err != nil { + t.Error(err) + } + + if !bytes.Equal(Bytes, bytesOut) { + t.Errorf("expected %x got %x", bytesOut, Bytes) + } + + // marshall dynamic bytes max length 63 + buff.Reset() + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020")) + buff.Write(common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000003f")) + bytesOut = common.RightPadBytes([]byte("hello"), 63) + buff.Write(bytesOut) + + err = abi.Unpack(&Bytes, "bytes", buff.Bytes()) + if err != nil { + t.Error(err) + } + + if !bytes.Equal(Bytes, bytesOut) { + t.Errorf("expected %x got %x", bytesOut, Bytes) + } + + // marshal dynamic bytes output empty + err = abi.Unpack(&Bytes, "bytes", nil) + if err == nil { + t.Error("expected error") + } + + // marshal dynamic bytes length 5 + buff.Reset() + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020")) + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000005")) + buff.Write(common.RightPadBytes([]byte("hello"), 32)) + + err = abi.Unpack(&Bytes, "bytes", buff.Bytes()) + if err != nil { + t.Error(err) + } + + if !bytes.Equal(Bytes, []byte("hello")) { + t.Errorf("expected %x got %x", bytesOut, Bytes) + } + + // marshal dynamic bytes length 5 + buff.Reset() + buff.Write(common.RightPadBytes([]byte("hello"), 32)) + + var hash common.Hash + err = abi.Unpack(&hash, "fixed", buff.Bytes()) + if err != nil { + t.Error(err) + } + + helloHash := common.BytesToHash(common.RightPadBytes([]byte("hello"), 32)) + if hash != helloHash { + t.Errorf("Expected %x to equal %x", hash, helloHash) + } + + // marshal error + buff.Reset() + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020")) + err = abi.Unpack(&Bytes, "bytes", buff.Bytes()) + if err == nil { + t.Error("expected error") + } + + err = abi.Unpack(&Bytes, "multi", make([]byte, 64)) + if err == nil { + t.Error("expected error") + } + + // marshal mixed bytes + buff.Reset() + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000040")) + fixed := common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001") + buff.Write(fixed) + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020")) + bytesOut = common.RightPadBytes([]byte("hello"), 32) + buff.Write(bytesOut) + + var out []interface{} + err = abi.Unpack(&out, "mixedBytes", buff.Bytes()) + if err != nil { + t.Fatal("didn't expect error:", err) + } + + if !bytes.Equal(bytesOut, out[0].([]byte)) { + t.Errorf("expected %x, got %x", bytesOut, out[0]) + } + + if !bytes.Equal(fixed, out[1].([]byte)) { + t.Errorf("expected %x, got %x", fixed, out[1]) + } + + buff.Reset() + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002")) + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000003")) + // marshal int array + var intArray [3]*big.Int + err = abi.Unpack(&intArray, "intArraySingle", buff.Bytes()) + if err != nil { + t.Error(err) + } + var testAgainstIntArray [3]*big.Int + testAgainstIntArray[0] = big.NewInt(1) + testAgainstIntArray[1] = big.NewInt(2) + testAgainstIntArray[2] = big.NewInt(3) + + for i, Int := range intArray { + if Int.Cmp(testAgainstIntArray[i]) != 0 { + t.Errorf("expected %v, got %v", testAgainstIntArray[i], Int) + } + } + // marshal address slice + buff.Reset() + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020")) // offset + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) // size + buff.Write(common.Hex2Bytes("0000000000000000000000000100000000000000000000000000000000000000")) + + var outAddr []common.Address + err = abi.Unpack(&outAddr, "addressSliceSingle", buff.Bytes()) + if err != nil { + t.Fatal("didn't expect error:", err) + } + + if len(outAddr) != 1 { + t.Fatal("expected 1 item, got", len(outAddr)) + } + + if outAddr[0] != (common.Address{1}) { + t.Errorf("expected %x, got %x", common.Address{1}, outAddr[0]) + } + + // marshal multiple address slice + buff.Reset() + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000040")) // offset + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000080")) // offset + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) // size + buff.Write(common.Hex2Bytes("0000000000000000000000000100000000000000000000000000000000000000")) + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002")) // size + buff.Write(common.Hex2Bytes("0000000000000000000000000200000000000000000000000000000000000000")) + buff.Write(common.Hex2Bytes("0000000000000000000000000300000000000000000000000000000000000000")) + + var outAddrStruct struct { + A []common.Address + B []common.Address + } + err = abi.Unpack(&outAddrStruct, "addressSliceDouble", buff.Bytes()) + if err != nil { + t.Fatal("didn't expect error:", err) + } + + if len(outAddrStruct.A) != 1 { + t.Fatal("expected 1 item, got", len(outAddrStruct.A)) + } + + if outAddrStruct.A[0] != (common.Address{1}) { + t.Errorf("expected %x, got %x", common.Address{1}, outAddrStruct.A[0]) + } + + if len(outAddrStruct.B) != 2 { + t.Fatal("expected 1 item, got", len(outAddrStruct.B)) + } + + if outAddrStruct.B[0] != (common.Address{2}) { + t.Errorf("expected %x, got %x", common.Address{2}, outAddrStruct.B[0]) + } + if outAddrStruct.B[1] != (common.Address{3}) { + t.Errorf("expected %x, got %x", common.Address{3}, outAddrStruct.B[1]) + } + + // marshal invalid address slice + buff.Reset() + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000100")) + + err = abi.Unpack(&outAddr, "addressSliceSingle", buff.Bytes()) + if err == nil { + t.Fatal("expected error:", err) + } +} diff --git a/build/ci.go b/build/ci.go index 47b1dc780..6a52077d4 100644 --- a/build/ci.go +++ b/build/ci.go @@ -175,7 +175,7 @@ func doInstall(cmdline []string) { // Check Go version. People regularly open issues about compilation // failure with outdated Go. This should save them the trouble. - if runtime.Version() < "go1.7" && !strings.HasPrefix(runtime.Version(), "devel") { + if runtime.Version() < "go1.7" && !strings.Contains(runtime.Version(), "devel") { log.Println("You have Go version", runtime.Version()) log.Println("go-ethereum requires at least Go version 1.7 and cannot") log.Println("be compiled with an earlier version. Please upgrade your Go installation.") diff --git a/cmd/evm/runner.go b/cmd/evm/runner.go index 2ce0920f6..3f95a0c93 100644 --- a/cmd/evm/runner.go +++ b/cmd/evm/runner.go @@ -98,8 +98,8 @@ func runCmd(ctx *cli.Context) error { _, statedb = gen.ToBlock() chainConfig = gen.Config } else { - var db, _ = ethdb.NewMemDatabase() - statedb, _ = state.New(common.Hash{}, db) + db, _ := ethdb.NewMemDatabase() + statedb, _ = state.New(common.Hash{}, state.NewDatabase(db)) } if ctx.GlobalString(SenderFlag.Name) != "" { sender = common.HexToAddress(ctx.GlobalString(SenderFlag.Name)) @@ -188,7 +188,7 @@ func runCmd(ctx *cli.Context) error { execTime := time.Since(tstart) if ctx.GlobalBool(DumpFlag.Name) { - statedb.Commit(true) + statedb.IntermediateRoot(true) fmt.Println(string(statedb.Dump())) } diff --git a/cmd/geth/chaincmd.go b/cmd/geth/chaincmd.go index ab0e92f21..12bc1d7c6 100644 --- a/cmd/geth/chaincmd.go +++ b/cmd/geth/chaincmd.go @@ -312,7 +312,7 @@ func dump(ctx *cli.Context) error { fmt.Println("{}") utils.Fatalf("block not found") } else { - state, err := state.New(block.Root(), chainDb) + state, err := state.New(block.Root(), state.NewDatabase(chainDb)) if err != nil { utils.Fatalf("could not create new state: %v", err) } diff --git a/cmd/puppeth/wizard_faucet.go b/cmd/puppeth/wizard_faucet.go index 66ec98c73..51c4e2f7f 100644 --- a/cmd/puppeth/wizard_faucet.go +++ b/cmd/puppeth/wizard_faucet.go @@ -165,8 +165,7 @@ func (w *wizard) deployFaucet() { } // Load up the credential needed to release funds if infos.node.keyJSON != "" { - var key keystore.Key - if err := json.Unmarshal([]byte(infos.node.keyJSON), &key); err != nil { + if key, err := keystore.DecryptKey([]byte(infos.node.keyJSON), infos.node.keyPass); err != nil { infos.node.keyJSON, infos.node.keyPass = "", "" } else { fmt.Println() diff --git a/common/hexutil/hexutil.go b/common/hexutil/hexutil.go index 6b128ae36..582a67c22 100644 --- a/common/hexutil/hexutil.go +++ b/common/hexutil/hexutil.go @@ -32,7 +32,6 @@ package hexutil import ( "encoding/hex" - "errors" "fmt" "math/big" "strconv" @@ -41,17 +40,23 @@ import ( const uintBits = 32 << (uint64(^uint(0)) >> 63) var ( - ErrEmptyString = errors.New("empty hex string") - ErrMissingPrefix = errors.New("missing 0x prefix for hex data") - ErrSyntax = errors.New("invalid hex") - ErrEmptyNumber = errors.New("hex number has no digits after 0x") - ErrLeadingZero = errors.New("hex number has leading zero digits after 0x") - ErrOddLength = errors.New("hex string has odd length") - ErrUint64Range = errors.New("hex number does not fit into 64 bits") - ErrUintRange = fmt.Errorf("hex number does not fit into %d bits", uintBits) - ErrBig256Range = errors.New("hex number does not fit into 256 bits") + ErrEmptyString = &decError{"empty hex string"} + ErrSyntax = &decError{"invalid hex string"} + ErrMissingPrefix = &decError{"hex string without 0x prefix"} + ErrOddLength = &decError{"hex string of odd length"} + ErrEmptyNumber = &decError{"hex string \"0x\""} + ErrLeadingZero = &decError{"hex number with leading zero digits"} + ErrUint64Range = &decError{"hex number > 64 bits"} + ErrUintRange = &decError{fmt.Sprintf("hex number > %d bits", uintBits)} + ErrBig256Range = &decError{"hex number > 256 bits"} ) +type decError struct{ msg string } + +func (err decError) Error() string { + return string(err.msg) +} + // Decode decodes a hex string with 0x prefix. func Decode(input string) ([]byte, error) { if len(input) == 0 { diff --git a/common/hexutil/json.go b/common/hexutil/json.go index 1bc1d014c..943288fad 100644 --- a/common/hexutil/json.go +++ b/common/hexutil/json.go @@ -18,15 +18,19 @@ package hexutil import ( "encoding/hex" - "errors" + "encoding/json" "fmt" "math/big" + "reflect" "strconv" ) var ( - textZero = []byte(`0x0`) - errNonString = errors.New("cannot unmarshal non-string as hex data") + textZero = []byte(`0x0`) + bytesT = reflect.TypeOf(Bytes(nil)) + bigT = reflect.TypeOf((*Big)(nil)) + uintT = reflect.TypeOf(Uint(0)) + uint64T = reflect.TypeOf(Uint64(0)) ) // Bytes marshals/unmarshals as a JSON string with 0x prefix. @@ -44,9 +48,9 @@ func (b Bytes) MarshalText() ([]byte, error) { // UnmarshalJSON implements json.Unmarshaler. func (b *Bytes) UnmarshalJSON(input []byte) error { if !isString(input) { - return errNonString + return errNonString(bytesT) } - return b.UnmarshalText(input[1 : len(input)-1]) + return wrapTypeError(b.UnmarshalText(input[1:len(input)-1]), bytesT) } // UnmarshalText implements encoding.TextUnmarshaler. @@ -69,6 +73,16 @@ func (b Bytes) String() string { return Encode(b) } +// UnmarshalFixedJSON decodes the input as a string with 0x prefix. The length of out +// determines the required input length. This function is commonly used to implement the +// UnmarshalJSON method for fixed-size types. +func UnmarshalFixedJSON(typ reflect.Type, input, out []byte) error { + if !isString(input) { + return errNonString(typ) + } + return wrapTypeError(UnmarshalFixedText(typ.String(), input[1:len(input)-1], out), typ) +} + // UnmarshalFixedText decodes the input as a string with 0x prefix. The length of out // determines the required input length. This function is commonly used to implement the // UnmarshalText method for fixed-size types. @@ -127,9 +141,9 @@ func (b Big) MarshalText() ([]byte, error) { // UnmarshalJSON implements json.Unmarshaler. func (b *Big) UnmarshalJSON(input []byte) error { if !isString(input) { - return errNonString + return errNonString(bigT) } - return b.UnmarshalText(input[1 : len(input)-1]) + return wrapTypeError(b.UnmarshalText(input[1:len(input)-1]), bigT) } // UnmarshalText implements encoding.TextUnmarshaler @@ -189,9 +203,9 @@ func (b Uint64) MarshalText() ([]byte, error) { // UnmarshalJSON implements json.Unmarshaler. func (b *Uint64) UnmarshalJSON(input []byte) error { if !isString(input) { - return errNonString + return errNonString(uint64T) } - return b.UnmarshalText(input[1 : len(input)-1]) + return wrapTypeError(b.UnmarshalText(input[1:len(input)-1]), uint64T) } // UnmarshalText implements encoding.TextUnmarshaler @@ -233,9 +247,9 @@ func (b Uint) MarshalText() ([]byte, error) { // UnmarshalJSON implements json.Unmarshaler. func (b *Uint) UnmarshalJSON(input []byte) error { if !isString(input) { - return errNonString + return errNonString(uintT) } - return b.UnmarshalText(input[1 : len(input)-1]) + return wrapTypeError(b.UnmarshalText(input[1:len(input)-1]), uintT) } // UnmarshalText implements encoding.TextUnmarshaler. @@ -295,3 +309,14 @@ func checkNumberText(input []byte) (raw []byte, err error) { } return input, nil } + +func wrapTypeError(err error, typ reflect.Type) error { + if _, ok := err.(*decError); ok { + return &json.UnmarshalTypeError{Value: err.Error(), Type: typ} + } + return err +} + +func errNonString(typ reflect.Type) error { + return &json.UnmarshalTypeError{Value: "non-string", Type: typ} +} diff --git a/common/hexutil/json_test.go b/common/hexutil/json_test.go index e4e827491..8a6b8643a 100644 --- a/common/hexutil/json_test.go +++ b/common/hexutil/json_test.go @@ -62,12 +62,12 @@ var errJSONEOF = errors.New("unexpected end of JSON input") var unmarshalBytesTests = []unmarshalTest{ // invalid encoding {input: "", wantErr: errJSONEOF}, - {input: "null", wantErr: errNonString}, - {input: "10", wantErr: errNonString}, - {input: `"0"`, wantErr: ErrMissingPrefix}, - {input: `"0x0"`, wantErr: ErrOddLength}, - {input: `"0xxx"`, wantErr: ErrSyntax}, - {input: `"0x01zz01"`, wantErr: ErrSyntax}, + {input: "null", wantErr: errNonString(bytesT)}, + {input: "10", wantErr: errNonString(bytesT)}, + {input: `"0"`, wantErr: wrapTypeError(ErrMissingPrefix, bytesT)}, + {input: `"0x0"`, wantErr: wrapTypeError(ErrOddLength, bytesT)}, + {input: `"0xxx"`, wantErr: wrapTypeError(ErrSyntax, bytesT)}, + {input: `"0x01zz01"`, wantErr: wrapTypeError(ErrSyntax, bytesT)}, // valid encoding {input: `""`, want: referenceBytes("")}, @@ -127,16 +127,16 @@ func TestMarshalBytes(t *testing.T) { var unmarshalBigTests = []unmarshalTest{ // invalid encoding {input: "", wantErr: errJSONEOF}, - {input: "null", wantErr: errNonString}, - {input: "10", wantErr: errNonString}, - {input: `"0"`, wantErr: ErrMissingPrefix}, - {input: `"0x"`, wantErr: ErrEmptyNumber}, - {input: `"0x01"`, wantErr: ErrLeadingZero}, - {input: `"0xx"`, wantErr: ErrSyntax}, - {input: `"0x1zz01"`, wantErr: ErrSyntax}, + {input: "null", wantErr: errNonString(bigT)}, + {input: "10", wantErr: errNonString(bigT)}, + {input: `"0"`, wantErr: wrapTypeError(ErrMissingPrefix, bigT)}, + {input: `"0x"`, wantErr: wrapTypeError(ErrEmptyNumber, bigT)}, + {input: `"0x01"`, wantErr: wrapTypeError(ErrLeadingZero, bigT)}, + {input: `"0xx"`, wantErr: wrapTypeError(ErrSyntax, bigT)}, + {input: `"0x1zz01"`, wantErr: wrapTypeError(ErrSyntax, bigT)}, { input: `"0x10000000000000000000000000000000000000000000000000000000000000000"`, - wantErr: ErrBig256Range, + wantErr: wrapTypeError(ErrBig256Range, bigT), }, // valid encoding @@ -208,14 +208,14 @@ func TestMarshalBig(t *testing.T) { var unmarshalUint64Tests = []unmarshalTest{ // invalid encoding {input: "", wantErr: errJSONEOF}, - {input: "null", wantErr: errNonString}, - {input: "10", wantErr: errNonString}, - {input: `"0"`, wantErr: ErrMissingPrefix}, - {input: `"0x"`, wantErr: ErrEmptyNumber}, - {input: `"0x01"`, wantErr: ErrLeadingZero}, - {input: `"0xfffffffffffffffff"`, wantErr: ErrUint64Range}, - {input: `"0xx"`, wantErr: ErrSyntax}, - {input: `"0x1zz01"`, wantErr: ErrSyntax}, + {input: "null", wantErr: errNonString(uint64T)}, + {input: "10", wantErr: errNonString(uint64T)}, + {input: `"0"`, wantErr: wrapTypeError(ErrMissingPrefix, uint64T)}, + {input: `"0x"`, wantErr: wrapTypeError(ErrEmptyNumber, uint64T)}, + {input: `"0x01"`, wantErr: wrapTypeError(ErrLeadingZero, uint64T)}, + {input: `"0xfffffffffffffffff"`, wantErr: wrapTypeError(ErrUint64Range, uint64T)}, + {input: `"0xx"`, wantErr: wrapTypeError(ErrSyntax, uint64T)}, + {input: `"0x1zz01"`, wantErr: wrapTypeError(ErrSyntax, uint64T)}, // valid encoding {input: `""`, want: uint64(0)}, @@ -298,15 +298,15 @@ var ( var unmarshalUintTests = []unmarshalTest{ // invalid encoding {input: "", wantErr: errJSONEOF}, - {input: "null", wantErr: errNonString}, - {input: "10", wantErr: errNonString}, - {input: `"0"`, wantErr: ErrMissingPrefix}, - {input: `"0x"`, wantErr: ErrEmptyNumber}, - {input: `"0x01"`, wantErr: ErrLeadingZero}, - {input: `"0x100000000"`, want: uint(maxUint33bits), wantErr32bit: ErrUintRange}, - {input: `"0xfffffffffffffffff"`, wantErr: ErrUintRange}, - {input: `"0xx"`, wantErr: ErrSyntax}, - {input: `"0x1zz01"`, wantErr: ErrSyntax}, + {input: "null", wantErr: errNonString(uintT)}, + {input: "10", wantErr: errNonString(uintT)}, + {input: `"0"`, wantErr: wrapTypeError(ErrMissingPrefix, uintT)}, + {input: `"0x"`, wantErr: wrapTypeError(ErrEmptyNumber, uintT)}, + {input: `"0x01"`, wantErr: wrapTypeError(ErrLeadingZero, uintT)}, + {input: `"0x100000000"`, want: uint(maxUint33bits), wantErr32bit: wrapTypeError(ErrUintRange, uintT)}, + {input: `"0xfffffffffffffffff"`, wantErr: wrapTypeError(ErrUintRange, uintT)}, + {input: `"0xx"`, wantErr: wrapTypeError(ErrSyntax, uintT)}, + {input: `"0x1zz01"`, wantErr: wrapTypeError(ErrSyntax, uintT)}, // valid encoding {input: `""`, want: uint(0)}, @@ -317,7 +317,7 @@ var unmarshalUintTests = []unmarshalTest{ {input: `"0x1122aaff"`, want: uint(0x1122aaff)}, {input: `"0xbbb"`, want: uint(0xbbb)}, {input: `"0xffffffff"`, want: uint(0xffffffff)}, - {input: `"0xffffffffffffffff"`, want: uint(maxUint64bits), wantErr32bit: ErrUintRange}, + {input: `"0xffffffffffffffff"`, want: uint(maxUint64bits), wantErr32bit: wrapTypeError(ErrUintRange, uintT)}, } func TestUnmarshalUint(t *testing.T) { diff --git a/common/types.go b/common/types.go index 05288bf46..803726634 100644 --- a/common/types.go +++ b/common/types.go @@ -31,6 +31,11 @@ const ( AddressLength = 20 ) +var ( + hashT = reflect.TypeOf(Hash{}) + addressT = reflect.TypeOf(Address{}) +) + // Hash represents the 32 byte Keccak256 hash of arbitrary data. type Hash [HashLength]byte @@ -72,6 +77,11 @@ func (h *Hash) UnmarshalText(input []byte) error { return hexutil.UnmarshalFixedText("Hash", input, h[:]) } +// UnmarshalJSON parses a hash in hex syntax. +func (h *Hash) UnmarshalJSON(input []byte) error { + return hexutil.UnmarshalFixedJSON(hashT, input, h[:]) +} + // MarshalText returns the hex representation of h. func (h Hash) MarshalText() ([]byte, error) { return hexutil.Bytes(h[:]).MarshalText() @@ -194,6 +204,11 @@ func (a *Address) UnmarshalText(input []byte) error { return hexutil.UnmarshalFixedText("Address", input, a[:]) } +// UnmarshalJSON parses a hash in hex syntax. +func (a *Address) UnmarshalJSON(input []byte) error { + return hexutil.UnmarshalFixedJSON(addressT, input, a[:]) +} + // UnprefixedHash allows marshaling an Address without 0x prefix. type UnprefixedAddress Address diff --git a/common/types_test.go b/common/types_test.go index 9f9d8b767..154c33063 100644 --- a/common/types_test.go +++ b/common/types_test.go @@ -21,8 +21,6 @@ import ( "math/big" "strings" "testing" - - "github.com/ethereum/go-ethereum/common/hexutil" ) func TestBytesConversion(t *testing.T) { @@ -43,10 +41,10 @@ func TestHashJsonValidation(t *testing.T) { Size int Error string }{ - {"", 62, hexutil.ErrMissingPrefix.Error()}, - {"0x", 66, "hex string has length 66, want 64 for Hash"}, - {"0x", 63, hexutil.ErrOddLength.Error()}, - {"0x", 0, "hex string has length 0, want 64 for Hash"}, + {"", 62, "json: cannot unmarshal hex string without 0x prefix into Go value of type common.Hash"}, + {"0x", 66, "hex string has length 66, want 64 for common.Hash"}, + {"0x", 63, "json: cannot unmarshal hex string of odd length into Go value of type common.Hash"}, + {"0x", 0, "hex string has length 0, want 64 for common.Hash"}, {"0x", 64, ""}, {"0X", 64, ""}, } diff --git a/core/block_validator.go b/core/block_validator.go index 4f85df007..e9cfd0482 100644 --- a/core/block_validator.go +++ b/core/block_validator.go @@ -52,16 +52,10 @@ func NewBlockValidator(config *params.ChainConfig, blockchain *BlockChain, engin // validated at this point. func (v *BlockValidator) ValidateBody(block *types.Block) error { // Check whether the block's known, and if not, that it's linkable - if v.bc.HasBlock(block.Hash()) { - if _, err := state.New(block.Root(), v.bc.chainDb); err == nil { - return ErrKnownBlock - } + if v.bc.HasBlockAndState(block.Hash()) { + return ErrKnownBlock } - parent := v.bc.GetBlock(block.ParentHash(), block.NumberU64()-1) - if parent == nil { - return consensus.ErrUnknownAncestor - } - if _, err := state.New(parent.Root(), v.bc.chainDb); err != nil { + if !v.bc.HasBlockAndState(block.ParentHash()) { return consensus.ErrUnknownAncestor } // Header validity is known at this point, check the uncles and transactions diff --git a/core/blockchain.go b/core/blockchain.go index 073b91bab..6772ea284 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -92,7 +92,7 @@ type BlockChain struct { currentBlock *types.Block // Current head of the block chain currentFastBlock *types.Block // Current head of the fast-sync chain (may be above the block chain!) - stateCache *state.StateDB // State database to reuse between imports (contains state cache) + stateCache state.Database // State database to reuse between imports (contains state cache) bodyCache *lru.Cache // Cache for the most recent block bodies bodyRLPCache *lru.Cache // Cache for the most recent block bodies in RLP encoded format blockCache *lru.Cache // Cache for the most recent entire blocks @@ -125,6 +125,7 @@ func NewBlockChain(chainDb ethdb.Database, config *params.ChainConfig, engine co bc := &BlockChain{ config: config, chainDb: chainDb, + stateCache: state.NewDatabase(chainDb), eventMux: mux, quit: make(chan struct{}), bodyCache: bodyCache, @@ -190,7 +191,7 @@ func (bc *BlockChain) loadLastState() error { return bc.Reset() } // Make sure the state associated with the block is available - if _, err := state.New(currentBlock.Root(), bc.chainDb); err != nil { + if _, err := state.New(currentBlock.Root(), bc.stateCache); err != nil { // Dangling block without a state associated, init from scratch log.Warn("Head state missing, resetting chain", "number", currentBlock.Number(), "hash", currentBlock.Hash()) return bc.Reset() @@ -214,12 +215,6 @@ func (bc *BlockChain) loadLastState() error { bc.currentFastBlock = block } } - // Initialize a statedb cache to ensure singleton account bloom filter generation - statedb, err := state.New(bc.currentBlock.Root(), bc.chainDb) - if err != nil { - return err - } - bc.stateCache = statedb // Issue a status log for the user headerTd := bc.GetTd(currentHeader.Hash(), currentHeader.Number.Uint64()) @@ -261,7 +256,7 @@ func (bc *BlockChain) SetHead(head uint64) error { bc.currentBlock = bc.GetBlock(currentHeader.Hash(), currentHeader.Number.Uint64()) } if bc.currentBlock != nil { - if _, err := state.New(bc.currentBlock.Root(), bc.chainDb); err != nil { + if _, err := state.New(bc.currentBlock.Root(), bc.stateCache); err != nil { // Rewound state missing, rolled back to before pivot, reset to genesis bc.currentBlock = nil } @@ -384,7 +379,7 @@ func (bc *BlockChain) State() (*state.StateDB, error) { // StateAt returns a new mutable state based on a particular point in time. func (bc *BlockChain) StateAt(root common.Hash) (*state.StateDB, error) { - return bc.stateCache.New(root) + return state.New(root, bc.stateCache) } // Reset purges the entire blockchain, restoring it to its genesis state. @@ -531,7 +526,7 @@ func (bc *BlockChain) HasBlockAndState(hash common.Hash) bool { return false } // Ensure the associated state is also present - _, err := state.New(block.Root(), bc.chainDb) + _, err := bc.stateCache.OpenTrie(block.Root()) return err == nil } @@ -959,31 +954,30 @@ func (bc *BlockChain) InsertChain(chain types.Blocks) (int, error) { } // Create a new statedb using the parent block and report an // error if it fails. - switch { - case i == 0: - err = bc.stateCache.Reset(bc.GetBlock(block.ParentHash(), block.NumberU64()-1).Root()) - default: - err = bc.stateCache.Reset(chain[i-1].Root()) + var parent *types.Block + if i == 0 { + parent = bc.GetBlock(block.ParentHash(), block.NumberU64()-1) + } else { + parent = chain[i-1] } + state, err := state.New(parent.Root(), bc.stateCache) if err != nil { - bc.reportBlock(block, nil, err) return i, err } // Process block using the parent state as reference point. - receipts, logs, usedGas, err := bc.processor.Process(block, bc.stateCache, bc.vmConfig) + receipts, logs, usedGas, err := bc.processor.Process(block, state, bc.vmConfig) if err != nil { bc.reportBlock(block, receipts, err) return i, err } // Validate the state using the default validator - err = bc.Validator().ValidateState(block, bc.GetBlock(block.ParentHash(), block.NumberU64()-1), bc.stateCache, receipts, usedGas) + err = bc.Validator().ValidateState(block, parent, state, receipts, usedGas) if err != nil { bc.reportBlock(block, receipts, err) return i, err } // Write state changes to database - _, err = bc.stateCache.Commit(bc.config.IsEIP158(block.Number())) - if err != nil { + if _, err = state.CommitTo(bc.chainDb, bc.config.IsEIP158(block.Number())); err != nil { return i, err } @@ -1021,7 +1015,7 @@ func (bc *BlockChain) InsertChain(chain types.Blocks) (int, error) { return i, err } // Write hash preimages - if err := WritePreimages(bc.chainDb, block.NumberU64(), bc.stateCache.Preimages()); err != nil { + if err := WritePreimages(bc.chainDb, block.NumberU64(), state.Preimages()); err != nil { return i, err } case SideStatTy: @@ -1079,7 +1073,7 @@ func (st *insertStats) report(chain []*types.Block, index int) { } log.Info("Imported new chain segment", context...) - *st = insertStats{startTime: now, lastIndex: index} + *st = insertStats{startTime: now, lastIndex: index + 1} } } diff --git a/core/blockchain_test.go b/core/blockchain_test.go index 7505208e1..371522ab7 100644 --- a/core/blockchain_test.go +++ b/core/blockchain_test.go @@ -131,7 +131,7 @@ func testBlockChainImport(chain types.Blocks, blockchain *BlockChain) error { } return err } - statedb, err := state.New(blockchain.GetBlockByHash(block.ParentHash()).Root(), blockchain.chainDb) + statedb, err := state.New(blockchain.GetBlockByHash(block.ParentHash()).Root(), blockchain.stateCache) if err != nil { return err } @@ -148,7 +148,7 @@ func testBlockChainImport(chain types.Blocks, blockchain *BlockChain) error { blockchain.mu.Lock() WriteTd(blockchain.chainDb, block.Hash(), block.NumberU64(), new(big.Int).Add(block.Difficulty(), blockchain.GetTdByHash(block.ParentHash()))) WriteBlock(blockchain.chainDb, block) - statedb.Commit(false) + statedb.CommitTo(blockchain.chainDb, false) blockchain.mu.Unlock() } return nil @@ -1131,7 +1131,7 @@ func TestEIP161AccountRemoval(t *testing.T) { if _, err := blockchain.InsertChain(types.Blocks{blocks[0]}); err != nil { t.Fatal(err) } - if !blockchain.stateCache.Exist(theAddr) { + if st, _ := blockchain.State(); !st.Exist(theAddr) { t.Error("expected account to exist") } @@ -1139,7 +1139,7 @@ func TestEIP161AccountRemoval(t *testing.T) { if _, err := blockchain.InsertChain(types.Blocks{blocks[1]}); err != nil { t.Fatal(err) } - if blockchain.stateCache.Exist(theAddr) { + if st, _ := blockchain.State(); st.Exist(theAddr) { t.Error("account should not exist") } @@ -1147,7 +1147,7 @@ func TestEIP161AccountRemoval(t *testing.T) { if _, err := blockchain.InsertChain(types.Blocks{blocks[2]}); err != nil { t.Fatal(err) } - if blockchain.stateCache.Exist(theAddr) { + if st, _ := blockchain.State(); st.Exist(theAddr) { t.Error("account should not exist") } } diff --git a/core/chain_makers.go b/core/chain_makers.go index cc14f8fb8..38a69d42a 100644 --- a/core/chain_makers.go +++ b/core/chain_makers.go @@ -181,7 +181,7 @@ func GenerateChain(config *params.ChainConfig, parent *types.Block, db ethdb.Dat gen(i, b) } ethash.AccumulateRewards(statedb, h, b.uncles) - root, err := statedb.Commit(config.IsEIP158(h.Number)) + root, err := statedb.CommitTo(db, config.IsEIP158(h.Number)) if err != nil { panic(fmt.Sprintf("state write error: %v", err)) } @@ -189,7 +189,7 @@ func GenerateChain(config *params.ChainConfig, parent *types.Block, db ethdb.Dat return types.NewBlock(h, b.txs, b.uncles, b.receipts), b.receipts } for i := 0; i < n; i++ { - statedb, err := state.New(parent.Root(), db) + statedb, err := state.New(parent.Root(), state.NewDatabase(db)) if err != nil { panic(err) } diff --git a/core/genesis.go b/core/genesis.go index 947a53c70..5815d5901 100644 --- a/core/genesis.go +++ b/core/genesis.go @@ -176,7 +176,7 @@ func (g *Genesis) configOrDefault(ghash common.Hash) *params.ChainConfig { // ToBlock creates the block and state of a genesis specification. func (g *Genesis) ToBlock() (*types.Block, *state.StateDB) { db, _ := ethdb.NewMemDatabase() - statedb, _ := state.New(common.Hash{}, db) + statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) for addr, account := range g.Alloc { statedb.AddBalance(addr, account.Balance) statedb.SetCode(addr, account.Code) diff --git a/core/state/database.go b/core/state/database.go new file mode 100644 index 000000000..946625e76 --- /dev/null +++ b/core/state/database.go @@ -0,0 +1,154 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +package state + +import ( + "fmt" + "sync" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/trie" + lru "github.com/hashicorp/golang-lru" +) + +// Trie cache generation limit after which to evic trie nodes from memory. +var MaxTrieCacheGen = uint16(120) + +const ( + // Number of past tries to keep. This value is chosen such that + // reasonable chain reorg depths will hit an existing trie. + maxPastTries = 12 + + // Number of codehash->size associations to keep. + codeSizeCacheSize = 100000 +) + +// Database wraps access to tries and contract code. +type Database interface { + // Accessing tries: + // OpenTrie opens the main account trie. + // OpenStorageTrie opens the storage trie of an account. + OpenTrie(root common.Hash) (Trie, error) + OpenStorageTrie(addrHash, root common.Hash) (Trie, error) + // Accessing contract code: + ContractCode(addrHash, codeHash common.Hash) ([]byte, error) + ContractCodeSize(addrHash, codeHash common.Hash) (int, error) + // CopyTrie returns an independent copy of the given trie. + CopyTrie(Trie) Trie +} + +// Trie is a Ethereum Merkle Trie. +type Trie interface { + TryGet(key []byte) ([]byte, error) + TryUpdate(key, value []byte) error + TryDelete(key []byte) error + CommitTo(trie.DatabaseWriter) (common.Hash, error) + Hash() common.Hash + NodeIterator(startKey []byte) trie.NodeIterator + GetKey([]byte) []byte // TODO(fjl): remove this when SecureTrie is removed +} + +// NewDatabase creates a backing store for state. The returned database is safe for +// concurrent use and retains cached trie nodes in memory. +func NewDatabase(db ethdb.Database) Database { + csc, _ := lru.New(codeSizeCacheSize) + return &cachingDB{db: db, codeSizeCache: csc} +} + +type cachingDB struct { + db ethdb.Database + mu sync.Mutex + pastTries []*trie.SecureTrie + codeSizeCache *lru.Cache +} + +func (db *cachingDB) OpenTrie(root common.Hash) (Trie, error) { + db.mu.Lock() + defer db.mu.Unlock() + + for i := len(db.pastTries) - 1; i >= 0; i-- { + if db.pastTries[i].Hash() == root { + return cachedTrie{db.pastTries[i].Copy(), db}, nil + } + } + tr, err := trie.NewSecure(root, db.db, MaxTrieCacheGen) + if err != nil { + return nil, err + } + return cachedTrie{tr, db}, nil +} + +func (db *cachingDB) pushTrie(t *trie.SecureTrie) { + db.mu.Lock() + defer db.mu.Unlock() + + if len(db.pastTries) >= maxPastTries { + copy(db.pastTries, db.pastTries[1:]) + db.pastTries[len(db.pastTries)-1] = t + } else { + db.pastTries = append(db.pastTries, t) + } +} + +func (db *cachingDB) OpenStorageTrie(addrHash, root common.Hash) (Trie, error) { + return trie.NewSecure(root, db.db, 0) +} + +func (db *cachingDB) CopyTrie(t Trie) Trie { + switch t := t.(type) { + case cachedTrie: + return cachedTrie{t.SecureTrie.Copy(), db} + case *trie.SecureTrie: + return t.Copy() + default: + panic(fmt.Errorf("unknown trie type %T", t)) + } +} + +func (db *cachingDB) ContractCode(addrHash, codeHash common.Hash) ([]byte, error) { + code, err := db.db.Get(codeHash[:]) + if err == nil { + db.codeSizeCache.Add(codeHash, len(code)) + } + return code, err +} + +func (db *cachingDB) ContractCodeSize(addrHash, codeHash common.Hash) (int, error) { + if cached, ok := db.codeSizeCache.Get(codeHash); ok { + return cached.(int), nil + } + code, err := db.ContractCode(addrHash, codeHash) + if err == nil { + db.codeSizeCache.Add(codeHash, len(code)) + } + return len(code), err +} + +// cachedTrie inserts its trie into a cachingDB on commit. +type cachedTrie struct { + *trie.SecureTrie + db *cachingDB +} + +func (m cachedTrie) CommitTo(dbw trie.DatabaseWriter) (common.Hash, error) { + root, err := m.SecureTrie.CommitTo(dbw) + if err == nil { + m.db.pushTrie(m.SecureTrie) + } + return root, err +} diff --git a/core/state/dump.go b/core/state/dump.go index ffa1a7283..46e612850 100644 --- a/core/state/dump.go +++ b/core/state/dump.go @@ -41,7 +41,7 @@ type Dump struct { func (self *StateDB) RawDump() Dump { dump := Dump{ - Root: common.Bytes2Hex(self.trie.Root()), + Root: fmt.Sprintf("%x", self.trie.Hash()), Accounts: make(map[string]DumpAccount), } diff --git a/core/state/iterator.go b/core/state/iterator.go index a8a2722ae..6a5c73d3d 100644 --- a/core/state/iterator.go +++ b/core/state/iterator.go @@ -19,7 +19,6 @@ package state import ( "bytes" "fmt" - "math/big" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/rlp" @@ -105,16 +104,11 @@ func (it *NodeIterator) step() error { return nil } // Otherwise we've reached an account node, initiate data iteration - var account struct { - Nonce uint64 - Balance *big.Int - Root common.Hash - CodeHash []byte - } + var account Account if err := rlp.Decode(bytes.NewReader(it.stateIt.LeafBlob()), &account); err != nil { return err } - dataTrie, err := trie.New(account.Root, it.state.db) + dataTrie, err := it.state.db.OpenStorageTrie(common.BytesToHash(it.stateIt.LeafKey()), account.Root) if err != nil { return err } @@ -124,7 +118,8 @@ func (it *NodeIterator) step() error { } if !bytes.Equal(account.CodeHash, emptyCodeHash) { it.codeHash = common.BytesToHash(account.CodeHash) - it.code, err = it.state.db.Get(account.CodeHash) + addrHash := common.BytesToHash(it.stateIt.LeafKey()) + it.code, err = it.state.db.ContractCode(addrHash, common.BytesToHash(account.CodeHash)) if err != nil { return fmt.Errorf("code %x: %v", account.CodeHash, err) } diff --git a/core/state/iterator_test.go b/core/state/iterator_test.go index aa9c5b728..ff66ba7a9 100644 --- a/core/state/iterator_test.go +++ b/core/state/iterator_test.go @@ -21,13 +21,12 @@ import ( "testing" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/ethdb" ) // Tests that the node iterator indeed walks over the entire database contents. func TestNodeIteratorCoverage(t *testing.T) { // Create some arbitrary test state to iterate - db, root, _ := makeTestState() + db, mem, root, _ := makeTestState() state, err := New(root, db) if err != nil { @@ -40,13 +39,14 @@ func TestNodeIteratorCoverage(t *testing.T) { hashes[it.Hash] = struct{}{} } } + // Cross check the hashes and the database itself for hash := range hashes { - if _, err := db.Get(hash.Bytes()); err != nil { + if _, err := mem.Get(hash.Bytes()); err != nil { t.Errorf("failed to retrieve reported node %x: %v", hash, err) } } - for _, key := range db.(*ethdb.MemDatabase).Keys() { + for _, key := range mem.Keys() { if bytes.HasPrefix(key, []byte("secure-key-")) { continue } diff --git a/core/state/managed_state_test.go b/core/state/managed_state_test.go index ea5737a08..1cfdd3a89 100644 --- a/core/state/managed_state_test.go +++ b/core/state/managed_state_test.go @@ -27,7 +27,7 @@ var addr = common.BytesToAddress([]byte("test")) func create() (*ManagedState, *account) { db, _ := ethdb.NewMemDatabase() - statedb, _ := New(common.Hash{}, db) + statedb, _ := New(common.Hash{}, NewDatabase(db)) ms := ManageState(statedb) ms.StateDB.SetNonce(addr, 100) ms.accounts[addr] = newAccount(ms.StateDB.getStateObject(addr)) diff --git a/core/state/state_object.go b/core/state/state_object.go index dcad9d068..b2378c69c 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -62,9 +62,10 @@ func (self Storage) Copy() Storage { // Account values can be accessed and modified through the object. // Finally, call CommitTrie to write the modified storage trie into a database. type stateObject struct { - address common.Address // Ethereum address of this account - data Account - db *StateDB + address common.Address + addrHash common.Hash // hash of ethereum address of the account + data Account + db *StateDB // DB error. // State objects are used by the consensus core and VM which are @@ -74,8 +75,8 @@ type stateObject struct { dbErr error // Write caches. - trie *trie.SecureTrie // storage trie, which becomes non-nil on first access - code Code // contract bytecode, which gets set when code is loaded + trie Trie // storage trie, which becomes non-nil on first access + code Code // contract bytecode, which gets set when code is loaded cachedStorage Storage // Storage entry cache to avoid duplicate reads dirtyStorage Storage // Storage entries that need to be flushed to disk @@ -112,7 +113,15 @@ func newObject(db *StateDB, address common.Address, data Account, onDirty func(a if data.CodeHash == nil { data.CodeHash = emptyCodeHash } - return &stateObject{db: db, address: address, data: data, cachedStorage: make(Storage), dirtyStorage: make(Storage), onDirty: onDirty} + return &stateObject{ + db: db, + address: address, + addrHash: crypto.Keccak256Hash(address[:]), + data: data, + cachedStorage: make(Storage), + dirtyStorage: make(Storage), + onDirty: onDirty, + } } // EncodeRLP implements rlp.Encoder. @@ -148,12 +157,12 @@ func (c *stateObject) touch() { c.touched = true } -func (c *stateObject) getTrie(db trie.Database) *trie.SecureTrie { +func (c *stateObject) getTrie(db Database) Trie { if c.trie == nil { var err error - c.trie, err = trie.NewSecure(c.data.Root, db, 0) + c.trie, err = db.OpenStorageTrie(c.addrHash, c.data.Root) if err != nil { - c.trie, _ = trie.NewSecure(common.Hash{}, db, 0) + c.trie, _ = db.OpenStorageTrie(c.addrHash, common.Hash{}) c.setError(fmt.Errorf("can't create storage trie: %v", err)) } } @@ -161,13 +170,18 @@ func (c *stateObject) getTrie(db trie.Database) *trie.SecureTrie { } // GetState returns a value in account storage. -func (self *stateObject) GetState(db trie.Database, key common.Hash) common.Hash { +func (self *stateObject) GetState(db Database, key common.Hash) common.Hash { value, exists := self.cachedStorage[key] if exists { return value } // Load from DB in case it is missing. - if enc := self.getTrie(db).Get(key[:]); len(enc) > 0 { + enc, err := self.getTrie(db).TryGet(key[:]) + if err != nil { + self.setError(err) + return common.Hash{} + } + if len(enc) > 0 { _, content, _, err := rlp.Split(enc) if err != nil { self.setError(err) @@ -181,7 +195,7 @@ func (self *stateObject) GetState(db trie.Database, key common.Hash) common.Hash } // SetState updates a value in account storage. -func (self *stateObject) SetState(db trie.Database, key, value common.Hash) { +func (self *stateObject) SetState(db Database, key, value common.Hash) { self.db.journal = append(self.db.journal, storageChange{ account: &self.address, key: key, @@ -201,30 +215,30 @@ func (self *stateObject) setState(key, value common.Hash) { } // updateTrie writes cached storage modifications into the object's storage trie. -func (self *stateObject) updateTrie(db trie.Database) *trie.SecureTrie { +func (self *stateObject) updateTrie(db Database) Trie { tr := self.getTrie(db) for key, value := range self.dirtyStorage { delete(self.dirtyStorage, key) if (value == common.Hash{}) { - tr.Delete(key[:]) + self.setError(tr.TryDelete(key[:])) continue } // Encoding []byte cannot fail, ok to ignore the error. v, _ := rlp.EncodeToBytes(bytes.TrimLeft(value[:], "\x00")) - tr.Update(key[:], v) + self.setError(tr.TryUpdate(key[:], v)) } return tr } // UpdateRoot sets the trie root to the current root hash of -func (self *stateObject) updateRoot(db trie.Database) { +func (self *stateObject) updateRoot(db Database) { self.updateTrie(db) self.data.Root = self.trie.Hash() } // CommitTrie the storage trie of the object to dwb. // This updates the trie root. -func (self *stateObject) CommitTrie(db trie.Database, dbw trie.DatabaseWriter) error { +func (self *stateObject) CommitTrie(db Database, dbw trie.DatabaseWriter) error { self.updateTrie(db) if self.dbErr != nil { return self.dbErr @@ -282,9 +296,7 @@ func (c *stateObject) ReturnGas(gas *big.Int) {} func (self *stateObject) deepCopy(db *StateDB, onDirty func(addr common.Address)) *stateObject { stateObject := newObject(db, self.address, self.data, onDirty) if self.trie != nil { - // A shallow copy makes the two tries independent. - cpy := *self.trie - stateObject.trie = &cpy + stateObject.trie = db.db.CopyTrie(self.trie) } stateObject.code = self.code stateObject.dirtyStorage = self.dirtyStorage.Copy() @@ -305,14 +317,14 @@ func (c *stateObject) Address() common.Address { } // Code returns the contract code associated with this object, if any. -func (self *stateObject) Code(db trie.Database) []byte { +func (self *stateObject) Code(db Database) []byte { if self.code != nil { return self.code } if bytes.Equal(self.CodeHash(), emptyCodeHash) { return nil } - code, err := db.Get(self.CodeHash()) + code, err := db.ContractCode(self.addrHash, common.BytesToHash(self.CodeHash())) if err != nil { self.setError(fmt.Errorf("can't load code hash %x: %v", self.CodeHash(), err)) } diff --git a/core/state/state_test.go b/core/state/state_test.go index 3bc63c148..bbae3685b 100644 --- a/core/state/state_test.go +++ b/core/state/state_test.go @@ -21,14 +21,14 @@ import ( "math/big" "testing" - checker "gopkg.in/check.v1" - "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethdb" + checker "gopkg.in/check.v1" ) type StateSuite struct { + db *ethdb.MemDatabase state *StateDB } @@ -48,7 +48,7 @@ func (s *StateSuite) TestDump(c *checker.C) { // write some of them to the trie s.state.updateStateObject(obj1) s.state.updateStateObject(obj2) - s.state.Commit(false) + s.state.CommitTo(s.db, false) // check that dump contains the state objects that are in trie got := string(s.state.Dump()) @@ -87,23 +87,20 @@ func (s *StateSuite) TestDump(c *checker.C) { } func (s *StateSuite) SetUpTest(c *checker.C) { - db, _ := ethdb.NewMemDatabase() - s.state, _ = New(common.Hash{}, db) + s.db, _ = ethdb.NewMemDatabase() + s.state, _ = New(common.Hash{}, NewDatabase(s.db)) } -func TestNull(t *testing.T) { - db, _ := ethdb.NewMemDatabase() - state, _ := New(common.Hash{}, db) - +func (s *StateSuite) TestNull(c *checker.C) { address := common.HexToAddress("0x823140710bf13990e4500136726d8b55") - state.CreateAccount(address) + s.state.CreateAccount(address) //value := common.FromHex("0x823140710bf13990e4500136726d8b55") var value common.Hash - state.SetState(address, common.Hash{}, value) - state.Commit(false) - value = state.GetState(address, common.Hash{}) + s.state.SetState(address, common.Hash{}, value) + s.state.CommitTo(s.db, false) + value = s.state.GetState(address, common.Hash{}) if !common.EmptyHash(value) { - t.Errorf("expected empty hash. got %x", value) + c.Errorf("expected empty hash. got %x", value) } } @@ -129,17 +126,15 @@ func (s *StateSuite) TestSnapshot(c *checker.C) { c.Assert(data1, checker.DeepEquals, res) } -func TestSnapshotEmpty(t *testing.T) { - db, _ := ethdb.NewMemDatabase() - state, _ := New(common.Hash{}, db) - state.RevertToSnapshot(state.Snapshot()) +func (s *StateSuite) TestSnapshotEmpty(c *checker.C) { + s.state.RevertToSnapshot(s.state.Snapshot()) } // use testing instead of checker because checker does not support // printing/logging in tests (-check.vv does not work) func TestSnapshot2(t *testing.T) { db, _ := ethdb.NewMemDatabase() - state, _ := New(common.Hash{}, db) + state, _ := New(common.Hash{}, NewDatabase(db)) stateobjaddr0 := toAddr([]byte("so0")) stateobjaddr1 := toAddr([]byte("so1")) @@ -160,7 +155,7 @@ func TestSnapshot2(t *testing.T) { so0.deleted = false state.setStateObject(so0) - root, _ := state.Commit(false) + root, _ := state.CommitTo(db, false) state.Reset(root) // and one with deleted == true @@ -182,8 +177,8 @@ func TestSnapshot2(t *testing.T) { so0Restored := state.getStateObject(stateobjaddr0) // Update lazily-loaded values before comparing. - so0Restored.GetState(db, storageaddr) - so0Restored.Code(db) + so0Restored.GetState(state.db, storageaddr) + so0Restored.Code(state.db) // non-deleted is equal (restored) compareStateObjects(so0Restored, so0, t) diff --git a/core/state/statedb.go b/core/state/statedb.go index 05869a0c8..694374f82 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -26,23 +26,9 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie" - lru "github.com/hashicorp/golang-lru" -) - -// Trie cache generation limit after which to evic trie nodes from memory. -var MaxTrieCacheGen = uint16(120) - -const ( - // Number of past tries to keep. This value is chosen such that - // reasonable chain reorg depths will hit an existing trie. - maxPastTries = 12 - - // Number of codehash->size associations to keep. - codeSizeCacheSize = 100000 ) type revision struct { @@ -56,16 +42,21 @@ type revision struct { // * Contracts // * Accounts type StateDB struct { - db ethdb.Database - trie *trie.SecureTrie - pastTries []*trie.SecureTrie - codeSizeCache *lru.Cache + db Database + trie Trie // This map holds 'live' objects, which will get modified while processing a state transition. stateObjects map[common.Address]*stateObject stateObjectsDirty map[common.Address]struct{} stateObjectsDestructed map[common.Address]struct{} + // DB error. + // State objects are used by the consensus core and VM which are + // unable to deal with database-level errors. Any error that occurs + // during a database read is memoized here and will eventually be returned + // by StateDB.Commit. + dbErr error + // The refund counter, also used by state transitioning. refund *big.Int @@ -86,16 +77,14 @@ type StateDB struct { } // Create a new state from a given trie -func New(root common.Hash, db ethdb.Database) (*StateDB, error) { - tr, err := trie.NewSecure(root, db, MaxTrieCacheGen) +func New(root common.Hash, db Database) (*StateDB, error) { + tr, err := db.OpenTrie(root) if err != nil { return nil, err } - csc, _ := lru.New(codeSizeCacheSize) return &StateDB{ db: db, trie: tr, - codeSizeCache: csc, stateObjects: make(map[common.Address]*stateObject), stateObjectsDirty: make(map[common.Address]struct{}), stateObjectsDestructed: make(map[common.Address]struct{}), @@ -105,36 +94,21 @@ func New(root common.Hash, db ethdb.Database) (*StateDB, error) { }, nil } -// New creates a new statedb by reusing any journalled tries to avoid costly -// disk io. -func (self *StateDB) New(root common.Hash) (*StateDB, error) { - self.lock.Lock() - defer self.lock.Unlock() - - tr, err := self.openTrie(root) - if err != nil { - return nil, err +// setError remembers the first non-nil error it is called with. +func (self *StateDB) setError(err error) { + if self.dbErr == nil { + self.dbErr = err } - return &StateDB{ - db: self.db, - trie: tr, - codeSizeCache: self.codeSizeCache, - stateObjects: make(map[common.Address]*stateObject), - stateObjectsDirty: make(map[common.Address]struct{}), - stateObjectsDestructed: make(map[common.Address]struct{}), - refund: new(big.Int), - logs: make(map[common.Hash][]*types.Log), - preimages: make(map[common.Hash][]byte), - }, nil +} + +func (self *StateDB) Error() error { + return self.dbErr } // Reset clears out all emphemeral state objects from the state db, but keeps // the underlying state trie to avoid reloading data for the next operations. func (self *StateDB) Reset(root common.Hash) error { - self.lock.Lock() - defer self.lock.Unlock() - - tr, err := self.openTrie(root) + tr, err := self.db.OpenTrie(root) if err != nil { return err } @@ -149,34 +123,9 @@ func (self *StateDB) Reset(root common.Hash) error { self.logSize = 0 self.preimages = make(map[common.Hash][]byte) self.clearJournalAndRefund() - return nil } -// openTrie creates a trie. It uses an existing trie if one is available -// from the journal if available. -func (self *StateDB) openTrie(root common.Hash) (*trie.SecureTrie, error) { - for i := len(self.pastTries) - 1; i >= 0; i-- { - if self.pastTries[i].Hash() == root { - tr := *self.pastTries[i] - return &tr, nil - } - } - return trie.NewSecure(root, self.db, MaxTrieCacheGen) -} - -func (self *StateDB) pushTrie(t *trie.SecureTrie) { - self.lock.Lock() - defer self.lock.Unlock() - - if len(self.pastTries) >= maxPastTries { - copy(self.pastTries, self.pastTries[1:]) - self.pastTries[len(self.pastTries)-1] = t - } else { - self.pastTries = append(self.pastTries, t) - } -} - func (self *StateDB) AddLog(log *types.Log) { self.journal = append(self.journal, addLogChange{txhash: self.thash}) @@ -254,10 +203,7 @@ func (self *StateDB) GetNonce(addr common.Address) uint64 { func (self *StateDB) GetCode(addr common.Address) []byte { stateObject := self.getStateObject(addr) if stateObject != nil { - code := stateObject.Code(self.db) - key := common.BytesToHash(stateObject.CodeHash()) - self.codeSizeCache.Add(key, len(code)) - return code + return stateObject.Code(self.db) } return nil } @@ -267,13 +213,12 @@ func (self *StateDB) GetCodeSize(addr common.Address) int { if stateObject == nil { return 0 } - key := common.BytesToHash(stateObject.CodeHash()) - if cached, ok := self.codeSizeCache.Get(key); ok { - return cached.(int) + if stateObject.code != nil { + return len(stateObject.code) } - size := len(stateObject.Code(self.db)) - if stateObject.dbErr == nil { - self.codeSizeCache.Add(key, size) + size, err := self.db.ContractCodeSize(stateObject.addrHash, common.BytesToHash(stateObject.CodeHash())) + if err != nil { + self.setError(err) } return size } @@ -296,7 +241,7 @@ func (self *StateDB) GetState(a common.Address, b common.Hash) common.Hash { // StorageTrie returns the storage trie of an account. // The return value is a copy and is nil for non-existent accounts. -func (self *StateDB) StorageTrie(a common.Address) *trie.SecureTrie { +func (self *StateDB) StorageTrie(a common.Address) Trie { stateObject := self.getStateObject(a) if stateObject == nil { return nil @@ -394,14 +339,14 @@ func (self *StateDB) updateStateObject(stateObject *stateObject) { if err != nil { panic(fmt.Errorf("can't encode object at %x: %v", addr[:], err)) } - self.trie.Update(addr[:], data) + self.setError(self.trie.TryUpdate(addr[:], data)) } // deleteStateObject removes the given object from the state trie. func (self *StateDB) deleteStateObject(stateObject *stateObject) { stateObject.deleted = true addr := stateObject.Address() - self.trie.Delete(addr[:]) + self.setError(self.trie.TryDelete(addr[:])) } // Retrieve a state object given my the address. Returns nil if not found. @@ -415,8 +360,9 @@ func (self *StateDB) getStateObject(addr common.Address) (stateObject *stateObje } // Load the object from the database. - enc := self.trie.Get(addr[:]) + enc, err := self.trie.TryGet(addr[:]) if len(enc) == 0 { + self.setError(err) return nil } var data Account @@ -512,8 +458,6 @@ func (self *StateDB) Copy() *StateDB { state := &StateDB{ db: self.db, trie: self.trie, - pastTries: self.pastTries, - codeSizeCache: self.codeSizeCache, stateObjects: make(map[common.Address]*stateObject, len(self.stateObjectsDirty)), stateObjectsDirty: make(map[common.Address]struct{}, len(self.stateObjectsDirty)), stateObjectsDestructed: make(map[common.Address]struct{}, len(self.stateObjectsDestructed)), @@ -636,23 +580,6 @@ func (s *StateDB) DeleteSuicides() { } } -// Commit commits all state changes to the database. -func (s *StateDB) Commit(deleteEmptyObjects bool) (root common.Hash, err error) { - root, batch := s.CommitBatch(deleteEmptyObjects) - return root, batch.Write() -} - -// CommitBatch commits all state changes to a write batch but does not -// execute the batch. It is used to validate state changes against -// the root hash stored in a block. -func (s *StateDB) CommitBatch(deleteEmptyObjects bool) (root common.Hash, batch ethdb.Batch) { - batch = s.db.NewBatch() - root, _ = s.CommitTo(batch, deleteEmptyObjects) - - log.Debug("Trie cache stats after commit", "misses", trie.CacheMisses(), "unloads", trie.CacheUnloads()) - return root, batch -} - func (s *StateDB) clearJournalAndRefund() { s.journal = nil s.validRevisions = s.validRevisions[:0] @@ -690,8 +617,6 @@ func (s *StateDB) CommitTo(dbw trie.DatabaseWriter, deleteEmptyObjects bool) (ro } // Write trie changes. root, err = s.trie.CommitTo(dbw) - if err == nil { - s.pushTrie(s.trie) - } + log.Debug("Trie cache stats after commit", "misses", trie.CacheMisses(), "unloads", trie.CacheUnloads()) return root, err } diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go index 72b638f97..b2bd18e65 100644 --- a/core/state/statedb_test.go +++ b/core/state/statedb_test.go @@ -28,6 +28,8 @@ import ( "testing" "testing/quick" + check "gopkg.in/check.v1" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethdb" @@ -38,7 +40,7 @@ import ( func TestUpdateLeaks(t *testing.T) { // Create an empty state database db, _ := ethdb.NewMemDatabase() - state, _ := New(common.Hash{}, db) + state, _ := New(common.Hash{}, NewDatabase(db)) // Update it with some accounts for i := byte(0); i < 255; i++ { @@ -66,8 +68,8 @@ func TestIntermediateLeaks(t *testing.T) { // Create two state databases, one transitioning to the final state, the other final from the beginning transDb, _ := ethdb.NewMemDatabase() finalDb, _ := ethdb.NewMemDatabase() - transState, _ := New(common.Hash{}, transDb) - finalState, _ := New(common.Hash{}, finalDb) + transState, _ := New(common.Hash{}, NewDatabase(transDb)) + finalState, _ := New(common.Hash{}, NewDatabase(finalDb)) modify := func(state *StateDB, addr common.Address, i, tweak byte) { state.SetBalance(addr, big.NewInt(int64(11*i)+int64(tweak))) @@ -95,10 +97,10 @@ func TestIntermediateLeaks(t *testing.T) { } // Commit and cross check the databases. - if _, err := transState.Commit(false); err != nil { + if _, err := transState.CommitTo(transDb, false); err != nil { t.Fatalf("failed to commit transition state: %v", err) } - if _, err := finalState.Commit(false); err != nil { + if _, err := finalState.CommitTo(finalDb, false); err != nil { t.Fatalf("failed to commit final state: %v", err) } for _, key := range finalDb.Keys() { @@ -282,7 +284,7 @@ func (test *snapshotTest) run() bool { // Run all actions and create snapshots. var ( db, _ = ethdb.NewMemDatabase() - state, _ = New(common.Hash{}, db) + state, _ = New(common.Hash{}, NewDatabase(db)) snapshotRevs = make([]int, len(test.snapshots)) sindex = 0 ) @@ -297,7 +299,7 @@ func (test *snapshotTest) run() bool { // Revert all snapshots in reverse order. Each revert must yield a state // that is equivalent to fresh state with all actions up the snapshot applied. for sindex--; sindex >= 0; sindex-- { - checkstate, _ := New(common.Hash{}, db) + checkstate, _ := New(common.Hash{}, NewDatabase(db)) for _, action := range test.actions[:test.snapshots[sindex]] { action.fn(action, checkstate) } @@ -354,21 +356,19 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error { return nil } -func TestTouchDelete(t *testing.T) { - db, _ := ethdb.NewMemDatabase() - state, _ := New(common.Hash{}, db) - state.GetOrNewStateObject(common.Address{}) - root, _ := state.Commit(false) - state.Reset(root) +func (s *StateSuite) TestTouchDelete(c *check.C) { + s.state.GetOrNewStateObject(common.Address{}) + root, _ := s.state.CommitTo(s.db, false) + s.state.Reset(root) - snapshot := state.Snapshot() - state.AddBalance(common.Address{}, new(big.Int)) - if len(state.stateObjectsDirty) != 1 { - t.Fatal("expected one dirty state object") + snapshot := s.state.Snapshot() + s.state.AddBalance(common.Address{}, new(big.Int)) + if len(s.state.stateObjectsDirty) != 1 { + c.Fatal("expected one dirty state object") } - state.RevertToSnapshot(snapshot) - if len(state.stateObjectsDirty) != 0 { - t.Fatal("expected no dirty state object") + s.state.RevertToSnapshot(snapshot) + if len(s.state.stateObjectsDirty) != 0 { + c.Fatal("expected no dirty state object") } } diff --git a/core/state/sync_test.go b/core/state/sync_test.go index 108ebb320..06c572ea6 100644 --- a/core/state/sync_test.go +++ b/core/state/sync_test.go @@ -36,9 +36,10 @@ type testAccount struct { } // makeTestState create a sample test state to test node-wise reconstruction. -func makeTestState() (ethdb.Database, common.Hash, []*testAccount) { +func makeTestState() (Database, *ethdb.MemDatabase, common.Hash, []*testAccount) { // Create an empty state - db, _ := ethdb.NewMemDatabase() + mem, _ := ethdb.NewMemDatabase() + db := NewDatabase(mem) state, _ := New(common.Hash{}, db) // Fill it with some arbitrary data @@ -60,17 +61,17 @@ func makeTestState() (ethdb.Database, common.Hash, []*testAccount) { state.updateStateObject(obj) accounts = append(accounts, acc) } - root, _ := state.Commit(false) + root, _ := state.CommitTo(mem, false) // Return the generated state - return db, root, accounts + return db, mem, root, accounts } // checkStateAccounts cross references a reconstructed state with an expected // account array. func checkStateAccounts(t *testing.T, db ethdb.Database, root common.Hash, accounts []*testAccount) { // Check root availability and state contents - state, err := New(root, db) + state, err := New(root, NewDatabase(db)) if err != nil { t.Fatalf("failed to create state trie at %x: %v", root, err) } @@ -90,13 +91,28 @@ func checkStateAccounts(t *testing.T, db ethdb.Database, root common.Hash, accou } } -// checkStateConsistency checks that all nodes in a state trie are indeed present. +// checkTrieConsistency checks that all nodes in a (sub-)trie are indeed present. +func checkTrieConsistency(db ethdb.Database, root common.Hash) error { + if v, _ := db.Get(root[:]); v == nil { + return nil // Consider a non existent state consistent. + } + trie, err := trie.New(root, db) + if err != nil { + return err + } + it := trie.NodeIterator(nil) + for it.Next(true) { + } + return it.Error() +} + +// checkStateConsistency checks that all data of a state root is present. func checkStateConsistency(db ethdb.Database, root common.Hash) error { // Create and iterate a state trie rooted in a sub-node if _, err := db.Get(root.Bytes()); err != nil { - return nil // Consider a non existent state consistent + return nil // Consider a non existent state consistent. } - state, err := New(root, db) + state, err := New(root, NewDatabase(db)) if err != nil { return err } @@ -122,7 +138,7 @@ func TestIterativeStateSyncBatched(t *testing.T) { testIterativeStateSync(t, func testIterativeStateSync(t *testing.T, batch int) { // Create a random state to copy - srcDb, srcRoot, srcAccounts := makeTestState() + _, srcMem, srcRoot, srcAccounts := makeTestState() // Create a destination state and sync with the scheduler dstDb, _ := ethdb.NewMemDatabase() @@ -132,7 +148,7 @@ func testIterativeStateSync(t *testing.T, batch int) { for len(queue) > 0 { results := make([]trie.SyncResult, len(queue)) for i, hash := range queue { - data, err := srcDb.Get(hash.Bytes()) + data, err := srcMem.Get(hash.Bytes()) if err != nil { t.Fatalf("failed to retrieve node data for %x: %v", hash, err) } @@ -154,7 +170,7 @@ func testIterativeStateSync(t *testing.T, batch int) { // partial results are returned, and the others sent only later. func TestIterativeDelayedStateSync(t *testing.T) { // Create a random state to copy - srcDb, srcRoot, srcAccounts := makeTestState() + _, srcMem, srcRoot, srcAccounts := makeTestState() // Create a destination state and sync with the scheduler dstDb, _ := ethdb.NewMemDatabase() @@ -165,7 +181,7 @@ func TestIterativeDelayedStateSync(t *testing.T) { // Sync only half of the scheduled nodes results := make([]trie.SyncResult, len(queue)/2+1) for i, hash := range queue[:len(results)] { - data, err := srcDb.Get(hash.Bytes()) + data, err := srcMem.Get(hash.Bytes()) if err != nil { t.Fatalf("failed to retrieve node data for %x: %v", hash, err) } @@ -191,7 +207,7 @@ func TestIterativeRandomStateSyncBatched(t *testing.T) { testIterativeRandomS func testIterativeRandomStateSync(t *testing.T, batch int) { // Create a random state to copy - srcDb, srcRoot, srcAccounts := makeTestState() + _, srcMem, srcRoot, srcAccounts := makeTestState() // Create a destination state and sync with the scheduler dstDb, _ := ethdb.NewMemDatabase() @@ -205,7 +221,7 @@ func testIterativeRandomStateSync(t *testing.T, batch int) { // Fetch all the queued nodes in a random order results := make([]trie.SyncResult, 0, len(queue)) for hash := range queue { - data, err := srcDb.Get(hash.Bytes()) + data, err := srcMem.Get(hash.Bytes()) if err != nil { t.Fatalf("failed to retrieve node data for %x: %v", hash, err) } @@ -231,7 +247,7 @@ func testIterativeRandomStateSync(t *testing.T, batch int) { // partial results are returned (Even those randomly), others sent only later. func TestIterativeRandomDelayedStateSync(t *testing.T) { // Create a random state to copy - srcDb, srcRoot, srcAccounts := makeTestState() + _, srcMem, srcRoot, srcAccounts := makeTestState() // Create a destination state and sync with the scheduler dstDb, _ := ethdb.NewMemDatabase() @@ -247,7 +263,7 @@ func TestIterativeRandomDelayedStateSync(t *testing.T) { for hash := range queue { delete(queue, hash) - data, err := srcDb.Get(hash.Bytes()) + data, err := srcMem.Get(hash.Bytes()) if err != nil { t.Fatalf("failed to retrieve node data for %x: %v", hash, err) } @@ -276,7 +292,9 @@ func TestIterativeRandomDelayedStateSync(t *testing.T) { // the database. func TestIncompleteStateSync(t *testing.T) { // Create a random state to copy - srcDb, srcRoot, srcAccounts := makeTestState() + _, srcMem, srcRoot, srcAccounts := makeTestState() + + checkTrieConsistency(srcMem, srcRoot) // Create a destination state and sync with the scheduler dstDb, _ := ethdb.NewMemDatabase() @@ -288,7 +306,7 @@ func TestIncompleteStateSync(t *testing.T) { // Fetch a batch of state nodes results := make([]trie.SyncResult, len(queue)) for i, hash := range queue { - data, err := srcDb.Get(hash.Bytes()) + data, err := srcMem.Get(hash.Bytes()) if err != nil { t.Fatalf("failed to retrieve node data for %x: %v", hash, err) } @@ -304,21 +322,18 @@ func TestIncompleteStateSync(t *testing.T) { for _, result := range results { added = append(added, result.Hash) } - // Check that all known sub-tries in the synced state is complete - for _, root := range added { - // Skim through the accounts and make sure the root hash is not a code node - codeHash := false + // Check that all known sub-tries added so far are complete or missing entirely. + checkSubtries: + for _, hash := range added { for _, acc := range srcAccounts { - if root == crypto.Keccak256Hash(acc.code) { - codeHash = true - break + if hash == crypto.Keccak256Hash(acc.code) { + continue checkSubtries // skip trie check of code nodes. } } - // If the root is a real trie node, check consistency - if !codeHash { - if err := checkStateConsistency(dstDb, root); err != nil { - t.Fatalf("state inconsistent: %v", err) - } + // Can't use checkStateConsistency here because subtrie keys may have odd + // length and crash in LeafKey. + if err := checkTrieConsistency(dstDb, hash); err != nil { + t.Fatalf("state inconsistent: %v", err) } } // Fetch the next batch to retrieve diff --git a/core/tx_pool_test.go b/core/tx_pool_test.go index 4e28522e9..4903bc3ca 100644 --- a/core/tx_pool_test.go +++ b/core/tx_pool_test.go @@ -44,7 +44,7 @@ func pricedTransaction(nonce uint64, gaslimit, gasprice *big.Int, key *ecdsa.Pri func setupTxPool() (*TxPool, *ecdsa.PrivateKey) { db, _ := ethdb.NewMemDatabase() - statedb, _ := state.New(common.Hash{}, db) + statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) key, _ := crypto.GenerateKey() newPool := NewTxPool(DefaultTxPoolConfig, params.TestChainConfig, new(event.TypeMux), func() (*state.StateDB, error) { return statedb, nil }, func() *big.Int { return big.NewInt(1000000) }) @@ -95,7 +95,7 @@ func TestStateChangeDuringPoolReset(t *testing.T) { key, _ = crypto.GenerateKey() address = crypto.PubkeyToAddress(key.PublicKey) mux = new(event.TypeMux) - statedb, _ = state.New(common.Hash{}, db) + statedb, _ = state.New(common.Hash{}, state.NewDatabase(db)) trigger = false ) @@ -114,7 +114,7 @@ func TestStateChangeDuringPoolReset(t *testing.T) { // a state change between those fetches. stdb := statedb if trigger { - statedb, _ = state.New(common.Hash{}, db) + statedb, _ = state.New(common.Hash{}, state.NewDatabase(db)) // simulate that the new head block included tx0 and tx1 statedb.SetNonce(address, 2) statedb.SetBalance(address, new(big.Int).SetUint64(params.Ether)) @@ -292,7 +292,7 @@ func TestTransactionChainFork(t *testing.T) { addr := crypto.PubkeyToAddress(key.PublicKey) resetState := func() { db, _ := ethdb.NewMemDatabase() - statedb, _ := state.New(common.Hash{}, db) + statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) pool.currentState = func() (*state.StateDB, error) { return statedb, nil } currentState, _ := pool.currentState() currentState.AddBalance(addr, big.NewInt(100000000000000)) @@ -318,7 +318,7 @@ func TestTransactionDoubleNonce(t *testing.T) { addr := crypto.PubkeyToAddress(key.PublicKey) resetState := func() { db, _ := ethdb.NewMemDatabase() - statedb, _ := state.New(common.Hash{}, db) + statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) pool.currentState = func() (*state.StateDB, error) { return statedb, nil } currentState, _ := pool.currentState() currentState.AddBalance(addr, big.NewInt(100000000000000)) @@ -628,7 +628,7 @@ func TestTransactionQueueGlobalLimiting(t *testing.T) { // Create the pool to test the limit enforcement with db, _ := ethdb.NewMemDatabase() - statedb, _ := state.New(common.Hash{}, db) + statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) pool := NewTxPool(DefaultTxPoolConfig, params.TestChainConfig, new(event.TypeMux), func() (*state.StateDB, error) { return statedb, nil }, func() *big.Int { return big.NewInt(1000000) }) pool.resetState() @@ -783,7 +783,7 @@ func TestTransactionPendingGlobalLimiting(t *testing.T) { // Create the pool to test the limit enforcement with db, _ := ethdb.NewMemDatabase() - statedb, _ := state.New(common.Hash{}, db) + statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) pool := NewTxPool(DefaultTxPoolConfig, params.TestChainConfig, new(event.TypeMux), func() (*state.StateDB, error) { return statedb, nil }, func() *big.Int { return big.NewInt(1000000) }) pool.resetState() @@ -835,7 +835,7 @@ func TestTransactionCapClearsFromAll(t *testing.T) { // Create the pool to test the limit enforcement with db, _ := ethdb.NewMemDatabase() - statedb, _ := state.New(common.Hash{}, db) + statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) pool := NewTxPool(DefaultTxPoolConfig, params.TestChainConfig, new(event.TypeMux), func() (*state.StateDB, error) { return statedb, nil }, func() *big.Int { return big.NewInt(1000000) }) pool.resetState() @@ -868,7 +868,7 @@ func TestTransactionPendingMinimumAllowance(t *testing.T) { // Create the pool to test the limit enforcement with db, _ := ethdb.NewMemDatabase() - statedb, _ := state.New(common.Hash{}, db) + statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) pool := NewTxPool(DefaultTxPoolConfig, params.TestChainConfig, new(event.TypeMux), func() (*state.StateDB, error) { return statedb, nil }, func() *big.Int { return big.NewInt(1000000) }) pool.resetState() @@ -913,7 +913,7 @@ func TestTransactionPendingMinimumAllowance(t *testing.T) { func TestTransactionPoolRepricing(t *testing.T) { // Create the pool to test the pricing enforcement with db, _ := ethdb.NewMemDatabase() - statedb, _ := state.New(common.Hash{}, db) + statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) pool := NewTxPool(DefaultTxPoolConfig, params.TestChainConfig, new(event.TypeMux), func() (*state.StateDB, error) { return statedb, nil }, func() *big.Int { return big.NewInt(1000000) }) pool.resetState() @@ -1006,7 +1006,7 @@ func TestTransactionPoolUnderpricing(t *testing.T) { // Create the pool to test the pricing enforcement with db, _ := ethdb.NewMemDatabase() - statedb, _ := state.New(common.Hash{}, db) + statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) pool := NewTxPool(DefaultTxPoolConfig, params.TestChainConfig, new(event.TypeMux), func() (*state.StateDB, error) { return statedb, nil }, func() *big.Int { return big.NewInt(1000000) }) pool.resetState() @@ -1091,7 +1091,7 @@ func TestTransactionPoolUnderpricing(t *testing.T) { func TestTransactionReplacement(t *testing.T) { // Create the pool to test the pricing enforcement with db, _ := ethdb.NewMemDatabase() - statedb, _ := state.New(common.Hash{}, db) + statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) pool := NewTxPool(DefaultTxPoolConfig, params.TestChainConfig, new(event.TypeMux), func() (*state.StateDB, error) { return statedb, nil }, func() *big.Int { return big.NewInt(1000000) }) pool.resetState() diff --git a/core/vm/gas_table.go b/core/vm/gas_table.go index 24ad6caa5..761ca4450 100644 --- a/core/vm/gas_table.go +++ b/core/vm/gas_table.go @@ -17,7 +17,6 @@ package vm import ( - gmath "math" "math/big" "github.com/ethereum/go-ethereum/common" @@ -28,15 +27,20 @@ import ( // memoryGasCosts calculates the quadratic gas for memory expansion. It does so // only for the memory region that is expanded, not the total memory. func memoryGasCost(mem *Memory, newMemSize uint64) (uint64, error) { - // The maximum that will fit in a uint64 is max_word_count - 1 - // anything above that will result in an overflow. - if newMemSize > gmath.MaxUint64-32 { - return 0, errGasUintOverflow - } if newMemSize == 0 { return 0, nil } + // The maximum that will fit in a uint64 is max_word_count - 1 + // anything above that will result in an overflow. + // Additionally, a newMemSize which results in a + // newMemSizeWords larger than 0x7ffffffff will cause the square operation + // to overflow. + // The constant 0xffffffffe0 is the highest number that can be used without + // overflowing the gas calculation + if newMemSize > 0xffffffffe0 { + return 0, errGasUintOverflow + } newMemSizeWords := toWordSize(newMemSize) newMemSize = newMemSizeWords * 32 diff --git a/core/vm/gas_table_test.go b/core/vm/gas_table_test.go index 1ee909e92..1b91aee56 100644 --- a/core/vm/gas_table_test.go +++ b/core/vm/gas_table_test.go @@ -16,24 +16,20 @@ package vm -import ( - "math" - "testing" -) +import "testing" func TestMemoryGasCost(t *testing.T) { - size := uint64(math.MaxUint64 - 64) - _, err := memoryGasCost(&Memory{}, size) + //size := uint64(math.MaxUint64 - 64) + size := uint64(0xffffffffe0) + v, err := memoryGasCost(&Memory{}, size) if err != nil { t.Error("didn't expect error:", err) } - - _, err = memoryGasCost(&Memory{}, size+32) - if err != nil { - t.Error("didn't expect error:", err) + if v != 36028899963961341 { + t.Errorf("Expected: 36028899963961341, got %d", v) } - _, err = memoryGasCost(&Memory{}, size+33) + _, err = memoryGasCost(&Memory{}, size+1) if err == nil { t.Error("expected error") } diff --git a/core/vm/instructions_test.go b/core/vm/instructions_test.go index ae428aeab..03c42c561 100644 --- a/core/vm/instructions_test.go +++ b/core/vm/instructions_test.go @@ -1,6 +1,7 @@ package vm import ( + "fmt" "math/big" "testing" @@ -40,3 +41,244 @@ func TestByteOp(t *testing.T) { } } } + +func opBenchmark(bench *testing.B, op func(pc *uint64, evm *EVM, contract *Contract, memory *Memory, stack *Stack) ([]byte, error), args ...string) { + var ( + env = NewEVM(Context{}, nil, params.TestChainConfig, Config{EnableJit: false, ForceJit: false}) + stack = newstack() + ) + // convert args + byteArgs := make([][]byte, len(args)) + for i, arg := range args { + byteArgs[i] = common.Hex2Bytes(arg) + } + pc := uint64(0) + bench.ResetTimer() + for i := 0; i < bench.N; i++ { + for _, arg := range byteArgs { + a := new(big.Int).SetBytes(arg) + stack.push(a) + } + op(&pc, env, nil, nil, stack) + stack.pop() + } +} + +func precompiledBenchmark(addr, input, expected string, gas uint64, bench *testing.B) { + + contract := NewContract(AccountRef(common.HexToAddress("1337")), + nil, new(big.Int), gas) + + p := PrecompiledContracts[common.HexToAddress(addr)] + in := common.Hex2Bytes(input) + var ( + res []byte + err error + ) + data := make([]byte, len(in)) + bench.ResetTimer() + for i := 0; i < bench.N; i++ { + contract.Gas = gas + copy(data, in) + res, err = RunPrecompiledContract(p, data, contract) + } + bench.StopTimer() + //Check if it is correct + if err != nil { + bench.Error(err) + return + } + if common.Bytes2Hex(res) != expected { + bench.Error(fmt.Sprintf("Expected %v, got %v", expected, common.Bytes2Hex(res))) + return + } +} + +func BenchmarkPrecompiledEcdsa(bench *testing.B) { + var ( + addr = "01" + inp = "38d18acb67d25c8bb9942764b62f18e17054f66a817bd4295423adf9ed98873e000000000000000000000000000000000000000000000000000000000000001b38d18acb67d25c8bb9942764b62f18e17054f66a817bd4295423adf9ed98873e789d1dd423d25f0772d2748d60f7e4b81bb14d086eba8e8e8efb6dcff8a4ae02" + exp = "000000000000000000000000ceaccac640adf55b2028469bd36ba501f28b699d" + gas = uint64(4000000) + ) + precompiledBenchmark(addr, inp, exp, gas, bench) +} +func BenchmarkPrecompiledSha256(bench *testing.B) { + var ( + addr = "02" + inp = "38d18acb67d25c8bb9942764b62f18e17054f66a817bd4295423adf9ed98873e000000000000000000000000000000000000000000000000000000000000001b38d18acb67d25c8bb9942764b62f18e17054f66a817bd4295423adf9ed98873e789d1dd423d25f0772d2748d60f7e4b81bb14d086eba8e8e8efb6dcff8a4ae02" + exp = "811c7003375852fabd0d362e40e68607a12bdabae61a7d068fe5fdd1dbbf2a5d" + gas = uint64(4000000) + ) + precompiledBenchmark(addr, inp, exp, gas, bench) +} +func BenchmarkPrecompiledRipeMD(bench *testing.B) { + var ( + addr = "03" + inp = "38d18acb67d25c8bb9942764b62f18e17054f66a817bd4295423adf9ed98873e000000000000000000000000000000000000000000000000000000000000001b38d18acb67d25c8bb9942764b62f18e17054f66a817bd4295423adf9ed98873e789d1dd423d25f0772d2748d60f7e4b81bb14d086eba8e8e8efb6dcff8a4ae02" + exp = "0000000000000000000000009215b8d9882ff46f0dfde6684d78e831467f65e6" + gas = uint64(4000000) + ) + precompiledBenchmark(addr, inp, exp, gas, bench) +} +func BenchmarkPrecompiledIdentity(bench *testing.B) { + var ( + addr = "04" + inp = "38d18acb67d25c8bb9942764b62f18e17054f66a817bd4295423adf9ed98873e000000000000000000000000000000000000000000000000000000000000001b38d18acb67d25c8bb9942764b62f18e17054f66a817bd4295423adf9ed98873e789d1dd423d25f0772d2748d60f7e4b81bb14d086eba8e8e8efb6dcff8a4ae02" + exp = "38d18acb67d25c8bb9942764b62f18e17054f66a817bd4295423adf9ed98873e000000000000000000000000000000000000000000000000000000000000001b38d18acb67d25c8bb9942764b62f18e17054f66a817bd4295423adf9ed98873e789d1dd423d25f0772d2748d60f7e4b81bb14d086eba8e8e8efb6dcff8a4ae02" + gas = uint64(4000000) + ) + precompiledBenchmark(addr, inp, exp, gas, bench) +} +func BenchmarkOpAdd(b *testing.B) { + x := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff" + y := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff" + + opBenchmark(b, opAdd, x, y) + +} +func BenchmarkOpSub(b *testing.B) { + x := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff" + y := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff" + + opBenchmark(b, opSub, x, y) + +} +func BenchmarkOpMul(b *testing.B) { + x := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff" + y := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff" + + opBenchmark(b, opMul, x, y) + +} +func BenchmarkOpDiv(b *testing.B) { + x := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff" + y := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff" + + opBenchmark(b, opDiv, x, y) + +} +func BenchmarkOpSdiv(b *testing.B) { + x := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff" + y := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff" + + opBenchmark(b, opSdiv, x, y) + +} +func BenchmarkOpMod(b *testing.B) { + x := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff" + y := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff" + + opBenchmark(b, opMod, x, y) + +} +func BenchmarkOpSmod(b *testing.B) { + x := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff" + y := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff" + + opBenchmark(b, opSmod, x, y) + +} +func BenchmarkOpExp(b *testing.B) { + x := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff" + y := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff" + + opBenchmark(b, opExp, x, y) + +} +func BenchmarkOpSignExtend(b *testing.B) { + x := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff" + y := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff" + + opBenchmark(b, opSignExtend, x, y) + +} +func BenchmarkOpLt(b *testing.B) { + x := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff" + y := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff" + + opBenchmark(b, opLt, x, y) + +} +func BenchmarkOpGt(b *testing.B) { + x := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff" + y := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff" + + opBenchmark(b, opGt, x, y) + +} +func BenchmarkOpSlt(b *testing.B) { + x := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff" + y := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff" + + opBenchmark(b, opSlt, x, y) + +} +func BenchmarkOpSgt(b *testing.B) { + x := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff" + y := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff" + + opBenchmark(b, opSgt, x, y) + +} +func BenchmarkOpEq(b *testing.B) { + x := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff" + y := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff" + + opBenchmark(b, opEq, x, y) + +} +func BenchmarkOpAnd(b *testing.B) { + x := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff" + y := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff" + + opBenchmark(b, opAnd, x, y) + +} +func BenchmarkOpOr(b *testing.B) { + x := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff" + y := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff" + + opBenchmark(b, opOr, x, y) + +} +func BenchmarkOpXor(b *testing.B) { + x := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff" + y := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff" + + opBenchmark(b, opXor, x, y) + +} +func BenchmarkOpByte(b *testing.B) { + x := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff" + y := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff" + + opBenchmark(b, opByte, x, y) + +} + +func BenchmarkOpAddmod(b *testing.B) { + x := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff" + y := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff" + z := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff" + + opBenchmark(b, opAddmod, x, y, z) + +} +func BenchmarkOpMulmod(b *testing.B) { + x := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff" + y := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff" + z := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff" + + opBenchmark(b, opMulmod, x, y, z) + +} + +//func BenchmarkOpSha3(b *testing.B) { +// x := "0" +// y := "32" +// +// opBenchmark(b,opSha3, x, y) +// +// +//} diff --git a/core/vm/runtime/runtime.go b/core/vm/runtime/runtime.go index aa386a995..44cde4f70 100644 --- a/core/vm/runtime/runtime.go +++ b/core/vm/runtime/runtime.go @@ -102,7 +102,7 @@ func Execute(code, input []byte, cfg *Config) ([]byte, *state.StateDB, error) { if cfg.State == nil { db, _ := ethdb.NewMemDatabase() - cfg.State, _ = state.New(common.Hash{}, db) + cfg.State, _ = state.New(common.Hash{}, state.NewDatabase(db)) } var ( address = common.StringToAddress("contract") @@ -133,7 +133,7 @@ func Create(input []byte, cfg *Config) ([]byte, common.Address, uint64, error) { if cfg.State == nil { db, _ := ethdb.NewMemDatabase() - cfg.State, _ = state.New(common.Hash{}, db) + cfg.State, _ = state.New(common.Hash{}, state.NewDatabase(db)) } var ( vmenv = NewEnv(cfg, cfg.State) diff --git a/core/vm/runtime/runtime_test.go b/core/vm/runtime/runtime_test.go index 7f40770d2..2c4dc5026 100644 --- a/core/vm/runtime/runtime_test.go +++ b/core/vm/runtime/runtime_test.go @@ -95,7 +95,7 @@ func TestExecute(t *testing.T) { func TestCall(t *testing.T) { db, _ := ethdb.NewMemDatabase() - state, _ := state.New(common.Hash{}, db) + state, _ := state.New(common.Hash{}, state.NewDatabase(db)) address := common.HexToAddress("0x0a") state.SetCode(address, []byte{ byte(vm.PUSH1), 10, diff --git a/eth/api.go b/eth/api.go index 81570988c..0d90759b6 100644 --- a/eth/api.go +++ b/eth/api.go @@ -637,7 +637,7 @@ func (api *PrivateDebugAPI) StorageRangeAt(ctx context.Context, blockHash common return storageRangeAt(st, keyStart, maxResult), nil } -func storageRangeAt(st *trie.SecureTrie, start []byte, maxResult int) StorageRangeResult { +func storageRangeAt(st state.Trie, start []byte, maxResult int) StorageRangeResult { it := trie.NewIterator(st.NodeIterator(start)) result := StorageRangeResult{Storage: storageMap{}} for i := 0; i < maxResult && it.Next(); i++ { diff --git a/eth/api_backend.go b/eth/api_backend.go index fe108d272..166b5084d 100644 --- a/eth/api_backend.go +++ b/eth/api_backend.go @@ -31,7 +31,6 @@ import ( "github.com/ethereum/go-ethereum/eth/gasprice" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/event" - "github.com/ethereum/go-ethereum/internal/ethapi" "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/rpc" ) @@ -81,11 +80,11 @@ func (b *EthApiBackend) BlockByNumber(ctx context.Context, blockNr rpc.BlockNumb return b.eth.blockchain.GetBlockByNumber(uint64(blockNr)), nil } -func (b *EthApiBackend) StateAndHeaderByNumber(ctx context.Context, blockNr rpc.BlockNumber) (ethapi.State, *types.Header, error) { +func (b *EthApiBackend) StateAndHeaderByNumber(ctx context.Context, blockNr rpc.BlockNumber) (*state.StateDB, *types.Header, error) { // Pending state is only known by the miner if blockNr == rpc.PendingBlockNumber { block, state := b.eth.miner.Pending() - return EthApiState{state}, block.Header(), nil + return state, block.Header(), nil } // Otherwise resolve the block number and return its state header, err := b.HeaderByNumber(ctx, blockNr) @@ -93,7 +92,7 @@ func (b *EthApiBackend) StateAndHeaderByNumber(ctx context.Context, blockNr rpc. return nil, nil, err } stateDb, err := b.eth.BlockChain().StateAt(header.Root) - return EthApiState{stateDb}, header, err + return stateDb, header, err } func (b *EthApiBackend) GetBlock(ctx context.Context, blockHash common.Hash) (*types.Block, error) { @@ -108,14 +107,12 @@ func (b *EthApiBackend) GetTd(blockHash common.Hash) *big.Int { return b.eth.blockchain.GetTdByHash(blockHash) } -func (b *EthApiBackend) GetEVM(ctx context.Context, msg core.Message, state ethapi.State, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error) { - statedb := state.(EthApiState).state - from := statedb.GetOrNewStateObject(msg.From()) - from.SetBalance(math.MaxBig256) +func (b *EthApiBackend) GetEVM(ctx context.Context, msg core.Message, state *state.StateDB, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error) { + state.SetBalance(msg.From(), math.MaxBig256) vmError := func() error { return nil } context := core.NewEVMContext(msg, header, b.eth.BlockChain(), nil) - return vm.NewEVM(context, statedb, b.eth.chainConfig, vmCfg), vmError, nil + return vm.NewEVM(context, state, b.eth.chainConfig, vmCfg), vmError, nil } func (b *EthApiBackend) SendTx(ctx context.Context, signedTx *types.Transaction) error { @@ -200,23 +197,3 @@ func (b *EthApiBackend) EventMux() *event.TypeMux { func (b *EthApiBackend) AccountManager() *accounts.Manager { return b.eth.AccountManager() } - -type EthApiState struct { - state *state.StateDB -} - -func (s EthApiState) GetBalance(ctx context.Context, addr common.Address) (*big.Int, error) { - return s.state.GetBalance(addr), nil -} - -func (s EthApiState) GetCode(ctx context.Context, addr common.Address) ([]byte, error) { - return s.state.GetCode(addr), nil -} - -func (s EthApiState) GetState(ctx context.Context, a common.Address, b common.Hash) (common.Hash, error) { - return s.state.GetState(a, b), nil -} - -func (s EthApiState) GetNonce(ctx context.Context, addr common.Address) (uint64, error) { - return s.state.GetNonce(addr), nil -} diff --git a/eth/api_test.go b/eth/api_test.go index f8d2e9c76..49ce38688 100644 --- a/eth/api_test.go +++ b/eth/api_test.go @@ -32,7 +32,7 @@ func TestStorageRangeAt(t *testing.T) { // Create a state where account 0x010000... has a few storage entries. var ( db, _ = ethdb.NewMemDatabase() - state, _ = state.New(common.Hash{}, db) + state, _ = state.New(common.Hash{}, state.NewDatabase(db)) addr = common.Address{0x01} keys = []common.Hash{ // hashes of Keys of storage common.HexToHash("340dd630ad21bf010b4e676dbfa9ba9a02175262d1fa356232cfde6cb5b47ef2"), diff --git a/eth/bind.go b/eth/bind.go index e5abd8617..0385db1f9 100644 --- a/eth/bind.go +++ b/eth/bind.go @@ -54,14 +54,12 @@ func NewContractBackend(apiBackend ethapi.Backend) *ContractBackend { // CodeAt retrieves any code associated with the contract from the local API. func (b *ContractBackend) CodeAt(ctx context.Context, contract common.Address, blockNum *big.Int) ([]byte, error) { - out, err := b.bcapi.GetCode(ctx, contract, toBlockNumber(blockNum)) - return common.FromHex(out), err + return b.bcapi.GetCode(ctx, contract, toBlockNumber(blockNum)) } // CodeAt retrieves any code associated with the contract from the local API. func (b *ContractBackend) PendingCodeAt(ctx context.Context, contract common.Address) ([]byte, error) { - out, err := b.bcapi.GetCode(ctx, contract, rpc.PendingBlockNumber) - return common.FromHex(out), err + return b.bcapi.GetCode(ctx, contract, rpc.PendingBlockNumber) } // ContractCall implements bind.ContractCaller executing an Ethereum contract diff --git a/eth/downloader/downloader_test.go b/eth/downloader/downloader_test.go index 267a0def9..1fb5a0910 100644 --- a/eth/downloader/downloader_test.go +++ b/eth/downloader/downloader_test.go @@ -657,7 +657,7 @@ func assertOwnForkedChain(t *testing.T, tester *downloadTester, common int, leng index = len(tester.ownHashes) - lengths[len(lengths)-1] + int(tester.downloader.queue.fastSyncPivot) } if index > 0 { - if statedb, err := state.New(tester.ownHeaders[tester.ownHashes[index]].Root, tester.stateDb); statedb == nil || err != nil { + if statedb, err := state.New(tester.ownHeaders[tester.ownHashes[index]].Root, state.NewDatabase(tester.stateDb)); statedb == nil || err != nil { t.Fatalf("state reconstruction failed: %v", err) } } diff --git a/eth/handler_test.go b/eth/handler_test.go index 413ed2bff..ca9c9e1b4 100644 --- a/eth/handler_test.go +++ b/eth/handler_test.go @@ -374,7 +374,7 @@ func testGetNodeData(t *testing.T, protocol int) { } accounts := []common.Address{testBank, acc1Addr, acc2Addr} for i := uint64(0); i <= pm.blockchain.CurrentBlock().NumberU64(); i++ { - trie, _ := state.New(pm.blockchain.GetBlockByNumber(i).Root(), statedb) + trie, _ := state.New(pm.blockchain.GetBlockByNumber(i).Root(), state.NewDatabase(statedb)) for j, acc := range accounts { state, _ := pm.blockchain.State() diff --git a/internal/ethapi/api.go b/internal/ethapi/api.go index da5dc5d58..c22c56dfb 100644 --- a/internal/ethapi/api.go +++ b/internal/ethapi/api.go @@ -447,8 +447,8 @@ func (s *PublicBlockChainAPI) GetBalance(ctx context.Context, address common.Add if state == nil || err != nil { return nil, err } - - return state.GetBalance(ctx, address) + b := state.GetBalance(address) + return b, state.Error() } // GetBlockByNumber returns the requested block. When blockNr is -1 the chain head is returned. When fullTx is true all @@ -529,31 +529,25 @@ func (s *PublicBlockChainAPI) GetUncleCountByBlockHash(ctx context.Context, bloc } // GetCode returns the code stored at the given address in the state for the given block number. -func (s *PublicBlockChainAPI) GetCode(ctx context.Context, address common.Address, blockNr rpc.BlockNumber) (string, error) { +func (s *PublicBlockChainAPI) GetCode(ctx context.Context, address common.Address, blockNr rpc.BlockNumber) (hexutil.Bytes, error) { state, _, err := s.b.StateAndHeaderByNumber(ctx, blockNr) if state == nil || err != nil { - return "", err - } - res, err := state.GetCode(ctx, address) - if len(res) == 0 || err != nil { // backwards compatibility - return "0x", err + return nil, err } - return common.ToHex(res), nil + code := state.GetCode(address) + return code, state.Error() } // GetStorageAt returns the storage from the state at the given address, key and // block number. The rpc.LatestBlockNumber and rpc.PendingBlockNumber meta block // numbers are also allowed. -func (s *PublicBlockChainAPI) GetStorageAt(ctx context.Context, address common.Address, key string, blockNr rpc.BlockNumber) (string, error) { +func (s *PublicBlockChainAPI) GetStorageAt(ctx context.Context, address common.Address, key string, blockNr rpc.BlockNumber) (hexutil.Bytes, error) { state, _, err := s.b.StateAndHeaderByNumber(ctx, blockNr) if state == nil || err != nil { - return "0x", err - } - res, err := state.GetState(ctx, address, common.HexToHash(key)) - if err != nil { - return "0x", err + return nil, err } - return res.Hex(), nil + res := state.GetState(address, common.HexToHash(key)) + return res[:], state.Error() } // callmsg is the message type used for call transitions. @@ -978,11 +972,8 @@ func (s *PublicTransactionPoolAPI) GetTransactionCount(ctx context.Context, addr if state == nil || err != nil { return nil, err } - nonce, err := state.GetNonce(ctx, address) - if err != nil { - return nil, err - } - return (*hexutil.Uint64)(&nonce), nil + nonce := state.GetNonce(address) + return (*hexutil.Uint64)(&nonce), state.Error() } // getTransactionBlockData fetches the meta data for the given transaction from the chain database. This is useful to diff --git a/internal/ethapi/backend.go b/internal/ethapi/backend.go index 68b5069d0..d122b7915 100644 --- a/internal/ethapi/backend.go +++ b/internal/ethapi/backend.go @@ -24,6 +24,7 @@ import ( "github.com/ethereum/go-ethereum/accounts" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/eth/downloader" @@ -47,11 +48,12 @@ type Backend interface { SetHead(number uint64) HeaderByNumber(ctx context.Context, blockNr rpc.BlockNumber) (*types.Header, error) BlockByNumber(ctx context.Context, blockNr rpc.BlockNumber) (*types.Block, error) - StateAndHeaderByNumber(ctx context.Context, blockNr rpc.BlockNumber) (State, *types.Header, error) + StateAndHeaderByNumber(ctx context.Context, blockNr rpc.BlockNumber) (*state.StateDB, *types.Header, error) GetBlock(ctx context.Context, blockHash common.Hash) (*types.Block, error) GetReceipts(ctx context.Context, blockHash common.Hash) (types.Receipts, error) GetTd(blockHash common.Hash) *big.Int - GetEVM(ctx context.Context, msg core.Message, state State, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error) + GetEVM(ctx context.Context, msg core.Message, state *state.StateDB, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error) + // TxPool API SendTx(ctx context.Context, signedTx *types.Transaction) error RemoveTx(txHash common.Hash) @@ -65,13 +67,6 @@ type Backend interface { CurrentBlock() *types.Block } -type State interface { - GetBalance(ctx context.Context, addr common.Address) (*big.Int, error) - GetCode(ctx context.Context, addr common.Address) ([]byte, error) - GetState(ctx context.Context, a common.Address, b common.Hash) (common.Hash, error) - GetNonce(ctx context.Context, addr common.Address) (uint64, error) -} - func GetAPIs(apiBackend Backend) []rpc.API { nonceLock := new(AddrLocker) return []rpc.API{ diff --git a/les/api_backend.go b/les/api_backend.go index 7d69046de..7a3c2447c 100644 --- a/les/api_backend.go +++ b/les/api_backend.go @@ -24,13 +24,13 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/math" "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/eth/downloader" "github.com/ethereum/go-ethereum/eth/gasprice" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/event" - "github.com/ethereum/go-ethereum/internal/ethapi" "github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/rpc" @@ -70,12 +70,12 @@ func (b *LesApiBackend) BlockByNumber(ctx context.Context, blockNr rpc.BlockNumb return b.GetBlock(ctx, header.Hash()) } -func (b *LesApiBackend) StateAndHeaderByNumber(ctx context.Context, blockNr rpc.BlockNumber) (ethapi.State, *types.Header, error) { +func (b *LesApiBackend) StateAndHeaderByNumber(ctx context.Context, blockNr rpc.BlockNumber) (*state.StateDB, *types.Header, error) { header, err := b.HeaderByNumber(ctx, blockNr) if header == nil || err != nil { return nil, nil, err } - return light.NewLightState(light.StateTrieID(header), b.eth.odr), header, nil + return light.NewState(ctx, header, b.eth.odr), header, nil } func (b *LesApiBackend) GetBlock(ctx context.Context, blockHash common.Hash) (*types.Block, error) { @@ -90,18 +90,10 @@ func (b *LesApiBackend) GetTd(blockHash common.Hash) *big.Int { return b.eth.blockchain.GetTdByHash(blockHash) } -func (b *LesApiBackend) GetEVM(ctx context.Context, msg core.Message, state ethapi.State, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error) { - stateDb := state.(*light.LightState).Copy() - addr := msg.From() - from, err := stateDb.GetOrNewStateObject(ctx, addr) - if err != nil { - return nil, nil, err - } - from.SetBalance(math.MaxBig256) - - vmstate := light.NewVMState(ctx, stateDb) +func (b *LesApiBackend) GetEVM(ctx context.Context, msg core.Message, state *state.StateDB, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error) { + state.SetBalance(msg.From(), math.MaxBig256) context := core.NewEVMContext(msg, header, b.eth.blockchain, nil) - return vm.NewEVM(context, vmstate, b.eth.chainConfig, vmCfg), vmstate.Error, nil + return vm.NewEVM(context, state, b.eth.chainConfig, vmCfg), state.Error, nil } func (b *LesApiBackend) SendTx(ctx context.Context, signedTx *types.Transaction) error { diff --git a/les/odr_test.go b/les/odr_test.go index 7b34996ce..3a0fd6738 100644 --- a/les/odr_test.go +++ b/les/odr_test.go @@ -75,24 +75,23 @@ func odrAccounts(ctx context.Context, db ethdb.Database, config *params.ChainCon dummyAddr := common.HexToAddress("1234567812345678123456781234567812345678") acc := []common.Address{testBankAddress, acc1Addr, acc2Addr, dummyAddr} - var res []byte + var ( + res []byte + st *state.StateDB + err error + ) for _, addr := range acc { if bc != nil { header := bc.GetHeaderByHash(bhash) - st, err := state.New(header.Root, db) - if err == nil { - bal := st.GetBalance(addr) - rlp, _ := rlp.EncodeToBytes(bal) - res = append(res, rlp...) - } + st, err = state.New(header.Root, state.NewDatabase(db)) } else { header := lc.GetHeaderByHash(bhash) - st := light.NewLightState(light.StateTrieID(header), lc.Odr()) - bal, err := st.GetBalance(ctx, addr) - if err == nil { - rlp, _ := rlp.EncodeToBytes(bal) - res = append(res, rlp...) - } + st = light.NewState(ctx, header, lc.Odr()) + } + if err == nil { + bal := st.GetBalance(addr) + rlp, _ := rlp.EncodeToBytes(bal) + res = append(res, rlp...) } } @@ -115,7 +114,7 @@ func odrContractCall(ctx context.Context, db ethdb.Database, config *params.Chai data[35] = byte(i) if bc != nil { header := bc.GetHeaderByHash(bhash) - statedb, err := state.New(header.Root, db) + statedb, err := state.New(header.Root, state.NewDatabase(db)) if err == nil { from := statedb.GetOrNewStateObject(testBankAddress) @@ -133,23 +132,15 @@ func odrContractCall(ctx context.Context, db ethdb.Database, config *params.Chai } } else { header := lc.GetHeaderByHash(bhash) - state := light.NewLightState(light.StateTrieID(header), lc.Odr()) - vmstate := light.NewVMState(ctx, state) - from, err := state.GetOrNewStateObject(ctx, testBankAddress) - if err == nil { - from.SetBalance(math.MaxBig256) - - msg := callmsg{types.NewMessage(from.Address(), &testContractAddr, 0, new(big.Int), big.NewInt(100000), new(big.Int), data, false)} - - context := core.NewEVMContext(msg, header, lc, nil) - vmenv := vm.NewEVM(context, vmstate, config, vm.Config{}) - - //vmenv := light.NewEnv(ctx, state, config, lc, msg, header, vm.Config{}) - gp := new(core.GasPool).AddGas(math.MaxBig256) - ret, _, _ := core.ApplyMessage(vmenv, msg, gp) - if vmstate.Error() == nil { - res = append(res, ret...) - } + state := light.NewState(ctx, header, lc.Odr()) + state.SetBalance(testBankAddress, math.MaxBig256) + msg := callmsg{types.NewMessage(testBankAddress, &testContractAddr, 0, new(big.Int), big.NewInt(100000), new(big.Int), data, false)} + context := core.NewEVMContext(msg, header, lc, nil) + vmenv := vm.NewEVM(context, state, config, vm.Config{}) + gp := new(core.GasPool).AddGas(math.MaxBig256) + ret, _, _ := core.ApplyMessage(vmenv, msg, gp) + if state.Error() == nil { + res = append(res, ret...) } } } diff --git a/les/request_test.go b/les/request_test.go index 3add5f20d..6b594462d 100644 --- a/les/request_test.go +++ b/les/request_test.go @@ -62,7 +62,7 @@ func tfCodeAccess(db ethdb.Database, bhash common.Hash, number uint64) light.Odr return nil } sti := light.StateTrieID(header) - ci := light.StorageTrieID(sti, testContractAddr, common.Hash{}) + ci := light.StorageTrieID(sti, crypto.Keccak256Hash(testContractAddr[:]), common.Hash{}) return &light.CodeRequest{Id: ci, Hash: crypto.Keccak256Hash(testContractCodeDeployed)} } diff --git a/light/lightchain.go b/light/lightchain.go index 5b7e57041..87436f4a5 100644 --- a/light/lightchain.go +++ b/light/lightchain.go @@ -180,11 +180,6 @@ func (self *LightChain) Status() (td *big.Int, currentBlock common.Hash, genesis return self.GetTd(hash, header.Number.Uint64()), hash, self.genesisBlock.Hash() } -// State returns a new mutable state based on the current HEAD block. -func (self *LightChain) State() *LightState { - return NewLightState(StateTrieID(self.hc.CurrentHeader()), self.odr) -} - // Reset purges the entire blockchain, restoring it to its genesis state. func (bc *LightChain) Reset() { bc.ResetWithGenesisBlock(bc.genesisBlock) diff --git a/light/odr.go b/light/odr.go index ca6364f28..d19a488f6 100644 --- a/light/odr.go +++ b/light/odr.go @@ -34,7 +34,7 @@ import ( // service is not required. var NoOdr = context.Background() -// OdrBackend is an interface to a backend service that handles ODR retrievals +// OdrBackend is an interface to a backend service that handles ODR retrievals type type OdrBackend interface { Database() ethdb.Database Retrieve(ctx context.Context, req OdrRequest) error @@ -66,11 +66,11 @@ func StateTrieID(header *types.Header) *TrieID { // StorageTrieID returns a TrieID for a contract storage trie at a given account // of a given state trie. It also requires the root hash of the trie for // checking Merkle proofs. -func StorageTrieID(state *TrieID, addr common.Address, root common.Hash) *TrieID { +func StorageTrieID(state *TrieID, addrHash, root common.Hash) *TrieID { return &TrieID{ BlockHash: state.BlockHash, BlockNumber: state.BlockNumber, - AccKey: crypto.Keccak256(addr[:]), + AccKey: addrHash[:], Root: root, } } @@ -102,7 +102,7 @@ func storeProof(db ethdb.Database, proof []rlp.RawValue) { // CodeRequest is the ODR request type for retrieving contract code type CodeRequest struct { OdrRequest - Id *TrieID + Id *TrieID // references storage trie of the account Hash common.Hash Data []byte } diff --git a/light/odr_test.go b/light/odr_test.go index 576e3abc9..544b64eff 100644 --- a/light/odr_test.go +++ b/light/odr_test.go @@ -86,11 +86,11 @@ func (odr *testOdr) Retrieve(ctx context.Context, req OdrRequest) error { return nil } -type odrTestFn func(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) []byte +type odrTestFn func(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) ([]byte, error) -func TestOdrGetBlockLes1(t *testing.T) { testChainOdr(t, 1, 1, odrGetBlock) } +func TestOdrGetBlockLes1(t *testing.T) { testChainOdr(t, 1, odrGetBlock) } -func odrGetBlock(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) []byte { +func odrGetBlock(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) ([]byte, error) { var block *types.Block if bc != nil { block = bc.GetBlockByHash(bhash) @@ -98,15 +98,15 @@ func odrGetBlock(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc block, _ = lc.GetBlockByHash(ctx, bhash) } if block == nil { - return nil + return nil, nil } rlp, _ := rlp.EncodeToBytes(block) - return rlp + return rlp, nil } -func TestOdrGetReceiptsLes1(t *testing.T) { testChainOdr(t, 1, 1, odrGetReceipts) } +func TestOdrGetReceiptsLes1(t *testing.T) { testChainOdr(t, 1, odrGetReceipts) } -func odrGetReceipts(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) []byte { +func odrGetReceipts(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) ([]byte, error) { var receipts types.Receipts if bc != nil { receipts = core.GetBlockReceipts(db, bhash, core.GetBlockNumber(db, bhash)) @@ -114,43 +114,37 @@ func odrGetReceipts(ctx context.Context, db ethdb.Database, bc *core.BlockChain, receipts, _ = GetBlockReceipts(ctx, lc.Odr(), bhash, core.GetBlockNumber(db, bhash)) } if receipts == nil { - return nil + return nil, nil } rlp, _ := rlp.EncodeToBytes(receipts) - return rlp + return rlp, nil } -func TestOdrAccountsLes1(t *testing.T) { testChainOdr(t, 1, 1, odrAccounts) } +func TestOdrAccountsLes1(t *testing.T) { testChainOdr(t, 1, odrAccounts) } -func odrAccounts(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) []byte { +func odrAccounts(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) ([]byte, error) { dummyAddr := common.HexToAddress("1234567812345678123456781234567812345678") acc := []common.Address{testBankAddress, acc1Addr, acc2Addr, dummyAddr} + var st *state.StateDB + if bc == nil { + header := lc.GetHeaderByHash(bhash) + st = NewState(ctx, header, lc.Odr()) + } else { + header := bc.GetHeaderByHash(bhash) + st, _ = state.New(header.Root, state.NewDatabase(db)) + } + var res []byte for _, addr := range acc { - if bc != nil { - header := bc.GetHeaderByHash(bhash) - st, err := state.New(header.Root, db) - if err == nil { - bal := st.GetBalance(addr) - rlp, _ := rlp.EncodeToBytes(bal) - res = append(res, rlp...) - } - } else { - header := lc.GetHeaderByHash(bhash) - st := NewLightState(StateTrieID(header), lc.Odr()) - bal, err := st.GetBalance(ctx, addr) - if err == nil { - rlp, _ := rlp.EncodeToBytes(bal) - res = append(res, rlp...) - } - } + bal := st.GetBalance(addr) + rlp, _ := rlp.EncodeToBytes(bal) + res = append(res, rlp...) } - - return res + return res, st.Error() } -func TestOdrContractCallLes1(t *testing.T) { testChainOdr(t, 1, 2, odrContractCall) } +func TestOdrContractCallLes1(t *testing.T) { testChainOdr(t, 1, odrContractCall) } type callmsg struct { types.Message @@ -158,50 +152,42 @@ type callmsg struct { func (callmsg) CheckNonce() bool { return false } -func odrContractCall(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) []byte { +func odrContractCall(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) ([]byte, error) { data := common.Hex2Bytes("60CD26850000000000000000000000000000000000000000000000000000000000000000") - config := params.TestChainConfig var res []byte for i := 0; i < 3; i++ { data[35] = byte(i) - if bc != nil { - header := bc.GetHeaderByHash(bhash) - statedb, err := state.New(header.Root, db) - if err == nil { - from := statedb.GetOrNewStateObject(testBankAddress) - from.SetBalance(math.MaxBig256) - - msg := callmsg{types.NewMessage(from.Address(), &testContractAddr, 0, new(big.Int), big.NewInt(1000000), new(big.Int), data, false)} - context := core.NewEVMContext(msg, header, bc, nil) - vmenv := vm.NewEVM(context, statedb, config, vm.Config{}) - - gp := new(core.GasPool).AddGas(math.MaxBig256) - ret, _, _ := core.ApplyMessage(vmenv, msg, gp) - res = append(res, ret...) - } + var ( + st *state.StateDB + header *types.Header + chain core.ChainContext + ) + if bc == nil { + chain = lc + header = lc.GetHeaderByHash(bhash) + st = NewState(ctx, header, lc.Odr()) } else { - header := lc.GetHeaderByHash(bhash) - state := NewLightState(StateTrieID(header), lc.Odr()) - vmstate := NewVMState(ctx, state) - from, err := state.GetOrNewStateObject(ctx, testBankAddress) - if err == nil { - from.SetBalance(math.MaxBig256) - - msg := callmsg{types.NewMessage(from.Address(), &testContractAddr, 0, new(big.Int), big.NewInt(1000000), new(big.Int), data, false)} - context := core.NewEVMContext(msg, header, lc, nil) - vmenv := vm.NewEVM(context, vmstate, config, vm.Config{}) - gp := new(core.GasPool).AddGas(math.MaxBig256) - ret, _, _ := core.ApplyMessage(vmenv, msg, gp) - if vmstate.Error() == nil { - res = append(res, ret...) - } - } + chain = bc + header = bc.GetHeaderByHash(bhash) + st, _ = state.New(header.Root, state.NewDatabase(db)) + } + + // Perform read-only call. + st.SetBalance(testBankAddress, math.MaxBig256) + msg := callmsg{types.NewMessage(testBankAddress, &testContractAddr, 0, new(big.Int), big.NewInt(1000000), new(big.Int), data, false)} + context := core.NewEVMContext(msg, header, chain, nil) + vmenv := vm.NewEVM(context, st, config, vm.Config{}) + gp := new(core.GasPool).AddGas(math.MaxBig256) + ret, _, _ := core.ApplyMessage(vmenv, msg, gp) + res = append(res, ret...) + if st.Error() != nil { + return res, st.Error() } } - return res + return res, nil } func testChainGen(i int, block *core.BlockGen) { @@ -245,7 +231,7 @@ func testChainGen(i int, block *core.BlockGen) { } } -func testChainOdr(t *testing.T, protocol int, expFail uint64, fn odrTestFn) { +func testChainOdr(t *testing.T, protocol int, fn odrTestFn) { var ( evmux = new(event.TypeMux) sdb, _ = ethdb.NewMemDatabase() @@ -258,46 +244,58 @@ func testChainOdr(t *testing.T, protocol int, expFail uint64, fn odrTestFn) { blockchain, _ := core.NewBlockChain(sdb, params.TestChainConfig, ethash.NewFullFaker(), evmux, vm.Config{}) gchain, _ := core.GenerateChain(params.TestChainConfig, genesis, sdb, 4, testChainGen) if _, err := blockchain.InsertChain(gchain); err != nil { - panic(err) + t.Fatal(err) } odr := &testOdr{sdb: sdb, ldb: ldb} - lightchain, _ := NewLightChain(odr, params.TestChainConfig, ethash.NewFullFaker(), evmux) + lightchain, err := NewLightChain(odr, params.TestChainConfig, ethash.NewFullFaker(), evmux) + if err != nil { + t.Fatal(err) + } headers := make([]*types.Header, len(gchain)) for i, block := range gchain { headers[i] = block.Header() } if _, err := lightchain.InsertHeaderChain(headers, 1); err != nil { - panic(err) + t.Fatal(err) } - test := func(expFail uint64) { + test := func(expFail int) { for i := uint64(0); i <= blockchain.CurrentHeader().Number.Uint64(); i++ { bhash := core.GetCanonicalHash(sdb, i) - b1 := fn(NoOdr, sdb, blockchain, nil, bhash) + b1, err := fn(NoOdr, sdb, blockchain, nil, bhash) + if err != nil { + t.Fatalf("error in full-node test for block %d: %v", i, err) + } ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) defer cancel() - b2 := fn(ctx, ldb, nil, lightchain, bhash) + + exp := i < uint64(expFail) + b2, err := fn(ctx, ldb, nil, lightchain, bhash) + if err != nil && exp { + t.Errorf("error in ODR test for block %d: %v", i, err) + } eq := bytes.Equal(b1, b2) - exp := i < expFail if exp && !eq { - t.Errorf("odr mismatch") - } - if !exp && eq { - t.Errorf("unexpected odr match") + t.Errorf("ODR test output for block %d doesn't match full node", i) } } } - odr.disable = true // expect retrievals to fail (except genesis block) without a les peer - test(expFail) - odr.disable = false - // expect all retrievals to pass - test(5) + t.Log("checking without ODR") odr.disable = true + test(1) + + // expect all retrievals to pass with ODR enabled + t.Log("checking with ODR") + odr.disable = false + test(len(gchain)) + // still expect all retrievals to pass, now data should be cached locally - test(5) + t.Log("checking without ODR, should be cached") + odr.disable = true + test(len(gchain)) } diff --git a/light/odr_util.go b/light/odr_util.go index d7f8458f1..fcdfdb82c 100644 --- a/light/odr_util.go +++ b/light/odr_util.go @@ -106,25 +106,6 @@ func GetCanonicalHash(ctx context.Context, odr OdrBackend, number uint64) (commo return common.Hash{}, err } -// retrieveContractCode tries to retrieve the contract code of the given account -// with the given hash from the network (id points to the storage trie belonging -// to the same account) -func retrieveContractCode(ctx context.Context, odr OdrBackend, id *TrieID, hash common.Hash) ([]byte, error) { - if hash == sha3_nil { - return nil, nil - } - res, _ := odr.Database().Get(hash[:]) - if res != nil { - return res, nil - } - r := &CodeRequest{Id: id, Hash: hash} - if err := odr.Retrieve(ctx, r); err != nil { - return nil, err - } else { - return r.Data, nil - } -} - // GetBodyRLP retrieves the block body (transactions and uncles) in RLP encoding. func GetBodyRLP(ctx context.Context, odr OdrBackend, hash common.Hash, number uint64) (rlp.RawValue, error) { if data := core.GetBodyRLP(odr.Database(), hash, number); data != nil { diff --git a/light/state.go b/light/state.go deleted file mode 100644 index b184dc3a5..000000000 --- a/light/state.go +++ /dev/null @@ -1,316 +0,0 @@ -// Copyright 2015 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. - -package light - -import ( - "context" - "math/big" - - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/crypto" -) - -// LightState is a memory representation of a state. -// This version is ODR capable, caching only the already accessed part of the -// state, retrieving unknown parts on-demand from the ODR backend. Changes are -// never stored in the local database, only in the memory objects. -type LightState struct { - odr OdrBackend - trie *LightTrie - id *TrieID - stateObjects map[string]*StateObject - refund *big.Int -} - -// NewLightState creates a new LightState with the specified root. -// Note that the creation of a light state is always successful, even if the -// root is non-existent. In that case, ODR retrieval will always be unsuccessful -// and every operation will return with an error or wait for the context to be -// cancelled. -func NewLightState(id *TrieID, odr OdrBackend) *LightState { - var tr *LightTrie - if id != nil { - tr = NewLightTrie(id, odr, true) - } - return &LightState{ - odr: odr, - trie: tr, - id: id, - stateObjects: make(map[string]*StateObject), - refund: new(big.Int), - } -} - -// AddRefund adds an amount to the refund value collected during a vm execution -func (self *LightState) AddRefund(gas *big.Int) { - self.refund.Add(self.refund, gas) -} - -// HasAccount returns true if an account exists at the given address -func (self *LightState) HasAccount(ctx context.Context, addr common.Address) (bool, error) { - so, err := self.GetStateObject(ctx, addr) - return so != nil, err -} - -// GetBalance retrieves the balance from the given address or 0 if the account does -// not exist -func (self *LightState) GetBalance(ctx context.Context, addr common.Address) (*big.Int, error) { - stateObject, err := self.GetStateObject(ctx, addr) - if err != nil { - return common.Big0, err - } - if stateObject != nil { - return stateObject.balance, nil - } - - return common.Big0, nil -} - -// GetNonce returns the nonce at the given address or 0 if the account does -// not exist -func (self *LightState) GetNonce(ctx context.Context, addr common.Address) (uint64, error) { - stateObject, err := self.GetStateObject(ctx, addr) - if err != nil { - return 0, err - } - if stateObject != nil { - return stateObject.nonce, nil - } - return 0, nil -} - -// GetCode returns the contract code at the given address or nil if the account -// does not exist -func (self *LightState) GetCode(ctx context.Context, addr common.Address) ([]byte, error) { - stateObject, err := self.GetStateObject(ctx, addr) - if err != nil { - return nil, err - } - if stateObject != nil { - return stateObject.code, nil - } - return nil, nil -} - -// GetState returns the contract storage value at storage address b from the -// contract address a or common.Hash{} if the account does not exist -func (self *LightState) GetState(ctx context.Context, a common.Address, b common.Hash) (common.Hash, error) { - stateObject, err := self.GetStateObject(ctx, a) - if err == nil && stateObject != nil { - return stateObject.GetState(ctx, b) - } - return common.Hash{}, err -} - -// HasSuicided returns true if the given account has been marked for deletion -// or false if the account does not exist -func (self *LightState) HasSuicided(ctx context.Context, addr common.Address) (bool, error) { - stateObject, err := self.GetStateObject(ctx, addr) - if err == nil && stateObject != nil { - return stateObject.remove, nil - } - return false, err -} - -/* - * SETTERS - */ - -// AddBalance adds the given amount to the balance of the specified account -func (self *LightState) AddBalance(ctx context.Context, addr common.Address, amount *big.Int) error { - stateObject, err := self.GetOrNewStateObject(ctx, addr) - if err == nil && stateObject != nil { - stateObject.AddBalance(amount) - } - return err -} - -// SubBalance adds the given amount to the balance of the specified account -func (self *LightState) SubBalance(ctx context.Context, addr common.Address, amount *big.Int) error { - stateObject, err := self.GetOrNewStateObject(ctx, addr) - if err == nil && stateObject != nil { - stateObject.SubBalance(amount) - } - return err -} - -// SetNonce sets the nonce of the specified account -func (self *LightState) SetNonce(ctx context.Context, addr common.Address, nonce uint64) error { - stateObject, err := self.GetOrNewStateObject(ctx, addr) - if err == nil && stateObject != nil { - stateObject.SetNonce(nonce) - } - return err -} - -// SetCode sets the contract code at the specified account -func (self *LightState) SetCode(ctx context.Context, addr common.Address, code []byte) error { - stateObject, err := self.GetOrNewStateObject(ctx, addr) - if err == nil && stateObject != nil { - stateObject.SetCode(crypto.Keccak256Hash(code), code) - } - return err -} - -// SetState sets the storage value at storage address key of the account addr -func (self *LightState) SetState(ctx context.Context, addr common.Address, key common.Hash, value common.Hash) error { - stateObject, err := self.GetOrNewStateObject(ctx, addr) - if err == nil && stateObject != nil { - stateObject.SetState(key, value) - } - return err -} - -// Delete marks an account to be removed and clears its balance -func (self *LightState) Suicide(ctx context.Context, addr common.Address) (bool, error) { - stateObject, err := self.GetOrNewStateObject(ctx, addr) - if err == nil && stateObject != nil { - stateObject.MarkForDeletion() - stateObject.balance = new(big.Int) - - return true, nil - } - - return false, err -} - -// -// Get, set, new state object methods -// - -// GetStateObject returns the state object of the given account or nil if the -// account does not exist -func (self *LightState) GetStateObject(ctx context.Context, addr common.Address) (stateObject *StateObject, err error) { - stateObject = self.stateObjects[addr.Str()] - if stateObject != nil { - if stateObject.deleted { - stateObject = nil - } - return stateObject, nil - } - data, err := self.trie.Get(ctx, addr[:]) - if err != nil { - return nil, err - } - if len(data) == 0 { - return nil, nil - } - - stateObject, err = DecodeObject(ctx, self.id, addr, self.odr, []byte(data)) - if err != nil { - return nil, err - } - - self.SetStateObject(stateObject) - - return stateObject, nil -} - -// SetStateObject sets the state object of the given account -func (self *LightState) SetStateObject(object *StateObject) { - self.stateObjects[object.Address().Str()] = object -} - -// GetOrNewStateObject returns the state object of the given account or creates a -// new one if the account does not exist -func (self *LightState) GetOrNewStateObject(ctx context.Context, addr common.Address) (*StateObject, error) { - stateObject, err := self.GetStateObject(ctx, addr) - if err == nil && (stateObject == nil || stateObject.deleted) { - stateObject, err = self.CreateStateObject(ctx, addr) - } - return stateObject, err -} - -// newStateObject creates a state object whether it exists in the state or not -func (self *LightState) newStateObject(addr common.Address) *StateObject { - stateObject := NewStateObject(addr, self.odr) - self.stateObjects[addr.Str()] = stateObject - - return stateObject -} - -// CreateStateObject creates creates a new state object and takes ownership. -// This is different from "NewStateObject" -func (self *LightState) CreateStateObject(ctx context.Context, addr common.Address) (*StateObject, error) { - // Get previous (if any) - so, err := self.GetStateObject(ctx, addr) - if err != nil { - return nil, err - } - // Create a new one - newSo := self.newStateObject(addr) - - // If it existed set the balance to the new account - if so != nil { - newSo.balance = so.balance - } - - return newSo, nil -} - -// ForEachStorage calls a callback function for every key/value pair found -// in the local storage cache. Note that unlike core/state.StateObject, -// light.StateObject only returns cached values and doesn't download the -// entire storage tree. -func (self *LightState) ForEachStorage(ctx context.Context, addr common.Address, cb func(key, value common.Hash) bool) error { - so, err := self.GetStateObject(ctx, addr) - if err != nil { - return err - } - - if so == nil { - return nil - } - - for h, v := range so.storage { - cb(h, v) - } - return nil -} - -// -// Setting, copying of the state methods -// - -// Copy creates a copy of the state -func (self *LightState) Copy() *LightState { - // ignore error - we assume state-to-be-copied always exists - state := NewLightState(nil, self.odr) - state.trie = self.trie - state.id = self.id - for k, stateObject := range self.stateObjects { - if stateObject.dirty { - state.stateObjects[k] = stateObject.Copy() - } - } - - state.refund.Set(self.refund) - return state -} - -// Set copies the contents of the given state onto this state, overwriting -// its contents -func (self *LightState) Set(state *LightState) { - self.trie = state.trie - self.stateObjects = state.stateObjects - self.refund = state.refund -} - -// GetRefund returns the refund value collected during a vm execution -func (self *LightState) GetRefund() *big.Int { - return self.refund -} diff --git a/light/state_object.go b/light/state_object.go deleted file mode 100644 index a54ea1d9f..000000000 --- a/light/state_object.go +++ /dev/null @@ -1,275 +0,0 @@ -// Copyright 2015 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. - -package light - -import ( - "bytes" - "context" - "fmt" - "math/big" - - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/rlp" -) - -var emptyCodeHash = crypto.Keccak256(nil) - -// Code represents a contract code in binary form -type Code []byte - -// String returns a string representation of the code -func (self Code) String() string { - return string(self) //strings.Join(Disassemble(self), " ") -} - -// Storage is a memory map cache of a contract storage -type Storage map[common.Hash]common.Hash - -// String returns a string representation of the storage cache -func (self Storage) String() (str string) { - for key, value := range self { - str += fmt.Sprintf("%X : %X\n", key, value) - } - - return -} - -// Copy copies the contents of a storage cache -func (self Storage) Copy() Storage { - cpy := make(Storage) - for key, value := range self { - cpy[key] = value - } - - return cpy -} - -// StateObject is a memory representation of an account or contract and its storage. -// This version is ODR capable, caching only the already accessed part of the -// storage, retrieving unknown parts on-demand from the ODR backend. Changes are -// never stored in the local database, only in the memory objects. -type StateObject struct { - odr OdrBackend - trie *LightTrie - - // Address belonging to this account - address common.Address - // The balance of the account - balance *big.Int - // The nonce of the account - nonce uint64 - // The code hash if code is present (i.e. a contract) - codeHash []byte - // The code for this account - code Code - // Cached storage (flushed when updated) - storage Storage - - // Mark for deletion - // When an object is marked for deletion it will be delete from the trie - // during the "update" phase of the state transition - remove bool - deleted bool - dirty bool -} - -// NewStateObject creates a new StateObject of the specified account address -func NewStateObject(address common.Address, odr OdrBackend) *StateObject { - object := &StateObject{ - odr: odr, - address: address, - balance: new(big.Int), - dirty: true, - codeHash: emptyCodeHash, - storage: make(Storage), - } - object.trie = NewLightTrie(&TrieID{}, odr, true) - return object -} - -// MarkForDeletion marks an account to be removed -func (self *StateObject) MarkForDeletion() { - self.remove = true - self.dirty = true -} - -// getAddr gets the storage value at the given address from the trie -func (c *StateObject) getAddr(ctx context.Context, addr common.Hash) (common.Hash, error) { - var ret []byte - val, err := c.trie.Get(ctx, addr[:]) - if err != nil { - return common.Hash{}, err - } - rlp.DecodeBytes(val, &ret) - return common.BytesToHash(ret), nil -} - -// Storage returns the storage cache object of the account -func (self *StateObject) Storage() Storage { - return self.storage -} - -// GetState returns the storage value at the given address from either the cache -// or the trie -func (self *StateObject) GetState(ctx context.Context, key common.Hash) (common.Hash, error) { - value, exists := self.storage[key] - if !exists { - var err error - value, err = self.getAddr(ctx, key) - if err != nil { - return common.Hash{}, err - } - if (value != common.Hash{}) { - self.storage[key] = value - } - } - - return value, nil -} - -// SetState sets the storage value at the given address -func (self *StateObject) SetState(k, value common.Hash) { - self.storage[k] = value - self.dirty = true -} - -// AddBalance adds the given amount to the account balance -func (c *StateObject) AddBalance(amount *big.Int) { - c.SetBalance(new(big.Int).Add(c.balance, amount)) -} - -// SubBalance subtracts the given amount from the account balance -func (c *StateObject) SubBalance(amount *big.Int) { - c.SetBalance(new(big.Int).Sub(c.balance, amount)) -} - -// SetBalance sets the account balance to the given amount -func (c *StateObject) SetBalance(amount *big.Int) { - c.balance = amount - c.dirty = true -} - -// ReturnGas returns the gas back to the origin. Used by the Virtual machine or Closures -func (c *StateObject) ReturnGas(gas *big.Int) {} - -// Copy creates a copy of the state object -func (self *StateObject) Copy() *StateObject { - stateObject := NewStateObject(self.Address(), self.odr) - stateObject.balance.Set(self.balance) - stateObject.codeHash = common.CopyBytes(self.codeHash) - stateObject.nonce = self.nonce - stateObject.trie = self.trie - stateObject.code = self.code - stateObject.storage = self.storage.Copy() - stateObject.remove = self.remove - stateObject.dirty = self.dirty - stateObject.deleted = self.deleted - - return stateObject -} - -// -// Attribute accessors -// - -// empty returns whether the account is considered empty. -func (self *StateObject) empty() bool { - return self.nonce == 0 && self.balance.Sign() == 0 && bytes.Equal(self.codeHash, emptyCodeHash) -} - -// Balance returns the account balance -func (self *StateObject) Balance() *big.Int { - return self.balance -} - -// Address returns the address of the contract/account -func (self *StateObject) Address() common.Address { - return self.address -} - -// Code returns the contract code -func (self *StateObject) Code() []byte { - return self.code -} - -// SetCode sets the contract code -func (self *StateObject) SetCode(hash common.Hash, code []byte) { - self.code = code - self.codeHash = hash[:] - self.dirty = true -} - -// SetNonce sets the account nonce -func (self *StateObject) SetNonce(nonce uint64) { - self.nonce = nonce - self.dirty = true -} - -// Nonce returns the account nonce -func (self *StateObject) Nonce() uint64 { - return self.nonce -} - -// ForEachStorage calls a callback function for every key/value pair found -// in the local storage cache. Note that unlike core/state.StateObject, -// light.StateObject only returns cached values and doesn't download the -// entire storage tree. -func (self *StateObject) ForEachStorage(cb func(key, value common.Hash) bool) { - for h, v := range self.storage { - cb(h, v) - } -} - -// Never called, but must be present to allow StateObject to be used -// as a vm.Account interface that also satisfies the vm.ContractRef -// interface. Interfaces are awesome. -func (self *StateObject) Value() *big.Int { - panic("Value on StateObject should never be called") -} - -// Encoding - -type extStateObject struct { - Nonce uint64 - Balance *big.Int - Root common.Hash - CodeHash []byte -} - -// DecodeObject decodes an RLP-encoded state object. -func DecodeObject(ctx context.Context, stateID *TrieID, address common.Address, odr OdrBackend, data []byte) (*StateObject, error) { - var ( - obj = &StateObject{address: address, odr: odr, storage: make(Storage)} - ext extStateObject - err error - ) - if err = rlp.DecodeBytes(data, &ext); err != nil { - return nil, err - } - trieID := StorageTrieID(stateID, address, ext.Root) - obj.trie = NewLightTrie(trieID, odr, true) - if !bytes.Equal(ext.CodeHash, emptyCodeHash) { - if obj.code, err = retrieveContractCode(ctx, obj.odr, trieID, common.BytesToHash(ext.CodeHash)); err != nil { - return nil, fmt.Errorf("can't find code for hash %x: %v", ext.CodeHash, err) - } - } - obj.nonce = ext.Nonce - obj.balance = ext.Balance - obj.codeHash = ext.CodeHash - return obj, nil -} diff --git a/light/state_test.go b/light/state_test.go deleted file mode 100644 index e776efec8..000000000 --- a/light/state_test.go +++ /dev/null @@ -1,248 +0,0 @@ -// Copyright 2015 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. - -package light - -import ( - "bytes" - "context" - "math/big" - "testing" - - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core" - "github.com/ethereum/go-ethereum/core/state" - "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/ethdb" -) - -func makeTestState() (common.Hash, ethdb.Database) { - sdb, _ := ethdb.NewMemDatabase() - st, _ := state.New(common.Hash{}, sdb) - for i := byte(0); i < 100; i++ { - addr := common.Address{i} - for j := byte(0); j < 100; j++ { - st.SetState(addr, common.Hash{j}, common.Hash{i, j}) - } - st.SetNonce(addr, 100) - st.AddBalance(addr, big.NewInt(int64(i))) - st.SetCode(addr, []byte{i, i, i}) - } - root, _ := st.Commit(false) - return root, sdb -} - -func TestLightStateOdr(t *testing.T) { - root, sdb := makeTestState() - header := &types.Header{Root: root, Number: big.NewInt(0)} - core.WriteHeader(sdb, header) - ldb, _ := ethdb.NewMemDatabase() - odr := &testOdr{sdb: sdb, ldb: ldb} - ls := NewLightState(StateTrieID(header), odr) - ctx := context.Background() - - for i := byte(0); i < 100; i++ { - addr := common.Address{i} - err := ls.AddBalance(ctx, addr, big.NewInt(1000)) - if err != nil { - t.Fatalf("Error adding balance to acc[%d]: %v", i, err) - } - err = ls.SetState(ctx, addr, common.Hash{100}, common.Hash{i, 100}) - if err != nil { - t.Fatalf("Error setting storage of acc[%d]: %v", i, err) - } - } - - addr := common.Address{100} - _, err := ls.CreateStateObject(ctx, addr) - if err != nil { - t.Fatalf("Error creating state object: %v", err) - } - err = ls.SetCode(ctx, addr, []byte{100, 100, 100}) - if err != nil { - t.Fatalf("Error setting code: %v", err) - } - err = ls.AddBalance(ctx, addr, big.NewInt(1100)) - if err != nil { - t.Fatalf("Error adding balance to acc[100]: %v", err) - } - for j := byte(0); j < 101; j++ { - err = ls.SetState(ctx, addr, common.Hash{j}, common.Hash{100, j}) - if err != nil { - t.Fatalf("Error setting storage of acc[100]: %v", err) - } - } - err = ls.SetNonce(ctx, addr, 100) - if err != nil { - t.Fatalf("Error setting nonce for acc[100]: %v", err) - } - - for i := byte(0); i < 101; i++ { - addr := common.Address{i} - - bal, err := ls.GetBalance(ctx, addr) - if err != nil { - t.Fatalf("Error getting balance of acc[%d]: %v", i, err) - } - if bal.Int64() != int64(i)+1000 { - t.Fatalf("Incorrect balance at acc[%d]: expected %v, got %v", i, int64(i)+1000, bal.Int64()) - } - - nonce, err := ls.GetNonce(ctx, addr) - if err != nil { - t.Fatalf("Error getting nonce of acc[%d]: %v", i, err) - } - if nonce != 100 { - t.Fatalf("Incorrect nonce at acc[%d]: expected %v, got %v", i, 100, nonce) - } - - code, err := ls.GetCode(ctx, addr) - exp := []byte{i, i, i} - if err != nil { - t.Fatalf("Error getting code of acc[%d]: %v", i, err) - } - if !bytes.Equal(code, exp) { - t.Fatalf("Incorrect code at acc[%d]: expected %v, got %v", i, exp, code) - } - - for j := byte(0); j < 101; j++ { - exp := common.Hash{i, j} - val, err := ls.GetState(ctx, addr, common.Hash{j}) - if err != nil { - t.Fatalf("Error retrieving acc[%d].storage[%d]: %v", i, j, err) - } - if val != exp { - t.Fatalf("Retrieved wrong value from acc[%d].storage[%d]: expected %04x, got %04x", i, j, exp, val) - } - } - } -} - -func TestLightStateSetCopy(t *testing.T) { - root, sdb := makeTestState() - header := &types.Header{Root: root, Number: big.NewInt(0)} - core.WriteHeader(sdb, header) - ldb, _ := ethdb.NewMemDatabase() - odr := &testOdr{sdb: sdb, ldb: ldb} - ls := NewLightState(StateTrieID(header), odr) - ctx := context.Background() - - for i := byte(0); i < 100; i++ { - addr := common.Address{i} - err := ls.AddBalance(ctx, addr, big.NewInt(1000)) - if err != nil { - t.Fatalf("Error adding balance to acc[%d]: %v", i, err) - } - err = ls.SetState(ctx, addr, common.Hash{100}, common.Hash{i, 100}) - if err != nil { - t.Fatalf("Error setting storage of acc[%d]: %v", i, err) - } - } - - ls2 := ls.Copy() - - for i := byte(0); i < 100; i++ { - addr := common.Address{i} - err := ls2.AddBalance(ctx, addr, big.NewInt(1000)) - if err != nil { - t.Fatalf("Error adding balance to acc[%d]: %v", i, err) - } - err = ls2.SetState(ctx, addr, common.Hash{100}, common.Hash{i, 200}) - if err != nil { - t.Fatalf("Error setting storage of acc[%d]: %v", i, err) - } - } - - lsx := ls.Copy() - ls.Set(ls2) - ls2.Set(lsx) - - for i := byte(0); i < 100; i++ { - addr := common.Address{i} - // check balance in ls - bal, err := ls.GetBalance(ctx, addr) - if err != nil { - t.Fatalf("Error getting balance to acc[%d]: %v", i, err) - } - if bal.Int64() != int64(i)+2000 { - t.Fatalf("Incorrect balance at ls.acc[%d]: expected %v, got %v", i, int64(i)+1000, bal.Int64()) - } - // check balance in ls2 - bal, err = ls2.GetBalance(ctx, addr) - if err != nil { - t.Fatalf("Error getting balance to acc[%d]: %v", i, err) - } - if bal.Int64() != int64(i)+1000 { - t.Fatalf("Incorrect balance at ls.acc[%d]: expected %v, got %v", i, int64(i)+1000, bal.Int64()) - } - // check storage in ls - exp := common.Hash{i, 200} - val, err := ls.GetState(ctx, addr, common.Hash{100}) - if err != nil { - t.Fatalf("Error retrieving acc[%d].storage[100]: %v", i, err) - } - if val != exp { - t.Fatalf("Retrieved wrong value from acc[%d].storage[100]: expected %04x, got %04x", i, exp, val) - } - // check storage in ls2 - exp = common.Hash{i, 100} - val, err = ls2.GetState(ctx, addr, common.Hash{100}) - if err != nil { - t.Fatalf("Error retrieving acc[%d].storage[100]: %v", i, err) - } - if val != exp { - t.Fatalf("Retrieved wrong value from acc[%d].storage[100]: expected %04x, got %04x", i, exp, val) - } - } -} - -func TestLightStateDelete(t *testing.T) { - root, sdb := makeTestState() - header := &types.Header{Root: root, Number: big.NewInt(0)} - core.WriteHeader(sdb, header) - ldb, _ := ethdb.NewMemDatabase() - odr := &testOdr{sdb: sdb, ldb: ldb} - ls := NewLightState(StateTrieID(header), odr) - ctx := context.Background() - - addr := common.Address{42} - - b, err := ls.HasAccount(ctx, addr) - if err != nil { - t.Fatalf("HasAccount error: %v", err) - } - if !b { - t.Fatalf("HasAccount returned false, expected true") - } - - b, err = ls.HasSuicided(ctx, addr) - if err != nil { - t.Fatalf("HasSuicided error: %v", err) - } - if b { - t.Fatalf("HasSuicided returned true, expected false") - } - - ls.Suicide(ctx, addr) - - b, err = ls.HasSuicided(ctx, addr) - if err != nil { - t.Fatalf("HasSuicided error: %v", err) - } - if !b { - t.Fatalf("HasSuicided returned false, expected true") - } -} diff --git a/light/trie.go b/light/trie.go index 2988a16cf..7502b6e5d 100644 --- a/light/trie.go +++ b/light/trie.go @@ -18,99 +18,216 @@ package light import ( "context" + "fmt" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/trie" ) -// LightTrie is an ODR-capable wrapper around trie.SecureTrie -type LightTrie struct { - trie *trie.SecureTrie +func NewState(ctx context.Context, head *types.Header, odr OdrBackend) *state.StateDB { + state, _ := state.New(head.Root, NewStateDatabase(ctx, head, odr)) + return state +} + +func NewStateDatabase(ctx context.Context, head *types.Header, odr OdrBackend) state.Database { + return &odrDatabase{ctx, StateTrieID(head), odr} +} + +type odrDatabase struct { + ctx context.Context + id *TrieID + backend OdrBackend +} + +func (db *odrDatabase) OpenTrie(root common.Hash) (state.Trie, error) { + return &odrTrie{db: db, id: db.id}, nil +} + +func (db *odrDatabase) OpenStorageTrie(addrHash, root common.Hash) (state.Trie, error) { + return &odrTrie{db: db, id: StorageTrieID(db.id, addrHash, root)}, nil +} + +func (db *odrDatabase) CopyTrie(t state.Trie) state.Trie { + switch t := t.(type) { + case *odrTrie: + cpy := &odrTrie{db: t.db, id: t.id} + if t.trie != nil { + cpytrie := *t.trie + cpy.trie = &cpytrie + } + return cpy + default: + panic(fmt.Errorf("unknown trie type %T", t)) + } +} + +func (db *odrDatabase) ContractCode(addrHash, codeHash common.Hash) ([]byte, error) { + if codeHash == sha3_nil { + return nil, nil + } + if code, err := db.backend.Database().Get(codeHash[:]); err == nil { + return code, nil + } + id := *db.id + id.AccKey = addrHash[:] + req := &CodeRequest{Id: &id, Hash: codeHash} + err := db.backend.Retrieve(db.ctx, req) + return req.Data, err +} + +func (db *odrDatabase) ContractCodeSize(addrHash, codeHash common.Hash) (int, error) { + code, err := db.ContractCode(addrHash, codeHash) + return len(code), err +} + +type odrTrie struct { + db *odrDatabase id *TrieID - odr OdrBackend - db ethdb.Database -} - -// NewLightTrie creates a new LightTrie instance. It doesn't instantly try to -// access the db or network and retrieve the root node, it only initializes its -// encapsulated SecureTrie at the first actual operation. -func NewLightTrie(id *TrieID, odr OdrBackend, useFakeMap bool) *LightTrie { - return &LightTrie{ - // SecureTrie is initialized before first request - id: id, - odr: odr, - db: odr.Database(), + trie *trie.Trie +} + +func (t *odrTrie) TryGet(key []byte) ([]byte, error) { + key = crypto.Keccak256(key) + var res []byte + err := t.do(key, func() (err error) { + res, err = t.trie.TryGet(key) + return err + }) + return res, err +} + +func (t *odrTrie) TryUpdate(key, value []byte) error { + key = crypto.Keccak256(key) + return t.do(key, func() error { + return t.trie.TryDelete(key) + }) +} + +func (t *odrTrie) TryDelete(key []byte) error { + key = crypto.Keccak256(key) + return t.do(key, func() error { + return t.trie.TryDelete(key) + }) +} + +func (t *odrTrie) CommitTo(db trie.DatabaseWriter) (common.Hash, error) { + if t.trie == nil { + return t.id.Root, nil + } + return t.trie.CommitTo(db) +} + +func (t *odrTrie) Hash() common.Hash { + if t.trie == nil { + return t.id.Root } + return t.trie.Hash() +} + +func (t *odrTrie) NodeIterator(startkey []byte) trie.NodeIterator { + return newNodeIterator(t, startkey) } -// retrieveKey retrieves a single key, returns true and stores nodes in local -// database if successful -func (t *LightTrie) retrieveKey(ctx context.Context, key []byte) bool { - r := &TrieRequest{Id: t.id, Key: crypto.Keccak256(key)} - return t.odr.Retrieve(ctx, r) == nil +func (t *odrTrie) GetKey(sha []byte) []byte { + return nil } // do tries and retries to execute a function until it returns with no error or // an error type other than MissingNodeError -func (t *LightTrie) do(ctx context.Context, key []byte, fn func() error) error { - err := fn() - for err != nil { +func (t *odrTrie) do(key []byte, fn func() error) error { + for { + var err error + if t.trie == nil { + t.trie, err = trie.New(t.id.Root, t.db.backend.Database()) + } + if err == nil { + err = fn() + } if _, ok := err.(*trie.MissingNodeError); !ok { return err } - if !t.retrieveKey(ctx, key) { - break + r := &TrieRequest{Id: t.id, Key: key} + if err := t.db.backend.Retrieve(t.db.ctx, r); err != nil { + return fmt.Errorf("can't fetch trie key %x: %v", key, err) } - err = fn() } - return err } -// Get returns the value for key stored in the trie. -// The value bytes must not be modified by the caller. -func (t *LightTrie) Get(ctx context.Context, key []byte) (res []byte, err error) { - err = t.do(ctx, key, func() (err error) { - if t.trie == nil { - t.trie, err = trie.NewSecure(t.id.Root, t.db, 0) - } - if err == nil { - res, err = t.trie.TryGet(key) - } - return +type nodeIterator struct { + trie.NodeIterator + t *odrTrie + err error +} + +func newNodeIterator(t *odrTrie, startkey []byte) trie.NodeIterator { + it := &nodeIterator{t: t} + // Open the actual non-ODR trie if that hasn't happened yet. + if t.trie == nil { + it.do(func() error { + t, err := trie.New(t.id.Root, t.db.backend.Database()) + if err == nil { + it.t.trie = t + } + return err + }) + } + it.do(func() error { + it.NodeIterator = it.t.trie.NodeIterator(startkey) + return it.NodeIterator.Error() }) - return + return it } -// Update associates key with value in the trie. Subsequent calls to -// Get will return value. If value has length zero, any existing value -// is deleted from the trie and calls to Get will return nil. -// -// The value bytes must not be modified by the caller while they are -// stored in the trie. -func (t *LightTrie) Update(ctx context.Context, key, value []byte) (err error) { - err = t.do(ctx, key, func() (err error) { - if t.trie == nil { - t.trie, err = trie.NewSecure(t.id.Root, t.db, 0) - } - if err == nil { - err = t.trie.TryUpdate(key, value) - } - return +func (it *nodeIterator) Next(descend bool) bool { + var ok bool + it.do(func() error { + ok = it.NodeIterator.Next(descend) + return it.NodeIterator.Error() }) - return + return ok } -// Delete removes any existing value for key from the trie. -func (t *LightTrie) Delete(ctx context.Context, key []byte) (err error) { - err = t.do(ctx, key, func() (err error) { - if t.trie == nil { - t.trie, err = trie.NewSecure(t.id.Root, t.db, 0) +// do runs fn and attempts to fill in missing nodes by retrieving. +func (it *nodeIterator) do(fn func() error) { + var lasthash common.Hash + for { + it.err = fn() + missing, ok := it.err.(*trie.MissingNodeError) + if !ok { + return } - if err == nil { - err = t.trie.TryDelete(key) + if missing.NodeHash == lasthash { + it.err = fmt.Errorf("retrieve loop for trie node %x", missing.NodeHash) + return } - return - }) - return + lasthash = missing.NodeHash + r := &TrieRequest{Id: it.t.id, Key: nibblesToKey(missing.Path)} + if it.err = it.t.db.backend.Retrieve(it.t.db.ctx, r); it.err != nil { + return + } + } +} + +func (it *nodeIterator) Error() error { + if it.err != nil { + return it.err + } + return it.NodeIterator.Error() +} + +func nibblesToKey(nib []byte) []byte { + if len(nib) > 0 && nib[len(nib)-1] == 0x10 { + nib = nib[:len(nib)-1] // drop terminator + } + if len(nib)&1 == 1 { + nib = append(nib, 0) // make even + } + key := make([]byte, len(nib)/2) + for bi, ni := 0, 0; ni < len(nib); bi, ni = bi+1, ni+2 { + key[bi] = nib[ni]<<4 | nib[ni+1] + } + return key } diff --git a/light/trie_test.go b/light/trie_test.go new file mode 100644 index 000000000..9b2cf7c2b --- /dev/null +++ b/light/trie_test.go @@ -0,0 +1,83 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +package light + +import ( + "bytes" + "context" + "fmt" + "testing" + + "github.com/davecgh/go-spew/spew" + "github.com/ethereum/go-ethereum/consensus/ethash" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/core/vm" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/event" + "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/trie" +) + +func TestNodeIterator(t *testing.T) { + var ( + fulldb, _ = ethdb.NewMemDatabase() + lightdb, _ = ethdb.NewMemDatabase() + gspec = core.Genesis{Alloc: core.GenesisAlloc{testBankAddress: {Balance: testBankFunds}}} + genesis = gspec.MustCommit(fulldb) + ) + gspec.MustCommit(lightdb) + blockchain, _ := core.NewBlockChain(fulldb, params.TestChainConfig, ethash.NewFullFaker(), new(event.TypeMux), vm.Config{}) + gchain, _ := core.GenerateChain(params.TestChainConfig, genesis, fulldb, 4, testChainGen) + if _, err := blockchain.InsertChain(gchain); err != nil { + panic(err) + } + + ctx := context.Background() + odr := &testOdr{sdb: fulldb, ldb: lightdb} + head := blockchain.CurrentHeader() + lightTrie, _ := NewStateDatabase(ctx, head, odr).OpenTrie(head.Root) + fullTrie, _ := state.NewDatabase(fulldb).OpenTrie(head.Root) + if err := diffTries(fullTrie, lightTrie); err != nil { + t.Fatal(err) + } +} + +func diffTries(t1, t2 state.Trie) error { + i1 := trie.NewIterator(t1.NodeIterator(nil)) + i2 := trie.NewIterator(t2.NodeIterator(nil)) + for i1.Next() && i2.Next() { + if !bytes.Equal(i1.Key, i2.Key) { + spew.Dump(i2) + return fmt.Errorf("tries have different keys %x, %x", i1.Key, i2.Key) + } + if !bytes.Equal(i2.Value, i2.Value) { + return fmt.Errorf("tries differ at key %x", i1.Key) + } + } + switch { + case i1.Err != nil: + return fmt.Errorf("full trie iterator error: %v", i1.Err) + case i2.Err != nil: + return fmt.Errorf("light trie iterator error: %v", i1.Err) + case i1.Next(): + return fmt.Errorf("full trie iterator has more k/v pairs") + case i2.Next(): + return fmt.Errorf("light trie iterator has more k/v pairs") + } + return nil +} diff --git a/light/txpool.go b/light/txpool.go index 7276874b8..0430b280f 100644 --- a/light/txpool.go +++ b/light/txpool.go @@ -24,6 +24,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/event" @@ -100,17 +101,18 @@ func NewTxPool(config *params.ChainConfig, eventMux *event.TypeMux, chain *Light } // currentState returns the light state of the current head header -func (pool *TxPool) currentState() *LightState { - return NewLightState(StateTrieID(pool.chain.CurrentHeader()), pool.odr) +func (pool *TxPool) currentState(ctx context.Context) *state.StateDB { + return NewState(ctx, pool.chain.CurrentHeader(), pool.odr) } // GetNonce returns the "pending" nonce of a given address. It always queries // the nonce belonging to the latest header too in order to detect if another // client using the same key sent a transaction. func (pool *TxPool) GetNonce(ctx context.Context, addr common.Address) (uint64, error) { - nonce, err := pool.currentState().GetNonce(ctx, addr) - if err != nil { - return 0, err + state := pool.currentState(ctx) + nonce := state.GetNonce(addr) + if state.Error() != nil { + return 0, state.Error() } sn, ok := pool.nonce[addr] if ok && sn > nonce { @@ -357,13 +359,9 @@ func (pool *TxPool) validateTx(ctx context.Context, tx *types.Transaction) error return core.ErrInvalidSender } // Last but not least check for nonce errors - currentState := pool.currentState() - if n, err := currentState.GetNonce(ctx, from); err == nil { - if n > tx.Nonce() { - return core.ErrNonceTooLow - } - } else { - return err + currentState := pool.currentState(ctx) + if n := currentState.GetNonce(from); n > tx.Nonce() { + return core.ErrNonceTooLow } // Check the transaction doesn't exceed the current @@ -382,12 +380,8 @@ func (pool *TxPool) validateTx(ctx context.Context, tx *types.Transaction) error // Transactor should have enough funds to cover the costs // cost == V + GP * GL - if b, err := currentState.GetBalance(ctx, from); err == nil { - if b.Cmp(tx.Cost()) < 0 { - return core.ErrInsufficientFunds - } - } else { - return err + if b := currentState.GetBalance(from); b.Cmp(tx.Cost()) < 0 { + return core.ErrInsufficientFunds } // Should supply enough intrinsic gas @@ -395,7 +389,7 @@ func (pool *TxPool) validateTx(ctx context.Context, tx *types.Transaction) error return core.ErrIntrinsicGas } - return nil + return currentState.Error() } // add validates a new transaction and sets its state pending if processable. diff --git a/light/vm_env.go b/light/vm_env.go deleted file mode 100644 index 54aa12875..000000000 --- a/light/vm_env.go +++ /dev/null @@ -1,194 +0,0 @@ -// Copyright 2016 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. - -package light - -import ( - "context" - "math/big" - - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/crypto" -) - -// VMState is a wrapper for the light state that holds the actual context and -// passes it to any state operation that requires it. -type VMState struct { - ctx context.Context - state *LightState - snapshots []*LightState - err error -} - -func NewVMState(ctx context.Context, state *LightState) *VMState { - return &VMState{ctx: ctx, state: state} -} - -func (s *VMState) Error() error { - return s.err -} - -func (s *VMState) AddLog(log *types.Log) {} - -func (s *VMState) AddPreimage(hash common.Hash, preimage []byte) {} - -// errHandler handles and stores any state error that happens during execution. -func (s *VMState) errHandler(err error) { - if err != nil && s.err == nil { - s.err = err - } -} - -func (self *VMState) Snapshot() int { - self.snapshots = append(self.snapshots, self.state.Copy()) - return len(self.snapshots) - 1 -} - -func (self *VMState) RevertToSnapshot(idx int) { - self.state.Set(self.snapshots[idx]) - self.snapshots = self.snapshots[:idx] -} - -// CreateAccount creates creates a new account object and takes ownership. -func (s *VMState) CreateAccount(addr common.Address) { - _, err := s.state.CreateStateObject(s.ctx, addr) - s.errHandler(err) -} - -// AddBalance adds the given amount to the balance of the specified account -func (s *VMState) AddBalance(addr common.Address, amount *big.Int) { - err := s.state.AddBalance(s.ctx, addr, amount) - s.errHandler(err) -} - -// SubBalance adds the given amount to the balance of the specified account -func (s *VMState) SubBalance(addr common.Address, amount *big.Int) { - err := s.state.SubBalance(s.ctx, addr, amount) - s.errHandler(err) -} - -// ForEachStorage calls a callback function for every key/value pair found -// in the local storage cache. Note that unlike core/state.StateObject, -// light.StateObject only returns cached values and doesn't download the -// entire storage tree. -func (s *VMState) ForEachStorage(addr common.Address, cb func(key, value common.Hash) bool) { - err := s.state.ForEachStorage(s.ctx, addr, cb) - s.errHandler(err) -} - -// GetBalance retrieves the balance from the given address or 0 if the account does -// not exist -func (s *VMState) GetBalance(addr common.Address) *big.Int { - res, err := s.state.GetBalance(s.ctx, addr) - s.errHandler(err) - return res -} - -// GetNonce returns the nonce at the given address or 0 if the account does -// not exist -func (s *VMState) GetNonce(addr common.Address) uint64 { - res, err := s.state.GetNonce(s.ctx, addr) - s.errHandler(err) - return res -} - -// SetNonce sets the nonce of the specified account -func (s *VMState) SetNonce(addr common.Address, nonce uint64) { - err := s.state.SetNonce(s.ctx, addr, nonce) - s.errHandler(err) -} - -// GetCode returns the contract code at the given address or nil if the account -// does not exist -func (s *VMState) GetCode(addr common.Address) []byte { - res, err := s.state.GetCode(s.ctx, addr) - s.errHandler(err) - return res -} - -// GetCodeHash returns the contract code hash at the given address -func (s *VMState) GetCodeHash(addr common.Address) common.Hash { - res, err := s.state.GetCode(s.ctx, addr) - s.errHandler(err) - return crypto.Keccak256Hash(res) -} - -// GetCodeSize returns the contract code size at the given address -func (s *VMState) GetCodeSize(addr common.Address) int { - res, err := s.state.GetCode(s.ctx, addr) - s.errHandler(err) - return len(res) -} - -// SetCode sets the contract code at the specified account -func (s *VMState) SetCode(addr common.Address, code []byte) { - err := s.state.SetCode(s.ctx, addr, code) - s.errHandler(err) -} - -// AddRefund adds an amount to the refund value collected during a vm execution -func (s *VMState) AddRefund(gas *big.Int) { - s.state.AddRefund(gas) -} - -// GetRefund returns the refund value collected during a vm execution -func (s *VMState) GetRefund() *big.Int { - return s.state.GetRefund() -} - -// GetState returns the contract storage value at storage address b from the -// contract address a or common.Hash{} if the account does not exist -func (s *VMState) GetState(a common.Address, b common.Hash) common.Hash { - res, err := s.state.GetState(s.ctx, a, b) - s.errHandler(err) - return res -} - -// SetState sets the storage value at storage address key of the account addr -func (s *VMState) SetState(addr common.Address, key common.Hash, value common.Hash) { - err := s.state.SetState(s.ctx, addr, key, value) - s.errHandler(err) -} - -// Suicide marks an account to be removed and clears its balance -func (s *VMState) Suicide(addr common.Address) bool { - res, err := s.state.Suicide(s.ctx, addr) - s.errHandler(err) - return res -} - -// Exist returns true if an account exists at the given address -func (s *VMState) Exist(addr common.Address) bool { - res, err := s.state.HasAccount(s.ctx, addr) - s.errHandler(err) - return res -} - -// Empty returns true if the account at the given address is considered empty -func (s *VMState) Empty(addr common.Address) bool { - so, err := s.state.GetStateObject(s.ctx, addr) - s.errHandler(err) - return so == nil || so.empty() -} - -// HasSuicided returns true if the given account has been marked for deletion -// or false if the account does not exist -func (s *VMState) HasSuicided(addr common.Address) bool { - res, err := s.state.HasSuicided(s.ctx, addr) - s.errHandler(err) - return res -} diff --git a/miner/worker.go b/miner/worker.go index 803015390..e44514755 100644 --- a/miner/worker.go +++ b/miner/worker.go @@ -274,7 +274,7 @@ func (self *worker) wait() { } go self.mux.Post(core.NewMinedBlockEvent{Block: block}) } else { - work.state.Commit(self.config.IsEIP158(block.Number())) + work.state.CommitTo(self.chainDb, self.config.IsEIP158(block.Number())) stat, err := self.chain.WriteBlock(block) if err != nil { log.Error("Failed writing block to chain", "err", err) diff --git a/tests/block_test_util.go b/tests/block_test_util.go index b9678a77b..24d4672b6 100644 --- a/tests/block_test_util.go +++ b/tests/block_test_util.go @@ -204,7 +204,7 @@ func runBlockTest(homesteadBlock, daoForkBlock, gasPriceFork *big.Int, test *Blo // InsertPreState populates the given database with the genesis // accounts defined by the test. func (t *BlockTest) InsertPreState(db ethdb.Database) (*state.StateDB, error) { - statedb, err := state.New(common.Hash{}, db) + statedb, err := state.New(common.Hash{}, state.NewDatabase(db)) if err != nil { return nil, err } @@ -232,7 +232,7 @@ func (t *BlockTest) InsertPreState(db ethdb.Database) (*state.StateDB, error) { } } - root, err := statedb.Commit(false) + root, err := statedb.CommitTo(db, false) if err != nil { return nil, fmt.Errorf("error writing state: %v", err) } diff --git a/tests/state_test_util.go b/tests/state_test_util.go index c1892cdcc..58acdd488 100644 --- a/tests/state_test_util.go +++ b/tests/state_test_util.go @@ -20,7 +20,6 @@ import ( "bytes" "fmt" "io" - "math/big" "strconv" "strings" "testing" @@ -99,7 +98,7 @@ func benchStateTest(chainConfig *params.ChainConfig, test VmTest, env map[string statedb := makePreState(db, test.Pre) b.StartTimer() - RunState(chainConfig, statedb, env, test.Exec) + RunState(chainConfig, statedb, db, env, test.Exec) } func runStateTests(chainConfig *params.ChainConfig, tests map[string]VmTest, skipTests []string) error { @@ -143,16 +142,9 @@ func runStateTest(chainConfig *params.ChainConfig, test VmTest) error { env["currentTimestamp"] = test.Env.CurrentTimestamp.(string) } - var ( - ret []byte - // gas *big.Int - // err error - logs []*types.Log - ) + ret, logs, root, _ := RunState(chainConfig, statedb, db, env, test.Transaction) - ret, logs, _, _ = RunState(chainConfig, statedb, env, test.Transaction) - - // Compare expected and actual return + // Return value: var rexp []byte if strings.HasPrefix(test.Out, "#") { n, _ := strconv.Atoi(test.Out[1:]) @@ -163,61 +155,43 @@ func runStateTest(chainConfig *params.ChainConfig, test VmTest) error { if !bytes.Equal(rexp, ret) { return fmt.Errorf("return failed. Expected %x, got %x\n", rexp, ret) } - - // check post state + // Post state content: for addr, account := range test.Post { address := common.HexToAddress(addr) if !statedb.Exist(address) { return fmt.Errorf("did not find expected post-state account: %s", addr) } - if balance := statedb.GetBalance(address); balance.Cmp(math.MustParseBig256(account.Balance)) != 0 { return fmt.Errorf("(%x) balance failed. Expected: %v have: %v\n", address[:4], math.MustParseBig256(account.Balance), balance) } - if nonce := statedb.GetNonce(address); nonce != math.MustParseUint64(account.Nonce) { return fmt.Errorf("(%x) nonce failed. Expected: %v have: %v\n", address[:4], account.Nonce, nonce) } - for addr, value := range account.Storage { v := statedb.GetState(address, common.HexToHash(addr)) vexp := common.HexToHash(value) - if v != vexp { return fmt.Errorf("storage failed:\n%x: %s:\nexpected: %x\nhave: %x\n(%v %v)\n", address[:4], addr, vexp, v, vexp.Big(), v.Big()) } } } - - root, _ := statedb.Commit(false) + // Root: if common.HexToHash(test.PostStateRoot) != root { return fmt.Errorf("Post state root error. Expected: %s have: %x", test.PostStateRoot, root) } - - // check logs - if len(test.Logs) > 0 { - if err := checkLogs(test.Logs, logs); err != nil { - return err - } - } - - return nil + // Logs: + return checkLogs(test.Logs, logs) } -func RunState(chainConfig *params.ChainConfig, statedb *state.StateDB, env, tx map[string]string) ([]byte, []*types.Log, *big.Int, error) { +func RunState(chainConfig *params.ChainConfig, statedb *state.StateDB, db ethdb.Database, env, tx map[string]string) ([]byte, []*types.Log, common.Hash, error) { environment, msg := NewEVMEnvironment(false, chainConfig, statedb, env, tx) gaspool := new(core.GasPool).AddGas(math.MustParseBig256(env["currentGasLimit"])) - root, _ := statedb.Commit(false) - statedb.Reset(root) - snapshot := statedb.Snapshot() - - ret, gasUsed, err := core.ApplyMessage(environment, msg, gaspool) + ret, _, err := core.ApplyMessage(environment, msg, gaspool) if err != nil { statedb.RevertToSnapshot(snapshot) } - statedb.Commit(chainConfig.IsEIP158(environment.Context.BlockNumber)) - - return ret, statedb.Logs(), gasUsed, err + root, _ := statedb.CommitTo(db, chainConfig.IsEIP158(environment.Context.BlockNumber)) + return ret, statedb.Logs(), root, err } diff --git a/tests/util.go b/tests/util.go index a3a9a1f64..ff02679ec 100644 --- a/tests/util.go +++ b/tests/util.go @@ -48,7 +48,6 @@ func init() { } func checkLogs(tlog []Log, logs []*types.Log) error { - if len(tlog) != len(logs) { return fmt.Errorf("log length mismatch. Expected %d, got %d", len(tlog), len(logs)) } else { @@ -106,10 +105,14 @@ func (self Log) Topics() [][]byte { } func makePreState(db ethdb.Database, accounts map[string]Account) *state.StateDB { - statedb, _ := state.New(common.Hash{}, db) + sdb := state.NewDatabase(db) + statedb, _ := state.New(common.Hash{}, sdb) for addr, account := range accounts { insertAccount(statedb, addr, account) } + // Commit and re-open to start with a clean state. + root, _ := statedb.CommitTo(db, false) + statedb, _ = state.New(root, sdb) return statedb } diff --git a/trie/proof.go b/trie/proof.go index 1f8f76b1b..298f648c4 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -125,7 +125,7 @@ func VerifyProof(rootHash common.Hash, key []byte, proof []rlp.RawValue) (value } func get(tn node, key []byte) ([]byte, node) { - for len(key) > 0 { + for { switch n := tn.(type) { case *shortNode: if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) { @@ -140,9 +140,10 @@ func get(tn node, key []byte) ([]byte, node) { return key, n case nil: return key, nil + case valueNode: + return nil, n default: panic(fmt.Sprintf("%T: invalid node: %v", tn, tn)) } } - return nil, tn.(valueNode) } diff --git a/trie/secure_trie.go b/trie/secure_trie.go index 37d1d4b09..20c303f31 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -156,6 +156,11 @@ func (t *SecureTrie) Root() []byte { return t.trie.Root() } +func (t *SecureTrie) Copy() *SecureTrie { + cpy := *t + return &cpy +} + // NodeIterator returns an iterator that returns nodes of the underlying trie. Iteration // starts at the key after the given start key. func (t *SecureTrie) NodeIterator(start []byte) NodeIterator { |