aboutsummaryrefslogtreecommitdiffstats
path: root/rpc/websocket.go
diff options
context:
space:
mode:
Diffstat (limited to 'rpc/websocket.go')
-rw-r--r--rpc/websocket.go30
1 files changed, 29 insertions, 1 deletions
diff --git a/rpc/websocket.go b/rpc/websocket.go
index 4214fc86a..a6e1cec28 100644
--- a/rpc/websocket.go
+++ b/rpc/websocket.go
@@ -17,8 +17,10 @@
package rpc
import (
+ "bytes"
"context"
"crypto/tls"
+ "encoding/json"
"fmt"
"net"
"net/http"
@@ -32,6 +34,23 @@ import (
"gopkg.in/fatih/set.v0"
)
+// websocketJSONCodec is a custom JSON codec with payload size enforcement and
+// special number parsing.
+var websocketJSONCodec = websocket.Codec{
+ // Marshal is the stock JSON marshaller used by the websocket library too.
+ Marshal: func(v interface{}) ([]byte, byte, error) {
+ msg, err := json.Marshal(v)
+ return msg, websocket.TextFrame, err
+ },
+ // Unmarshal is a specialized unmarshaller to properly convert numbers.
+ Unmarshal: func(msg []byte, payloadType byte, v interface{}) error {
+ dec := json.NewDecoder(bytes.NewReader(msg))
+ dec.UseNumber()
+
+ return dec.Decode(v)
+ },
+}
+
// WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections.
//
// allowedOrigins should be a comma-separated list of allowed origin URLs.
@@ -40,7 +59,16 @@ func (srv *Server) WebsocketHandler(allowedOrigins []string) http.Handler {
return websocket.Server{
Handshake: wsHandshakeValidator(allowedOrigins),
Handler: func(conn *websocket.Conn) {
- srv.ServeCodec(NewJSONCodec(conn), OptionMethodInvocation|OptionSubscriptions)
+ // Create a custom encode/decode pair to enforce payload size and number encoding
+ conn.MaxPayloadBytes = maxRequestContentLength
+
+ encoder := func(v interface{}) error {
+ return websocketJSONCodec.Send(conn, v)
+ }
+ decoder := func(v interface{}) error {
+ return websocketJSONCodec.Receive(conn, v)
+ }
+ srv.ServeCodec(NewCodec(conn, encoder, decoder), OptionMethodInvocation|OptionSubscriptions)
},
}
}