diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index 34c1bc095..1ce008120 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "strings" + "sync" "github.com/AlecAivazis/survey/v2" "github.com/github/ghcs/cmd/ghcs/output" @@ -132,7 +133,27 @@ func deleteByRepo(log *output.Logger, repo string, force bool) error { return fmt.Errorf("error getting codespaces: %w", err) } - var deleted bool + delete := func(name string) error { + token, err := apiClient.GetCodespaceToken(ctx, user.Login, name) + if err != nil { + return fmt.Errorf("error getting codespace token: %w", err) + } + + if err := apiClient.DeleteCodespace(ctx, user, token, name); err != nil { + return fmt.Errorf("error deleting codespace: %w", err) + } + + return nil + } + + // Perform deletions in parallel, for performance, + // and to ensure all are attempted even if any one fails. + var ( + found bool + mu sync.Mutex // guards errs, logger + errs []error + wg sync.WaitGroup + ) for _, c := range codespaces { if !strings.EqualFold(c.RepositoryNWO, repo) { continue @@ -140,32 +161,46 @@ func deleteByRepo(log *output.Logger, repo string, force bool) error { confirmed, err := confirmDeletion(c, force) if err != nil { - return fmt.Errorf("deletion could not be confirmed: %w", err) + mu.Lock() + errs = append(errs, fmt.Errorf("deletion could not be confirmed: %w", err)) + mu.Unlock() + continue } if !confirmed { continue } - deleted = true - - token, err := apiClient.GetCodespaceToken(ctx, user.Login, c.Name) - if err != nil { - return fmt.Errorf("error getting codespace token: %w", err) - } - - if err := apiClient.DeleteCodespace(ctx, user, token, c.Name); err != nil { - return fmt.Errorf("error deleting codespace: %w", err) - } - - log.Printf("Codespace deleted: %s\n", c.Name) + found = true + c := c + wg.Add(1) + go func() { + defer wg.Done() + err := delete(c.Name) + mu.Lock() + defer mu.Unlock() + if err != nil { + errs = append(errs, err) + } else { + log.Printf("Codespace deleted: %s\n", c.Name) + } + }() } - - if !deleted { + if !found { return fmt.Errorf("no codespace was found for repository: %s", repo) } + wg.Wait() - return list(&listOptions{}) + // Return first error, plus count of others. + if errs != nil { + err := errs[0] + if others := len(errs) - 1; others > 0 { + err = fmt.Errorf("%w (+%d more)", err, others) + } + return err + } + + return nil } func confirmDeletion(codespace *api.Codespace, force bool) (bool, error) { diff --git a/cmd/ghcs/list.go b/cmd/ghcs/list.go index ccc150f08..7ee156012 100644 --- a/cmd/ghcs/list.go +++ b/cmd/ghcs/list.go @@ -51,7 +51,7 @@ func list(opts *listOptions) error { table.Append([]string{ codespace.Name, codespace.RepositoryNWO, - codespace.Name + dirtyStar(codespace.Environment.GitStatus), + codespace.Branch + dirtyStar(codespace.Environment.GitStatus), codespace.Environment.State, codespace.CreatedAt, }) diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index 4051cc209..2b50effd1 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -78,9 +78,12 @@ func logs(ctx context.Context, log *output.Logger, codespaceName string, follow } dst := fmt.Sprintf("%s@localhost", sshUser) - cmd := codespaces.NewRemoteCommand( + cmd, err := codespaces.NewRemoteCommand( ctx, localPort, dst, fmt.Sprintf("%s /workspaces/.codespaces/.persistedshare/creation.log", cmdType), ) + if err != nil { + return fmt.Errorf("remote command: %w", err) + } tunnelClosed := make(chan error, 1) go func() { diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 5063e8fc9..e4435853b 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -18,10 +18,10 @@ func newSSHCmd() *cobra.Command { var sshServerPort int sshCmd := &cobra.Command{ - Use: "ssh", + Use: "ssh [flags] [--] [ssh-flags] [command]", Short: "SSH into a codespace", RunE: func(cmd *cobra.Command, args []string) error { - return ssh(context.Background(), sshProfile, codespaceName, sshServerPort) + return ssh(context.Background(), args, sshProfile, codespaceName, sshServerPort) }, } @@ -32,7 +32,7 @@ func newSSHCmd() *cobra.Command { return sshCmd } -func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPort int) error { +func ssh(ctx context.Context, sshArgs []string, sshProfile, codespaceName string, localSSHServerPort int) error { // Ensure all child tasks (e.g. port forwarding) terminate before return. ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -85,7 +85,7 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPo shellClosed := make(chan error, 1) go func() { - shellClosed <- codespaces.Shell(ctx, log, localSSHServerPort, connectDestination, usingCustomPort) + shellClosed <- codespaces.Shell(ctx, log, sshArgs, localSSHServerPort, connectDestination, usingCustomPort) }() select { diff --git a/internal/codespaces/ssh.go b/internal/codespaces/ssh.go index d5ea876da..36c8bf5b2 100644 --- a/internal/codespaces/ssh.go +++ b/internal/codespaces/ssh.go @@ -2,6 +2,7 @@ package codespaces import ( "context" + "fmt" "os" "os/exec" "strconv" @@ -11,8 +12,11 @@ import ( // Shell runs an interactive secure shell over an existing // port-forwarding session. It runs until the shell is terminated // (including by cancellation of the context). -func Shell(ctx context.Context, log logger, port int, destination string, usingCustomPort bool) error { - cmd, connArgs := newSSHCommand(ctx, port, destination, "") +func Shell(ctx context.Context, log logger, sshArgs []string, port int, destination string, usingCustomPort bool) error { + cmd, connArgs, err := newSSHCommand(ctx, port, destination, sshArgs) + if err != nil { + return fmt.Errorf("failed to create ssh command: %w", err) + } if usingCustomPort { log.Println("Connection Details: ssh " + destination + " " + strings.Join(connArgs, " ")) @@ -23,25 +27,33 @@ func Shell(ctx context.Context, log logger, port int, destination string, usingC // NewRemoteCommand returns an exec.Cmd that will securely run a shell // command on the remote machine. -func NewRemoteCommand(ctx context.Context, tunnelPort int, destination, command string) *exec.Cmd { - cmd, _ := newSSHCommand(ctx, tunnelPort, destination, command) - return cmd +func NewRemoteCommand(ctx context.Context, tunnelPort int, destination string, sshArgs ...string) (*exec.Cmd, error) { + cmd, _, err := newSSHCommand(ctx, tunnelPort, destination, sshArgs) + return cmd, err } // newSSHCommand populates an exec.Cmd to run a command (or if blank, // an interactive shell) over ssh. -func newSSHCommand(ctx context.Context, port int, dst, command string) (*exec.Cmd, []string) { +func newSSHCommand(ctx context.Context, port int, dst string, cmdArgs []string) (*exec.Cmd, []string, error) { connArgs := []string{"-p", strconv.Itoa(port), "-o", "NoHostAuthenticationForLocalhost=yes"} - cmdArgs := []string{dst, "-C"} // Always use Compression - if command == "" { - // if we are in a shell send X11 and X11Trust - cmdArgs = append(cmdArgs, "-X", "-Y") + // The ssh command syntax is: ssh [flags] user@host command [args...] + // There is no way to specify the user@host destination as a flag. + // Unfortunately, that means we need to know which user-provided words are + // SSH flags and which are command arguments so that we can place + // them before or after the destination, and that means we need to know all + // the flags and their arities. + cmdArgs, command, err := parseSSHArgs(cmdArgs) + if err != nil { + return nil, nil, err } cmdArgs = append(cmdArgs, connArgs...) - if command != "" { - cmdArgs = append(cmdArgs, command) + cmdArgs = append(cmdArgs, "-C") // Compression + cmdArgs = append(cmdArgs, dst) // user@host + + if command != nil { + cmdArgs = append(cmdArgs, command...) } cmd := exec.CommandContext(ctx, "ssh", cmdArgs...) @@ -49,5 +61,30 @@ func newSSHCommand(ctx context.Context, port int, dst, command string) (*exec.Cm cmd.Stdin = os.Stdin cmd.Stderr = os.Stderr - return cmd, connArgs + return cmd, connArgs, nil +} + +// parseSSHArgs parses SSH arguments into two distinct slices of flags and command. +// It returns an error if a unary flag is provided without an argument. +func parseSSHArgs(args []string) (cmdArgs, command []string, err error) { + for i := 0; i < len(args); i++ { + arg := args[i] + + // if we've started parsing the command, set it to the rest of the args + if !strings.HasPrefix(arg, "-") { + command = args[i:] + break + } + + cmdArgs = append(cmdArgs, arg) + if len(arg) == 2 && strings.Contains("bcDeFIiLlmOopRSWw", arg[1:2]) { + if i++; i == len(args) { + return nil, nil, fmt.Errorf("ssh flag: %s requires an argument", arg) + } + + cmdArgs = append(cmdArgs, args[i]) + } + } + + return cmdArgs, command, nil } diff --git a/internal/codespaces/ssh_test.go b/internal/codespaces/ssh_test.go new file mode 100644 index 000000000..c804f6000 --- /dev/null +++ b/internal/codespaces/ssh_test.go @@ -0,0 +1,105 @@ +package codespaces + +import ( + "fmt" + "testing" +) + +func TestParseSSHArgs(t *testing.T) { + type testCase struct { + Args []string + ParsedArgs []string + Command []string + Error string + } + + testCases := []testCase{ + {}, // empty test case + { + Args: []string{"-X", "-Y"}, + ParsedArgs: []string{"-X", "-Y"}, + Command: nil, + }, + { + Args: []string{"-X", "-Y", "-o", "someoption=test"}, + ParsedArgs: []string{"-X", "-Y", "-o", "someoption=test"}, + Command: nil, + }, + { + Args: []string{"-X", "-Y", "-o", "someoption=test", "somecommand"}, + ParsedArgs: []string{"-X", "-Y", "-o", "someoption=test"}, + Command: []string{"somecommand"}, + }, + { + Args: []string{"-X", "-Y", "-o", "someoption=test", "echo", "test"}, + ParsedArgs: []string{"-X", "-Y", "-o", "someoption=test"}, + Command: []string{"echo", "test"}, + }, + { + Args: []string{"somecommand"}, + ParsedArgs: []string{}, + Command: []string{"somecommand"}, + }, + { + Args: []string{"echo", "test"}, + ParsedArgs: []string{}, + Command: []string{"echo", "test"}, + }, + { + Args: []string{"-v", "echo", "hello", "world"}, + ParsedArgs: []string{"-v"}, + Command: []string{"echo", "hello", "world"}, + }, + { + Args: []string{"-L", "-l"}, + ParsedArgs: []string{"-L", "-l"}, + Command: nil, + }, + { + Args: []string{"-v", "echo", "-n", "test"}, + ParsedArgs: []string{"-v"}, + Command: []string{"echo", "-n", "test"}, + }, + { + Args: []string{"-v", "echo", "-b", "test"}, + ParsedArgs: []string{"-v"}, + Command: []string{"echo", "-b", "test"}, + }, + { + Args: []string{"-b"}, + ParsedArgs: nil, + Command: nil, + Error: "ssh flag: -b requires an argument", + }, + } + + for _, tcase := range testCases { + args, command, err := parseSSHArgs(tcase.Args) + if tcase.Error != "" { + if err == nil { + t.Errorf("expected error and got nil: %#v", tcase) + } + + if err.Error() != tcase.Error { + t.Errorf("error does not match expected error, got: '%s', expected: '%s'", err.Error(), tcase.Error) + } + + continue + } + + if err != nil { + t.Errorf("unexpected error: %v on test case: %#v", err, tcase) + continue + } + + argsStr, parsedArgsStr := fmt.Sprintf("%s", args), fmt.Sprintf("%s", tcase.ParsedArgs) + if argsStr != parsedArgsStr { + t.Errorf("args do not match parsed args. got: '%s', expected: '%s'", argsStr, parsedArgsStr) + } + + commandStr, parsedCommandStr := fmt.Sprintf("%s", command), fmt.Sprintf("%s", tcase.Command) + if commandStr != parsedCommandStr { + t.Errorf("command does not match parsed command. got: '%s', expected: '%s'", commandStr, parsedCommandStr) + } + } +} diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index 99683b51d..408f11941 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -89,10 +89,14 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u } func getPostCreateOutput(ctx context.Context, tunnelPort int, codespace *api.Codespace, user string) ([]PostCreateState, error) { - cmd := NewRemoteCommand( + cmd, err := NewRemoteCommand( ctx, tunnelPort, fmt.Sprintf("%s@localhost", user), "cat /workspaces/.codespaces/shared/postCreateOutput.json", ) + if err != nil { + return nil, fmt.Errorf("remote command: %w", err) + } + stdout := new(bytes.Buffer) cmd.Stdout = stdout if err := cmd.Run(); err != nil {