diff options
Diffstat (limited to 'rpc/websocket.go')
-rw-r--r-- | rpc/websocket.go | 160 |
1 files changed, 65 insertions, 95 deletions
diff --git a/rpc/websocket.go b/rpc/websocket.go index fe9354d94..fc3cd0709 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -17,36 +17,39 @@ package rpc import ( + "crypto/tls" "fmt" + "net" "net/http" + "net/url" "os" "strings" - "sync" "github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger/glog" + "golang.org/x/net/context" "golang.org/x/net/websocket" "gopkg.in/fatih/set.v0" ) -// wsReaderWriterCloser reads and write payloads from and to a websocket connection. -type wsReaderWriterCloser struct { - c *websocket.Conn -} - -// Read will read incoming payload data into p. -func (rw *wsReaderWriterCloser) Read(p []byte) (int, error) { - return rw.c.Read(p) -} - -// Write writes p to the websocket. -func (rw *wsReaderWriterCloser) Write(p []byte) (int, error) { - return rw.c.Write(p) +// WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections. +// +// allowedOrigins should be a comma-separated list of allowed origin URLs. +// To allow connections with any origin, pass "*". +func (srv *Server) WebsocketHandler(allowedOrigins string) http.Handler { + return websocket.Server{ + Handshake: wsHandshakeValidator(strings.Split(allowedOrigins, ",")), + Handler: func(conn *websocket.Conn) { + srv.ServeCodec(NewJSONCodec(conn), OptionMethodInvocation|OptionSubscriptions) + }, + } } -// Close closes the websocket connection. -func (rw *wsReaderWriterCloser) Close() error { - return rw.c.Close() +// NewWSServer creates a new websocket RPC server around an API provider. +// +// Deprecated: use Server.WebsocketHandler +func NewWSServer(allowedOrigins string, srv *Server) *http.Server { + return &http.Server{Handler: srv.WebsocketHandler(allowedOrigins)} } // wsHandshakeValidator returns a handler that verifies the origin during the @@ -87,96 +90,63 @@ func wsHandshakeValidator(allowedOrigins []string) func(*websocket.Config, *http return f } -// NewWSServer creates a new websocket RPC server around an API provider. -func NewWSServer(allowedOrigins string, handler *Server) *http.Server { - return &http.Server{ - Handler: websocket.Server{ - Handshake: wsHandshakeValidator(strings.Split(allowedOrigins, ",")), - Handler: func(conn *websocket.Conn) { - handler.ServeCodec(NewJSONCodec(&wsReaderWriterCloser{conn}), - OptionMethodInvocation|OptionSubscriptions) - }, - }, +// DialWebsocket creates a new RPC client that communicates with a JSON-RPC server +// that is listening on the given endpoint. +// +// The context is used for the initial connection establishment. It does not +// affect subsequent interactions with the client. +func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) { + if origin == "" { + var err error + if origin, err = os.Hostname(); err != nil { + return nil, err + } + if strings.HasPrefix(endpoint, "wss") { + origin = "https://" + strings.ToLower(origin) + } else { + origin = "http://" + strings.ToLower(origin) + } + } + config, err := websocket.NewConfig(endpoint, origin) + if err != nil { + return nil, err } -} - -// wsClient represents a RPC client that communicates over websockets with a -// RPC server. -type wsClient struct { - endpoint string - connMu sync.Mutex - conn *websocket.Conn -} -// NewWSClientj creates a new RPC client that communicates with a RPC server -// that is listening on the given endpoint using JSON encoding. -func NewWSClient(endpoint string) (Client, error) { - return &wsClient{endpoint: endpoint}, nil + return newClient(ctx, func(ctx context.Context) (net.Conn, error) { + return wsDialContext(ctx, config) + }) } -// connection will return a websocket connection to the RPC server. It will -// (re)connect when necessary. -func (client *wsClient) connection() (*websocket.Conn, error) { - if client.conn != nil { - return client.conn, nil +func wsDialContext(ctx context.Context, config *websocket.Config) (*websocket.Conn, error) { + var conn net.Conn + var err error + switch config.Location.Scheme { + case "ws": + conn, err = dialContext(ctx, "tcp", wsDialAddress(config.Location)) + case "wss": + dialer := contextDialer(ctx) + conn, err = tls.DialWithDialer(dialer, "tcp", wsDialAddress(config.Location), config.TlsConfig) + default: + err = websocket.ErrBadScheme } - - origin, err := os.Hostname() if err != nil { return nil, err } - - origin = "http://" + origin - client.conn, err = websocket.Dial(client.endpoint, "", origin) - - return client.conn, err -} - -// SupportedModules is the collection of modules the RPC server offers. -func (client *wsClient) SupportedModules() (map[string]string, error) { - return SupportedModules(client) -} - -// Send writes the JSON serialized msg to the websocket. It will create a new -// websocket connection to the server if the client is currently not connected. -func (client *wsClient) Send(msg interface{}) (err error) { - client.connMu.Lock() - defer client.connMu.Unlock() - - var conn *websocket.Conn - if conn, err = client.connection(); err == nil { - if err = websocket.JSON.Send(conn, msg); err != nil { - client.conn.Close() - client.conn = nil - } + ws, err := websocket.NewClient(config, conn) + if err != nil { + conn.Close() + return nil, err } - - return err + return ws, err } -// Recv reads a JSON message from the websocket and unmarshals it into msg. -func (client *wsClient) Recv(msg interface{}) (err error) { - client.connMu.Lock() - defer client.connMu.Unlock() +var wsPortMap = map[string]string{"ws": "80", "wss": "443"} - var conn *websocket.Conn - if conn, err = client.connection(); err == nil { - if err = websocket.JSON.Receive(conn, msg); err != nil { - client.conn.Close() - client.conn = nil +func wsDialAddress(location *url.URL) string { + if _, ok := wsPortMap[location.Scheme]; ok { + if _, _, err := net.SplitHostPort(location.Host); err != nil { + return net.JoinHostPort(location.Host, wsPortMap[location.Scheme]) } } - return -} - -// Close closes the underlaying websocket connection. -func (client *wsClient) Close() { - client.connMu.Lock() - defer client.connMu.Unlock() - - if client.conn != nil { - client.conn.Close() - client.conn = nil - } - + return location.Host } |