From f67ca53c079bd5eca59431f52407c04d39802ee8 Mon Sep 17 00:00:00 2001 From: cmbrose <5447118+cmbrose@users.noreply.github.com> Date: Fri, 3 Jun 2022 13:39:52 -0500 Subject: [PATCH] Refactor ssh_keys to a more common location --- pkg/cmd/auth/shared/login_flow.go | 49 +++++++++- pkg/cmd/auth/shared/login_flow_test.go | 7 +- pkg/cmd/auth/shared/ssh_keys.go | 129 ------------------------- pkg/cmd/codespace/ssh.go | 52 ++++++---- pkg/liveshare/session.go | 20 +++- pkg/ssh/ssh_keys.go | 109 +++++++++++++++++++++ 6 files changed, 209 insertions(+), 157 deletions(-) delete mode 100644 pkg/cmd/auth/shared/ssh_keys.go create mode 100644 pkg/ssh/ssh_keys.go diff --git a/pkg/cmd/auth/shared/login_flow.go b/pkg/cmd/auth/shared/login_flow.go index 68927b74e..aa1acaa82 100644 --- a/pkg/cmd/auth/shared/login_flow.go +++ b/pkg/cmd/auth/shared/login_flow.go @@ -3,6 +3,7 @@ package shared import ( "fmt" "net/http" + "os" "strings" "github.com/AlecAivazis/survey/v2" @@ -10,8 +11,10 @@ import ( "github.com/cli/cli/v2/api" "github.com/cli/cli/v2/internal/authflow" "github.com/cli/cli/v2/internal/ghinstance" + "github.com/cli/cli/v2/pkg/cmd/ssh-key/add" "github.com/cli/cli/v2/pkg/iostreams" "github.com/cli/cli/v2/pkg/prompt" + "github.com/cli/cli/v2/pkg/ssh" ) const defaultSSHKeyTitle = "GitHub CLI" @@ -34,7 +37,7 @@ type LoginOptions struct { Executable string GitProtocol string - sshContext SshContext + sshContext ssh.SshContext } func Login(opts *LoginOptions) error { @@ -72,7 +75,7 @@ func Login(opts *LoginOptions) error { var keyToUpload string keyTitle := defaultSSHKeyTitle if opts.Interactive && gitProtocol == "ssh" { - pubKeys, err := opts.sshContext.localPublicKeys() + pubKeys, err := opts.sshContext.LocalPublicKeys() if err != nil { return err } @@ -89,11 +92,25 @@ func Login(opts *LoginOptions) error { if keyChoice < len(pubKeys) { keyToUpload = pubKeys[keyChoice] } - } else { + } else if opts.sshContext.HasKeygen() { + var sshChoice bool var err error - keyToUpload, err = opts.sshContext.GenerateSSHKey() + + err = prompt.SurveyAskOne(&survey.Confirm{ + Message: "Generate a new SSH key to add to your GitHub account?", + Default: true, + }, &sshChoice) + if err != nil { - return err + return fmt.Errorf("could not prompt: %w", err) + } + + if sshChoice { + keyPair, err := opts.sshContext.GenerateSSHKey("id_ed25519", true, promptForSshKeyPassphrase) + if err != nil { + return err + } + keyToUpload = keyPair.PublicKeyPath } } @@ -210,6 +227,18 @@ func Login(opts *LoginOptions) error { return nil } +func promptForSshKeyPassphrase() (string, error) { + var sshPassphrase string + err := prompt.SurveyAskOne(&survey.Password{ + Message: "Enter a passphrase for your new SSH key (Optional)", + }, &sshPassphrase) + if err != nil { + return "", fmt.Errorf("could not prompt: %w", err) + } + + return sshPassphrase, nil +} + func scopesSentence(scopes []string, isEnterprise bool) string { quoted := make([]string, len(scopes)) for i, s := range scopes { @@ -221,3 +250,13 @@ func scopesSentence(scopes []string, isEnterprise bool) string { } return strings.Join(quoted, ", ") } + +func sshKeyUpload(httpClient *http.Client, hostname, keyFile string, title string) error { + f, err := os.Open(keyFile) + if err != nil { + return err + } + defer f.Close() + + return add.SSHKeyUpload(httpClient, hostname, f, title) +} diff --git a/pkg/cmd/auth/shared/login_flow_test.go b/pkg/cmd/auth/shared/login_flow_test.go index 9cac2f367..3d99d1ebc 100644 --- a/pkg/cmd/auth/shared/login_flow_test.go +++ b/pkg/cmd/auth/shared/login_flow_test.go @@ -12,6 +12,7 @@ import ( "github.com/cli/cli/v2/pkg/httpmock" "github.com/cli/cli/v2/pkg/iostreams" "github.com/cli/cli/v2/pkg/prompt" + "github.com/cli/cli/v2/pkg/ssh" "github.com/stretchr/testify/assert" ) @@ -84,9 +85,9 @@ func TestLogin_ssh(t *testing.T) { HTTPClient: &http.Client{Transport: &tr}, Hostname: "example.com", Interactive: true, - sshContext: SshContext{ - configDir: dir, - keygenExe: "ssh-keygen", + sshContext: ssh.SshContext{ + ConfigDir: dir, + KeygenExe: "ssh-keygen", }, }) assert.NoError(t, err) diff --git a/pkg/cmd/auth/shared/ssh_keys.go b/pkg/cmd/auth/shared/ssh_keys.go deleted file mode 100644 index 3dd596496..000000000 --- a/pkg/cmd/auth/shared/ssh_keys.go +++ /dev/null @@ -1,129 +0,0 @@ -package shared - -import ( - "fmt" - "net/http" - "os" - "os/exec" - "path/filepath" - "runtime" - - "github.com/AlecAivazis/survey/v2" - "github.com/cli/cli/v2/internal/config" - "github.com/cli/cli/v2/internal/run" - "github.com/cli/cli/v2/pkg/cmd/ssh-key/add" - "github.com/cli/cli/v2/pkg/prompt" - "github.com/cli/safeexec" -) - -type SshContext struct { - configDir string - keygenExe string -} - -func (c *SshContext) sshDir() (string, error) { - if c.configDir != "" { - return c.configDir, nil - } - dir, err := config.HomeDirPath(".ssh") - if err == nil { - c.configDir = dir - } - return dir, err -} - -func (c *SshContext) localPublicKeys() ([]string, error) { - sshDir, err := c.sshDir() - if err != nil { - return nil, err - } - - return filepath.Glob(filepath.Join(sshDir, "*.pub")) -} - -func (c *SshContext) findKeygen() (string, error) { - if c.keygenExe != "" { - return c.keygenExe, nil - } - - keygenExe, err := safeexec.LookPath("ssh-keygen") - if err != nil && runtime.GOOS == "windows" { - // We can try and find ssh-keygen in a Git for Windows install - if gitPath, err := safeexec.LookPath("git"); err == nil { - gitKeygen := filepath.Join(filepath.Dir(gitPath), "..", "usr", "bin", "ssh-keygen.exe") - if _, err = os.Stat(gitKeygen); err == nil { - return gitKeygen, nil - } - } - } - - if err == nil { - c.keygenExe = keygenExe - } - return keygenExe, err -} - -func (c *SshContext) GenerateSSHKey() (string, error) { - return c.GenerateSSHKeyWithOptions("id_ed25519", true) -} - -func (c *SshContext) GenerateSSHKeyWithOptions(keyName string, errorOnExists bool) (string, error) { - keygenExe, err := c.findKeygen() - if err != nil { - // give up silently if `ssh-keygen` is not available - return "", nil - } - - // TODO: Prompt after searching for existing key - var sshChoice bool - err = prompt.SurveyAskOne(&survey.Confirm{ - // TODO: Change this message if we're not uploading - Message: "Generate a new SSH key to add to your GitHub account?", - Default: true, - }, &sshChoice) - if err != nil { - return "", fmt.Errorf("could not prompt: %w", err) - } - if !sshChoice { - return "", nil - } - - sshDir, err := c.sshDir() - if err != nil { - return "", err - } - keyFile := filepath.Join(sshDir, keyName) - if _, err := os.Stat(keyFile); err == nil { - if errorOnExists { - return "", fmt.Errorf("refusing to overwrite file %s", keyFile) - } else { - return keyFile + ".pub", nil - } - } - - if err := os.MkdirAll(filepath.Dir(keyFile), 0711); err != nil { - return "", err - } - - var sshLabel string - var sshPassphrase string - err = prompt.SurveyAskOne(&survey.Password{ - Message: "Enter a passphrase for your new SSH key (Optional)", - }, &sshPassphrase) - if err != nil { - return "", fmt.Errorf("could not prompt: %w", err) - } - - keygenCmd := exec.Command(keygenExe, "-t", "ed25519", "-C", sshLabel, "-N", sshPassphrase, "-f", keyFile) - return keyFile + ".pub", run.PrepareCmd(keygenCmd).Run() -} - -func sshKeyUpload(httpClient *http.Client, hostname, keyFile string, title string) error { - f, err := os.Open(keyFile) - if err != nil { - return err - } - defer f.Close() - - return add.SSHKeyUpload(httpClient, hostname, f, title) -} diff --git a/pkg/cmd/codespace/ssh.go b/pkg/cmd/codespace/ssh.go index f25e148c6..1dd601db5 100644 --- a/pkg/cmd/codespace/ssh.go +++ b/pkg/cmd/codespace/ssh.go @@ -18,9 +18,9 @@ import ( "github.com/MakeNowJust/heredoc" "github.com/cli/cli/v2/internal/codespaces" "github.com/cli/cli/v2/internal/codespaces/api" - "github.com/cli/cli/v2/pkg/cmd/auth/shared" "github.com/cli/cli/v2/pkg/cmdutil" "github.com/cli/cli/v2/pkg/liveshare" + "github.com/cli/cli/v2/pkg/ssh" "github.com/spf13/cobra" ) @@ -116,22 +116,23 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e ctx, cancel := context.WithCancel(ctx) defer cancel() - sshContext := shared.SshContext{} - keyFile, err := sshContext.GenerateSSHKeyWithOptions("codespaces", false) - if err != nil { - return err + liveshareSSHOptions := liveshare.StartSSHServerOptions{} + args := sshArgs + if opts.scpArgs != nil { + args = opts.scpArgs } - fmt.Println(keyFile) + sshContext := ssh.SshContext{} + if shouldGenerateSSHKeys(args, opts) && sshContext.HasKeygen() { + keyPair, err := sshContext.GenerateSSHKey("codespaces", false, nil) + if err != nil { + return fmt.Errorf("failed to generate ssh keys: %s", err) + } - // TODO: Fix bug that we read the entire file versus only the first line (where the SSH key is) - userPublicKey, err := os.ReadFile(keyFile) - if err != nil { - return err + liveshareSSHOptions.UserPublicKeyFile = keyPair.PublicKeyPath + args = append(args, "-i", keyPair.PrivateKeyPath) } - fmt.Println(userPublicKey) - codespace, err := getOrChooseCodespace(ctx, a.apiClient, opts.codespace) if err != nil { return err @@ -144,9 +145,7 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e defer safeClose(session, &err) a.StartProgressIndicatorWithLabel("Fetching SSH Details") - remoteSSHServerPort, sshUser, err := session.StartSSHServerWithOptions(ctx, liveshare.StartSSHServerOptions{ - UserPublicKey: string(userPublicKey), - }) + remoteSSHServerPort, sshUser, err := session.StartSSHServerWithOptions(ctx, liveshareSSHOptions) a.StopProgressIndicator() if err != nil { return fmt.Errorf("error getting ssh server details: %w", err) @@ -187,9 +186,10 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e go func() { var err error if opts.scpArgs != nil { - err = codespaces.Copy(ctx, opts.scpArgs, localSSHServerPort, connectDestination) + // args is the correct variable to use here, we just use scpArgs as the check for which command to run + err = codespaces.Copy(ctx, args, localSSHServerPort, connectDestination) } else { - err = codespaces.Shell(ctx, a.errLogger, sshArgs, localSSHServerPort, connectDestination, usingCustomPort) + err = codespaces.Shell(ctx, a.errLogger, args, localSSHServerPort, connectDestination, usingCustomPort) } shellClosed <- err }() @@ -205,6 +205,24 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e } } +func shouldGenerateSSHKeys(args []string, opts sshOptions) bool { + if opts.profile != "" { + // The profile may specify the identity file so cautiously don't override that option + return false + } + + for _, arg := range args { + if arg == "-i" { + // User specified the identity file so it should exist + return false + } + + // TODO: should -F have similar behavior? + } + + return true +} + func (a *App) printOpenSSHConfig(ctx context.Context, opts sshOptions) (err error) { ctx, cancel := context.WithCancel(ctx) defer cancel() diff --git a/pkg/liveshare/session.go b/pkg/liveshare/session.go index 09950a939..7cc1da441 100644 --- a/pkg/liveshare/session.go +++ b/pkg/liveshare/session.go @@ -3,7 +3,9 @@ package liveshare import ( "context" "fmt" + "os" "strconv" + "strings" "time" ) @@ -18,7 +20,7 @@ type Session struct { } type StartSSHServerOptions struct { - UserPublicKey string `json:"userPublicKey"` + UserPublicKeyFile string } // Close should be called by users to clean up RPC and SSH resources whenever the session @@ -50,6 +52,10 @@ func (s *Session) StartSSHServer(ctx context.Context) (int, string, error) { // necessary, applies specified options, and returns the port on which it listens and // the user name clients should provide. func (s *Session) StartSSHServerWithOptions(ctx context.Context, opts StartSSHServerOptions) (int, string, error) { + var params struct { + UserPublicKey string `json:"userPublicKey"` + } + var response struct { Result bool `json:"result"` ServerPort string `json:"serverPort"` @@ -57,11 +63,19 @@ func (s *Session) StartSSHServerWithOptions(ctx context.Context, opts StartSSHSe Message string `json:"message"` } + if opts.UserPublicKeyFile != "" { + publicKeyBytes, err := os.ReadFile(opts.UserPublicKeyFile) + if err != nil { + return 0, "", fmt.Errorf("failed to read public key file: %s", err) + } + + params.UserPublicKey = strings.TrimSpace(string(publicKeyBytes)) + } + // Add param with key here, update corresponding on C# side // TODO: Use this object once we update the service // params := []interface{}{opts} - params := []string{opts.UserPublicKey} - if err := s.rpc.do(ctx, "ISshServerHostService.startRemoteServerWithOptions", ¶ms, &response); err != nil { + if err := s.rpc.do(ctx, "ISshServerHostService.startRemoteServerWithOptions", params, &response); err != nil { return 0, "", err } diff --git a/pkg/ssh/ssh_keys.go b/pkg/ssh/ssh_keys.go new file mode 100644 index 000000000..e3572082d --- /dev/null +++ b/pkg/ssh/ssh_keys.go @@ -0,0 +1,109 @@ +package ssh + +import ( + "fmt" + "os" + "os/exec" + "path/filepath" + "runtime" + + "github.com/cli/cli/v2/internal/config" + "github.com/cli/cli/v2/internal/run" + "github.com/cli/safeexec" +) + +type SshContext struct { + ConfigDir string + KeygenExe string +} + +type SshKeyPair struct { + PublicKeyPath string + PrivateKeyPath string +} + +func (c *SshContext) LocalPublicKeys() ([]string, error) { + sshDir, err := c.sshDir() + if err != nil { + return nil, err + } + + return filepath.Glob(filepath.Join(sshDir, "*.pub")) +} + +func (c *SshContext) HasKeygen() bool { + _, err := c.findKeygen() + return err == nil +} + +func (c *SshContext) GenerateSSHKey(keyName string, errorOnExists bool, promptPassphrase func() (string, error)) (SshKeyPair, error) { + keygenExe, err := c.findKeygen() + if err != nil { + // TODO: is there a nicer way to do this default SshKeyPair? + return SshKeyPair{}, fmt.Errorf("could not find keygen executable") + } + + sshDir, err := c.sshDir() + if err != nil { + return SshKeyPair{}, err + } + keyFile := filepath.Join(sshDir, keyName) + keyPair := SshKeyPair{ + PublicKeyPath: keyFile + ".pub", + PrivateKeyPath: keyFile, + } + + if _, err := os.Stat(keyFile); err == nil { + if errorOnExists { + return SshKeyPair{}, fmt.Errorf("refusing to overwrite file %s", keyFile) + } else { + return keyPair, nil + } + } + + if err := os.MkdirAll(filepath.Dir(keyFile), 0711); err != nil { + return SshKeyPair{}, err + } + + var sshPassphrase string + if promptPassphrase != nil { + sshPassphrase, err = promptPassphrase() + } + + // TOOD: sshLabel was never set, so should -C just be removed? + keygenCmd := exec.Command(keygenExe, "-t", "ed25519", "-C", "", "-N", sshPassphrase, "-f", keyFile) + return keyPair, run.PrepareCmd(keygenCmd).Run() +} + +func (c *SshContext) sshDir() (string, error) { + if c.ConfigDir != "" { + return c.ConfigDir, nil + } + dir, err := config.HomeDirPath(".ssh") + if err == nil { + c.ConfigDir = dir + } + return dir, err +} + +func (c *SshContext) findKeygen() (string, error) { + if c.KeygenExe != "" { + return c.KeygenExe, nil + } + + keygenExe, err := safeexec.LookPath("ssh-keygen") + if err != nil && runtime.GOOS == "windows" { + // We can try and find ssh-keygen in a Git for Windows install + if gitPath, err := safeexec.LookPath("git"); err == nil { + gitKeygen := filepath.Join(filepath.Dir(gitPath), "..", "usr", "bin", "ssh-keygen.exe") + if _, err = os.Stat(gitKeygen); err == nil { + return gitKeygen, nil + } + } + } + + if err == nil { + c.KeygenExe = keygenExe + } + return keygenExe, err +}