diff options
-rw-r--r-- | accounts/abi/abi.go | 150 | ||||
-rw-r--r-- | accounts/abi/abi_test.go | 233 | ||||
-rw-r--r-- | accounts/abi/type.go | 10 | ||||
-rw-r--r-- | core/vm/runtime/runtime.go | 53 | ||||
-rw-r--r-- | core/vm/runtime/runtime_test.go | 46 | ||||
-rw-r--r-- | core/vm_env.go | 36 |
6 files changed, 480 insertions, 48 deletions
diff --git a/accounts/abi/abi.go b/accounts/abi/abi.go index b84fd463a..2dc8039f5 100644 --- a/accounts/abi/abi.go +++ b/accounts/abi/abi.go @@ -20,11 +20,10 @@ import ( "encoding/json" "fmt" "io" - "math" + "reflect" + "strings" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/logger" - "github.com/ethereum/go-ethereum/logger/glog" ) // Executer is an executer method for performing state executions. It takes one @@ -101,52 +100,143 @@ func (abi ABI) Pack(name string, args ...interface{}) ([]byte, error) { } // toGoType parses the input and casts it to the proper type defined by the ABI -// argument in t. -func toGoType(t Argument, input []byte) interface{} { +// argument in T. +func toGoType(i int, t Argument, output []byte) (interface{}, error) { + index := i * 32 + + if index+32 > len(output) { + return nil, fmt.Errorf("abi: cannot marshal in to go type: length insufficient %d require %d", len(output), index+32) + } + + // Parse the given index output and check whether we need to read + // a different offset and length based on the type (i.e. string, bytes) + var returnOutput []byte + switch t.Type.T { + case StringTy, BytesTy: // variable arrays are written at the end of the return bytes + // parse offset from which we should start reading + offset := int(common.BytesToBig(output[index : index+32]).Uint64()) + if offset+32 > len(output) { + return nil, fmt.Errorf("abi: cannot marshal in to go type: length insufficient %d require %d", len(output), offset+32) + } + // parse the size up until we should be reading + size := int(common.BytesToBig(output[offset : offset+32]).Uint64()) + if offset+32+size > len(output) { + return nil, fmt.Errorf("abi: cannot marshal in to go type: length insufficient %d require %d", len(output), offset+32+size) + } + + // get the bytes for this return value + returnOutput = output[offset+32 : offset+32+size] + default: + returnOutput = output[index : index+32] + } + + // cast bytes to abi return type switch t.Type.T { case IntTy: - return common.BytesToBig(input) + return common.BytesToBig(returnOutput), nil case UintTy: - return common.BytesToBig(input) + return common.BytesToBig(returnOutput), nil case BoolTy: - return common.BytesToBig(input).Uint64() > 0 + return common.BytesToBig(returnOutput).Uint64() > 0, nil case AddressTy: - return common.BytesToAddress(input) + return common.BytesToAddress(returnOutput), nil case HashTy: - return common.BytesToHash(input) + return common.BytesToHash(returnOutput), nil + case BytesTy, FixedBytesTy: + return returnOutput, nil + case StringTy: + return string(returnOutput), nil } - return nil + return nil, fmt.Errorf("abi: unknown type %v", t.Type.T) } -// Call executes a call and attemps to parse the return values and returns it as -// an interface. It uses the executer method to perform the actual call since -// the abi knows nothing of the lower level calling mechanism. +// Call will unmarshal the output of the call in v. It will return an error if +// invalid type is given or if the output is too short to conform to the ABI +// spec. // -// Call supports all abi types and includes multiple return values. When only -// one item is returned a single interface{} will be returned, if a contract -// method returns multiple values an []interface{} slice is returned. -func (abi ABI) Call(executer Executer, name string, args ...interface{}) interface{} { +// Call supports all of the available types and accepts a struct or an interface +// slice if the return is a tuple. +func (abi ABI) Call(executer Executer, v interface{}, name string, args ...interface{}) error { callData, err := abi.Pack(name, args...) if err != nil { - glog.V(logger.Debug).Infoln("pack error:", err) - return nil + return err } - output := executer(callData) + return abi.unmarshal(v, name, executer(callData)) +} - method := abi.Methods[name] - ret := make([]interface{}, int(math.Max(float64(len(method.Outputs)), float64(len(output)/32)))) - for i := 0; i < len(ret); i += 32 { - index := i / 32 - ret[index] = toGoType(method.Outputs[index], output[i:i+32]) +var interSlice = reflect.TypeOf([]interface{}{}) + +// unmarshal output in v according to the abi specification +func (abi ABI) unmarshal(v interface{}, name string, output []byte) error { + var method = abi.Methods[name] + + if len(output) == 0 { + return fmt.Errorf("abi: unmarshalling empty output") } - // return single interface - if len(ret) == 1 { - return ret[0] + value := reflect.ValueOf(v).Elem() + typ := value.Type() + + if len(method.Outputs) > 1 { + switch value.Kind() { + // struct will match named return values to the struct's field + // names + case reflect.Struct: + for i := 0; i < len(method.Outputs); i++ { + marshalledValue, err := toGoType(i, method.Outputs[i], output) + if err != nil { + return err + } + reflectValue := reflect.ValueOf(marshalledValue) + + for j := 0; j < typ.NumField(); j++ { + field := typ.Field(j) + // TODO read tags: `abi:"fieldName"` + if field.Name == strings.ToUpper(method.Outputs[i].Name[:1])+method.Outputs[i].Name[1:] { + if field.Type.AssignableTo(reflectValue.Type()) { + value.Field(j).Set(reflectValue) + break + } else { + return fmt.Errorf("abi: cannot unmarshal %v in to %v", field.Type, reflectValue.Type()) + } + } + } + } + case reflect.Slice: + if !value.Type().AssignableTo(interSlice) { + return fmt.Errorf("abi: cannot marshal tuple in to slice %T (only []interface{} is supported)", v) + } + + // create a new slice and start appending the unmarshalled + // values to the new interface slice. + z := reflect.MakeSlice(typ, 0, len(method.Outputs)) + for i := 0; i < len(method.Outputs); i++ { + marshalledValue, err := toGoType(i, method.Outputs[i], output) + if err != nil { + return err + } + z = reflect.Append(z, reflect.ValueOf(marshalledValue)) + } + value.Set(z) + default: + return fmt.Errorf("abi: cannot unmarshal tuple in to %v", typ) + } + + } else { + marshalledValue, err := toGoType(0, method.Outputs[0], output) + if err != nil { + return err + } + reflectValue := reflect.ValueOf(marshalledValue) + if typ.AssignableTo(reflectValue.Type()) { + value.Set(reflectValue) + } else { + return fmt.Errorf("abi: cannot unmarshal %v in to %v", reflectValue.Type(), value.Type()) + } } - return ret + return nil } func (abi *ABI) UnmarshalJSON(data []byte) error { diff --git a/accounts/abi/abi_test.go b/accounts/abi/abi_test.go index 000c118f8..bb0143d21 100644 --- a/accounts/abi/abi_test.go +++ b/accounts/abi/abi_test.go @@ -394,6 +394,7 @@ func TestBytes(t *testing.T) { } } +/* func TestReturn(t *testing.T) { const definition = `[ { "type" : "function", "name" : "balance", "const" : true, "inputs" : [], "outputs" : [ { "name": "", "type": "hash" } ] }, @@ -422,6 +423,7 @@ func TestReturn(t *testing.T) { t.Errorf("expected type common.Address, got %T", r) } } +*/ func TestDefaultFunctionParsing(t *testing.T) { const definition = `[{ "name" : "balance" }]` @@ -458,3 +460,234 @@ func TestBareEvents(t *testing.T) { t.Error("expected 'name' event to be present") } } + +func TestMultiReturnWithStruct(t *testing.T) { + const definition = `[ + { "name" : "multi", "const" : false, "outputs": [ { "name": "Int", "type": "uint256" }, { "name": "String", "type": "string" } ] }]` + + abi, err := JSON(strings.NewReader(definition)) + if err != nil { + t.Fatal(err) + } + + // using buff to make the code readable + buff := new(bytes.Buffer) + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000040")) + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000005")) + stringOut := "hello" + buff.Write(common.RightPadBytes([]byte(stringOut), 32)) + + var inter struct { + Int *big.Int + String string + } + err = abi.unmarshal(&inter, "multi", buff.Bytes()) + if err != nil { + t.Error(err) + } + + if inter.Int == nil || inter.Int.Cmp(big.NewInt(1)) != 0 { + t.Error("expected Int to be 1 got", inter.Int) + } + + if inter.String != stringOut { + t.Error("expected String to be", stringOut, "got", inter.String) + } + + var reversed struct { + String string + Int *big.Int + } + + err = abi.unmarshal(&reversed, "multi", buff.Bytes()) + if err != nil { + t.Error(err) + } + + if reversed.Int == nil || reversed.Int.Cmp(big.NewInt(1)) != 0 { + t.Error("expected Int to be 1 got", reversed.Int) + } + + if reversed.String != stringOut { + t.Error("expected String to be", stringOut, "got", reversed.String) + } +} + +func TestMultiReturnWithSlice(t *testing.T) { + const definition = `[ + { "name" : "multi", "const" : false, "outputs": [ { "name": "Int", "type": "uint256" }, { "name": "String", "type": "string" } ] }]` + + abi, err := JSON(strings.NewReader(definition)) + if err != nil { + t.Fatal(err) + } + + // using buff to make the code readable + buff := new(bytes.Buffer) + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000040")) + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000005")) + stringOut := "hello" + buff.Write(common.RightPadBytes([]byte(stringOut), 32)) + + var inter []interface{} + err = abi.unmarshal(&inter, "multi", buff.Bytes()) + if err != nil { + t.Error(err) + } + + if len(inter) != 2 { + t.Fatal("expected 2 results got", len(inter)) + } + + if num, ok := inter[0].(*big.Int); !ok || num.Cmp(big.NewInt(1)) != 0 { + t.Error("expected index 0 to be 1 got", num) + } + + if str, ok := inter[1].(string); !ok || str != stringOut { + t.Error("expected index 1 to be", stringOut, "got", str) + } +} + +func TestUnmarshal(t *testing.T) { + const definition = `[ + { "name" : "int", "const" : false, "outputs": [ { "type": "uint256" } ] }, + { "name" : "bool", "const" : false, "outputs": [ { "type": "bool" } ] }, + { "name" : "bytes", "const" : false, "outputs": [ { "type": "bytes" } ] }, + { "name" : "multi", "const" : false, "outputs": [ { "type": "bytes" }, { "type": "bytes" } ] }, + { "name" : "mixedBytes", "const" : true, "outputs": [ { "name": "a", "type": "bytes" }, { "name": "b", "type": "bytes32" } ] }]` + + abi, err := JSON(strings.NewReader(definition)) + if err != nil { + t.Fatal(err) + } + + // marshal int + var Int *big.Int + err = abi.unmarshal(&Int, "int", common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) + if err != nil { + t.Error(err) + } + + if Int == nil || Int.Cmp(big.NewInt(1)) != 0 { + t.Error("expected Int to be 1 got", Int) + } + + // marshal bool + var Bool bool + err = abi.unmarshal(&Bool, "bool", common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) + if err != nil { + t.Error(err) + } + + if !Bool { + t.Error("expected Bool to be true") + } + + // marshal dynamic bytes max length 32 + buff := new(bytes.Buffer) + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020")) + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020")) + bytesOut := common.RightPadBytes([]byte("hello"), 32) + buff.Write(bytesOut) + + var Bytes []byte + err = abi.unmarshal(&Bytes, "bytes", buff.Bytes()) + if err != nil { + t.Error(err) + } + + if !bytes.Equal(Bytes, bytesOut) { + t.Errorf("expected %x got %x", bytesOut, Bytes) + } + + // marshall dynamic bytes max length 64 + buff.Reset() + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020")) + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000040")) + bytesOut = common.RightPadBytes([]byte("hello"), 64) + buff.Write(bytesOut) + + err = abi.unmarshal(&Bytes, "bytes", buff.Bytes()) + if err != nil { + t.Error(err) + } + + if !bytes.Equal(Bytes, bytesOut) { + t.Errorf("expected %x got %x", bytesOut, Bytes) + } + + // marshall dynamic bytes max length 63 + buff.Reset() + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020")) + buff.Write(common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000003f")) + bytesOut = common.RightPadBytes([]byte("hello"), 63) + buff.Write(bytesOut) + + err = abi.unmarshal(&Bytes, "bytes", buff.Bytes()) + if err != nil { + t.Error(err) + } + + if !bytes.Equal(Bytes, bytesOut) { + t.Errorf("expected %x got %x", bytesOut, Bytes) + } + + // marshal dynamic bytes output empty + err = abi.unmarshal(&Bytes, "bytes", nil) + if err == nil { + t.Error("expected error") + } + + // marshal dynamic bytes length 5 + buff.Reset() + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020")) + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000005")) + buff.Write(common.RightPadBytes([]byte("hello"), 32)) + + err = abi.unmarshal(&Bytes, "bytes", buff.Bytes()) + if err != nil { + t.Error(err) + } + + if !bytes.Equal(Bytes, []byte("hello")) { + t.Errorf("expected %x got %x", bytesOut, Bytes) + } + + // marshal error + buff.Reset() + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020")) + err = abi.unmarshal(&Bytes, "bytes", buff.Bytes()) + if err == nil { + t.Error("expected error") + } + + err = abi.unmarshal(&Bytes, "multi", make([]byte, 64)) + if err == nil { + t.Error("expected error") + } + + // marshal mixed bytes + buff.Reset() + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000040")) + fixed := common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001") + buff.Write(fixed) + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020")) + bytesOut = common.RightPadBytes([]byte("hello"), 32) + buff.Write(bytesOut) + + var out []interface{} + err = abi.unmarshal(&out, "mixedBytes", buff.Bytes()) + if err != nil { + t.Fatal("didn't expect error:", err) + } + + if !bytes.Equal(bytesOut, out[0].([]byte)) { + t.Errorf("expected %x, got %x", bytesOut, out[0]) + } + + if !bytes.Equal(fixed, out[1].([]byte)) { + t.Errorf("expected %x, got %x", fixed, out[1]) + } +} diff --git a/accounts/abi/type.go b/accounts/abi/type.go index 32f761ef0..6fb2950ba 100644 --- a/accounts/abi/type.go +++ b/accounts/abi/type.go @@ -29,8 +29,11 @@ const ( IntTy byte = iota UintTy BoolTy + StringTy SliceTy AddressTy + FixedBytesTy + BytesTy HashTy RealTy ) @@ -118,6 +121,7 @@ func NewType(t string) (typ Type, err error) { typ.T = UintTy case "bool": typ.Kind = reflect.Bool + typ.T = BoolTy case "real": // TODO typ.Kind = reflect.Invalid case "address": @@ -128,6 +132,7 @@ func NewType(t string) (typ Type, err error) { case "string": typ.Kind = reflect.String typ.Size = -1 + typ.T = StringTy if vsize > 0 { typ.Size = 32 } @@ -140,6 +145,11 @@ func NewType(t string) (typ Type, err error) { typ.Kind = reflect.Slice typ.Type = byte_ts typ.Size = vsize + if vsize == 0 { + typ.T = BytesTy + } else { + typ.T = FixedBytesTy + } default: return Type{}, fmt.Errorf("unsupported arg type: %s", t) } diff --git a/core/vm/runtime/runtime.go b/core/vm/runtime/runtime.go index dd3aa1b0b..1fa06e980 100644 --- a/core/vm/runtime/runtime.go +++ b/core/vm/runtime/runtime.go @@ -41,6 +41,7 @@ type Config struct { DisableJit bool // "disable" so it's enabled by default Debug bool + State *state.StateDB GetHashFn func(n uint64) common.Hash } @@ -94,12 +95,14 @@ func Execute(code, input []byte, cfg *Config) ([]byte, *state.StateDB, error) { vm.EnableJit = !cfg.DisableJit vm.Debug = cfg.Debug + if cfg.State == nil { + db, _ := ethdb.NewMemDatabase() + cfg.State, _ = state.New(common.Hash{}, db) + } var ( - db, _ = ethdb.NewMemDatabase() - statedb, _ = state.New(common.Hash{}, db) - vmenv = NewEnv(cfg, statedb) - sender = statedb.CreateAccount(cfg.Origin) - receiver = statedb.CreateAccount(common.StringToAddress("contract")) + vmenv = NewEnv(cfg, cfg.State) + sender = cfg.State.CreateAccount(cfg.Origin) + receiver = cfg.State.CreateAccount(common.StringToAddress("contract")) ) // set the receiver's (the executing contract) code for execution. receiver.SetCode(code) @@ -117,5 +120,43 @@ func Execute(code, input []byte, cfg *Config) ([]byte, *state.StateDB, error) { if cfg.Debug { vm.StdErrFormat(vmenv.StructLogs()) } - return ret, statedb, err + return ret, cfg.State, err +} + +// Call executes the code given by the contract's address. It will return the +// EVM's return value or an error if it failed. +// +// Call, unlike Execute, requires a config and also requires the State field to +// be set. +func Call(address common.Address, input []byte, cfg *Config) ([]byte, error) { + setDefaults(cfg) + + // defer the call to setting back the original values + defer func(debug, forceJit, enableJit bool) { + vm.Debug = debug + vm.ForceJit = forceJit + vm.EnableJit = enableJit + }(vm.Debug, vm.ForceJit, vm.EnableJit) + + vm.ForceJit = !cfg.DisableJit + vm.EnableJit = !cfg.DisableJit + vm.Debug = cfg.Debug + + vmenv := NewEnv(cfg, cfg.State) + + sender := cfg.State.GetOrNewStateObject(cfg.Origin) + // Call the code with the given configuration. + ret, err := vmenv.Call( + sender, + address, + input, + cfg.GasLimit, + cfg.GasPrice, + cfg.Value, + ) + + if cfg.Debug { + vm.StdErrFormat(vmenv.StructLogs()) + } + return ret, err } diff --git a/core/vm/runtime/runtime_test.go b/core/vm/runtime/runtime_test.go index 773a0163e..e5183052f 100644 --- a/core/vm/runtime/runtime_test.go +++ b/core/vm/runtime/runtime_test.go @@ -17,12 +17,15 @@ package runtime import ( + "math/big" "strings" "testing" "github.com/ethereum/go-ethereum/accounts/abi" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/vm" + "github.com/ethereum/go-ethereum/ethdb" ) func TestDefaults(t *testing.T) { @@ -71,6 +74,49 @@ func TestEnvironment(t *testing.T) { }, nil, nil) } +func TestExecute(t *testing.T) { + ret, _, err := Execute([]byte{ + byte(vm.PUSH1), 10, + byte(vm.PUSH1), 0, + byte(vm.MSTORE), + byte(vm.PUSH1), 32, + byte(vm.PUSH1), 0, + byte(vm.RETURN), + }, nil, nil) + if err != nil { + t.Fatal("didn't expect error", err) + } + + num := common.BytesToBig(ret) + if num.Cmp(big.NewInt(10)) != 0 { + t.Error("Expected 10, got", num) + } +} + +func TestCall(t *testing.T) { + db, _ := ethdb.NewMemDatabase() + state, _ := state.New(common.Hash{}, db) + address := common.HexToAddress("0x0a") + state.SetCode(address, []byte{ + byte(vm.PUSH1), 10, + byte(vm.PUSH1), 0, + byte(vm.MSTORE), + byte(vm.PUSH1), 32, + byte(vm.PUSH1), 0, + byte(vm.RETURN), + }) + + ret, err := Call(address, nil, &Config{State: state}) + if err != nil { + t.Fatal("didn't expect error", err) + } + + num := common.BytesToBig(ret) + if num.Cmp(big.NewInt(10)) != 0 { + t.Error("Expected 10, got", num) + } +} + func TestRestoreDefaults(t *testing.T) { Execute(nil, nil, &Config{Debug: true}) if vm.ForceJit { diff --git a/core/vm_env.go b/core/vm_env.go index c8b50debc..1c787e982 100644 --- a/core/vm_env.go +++ b/core/vm_env.go @@ -25,6 +25,21 @@ import ( "github.com/ethereum/go-ethereum/core/vm" ) +// GetHashFn returns a function for which the VM env can query block hashes thru +// up to the limit defined by the Yellow Paper and uses the given block chain +// to query for information. +func GetHashFn(ref common.Hash, chain *BlockChain) func(n uint64) common.Hash { + return func(n uint64) common.Hash { + for block := chain.GetBlock(ref); block != nil; block = chain.GetBlock(block.ParentHash()) { + if block.NumberU64() == n { + return block.Hash() + } + } + + return common.Hash{} + } +} + type VMEnv struct { state *state.StateDB header *types.Header @@ -32,17 +47,20 @@ type VMEnv struct { depth int chain *BlockChain typ vm.Type + + getHashFn func(uint64) common.Hash // structured logging logs []vm.StructLog } func NewEnv(state *state.StateDB, chain *BlockChain, msg Message, header *types.Header) *VMEnv { return &VMEnv{ - chain: chain, - state: state, - header: header, - msg: msg, - typ: vm.StdVmTy, + chain: chain, + state: state, + header: header, + msg: msg, + typ: vm.StdVmTy, + getHashFn: GetHashFn(header.ParentHash, chain), } } @@ -59,13 +77,7 @@ func (self *VMEnv) SetDepth(i int) { self.depth = i } func (self *VMEnv) VmType() vm.Type { return self.typ } func (self *VMEnv) SetVmType(t vm.Type) { self.typ = t } func (self *VMEnv) GetHash(n uint64) common.Hash { - for block := self.chain.GetBlock(self.header.ParentHash); block != nil; block = self.chain.GetBlock(block.ParentHash()) { - if block.NumberU64() == n { - return block.Hash() - } - } - - return common.Hash{} + return self.getHashFn(n) } func (self *VMEnv) AddLog(log *vm.Log) { |