diff options
Diffstat (limited to 'rpc')
-rw-r--r-- | rpc/client.go | 5 | ||||
-rw-r--r-- | rpc/http.go | 2 | ||||
-rw-r--r-- | rpc/ipc_unix.go | 2 | ||||
-rw-r--r-- | rpc/json.go | 17 | ||||
-rw-r--r-- | rpc/websocket.go | 199 | ||||
-rw-r--r-- | rpc/websocket_test.go | 243 |
6 files changed, 311 insertions, 157 deletions
diff --git a/rpc/client.go b/rpc/client.go index 2053f5406..4b65d0042 100644 --- a/rpc/client.go +++ b/rpc/client.go @@ -41,9 +41,8 @@ var ( const ( // Timeouts - tcpKeepAliveInterval = 30 * time.Second - defaultDialTimeout = 10 * time.Second // used if context has no deadline - subscribeTimeout = 5 * time.Second // overall timeout eth_subscribe, rpc_modules calls + defaultDialTimeout = 10 * time.Second // used if context has no deadline + subscribeTimeout = 5 * time.Second // overall timeout eth_subscribe, rpc_modules calls ) const ( diff --git a/rpc/http.go b/rpc/http.go index 518b3b874..e8f2cfda7 100644 --- a/rpc/http.go +++ b/rpc/http.go @@ -36,7 +36,7 @@ import ( ) const ( - maxRequestContentLength = 1024 * 512 + maxRequestContentLength = 1024 * 1024 * 5 contentType = "application/json" ) diff --git a/rpc/ipc_unix.go b/rpc/ipc_unix.go index 022d480b5..f4690cc0a 100644 --- a/rpc/ipc_unix.go +++ b/rpc/ipc_unix.go @@ -50,5 +50,5 @@ func ipcListen(endpoint string) (net.Listener, error) { // newIPCConnection will connect to a Unix socket on the given endpoint. func newIPCConnection(ctx context.Context, endpoint string) (net.Conn, error) { - return dialContext(ctx, "unix", endpoint) + return new(net.Dialer).DialContext(ctx, "unix", endpoint) } diff --git a/rpc/json.go b/rpc/json.go index b2e8c7bab..75c221038 100644 --- a/rpc/json.go +++ b/rpc/json.go @@ -141,6 +141,11 @@ type Conn interface { SetWriteDeadline(time.Time) error } +type deadlineCloser interface { + io.Closer + SetWriteDeadline(time.Time) error +} + // ConnRemoteAddr wraps the RemoteAddr operation, which returns a description // of the peer address of a connection. If a Conn also implements ConnRemoteAddr, this // description is used in log messages. @@ -165,12 +170,10 @@ type jsonCodec struct { decode func(v interface{}) error // decoder to allow multiple transports encMu sync.Mutex // guards the encoder encode func(v interface{}) error // encoder to allow multiple transports - conn Conn + conn deadlineCloser } -// NewCodec creates a new RPC server codec with support for JSON-RPC 2.0 based -// on explicitly given encoding and decoding methods. -func NewCodec(conn Conn, encode, decode func(v interface{}) error) ServerCodec { +func newCodec(conn deadlineCloser, encode, decode func(v interface{}) error) ServerCodec { codec := &jsonCodec{ closed: make(chan interface{}), encode: encode, @@ -183,12 +186,14 @@ func NewCodec(conn Conn, encode, decode func(v interface{}) error) ServerCodec { return codec } -// NewJSONCodec creates a new RPC server codec with support for JSON-RPC 2.0. +// NewJSONCodec creates a codec that reads from the given connection. If conn implements +// ConnRemoteAddr, log messages will use it to include the remote address of the +// connection. func NewJSONCodec(conn Conn) ServerCodec { enc := json.NewEncoder(conn) dec := json.NewDecoder(conn) dec.UseNumber() - return NewCodec(conn, enc.Encode, dec.Decode) + return newCodec(conn, enc.Encode, dec.Decode) } func (c *jsonCodec) RemoteAddr() string { diff --git a/rpc/websocket.go b/rpc/websocket.go index c5383667d..1632d6af4 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -17,40 +17,32 @@ package rpc import ( - "bytes" "context" - "crypto/tls" "encoding/base64" - "encoding/json" - "errors" "fmt" - "net" "net/http" "net/url" "os" "strings" - "time" + "sync" mapset "github.com/deckarep/golang-set" "github.com/ethereum/go-ethereum/log" - "golang.org/x/net/websocket" + "github.com/gorilla/websocket" ) -// websocketJSONCodec is a custom JSON codec with payload size enforcement and -// special number parsing. -var websocketJSONCodec = websocket.Codec{ - // Marshal is the stock JSON marshaller used by the websocket library too. - Marshal: func(v interface{}) ([]byte, byte, error) { - msg, err := json.Marshal(v) - return msg, websocket.TextFrame, err - }, - // Unmarshal is a specialized unmarshaller to properly convert numbers. - Unmarshal: func(msg []byte, payloadType byte, v interface{}) error { - dec := json.NewDecoder(bytes.NewReader(msg)) - dec.UseNumber() - - return dec.Decode(v) - }, +const ( + wsReadBuffer = 1024 + wsWriteBuffer = 1024 +) + +var wsBufferPool = new(sync.Pool) + +// NewWSServer creates a new websocket RPC server around an API provider. +// +// Deprecated: use Server.WebsocketHandler +func NewWSServer(allowedOrigins []string, srv *Server) *http.Server { + return &http.Server{Handler: srv.WebsocketHandler(allowedOrigins)} } // WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections. @@ -58,49 +50,27 @@ var websocketJSONCodec = websocket.Codec{ // allowedOrigins should be a comma-separated list of allowed origin URLs. // To allow connections with any origin, pass "*". func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler { - return websocket.Server{ - Handshake: wsHandshakeValidator(allowedOrigins), - Handler: func(conn *websocket.Conn) { - codec := newWebsocketCodec(conn) - s.ServeCodec(codec, OptionMethodInvocation|OptionSubscriptions) - }, - } -} - -func newWebsocketCodec(conn *websocket.Conn) ServerCodec { - // Create a custom encode/decode pair to enforce payload size and number encoding - conn.MaxPayloadBytes = maxRequestContentLength - encoder := func(v interface{}) error { - return websocketJSONCodec.Send(conn, v) - } - decoder := func(v interface{}) error { - return websocketJSONCodec.Receive(conn, v) - } - rpcconn := Conn(conn) - if conn.IsServerConn() { - // Override remote address with the actual socket address because - // package websocket crashes if there is no request origin. - addr := conn.Request().RemoteAddr - if wsaddr := conn.RemoteAddr().(*websocket.Addr); wsaddr.URL != nil { - // Add origin if present. - addr += "(" + wsaddr.URL.String() + ")" + var upgrader = websocket.Upgrader{ + ReadBufferSize: wsReadBuffer, + WriteBufferSize: wsWriteBuffer, + WriteBufferPool: wsBufferPool, + CheckOrigin: wsHandshakeValidator(allowedOrigins), + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Debug("WebSocket upgrade failed", "err", err) + return } - rpcconn = connWithRemoteAddr{conn, addr} - } - return NewCodec(rpcconn, encoder, decoder) -} - -// NewWSServer creates a new websocket RPC server around an API provider. -// -// Deprecated: use Server.WebsocketHandler -func NewWSServer(allowedOrigins []string, srv *Server) *http.Server { - return &http.Server{Handler: srv.WebsocketHandler(allowedOrigins)} + codec := newWebsocketCodec(conn) + s.ServeCodec(codec, OptionMethodInvocation|OptionSubscriptions) + }) } // wsHandshakeValidator returns a handler that verifies the origin during the // websocket upgrade process. When a '*' is specified as an allowed origins all // connections are accepted. -func wsHandshakeValidator(allowedOrigins []string) func(*websocket.Config, *http.Request) error { +func wsHandshakeValidator(allowedOrigins []string) func(*http.Request) bool { origins := mapset.NewSet() allowAllOrigins := false @@ -112,7 +82,6 @@ func wsHandshakeValidator(allowedOrigins []string) func(*websocket.Config, *http origins.Add(strings.ToLower(origin)) } } - // allow localhost if no allowedOrigins are specified. if len(origins.ToSlice()) == 0 { origins.Add("http://localhost") @@ -120,52 +89,39 @@ func wsHandshakeValidator(allowedOrigins []string) func(*websocket.Config, *http origins.Add("http://" + strings.ToLower(hostname)) } } - log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v", origins.ToSlice())) - f := func(cfg *websocket.Config, req *http.Request) error { + f := func(req *http.Request) bool { // Skip origin verification if no Origin header is present. The origin check // is supposed to protect against browser based attacks. Browsers always set // Origin. Non-browser software can put anything in origin and checking it doesn't // provide additional security. if _, ok := req.Header["Origin"]; !ok { - return nil + return true } // Verify origin against whitelist. origin := strings.ToLower(req.Header.Get("Origin")) if allowAllOrigins || origins.Contains(origin) { - return nil + return true } log.Warn("Rejected WebSocket connection", "origin", origin) - return errors.New("origin not allowed") + return false } return f } -func wsGetConfig(endpoint, origin string) (*websocket.Config, error) { - if origin == "" { - var err error - if origin, err = os.Hostname(); err != nil { - return nil, err - } - if strings.HasPrefix(endpoint, "wss") { - origin = "https://" + strings.ToLower(origin) - } else { - origin = "http://" + strings.ToLower(origin) - } - } - config, err := websocket.NewConfig(endpoint, origin) - if err != nil { - return nil, err - } +type wsHandshakeError struct { + err error + status string +} - if config.Location.User != nil { - b64auth := base64.StdEncoding.EncodeToString([]byte(config.Location.User.String())) - config.Header.Add("Authorization", "Basic "+b64auth) - config.Location.User = nil +func (e wsHandshakeError) Error() string { + s := e.err.Error() + if e.status != "" { + s += " (HTTP status " + e.status + ")" } - return config, nil + return s } // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server @@ -174,65 +130,46 @@ func wsGetConfig(endpoint, origin string) (*websocket.Config, error) { // The context is used for the initial connection establishment. It does not // affect subsequent interactions with the client. func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) { - config, err := wsGetConfig(endpoint, origin) + endpoint, header, err := wsClientHeaders(endpoint, origin) if err != nil { return nil, err } - + dialer := websocket.Dialer{ + ReadBufferSize: wsReadBuffer, + WriteBufferSize: wsWriteBuffer, + WriteBufferPool: wsBufferPool, + } return newClient(ctx, func(ctx context.Context) (ServerCodec, error) { - conn, err := wsDialContext(ctx, config) + conn, resp, err := dialer.DialContext(ctx, endpoint, header) if err != nil { - return nil, err + hErr := wsHandshakeError{err: err} + if resp != nil { + hErr.status = resp.Status + } + return nil, hErr } return newWebsocketCodec(conn), nil }) } -func wsDialContext(ctx context.Context, config *websocket.Config) (*websocket.Conn, error) { - var conn net.Conn - var err error - switch config.Location.Scheme { - case "ws": - conn, err = dialContext(ctx, "tcp", wsDialAddress(config.Location)) - case "wss": - dialer := contextDialer(ctx) - conn, err = tls.DialWithDialer(dialer, "tcp", wsDialAddress(config.Location), config.TlsConfig) - default: - err = websocket.ErrBadScheme - } +func wsClientHeaders(endpoint, origin string) (string, http.Header, error) { + endpointURL, err := url.Parse(endpoint) if err != nil { - return nil, err + return endpoint, nil, err } - ws, err := websocket.NewClient(config, conn) - if err != nil { - conn.Close() - return nil, err + header := make(http.Header) + if origin != "" { + header.Add("origin", origin) } - return ws, err -} - -var wsPortMap = map[string]string{"ws": "80", "wss": "443"} - -func wsDialAddress(location *url.URL) string { - if _, ok := wsPortMap[location.Scheme]; ok { - if _, _, err := net.SplitHostPort(location.Host); err != nil { - return net.JoinHostPort(location.Host, wsPortMap[location.Scheme]) - } + if endpointURL.User != nil { + b64auth := base64.StdEncoding.EncodeToString([]byte(endpointURL.User.String())) + header.Add("authorization", "Basic "+b64auth) + endpointURL.User = nil } - return location.Host -} - -func dialContext(ctx context.Context, network, addr string) (net.Conn, error) { - d := &net.Dialer{KeepAlive: tcpKeepAliveInterval} - return d.DialContext(ctx, network, addr) + return endpointURL.String(), header, nil } -func contextDialer(ctx context.Context) *net.Dialer { - dialer := &net.Dialer{Cancel: ctx.Done(), KeepAlive: tcpKeepAliveInterval} - if deadline, ok := ctx.Deadline(); ok { - dialer.Deadline = deadline - } else { - dialer.Deadline = time.Now().Add(defaultDialTimeout) - } - return dialer +func newWebsocketCodec(conn *websocket.Conn) ServerCodec { + conn.SetReadLimit(maxRequestContentLength) + return newCodec(conn, conn.WriteJSON, conn.ReadJSON) } diff --git a/rpc/websocket_test.go b/rpc/websocket_test.go index 0ce9534b5..9dc108479 100644 --- a/rpc/websocket_test.go +++ b/rpc/websocket_test.go @@ -16,31 +16,244 @@ package rpc -import "testing" +import ( + "context" + "net" + "net/http" + "net/http/httptest" + "reflect" + "strings" + "testing" + "time" -func TestWSGetConfigNoAuth(t *testing.T) { - config, err := wsGetConfig("ws://example.com:1234", "") + "github.com/gorilla/websocket" +) + +func TestWebsocketClientHeaders(t *testing.T) { + t.Parallel() + + endpoint, header, err := wsClientHeaders("wss://testuser:test-PASS_01@example.com:1234", "https://example.com") if err != nil { t.Fatalf("wsGetConfig failed: %s", err) } - if config.Location.User != nil { - t.Fatalf("User should have been stripped from the URL") + if endpoint != "wss://example.com:1234" { + t.Fatal("User should have been stripped from the URL") + } + if header.Get("authorization") != "Basic dGVzdHVzZXI6dGVzdC1QQVNTXzAx" { + t.Fatal("Basic auth header is incorrect") } - if config.Location.Hostname() != "example.com" || - config.Location.Port() != "1234" || config.Location.Scheme != "ws" { - t.Fatalf("Unexpected URL: %s", config.Location) + if header.Get("origin") != "https://example.com" { + t.Fatal("Origin not set") } } -func TestWSGetConfigWithBasicAuth(t *testing.T) { - config, err := wsGetConfig("wss://testuser:test-PASS_01@example.com:1234", "") +// This test checks that the server rejects connections from disallowed origins. +func TestWebsocketOriginCheck(t *testing.T) { + t.Parallel() + + var ( + srv = newTestServer() + httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"http://example.com"})) + wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") + ) + defer srv.Stop() + defer httpsrv.Close() + + client, err := DialWebsocket(context.Background(), wsURL, "http://ekzample.com") + if err == nil { + client.Close() + t.Fatal("no error for wrong origin") + } + wantErr := wsHandshakeError{websocket.ErrBadHandshake, "403 Forbidden"} + if !reflect.DeepEqual(err, wantErr) { + t.Fatalf("wrong error for wrong origin: %q", err) + } + + // Connections without origin header should work. + client, err = DialWebsocket(context.Background(), wsURL, "") if err != nil { - t.Fatalf("wsGetConfig failed: %s", err) + t.Fatal("error for empty origin") } - if config.Location.User != nil { - t.Fatal("User should have been stripped from the URL") + client.Close() +} + +// This test checks whether calls exceeding the request size limit are rejected. +func TestWebsocketLargeCall(t *testing.T) { + t.Parallel() + + var ( + srv = newTestServer() + httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"*"})) + wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") + ) + defer srv.Stop() + defer httpsrv.Close() + + client, err := DialWebsocket(context.Background(), wsURL, "") + if err != nil { + t.Fatalf("can't dial: %v", err) } - if config.Header.Get("Authorization") != "Basic dGVzdHVzZXI6dGVzdC1QQVNTXzAx" { - t.Fatal("Basic auth header is incorrect") + defer client.Close() + + // This call sends slightly less than the limit and should work. + var result Result + arg := strings.Repeat("x", maxRequestContentLength-200) + if err := client.Call(&result, "test_echo", arg, 1); err != nil { + t.Fatalf("valid call didn't work: %v", err) + } + if result.String != arg { + t.Fatal("wrong string echoed") + } + + // This call sends twice the allowed size and shouldn't work. + arg = strings.Repeat("x", maxRequestContentLength*2) + err = client.Call(&result, "test_echo", arg) + if err == nil { + t.Fatal("no error for too large call") + } +} + +// This test checks that client handles WebSocket ping frames correctly. +func TestClientWebsocketPing(t *testing.T) { + t.Parallel() + + var ( + sendPing = make(chan struct{}) + server = wsPingTestServer(t, sendPing) + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second) + ) + defer cancel() + defer server.Shutdown(ctx) + + client, err := DialContext(ctx, "ws://"+server.Addr) + if err != nil { + t.Fatalf("client dial error: %v", err) + } + resultChan := make(chan int) + sub, err := client.EthSubscribe(ctx, resultChan, "foo") + if err != nil { + t.Fatalf("client subscribe error: %v", err) + } + + // Wait for the context's deadline to be reached before proceeding. + // This is important for reproducing https://github.com/ethereum/go-ethereum/issues/19798 + <-ctx.Done() + close(sendPing) + + // Wait for the subscription result. + timeout := time.NewTimer(5 * time.Second) + for { + select { + case err := <-sub.Err(): + t.Error("client subscription error:", err) + case result := <-resultChan: + t.Log("client got result:", result) + return + case <-timeout.C: + t.Error("didn't get any result within the test timeout") + return + } + } +} + +// wsPingTestServer runs a WebSocket server which accepts a single subscription request. +// When a value arrives on sendPing, the server sends a ping frame, waits for a matching +// pong and finally delivers a single subscription result. +func wsPingTestServer(t *testing.T, sendPing <-chan struct{}) *http.Server { + var srv http.Server + shutdown := make(chan struct{}) + srv.RegisterOnShutdown(func() { + close(shutdown) + }) + srv.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Upgrade to WebSocket. + upgrader := websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, + } + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("server WS upgrade error: %v", err) + return + } + defer conn.Close() + + // Handle the connection. + wsPingTestHandler(t, conn, shutdown, sendPing) + }) + + // Start the server. + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal("can't listen:", err) + } + srv.Addr = listener.Addr().String() + go srv.Serve(listener) + return &srv +} + +func wsPingTestHandler(t *testing.T, conn *websocket.Conn, shutdown, sendPing <-chan struct{}) { + // Canned responses for the eth_subscribe call in TestClientWebsocketPing. + const ( + subResp = `{"jsonrpc":"2.0","id":1,"result":"0x00"}` + subNotify = `{"jsonrpc":"2.0","method":"eth_subscription","params":{"subscription":"0x00","result":1}}` + ) + + // Handle subscribe request. + if _, _, err := conn.ReadMessage(); err != nil { + t.Errorf("server read error: %v", err) + return + } + if err := conn.WriteMessage(websocket.TextMessage, []byte(subResp)); err != nil { + t.Errorf("server write error: %v", err) + return + } + + // Read from the connection to process control messages. + var pongCh = make(chan string) + conn.SetPongHandler(func(d string) error { + t.Logf("server got pong: %q", d) + pongCh <- d + return nil + }) + go func() { + for { + typ, msg, err := conn.ReadMessage() + if err != nil { + return + } + t.Logf("server got message (%d): %q", typ, msg) + } + }() + + // Write messages. + var ( + sendResponse <-chan time.Time + wantPong string + ) + for { + select { + case _, open := <-sendPing: + if !open { + sendPing = nil + } + t.Logf("server sending ping") + conn.WriteMessage(websocket.PingMessage, []byte("ping")) + wantPong = "ping" + case data := <-pongCh: + if wantPong == "" { + t.Errorf("unexpected pong") + } else if data != wantPong { + t.Errorf("got pong with wrong data %q", data) + } + wantPong = "" + sendResponse = time.NewTimer(200 * time.Millisecond).C + case <-sendResponse: + t.Logf("server sending response") + conn.WriteMessage(websocket.TextMessage, []byte(subNotify)) + sendResponse = nil + case <-shutdown: + conn.Close() + return + } } } |