From c35659c6a0d8c9dd7b7616bb91700385292f403a Mon Sep 17 00:00:00 2001 From: Peter Broadhurst Date: Wed, 19 Sep 2018 12:09:03 -0400 Subject: rpc: enable basic auth for websocket client (#17699) --- rpc/websocket.go | 27 ++++++++++++++++++++------ rpc/websocket_test.go | 54 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 6 deletions(-) create mode 100644 rpc/websocket_test.go (limited to 'rpc') diff --git a/rpc/websocket.go b/rpc/websocket.go index e7a86ddae..eae8320e5 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -20,6 +20,7 @@ import ( "bytes" "context" "crypto/tls" + "encoding/base64" "encoding/json" "fmt" "net" @@ -118,12 +119,7 @@ func wsHandshakeValidator(allowedOrigins []string) func(*websocket.Config, *http return f } -// DialWebsocket creates a new RPC client that communicates with a JSON-RPC server -// that is listening on the given endpoint. -// -// The context is used for the initial connection establishment. It does not -// affect subsequent interactions with the client. -func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) { +func wsGetConfig(endpoint, origin string) (*websocket.Config, error) { if origin == "" { var err error if origin, err = os.Hostname(); err != nil { @@ -140,6 +136,25 @@ func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error return nil, err } + if config.Location.User != nil { + b64auth := base64.StdEncoding.EncodeToString([]byte(config.Location.User.String())) + config.Header.Add("Authorization", "Basic "+b64auth) + config.Location.User = nil + } + return config, nil +} + +// DialWebsocket creates a new RPC client that communicates with a JSON-RPC server +// that is listening on the given endpoint. +// +// The context is used for the initial connection establishment. It does not +// affect subsequent interactions with the client. +func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) { + config, err := wsGetConfig(endpoint, origin) + if err != nil { + return nil, err + } + return newClient(ctx, func(ctx context.Context) (net.Conn, error) { return wsDialContext(ctx, config) }) diff --git a/rpc/websocket_test.go b/rpc/websocket_test.go new file mode 100644 index 000000000..5bf3780d6 --- /dev/null +++ b/rpc/websocket_test.go @@ -0,0 +1,54 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package rpc + +import "testing" + +func TestWSGetConfigNoAuth(t *testing.T) { + config, err := wsGetConfig("ws://example.com:1234", "") + if err != nil { + t.Logf("wsGetConfig failed: %s", err) + t.Fail() + return + } + if config.Location.User != nil { + t.Log("User should have been stripped from the URL") + t.Fail() + } + if config.Location.Hostname() != "example.com" || + config.Location.Port() != "1234" || config.Location.Scheme != "ws" { + t.Logf("Unexpected URL: %s", config.Location) + t.Fail() + } +} + +func TestWSGetConfigWithBasicAuth(t *testing.T) { + config, err := wsGetConfig("wss://testuser:test-PASS_01@example.com:1234", "") + if err != nil { + t.Logf("wsGetConfig failed: %s", err) + t.Fail() + return + } + if config.Location.User != nil { + t.Log("User should have been stripped from the URL") + t.Fail() + } + if config.Header.Get("Authorization") != "Basic dGVzdHVzZXI6dGVzdC1QQVNTXzAx" { + t.Log("Basic auth header is incorrect") + t.Fail() + } +} -- cgit v1.2.3