diff options
Diffstat (limited to 'rpc/v2')
-rw-r--r-- | rpc/v2/server.go | 21 | ||||
-rw-r--r-- | rpc/v2/server_test.go | 40 | ||||
-rw-r--r-- | rpc/v2/types.go | 2 | ||||
-rw-r--r-- | rpc/v2/utils.go | 25 |
4 files changed, 70 insertions, 18 deletions
diff --git a/rpc/v2/server.go b/rpc/v2/server.go index ff6b69015..4c04f04d2 100644 --- a/rpc/v2/server.go +++ b/rpc/v2/server.go @@ -25,6 +25,7 @@ import ( "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger/glog" + "golang.org/x/net/context" ) // NewServer will create a new server instance with no registered handlers. @@ -120,6 +121,9 @@ func (s *Server) ServeCodec(codec ServerCodec) { codec.Close() }() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + for { reqs, batch, err := s.readRequest(codec) if err != nil { @@ -129,9 +133,9 @@ func (s *Server) ServeCodec(codec ServerCodec) { } if batch { - go s.execBatch(codec, reqs) + go s.execBatch(ctx, codec, reqs) } else { - go s.exec(codec, reqs[0]) + go s.exec(ctx, codec, reqs[0]) } } } @@ -220,7 +224,7 @@ func (s *Server) unsubscribe(subid string) bool { } // handle executes a request and returns the response from the callback. -func (s *Server) handle(codec ServerCodec, req *serverRequest) interface{} { +func (s *Server) handle(ctx context.Context, codec ServerCodec, req *serverRequest) interface{} { if req.err != nil { return codec.CreateErrorResponse(&req.id, req.err) } @@ -255,6 +259,9 @@ func (s *Server) handle(codec ServerCodec, req *serverRequest) interface{} { } arguments := []reflect.Value{req.callb.rcvr} + if req.callb.hasCtx { + arguments = append(arguments, reflect.ValueOf(ctx)) + } if len(req.args) > 0 { arguments = append(arguments, req.args...) } @@ -277,12 +284,12 @@ func (s *Server) handle(codec ServerCodec, req *serverRequest) interface{} { } // exec executes the given request and writes the result back using the codec. -func (s *Server) exec(codec ServerCodec, req *serverRequest) { +func (s *Server) exec(ctx context.Context, codec ServerCodec, req *serverRequest) { var response interface{} if req.err != nil { response = codec.CreateErrorResponse(&req.id, req.err) } else { - response = s.handle(codec, req) + response = s.handle(ctx, codec, req) } if err := codec.Write(response); err != nil { @@ -293,13 +300,13 @@ func (s *Server) exec(codec ServerCodec, req *serverRequest) { // execBatch executes the given requests and writes the result back using the codec. It will only write the response // back when the last request is processed. -func (s *Server) execBatch(codec ServerCodec, requests []*serverRequest) { +func (s *Server) execBatch(ctx context.Context, codec ServerCodec, requests []*serverRequest) { responses := make([]interface{}, len(requests)) for i, req := range requests { if req.err != nil { responses[i] = codec.CreateErrorResponse(&req.id, req.err) } else { - responses[i] = s.handle(codec, req) + responses[i] = s.handle(ctx, codec, req) } } diff --git a/rpc/v2/server_test.go b/rpc/v2/server_test.go index f4f77672f..f250c184f 100644 --- a/rpc/v2/server_test.go +++ b/rpc/v2/server_test.go @@ -6,6 +6,8 @@ import ( "reflect" "testing" "time" + + "golang.org/x/net/context" ) type Service struct{} @@ -27,6 +29,10 @@ func (s *Service) Echo(str string, i int, args *Args) Result { return Result{str, i, args} } +func (s *Service) EchoWithCtx(ctx context.Context, str string, i int, args *Args) Result { + return Result{str, i, args} +} + func (s *Service) Rets() (string, error) { return "", nil } @@ -64,8 +70,8 @@ func TestServerRegisterName(t *testing.T) { t.Fatalf("Expected service calc to be registered") } - if len(svc.callbacks) != 3 { - t.Errorf("Expected 3 callbacks for service 'calc', got %d", len(svc.callbacks)) + if len(svc.callbacks) != 4 { + t.Errorf("Expected 4 callbacks for service 'calc', got %d", len(svc.callbacks)) } if len(svc.subscriptions) != 1 { @@ -217,3 +223,33 @@ func TestServerMethodExecution(t *testing.T) { t.Fatalf("expected %s, got %s\n", expected, codec.output) } } + +func TestServerMethodWithCtx(t *testing.T) { + server := NewServer() + service := new(Service) + + if err := server.RegisterName("test", service); err != nil { + t.Fatalf("%v", err) + } + + id := int64(12345) + req := jsonRequest{ + Method: "echoWithCtx", + Version: "2.0", + Id: &id, + } + args := []interface{}{"string arg", 1122, &Args{"qwerty"}} + req.Payload, _ = json.Marshal(&args) + + input, _ := json.Marshal(&req) + codec := &ServerTestCodec{input: input, closer: make(chan interface{})} + go server.ServeCodec(codec) + + <-codec.closer + + expected := `{"jsonrpc":"2.0","id":12345,"result":{"String":"string arg","Int":1122,"Args":{"S":"qwerty"}}}` + + if expected != codec.output { + t.Fatalf("expected %s, got %s\n", expected, codec.output) + } +} diff --git a/rpc/v2/types.go b/rpc/v2/types.go index d538e0a3f..8e638726f 100644 --- a/rpc/v2/types.go +++ b/rpc/v2/types.go @@ -22,7 +22,6 @@ import ( "math/big" "reflect" "strings" - "sync" "github.com/ethereum/go-ethereum/event" @@ -41,6 +40,7 @@ type callback struct { rcvr reflect.Value // receiver of method method reflect.Method // callback argTypes []reflect.Type // input argument types + hasCtx bool // method's first argument is a context (not included in argTypes) errPos int // err return idx, of -1 when method cannot return error isSubscribe bool // indication if the callback is a subscription } diff --git a/rpc/v2/utils.go b/rpc/v2/utils.go index a564b2473..ca37924a3 100644 --- a/rpc/v2/utils.go +++ b/rpc/v2/utils.go @@ -24,6 +24,8 @@ import ( "reflect" "unicode" "unicode/utf8" + + "golang.org/x/net/context" ) // Is this an exported - upper case - name? @@ -107,6 +109,8 @@ func isBlockNumber(t reflect.Type) bool { return t == blockNumberType } +var contextType = reflect.TypeOf(new(context.Context)).Elem() + // suitableCallbacks iterates over the methods of the given type. It will determine if a method satisfies the criteria // for a RPC callback or a subscription callback and adds it to the collection of callbacks or subscriptions. See server // documentation for a summary of these criteria. @@ -129,12 +133,19 @@ METHODS: h.method = method h.errPos = -1 + firstArg := 1 + numIn := mtype.NumIn() + if numIn >= 2 && mtype.In(1) == contextType { + h.hasCtx = true + firstArg = 2 + } + if h.isSubscribe { - h.argTypes = make([]reflect.Type, mtype.NumIn()-1) // skip rcvr type - for i := 1; i < mtype.NumIn(); i++ { + h.argTypes = make([]reflect.Type, numIn-firstArg) // skip rcvr type + for i := firstArg; i < numIn; i++ { argType := mtype.In(i) if isExportedOrBuiltinType(argType) { - h.argTypes[i-1] = argType + h.argTypes[i-firstArg] = argType } else { continue METHODS } @@ -144,17 +155,15 @@ METHODS: continue METHODS } - numIn := mtype.NumIn() - // determine method arguments, ignore first arg since it's the receiver type // Arguments must be exported or builtin types - h.argTypes = make([]reflect.Type, numIn-1) - for i := 1; i < numIn; i++ { + h.argTypes = make([]reflect.Type, numIn-firstArg) + for i := firstArg; i < numIn; i++ { argType := mtype.In(i) if !isExportedOrBuiltinType(argType) { continue METHODS } - h.argTypes[i-1] = argType + h.argTypes[i-firstArg] = argType } // check that all returned values are exported or builtin types |