diff --git a/internal/codespaces/ssh.go b/internal/codespaces/ssh.go index 4c6ccb6d7..36c8bf5b2 100644 --- a/internal/codespaces/ssh.go +++ b/internal/codespaces/ssh.go @@ -70,21 +70,20 @@ func parseSSHArgs(args []string) (cmdArgs, command []string, err error) { for i := 0; i < len(args); i++ { arg := args[i] - if strings.HasPrefix(arg, "-") { - cmdArgs = append(cmdArgs, arg) - if len(arg) == 2 && strings.Contains("bcDeFIiLlmOopRSWw", arg[1:2]) { - if i++; i == len(args) { - return nil, nil, fmt.Errorf("ssh flag: %s requires an argument", arg) - } - - cmdArgs = append(cmdArgs, args[i]) - } - continue + // if we've started parsing the command, set it to the rest of the args + if !strings.HasPrefix(arg, "-") { + command = args[i:] + break } - // if we've started parsing the command, set it to the rest of the args - command = args[i:] - break + cmdArgs = append(cmdArgs, arg) + if len(arg) == 2 && strings.Contains("bcDeFIiLlmOopRSWw", arg[1:2]) { + if i++; i == len(args) { + return nil, nil, fmt.Errorf("ssh flag: %s requires an argument", arg) + } + + cmdArgs = append(cmdArgs, args[i]) + } } return cmdArgs, command, nil diff --git a/internal/codespaces/ssh_test.go b/internal/codespaces/ssh_test.go index c3e1b4c0a..c804f6000 100644 --- a/internal/codespaces/ssh_test.go +++ b/internal/codespaces/ssh_test.go @@ -10,7 +10,7 @@ func TestParseSSHArgs(t *testing.T) { Args []string ParsedArgs []string Command []string - Error bool + Error string } testCases := []testCase{ @@ -69,19 +69,26 @@ func TestParseSSHArgs(t *testing.T) { Args: []string{"-b"}, ParsedArgs: nil, Command: nil, - Error: true, + Error: "ssh flag: -b requires an argument", }, } for _, tcase := range testCases { args, command, err := parseSSHArgs(tcase.Args) - if !tcase.Error && err != nil { - t.Errorf("unexpected error: %v on test case: %#v", err, tcase) + if tcase.Error != "" { + if err == nil { + t.Errorf("expected error and got nil: %#v", tcase) + } + + if err.Error() != tcase.Error { + t.Errorf("error does not match expected error, got: '%s', expected: '%s'", err.Error(), tcase.Error) + } + continue } - if tcase.Error && err == nil { - t.Errorf("expected error and got nil: %#v", tcase) + if err != nil { + t.Errorf("unexpected error: %v on test case: %#v", err, tcase) continue }