diff options
-rw-r--r-- | accounts/abi/abi.go | 79 | ||||
-rw-r--r-- | accounts/abi/abi_test.go | 205 | ||||
-rw-r--r-- | accounts/abi/numbers.go | 2 | ||||
-rw-r--r-- | accounts/abi/numbers_test.go | 4 | ||||
-rw-r--r-- | accounts/abi/type.go | 188 | ||||
-rw-r--r-- | core/canary.go | 51 | ||||
-rw-r--r-- | eth/api.go | 206 | ||||
-rw-r--r-- | eth/backend.go | 2 | ||||
-rw-r--r-- | eth/downloader/api.go | 75 | ||||
-rw-r--r-- | eth/filters/api.go | 79 | ||||
-rw-r--r-- | miner/worker.go | 20 | ||||
-rw-r--r-- | node/node.go | 2 | ||||
-rw-r--r-- | rpc/doc.go | 28 | ||||
-rw-r--r-- | rpc/http.go | 2 | ||||
-rw-r--r-- | rpc/inproc.go | 2 | ||||
-rw-r--r-- | rpc/json.go | 24 | ||||
-rw-r--r-- | rpc/notification.go | 288 | ||||
-rw-r--r-- | rpc/notification_test.go | 119 | ||||
-rw-r--r-- | rpc/server.go | 228 | ||||
-rw-r--r-- | rpc/server_test.go | 16 | ||||
-rw-r--r-- | rpc/types.go | 54 | ||||
-rw-r--r-- | rpc/utils.go | 26 | ||||
-rw-r--r-- | rpc/websocket.go | 3 |
23 files changed, 1154 insertions, 549 deletions
diff --git a/accounts/abi/abi.go b/accounts/abi/abi.go index 91f9700d9..9ef7c0f0d 100644 --- a/accounts/abi/abi.go +++ b/accounts/abi/abi.go @@ -20,6 +20,7 @@ import ( "encoding/json" "fmt" "io" + "math/big" "reflect" "strings" @@ -63,9 +64,8 @@ func (abi ABI) pack(method Method, args ...interface{}) ([]byte, error) { return nil, fmt.Errorf("`%s` %v", method.Name, err) } - // check for a string or bytes input type - switch input.Type.T { - case StringTy, BytesTy: + // check for a slice type (string, bytes, slice) + if input.Type.T == StringTy || input.Type.T == BytesTy || input.Type.IsSlice { // calculate the offset offset := len(method.Inputs)*32 + len(variableInput) // set the offset @@ -73,7 +73,7 @@ func (abi ABI) pack(method Method, args ...interface{}) ([]byte, error) { // Append the packed output to the variable input. The variable input // will be appended at the end of the input. variableInput = append(variableInput, packed...) - default: + } else { // append the packed value to the input ret = append(ret, packed...) } @@ -117,11 +117,80 @@ func (abi ABI) Pack(name string, args ...interface{}) ([]byte, error) { return append(method.Id(), arguments...), nil } +// toGoSliceType prses the input and casts it to the proper slice defined by the ABI +// argument in T. +func toGoSlice(i int, t Argument, output []byte) (interface{}, error) { + index := i * 32 + // The slice must, at very least be large enough for the index+32 which is exactly the size required + // for the [offset in output, size of offset]. + if index+32 > len(output) { + return nil, fmt.Errorf("abi: cannot marshal in to go slice: insufficient size output %d require %d", len(output), index+32) + } + + // first we need to create a slice of the type + var refSlice reflect.Value + switch t.Type.T { + case IntTy, UintTy, BoolTy: // int, uint, bool can all be of type big int. + refSlice = reflect.ValueOf([]*big.Int(nil)) + case AddressTy: // address must be of slice Address + refSlice = reflect.ValueOf([]common.Address(nil)) + case HashTy: // hash must be of slice hash + refSlice = reflect.ValueOf([]common.Hash(nil)) + default: // no other types are supported + return nil, fmt.Errorf("abi: unsupported slice type %v", t.Type.T) + } + // get the offset which determines the start of this array ... + offset := int(common.BytesToBig(output[index : index+32]).Uint64()) + if offset+32 > len(output) { + return nil, fmt.Errorf("abi: cannot marshal in to go slice: offset %d would go over slice boundary (len=%d)", len(output), offset+32) + } + + slice := output[offset:] + // ... starting with the size of the array in elements ... + size := int(common.BytesToBig(slice[:32]).Uint64()) + slice = slice[32:] + // ... and make sure that we've at the very least the amount of bytes + // available in the buffer. + if size*32 > len(slice) { + return nil, fmt.Errorf("abi: cannot marshal in to go slice: insufficient size output %d require %d", len(output), offset+32+size*32) + } + + // reslice to match the required size + slice = slice[:(size * 32)] + for i := 0; i < size; i++ { + var ( + inter interface{} // interface type + returnOutput = slice[i*32 : i*32+32] // the return output + ) + + // set inter to the correct type (cast) + switch t.Type.T { + case IntTy, UintTy: + inter = common.BytesToBig(returnOutput) + case BoolTy: + inter = common.BytesToBig(returnOutput).Uint64() > 0 + case AddressTy: + inter = common.BytesToAddress(returnOutput) + case HashTy: + inter = common.BytesToHash(returnOutput) + } + // append the item to our reflect slice + refSlice = reflect.Append(refSlice, reflect.ValueOf(inter)) + } + + // return the interface + return refSlice.Interface(), nil +} + // toGoType parses the input and casts it to the proper type defined by the ABI // argument in T. func toGoType(i int, t Argument, output []byte) (interface{}, error) { - index := i * 32 + // we need to treat slices differently + if t.Type.IsSlice { + return toGoSlice(i, t, output) + } + index := i * 32 if index+32 > len(output) { return nil, fmt.Errorf("abi: cannot marshal in to go type: length insufficient %d require %d", len(output), index+32) } diff --git a/accounts/abi/abi_test.go b/accounts/abi/abi_test.go index 66d2e1b39..a1b3e62d9 100644 --- a/accounts/abi/abi_test.go +++ b/accounts/abi/abi_test.go @@ -49,7 +49,9 @@ const jsondata2 = ` { "type" : "function", "name" : "foo", "const" : false, "inputs" : [ { "name" : "inputs", "type" : "uint32" } ] }, { "type" : "function", "name" : "bar", "const" : false, "inputs" : [ { "name" : "inputs", "type" : "uint32" }, { "name" : "string", "type" : "uint16" } ] }, { "type" : "function", "name" : "slice", "const" : false, "inputs" : [ { "name" : "inputs", "type" : "uint32[2]" } ] }, - { "type" : "function", "name" : "slice256", "const" : false, "inputs" : [ { "name" : "inputs", "type" : "uint256[2]" } ] } + { "type" : "function", "name" : "slice256", "const" : false, "inputs" : [ { "name" : "inputs", "type" : "uint256[2]" } ] }, + { "type" : "function", "name" : "sliceAddress", "const" : false, "inputs" : [ { "name" : "inputs", "type" : "address[]" } ] }, + { "type" : "function", "name" : "sliceMultiAddress", "const" : false, "inputs" : [ { "name" : "a", "type" : "address[]" }, { "name" : "b", "type" : "address[]" } ] } ]` func TestType(t *testing.T) { @@ -57,7 +59,7 @@ func TestType(t *testing.T) { if err != nil { t.Error(err) } - if typ.Kind != reflect.Ptr { + if typ.Kind != reflect.Uint { t.Error("expected uint32 to have kind Ptr") } @@ -65,10 +67,10 @@ func TestType(t *testing.T) { if err != nil { t.Error(err) } - if typ.Kind != reflect.Slice { - t.Error("expected uint32[] to have type slice") + if !typ.IsSlice { + t.Error("expected uint32[] to be slice") } - if typ.Type != ubig_ts { + if typ.Type != ubig_t { t.Error("expcted uith32[] to have type uint64") } @@ -76,13 +78,13 @@ func TestType(t *testing.T) { if err != nil { t.Error(err) } - if typ.Kind != reflect.Slice { - t.Error("expected uint32[2] to have kind slice") + if !typ.IsSlice { + t.Error("expected uint32[2] to be slice") } - if typ.Type != ubig_ts { + if typ.Type != ubig_t { t.Error("expcted uith32[2] to have type uint64") } - if typ.Size != 2 { + if typ.SliceSize != 2 { t.Error("expected uint32[2] to have a size of 2") } } @@ -147,10 +149,6 @@ func TestTestNumbers(t *testing.T) { t.Errorf("expected send( ptr ) to throw, requires *big.Int instead of *int") } - if _, err := abi.Pack("send", 1000); err != nil { - t.Error("expected send(1000) to cast to big") - } - if _, err := abi.Pack("test", uint32(1000)); err != nil { t.Error(err) } @@ -202,17 +200,7 @@ func TestTestSlice(t *testing.T) { t.FailNow() } - addr := make([]byte, 20) - if _, err := abi.Pack("address", addr); err != nil { - t.Error(err) - } - - addr = make([]byte, 21) - if _, err := abi.Pack("address", addr); err == nil { - t.Error("expected address of 21 width to throw") - } - - slice := make([]byte, 2) + slice := make([]uint64, 2) if _, err := abi.Pack("uint64[2]", slice); err != nil { t.Error(err) } @@ -222,16 +210,18 @@ func TestTestSlice(t *testing.T) { } } -func TestTestAddress(t *testing.T) { +func TestImplicitTypeCasts(t *testing.T) { abi, err := JSON(strings.NewReader(jsondata2)) if err != nil { t.Error(err) t.FailNow() } - addr := make([]byte, 20) - if _, err := abi.Pack("address", addr); err != nil { - t.Error(err) + slice := make([]uint8, 2) + _, err = abi.Pack("uint64[2]", slice) + expStr := "`uint64[2]` abi: cannot use type uint8 as type uint64" + if err.Error() != expStr { + t.Errorf("expected %v, got %v", expStr, err) } } @@ -310,44 +300,69 @@ func TestPackSlice(t *testing.T) { } sig := crypto.Keccak256([]byte("slice(uint32[2])"))[:4] - sig = append(sig, make([]byte, 64)...) - sig[35] = 1 - sig[67] = 2 + sig = append(sig, common.LeftPadBytes([]byte{32}, 32)...) + sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...) + sig = append(sig, common.LeftPadBytes([]byte{1}, 32)...) + sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...) packed, err := abi.Pack("slice", []uint32{1, 2}) if err != nil { t.Error(err) - t.FailNow() } if !bytes.Equal(packed, sig) { t.Errorf("expected %x got %x", sig, packed) } -} -func TestPackSliceBig(t *testing.T) { - abi, err := JSON(strings.NewReader(jsondata2)) + var addrA, addrB = common.Address{1}, common.Address{2} + sig = abi.Methods["sliceAddress"].Id() + sig = append(sig, common.LeftPadBytes([]byte{32}, 32)...) + sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...) + sig = append(sig, common.LeftPadBytes(addrA[:], 32)...) + sig = append(sig, common.LeftPadBytes(addrB[:], 32)...) + + packed, err = abi.Pack("sliceAddress", []common.Address{addrA, addrB}) if err != nil { - t.Error(err) - t.FailNow() + t.Fatal(err) + } + if !bytes.Equal(packed, sig) { + t.Errorf("expected %x got %x", sig, packed) } - sig := crypto.Keccak256([]byte("slice256(uint256[2])"))[:4] - sig = append(sig, make([]byte, 64)...) - sig[35] = 1 - sig[67] = 2 + var addrC, addrD = common.Address{3}, common.Address{4} + sig = abi.Methods["sliceMultiAddress"].Id() + sig = append(sig, common.LeftPadBytes([]byte{64}, 32)...) + sig = append(sig, common.LeftPadBytes([]byte{160}, 32)...) + sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...) + sig = append(sig, common.LeftPadBytes(addrA[:], 32)...) + sig = append(sig, common.LeftPadBytes(addrB[:], 32)...) + sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...) + sig = append(sig, common.LeftPadBytes(addrC[:], 32)...) + sig = append(sig, common.LeftPadBytes(addrD[:], 32)...) - packed, err := abi.Pack("slice256", []*big.Int{big.NewInt(1), big.NewInt(2)}) + packed, err = abi.Pack("sliceMultiAddress", []common.Address{addrA, addrB}, []common.Address{addrC, addrD}) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(packed, sig) { + t.Errorf("expected %x got %x", sig, packed) + } + + sig = crypto.Keccak256([]byte("slice256(uint256[2])"))[:4] + sig = append(sig, common.LeftPadBytes([]byte{32}, 32)...) + sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...) + sig = append(sig, common.LeftPadBytes([]byte{1}, 32)...) + sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...) + + packed, err = abi.Pack("slice256", []*big.Int{big.NewInt(1), big.NewInt(2)}) if err != nil { t.Error(err) - t.FailNow() } if !bytes.Equal(packed, sig) { t.Errorf("expected %x got %x", sig, packed) } } - func ExampleJSON() { const definition = `[{"constant":true,"inputs":[{"name":"","type":"address"}],"name":"isBar","outputs":[{"name":"","type":"bool"}],"type":"function"}]` @@ -370,7 +385,7 @@ func TestInputVariableInputLength(t *testing.T) { { "type" : "function", "name" : "strOne", "const" : true, "inputs" : [ { "name" : "str", "type" : "string" } ] }, { "type" : "function", "name" : "bytesOne", "const" : true, "inputs" : [ { "name" : "str", "type" : "bytes" } ] }, { "type" : "function", "name" : "strTwo", "const" : true, "inputs" : [ { "name" : "str", "type" : "string" }, { "name" : "str1", "type" : "string" } ] } -]` + ]` abi, err := JSON(strings.NewReader(definition)) if err != nil { @@ -493,35 +508,6 @@ func TestInputVariableInputLength(t *testing.T) { } } -func TestBytes(t *testing.T) { - const definition = `[ - { "type" : "function", "name" : "balance", "const" : true, "inputs" : [ { "name" : "address", "type" : "bytes20" } ] }, - { "type" : "function", "name" : "send", "const" : false, "inputs" : [ { "name" : "amount", "type" : "uint256" } ] } -]` - - abi, err := JSON(strings.NewReader(definition)) - if err != nil { - t.Fatal(err) - } - ok := make([]byte, 20) - _, err = abi.Pack("balance", ok) - if err != nil { - t.Error(err) - } - - toosmall := make([]byte, 19) - _, err = abi.Pack("balance", toosmall) - if err != nil { - t.Error(err) - } - - toobig := make([]byte, 21) - _, err = abi.Pack("balance", toobig) - if err == nil { - t.Error("expected error") - } -} - func TestDefaultFunctionParsing(t *testing.T) { const definition = `[{ "name" : "balance" }]` @@ -713,12 +699,15 @@ func TestUnmarshal(t *testing.T) { { "name" : "bytes", "const" : false, "outputs": [ { "type": "bytes" } ] }, { "name" : "fixed", "const" : false, "outputs": [ { "type": "bytes32" } ] }, { "name" : "multi", "const" : false, "outputs": [ { "type": "bytes" }, { "type": "bytes" } ] }, + { "name" : "addressSliceSingle", "const" : false, "outputs": [ { "type": "address[]" } ] }, + { "name" : "addressSliceDouble", "const" : false, "outputs": [ { "name": "a", "type": "address[]" }, { "name": "b", "type": "address[]" } ] }, { "name" : "mixedBytes", "const" : true, "outputs": [ { "name": "a", "type": "bytes" }, { "name": "b", "type": "bytes32" } ] }]` abi, err := JSON(strings.NewReader(definition)) if err != nil { t.Fatal(err) } + buff := new(bytes.Buffer) // marshal int var Int *big.Int @@ -743,7 +732,6 @@ func TestUnmarshal(t *testing.T) { } // marshal dynamic bytes max length 32 - buff := new(bytes.Buffer) buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020")) buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020")) bytesOut := common.RightPadBytes([]byte("hello"), 32) @@ -862,4 +850,71 @@ func TestUnmarshal(t *testing.T) { if !bytes.Equal(fixed, out[1].([]byte)) { t.Errorf("expected %x, got %x", fixed, out[1]) } + + // marshal address slice + buff.Reset() + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020")) // offset + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) // size + buff.Write(common.Hex2Bytes("0000000000000000000000000100000000000000000000000000000000000000")) + + var outAddr []common.Address + err = abi.Unpack(&outAddr, "addressSliceSingle", buff.Bytes()) + if err != nil { + t.Fatal("didn't expect error:", err) + } + + if len(outAddr) != 1 { + t.Fatal("expected 1 item, got", len(outAddr)) + } + + if outAddr[0] != (common.Address{1}) { + t.Errorf("expected %x, got %x", common.Address{1}, outAddr[0]) + } + + // marshal multiple address slice + buff.Reset() + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000040")) // offset + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000080")) // offset + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) // size + buff.Write(common.Hex2Bytes("0000000000000000000000000100000000000000000000000000000000000000")) + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002")) // size + buff.Write(common.Hex2Bytes("0000000000000000000000000200000000000000000000000000000000000000")) + buff.Write(common.Hex2Bytes("0000000000000000000000000300000000000000000000000000000000000000")) + + var outAddrStruct struct { + A []common.Address + B []common.Address + } + err = abi.Unpack(&outAddrStruct, "addressSliceDouble", buff.Bytes()) + if err != nil { + t.Fatal("didn't expect error:", err) + } + + if len(outAddrStruct.A) != 1 { + t.Fatal("expected 1 item, got", len(outAddrStruct.A)) + } + + if outAddrStruct.A[0] != (common.Address{1}) { + t.Errorf("expected %x, got %x", common.Address{1}, outAddrStruct.A[0]) + } + + if len(outAddrStruct.B) != 2 { + t.Fatal("expected 1 item, got", len(outAddrStruct.B)) + } + + if outAddrStruct.B[0] != (common.Address{2}) { + t.Errorf("expected %x, got %x", common.Address{2}, outAddrStruct.B[0]) + } + if outAddrStruct.B[1] != (common.Address{3}) { + t.Errorf("expected %x, got %x", common.Address{3}, outAddrStruct.B[1]) + } + + // marshal invalid address slice + buff.Reset() + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000100")) + + err = abi.Unpack(&outAddr, "addressSliceSingle", buff.Bytes()) + if err == nil { + t.Fatal("expected error:", err) + } } diff --git a/accounts/abi/numbers.go b/accounts/abi/numbers.go index 02609d567..084701de5 100644 --- a/accounts/abi/numbers.go +++ b/accounts/abi/numbers.go @@ -117,8 +117,6 @@ func packNum(value reflect.Value, to byte) []byte { // checks whether the given reflect value is signed. This also works for slices with a number type func isSigned(v reflect.Value) bool { switch v.Type() { - case ubig_ts, big_ts, big_t, ubig_t: - return true case int_ts, int8_ts, int16_ts, int32_ts, int64_ts, int_t, int8_t, int16_t, int32_t, int64_t: return true } diff --git a/accounts/abi/numbers_test.go b/accounts/abi/numbers_test.go index 78dc57543..6590e41a6 100644 --- a/accounts/abi/numbers_test.go +++ b/accounts/abi/numbers_test.go @@ -81,8 +81,4 @@ func TestSigned(t *testing.T) { if !isSigned(reflect.ValueOf(int(10))) { t.Error() } - - if !isSigned(reflect.ValueOf(big.NewInt(10))) { - t.Error() - } } diff --git a/accounts/abi/type.go b/accounts/abi/type.go index c08b744f7..5a5a5ac49 100644 --- a/accounts/abi/type.go +++ b/accounts/abi/type.go @@ -40,6 +40,9 @@ const ( // Type is the reflection of the supported argument type type Type struct { + IsSlice bool + SliceSize int + Kind reflect.Kind Type reflect.Type Size int @@ -47,6 +50,11 @@ type Type struct { stringKind string // holds the unparsed string for deriving signatures } +var ( + fullTypeRegex = regexp.MustCompile("([a-zA-Z0-9]+)(\\[([0-9]*)?\\])?") + typeRegex = regexp.MustCompile("([a-zA-Z]+)([0-9]*)?") +) + // NewType returns a fully parsed Type given by the input string or an error if it can't be parsed. // // Strings can be in the format of: @@ -61,98 +69,87 @@ type Type struct { // address int256 uint256 real[2] func NewType(t string) (typ Type, err error) { // 1. full string 2. type 3. (opt.) is slice 4. (opt.) size - freg, err := regexp.Compile("([a-zA-Z0-9]+)(\\[([0-9]*)?\\])?") - if err != nil { - return Type{}, err - } - res := freg.FindAllStringSubmatch(t, -1)[0] - var ( - isslice bool - size int - ) + // parse the full representation of the abi-type definition; including: + // * full string + // * type + // * is slice + // * slice size + res := fullTypeRegex.FindAllStringSubmatch(t, -1)[0] + + // check if type is slice and parse type. switch { case res[3] != "": // err is ignored. Already checked for number through the regexp - size, _ = strconv.Atoi(res[3]) - isslice = true + typ.SliceSize, _ = strconv.Atoi(res[3]) + typ.IsSlice = true case res[2] != "": - isslice = true - size = -1 + typ.IsSlice, typ.SliceSize = true, -1 case res[0] == "": - return Type{}, fmt.Errorf("type parse error for `%s`", t) + return Type{}, fmt.Errorf("abi: type parse error: %s", t) } - treg, err := regexp.Compile("([a-zA-Z]+)([0-9]*)?") - if err != nil { - return Type{}, err + // parse the type and size of the abi-type. + parsedType := typeRegex.FindAllStringSubmatch(res[1], -1)[0] + // varSize is the size of the variable + var varSize int + if len(parsedType[2]) > 0 { + var err error + varSize, err = strconv.Atoi(parsedType[2]) + if err != nil { + return Type{}, fmt.Errorf("abi: error parsing variable size: %v", err) + } } - - parsedType := treg.FindAllStringSubmatch(res[1], -1)[0] - vsize, _ := strconv.Atoi(parsedType[2]) - vtype := parsedType[1] - // substitute canonical representation - if vsize == 0 && (vtype == "int" || vtype == "uint") { - vsize = 256 + // varType is the parsed abi type + varType := parsedType[1] + // substitute canonical integer + if varSize == 0 && (varType == "int" || varType == "uint") { + varSize = 256 t += "256" } - if isslice { - typ.Kind = reflect.Slice - typ.Size = size - switch vtype { - case "int": - typ.Type = big_ts - case "uint": - typ.Type = ubig_ts - default: - return Type{}, fmt.Errorf("unsupported arg slice type: %s", t) - } - } else { - switch vtype { - case "int": - typ.Kind = reflect.Ptr - typ.Type = big_t - typ.Size = 256 - typ.T = IntTy - case "uint": - typ.Kind = reflect.Ptr - typ.Type = ubig_t - typ.Size = 256 - typ.T = UintTy - case "bool": - typ.Kind = reflect.Bool - typ.T = BoolTy - case "real": // TODO - typ.Kind = reflect.Invalid - case "address": - typ.Kind = reflect.Slice - typ.Type = address_t - typ.Size = 20 - typ.T = AddressTy - case "string": - typ.Kind = reflect.String - typ.Size = -1 - typ.T = StringTy - if vsize > 0 { - typ.Size = 32 - } - case "hash": - typ.Kind = reflect.Slice + switch varType { + case "int": + typ.Kind = reflect.Int + typ.Type = big_t + typ.Size = varSize + typ.T = IntTy + case "uint": + typ.Kind = reflect.Uint + typ.Type = ubig_t + typ.Size = varSize + typ.T = UintTy + case "bool": + typ.Kind = reflect.Bool + typ.T = BoolTy + case "real": // TODO + typ.Kind = reflect.Invalid + case "address": + typ.Type = address_t + typ.Size = 20 + typ.T = AddressTy + case "string": + typ.Kind = reflect.String + typ.Size = -1 + typ.T = StringTy + if varSize > 0 { typ.Size = 32 - typ.Type = hash_t - typ.T = HashTy - case "bytes": - typ.Kind = reflect.Slice - typ.Type = byte_ts - typ.Size = vsize - if vsize == 0 { - typ.T = BytesTy - } else { - typ.T = FixedBytesTy - } - default: - return Type{}, fmt.Errorf("unsupported arg type: %s", t) } + case "hash": + typ.Kind = reflect.Array + typ.Size = 32 + typ.Type = hash_t + typ.T = HashTy + case "bytes": + typ.Kind = reflect.Array + typ.Type = byte_ts + typ.Size = varSize + if varSize == 0 { + typ.T = BytesTy + } else { + typ.T = FixedBytesTy + } + default: + return Type{}, fmt.Errorf("unsupported arg type: %s", t) } typ.stringKind = t @@ -180,14 +177,26 @@ func (t Type) pack(v interface{}) ([]byte, error) { value := reflect.ValueOf(v) switch kind := value.Kind(); kind { case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + // check input is unsigned if t.Type != ubig_t { - return nil, fmt.Errorf("type mismatch: %s for %T", t.Type, v) + return nil, fmt.Errorf("abi: type mismatch: %s for %T", t.Type, v) + } + + // no implicit type casting + if int(value.Type().Size()*8) != t.Size { + return nil, fmt.Errorf("abi: cannot use type %T as type uint%d", v, t.Size) } + return packNum(value, t.T), nil case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: if t.Type != ubig_t { return nil, fmt.Errorf("type mismatch: %s for %T", t.Type, v) } + + // no implicit type casting + if int(value.Type().Size()*8) != t.Size { + return nil, fmt.Errorf("abi: cannot use type %T as type uint%d", v, t.Size) + } return packNum(value, t.T), nil case reflect.Ptr: // If the value is a ptr do a assign check (only used by @@ -203,30 +212,29 @@ func (t Type) pack(v interface{}) ([]byte, error) { return packBytesSlice([]byte(value.String()), value.Len()), nil case reflect.Slice: - // if the param is a bytes type, pack the slice up as a string + // Byte slice is a special case, it gets treated as a single value if t.T == BytesTy { return packBytesSlice(value.Bytes(), value.Len()), nil } - if t.Size > -1 && value.Len() > t.Size { + if t.SliceSize > -1 && value.Len() > t.SliceSize { return nil, fmt.Errorf("%v out of bound. %d for %d", value.Kind(), value.Len(), t.Size) } - // Address is a special slice. The slice acts as one rather than a list of elements. - if t.T == AddressTy { - return common.LeftPadBytes(v.([]byte), 32), nil - } - // Signed / Unsigned check - if (t.T != IntTy && isSigned(value)) || (t.T == UintTy && isSigned(value)) { + if value.Type() == big_t && (t.T != IntTy && isSigned(value)) || (t.T == UintTy && isSigned(value)) { return nil, fmt.Errorf("slice of incompatible types.") } var packed []byte for i := 0; i < value.Len(); i++ { - packed = append(packed, packNum(value.Index(i), t.T)...) + val, err := t.pack(value.Index(i).Interface()) + if err != nil { + return nil, err + } + packed = append(packed, val...) } - return packed, nil + return packBytesSlice(packed, value.Len()), nil case reflect.Bool: if value.Bool() { return common.LeftPadBytes(common.Big1.Bytes(), 32), nil diff --git a/core/canary.go b/core/canary.go deleted file mode 100644 index 69db18e58..000000000 --- a/core/canary.go +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2015 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. - -package core - -import ( - "math/big" - - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core/state" -) - -var ( - jeff = common.HexToAddress("959c33de5961820567930eccce51ea715c496f85") - vitalik = common.HexToAddress("c8158da0b567a8cc898991c2c2a073af67dc03a9") - christoph = common.HexToAddress("7a19a893f91d5b6e2cdf941b6acbba2cbcf431ee") - gav = common.HexToAddress("539dd9aaf45c3feb03f9c004f4098bd3268fef6b") -) - -// Canary will check the 0'd address of the 4 contracts above. -// If two or more are set to anything other than a 0 the canary -// dies a horrible death. -func Canary(statedb *state.StateDB) bool { - var r int - if (statedb.GetState(jeff, common.Hash{}).Big().Cmp(big.NewInt(0)) > 0) { - r++ - } - if (statedb.GetState(gav, common.Hash{}).Big().Cmp(big.NewInt(0)) > 0) { - r++ - } - if (statedb.GetState(christoph, common.Hash{}).Big().Cmp(big.NewInt(0)) > 0) { - r++ - } - if (statedb.GetState(vitalik, common.Hash{}).Big().Cmp(big.NewInt(0)) > 0) { - r++ - } - return r > 1 -} diff --git a/eth/api.go b/eth/api.go index a257639ba..676191fc2 100644 --- a/eth/api.go +++ b/eth/api.go @@ -28,6 +28,8 @@ import ( "sync" "time" + "golang.org/x/net/context" + "github.com/ethereum/ethash" "github.com/ethereum/go-ethereum/accounts" "github.com/ethereum/go-ethereum/common" @@ -457,16 +459,46 @@ func (s *PrivateAccountAPI) LockAccount(addr common.Address) bool { // It offers only methods that operate on public data that is freely available to anyone. type PublicBlockChainAPI struct { config *core.ChainConfig - bc *core.BlockChain - chainDb ethdb.Database - eventMux *event.TypeMux - am *accounts.Manager - miner *miner.Miner + bc *core.BlockChain + chainDb ethdb.Database + eventMux *event.TypeMux + muNewBlockSubscriptions sync.Mutex // protects newBlocksSubscriptions + newBlockSubscriptions map[string]func(core.ChainEvent) error // callbacks for new block subscriptions + am *accounts.Manager + miner *miner.Miner } // NewPublicBlockChainAPI creates a new Etheruem blockchain API. func NewPublicBlockChainAPI(config *core.ChainConfig, bc *core.BlockChain, m *miner.Miner, chainDb ethdb.Database, eventMux *event.TypeMux, am *accounts.Manager) *PublicBlockChainAPI { - return &PublicBlockChainAPI{config: config, bc: bc, miner: m, chainDb: chainDb, eventMux: eventMux, am: am} + api := &PublicBlockChainAPI{ + config: config, + bc: bc, + miner: m, + chainDb: chainDb, + eventMux: eventMux, + am: am, + newBlockSubscriptions: make(map[string]func(core.ChainEvent) error), + } + + go api.subscriptionLoop() + + return api +} + +// subscriptionLoop reads events from the global event mux and creates notifications for the matched subscriptions. +func (s *PublicBlockChainAPI) subscriptionLoop() { + sub := s.eventMux.Subscribe(core.ChainEvent{}) + for event := range sub.Chan() { + if chainEvent, ok := event.Data.(core.ChainEvent); ok { + s.muNewBlockSubscriptions.Lock() + for id, notifyOf := range s.newBlockSubscriptions { + if notifyOf(chainEvent) == rpc.ErrNotificationNotFound { + delete(s.newBlockSubscriptions, id) + } + } + s.muNewBlockSubscriptions.Unlock() + } + } } // BlockNumber returns the block number of the chain head. @@ -564,20 +596,36 @@ type NewBlocksArgs struct { // NewBlocks triggers a new block event each time a block is appended to the chain. It accepts an argument which allows // the caller to specify whether the output should contain transactions and in what format. -func (s *PublicBlockChainAPI) NewBlocks(args NewBlocksArgs) (rpc.Subscription, error) { - sub := s.eventMux.Subscribe(core.ChainEvent{}) +func (s *PublicBlockChainAPI) NewBlocks(ctx context.Context, args NewBlocksArgs) (rpc.Subscription, error) { + notifier, supported := ctx.Value(rpc.NotifierContextKey).(rpc.Notifier) + if !supported { + return nil, rpc.ErrNotificationsUnsupported + } - output := func(rawBlock interface{}) interface{} { - if event, ok := rawBlock.(core.ChainEvent); ok { - notification, err := s.rpcOutputBlock(event.Block, args.IncludeTransactions, args.TransactionDetails) - if err == nil { - return notification - } + // create a subscription that will remove itself when unsubscribed/cancelled + subscription, err := notifier.NewSubscription(func(subId string) { + s.muNewBlockSubscriptions.Lock() + delete(s.newBlockSubscriptions, subId) + s.muNewBlockSubscriptions.Unlock() + }) + + if err != nil { + return nil, err + } + + // add a callback that is called on chain events which will format the block and notify the client + s.muNewBlockSubscriptions.Lock() + s.newBlockSubscriptions[subscription.ID()] = func(e core.ChainEvent) error { + if notification, err := s.rpcOutputBlock(e.Block, args.IncludeTransactions, args.TransactionDetails); err == nil { + return subscription.Notify(notification) + } else { + glog.V(logger.Warn).Info("unable to format block %v\n", err) } - return rawBlock + return nil } + s.muNewBlockSubscriptions.Unlock() - return rpc.NewSubscriptionWithOutputFormat(sub, output), nil + return subscription, nil } // GetCode returns the code stored at the given address in the state for the given block number. @@ -821,26 +869,75 @@ func newRPCTransaction(b *types.Block, txHash common.Hash) (*RPCTransaction, err // PublicTransactionPoolAPI exposes methods for the RPC interface type PublicTransactionPoolAPI struct { - eventMux *event.TypeMux - chainDb ethdb.Database - gpo *GasPriceOracle - bc *core.BlockChain - miner *miner.Miner - am *accounts.Manager - txPool *core.TxPool - txMu sync.Mutex + eventMux *event.TypeMux + chainDb ethdb.Database + gpo *GasPriceOracle + bc *core.BlockChain + miner *miner.Miner + am *accounts.Manager + txPool *core.TxPool + txMu sync.Mutex + muPendingTxSubs sync.Mutex + pendingTxSubs map[string]rpc.Subscription } // NewPublicTransactionPoolAPI creates a new RPC service with methods specific for the transaction pool. func NewPublicTransactionPoolAPI(e *Ethereum) *PublicTransactionPoolAPI { - return &PublicTransactionPoolAPI{ - eventMux: e.EventMux(), - gpo: NewGasPriceOracle(e), - chainDb: e.ChainDb(), - bc: e.BlockChain(), - am: e.AccountManager(), - txPool: e.TxPool(), - miner: e.Miner(), + api := &PublicTransactionPoolAPI{ + eventMux: e.EventMux(), + gpo: NewGasPriceOracle(e), + chainDb: e.ChainDb(), + bc: e.BlockChain(), + am: e.AccountManager(), + txPool: e.TxPool(), + miner: e.Miner(), + pendingTxSubs: make(map[string]rpc.Subscription), + } + + go api.subscriptionLoop() + + return api +} + +// subscriptionLoop listens for events on the global event mux and creates notifications for subscriptions. +func (s *PublicTransactionPoolAPI) subscriptionLoop() { + sub := s.eventMux.Subscribe(core.TxPreEvent{}) + accountTimeout := time.NewTicker(10 * time.Second) + + // only publish pending tx signed by one of the accounts in the node + accountSet := set.New() + accounts, _ := s.am.Accounts() + for _, acc := range accounts { + accountSet.Add(acc.Address) + } + + for { + select { + case event := <-sub.Chan(): + if event == nil { + continue + } + tx := event.Data.(core.TxPreEvent) + if from, err := tx.Tx.FromFrontier(); err == nil { + if accountSet.Has(from) { + s.muPendingTxSubs.Lock() + for id, sub := range s.pendingTxSubs { + if sub.Notify(tx.Tx.Hash()) == rpc.ErrNotificationNotFound { + delete(s.pendingTxSubs, id) + } + } + s.muPendingTxSubs.Unlock() + } + } + case <-accountTimeout.C: + // refresh account list when accounts are added/removed from the node. + if accounts, err := s.am.Accounts(); err == nil { + accountSet.Clear() + for _, acc := range accounts { + accountSet.Add(acc.Address) + } + } + } } } @@ -1275,40 +1372,27 @@ func (s *PublicTransactionPoolAPI) PendingTransactions() ([]*RPCTransaction, err // NewPendingTransaction creates a subscription that is triggered each time a transaction enters the transaction pool // and is send from one of the transactions this nodes manages. -func (s *PublicTransactionPoolAPI) NewPendingTransactions() (rpc.Subscription, error) { - sub := s.eventMux.Subscribe(core.TxPreEvent{}) - - accounts, err := s.am.Accounts() - if err != nil { - return rpc.Subscription{}, err +func (s *PublicTransactionPoolAPI) NewPendingTransactions(ctx context.Context) (rpc.Subscription, error) { + notifier, supported := ctx.Value(rpc.NotifierContextKey).(rpc.Notifier) + if !supported { + return nil, rpc.ErrNotificationsUnsupported } - accountSet := set.New() - for _, account := range accounts { - accountSet.Add(account.Address) - } - accountSetLastUpdates := time.Now() - output := func(transaction interface{}) interface{} { - if time.Since(accountSetLastUpdates) > (time.Duration(2) * time.Second) { - if accounts, err = s.am.Accounts(); err != nil { - accountSet.Clear() - for _, account := range accounts { - accountSet.Add(account.Address) - } - accountSetLastUpdates = time.Now() - } - } + subscription, err := notifier.NewSubscription(func(id string) { + s.muPendingTxSubs.Lock() + delete(s.pendingTxSubs, id) + s.muPendingTxSubs.Unlock() + }) - tx := transaction.(core.TxPreEvent) - if from, err := tx.Tx.FromFrontier(); err == nil { - if accountSet.Has(from) { - return tx.Tx.Hash() - } - } - return nil + if err != nil { + return nil, err } - return rpc.NewSubscriptionWithOutputFormat(sub, output), nil + s.muPendingTxSubs.Lock() + s.pendingTxSubs[subscription.ID()] = subscription + s.muPendingTxSubs.Unlock() + + return subscription, nil } // Resend accepts an existing transaction and a new gas price and limit. It will remove the given transaction from the diff --git a/eth/backend.go b/eth/backend.go index 26af7ff91..20f516610 100644 --- a/eth/backend.go +++ b/eth/backend.go @@ -306,7 +306,7 @@ func (s *Ethereum) APIs() []rpc.API { }, { Namespace: "eth", Version: "1.0", - Service: downloader.NewPublicDownloaderAPI(s.Downloader()), + Service: downloader.NewPublicDownloaderAPI(s.Downloader(), s.EventMux()), Public: true, }, { Namespace: "miner", diff --git a/eth/downloader/api.go b/eth/downloader/api.go index 13d0ed46e..576b33f1d 100644 --- a/eth/downloader/api.go +++ b/eth/downloader/api.go @@ -17,18 +17,55 @@ package downloader import ( + "sync" + + "golang.org/x/net/context" + + "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/rpc" ) // PublicDownloaderAPI provides an API which gives information about the current synchronisation status. // It offers only methods that operates on data that can be available to anyone without security risks. type PublicDownloaderAPI struct { - d *Downloader + d *Downloader + mux *event.TypeMux + muSyncSubscriptions sync.Mutex + syncSubscriptions map[string]rpc.Subscription } // NewPublicDownloaderAPI create a new PublicDownloaderAPI. -func NewPublicDownloaderAPI(d *Downloader) *PublicDownloaderAPI { - return &PublicDownloaderAPI{d} +func NewPublicDownloaderAPI(d *Downloader, m *event.TypeMux) *PublicDownloaderAPI { + api := &PublicDownloaderAPI{d: d, mux: m, syncSubscriptions: make(map[string]rpc.Subscription)} + + go api.run() + + return api +} + +func (api *PublicDownloaderAPI) run() { + sub := api.mux.Subscribe(StartEvent{}, DoneEvent{}, FailedEvent{}) + + for event := range sub.Chan() { + var notification interface{} + + switch event.Data.(type) { + case StartEvent: + result := &SyncingResult{Syncing: true} + result.Status.Origin, result.Status.Current, result.Status.Height, result.Status.Pulled, result.Status.Known = api.d.Progress() + notification = result + case DoneEvent, FailedEvent: + notification = false + } + + api.muSyncSubscriptions.Lock() + for id, sub := range api.syncSubscriptions { + if sub.Notify(notification) == rpc.ErrNotificationNotFound { + delete(api.syncSubscriptions, id) + } + } + api.muSyncSubscriptions.Unlock() + } } // Progress gives progress indications when the node is synchronising with the Ethereum network. @@ -47,19 +84,25 @@ type SyncingResult struct { } // Syncing provides information when this nodes starts synchronising with the Ethereum network and when it's finished. -func (s *PublicDownloaderAPI) Syncing() (rpc.Subscription, error) { - sub := s.d.mux.Subscribe(StartEvent{}, DoneEvent{}, FailedEvent{}) +func (api *PublicDownloaderAPI) Syncing(ctx context.Context) (rpc.Subscription, error) { + notifier, supported := ctx.Value(rpc.NotifierContextKey).(rpc.Notifier) + if !supported { + return nil, rpc.ErrNotificationsUnsupported + } - output := func(event interface{}) interface{} { - switch event.(type) { - case StartEvent: - result := &SyncingResult{Syncing: true} - result.Status.Origin, result.Status.Current, result.Status.Height, result.Status.Pulled, result.Status.Known = s.d.Progress() - return result - case DoneEvent, FailedEvent: - return false - } - return nil + subscription, err := notifier.NewSubscription(func(id string) { + api.muSyncSubscriptions.Lock() + delete(api.syncSubscriptions, id) + api.muSyncSubscriptions.Unlock() + }) + + if err != nil { + return nil, err } - return rpc.NewSubscriptionWithOutputFormat(sub, output), nil + + api.muSyncSubscriptions.Lock() + api.syncSubscriptions[subscription.ID()] = subscription + api.muSyncSubscriptions.Unlock() + + return subscription, nil } diff --git a/eth/filters/api.go b/eth/filters/api.go index e6a1ce3ab..956660363 100644 --- a/eth/filters/api.go +++ b/eth/filters/api.go @@ -17,15 +17,13 @@ package filters import ( - "sync" - "time" - "crypto/rand" "encoding/hex" - "errors" - "encoding/json" + "errors" "fmt" + "sync" + "time" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" @@ -33,6 +31,8 @@ import ( "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/rpc" + + "golang.org/x/net/context" ) var ( @@ -202,7 +202,7 @@ func (s *PublicFilterAPI) NewPendingTransactionFilter() (string, error) { } // newLogFilter creates a new log filter. -func (s *PublicFilterAPI) newLogFilter(earliest, latest int64, addresses []common.Address, topics [][]common.Hash) (int, error) { +func (s *PublicFilterAPI) newLogFilter(earliest, latest int64, addresses []common.Address, topics [][]common.Hash, callback func(log *vm.Log, removed bool)) (int, error) { s.logMu.Lock() defer s.logMu.Unlock() @@ -219,17 +219,70 @@ func (s *PublicFilterAPI) newLogFilter(earliest, latest int64, addresses []commo filter.SetAddresses(addresses) filter.SetTopics(topics) filter.LogCallback = func(log *vm.Log, removed bool) { - s.logMu.Lock() - defer s.logMu.Unlock() - - if queue := s.logQueue[id]; queue != nil { - queue.add(vmlog{log, removed}) + if callback != nil { + callback(log, removed) + } else { + s.logMu.Lock() + defer s.logMu.Unlock() + if queue := s.logQueue[id]; queue != nil { + queue.add(vmlog{log, removed}) + } } } return id, nil } +func (s *PublicFilterAPI) Logs(ctx context.Context, args NewFilterArgs) (rpc.Subscription, error) { + notifier, supported := ctx.Value(rpc.NotifierContextKey).(rpc.Notifier) + if !supported { + return nil, rpc.ErrNotificationsUnsupported + } + + var ( + externalId string + subscription rpc.Subscription + err error + ) + + if externalId, err = newFilterId(); err != nil { + return nil, err + } + + // uninstall filter when subscription is unsubscribed/cancelled + if subscription, err = notifier.NewSubscription(func(string) { + s.UninstallFilter(externalId) + }); err != nil { + return nil, err + } + + notifySubscriber := func(log *vm.Log, removed bool) { + rpcLog := toRPCLogs(vm.Logs{log}, removed) + if err := subscription.Notify(rpcLog); err != nil { + subscription.Cancel() + } + } + + // from and to block number are not used since subscriptions don't allow you to travel to "time" + var id int + if len(args.Addresses) > 0 { + id, err = s.newLogFilter(-1, -1, args.Addresses, args.Topics, notifySubscriber) + } else { + id, err = s.newLogFilter(-1, -1, nil, args.Topics, notifySubscriber) + } + + if err != nil { + subscription.Cancel() + return nil, err + } + + s.filterMapMu.Lock() + s.filterMapping[externalId] = id + s.filterMapMu.Unlock() + + return subscription, err +} + // NewFilterArgs represents a request to create a new filter. type NewFilterArgs struct { FromBlock rpc.BlockNumber @@ -364,9 +417,9 @@ func (s *PublicFilterAPI) NewFilter(args NewFilterArgs) (string, error) { var id int if len(args.Addresses) > 0 { - id, err = s.newLogFilter(args.FromBlock.Int64(), args.ToBlock.Int64(), args.Addresses, args.Topics) + id, err = s.newLogFilter(args.FromBlock.Int64(), args.ToBlock.Int64(), args.Addresses, args.Topics, nil) } else { - id, err = s.newLogFilter(args.FromBlock.Int64(), args.ToBlock.Int64(), nil, args.Topics) + id, err = s.newLogFilter(args.FromBlock.Int64(), args.ToBlock.Int64(), nil, args.Topics, nil) } if err != nil { return "", err diff --git a/miner/worker.go b/miner/worker.go index a5e2516fe..c5fb82b45 100644 --- a/miner/worker.go +++ b/miner/worker.go @@ -351,19 +351,15 @@ func (self *worker) wait() { } } +// push sends a new work task to currently live miner agents. func (self *worker) push(work *Work) { - if atomic.LoadInt32(&self.mining) == 1 { - if core.Canary(work.state) { - glog.Infoln("Toxicity levels rising to deadly levels. Your canary has died. You can go back or continue down the mineshaft --more--") - glog.Infoln("You turn back and abort mining") - return - } - // push new work to agents - for agent := range self.agents { - atomic.AddInt32(&self.atWork, 1) - if agent.Work() != nil { - agent.Work() <- work - } + if atomic.LoadInt32(&self.mining) != 1 { + return + } + for agent := range self.agents { + atomic.AddInt32(&self.atWork, 1) + if ch := agent.Work(); ch != nil { + ch <- work } } } diff --git a/node/node.go b/node/node.go index 7d3a10874..62cc3895b 100644 --- a/node/node.go +++ b/node/node.go @@ -303,7 +303,7 @@ func (n *Node) startIPC(apis []rpc.API) error { glog.V(logger.Error).Infof("IPC accept failed: %v", err) continue } - go handler.ServeCodec(rpc.NewJSONCodec(conn)) + go handler.ServeCodec(rpc.NewJSONCodec(conn), rpc.OptionMethodInvocation | rpc.OptionSubscriptions) } }() // All listeners booted successfully diff --git a/rpc/doc.go b/rpc/doc.go index a2506ad58..c9dba3270 100644 --- a/rpc/doc.go +++ b/rpc/doc.go @@ -68,35 +68,19 @@ The package also supports the publish subscribe pattern through the use of subsc A method that is considered eligible for notifications must satisfy the following criteria: - object must be exported - method must be exported + - first method argument type must be context.Context - method argument(s) must be exported or builtin types - method must return the tuple Subscription, error - An example method: - func (s *BlockChainService) Head() (Subscription, error) { - sub := s.bc.eventMux.Subscribe(ChainHeadEvent{}) - return v2.NewSubscription(sub), nil - } - -This method will push all raised ChainHeadEvents to subscribed clients. If the client is only -interested in every N'th block it is possible to add a criteria. - - func (s *BlockChainService) HeadFiltered(nth uint64) (Subscription, error) { - sub := s.bc.eventMux.Subscribe(ChainHeadEvent{}) - - criteria := func(event interface{}) bool { - chainHeadEvent := event.(ChainHeadEvent) - if chainHeadEvent.Block.NumberU64() % nth == 0 { - return true - } - return false - } - - return v2.NewSubscriptionFiltered(sub, criteria), nil + func (s *BlockChainService) NewBlocks(ctx context.Context) (Subscription, error) { + ... } Subscriptions are deleted when: - the user sends an unsubscribe request - - the connection which was used to create the subscription is closed + - the connection which was used to create the subscription is closed. This can be initiated + by the client and server. The server will close the connection on an write error or when + the queue of buffered notifications gets too big. */ package rpc diff --git a/rpc/http.go b/rpc/http.go index af3d29014..dd1ec2c01 100644 --- a/rpc/http.go +++ b/rpc/http.go @@ -126,7 +126,7 @@ func newJSONHTTPHandler(srv *Server) http.HandlerFunc { // a single request. codec := NewJSONCodec(&httpReadWriteNopCloser{r.Body, w}) defer codec.Close() - srv.ServeSingleRequest(codec) + srv.ServeSingleRequest(codec, OptionMethodInvocation) } } diff --git a/rpc/inproc.go b/rpc/inproc.go index 3cfbea71c..250f5c787 100644 --- a/rpc/inproc.go +++ b/rpc/inproc.go @@ -39,7 +39,7 @@ func (c *inProcClient) Close() { // RPC server. func NewInProcRPCClient(handler *Server) Client { p1, p2 := net.Pipe() - go handler.ServeCodec(NewJSONCodec(p1)) + go handler.ServeCodec(NewJSONCodec(p1), OptionMethodInvocation|OptionSubscriptions) return &inProcClient{handler, p2, json.NewEncoder(p2), json.NewDecoder(p2)} } diff --git a/rpc/json.go b/rpc/json.go index 1ed943c00..a0bfcac04 100644 --- a/rpc/json.go +++ b/rpc/json.go @@ -22,7 +22,7 @@ import ( "io" "reflect" "strings" - "sync/atomic" + "sync" "github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger/glog" @@ -81,19 +81,20 @@ type jsonNotification struct { // jsonCodec reads and writes JSON-RPC messages to the underlying connection. It also has support for parsing arguments // and serializing (result) objects. type jsonCodec struct { - closed chan interface{} - isClosed int32 - d *json.Decoder - e *json.Encoder - req JSONRequest - rw io.ReadWriteCloser + closed chan interface{} + closer sync.Once + d *json.Decoder + muEncoder sync.Mutex + e *json.Encoder + req JSONRequest + rw io.ReadWriteCloser } // NewJSONCodec creates a new RPC server codec with support for JSON-RPC 2.0 func NewJSONCodec(rwc io.ReadWriteCloser) ServerCodec { d := json.NewDecoder(rwc) d.UseNumber() - return &jsonCodec{closed: make(chan interface{}), d: d, e: json.NewEncoder(rwc), rw: rwc, isClosed: 0} + return &jsonCodec{closed: make(chan interface{}), d: d, e: json.NewEncoder(rwc), rw: rwc} } // isBatch returns true when the first non-whitespace characters is '[' @@ -326,15 +327,18 @@ func (c *jsonCodec) CreateNotification(subid string, event interface{}) interfac // Write message to client func (c *jsonCodec) Write(res interface{}) error { + c.muEncoder.Lock() + defer c.muEncoder.Unlock() + return c.e.Encode(res) } // Close the underlying connection func (c *jsonCodec) Close() { - if atomic.CompareAndSwapInt32(&c.isClosed, 0, 1) { + c.closer.Do(func() { close(c.closed) c.rw.Close() - } + }) } // Closed returns a channel which will be closed when Close is called diff --git a/rpc/notification.go b/rpc/notification.go new file mode 100644 index 000000000..146d785c9 --- /dev/null +++ b/rpc/notification.go @@ -0,0 +1,288 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +package rpc + +import ( + "errors" + "sync" + "time" + + "github.com/ethereum/go-ethereum/logger" + "github.com/ethereum/go-ethereum/logger/glog" +) + +var ( + // ErrNotificationsUnsupported is returned when the connection doesn't support notifications + ErrNotificationsUnsupported = errors.New("notifications not supported") + + // ErrNotificationNotFound is returned when the notification for the given id is not found + ErrNotificationNotFound = errors.New("notification not found") + + // errNotifierStopped is returned when the notifier is stopped (e.g. codec is closed) + errNotifierStopped = errors.New("unable to send notification") + + // errNotificationQueueFull is returns when there are too many notifications in the queue + errNotificationQueueFull = errors.New("too many pending notifications") +) + +// unsubSignal is a signal that the subscription is unsubscribed. It is used to flush buffered +// notifications that might be pending in the internal queue. +var unsubSignal = new(struct{}) + +// UnsubscribeCallback defines a callback that is called when a subcription ends. +// It receives the subscription id as argument. +type UnsubscribeCallback func(id string) + +// notification is a helper object that holds event data for a subscription +type notification struct { + sub *bufferedSubscription // subscription id + data interface{} // event data +} + +// A Notifier type describes the interface for objects that can send create subscriptions +type Notifier interface { + // Create a new subscription. The given callback is called when this subscription + // is cancelled (e.g. client send an unsubscribe, connection closed). + NewSubscription(UnsubscribeCallback) (Subscription, error) + // Cancel subscription + Unsubscribe(id string) error +} + +// Subscription defines the interface for objects that can notify subscribers +type Subscription interface { + // Inform client of an event + Notify(data interface{}) error + // Unique identifier + ID() string + // Cancel subscription + Cancel() error +} + +// bufferedSubscription is a subscription that uses a bufferedNotifier to send +// notifications to subscribers. +type bufferedSubscription struct { + id string + unsubOnce sync.Once // call unsub method once + unsub UnsubscribeCallback // called on Unsubscribed + notifier *bufferedNotifier // forward notifications to + pending chan interface{} // closed when active + flushed chan interface{} // closed when all buffered notifications are send + lastNotification time.Time // last time a notification was send +} + +// ID returns the subscription identifier that the client uses to refer to this instance. +func (s *bufferedSubscription) ID() string { + return s.id +} + +// Cancel informs the notifier that this subscription is cancelled by the API +func (s *bufferedSubscription) Cancel() error { + return s.notifier.Unsubscribe(s.id) +} + +// Notify the subscriber of a particular event. +func (s *bufferedSubscription) Notify(data interface{}) error { + return s.notifier.send(s.id, data) +} + +// bufferedNotifier is a notifier that queues notifications in an internal queue and +// send them as fast as possible to the client from this queue. It will stop if the +// queue grows past a given size. +type bufferedNotifier struct { + codec ServerCodec // underlying connection + mu sync.Mutex // guard internal state + subscriptions map[string]*bufferedSubscription // keep track of subscriptions associated with codec + queueSize int // max number of items in queue + queue chan *notification // notification queue + stopped bool // indication if this notifier is ordered to stop +} + +// newBufferedNotifier returns a notifier that queues notifications in an internal queue +// from which notifications are send as fast as possible to the client. If the queue size +// limit is reached (client is unable to keep up) it will stop and closes the codec. +func newBufferedNotifier(codec ServerCodec, size int) *bufferedNotifier { + notifier := &bufferedNotifier{ + codec: codec, + subscriptions: make(map[string]*bufferedSubscription), + queue: make(chan *notification, size), + queueSize: size, + } + + go notifier.run() + + return notifier +} + +// NewSubscription creates a new subscription that forwards events to this instance internal +// queue. The given callback is called when the subscription is unsubscribed/cancelled. +func (n *bufferedNotifier) NewSubscription(callback UnsubscribeCallback) (Subscription, error) { + id, err := newSubscriptionID() + if err != nil { + return nil, err + } + + n.mu.Lock() + defer n.mu.Unlock() + + if n.stopped { + return nil, errNotifierStopped + } + + sub := &bufferedSubscription{ + id: id, + unsub: callback, + notifier: n, + pending: make(chan interface{}), + flushed: make(chan interface{}), + lastNotification: time.Now(), + } + + n.subscriptions[id] = sub + + return sub, nil +} + +// Remove the given subscription. If subscription is not found notificationNotFoundErr is returned. +func (n *bufferedNotifier) Unsubscribe(subid string) error { + n.mu.Lock() + sub, found := n.subscriptions[subid] + n.mu.Unlock() + + if found { + // send the unsubscribe signal, this will cause the notifier not to accept new events + // for this subscription and will close the flushed channel after the last (buffered) + // notification was send to the client. + if err := n.send(subid, unsubSignal); err != nil { + return err + } + + // wait for confirmation that all (buffered) events are send for this subscription. + // this ensures that the unsubscribe method response is not send before all buffered + // events for this subscription are send. + <-sub.flushed + + return nil + } + + return ErrNotificationNotFound +} + +// Send enques the given data for the subscription with public ID on the internal queue. t returns +// an error when the notifier is stopped or the queue is full. If data is the unsubscribe signal it +// will remove the subscription with the given id from the subscription collection. +func (n *bufferedNotifier) send(id string, data interface{}) error { + n.mu.Lock() + defer n.mu.Unlock() + + if n.stopped { + return errNotifierStopped + } + + var ( + subscription *bufferedSubscription + found bool + ) + + // check if subscription is associated with this connection, it might be cancelled + // (subscribe/connection closed) + if subscription, found = n.subscriptions[id]; !found { + glog.V(logger.Error).Infof("received notification for unknown subscription %s\n", id) + return ErrNotificationNotFound + } + + // received the unsubscribe signal. Add it to the queue to make sure any pending notifications + // for this subscription are send. When the run loop receives this singal it will signal that + // all pending subscriptions are flushed and that the confirmation of the unsubscribe can be + // send to the user. Remove the subscriptions to make sure new notifications are not accepted. + if data == unsubSignal { + delete(n.subscriptions, id) + if subscription.unsub != nil { + subscription.unsubOnce.Do(func() { subscription.unsub(id) }) + } + } + + subscription.lastNotification = time.Now() + + if len(n.queue) >= n.queueSize { + glog.V(logger.Warn).Infoln("too many buffered notifications -> close connection") + n.codec.Close() + return errNotificationQueueFull + } + + n.queue <- ¬ification{subscription, data} + return nil +} + +// run reads notifications from the internal queue and sends them to the client. In case of an +// error, or when the codec is closed it will cancel all active subscriptions and returns. +func (n *bufferedNotifier) run() { + defer func() { + n.mu.Lock() + defer n.mu.Unlock() + + n.stopped = true + close(n.queue) + + // on exit call unsubscribe callback + for id, sub := range n.subscriptions { + if sub.unsub != nil { + sub.unsubOnce.Do(func() { sub.unsub(id) }) + } + close(sub.flushed) + delete(n.subscriptions, id) + } + }() + + for { + select { + case notification := <-n.queue: + // It can happen that an event is raised before the RPC server was able to send the sub + // id to the client. Therefore subscriptions are marked as pending until the sub id was + // send. The RPC server will activate the subscription by closing the pending chan. + <-notification.sub.pending + + if notification.data == unsubSignal { + // unsubSignal is the last accepted message for this subscription. Raise the signal + // that all buffered notifications are sent by closing the flushed channel. This + // indicates that the response for the unsubscribe can be send to the client. + close(notification.sub.flushed) + } else { + msg := n.codec.CreateNotification(notification.sub.id, notification.data) + if err := n.codec.Write(msg); err != nil { + n.codec.Close() + // unable to send notification to client, unsubscribe all subscriptions + glog.V(logger.Warn).Infof("unable to send notification - %v\n", err) + return + } + } + case <-n.codec.Closed(): // connection was closed + glog.V(logger.Debug).Infoln("codec closed, stop subscriptions") + return + } + } +} + +// Marks the subscription as active. This will causes the notifications for this subscription to be +// forwarded to the client. +func (n *bufferedNotifier) activate(subid string) { + n.mu.Lock() + defer n.mu.Unlock() + + if sub, found := n.subscriptions[subid]; found { + close(sub.pending) + } +} diff --git a/rpc/notification_test.go b/rpc/notification_test.go new file mode 100644 index 000000000..8d2add81c --- /dev/null +++ b/rpc/notification_test.go @@ -0,0 +1,119 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. +package rpc + +import ( + "encoding/json" + "net" + "testing" + "time" + + "golang.org/x/net/context" +) + +type NotificationTestService struct{} + +var ( + unsubCallbackCalled = false +) + +func (s *NotificationTestService) Unsubscribe(subid string) { + unsubCallbackCalled = true +} + +func (s *NotificationTestService) SomeSubscription(ctx context.Context, n, val int) (Subscription, error) { + notifier, supported := ctx.Value(NotifierContextKey).(Notifier) + if !supported { + return nil, ErrNotificationsUnsupported + } + + // by explicitly creating an subscription we make sure that the subscription id is send back to the client + // before the first subscription.Notify is called. Otherwise the events might be send before the response + // for the eth_subscribe method. + subscription, err := notifier.NewSubscription(s.Unsubscribe) + if err != nil { + return nil, err + } + + go func() { + for i := 0; i < n; i++ { + if err := subscription.Notify(val + i); err != nil { + return + } + } + }() + + return subscription, nil +} + +func TestNotifications(t *testing.T) { + server := NewServer() + service := &NotificationTestService{} + + if err := server.RegisterName("eth", service); err != nil { + t.Fatalf("unable to register test service %v", err) + } + + clientConn, serverConn := net.Pipe() + + go server.ServeCodec(NewJSONCodec(serverConn), OptionMethodInvocation|OptionSubscriptions) + + out := json.NewEncoder(clientConn) + in := json.NewDecoder(clientConn) + + n := 5 + val := 12345 + request := map[string]interface{}{ + "id": 1, + "method": "eth_subscribe", + "version": "2.0", + "params": []interface{}{"someSubscription", n, val}, + } + + // create subscription + if err := out.Encode(request); err != nil { + t.Fatal(err) + } + + var subid string + response := JSONSuccessResponse{Result: subid} + if err := in.Decode(&response); err != nil { + t.Fatal(err) + } + + var ok bool + if subid, ok = response.Result.(string); !ok { + t.Fatalf("expected subscription id, got %T", response.Result) + } + + for i := 0; i < n; i++ { + var notification jsonNotification + if err := in.Decode(¬ification); err != nil { + t.Fatalf("%v", err) + } + + if int(notification.Params.Result.(float64)) != val+i { + t.Fatalf("expected %d, got %d", val+i, notification.Params.Result) + } + } + + clientConn.Close() // causes notification unsubscribe callback to be called + time.Sleep(1 * time.Second) + + if !unsubCallbackCalled { + t.Error("unsubscribe callback not called after closing connection") + } +} diff --git a/rpc/server.go b/rpc/server.go index 22448f8e3..cf90eba02 100644 --- a/rpc/server.go +++ b/rpc/server.go @@ -23,7 +23,6 @@ import ( "sync/atomic" "time" - "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger/glog" "golang.org/x/net/context" @@ -33,10 +32,26 @@ import ( const ( stopPendingRequestTimeout = 3 * time.Second // give pending requests stopPendingRequestTimeout the time to finish when the server is stopped + // NotifierContextKey is the key where the notifier associated with the codec is stored in the context + NotifierContextKey = 1 + + notificationBufferSize = 10000 // max buffered notifications before codec is closed + DefaultIPCApis = "admin,eth,debug,miner,net,shh,txpool,personal,web3" DefaultHTTPApis = "eth,net,web3" ) +// CodecOption specifies which type of messages this codec supports +type CodecOption int + +const ( + // OptionMethodInvocation is an indication that the codec supports RPC method calls + OptionMethodInvocation CodecOption = 1 << iota + + // OptionSubscriptions is an indication that the codec suports RPC notifications + OptionSubscriptions = 1 << iota // support pub sub +) + // NewServer will create a new server instance with no registered handlers. func NewServer() *Server { server := &Server{ @@ -63,7 +78,7 @@ type RPCService struct { // Modules returns the list of RPC services with their version number func (s *RPCService) Modules() map[string]string { modules := make(map[string]string) - for name, _ := range s.server.services { + for name := range s.server.services { modules[name] = "1.0" } return modules @@ -92,7 +107,7 @@ func (s *Server) RegisterName(name string, rcvr interface{}) error { if regsvc, present := s.services[name]; present { methods, subscriptions := suitableCallbacks(rcvrVal, svc.typ) if len(methods) == 0 && len(subscriptions) == 0 { - return fmt.Errorf("Service doesn't have any suitable methods/subscriptions to expose") + return fmt.Errorf("Service %T doesn't have any suitable methods/subscriptions to expose", rcvr) } for _, m := range methods { @@ -109,7 +124,7 @@ func (s *Server) RegisterName(name string, rcvr interface{}) error { svc.callbacks, svc.subscriptions = suitableCallbacks(rcvrVal, svc.typ) if len(svc.callbacks) == 0 && len(svc.subscriptions) == 0 { - return fmt.Errorf("Service doesn't have any suitable methods/subscriptions to expose") + return fmt.Errorf("Service %T doesn't have any suitable methods/subscriptions to expose", rcvr) } s.services[svc.name] = svc @@ -117,12 +132,23 @@ func (s *Server) RegisterName(name string, rcvr interface{}) error { return nil } +// hasOption returns true if option is included in options, otherwise false +func hasOption(option CodecOption, options []CodecOption) bool { + for _, o := range options { + if option == o { + return true + } + } + return false +} + // serveRequest will reads requests from the codec, calls the RPC callback and // writes the response to the given codec. +// // If singleShot is true it will process a single request, otherwise it will handle // requests until the codec returns an error when reading a request (in most cases // an EOF). It executes requests in parallel when singleShot is false. -func (s *Server) serveRequest(codec ServerCodec, singleShot bool) error { +func (s *Server) serveRequest(codec ServerCodec, singleShot bool, options CodecOption) error { defer func() { if err := recover(); err != nil { const size = 64 << 10 @@ -141,6 +167,12 @@ func (s *Server) serveRequest(codec ServerCodec, singleShot bool) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() + // if the codec supports notification include a notifier that callbacks can use + // to send notification to clients. It is thight to the codec/connection. If the + // connection is closed the notifier will stop and cancels all active subscriptions. + if options&OptionSubscriptions == OptionSubscriptions { + ctx = context.WithValue(ctx, NotifierContextKey, newBufferedNotifier(codec, notificationBufferSize)) + } s.codecsMu.Lock() if atomic.LoadInt32(&s.run) != 1 { // server stopped s.codecsMu.Unlock() @@ -193,20 +225,16 @@ func (s *Server) serveRequest(codec ServerCodec, singleShot bool) error { // ServeCodec reads incoming requests from codec, calls the appropriate callback and writes the // response back using the given codec. It will block until the codec is closed or the server is // stopped. In either case the codec is closed. -// -// This server will: -// 1. allow for asynchronous and parallel request execution -// 2. supports notifications (pub/sub) -// 3. supports request batches -func (s *Server) ServeCodec(codec ServerCodec) { +func (s *Server) ServeCodec(codec ServerCodec, options CodecOption) { defer codec.Close() - s.serveRequest(codec, false) + s.serveRequest(codec, false, options) } // ServeSingleRequest reads and processes a single RPC request from the given codec. It will not -// close the codec unless a non-recoverable error has occurred. -func (s *Server) ServeSingleRequest(codec ServerCodec) { - s.serveRequest(codec, true) +// close the codec unless a non-recoverable error has occurred. Note, this method will return after +// a single request has been processed! +func (s *Server) ServeSingleRequest(codec ServerCodec, options CodecOption) { + s.serveRequest(codec, true, options) } // Stop will stop reading new requests, wait for stopPendingRequestTimeout to allow pending requests to finish, @@ -225,122 +253,64 @@ func (s *Server) Stop() { } } -// sendNotification will create a notification from the given event by serializing member fields of the event. -// It will then send the notification to the client, when it fails the codec is closed. When the event has multiple -// fields an array of values is returned. -func sendNotification(codec ServerCodec, subid string, event interface{}) { - notification := codec.CreateNotification(subid, event) - - if err := codec.Write(notification); err != nil { - codec.Close() - } -} - -// createSubscription will register a new subscription and waits for raised events. When an event is raised it will: -// 1. test if the event is raised matches the criteria the user has (optionally) specified -// 2. create a notification of the event and send it the client when it matches the criteria -// It will unsubscribe the subscription when the socket is closed or the subscription is unsubscribed by the user. -func (s *Server) createSubscription(c ServerCodec, req *serverRequest) (string, error) { - args := []reflect.Value{req.callb.rcvr} - if len(req.args) > 0 { - args = append(args, req.args...) - } - - subid, err := newSubscriptionId() - if err != nil { - return "", err - } - +// createSubscription will call the subscription callback and returns the subscription id or error. +func (s *Server) createSubscription(ctx context.Context, c ServerCodec, req *serverRequest) (string, error) { + // subscription have as first argument the context following optional arguments + args := []reflect.Value{req.callb.rcvr, reflect.ValueOf(ctx)} + args = append(args, req.args...) reply := req.callb.method.Func.Call(args) - if reply[1].IsNil() { // no error - if subscription, ok := reply[0].Interface().(Subscription); ok { - s.muSubcriptions.Lock() - s.subscriptions[subid] = subscription - s.muSubcriptions.Unlock() - go func() { - cases := []reflect.SelectCase{ - reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(subscription.Chan())}, // new event - reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(c.Closed())}, // connection closed - } - - for { - idx, notification, recvOk := reflect.Select(cases) - switch idx { - case 0: // new event, or channel closed - if recvOk { // send notification - if event, ok := notification.Interface().(*event.Event); ok { - if subscription.match == nil || subscription.match(event.Data) { - sendNotification(c, subid, subscription.format(event.Data)) - } - } - } else { // user send an eth_unsubscribe request - return - } - case 1: // connection closed - s.unsubscribe(subid) - return - } - } - }() - } else { // unable to create subscription - s.muSubcriptions.Lock() - delete(s.subscriptions, subid) - s.muSubcriptions.Unlock() - } - } else { - return "", fmt.Errorf("Unable to create subscription") + if !reply[1].IsNil() { // subscription creation failed + return "", reply[1].Interface().(error) } - return subid, nil -} - -// unsubscribe calls the Unsubscribe method on the subscription and removes a subscription from the subscription -// registry. -func (s *Server) unsubscribe(subid string) bool { - s.muSubcriptions.Lock() - defer s.muSubcriptions.Unlock() - if sub, ok := s.subscriptions[subid]; ok { - sub.Unsubscribe() - delete(s.subscriptions, subid) - return true - } - return false + return reply[0].Interface().(Subscription).ID(), nil } // handle executes a request and returns the response from the callback. -func (s *Server) handle(ctx context.Context, codec ServerCodec, req *serverRequest) interface{} { +func (s *Server) handle(ctx context.Context, codec ServerCodec, req *serverRequest) (interface{}, func()) { if req.err != nil { - return codec.CreateErrorResponse(&req.id, req.err) + return codec.CreateErrorResponse(&req.id, req.err), nil } - if req.isUnsubscribe { // first param must be the subscription id + if req.isUnsubscribe { // cancel subscription, first param must be the subscription id if len(req.args) >= 1 && req.args[0].Kind() == reflect.String { + notifier, supported := ctx.Value(NotifierContextKey).(*bufferedNotifier) + if !supported { // interface doesn't support subscriptions (e.g. http) + return codec.CreateErrorResponse(&req.id, &callbackError{ErrNotificationsUnsupported.Error()}), nil + } + subid := req.args[0].String() - if s.unsubscribe(subid) { - return codec.CreateResponse(req.id, true) - } else { - return codec.CreateErrorResponse(&req.id, - &callbackError{fmt.Sprintf("subscription '%s' not found", subid)}) + if err := notifier.Unsubscribe(subid); err != nil { + return codec.CreateErrorResponse(&req.id, &callbackError{err.Error()}), nil } + + return codec.CreateResponse(req.id, true), nil } - return codec.CreateErrorResponse(&req.id, &invalidParamsError{"Expected subscription id as argument"}) + return codec.CreateErrorResponse(&req.id, &invalidParamsError{"Expected subscription id as first argument"}), nil } if req.callb.isSubscribe { - subid, err := s.createSubscription(codec, req) + subid, err := s.createSubscription(ctx, codec, req) if err != nil { - return codec.CreateErrorResponse(&req.id, &callbackError{err.Error()}) + return codec.CreateErrorResponse(&req.id, &callbackError{err.Error()}), nil + } + + // active the subscription after the sub id was successful sent to the client + activateSub := func() { + notifier, _ := ctx.Value(NotifierContextKey).(*bufferedNotifier) + notifier.activate(subid) } - return codec.CreateResponse(req.id, subid) + + return codec.CreateResponse(req.id, subid), activateSub } - // regular RPC call + // regular RPC call, prepare arguments if len(req.args) != len(req.callb.argTypes) { rpcErr := &invalidParamsError{fmt.Sprintf("%s%s%s expects %d parameters, got %d", req.svcname, serviceMethodSeparator, req.callb.method.Name, len(req.callb.argTypes), len(req.args))} - return codec.CreateErrorResponse(&req.id, rpcErr) + return codec.CreateErrorResponse(&req.id, rpcErr), nil } arguments := []reflect.Value{req.callb.rcvr} @@ -351,45 +321,56 @@ func (s *Server) handle(ctx context.Context, codec ServerCodec, req *serverReque arguments = append(arguments, req.args...) } + // execute RPC method and return result reply := req.callb.method.Func.Call(arguments) - if len(reply) == 0 { - return codec.CreateResponse(req.id, nil) + return codec.CreateResponse(req.id, nil), nil } if req.callb.errPos >= 0 { // test if method returned an error if !reply[req.callb.errPos].IsNil() { e := reply[req.callb.errPos].Interface().(error) res := codec.CreateErrorResponse(&req.id, &callbackError{e.Error()}) - return res + return res, nil } } - return codec.CreateResponse(req.id, reply[0].Interface()) + return codec.CreateResponse(req.id, reply[0].Interface()), nil } // exec executes the given request and writes the result back using the codec. func (s *Server) exec(ctx context.Context, codec ServerCodec, req *serverRequest) { var response interface{} + var callback func() if req.err != nil { response = codec.CreateErrorResponse(&req.id, req.err) } else { - response = s.handle(ctx, codec, req) + response, callback = s.handle(ctx, codec, req) } + if err := codec.Write(response); err != nil { glog.V(logger.Error).Infof("%v\n", err) codec.Close() } + + // when request was a subscribe request this allows these subscriptions to be actived + if callback != nil { + callback() + } } -// execBatch executes the given requests and writes the result back using the codec. It will only write the response -// back when the last request is processed. +// execBatch executes the given requests and writes the result back using the codec. +// It will only write the response back when the last request is processed. func (s *Server) execBatch(ctx context.Context, codec ServerCodec, requests []*serverRequest) { responses := make([]interface{}, len(requests)) + var callbacks []func() for i, req := range requests { if req.err != nil { responses[i] = codec.CreateErrorResponse(&req.id, req.err) } else { - responses[i] = s.handle(ctx, codec, req) + var callback func() + if responses[i], callback = s.handle(ctx, codec, req); callback != nil { + callbacks = append(callbacks, callback) + } } } @@ -397,11 +378,16 @@ func (s *Server) execBatch(ctx context.Context, codec ServerCodec, requests []*s glog.V(logger.Error).Infof("%v\n", err) codec.Close() } + + // when request holds one of more subscribe requests this allows these subscriptions to be actived + for _, c := range callbacks { + c() + } } -// readRequest requests the next (batch) request from the codec. It will return the collection of requests, an -// indication if the request was a batch, the invalid request identifier and an error when the request could not be -// read/parsed. +// readRequest requests the next (batch) request from the codec. It will return the collection +// of requests, an indication if the request was a batch, the invalid request identifier and an +// error when the request could not be read/parsed. func (s *Server) readRequest(codec ServerCodec) ([]*serverRequest, bool, RPCError) { reqs, batch, err := codec.ReadRequestHeaders() if err != nil { @@ -417,7 +403,7 @@ func (s *Server) readRequest(codec ServerCodec) ([]*serverRequest, bool, RPCErro if r.isPubSub && r.method == unsubscribeMethod { requests[i] = &serverRequest{id: r.id, isUnsubscribe: true} - argTypes := []reflect.Type{reflect.TypeOf("")} + argTypes := []reflect.Type{reflect.TypeOf("")} // expect subscription id as first arg if args, err := codec.ParseRequestArguments(argTypes, r.params); err == nil { requests[i].args = args } else { @@ -426,12 +412,12 @@ func (s *Server) readRequest(codec ServerCodec) ([]*serverRequest, bool, RPCErro continue } - if svc, ok = s.services[r.service]; !ok { + if svc, ok = s.services[r.service]; !ok { // rpc method isn't available requests[i] = &serverRequest{id: r.id, err: &methodNotFoundError{r.service, r.method}} continue } - if r.isPubSub { // eth_subscribe + if r.isPubSub { // eth_subscribe, r.method contains the subscription method name if callb, ok := svc.subscriptions[r.method]; ok { requests[i] = &serverRequest{id: r.id, svcname: svc.name, callb: callb} if r.params != nil && len(callb.argTypes) > 0 { @@ -449,7 +435,7 @@ func (s *Server) readRequest(codec ServerCodec) ([]*serverRequest, bool, RPCErro continue } - if callb, ok := svc.callbacks[r.method]; ok { + if callb, ok := svc.callbacks[r.method]; ok { // lookup RPC method requests[i] = &serverRequest{id: r.id, svcname: svc.name, callb: callb} if r.params != nil && len(callb.argTypes) > 0 { if args, err := codec.ParseRequestArguments(callb.argTypes, r.params); err == nil { diff --git a/rpc/server_test.go b/rpc/server_test.go index 5b91fe42a..c60db38df 100644 --- a/rpc/server_test.go +++ b/rpc/server_test.go @@ -65,8 +65,12 @@ func (s *Service) InvalidRets3() (string, string, error) { return "", "", nil } -func (s *Service) Subscription() (Subscription, error) { - return NewSubscription(nil), nil +func (s *Service) Subscription(ctx context.Context) (Subscription, error) { + return nil, nil +} + +func (s *Service) SubsriptionWithArgs(ctx context.Context, a, b int) (Subscription, error) { + return nil, nil } func TestServerRegisterName(t *testing.T) { @@ -90,8 +94,8 @@ func TestServerRegisterName(t *testing.T) { t.Errorf("Expected 4 callbacks for service 'calc', got %d", len(svc.callbacks)) } - if len(svc.subscriptions) != 1 { - t.Errorf("Expected 1 subscription for service 'calc', got %d", len(svc.subscriptions)) + if len(svc.subscriptions) != 2 { + t.Errorf("Expected 2 subscriptions for service 'calc', got %d", len(svc.subscriptions)) } } @@ -229,7 +233,7 @@ func TestServerMethodExecution(t *testing.T) { input, _ := json.Marshal(&req) codec := &ServerTestCodec{input: input, closer: make(chan interface{})} - go server.ServeCodec(codec) + go server.ServeCodec(codec, OptionMethodInvocation) <-codec.closer @@ -259,7 +263,7 @@ func TestServerMethodWithCtx(t *testing.T) { input, _ := json.Marshal(&req) codec := &ServerTestCodec{input: input, closer: make(chan interface{})} - go server.ServeCodec(codec) + go server.ServeCodec(codec, OptionMethodInvocation) <-codec.closer diff --git a/rpc/types.go b/rpc/types.go index f268d84db..596fdf264 100644 --- a/rpc/types.go +++ b/rpc/types.go @@ -24,7 +24,6 @@ import ( "strings" "sync" - "github.com/ethereum/go-ethereum/event" "gopkg.in/fatih/set.v0" ) @@ -66,10 +65,10 @@ type serverRequest struct { err RPCError } -type serviceRegistry map[string]*service // collection of services -type callbacks map[string]*callback // collection of RPC callbacks -type subscriptions map[string]*callback // collection of subscription callbacks -type subscriptionRegistry map[string]Subscription // collection of subscriptions +type serviceRegistry map[string]*service // collection of services +type callbacks map[string]*callback // collection of RPC callbacks +type subscriptions map[string]*callback // collection of subscription callbacks +type subscriptionRegistry map[string]*callback // collection of subscription callbacks // Server represents a RPC server type Server struct { @@ -123,51 +122,6 @@ type ServerCodec interface { Closed() <-chan interface{} } -// SubscriptionMatcher returns true if the given value matches the criteria specified by the user -type SubscriptionMatcher func(interface{}) bool - -// SubscriptionOutputFormat accepts event data and has the ability to format the data before it is send to the client -type SubscriptionOutputFormat func(interface{}) interface{} - -// defaultSubscriptionOutputFormatter returns data and is used as default output format for notifications -func defaultSubscriptionOutputFormatter(data interface{}) interface{} { - return data -} - -// Subscription is used by the server to send notifications to the client -type Subscription struct { - sub event.Subscription - match SubscriptionMatcher - format SubscriptionOutputFormat -} - -// NewSubscription create a new RPC subscription -func NewSubscription(sub event.Subscription) Subscription { - return Subscription{sub, nil, defaultSubscriptionOutputFormatter} -} - -// NewSubscriptionWithOutputFormat create a new RPC subscription which a custom notification output format -func NewSubscriptionWithOutputFormat(sub event.Subscription, formatter SubscriptionOutputFormat) Subscription { - return Subscription{sub, nil, formatter} -} - -// NewSubscriptionFiltered will create a new subscription. For each raised event the given matcher is -// called. If it returns true the event is send as notification to the client, otherwise it is ignored. -func NewSubscriptionFiltered(sub event.Subscription, match SubscriptionMatcher) Subscription { - return Subscription{sub, match, defaultSubscriptionOutputFormatter} -} - -// Chan returns the channel where new events will be published. It's up the user to call the matcher to -// determine if the events are interesting for the client. -func (s *Subscription) Chan() <-chan *event.Event { - return s.sub.Chan() -} - -// Unsubscribe will end the subscription and closes the event channel -func (s *Subscription) Unsubscribe() { - s.sub.Unsubscribe() -} - // HexNumber serializes a number to hex format using the "%#x" format type HexNumber big.Int diff --git a/rpc/utils.go b/rpc/utils.go index fa114284d..d43c50495 100644 --- a/rpc/utils.go +++ b/rpc/utils.go @@ -45,6 +45,16 @@ func isExportedOrBuiltinType(t reflect.Type) bool { return isExported(t.Name()) || t.PkgPath() == "" } +var contextType = reflect.TypeOf((*context.Context)(nil)).Elem() + +// isContextType returns an indication if the given t is of context.Context or *context.Context type +func isContextType(t reflect.Type) bool { + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + return t == contextType +} + var errorType = reflect.TypeOf((*error)(nil)).Elem() // Implements this type the error interface @@ -57,6 +67,7 @@ func isErrorType(t reflect.Type) bool { var subscriptionType = reflect.TypeOf((*Subscription)(nil)).Elem() +// isSubscriptionType returns an indication if the given t is of Subscription or *Subscription type func isSubscriptionType(t reflect.Type) bool { for t.Kind() == reflect.Ptr { t = t.Elem() @@ -64,12 +75,17 @@ func isSubscriptionType(t reflect.Type) bool { return t == subscriptionType } -// isPubSub tests whether the given method return the pair (v2.Subscription, error) +// isPubSub tests whether the given method has as as first argument a context.Context +// and returns the pair (Subscription, error) func isPubSub(methodType reflect.Type) bool { - if methodType.NumOut() != 2 { + // numIn(0) is the receiver type + if methodType.NumIn() < 2 || methodType.NumOut() != 2 { return false } - return isSubscriptionType(methodType.Out(0)) && isErrorType(methodType.Out(1)) + + return isContextType(methodType.In(1)) && + isSubscriptionType(methodType.Out(0)) && + isErrorType(methodType.Out(1)) } // formatName will convert to first character to lower case @@ -110,8 +126,6 @@ func isBlockNumber(t reflect.Type) bool { return t == blockNumberType } -var contextType = reflect.TypeOf(new(context.Context)).Elem() - // suitableCallbacks iterates over the methods of the given type. It will determine if a method satisfies the criteria // for a RPC callback or a subscription callback and adds it to the collection of callbacks or subscriptions. See server // documentation for a summary of these criteria. @@ -205,7 +219,7 @@ METHODS: return callbacks, subscriptions } -func newSubscriptionId() (string, error) { +func newSubscriptionID() (string, error) { var subid [16]byte n, _ := rand.Read(subid[:]) if n != 16 { diff --git a/rpc/websocket.go b/rpc/websocket.go index 92615494e..499eedabe 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -93,7 +93,8 @@ func NewWSServer(cors string, handler *Server) *http.Server { Handler: websocket.Server{ Handshake: wsHandshakeValidator(strings.Split(cors, ",")), Handler: func(conn *websocket.Conn) { - handler.ServeCodec(NewJSONCodec(&wsReaderWriterCloser{conn})) + handler.ServeCodec(NewJSONCodec(&wsReaderWriterCloser{conn}), + OptionMethodInvocation|OptionSubscriptions) }, }, } |