aboutsummaryrefslogtreecommitdiffstats
path: root/rpc
diff options
context:
space:
mode:
Diffstat (limited to 'rpc')
-rw-r--r--rpc/http.go10
-rw-r--r--rpc/http_test.go2
-rw-r--r--rpc/json.go46
-rw-r--r--rpc/websocket.go30
4 files changed, 66 insertions, 22 deletions
diff --git a/rpc/http.go b/rpc/http.go
index a46d8c2b3..9805d69b6 100644
--- a/rpc/http.go
+++ b/rpc/http.go
@@ -27,16 +27,16 @@ import (
"mime"
"net"
"net/http"
+ "strings"
"sync"
"time"
"github.com/rs/cors"
- "strings"
)
const (
- contentType = "application/json"
- maxHTTPRequestContentLength = 1024 * 128
+ contentType = "application/json"
+ maxRequestContentLength = 1024 * 128
)
var nullAddr, _ = net.ResolveTCPAddr("tcp", "127.0.0.1:0")
@@ -182,8 +182,8 @@ func validateRequest(r *http.Request) (int, error) {
if r.Method == http.MethodPut || r.Method == http.MethodDelete {
return http.StatusMethodNotAllowed, errors.New("method not allowed")
}
- if r.ContentLength > maxHTTPRequestContentLength {
- err := fmt.Errorf("content length too large (%d>%d)", r.ContentLength, maxHTTPRequestContentLength)
+ if r.ContentLength > maxRequestContentLength {
+ err := fmt.Errorf("content length too large (%d>%d)", r.ContentLength, maxRequestContentLength)
return http.StatusRequestEntityTooLarge, err
}
mt, _, err := mime.ParseMediaType(r.Header.Get("content-type"))
diff --git a/rpc/http_test.go b/rpc/http_test.go
index aed84f683..b3f694d8a 100644
--- a/rpc/http_test.go
+++ b/rpc/http_test.go
@@ -32,7 +32,7 @@ func TestHTTPErrorResponseWithPut(t *testing.T) {
}
func TestHTTPErrorResponseWithMaxContentLength(t *testing.T) {
- body := make([]rune, maxHTTPRequestContentLength+1)
+ body := make([]rune, maxRequestContentLength+1)
testHTTPErrorResponse(t,
http.MethodPost, contentType, string(body), http.StatusRequestEntityTooLarge)
}
diff --git a/rpc/json.go b/rpc/json.go
index 2e7fd599e..837011f51 100644
--- a/rpc/json.go
+++ b/rpc/json.go
@@ -76,13 +76,13 @@ type jsonNotification struct {
// jsonCodec reads and writes JSON-RPC messages to the underlying connection. It
// also has support for parsing arguments and serializing (result) objects.
type jsonCodec struct {
- closer sync.Once // close closed channel once
- closed chan interface{} // closed on Close
- decMu sync.Mutex // guards d
- d *json.Decoder // decodes incoming requests
- encMu sync.Mutex // guards e
- e *json.Encoder // encodes responses
- rw io.ReadWriteCloser // connection
+ closer sync.Once // close closed channel once
+ closed chan interface{} // closed on Close
+ decMu sync.Mutex // guards the decoder
+ decode func(v interface{}) error // decoder to allow multiple transports
+ encMu sync.Mutex // guards the encoder
+ encode func(v interface{}) error // encoder to allow multiple transports
+ rw io.ReadWriteCloser // connection
}
func (err *jsonError) Error() string {
@@ -96,11 +96,29 @@ func (err *jsonError) ErrorCode() int {
return err.Code
}
-// NewJSONCodec creates a new RPC server codec with support for JSON-RPC 2.0
+// NewCodec creates a new RPC server codec with support for JSON-RPC 2.0 based
+// on explicitly given encoding and decoding methods.
+func NewCodec(rwc io.ReadWriteCloser, encode, decode func(v interface{}) error) ServerCodec {
+ return &jsonCodec{
+ closed: make(chan interface{}),
+ encode: encode,
+ decode: decode,
+ rw: rwc,
+ }
+}
+
+// NewJSONCodec creates a new RPC server codec with support for JSON-RPC 2.0.
func NewJSONCodec(rwc io.ReadWriteCloser) ServerCodec {
- d := json.NewDecoder(rwc)
- d.UseNumber()
- return &jsonCodec{closed: make(chan interface{}), d: d, e: json.NewEncoder(rwc), rw: rwc}
+ enc := json.NewEncoder(rwc)
+ dec := json.NewDecoder(rwc)
+ dec.UseNumber()
+
+ return &jsonCodec{
+ closed: make(chan interface{}),
+ encode: enc.Encode,
+ decode: dec.Decode,
+ rw: rwc,
+ }
}
// isBatch returns true when the first non-whitespace characters is '['
@@ -123,14 +141,12 @@ func (c *jsonCodec) ReadRequestHeaders() ([]rpcRequest, bool, Error) {
defer c.decMu.Unlock()
var incomingMsg json.RawMessage
- if err := c.d.Decode(&incomingMsg); err != nil {
+ if err := c.decode(&incomingMsg); err != nil {
return nil, false, &invalidRequestError{err.Error()}
}
-
if isBatch(incomingMsg) {
return parseBatchRequest(incomingMsg)
}
-
return parseRequest(incomingMsg)
}
@@ -338,7 +354,7 @@ func (c *jsonCodec) Write(res interface{}) error {
c.encMu.Lock()
defer c.encMu.Unlock()
- return c.e.Encode(res)
+ return c.encode(res)
}
// Close the underlying connection
diff --git a/rpc/websocket.go b/rpc/websocket.go
index 4214fc86a..a6e1cec28 100644
--- a/rpc/websocket.go
+++ b/rpc/websocket.go
@@ -17,8 +17,10 @@
package rpc
import (
+ "bytes"
"context"
"crypto/tls"
+ "encoding/json"
"fmt"
"net"
"net/http"
@@ -32,6 +34,23 @@ import (
"gopkg.in/fatih/set.v0"
)
+// websocketJSONCodec is a custom JSON codec with payload size enforcement and
+// special number parsing.
+var websocketJSONCodec = websocket.Codec{
+ // Marshal is the stock JSON marshaller used by the websocket library too.
+ Marshal: func(v interface{}) ([]byte, byte, error) {
+ msg, err := json.Marshal(v)
+ return msg, websocket.TextFrame, err
+ },
+ // Unmarshal is a specialized unmarshaller to properly convert numbers.
+ Unmarshal: func(msg []byte, payloadType byte, v interface{}) error {
+ dec := json.NewDecoder(bytes.NewReader(msg))
+ dec.UseNumber()
+
+ return dec.Decode(v)
+ },
+}
+
// WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections.
//
// allowedOrigins should be a comma-separated list of allowed origin URLs.
@@ -40,7 +59,16 @@ func (srv *Server) WebsocketHandler(allowedOrigins []string) http.Handler {
return websocket.Server{
Handshake: wsHandshakeValidator(allowedOrigins),
Handler: func(conn *websocket.Conn) {
- srv.ServeCodec(NewJSONCodec(conn), OptionMethodInvocation|OptionSubscriptions)
+ // Create a custom encode/decode pair to enforce payload size and number encoding
+ conn.MaxPayloadBytes = maxRequestContentLength
+
+ encoder := func(v interface{}) error {
+ return websocketJSONCodec.Send(conn, v)
+ }
+ decoder := func(v interface{}) error {
+ return websocketJSONCodec.Receive(conn, v)
+ }
+ srv.ServeCodec(NewCodec(conn, encoder, decoder), OptionMethodInvocation|OptionSubscriptions)
},
}
}