Add check to -i short circuit for the automatic key
This commit is contained in:
parent
c781ea520a
commit
d8b06ae8af
2 changed files with 40 additions and 12 deletions
|
|
@ -246,7 +246,11 @@ func useAutomaticSSHKeys(
|
|||
arg := args[i]
|
||||
|
||||
if arg == "-i" {
|
||||
// User specified the identity file so just trust it is correct
|
||||
if i+1 < len(args) && path.Base(args[i+1]) == automaticPrivateKeyName {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// User specified a custom identity file so just trust it is correct
|
||||
return false, nil
|
||||
}
|
||||
|
||||
|
|
@ -550,13 +554,11 @@ func (a *App) printOpenSSHConfig(ctx context.Context, opts sshOptions) (err erro
|
|||
return fmt.Errorf("error formatting template: %w", err)
|
||||
}
|
||||
|
||||
sshDir, err := config.HomeDirPath(".ssh")
|
||||
automaticIdentityFilePath, err := automaticPrivateKeyPath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error finding .ssh directory: %w", err)
|
||||
}
|
||||
|
||||
automaticIdentityFilePath := path.Join(sshDir, automaticPrivateKeyName)
|
||||
|
||||
ghExec := a.executable.Executable()
|
||||
for result := range sshUsers {
|
||||
if result.err != nil {
|
||||
|
|
@ -599,6 +601,15 @@ func (a *App) printOpenSSHConfig(ctx context.Context, opts sshOptions) (err erro
|
|||
return status
|
||||
}
|
||||
|
||||
func automaticPrivateKeyPath() (string, error) {
|
||||
sshDir, err := config.HomeDirPath(".ssh")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return path.Join(sshDir, automaticPrivateKeyName), nil
|
||||
}
|
||||
|
||||
type cpOptions struct {
|
||||
sshOptions
|
||||
recursive bool // -r
|
||||
|
|
|
|||
|
|
@ -125,16 +125,33 @@ func TestAutomaticSSHKeyPairs(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestDontUseAutomaticSSHKeysWithCustomPrivateKey(t *testing.T) {
|
||||
args := []string{"-i", "private-key"}
|
||||
result, err := useAutomaticSSHKeys(context.Background(), ssh.Context{}, nil, args, sshOptions{})
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error from useAutomaticSSHKeys: %v", err)
|
||||
func TestUseAutomaticSSHKeysIdentityFileArg(t *testing.T) {
|
||||
tests := []struct {
|
||||
identityFileArg string
|
||||
wantResult bool
|
||||
}{
|
||||
{"custom-private-key", false},
|
||||
{automaticPrivateKeyName, true},
|
||||
{"", false}, // Edge case check for missing arg value
|
||||
}
|
||||
|
||||
if result != false {
|
||||
t.Errorf("Want useAutomaticSSHKeys to be false, got true")
|
||||
for _, tt := range tests {
|
||||
t.Logf("%v", tt.identityFileArg)
|
||||
|
||||
args := []string{"-i"}
|
||||
if tt.identityFileArg != "" {
|
||||
args = append(args, tt.identityFileArg)
|
||||
}
|
||||
|
||||
result, err := useAutomaticSSHKeys(context.Background(), ssh.Context{}, nil, args, sshOptions{})
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error from useAutomaticSSHKeys: %v", err)
|
||||
}
|
||||
|
||||
if result != tt.wantResult {
|
||||
t.Errorf("Want useAutomaticSSHKeys to be %v, got %v", tt.wantResult, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue