diff --git a/pkg/cmd/auth/shared/login_flow.go b/pkg/cmd/auth/shared/login_flow.go index a3698a8b5..08087a8ac 100644 --- a/pkg/cmd/auth/shared/login_flow.go +++ b/pkg/cmd/auth/shared/login_flow.go @@ -104,7 +104,7 @@ func Login(opts *LoginOptions) error { } if sshChoice { - keyPair, err := opts.sshContext.GenerateSSHKey("id_ed25519", true, promptForSshKeyPassphrase) + keyPair, err := opts.sshContext.GenerateSSHKey("id_ed25519", ssh.WithPassphrasePrompt(promptForSshKeyPassphrase)) if err != nil { return err } diff --git a/pkg/cmd/codespace/ssh.go b/pkg/cmd/codespace/ssh.go index 12699e5c3..cc0502286 100644 --- a/pkg/cmd/codespace/ssh.go +++ b/pkg/cmd/codespace/ssh.go @@ -124,7 +124,7 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e sshContext := ssh.SshContext{} if shouldGenerateSSHKeys(args, opts) && sshContext.HasKeygen() { - keyPair, err := sshContext.GenerateSSHKey("codespaces", false, nil) + keyPair, err := sshContext.GenerateSSHKey("codespaces", ssh.WithNoErrorOnExitingKey()) if err != nil { return fmt.Errorf("failed to generate ssh keys: %s", err) } diff --git a/pkg/ssh/ssh_keys.go b/pkg/ssh/ssh_keys.go index 7ecd36fff..1e45a59ef 100644 --- a/pkg/ssh/ssh_keys.go +++ b/pkg/ssh/ssh_keys.go @@ -36,16 +36,32 @@ func (c *SshContext) HasKeygen() bool { return err == nil } -func (c *SshContext) GenerateSSHKey(keyName string, errorOnExists bool, promptPassphrase func() (string, error)) (SshKeyPair, error) { +type GenerateSSHKeyOptions struct { + noErrorOnExists bool + promptPassphrase func() (string, error) +} + +func WithNoErrorOnExitingKey() func(*GenerateSSHKeyOptions) { + return func(o *GenerateSSHKeyOptions) { + o.noErrorOnExists = false + } +} + +func WithPassphrasePrompt(prompt func() (string, error)) func(*GenerateSSHKeyOptions) { + return func(o *GenerateSSHKeyOptions) { + o.promptPassphrase = prompt + } +} + +func (c *SshContext) GenerateSSHKey(keyName string, configureOptions ...func(*GenerateSSHKeyOptions)) (*SshKeyPair, error) { keygenExe, err := c.findKeygen() if err != nil { - // TODO: is there a nicer way to do this default SshKeyPair? - return SshKeyPair{}, fmt.Errorf("could not find keygen executable") + return nil, fmt.Errorf("could not find keygen executable") } sshDir, err := c.sshDir() if err != nil { - return SshKeyPair{}, err + return nil, err } keyFile := filepath.Join(sshDir, keyName) keyPair := SshKeyPair{ @@ -53,29 +69,38 @@ func (c *SshContext) GenerateSSHKey(keyName string, errorOnExists bool, promptPa PrivateKeyPath: keyFile, } + opts := GenerateSSHKeyOptions{} + for _, configure := range configureOptions { + configure(&opts) + } + if _, err := os.Stat(keyFile); err == nil { - if errorOnExists { - return SshKeyPair{}, fmt.Errorf("refusing to overwrite file %s", keyFile) + if opts.noErrorOnExists { + return &keyPair, nil } else { - return keyPair, nil + return nil, fmt.Errorf("refusing to overwrite file %s", keyFile) } } if err := os.MkdirAll(filepath.Dir(keyFile), 0711); err != nil { - return SshKeyPair{}, err + return nil, err } var sshPassphrase string - if promptPassphrase != nil { - sshPassphrase, err = promptPassphrase() + if opts.promptPassphrase != nil { + sshPassphrase, err = opts.promptPassphrase() if err != nil { - return SshKeyPair{}, err + return nil, err } } - // TOOD: sshLabel was never set, so should -C just be removed? - keygenCmd := exec.Command(keygenExe, "-t", "ed25519", "-C", "", "-N", sshPassphrase, "-f", keyFile) - return keyPair, run.PrepareCmd(keygenCmd).Run() + keygenCmd := exec.Command(keygenExe, "-t", "ed25519", "-N", sshPassphrase, "-f", keyFile) + err = run.PrepareCmd(keygenCmd).Run() + if err != nil { + return nil, err + } + + return &keyPair, nil } func (c *SshContext) sshDir() (string, error) {