From d8b06ae8af4e4081382da7d463b6dbace4acfdc2 Mon Sep 17 00:00:00 2001 From: Caleb Brose <5447118+cmbrose@users.noreply.github.com> Date: Tue, 16 Aug 2022 10:15:43 -0500 Subject: [PATCH] Add check to -i short circuit for the automatic key --- pkg/cmd/codespace/ssh.go | 19 +++++++++++++++---- pkg/cmd/codespace/ssh_test.go | 33 +++++++++++++++++++++++++-------- 2 files changed, 40 insertions(+), 12 deletions(-) diff --git a/pkg/cmd/codespace/ssh.go b/pkg/cmd/codespace/ssh.go index 179514957..a41832153 100644 --- a/pkg/cmd/codespace/ssh.go +++ b/pkg/cmd/codespace/ssh.go @@ -246,7 +246,11 @@ func useAutomaticSSHKeys( arg := args[i] if arg == "-i" { - // User specified the identity file so just trust it is correct + if i+1 < len(args) && path.Base(args[i+1]) == automaticPrivateKeyName { + return true, nil + } + + // User specified a custom identity file so just trust it is correct return false, nil } @@ -550,13 +554,11 @@ func (a *App) printOpenSSHConfig(ctx context.Context, opts sshOptions) (err erro return fmt.Errorf("error formatting template: %w", err) } - sshDir, err := config.HomeDirPath(".ssh") + automaticIdentityFilePath, err := automaticPrivateKeyPath() if err != nil { return fmt.Errorf("error finding .ssh directory: %w", err) } - automaticIdentityFilePath := path.Join(sshDir, automaticPrivateKeyName) - ghExec := a.executable.Executable() for result := range sshUsers { if result.err != nil { @@ -599,6 +601,15 @@ func (a *App) printOpenSSHConfig(ctx context.Context, opts sshOptions) (err erro return status } +func automaticPrivateKeyPath() (string, error) { + sshDir, err := config.HomeDirPath(".ssh") + if err != nil { + return "", err + } + + return path.Join(sshDir, automaticPrivateKeyName), nil +} + type cpOptions struct { sshOptions recursive bool // -r diff --git a/pkg/cmd/codespace/ssh_test.go b/pkg/cmd/codespace/ssh_test.go index 7c0a02b47..7e0137641 100644 --- a/pkg/cmd/codespace/ssh_test.go +++ b/pkg/cmd/codespace/ssh_test.go @@ -125,16 +125,33 @@ func TestAutomaticSSHKeyPairs(t *testing.T) { } } -func TestDontUseAutomaticSSHKeysWithCustomPrivateKey(t *testing.T) { - args := []string{"-i", "private-key"} - result, err := useAutomaticSSHKeys(context.Background(), ssh.Context{}, nil, args, sshOptions{}) - - if err != nil { - t.Errorf("Unexpected error from useAutomaticSSHKeys: %v", err) +func TestUseAutomaticSSHKeysIdentityFileArg(t *testing.T) { + tests := []struct { + identityFileArg string + wantResult bool + }{ + {"custom-private-key", false}, + {automaticPrivateKeyName, true}, + {"", false}, // Edge case check for missing arg value } - if result != false { - t.Errorf("Want useAutomaticSSHKeys to be false, got true") + for _, tt := range tests { + t.Logf("%v", tt.identityFileArg) + + args := []string{"-i"} + if tt.identityFileArg != "" { + args = append(args, tt.identityFileArg) + } + + result, err := useAutomaticSSHKeys(context.Background(), ssh.Context{}, nil, args, sshOptions{}) + + if err != nil { + t.Errorf("Unexpected error from useAutomaticSSHKeys: %v", err) + } + + if result != tt.wantResult { + t.Errorf("Want useAutomaticSSHKeys to be %v, got %v", tt.wantResult, result) + } } }