diff options
Diffstat (limited to 'rpc/websocket.go')
-rw-r--r-- | rpc/websocket.go | 30 |
1 files changed, 29 insertions, 1 deletions
diff --git a/rpc/websocket.go b/rpc/websocket.go index 4214fc86a..a6e1cec28 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -17,8 +17,10 @@ package rpc import ( + "bytes" "context" "crypto/tls" + "encoding/json" "fmt" "net" "net/http" @@ -32,6 +34,23 @@ import ( "gopkg.in/fatih/set.v0" ) +// websocketJSONCodec is a custom JSON codec with payload size enforcement and +// special number parsing. +var websocketJSONCodec = websocket.Codec{ + // Marshal is the stock JSON marshaller used by the websocket library too. + Marshal: func(v interface{}) ([]byte, byte, error) { + msg, err := json.Marshal(v) + return msg, websocket.TextFrame, err + }, + // Unmarshal is a specialized unmarshaller to properly convert numbers. + Unmarshal: func(msg []byte, payloadType byte, v interface{}) error { + dec := json.NewDecoder(bytes.NewReader(msg)) + dec.UseNumber() + + return dec.Decode(v) + }, +} + // WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections. // // allowedOrigins should be a comma-separated list of allowed origin URLs. @@ -40,7 +59,16 @@ func (srv *Server) WebsocketHandler(allowedOrigins []string) http.Handler { return websocket.Server{ Handshake: wsHandshakeValidator(allowedOrigins), Handler: func(conn *websocket.Conn) { - srv.ServeCodec(NewJSONCodec(conn), OptionMethodInvocation|OptionSubscriptions) + // Create a custom encode/decode pair to enforce payload size and number encoding + conn.MaxPayloadBytes = maxRequestContentLength + + encoder := func(v interface{}) error { + return websocketJSONCodec.Send(conn, v) + } + decoder := func(v interface{}) error { + return websocketJSONCodec.Receive(conn, v) + } + srv.ServeCodec(NewCodec(conn, encoder, decoder), OptionMethodInvocation|OptionSubscriptions) }, } } |