From 7d38d53ae449c6ec06f7b0579f1a189b02222a60 Mon Sep 17 00:00:00 2001 From: Nilesh Trivedi Date: Mon, 20 Aug 2018 19:24:38 +0530 Subject: cmd/puppeth: accept ssh identity in the server string (#17407) * cmd/puppeth: Accept identityfile in the server string with fallback to id_rsa * cmd/puppeth: code polishes + fix heath check double ports --- cmd/puppeth/ssh.go | 52 +++++++++++++++++++++++++------------------ cmd/puppeth/wizard_network.go | 8 +++---- 2 files changed, 34 insertions(+), 26 deletions(-) (limited to 'cmd') diff --git a/cmd/puppeth/ssh.go b/cmd/puppeth/ssh.go index 158261ce0..c50759606 100644 --- a/cmd/puppeth/ssh.go +++ b/cmd/puppeth/ssh.go @@ -45,33 +45,44 @@ type sshClient struct { // dial establishes an SSH connection to a remote node using the current user and // the user's configured private RSA key. If that fails, password authentication -// is fallen back to. The caller may override the login user via user@server:port. +// is fallen back to. server can be a string like user:identity@server:port. func dial(server string, pubkey []byte) (*sshClient, error) { - // Figure out a label for the server and a logger - label := server - if strings.Contains(label, ":") { - label = label[:strings.Index(label, ":")] - } - login := "" + // Figure out username, identity, hostname and port + hostname := "" + hostport := server + username := "" + identity := "id_rsa" // default + if strings.Contains(server, "@") { - login = label[:strings.Index(label, "@")] - label = label[strings.Index(label, "@")+1:] - server = server[strings.Index(server, "@")+1:] + prefix := server[:strings.Index(server, "@")] + if strings.Contains(prefix, ":") { + username = prefix[:strings.Index(prefix, ":")] + identity = prefix[strings.Index(prefix, ":")+1:] + } else { + username = prefix + } + hostport = server[strings.Index(server, "@")+1:] } - logger := log.New("server", label) + if strings.Contains(hostport, ":") { + hostname = hostport[:strings.Index(hostport, ":")] + } else { + hostname = hostport + hostport += ":22" + } + logger := log.New("server", server) logger.Debug("Attempting to establish SSH connection") user, err := user.Current() if err != nil { return nil, err } - if login == "" { - login = user.Username + if username == "" { + username = user.Username } // Configure the supported authentication methods (private key and password) var auths []ssh.AuthMethod - path := filepath.Join(user.HomeDir, ".ssh", "id_rsa") + path := filepath.Join(user.HomeDir, ".ssh", identity) if buf, err := ioutil.ReadFile(path); err != nil { log.Warn("No SSH key, falling back to passwords", "path", path, "err", err) } else { @@ -94,14 +105,14 @@ func dial(server string, pubkey []byte) (*sshClient, error) { } } auths = append(auths, ssh.PasswordCallback(func() (string, error) { - fmt.Printf("What's the login password for %s at %s? (won't be echoed)\n> ", login, server) + fmt.Printf("What's the login password for %s at %s? (won't be echoed)\n> ", username, server) blob, err := terminal.ReadPassword(int(os.Stdin.Fd())) fmt.Println() return string(blob), err })) // Resolve the IP address of the remote server - addr, err := net.LookupHost(label) + addr, err := net.LookupHost(hostname) if err != nil { return nil, err } @@ -109,10 +120,7 @@ func dial(server string, pubkey []byte) (*sshClient, error) { return nil, errors.New("no IPs associated with domain") } // Try to dial in to the remote server - logger.Trace("Dialing remote SSH server", "user", login) - if !strings.Contains(server, ":") { - server += ":22" - } + logger.Trace("Dialing remote SSH server", "user", username) keycheck := func(hostname string, remote net.Addr, key ssh.PublicKey) error { // If no public key is known for SSH, ask the user to confirm if pubkey == nil { @@ -139,13 +147,13 @@ func dial(server string, pubkey []byte) (*sshClient, error) { // We have a mismatch, forbid connecting return errors.New("ssh key mismatch, readd the machine to update") } - client, err := ssh.Dial("tcp", server, &ssh.ClientConfig{User: login, Auth: auths, HostKeyCallback: keycheck}) + client, err := ssh.Dial("tcp", hostport, &ssh.ClientConfig{User: username, Auth: auths, HostKeyCallback: keycheck}) if err != nil { return nil, err } // Connection established, return our utility wrapper c := &sshClient{ - server: label, + server: hostname, address: addr[0], pubkey: pubkey, client: client, diff --git a/cmd/puppeth/wizard_network.go b/cmd/puppeth/wizard_network.go index d780c550b..c0ddcc2a3 100644 --- a/cmd/puppeth/wizard_network.go +++ b/cmd/puppeth/wizard_network.go @@ -62,14 +62,14 @@ func (w *wizard) manageServers() { } } -// makeServer reads a single line from stdin and interprets it as a hostname to -// connect to. It tries to establish a new SSH session and also executing some -// baseline validations. +// makeServer reads a single line from stdin and interprets it as +// username:identity@hostname to connect to. It tries to establish a +// new SSH session and also executing some baseline validations. // // If connection succeeds, the server is added to the wizards configs! func (w *wizard) makeServer() string { fmt.Println() - fmt.Println("Please enter remote server's address:") + fmt.Println("What is the remote server's address ([username[:identity]@]hostname[:port])?") // Read and dial the server to ensure docker is present input := w.readString() -- cgit v1.2.3