Add check to -i short circuit for the automatic key

This commit is contained in:
Caleb Brose 2022-08-16 10:15:43 -05:00
parent c781ea520a
commit d8b06ae8af
2 changed files with 40 additions and 12 deletions

View file

@ -246,7 +246,11 @@ func useAutomaticSSHKeys(
arg := args[i]
if arg == "-i" {
// User specified the identity file so just trust it is correct
if i+1 < len(args) && path.Base(args[i+1]) == automaticPrivateKeyName {
return true, nil
}
// User specified a custom identity file so just trust it is correct
return false, nil
}
@ -550,13 +554,11 @@ func (a *App) printOpenSSHConfig(ctx context.Context, opts sshOptions) (err erro
return fmt.Errorf("error formatting template: %w", err)
}
sshDir, err := config.HomeDirPath(".ssh")
automaticIdentityFilePath, err := automaticPrivateKeyPath()
if err != nil {
return fmt.Errorf("error finding .ssh directory: %w", err)
}
automaticIdentityFilePath := path.Join(sshDir, automaticPrivateKeyName)
ghExec := a.executable.Executable()
for result := range sshUsers {
if result.err != nil {
@ -599,6 +601,15 @@ func (a *App) printOpenSSHConfig(ctx context.Context, opts sshOptions) (err erro
return status
}
func automaticPrivateKeyPath() (string, error) {
sshDir, err := config.HomeDirPath(".ssh")
if err != nil {
return "", err
}
return path.Join(sshDir, automaticPrivateKeyName), nil
}
type cpOptions struct {
sshOptions
recursive bool // -r

View file

@ -125,16 +125,33 @@ func TestAutomaticSSHKeyPairs(t *testing.T) {
}
}
func TestDontUseAutomaticSSHKeysWithCustomPrivateKey(t *testing.T) {
args := []string{"-i", "private-key"}
result, err := useAutomaticSSHKeys(context.Background(), ssh.Context{}, nil, args, sshOptions{})
if err != nil {
t.Errorf("Unexpected error from useAutomaticSSHKeys: %v", err)
func TestUseAutomaticSSHKeysIdentityFileArg(t *testing.T) {
tests := []struct {
identityFileArg string
wantResult bool
}{
{"custom-private-key", false},
{automaticPrivateKeyName, true},
{"", false}, // Edge case check for missing arg value
}
if result != false {
t.Errorf("Want useAutomaticSSHKeys to be false, got true")
for _, tt := range tests {
t.Logf("%v", tt.identityFileArg)
args := []string{"-i"}
if tt.identityFileArg != "" {
args = append(args, tt.identityFileArg)
}
result, err := useAutomaticSSHKeys(context.Background(), ssh.Context{}, nil, args, sshOptions{})
if err != nil {
t.Errorf("Unexpected error from useAutomaticSSHKeys: %v", err)
}
if result != tt.wantResult {
t.Errorf("Want useAutomaticSSHKeys to be %v, got %v", tt.wantResult, result)
}
}
}