diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index f65fa1109..19528061a 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -82,9 +82,12 @@ func logs(ctx context.Context, log *output.Logger, codespaceName string, follow } dst := fmt.Sprintf("%s@localhost", sshUser) - cmd := codespaces.NewRemoteCommand( + cmd, err := codespaces.NewRemoteCommand( ctx, localPort, dst, fmt.Sprintf("%s /workspaces/.codespaces/.persistedshare/creation.log", cmdType), ) + if err != nil { + return fmt.Errorf("remote command: %w", err) + } tunnelClosed := make(chan error, 1) go func() { diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 6d5f2376b..4ece84d91 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -18,10 +18,10 @@ func newSSHCmd() *cobra.Command { var sshServerPort int sshCmd := &cobra.Command{ - Use: "ssh", + Use: "ssh [flags] [--] [ssh-flags] [command]", Short: "SSH into a codespace", RunE: func(cmd *cobra.Command, args []string) error { - return ssh(context.Background(), sshProfile, codespaceName, sshServerPort) + return ssh(context.Background(), args, sshProfile, codespaceName, sshServerPort) }, } @@ -36,7 +36,7 @@ func init() { rootCmd.AddCommand(newSSHCmd()) } -func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPort int) error { +func ssh(ctx context.Context, sshArgs []string, sshProfile, codespaceName string, localSSHServerPort int) error { // Ensure all child tasks (e.g. port forwarding) terminate before return. ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -89,7 +89,7 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPo shellClosed := make(chan error, 1) go func() { - shellClosed <- codespaces.Shell(ctx, log, localSSHServerPort, connectDestination, usingCustomPort) + shellClosed <- codespaces.Shell(ctx, log, sshArgs, localSSHServerPort, connectDestination, usingCustomPort) }() select { diff --git a/internal/codespaces/ssh.go b/internal/codespaces/ssh.go index d5ea876da..36c8bf5b2 100644 --- a/internal/codespaces/ssh.go +++ b/internal/codespaces/ssh.go @@ -2,6 +2,7 @@ package codespaces import ( "context" + "fmt" "os" "os/exec" "strconv" @@ -11,8 +12,11 @@ import ( // Shell runs an interactive secure shell over an existing // port-forwarding session. It runs until the shell is terminated // (including by cancellation of the context). -func Shell(ctx context.Context, log logger, port int, destination string, usingCustomPort bool) error { - cmd, connArgs := newSSHCommand(ctx, port, destination, "") +func Shell(ctx context.Context, log logger, sshArgs []string, port int, destination string, usingCustomPort bool) error { + cmd, connArgs, err := newSSHCommand(ctx, port, destination, sshArgs) + if err != nil { + return fmt.Errorf("failed to create ssh command: %w", err) + } if usingCustomPort { log.Println("Connection Details: ssh " + destination + " " + strings.Join(connArgs, " ")) @@ -23,25 +27,33 @@ func Shell(ctx context.Context, log logger, port int, destination string, usingC // NewRemoteCommand returns an exec.Cmd that will securely run a shell // command on the remote machine. -func NewRemoteCommand(ctx context.Context, tunnelPort int, destination, command string) *exec.Cmd { - cmd, _ := newSSHCommand(ctx, tunnelPort, destination, command) - return cmd +func NewRemoteCommand(ctx context.Context, tunnelPort int, destination string, sshArgs ...string) (*exec.Cmd, error) { + cmd, _, err := newSSHCommand(ctx, tunnelPort, destination, sshArgs) + return cmd, err } // newSSHCommand populates an exec.Cmd to run a command (or if blank, // an interactive shell) over ssh. -func newSSHCommand(ctx context.Context, port int, dst, command string) (*exec.Cmd, []string) { +func newSSHCommand(ctx context.Context, port int, dst string, cmdArgs []string) (*exec.Cmd, []string, error) { connArgs := []string{"-p", strconv.Itoa(port), "-o", "NoHostAuthenticationForLocalhost=yes"} - cmdArgs := []string{dst, "-C"} // Always use Compression - if command == "" { - // if we are in a shell send X11 and X11Trust - cmdArgs = append(cmdArgs, "-X", "-Y") + // The ssh command syntax is: ssh [flags] user@host command [args...] + // There is no way to specify the user@host destination as a flag. + // Unfortunately, that means we need to know which user-provided words are + // SSH flags and which are command arguments so that we can place + // them before or after the destination, and that means we need to know all + // the flags and their arities. + cmdArgs, command, err := parseSSHArgs(cmdArgs) + if err != nil { + return nil, nil, err } cmdArgs = append(cmdArgs, connArgs...) - if command != "" { - cmdArgs = append(cmdArgs, command) + cmdArgs = append(cmdArgs, "-C") // Compression + cmdArgs = append(cmdArgs, dst) // user@host + + if command != nil { + cmdArgs = append(cmdArgs, command...) } cmd := exec.CommandContext(ctx, "ssh", cmdArgs...) @@ -49,5 +61,30 @@ func newSSHCommand(ctx context.Context, port int, dst, command string) (*exec.Cm cmd.Stdin = os.Stdin cmd.Stderr = os.Stderr - return cmd, connArgs + return cmd, connArgs, nil +} + +// parseSSHArgs parses SSH arguments into two distinct slices of flags and command. +// It returns an error if a unary flag is provided without an argument. +func parseSSHArgs(args []string) (cmdArgs, command []string, err error) { + for i := 0; i < len(args); i++ { + arg := args[i] + + // if we've started parsing the command, set it to the rest of the args + if !strings.HasPrefix(arg, "-") { + command = args[i:] + break + } + + cmdArgs = append(cmdArgs, arg) + if len(arg) == 2 && strings.Contains("bcDeFIiLlmOopRSWw", arg[1:2]) { + if i++; i == len(args) { + return nil, nil, fmt.Errorf("ssh flag: %s requires an argument", arg) + } + + cmdArgs = append(cmdArgs, args[i]) + } + } + + return cmdArgs, command, nil } diff --git a/internal/codespaces/ssh_test.go b/internal/codespaces/ssh_test.go new file mode 100644 index 000000000..c804f6000 --- /dev/null +++ b/internal/codespaces/ssh_test.go @@ -0,0 +1,105 @@ +package codespaces + +import ( + "fmt" + "testing" +) + +func TestParseSSHArgs(t *testing.T) { + type testCase struct { + Args []string + ParsedArgs []string + Command []string + Error string + } + + testCases := []testCase{ + {}, // empty test case + { + Args: []string{"-X", "-Y"}, + ParsedArgs: []string{"-X", "-Y"}, + Command: nil, + }, + { + Args: []string{"-X", "-Y", "-o", "someoption=test"}, + ParsedArgs: []string{"-X", "-Y", "-o", "someoption=test"}, + Command: nil, + }, + { + Args: []string{"-X", "-Y", "-o", "someoption=test", "somecommand"}, + ParsedArgs: []string{"-X", "-Y", "-o", "someoption=test"}, + Command: []string{"somecommand"}, + }, + { + Args: []string{"-X", "-Y", "-o", "someoption=test", "echo", "test"}, + ParsedArgs: []string{"-X", "-Y", "-o", "someoption=test"}, + Command: []string{"echo", "test"}, + }, + { + Args: []string{"somecommand"}, + ParsedArgs: []string{}, + Command: []string{"somecommand"}, + }, + { + Args: []string{"echo", "test"}, + ParsedArgs: []string{}, + Command: []string{"echo", "test"}, + }, + { + Args: []string{"-v", "echo", "hello", "world"}, + ParsedArgs: []string{"-v"}, + Command: []string{"echo", "hello", "world"}, + }, + { + Args: []string{"-L", "-l"}, + ParsedArgs: []string{"-L", "-l"}, + Command: nil, + }, + { + Args: []string{"-v", "echo", "-n", "test"}, + ParsedArgs: []string{"-v"}, + Command: []string{"echo", "-n", "test"}, + }, + { + Args: []string{"-v", "echo", "-b", "test"}, + ParsedArgs: []string{"-v"}, + Command: []string{"echo", "-b", "test"}, + }, + { + Args: []string{"-b"}, + ParsedArgs: nil, + Command: nil, + Error: "ssh flag: -b requires an argument", + }, + } + + for _, tcase := range testCases { + args, command, err := parseSSHArgs(tcase.Args) + if tcase.Error != "" { + if err == nil { + t.Errorf("expected error and got nil: %#v", tcase) + } + + if err.Error() != tcase.Error { + t.Errorf("error does not match expected error, got: '%s', expected: '%s'", err.Error(), tcase.Error) + } + + continue + } + + if err != nil { + t.Errorf("unexpected error: %v on test case: %#v", err, tcase) + continue + } + + argsStr, parsedArgsStr := fmt.Sprintf("%s", args), fmt.Sprintf("%s", tcase.ParsedArgs) + if argsStr != parsedArgsStr { + t.Errorf("args do not match parsed args. got: '%s', expected: '%s'", argsStr, parsedArgsStr) + } + + commandStr, parsedCommandStr := fmt.Sprintf("%s", command), fmt.Sprintf("%s", tcase.Command) + if commandStr != parsedCommandStr { + t.Errorf("command does not match parsed command. got: '%s', expected: '%s'", commandStr, parsedCommandStr) + } + } +} diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index 99683b51d..408f11941 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -89,10 +89,14 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u } func getPostCreateOutput(ctx context.Context, tunnelPort int, codespace *api.Codespace, user string) ([]PostCreateState, error) { - cmd := NewRemoteCommand( + cmd, err := NewRemoteCommand( ctx, tunnelPort, fmt.Sprintf("%s@localhost", user), "cat /workspaces/.codespaces/shared/postCreateOutput.json", ) + if err != nil { + return nil, fmt.Errorf("remote command: %w", err) + } + stdout := new(bytes.Buffer) cmd.Stdout = stdout if err := cmd.Run(); err != nil {