diff --git a/pkg/cmd/auth/shared/login_flow_test.go b/pkg/cmd/auth/shared/login_flow_test.go index 8c8ba5d72..31cf2c107 100644 --- a/pkg/cmd/auth/shared/login_flow_test.go +++ b/pkg/cmd/auth/shared/login_flow_test.go @@ -101,10 +101,7 @@ func TestLogin(t *testing.T) { // simulate that the public key file has been generated _ = os.WriteFile(keyFile+".pub", []byte("PUBKEY asdf"), 0600) }) - opts.sshContext = ssh.Context{ - ConfigDir: dir, - KeygenExe: "ssh-keygen", - } + opts.sshContext = ssh.NewContextForTests(dir, "ssh-keygen") }, wantsConfig: map[string]string{ "example.com:user": "monalisa", @@ -112,6 +109,11 @@ func TestLogin(t *testing.T) { "example.com:git_protocol": "ssh", }, stderrAssert: func(t *testing.T, opts *LoginOptions, stderr string) { + sshDir, err := opts.sshContext.SshDir() + if err != nil { + t.Errorf("Could not load ssh config dir: %v", err) + } + assert.Equal(t, heredoc.Docf(` Tip: you can generate a Personal Access Token here https://example.com/settings/tokens The minimum required scopes are 'repo', 'read:org', 'admin:public_key'. @@ -119,7 +121,7 @@ func TestLogin(t *testing.T) { ✓ Configured git protocol ✓ Uploaded the SSH key to your GitHub account: %s ✓ Logged in as monalisa - `, filepath.Join(opts.sshContext.ConfigDir, "id_ed25519.pub")), stderr) + `, filepath.Join(sshDir, "id_ed25519.pub")), stderr) }, }, { @@ -179,10 +181,7 @@ func TestLogin(t *testing.T) { // simulate that the public key file has been generated _ = os.WriteFile(keyFile+".pub", []byte("PUBKEY asdf"), 0600) }) - opts.sshContext = ssh.Context{ - ConfigDir: dir, - KeygenExe: "ssh-keygen", - } + opts.sshContext = ssh.NewContextForTests(dir, "ssh-keygen") }, wantsConfig: map[string]string{ "example.com:user": "monalisa", @@ -190,6 +189,11 @@ func TestLogin(t *testing.T) { "example.com:git_protocol": "ssh", }, stderrAssert: func(t *testing.T, opts *LoginOptions, stderr string) { + sshDir, err := opts.sshContext.SshDir() + if err != nil { + t.Errorf("Could not load ssh config dir: %v", err) + } + assert.Equal(t, heredoc.Docf(` Tip: you can generate a Personal Access Token here https://example.com/settings/tokens The minimum required scopes are 'repo', 'read:org', 'admin:public_key'. @@ -197,7 +201,7 @@ func TestLogin(t *testing.T) { ✓ Configured git protocol ✓ Uploaded the SSH key to your GitHub account: %s ✓ Logged in as monalisa - `, filepath.Join(opts.sshContext.ConfigDir, "id_ed25519.pub")), stderr) + `, filepath.Join(sshDir, "id_ed25519.pub")), stderr) }, }, { diff --git a/pkg/cmd/codespace/ssh.go b/pkg/cmd/codespace/ssh.go index c6c3943e2..b79ec68c9 100644 --- a/pkg/cmd/codespace/ssh.go +++ b/pkg/cmd/codespace/ssh.go @@ -20,7 +20,6 @@ import ( "github.com/cli/cli/v2/internal/codespaces/api" "github.com/cli/cli/v2/internal/codespaces/portforwarder" "github.com/cli/cli/v2/internal/codespaces/rpc" - "github.com/cli/cli/v2/internal/config" "github.com/cli/cli/v2/pkg/cmdutil" "github.com/cli/cli/v2/pkg/ssh" "github.com/cli/safeexec" @@ -336,10 +335,20 @@ func selectSSHKeys( return nil, false, errors.New("missing value to -i argument") } + privateKeyPath := args[i+1] + + // The --config setup will set the automatic key with -i, but it might not actually be created, so we need to ensure that here + if automaticPrivateKeyPath, _ := automaticPrivateKeyPath(sshContext); automaticPrivateKeyPath == privateKeyPath { + _, err := generateAutomaticSSHKeys(sshContext) + if err != nil { + return nil, false, fmt.Errorf("generating automatic keypair: %w", err) + } + } + // User manually specified an identity file so just trust it is correct return &ssh.KeyPair{ - PrivateKeyPath: args[i+1], - PublicKeyPath: args[i+1] + ".pub", + PrivateKeyPath: privateKeyPath, + PublicKeyPath: privateKeyPath + ".pub", }, false, nil } @@ -636,7 +645,8 @@ func (a *App) printOpenSSHConfig(ctx context.Context, opts sshOptions) (err erro return fmt.Errorf("error formatting template: %w", err) } - automaticIdentityFilePath, err := automaticPrivateKeyPath() + sshContext := ssh.Context{} + automaticIdentityFilePath, err := automaticPrivateKeyPath(sshContext) if err != nil { return fmt.Errorf("error finding .ssh directory: %w", err) } @@ -683,8 +693,8 @@ func (a *App) printOpenSSHConfig(ctx context.Context, opts sshOptions) (err erro return status } -func automaticPrivateKeyPath() (string, error) { - sshDir, err := config.HomeDirPath(".ssh") +func automaticPrivateKeyPath(sshContext ssh.Context) (string, error) { + sshDir, err := sshContext.SshDir() if err != nil { return "", err } diff --git a/pkg/cmd/codespace/ssh_test.go b/pkg/cmd/codespace/ssh_test.go index b76d304d7..cd04f05fa 100644 --- a/pkg/cmd/codespace/ssh_test.go +++ b/pkg/cmd/codespace/ssh_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "os" + "path" "path/filepath" "strings" "testing" @@ -68,9 +69,7 @@ func TestGenerateAutomaticSSHKeys(t *testing.T) { for _, tt := range tests { dir := t.TempDir() - sshContext := ssh.Context{ - ConfigDir: dir, - } + sshContext := ssh.NewContextForTests(dir, "") for _, file := range tt.existingFiles { f, err := os.Create(filepath.Join(dir, file)) @@ -125,6 +124,10 @@ func TestGenerateAutomaticSSHKeys(t *testing.T) { } func TestSelectSSHKeys(t *testing.T) { + // This string will be subsituted in sshArgs for test cases + // This is to work around the temp test ssh dir not being known until the test is executing + substituteSSHDir := "SUB_SSH_DIR" + tests := []struct { sshDirFiles []string sshConfigKeys []string @@ -139,7 +142,7 @@ func TestSelectSSHKeys(t *testing.T) { wantKeyPair: &ssh.KeyPair{PrivateKeyPath: "custom-private-key", PublicKeyPath: "custom-private-key.pub"}, }, { - sshArgs: []string{"-i", automaticPrivateKeyName}, + sshArgs: []string{"-i", path.Join(substituteSSHDir, automaticPrivateKeyName)}, wantKeyPair: &ssh.KeyPair{PrivateKeyPath: automaticPrivateKeyName, PublicKeyPath: automaticPrivateKeyName + ".pub"}, }, { @@ -202,7 +205,7 @@ func TestSelectSSHKeys(t *testing.T) { for _, tt := range tests { sshDir := t.TempDir() - sshContext := ssh.Context{ConfigDir: sshDir} + sshContext := ssh.NewContextForTests(sshDir, "") for _, file := range tt.sshDirFiles { f, err := os.Create(filepath.Join(sshDir, file)) @@ -226,7 +229,12 @@ func TestSelectSSHKeys(t *testing.T) { t.Fatalf("could not write test config %v", err) } - tt.sshArgs = append([]string{"-F", configPath}, tt.sshArgs...) + var subbedSSHArgs []string + for _, arg := range tt.sshArgs { + subbedSSHArgs = append(subbedSSHArgs, strings.Replace(arg, substituteSSHDir, sshDir, -1)) + } + + tt.sshArgs = append([]string{"-F", configPath}, subbedSSHArgs...) gotKeyPair, gotShouldAddArg, err := selectSSHKeys(context.Background(), sshContext, tt.sshArgs, sshOptions{profile: tt.profileOpt}) @@ -254,11 +262,24 @@ func TestSelectSSHKeys(t *testing.T) { } // Strip the dir (sshDir) from the gotKeyPair paths so that they match wantKeyPair (which doesn't know the directory) - gotKeyPair.PrivateKeyPath = filepath.Base(gotKeyPair.PrivateKeyPath) - gotKeyPair.PublicKeyPath = filepath.Base(gotKeyPair.PublicKeyPath) + gotKeyPairJustFileNames := &ssh.KeyPair{ + PrivateKeyPath: filepath.Base(gotKeyPair.PrivateKeyPath), + PublicKeyPath: filepath.Base(gotKeyPair.PublicKeyPath), + } - if fmt.Sprintf("%v", gotKeyPair) != fmt.Sprintf("%v", tt.wantKeyPair) { - t.Errorf("Want selectSSHKeys result to be %v, got %v", tt.wantKeyPair, gotKeyPair) + if fmt.Sprintf("%v", gotKeyPairJustFileNames) != fmt.Sprintf("%v", tt.wantKeyPair) { + t.Errorf("Want selectSSHKeys result to be %v, got %v", tt.wantKeyPair, gotKeyPairJustFileNames) + } + + // If the automatic key pair is selected, it needs to exist no matter what + if strings.Contains(tt.wantKeyPair.PrivateKeyPath, automaticPrivateKeyName) { + if _, err := os.Stat(gotKeyPair.PrivateKeyPath); err != nil { + t.Errorf("Expected automatic key pair private key to exist, but it did not") + } + + if _, err := os.Stat(gotKeyPair.PublicKeyPath); err != nil { + t.Errorf("Expected automatic key pair public key to exist, but it did not") + } } } } diff --git a/pkg/ssh/ssh_keys.go b/pkg/ssh/ssh_keys.go index c750b608a..83d4f1b34 100644 --- a/pkg/ssh/ssh_keys.go +++ b/pkg/ssh/ssh_keys.go @@ -14,8 +14,17 @@ import ( ) type Context struct { - ConfigDir string - KeygenExe string + configDir string + keygenExe string +} + +// NewContextForTests creates a new `ssh.Context` with internal properties set to the +// specified values. It should only be used to inject test-specific setup. +func NewContextForTests(configDir, keygenExe string) Context { + return Context{ + configDir, + keygenExe, + } } type KeyPair struct { @@ -26,7 +35,7 @@ type KeyPair struct { var ErrKeyAlreadyExists = errors.New("SSH key already exists") func (c *Context) LocalPublicKeys() ([]string, error) { - sshDir, err := c.sshDir() + sshDir, err := c.SshDir() if err != nil { return nil, err } @@ -45,7 +54,7 @@ func (c *Context) GenerateSSHKey(keyName string, passphrase string) (*KeyPair, e return nil, err } - sshDir, err := c.sshDir() + sshDir, err := c.SshDir() if err != nil { return nil, err } @@ -79,20 +88,20 @@ func (c *Context) GenerateSSHKey(keyName string, passphrase string) (*KeyPair, e return &keyPair, nil } -func (c *Context) sshDir() (string, error) { - if c.ConfigDir != "" { - return c.ConfigDir, nil +func (c *Context) SshDir() (string, error) { + if c.configDir != "" { + return c.configDir, nil } dir, err := config.HomeDirPath(".ssh") if err == nil { - c.ConfigDir = dir + c.configDir = dir } return dir, err } func (c *Context) findKeygen() (string, error) { - if c.KeygenExe != "" { - return c.KeygenExe, nil + if c.keygenExe != "" { + return c.keygenExe, nil } keygenExe, err := safeexec.LookPath("ssh-keygen") @@ -107,7 +116,7 @@ func (c *Context) findKeygen() (string, error) { } if err == nil { - c.KeygenExe = keygenExe + c.keygenExe = keygenExe } return keygenExe, err }