aboutsummaryrefslogtreecommitdiffstats
path: root/rpc
diff options
context:
space:
mode:
Diffstat (limited to 'rpc')
-rw-r--r--rpc/client.go5
-rw-r--r--rpc/http.go2
-rw-r--r--rpc/ipc_unix.go2
-rw-r--r--rpc/json.go17
-rw-r--r--rpc/websocket.go199
-rw-r--r--rpc/websocket_test.go243
6 files changed, 311 insertions, 157 deletions
diff --git a/rpc/client.go b/rpc/client.go
index 2053f5406..4b65d0042 100644
--- a/rpc/client.go
+++ b/rpc/client.go
@@ -41,9 +41,8 @@ var (
const (
// Timeouts
- tcpKeepAliveInterval = 30 * time.Second
- defaultDialTimeout = 10 * time.Second // used if context has no deadline
- subscribeTimeout = 5 * time.Second // overall timeout eth_subscribe, rpc_modules calls
+ defaultDialTimeout = 10 * time.Second // used if context has no deadline
+ subscribeTimeout = 5 * time.Second // overall timeout eth_subscribe, rpc_modules calls
)
const (
diff --git a/rpc/http.go b/rpc/http.go
index 518b3b874..e8f2cfda7 100644
--- a/rpc/http.go
+++ b/rpc/http.go
@@ -36,7 +36,7 @@ import (
)
const (
- maxRequestContentLength = 1024 * 512
+ maxRequestContentLength = 1024 * 1024 * 5
contentType = "application/json"
)
diff --git a/rpc/ipc_unix.go b/rpc/ipc_unix.go
index 022d480b5..f4690cc0a 100644
--- a/rpc/ipc_unix.go
+++ b/rpc/ipc_unix.go
@@ -50,5 +50,5 @@ func ipcListen(endpoint string) (net.Listener, error) {
// newIPCConnection will connect to a Unix socket on the given endpoint.
func newIPCConnection(ctx context.Context, endpoint string) (net.Conn, error) {
- return dialContext(ctx, "unix", endpoint)
+ return new(net.Dialer).DialContext(ctx, "unix", endpoint)
}
diff --git a/rpc/json.go b/rpc/json.go
index b2e8c7bab..75c221038 100644
--- a/rpc/json.go
+++ b/rpc/json.go
@@ -141,6 +141,11 @@ type Conn interface {
SetWriteDeadline(time.Time) error
}
+type deadlineCloser interface {
+ io.Closer
+ SetWriteDeadline(time.Time) error
+}
+
// ConnRemoteAddr wraps the RemoteAddr operation, which returns a description
// of the peer address of a connection. If a Conn also implements ConnRemoteAddr, this
// description is used in log messages.
@@ -165,12 +170,10 @@ type jsonCodec struct {
decode func(v interface{}) error // decoder to allow multiple transports
encMu sync.Mutex // guards the encoder
encode func(v interface{}) error // encoder to allow multiple transports
- conn Conn
+ conn deadlineCloser
}
-// NewCodec creates a new RPC server codec with support for JSON-RPC 2.0 based
-// on explicitly given encoding and decoding methods.
-func NewCodec(conn Conn, encode, decode func(v interface{}) error) ServerCodec {
+func newCodec(conn deadlineCloser, encode, decode func(v interface{}) error) ServerCodec {
codec := &jsonCodec{
closed: make(chan interface{}),
encode: encode,
@@ -183,12 +186,14 @@ func NewCodec(conn Conn, encode, decode func(v interface{}) error) ServerCodec {
return codec
}
-// NewJSONCodec creates a new RPC server codec with support for JSON-RPC 2.0.
+// NewJSONCodec creates a codec that reads from the given connection. If conn implements
+// ConnRemoteAddr, log messages will use it to include the remote address of the
+// connection.
func NewJSONCodec(conn Conn) ServerCodec {
enc := json.NewEncoder(conn)
dec := json.NewDecoder(conn)
dec.UseNumber()
- return NewCodec(conn, enc.Encode, dec.Decode)
+ return newCodec(conn, enc.Encode, dec.Decode)
}
func (c *jsonCodec) RemoteAddr() string {
diff --git a/rpc/websocket.go b/rpc/websocket.go
index c5383667d..1632d6af4 100644
--- a/rpc/websocket.go
+++ b/rpc/websocket.go
@@ -17,40 +17,32 @@
package rpc
import (
- "bytes"
"context"
- "crypto/tls"
"encoding/base64"
- "encoding/json"
- "errors"
"fmt"
- "net"
"net/http"
"net/url"
"os"
"strings"
- "time"
+ "sync"
mapset "github.com/deckarep/golang-set"
"github.com/ethereum/go-ethereum/log"
- "golang.org/x/net/websocket"
+ "github.com/gorilla/websocket"
)
-// websocketJSONCodec is a custom JSON codec with payload size enforcement and
-// special number parsing.
-var websocketJSONCodec = websocket.Codec{
- // Marshal is the stock JSON marshaller used by the websocket library too.
- Marshal: func(v interface{}) ([]byte, byte, error) {
- msg, err := json.Marshal(v)
- return msg, websocket.TextFrame, err
- },
- // Unmarshal is a specialized unmarshaller to properly convert numbers.
- Unmarshal: func(msg []byte, payloadType byte, v interface{}) error {
- dec := json.NewDecoder(bytes.NewReader(msg))
- dec.UseNumber()
-
- return dec.Decode(v)
- },
+const (
+ wsReadBuffer = 1024
+ wsWriteBuffer = 1024
+)
+
+var wsBufferPool = new(sync.Pool)
+
+// NewWSServer creates a new websocket RPC server around an API provider.
+//
+// Deprecated: use Server.WebsocketHandler
+func NewWSServer(allowedOrigins []string, srv *Server) *http.Server {
+ return &http.Server{Handler: srv.WebsocketHandler(allowedOrigins)}
}
// WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections.
@@ -58,49 +50,27 @@ var websocketJSONCodec = websocket.Codec{
// allowedOrigins should be a comma-separated list of allowed origin URLs.
// To allow connections with any origin, pass "*".
func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler {
- return websocket.Server{
- Handshake: wsHandshakeValidator(allowedOrigins),
- Handler: func(conn *websocket.Conn) {
- codec := newWebsocketCodec(conn)
- s.ServeCodec(codec, OptionMethodInvocation|OptionSubscriptions)
- },
- }
-}
-
-func newWebsocketCodec(conn *websocket.Conn) ServerCodec {
- // Create a custom encode/decode pair to enforce payload size and number encoding
- conn.MaxPayloadBytes = maxRequestContentLength
- encoder := func(v interface{}) error {
- return websocketJSONCodec.Send(conn, v)
- }
- decoder := func(v interface{}) error {
- return websocketJSONCodec.Receive(conn, v)
- }
- rpcconn := Conn(conn)
- if conn.IsServerConn() {
- // Override remote address with the actual socket address because
- // package websocket crashes if there is no request origin.
- addr := conn.Request().RemoteAddr
- if wsaddr := conn.RemoteAddr().(*websocket.Addr); wsaddr.URL != nil {
- // Add origin if present.
- addr += "(" + wsaddr.URL.String() + ")"
+ var upgrader = websocket.Upgrader{
+ ReadBufferSize: wsReadBuffer,
+ WriteBufferSize: wsWriteBuffer,
+ WriteBufferPool: wsBufferPool,
+ CheckOrigin: wsHandshakeValidator(allowedOrigins),
+ }
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := upgrader.Upgrade(w, r, nil)
+ if err != nil {
+ log.Debug("WebSocket upgrade failed", "err", err)
+ return
}
- rpcconn = connWithRemoteAddr{conn, addr}
- }
- return NewCodec(rpcconn, encoder, decoder)
-}
-
-// NewWSServer creates a new websocket RPC server around an API provider.
-//
-// Deprecated: use Server.WebsocketHandler
-func NewWSServer(allowedOrigins []string, srv *Server) *http.Server {
- return &http.Server{Handler: srv.WebsocketHandler(allowedOrigins)}
+ codec := newWebsocketCodec(conn)
+ s.ServeCodec(codec, OptionMethodInvocation|OptionSubscriptions)
+ })
}
// wsHandshakeValidator returns a handler that verifies the origin during the
// websocket upgrade process. When a '*' is specified as an allowed origins all
// connections are accepted.
-func wsHandshakeValidator(allowedOrigins []string) func(*websocket.Config, *http.Request) error {
+func wsHandshakeValidator(allowedOrigins []string) func(*http.Request) bool {
origins := mapset.NewSet()
allowAllOrigins := false
@@ -112,7 +82,6 @@ func wsHandshakeValidator(allowedOrigins []string) func(*websocket.Config, *http
origins.Add(strings.ToLower(origin))
}
}
-
// allow localhost if no allowedOrigins are specified.
if len(origins.ToSlice()) == 0 {
origins.Add("http://localhost")
@@ -120,52 +89,39 @@ func wsHandshakeValidator(allowedOrigins []string) func(*websocket.Config, *http
origins.Add("http://" + strings.ToLower(hostname))
}
}
-
log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v", origins.ToSlice()))
- f := func(cfg *websocket.Config, req *http.Request) error {
+ f := func(req *http.Request) bool {
// Skip origin verification if no Origin header is present. The origin check
// is supposed to protect against browser based attacks. Browsers always set
// Origin. Non-browser software can put anything in origin and checking it doesn't
// provide additional security.
if _, ok := req.Header["Origin"]; !ok {
- return nil
+ return true
}
// Verify origin against whitelist.
origin := strings.ToLower(req.Header.Get("Origin"))
if allowAllOrigins || origins.Contains(origin) {
- return nil
+ return true
}
log.Warn("Rejected WebSocket connection", "origin", origin)
- return errors.New("origin not allowed")
+ return false
}
return f
}
-func wsGetConfig(endpoint, origin string) (*websocket.Config, error) {
- if origin == "" {
- var err error
- if origin, err = os.Hostname(); err != nil {
- return nil, err
- }
- if strings.HasPrefix(endpoint, "wss") {
- origin = "https://" + strings.ToLower(origin)
- } else {
- origin = "http://" + strings.ToLower(origin)
- }
- }
- config, err := websocket.NewConfig(endpoint, origin)
- if err != nil {
- return nil, err
- }
+type wsHandshakeError struct {
+ err error
+ status string
+}
- if config.Location.User != nil {
- b64auth := base64.StdEncoding.EncodeToString([]byte(config.Location.User.String()))
- config.Header.Add("Authorization", "Basic "+b64auth)
- config.Location.User = nil
+func (e wsHandshakeError) Error() string {
+ s := e.err.Error()
+ if e.status != "" {
+ s += " (HTTP status " + e.status + ")"
}
- return config, nil
+ return s
}
// DialWebsocket creates a new RPC client that communicates with a JSON-RPC server
@@ -174,65 +130,46 @@ func wsGetConfig(endpoint, origin string) (*websocket.Config, error) {
// The context is used for the initial connection establishment. It does not
// affect subsequent interactions with the client.
func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) {
- config, err := wsGetConfig(endpoint, origin)
+ endpoint, header, err := wsClientHeaders(endpoint, origin)
if err != nil {
return nil, err
}
-
+ dialer := websocket.Dialer{
+ ReadBufferSize: wsReadBuffer,
+ WriteBufferSize: wsWriteBuffer,
+ WriteBufferPool: wsBufferPool,
+ }
return newClient(ctx, func(ctx context.Context) (ServerCodec, error) {
- conn, err := wsDialContext(ctx, config)
+ conn, resp, err := dialer.DialContext(ctx, endpoint, header)
if err != nil {
- return nil, err
+ hErr := wsHandshakeError{err: err}
+ if resp != nil {
+ hErr.status = resp.Status
+ }
+ return nil, hErr
}
return newWebsocketCodec(conn), nil
})
}
-func wsDialContext(ctx context.Context, config *websocket.Config) (*websocket.Conn, error) {
- var conn net.Conn
- var err error
- switch config.Location.Scheme {
- case "ws":
- conn, err = dialContext(ctx, "tcp", wsDialAddress(config.Location))
- case "wss":
- dialer := contextDialer(ctx)
- conn, err = tls.DialWithDialer(dialer, "tcp", wsDialAddress(config.Location), config.TlsConfig)
- default:
- err = websocket.ErrBadScheme
- }
+func wsClientHeaders(endpoint, origin string) (string, http.Header, error) {
+ endpointURL, err := url.Parse(endpoint)
if err != nil {
- return nil, err
+ return endpoint, nil, err
}
- ws, err := websocket.NewClient(config, conn)
- if err != nil {
- conn.Close()
- return nil, err
+ header := make(http.Header)
+ if origin != "" {
+ header.Add("origin", origin)
}
- return ws, err
-}
-
-var wsPortMap = map[string]string{"ws": "80", "wss": "443"}
-
-func wsDialAddress(location *url.URL) string {
- if _, ok := wsPortMap[location.Scheme]; ok {
- if _, _, err := net.SplitHostPort(location.Host); err != nil {
- return net.JoinHostPort(location.Host, wsPortMap[location.Scheme])
- }
+ if endpointURL.User != nil {
+ b64auth := base64.StdEncoding.EncodeToString([]byte(endpointURL.User.String()))
+ header.Add("authorization", "Basic "+b64auth)
+ endpointURL.User = nil
}
- return location.Host
-}
-
-func dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
- d := &net.Dialer{KeepAlive: tcpKeepAliveInterval}
- return d.DialContext(ctx, network, addr)
+ return endpointURL.String(), header, nil
}
-func contextDialer(ctx context.Context) *net.Dialer {
- dialer := &net.Dialer{Cancel: ctx.Done(), KeepAlive: tcpKeepAliveInterval}
- if deadline, ok := ctx.Deadline(); ok {
- dialer.Deadline = deadline
- } else {
- dialer.Deadline = time.Now().Add(defaultDialTimeout)
- }
- return dialer
+func newWebsocketCodec(conn *websocket.Conn) ServerCodec {
+ conn.SetReadLimit(maxRequestContentLength)
+ return newCodec(conn, conn.WriteJSON, conn.ReadJSON)
}
diff --git a/rpc/websocket_test.go b/rpc/websocket_test.go
index 0ce9534b5..9dc108479 100644
--- a/rpc/websocket_test.go
+++ b/rpc/websocket_test.go
@@ -16,31 +16,244 @@
package rpc
-import "testing"
+import (
+ "context"
+ "net"
+ "net/http"
+ "net/http/httptest"
+ "reflect"
+ "strings"
+ "testing"
+ "time"
-func TestWSGetConfigNoAuth(t *testing.T) {
- config, err := wsGetConfig("ws://example.com:1234", "")
+ "github.com/gorilla/websocket"
+)
+
+func TestWebsocketClientHeaders(t *testing.T) {
+ t.Parallel()
+
+ endpoint, header, err := wsClientHeaders("wss://testuser:test-PASS_01@example.com:1234", "https://example.com")
if err != nil {
t.Fatalf("wsGetConfig failed: %s", err)
}
- if config.Location.User != nil {
- t.Fatalf("User should have been stripped from the URL")
+ if endpoint != "wss://example.com:1234" {
+ t.Fatal("User should have been stripped from the URL")
+ }
+ if header.Get("authorization") != "Basic dGVzdHVzZXI6dGVzdC1QQVNTXzAx" {
+ t.Fatal("Basic auth header is incorrect")
}
- if config.Location.Hostname() != "example.com" ||
- config.Location.Port() != "1234" || config.Location.Scheme != "ws" {
- t.Fatalf("Unexpected URL: %s", config.Location)
+ if header.Get("origin") != "https://example.com" {
+ t.Fatal("Origin not set")
}
}
-func TestWSGetConfigWithBasicAuth(t *testing.T) {
- config, err := wsGetConfig("wss://testuser:test-PASS_01@example.com:1234", "")
+// This test checks that the server rejects connections from disallowed origins.
+func TestWebsocketOriginCheck(t *testing.T) {
+ t.Parallel()
+
+ var (
+ srv = newTestServer()
+ httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"http://example.com"}))
+ wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:")
+ )
+ defer srv.Stop()
+ defer httpsrv.Close()
+
+ client, err := DialWebsocket(context.Background(), wsURL, "http://ekzample.com")
+ if err == nil {
+ client.Close()
+ t.Fatal("no error for wrong origin")
+ }
+ wantErr := wsHandshakeError{websocket.ErrBadHandshake, "403 Forbidden"}
+ if !reflect.DeepEqual(err, wantErr) {
+ t.Fatalf("wrong error for wrong origin: %q", err)
+ }
+
+ // Connections without origin header should work.
+ client, err = DialWebsocket(context.Background(), wsURL, "")
if err != nil {
- t.Fatalf("wsGetConfig failed: %s", err)
+ t.Fatal("error for empty origin")
}
- if config.Location.User != nil {
- t.Fatal("User should have been stripped from the URL")
+ client.Close()
+}
+
+// This test checks whether calls exceeding the request size limit are rejected.
+func TestWebsocketLargeCall(t *testing.T) {
+ t.Parallel()
+
+ var (
+ srv = newTestServer()
+ httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"*"}))
+ wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:")
+ )
+ defer srv.Stop()
+ defer httpsrv.Close()
+
+ client, err := DialWebsocket(context.Background(), wsURL, "")
+ if err != nil {
+ t.Fatalf("can't dial: %v", err)
}
- if config.Header.Get("Authorization") != "Basic dGVzdHVzZXI6dGVzdC1QQVNTXzAx" {
- t.Fatal("Basic auth header is incorrect")
+ defer client.Close()
+
+ // This call sends slightly less than the limit and should work.
+ var result Result
+ arg := strings.Repeat("x", maxRequestContentLength-200)
+ if err := client.Call(&result, "test_echo", arg, 1); err != nil {
+ t.Fatalf("valid call didn't work: %v", err)
+ }
+ if result.String != arg {
+ t.Fatal("wrong string echoed")
+ }
+
+ // This call sends twice the allowed size and shouldn't work.
+ arg = strings.Repeat("x", maxRequestContentLength*2)
+ err = client.Call(&result, "test_echo", arg)
+ if err == nil {
+ t.Fatal("no error for too large call")
+ }
+}
+
+// This test checks that client handles WebSocket ping frames correctly.
+func TestClientWebsocketPing(t *testing.T) {
+ t.Parallel()
+
+ var (
+ sendPing = make(chan struct{})
+ server = wsPingTestServer(t, sendPing)
+ ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second)
+ )
+ defer cancel()
+ defer server.Shutdown(ctx)
+
+ client, err := DialContext(ctx, "ws://"+server.Addr)
+ if err != nil {
+ t.Fatalf("client dial error: %v", err)
+ }
+ resultChan := make(chan int)
+ sub, err := client.EthSubscribe(ctx, resultChan, "foo")
+ if err != nil {
+ t.Fatalf("client subscribe error: %v", err)
+ }
+
+ // Wait for the context's deadline to be reached before proceeding.
+ // This is important for reproducing https://github.com/ethereum/go-ethereum/issues/19798
+ <-ctx.Done()
+ close(sendPing)
+
+ // Wait for the subscription result.
+ timeout := time.NewTimer(5 * time.Second)
+ for {
+ select {
+ case err := <-sub.Err():
+ t.Error("client subscription error:", err)
+ case result := <-resultChan:
+ t.Log("client got result:", result)
+ return
+ case <-timeout.C:
+ t.Error("didn't get any result within the test timeout")
+ return
+ }
+ }
+}
+
+// wsPingTestServer runs a WebSocket server which accepts a single subscription request.
+// When a value arrives on sendPing, the server sends a ping frame, waits for a matching
+// pong and finally delivers a single subscription result.
+func wsPingTestServer(t *testing.T, sendPing <-chan struct{}) *http.Server {
+ var srv http.Server
+ shutdown := make(chan struct{})
+ srv.RegisterOnShutdown(func() {
+ close(shutdown)
+ })
+ srv.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ // Upgrade to WebSocket.
+ upgrader := websocket.Upgrader{
+ CheckOrigin: func(r *http.Request) bool { return true },
+ }
+ conn, err := upgrader.Upgrade(w, r, nil)
+ if err != nil {
+ t.Errorf("server WS upgrade error: %v", err)
+ return
+ }
+ defer conn.Close()
+
+ // Handle the connection.
+ wsPingTestHandler(t, conn, shutdown, sendPing)
+ })
+
+ // Start the server.
+ listener, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal("can't listen:", err)
+ }
+ srv.Addr = listener.Addr().String()
+ go srv.Serve(listener)
+ return &srv
+}
+
+func wsPingTestHandler(t *testing.T, conn *websocket.Conn, shutdown, sendPing <-chan struct{}) {
+ // Canned responses for the eth_subscribe call in TestClientWebsocketPing.
+ const (
+ subResp = `{"jsonrpc":"2.0","id":1,"result":"0x00"}`
+ subNotify = `{"jsonrpc":"2.0","method":"eth_subscription","params":{"subscription":"0x00","result":1}}`
+ )
+
+ // Handle subscribe request.
+ if _, _, err := conn.ReadMessage(); err != nil {
+ t.Errorf("server read error: %v", err)
+ return
+ }
+ if err := conn.WriteMessage(websocket.TextMessage, []byte(subResp)); err != nil {
+ t.Errorf("server write error: %v", err)
+ return
+ }
+
+ // Read from the connection to process control messages.
+ var pongCh = make(chan string)
+ conn.SetPongHandler(func(d string) error {
+ t.Logf("server got pong: %q", d)
+ pongCh <- d
+ return nil
+ })
+ go func() {
+ for {
+ typ, msg, err := conn.ReadMessage()
+ if err != nil {
+ return
+ }
+ t.Logf("server got message (%d): %q", typ, msg)
+ }
+ }()
+
+ // Write messages.
+ var (
+ sendResponse <-chan time.Time
+ wantPong string
+ )
+ for {
+ select {
+ case _, open := <-sendPing:
+ if !open {
+ sendPing = nil
+ }
+ t.Logf("server sending ping")
+ conn.WriteMessage(websocket.PingMessage, []byte("ping"))
+ wantPong = "ping"
+ case data := <-pongCh:
+ if wantPong == "" {
+ t.Errorf("unexpected pong")
+ } else if data != wantPong {
+ t.Errorf("got pong with wrong data %q", data)
+ }
+ wantPong = ""
+ sendResponse = time.NewTimer(200 * time.Millisecond).C
+ case <-sendResponse:
+ t.Logf("server sending response")
+ conn.WriteMessage(websocket.TextMessage, []byte(subNotify))
+ sendResponse = nil
+ case <-shutdown:
+ conn.Close()
+ return
+ }
}
}