diff --git a/internal/codespaces/ssh.go b/internal/codespaces/ssh.go index 33fbd092a..e99f8971d 100644 --- a/internal/codespaces/ssh.go +++ b/internal/codespaces/ssh.go @@ -67,23 +67,19 @@ func newSSHCommand(ctx context.Context, port int, dst string, cmdArgs []string) // parseSSHArgs parses SSH arguments into two distinct slices of flags // and command. It returns an error if flags are found after a command // or if a unary flag is provided without an argument. -func parseSSHArgs(args []string) ([]string, []string, error) { - var ( - cmdArgs []string - command []string - ) - +func parseSSHArgs(args []string) (cmdArgs []string, command []string, err error) { for i := 0; i < len(args); i++ { arg := args[i] - if strings.HasPrefix(arg, "-") { - if command != nil { - return nil, nil, fmt.Errorf("invalid flag after command: %s", arg) - } + if command != nil { + command = append(command, arg) + continue + } + if strings.HasPrefix(arg, "-") { cmdArgs = append(cmdArgs, arg) - if strings.Contains("bcDeFIiLlmOopRSWw", arg[1:2]) { + if len(arg) == 2 && strings.Contains("bcDeFIiLlmOopRSWw", arg[1:2]) { if i++; i == len(args) { - return nil, nil, fmt.Errorf("invalid unary flag without argument: %s", arg) + return nil, nil, fmt.Errorf("ssh flag: %s requires an argument", arg) } cmdArgs = append(cmdArgs, args[i]) diff --git a/internal/codespaces/ssh_test.go b/internal/codespaces/ssh_test.go index 04d52b090..5450adf1a 100644 --- a/internal/codespaces/ssh_test.go +++ b/internal/codespaces/ssh_test.go @@ -1,12 +1,16 @@ package codespaces -import "testing" +import ( + "fmt" + "testing" +) func TestParseSSHArgs(t *testing.T) { type testCase struct { Args []string ParsedArgs []string Command []string + Error bool } testCases := []testCase{ @@ -50,37 +54,39 @@ func TestParseSSHArgs(t *testing.T) { ParsedArgs: []string{"-L", "-l"}, Command: nil, }, + { + Args: []string{"-v", "echo", "-n", "test"}, + ParsedArgs: []string{"-v"}, + Command: []string{"echo", "-n", "test"}, + }, + { + Args: []string{"-b"}, + ParsedArgs: nil, + Command: nil, + Error: true, + }, } for _, tcase := range testCases { args, command, err := parseSSHArgs(tcase.Args) - if err != nil { - t.Errorf("received unexpected error: %w", err) + if err != nil && !tcase.Error { + t.Errorf("unexpected error: %v on test case: %#v", err, tcase) + continue } - if len(args) != len(tcase.ParsedArgs) { - t.Fatalf("args do not match length of expected args. %#v, got '%d'", tcase, len(args)) - } - if len(command) != len(tcase.Command) { - t.Fatalf("command dooes not match length of expected command. %#v, got '%d'", tcase, len(command)) + if tcase.Error && err == nil { + t.Errorf("expected error and got nil: %#v", tcase) + continue } - for i, arg := range args { - if arg != tcase.ParsedArgs[i] { - t.Fatalf("arg does not match expected parsed arg. %v, got '%s'", tcase, arg) - } + argsStr, parsedArgsStr := fmt.Sprintf("%s", args), fmt.Sprintf("%s", tcase.ParsedArgs) + if argsStr != parsedArgsStr { + t.Errorf("args do not match parsed args. got: '%s', expected: '%s'", argsStr, parsedArgsStr) } - for i, c := range command { - if c != tcase.Command[i] { - t.Fatalf("command does not match expected command. %v, got: '%v'", tcase, command) - } + + commandStr, parsedCommandStr := fmt.Sprintf("%s", command), fmt.Sprintf("%s", tcase.Command) + if commandStr != parsedCommandStr { + t.Errorf("command does not match parsed command. got: '%s', expected: '%s'", commandStr, parsedCommandStr) } } } - -func TestParseSSHArgsError(t *testing.T) { - _, _, err := parseSSHArgs([]string{"-X", "test", "-Y"}) - if err == nil { - t.Error("expected an error for invalid args") - } -}