diff --git a/cmd/ghcs/code.go b/cmd/ghcs/code.go new file mode 100644 index 000000000..cfcd989e2 --- /dev/null +++ b/cmd/ghcs/code.go @@ -0,0 +1,65 @@ +package ghcs + +import ( + "context" + "fmt" + "net/url" + + "github.com/skratchdot/open-golang/open" + "github.com/spf13/cobra" +) + +func newCodeCmd(app *App) *cobra.Command { + var ( + codespace string + useInsiders bool + ) + + codeCmd := &cobra.Command{ + Use: "code", + Short: "Open a codespace in VS Code", + Args: noArgsConstraint, + RunE: func(cmd *cobra.Command, args []string) error { + return app.VSCode(cmd.Context(), codespace, useInsiders) + }, + } + + codeCmd.Flags().StringVarP(&codespace, "codespace", "c", "", "Name of the codespace") + codeCmd.Flags().BoolVar(&useInsiders, "insiders", false, "Use the insiders version of VS Code") + + return codeCmd +} + +// VSCode opens a codespace in the local VS VSCode application. +func (a *App) VSCode(ctx context.Context, codespaceName string, useInsiders bool) error { + user, err := a.apiClient.GetUser(ctx) + if err != nil { + return fmt.Errorf("error getting user: %w", err) + } + + if codespaceName == "" { + codespace, err := chooseCodespace(ctx, a.apiClient, user) + if err != nil { + if err == errNoCodespaces { + return err + } + return fmt.Errorf("error choosing codespace: %w", err) + } + codespaceName = codespace.Name + } + + url := vscodeProtocolURL(codespaceName, useInsiders) + if err := open.Run(url); err != nil { + return fmt.Errorf("error opening vscode URL %s: %s. (Is VS Code installed?)", url, err) + } + + return nil +} + +func vscodeProtocolURL(codespaceName string, useInsiders bool) string { + application := "vscode" + if useInsiders { + application = "vscode-insiders" + } + return fmt.Sprintf("%s://github.codespaces/connect?name=%s", application, url.QueryEscape(codespaceName)) +} diff --git a/cmd/ghcs/common.go b/cmd/ghcs/common.go new file mode 100644 index 000000000..d69ffd964 --- /dev/null +++ b/cmd/ghcs/common.go @@ -0,0 +1,185 @@ +package ghcs + +// This file defines functions common to the entire ghcs command set. + +import ( + "context" + "errors" + "fmt" + "io" + "os" + "sort" + + "github.com/AlecAivazis/survey/v2" + "github.com/AlecAivazis/survey/v2/terminal" + "github.com/cli/cli/v2/cmd/ghcs/output" + "github.com/cli/cli/v2/internal/api" + "github.com/spf13/cobra" + "golang.org/x/term" +) + +type App struct { + apiClient apiClient + logger *output.Logger +} + +func NewApp(logger *output.Logger, apiClient apiClient) *App { + return &App{ + apiClient: apiClient, + logger: logger, + } +} + +//go:generate moq -fmt goimports -rm -skip-ensure -out mock_api.go . apiClient +type apiClient interface { + GetUser(ctx context.Context) (*api.User, error) + GetCodespaceToken(ctx context.Context, user, name string) (string, error) + GetCodespace(ctx context.Context, token, user, name string) (*api.Codespace, error) + ListCodespaces(ctx context.Context, user string) ([]*api.Codespace, error) + DeleteCodespace(ctx context.Context, user, name string) error + StartCodespace(ctx context.Context, token string, codespace *api.Codespace) error + CreateCodespace(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error) + GetRepository(ctx context.Context, nwo string) (*api.Repository, error) + AuthorizedKeys(ctx context.Context, user string) ([]byte, error) + GetCodespaceRegionLocation(ctx context.Context) (string, error) + GetCodespacesSKUs(ctx context.Context, user *api.User, repository *api.Repository, branch, location string) ([]*api.SKU, error) + GetCodespaceRepositoryContents(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error) +} + +var errNoCodespaces = errors.New("you have no codespaces") + +func chooseCodespace(ctx context.Context, apiClient apiClient, user *api.User) (*api.Codespace, error) { + codespaces, err := apiClient.ListCodespaces(ctx, user.Login) + if err != nil { + return nil, fmt.Errorf("error getting codespaces: %w", err) + } + return chooseCodespaceFromList(ctx, codespaces) +} + +func chooseCodespaceFromList(ctx context.Context, codespaces []*api.Codespace) (*api.Codespace, error) { + if len(codespaces) == 0 { + return nil, errNoCodespaces + } + + sort.Slice(codespaces, func(i, j int) bool { + return codespaces[i].CreatedAt > codespaces[j].CreatedAt + }) + + codespacesByName := make(map[string]*api.Codespace) + codespacesNames := make([]string, 0, len(codespaces)) + for _, codespace := range codespaces { + codespacesByName[codespace.Name] = codespace + codespacesNames = append(codespacesNames, codespace.Name) + } + + sshSurvey := []*survey.Question{ + { + Name: "codespace", + Prompt: &survey.Select{ + Message: "Choose codespace:", + Options: codespacesNames, + Default: codespacesNames[0], + }, + Validate: survey.Required, + }, + } + + var answers struct { + Codespace string + } + if err := ask(sshSurvey, &answers); err != nil { + return nil, fmt.Errorf("error getting answers: %w", err) + } + + codespace := codespacesByName[answers.Codespace] + return codespace, nil +} + +// getOrChooseCodespace prompts the user to choose a codespace if the codespaceName is empty. +// It then fetches the codespace token and the codespace record. +func getOrChooseCodespace(ctx context.Context, apiClient apiClient, user *api.User, codespaceName string) (codespace *api.Codespace, token string, err error) { + if codespaceName == "" { + codespace, err = chooseCodespace(ctx, apiClient, user) + if err != nil { + if err == errNoCodespaces { + return nil, "", err + } + return nil, "", fmt.Errorf("choosing codespace: %w", err) + } + codespaceName = codespace.Name + + token, err = apiClient.GetCodespaceToken(ctx, user.Login, codespaceName) + if err != nil { + return nil, "", fmt.Errorf("getting codespace token: %w", err) + } + } else { + token, err = apiClient.GetCodespaceToken(ctx, user.Login, codespaceName) + if err != nil { + return nil, "", fmt.Errorf("getting codespace token for given codespace: %w", err) + } + + codespace, err = apiClient.GetCodespace(ctx, token, user.Login, codespaceName) + if err != nil { + return nil, "", fmt.Errorf("getting full codespace details: %w", err) + } + } + + return codespace, token, nil +} + +func safeClose(closer io.Closer, err *error) { + if closeErr := closer.Close(); *err == nil { + *err = closeErr + } +} + +// hasTTY indicates whether the process connected to a terminal. +// It is not portable to assume stdin/stdout are fds 0 and 1. +var hasTTY = term.IsTerminal(int(os.Stdin.Fd())) && term.IsTerminal(int(os.Stdout.Fd())) + +// ask asks survey questions on the terminal, using standard options. +// It fails unless hasTTY, but ideally callers should avoid calling it in that case. +func ask(qs []*survey.Question, response interface{}) error { + if !hasTTY { + return fmt.Errorf("no terminal") + } + err := survey.Ask(qs, response, survey.WithShowCursor(true)) + // The survey package temporarily clears the terminal's ISIG mode bit + // (see tcsetattr(3)) so the QUIT button (Ctrl-C) is reported as + // ASCII \x03 (ETX) instead of delivering SIGINT to the application. + // So we have to serve ourselves the SIGINT. + // + // https://github.com/AlecAivazis/survey/#why-isnt-ctrl-c-working + if err == terminal.InterruptErr { + self, _ := os.FindProcess(os.Getpid()) + _ = self.Signal(os.Interrupt) // assumes POSIX + + // Suspend the goroutine, to avoid a race between + // return from main and async delivery of INT signal. + select {} + } + return err +} + +// checkAuthorizedKeys reports an error if the user has not registered any SSH keys; +// see https://github.com/cli/cli/v2/issues/166#issuecomment-921769703. +// The check is not required for security but it improves the error message. +func checkAuthorizedKeys(ctx context.Context, client apiClient, user string) error { + keys, err := client.AuthorizedKeys(ctx, user) + if err != nil { + return fmt.Errorf("failed to read GitHub-authorized SSH keys for %s: %w", user, err) + } + if len(keys) == 0 { + return fmt.Errorf("user %s has no GitHub-authorized SSH keys", user) + } + return nil // success +} + +var ErrTooManyArgs = errors.New("the command accepts no arguments") + +func noArgsConstraint(cmd *cobra.Command, args []string) error { + if len(args) > 0 { + return ErrTooManyArgs + } + return nil +} diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go new file mode 100644 index 000000000..c184a0ff7 --- /dev/null +++ b/cmd/ghcs/create.go @@ -0,0 +1,297 @@ +package ghcs + +import ( + "context" + "errors" + "fmt" + "os" + "strings" + + "github.com/AlecAivazis/survey/v2" + "github.com/cli/cli/v2/cmd/ghcs/output" + "github.com/cli/cli/v2/internal/api" + "github.com/cli/cli/v2/internal/codespaces" + "github.com/fatih/camelcase" + "github.com/spf13/cobra" +) + +type createOptions struct { + repo string + branch string + machine string + showStatus bool +} + +func newCreateCmd(app *App) *cobra.Command { + opts := createOptions{} + + createCmd := &cobra.Command{ + Use: "create", + Short: "Create a codespace", + Args: noArgsConstraint, + RunE: func(cmd *cobra.Command, args []string) error { + return app.Create(cmd.Context(), opts) + }, + } + + createCmd.Flags().StringVarP(&opts.repo, "repo", "r", "", "repository name with owner: user/repo") + createCmd.Flags().StringVarP(&opts.branch, "branch", "b", "", "repository branch") + createCmd.Flags().StringVarP(&opts.machine, "machine", "m", "", "hardware specifications for the VM") + createCmd.Flags().BoolVarP(&opts.showStatus, "status", "s", false, "show status of post-create command and dotfiles") + + return createCmd +} + +// Create creates a new Codespace +func (a *App) Create(ctx context.Context, opts createOptions) error { + locationCh := getLocation(ctx, a.apiClient) + userCh := getUser(ctx, a.apiClient) + + repo, err := getRepoName(opts.repo) + if err != nil { + return fmt.Errorf("error getting repository name: %w", err) + } + branch, err := getBranchName(opts.branch) + if err != nil { + return fmt.Errorf("error getting branch name: %w", err) + } + + repository, err := a.apiClient.GetRepository(ctx, repo) + if err != nil { + return fmt.Errorf("error getting repository: %w", err) + } + + locationResult := <-locationCh + if locationResult.Err != nil { + return fmt.Errorf("error getting codespace region location: %w", locationResult.Err) + } + + userResult := <-userCh + if userResult.Err != nil { + return fmt.Errorf("error getting codespace user: %w", userResult.Err) + } + + machine, err := getMachineName(ctx, opts.machine, userResult.User, repository, branch, locationResult.Location, a.apiClient) + if err != nil { + return fmt.Errorf("error getting machine type: %w", err) + } + if machine == "" { + return errors.New("there are no available machine types for this repository") + } + + a.logger.Print("Creating your codespace...") + codespace, err := a.apiClient.CreateCodespace(ctx, &api.CreateCodespaceParams{ + User: userResult.User.Login, + RepositoryID: repository.ID, + Branch: branch, + Machine: machine, + Location: locationResult.Location, + }) + a.logger.Print("\n") + if err != nil { + return fmt.Errorf("error creating codespace: %w", err) + } + + if opts.showStatus { + if err := showStatus(ctx, a.logger, a.apiClient, userResult.User, codespace); err != nil { + return fmt.Errorf("show status: %w", err) + } + } + + a.logger.Printf("Codespace created: ") + + fmt.Fprintln(os.Stdout, codespace.Name) + + return nil +} + +// showStatus polls the codespace for a list of post create states and their status. It will keep polling +// until all states have finished. Once all states have finished, we poll once more to check if any new +// states have been introduced and stop polling otherwise. +func showStatus(ctx context.Context, log *output.Logger, apiClient apiClient, user *api.User, codespace *api.Codespace) error { + var lastState codespaces.PostCreateState + var breakNextState bool + + finishedStates := make(map[string]bool) + ctx, stopPolling := context.WithCancel(ctx) + defer stopPolling() + + poller := func(states []codespaces.PostCreateState) { + var inProgress bool + for _, state := range states { + if _, found := finishedStates[state.Name]; found { + continue // skip this state as we've processed it already + } + + if state.Name != lastState.Name { + log.Print(state.Name) + + if state.Status == codespaces.PostCreateStateRunning { + inProgress = true + lastState = state + log.Print("...") + break + } + + finishedStates[state.Name] = true + log.Println("..." + state.Status) + } else { + if state.Status == codespaces.PostCreateStateRunning { + inProgress = true + log.Print(".") + break + } + + finishedStates[state.Name] = true + log.Println(state.Status) + lastState = codespaces.PostCreateState{} // reset the value + } + } + + if !inProgress { + if breakNextState { + stopPolling() + return + } + breakNextState = true + } + } + + err := codespaces.PollPostCreateStates(ctx, log, apiClient, user, codespace, poller) + if err != nil { + if errors.Is(err, context.Canceled) && breakNextState { + return nil // we cancelled the context to stop polling, we can ignore the error + } + + return fmt.Errorf("failed to poll state changes from codespace: %w", err) + } + + return nil +} + +type getUserResult struct { + User *api.User + Err error +} + +// getUser fetches the user record associated with the GITHUB_TOKEN +func getUser(ctx context.Context, apiClient apiClient) <-chan getUserResult { + ch := make(chan getUserResult, 1) + go func() { + user, err := apiClient.GetUser(ctx) + ch <- getUserResult{user, err} + }() + return ch +} + +type locationResult struct { + Location string + Err error +} + +// getLocation fetches the closest Codespace datacenter region/location to the user. +func getLocation(ctx context.Context, apiClient apiClient) <-chan locationResult { + ch := make(chan locationResult, 1) + go func() { + location, err := apiClient.GetCodespaceRegionLocation(ctx) + ch <- locationResult{location, err} + }() + return ch +} + +// getRepoName prompts the user for the name of the repository, or returns the repository if non-empty. +func getRepoName(repo string) (string, error) { + if repo != "" { + return repo, nil + } + + repoSurvey := []*survey.Question{ + { + Name: "repository", + Prompt: &survey.Input{Message: "Repository:"}, + Validate: survey.Required, + }, + } + err := ask(repoSurvey, &repo) + return repo, err +} + +// getBranchName prompts the user for the name of the branch, or returns the branch if non-empty. +func getBranchName(branch string) (string, error) { + if branch != "" { + return branch, nil + } + + branchSurvey := []*survey.Question{ + { + Name: "branch", + Prompt: &survey.Input{Message: "Branch:"}, + Validate: survey.Required, + }, + } + err := ask(branchSurvey, &branch) + return branch, err +} + +// getMachineName prompts the user to select the machine type, or validates the machine if non-empty. +func getMachineName(ctx context.Context, machine string, user *api.User, repo *api.Repository, branch, location string, apiClient apiClient) (string, error) { + skus, err := apiClient.GetCodespacesSKUs(ctx, user, repo, branch, location) + if err != nil { + return "", fmt.Errorf("error requesting machine instance types: %w", err) + } + + // if user supplied a machine type, it must be valid + // if no machine type was supplied, we don't error if there are no machine types for the current repo + if machine != "" { + for _, sku := range skus { + if machine == sku.Name { + return machine, nil + } + } + + availableSKUs := make([]string, len(skus)) + for i := 0; i < len(skus); i++ { + availableSKUs[i] = skus[i].Name + } + + return "", fmt.Errorf("there is no such machine for the repository: %s\nAvailable machines: %v", machine, availableSKUs) + } else if len(skus) == 0 { + return "", nil + } + + if len(skus) == 1 { + return skus[0].Name, nil // VS Code does not prompt for SKU if there is only one, this makes us consistent with that behavior + } + + skuNames := make([]string, 0, len(skus)) + skuByName := make(map[string]*api.SKU) + for _, sku := range skus { + nameParts := camelcase.Split(sku.Name) + machineName := strings.Title(strings.ToLower(nameParts[0])) + skuName := fmt.Sprintf("%s - %s", machineName, sku.DisplayName) + skuNames = append(skuNames, skuName) + skuByName[skuName] = sku + } + + skuSurvey := []*survey.Question{ + { + Name: "sku", + Prompt: &survey.Select{ + Message: "Choose Machine Type:", + Options: skuNames, + Default: skuNames[0], + }, + Validate: survey.Required, + }, + } + + var skuAnswers struct{ SKU string } + if err := ask(skuSurvey, &skuAnswers); err != nil { + return "", fmt.Errorf("error getting SKU: %w", err) + } + + sku := skuByName[skuAnswers.SKU] + machine = sku.Name + + return machine, nil +} diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go new file mode 100644 index 000000000..228ed99ba --- /dev/null +++ b/cmd/ghcs/delete.go @@ -0,0 +1,180 @@ +package ghcs + +import ( + "context" + "errors" + "fmt" + "strings" + "time" + + "github.com/AlecAivazis/survey/v2" + "github.com/cli/cli/v2/internal/api" + "github.com/spf13/cobra" + "golang.org/x/sync/errgroup" +) + +type deleteOptions struct { + deleteAll bool + skipConfirm bool + codespaceName string + repoFilter string + keepDays uint16 + + isInteractive bool + now func() time.Time + prompter prompter +} + +//go:generate moq -fmt goimports -rm -skip-ensure -out mock_prompter.go . prompter +type prompter interface { + Confirm(message string) (bool, error) +} + +func newDeleteCmd(app *App) *cobra.Command { + opts := deleteOptions{ + isInteractive: hasTTY, + now: time.Now, + prompter: &surveyPrompter{}, + } + + deleteCmd := &cobra.Command{ + Use: "delete", + Short: "Delete a codespace", + Args: noArgsConstraint, + RunE: func(cmd *cobra.Command, args []string) error { + if opts.deleteAll && opts.repoFilter != "" { + return errors.New("both --all and --repo is not supported") + } + return app.Delete(cmd.Context(), opts) + }, + } + + deleteCmd.Flags().StringVarP(&opts.codespaceName, "codespace", "c", "", "Name of the codespace") + deleteCmd.Flags().BoolVar(&opts.deleteAll, "all", false, "Delete all codespaces") + deleteCmd.Flags().StringVarP(&opts.repoFilter, "repo", "r", "", "Delete codespaces for a `repository`") + deleteCmd.Flags().BoolVarP(&opts.skipConfirm, "force", "f", false, "Skip confirmation for codespaces that contain unsaved changes") + deleteCmd.Flags().Uint16Var(&opts.keepDays, "days", 0, "Delete codespaces older than `N` days") + + return deleteCmd +} + +func (a *App) Delete(ctx context.Context, opts deleteOptions) error { + user, err := a.apiClient.GetUser(ctx) + if err != nil { + return fmt.Errorf("error getting user: %w", err) + } + + var codespaces []*api.Codespace + nameFilter := opts.codespaceName + if nameFilter == "" { + codespaces, err = a.apiClient.ListCodespaces(ctx, user.Login) + if err != nil { + return fmt.Errorf("error getting codespaces: %w", err) + } + + if !opts.deleteAll && opts.repoFilter == "" { + c, err := chooseCodespaceFromList(ctx, codespaces) + if err != nil { + return fmt.Errorf("error choosing codespace: %w", err) + } + nameFilter = c.Name + } + } else { + // TODO: this token is discarded and then re-requested later in DeleteCodespace + token, err := a.apiClient.GetCodespaceToken(ctx, user.Login, nameFilter) + if err != nil { + return fmt.Errorf("error getting codespace token: %w", err) + } + + codespace, err := a.apiClient.GetCodespace(ctx, token, user.Login, nameFilter) + if err != nil { + return fmt.Errorf("error fetching codespace information: %w", err) + } + + codespaces = []*api.Codespace{codespace} + } + + codespacesToDelete := make([]*api.Codespace, 0, len(codespaces)) + lastUpdatedCutoffTime := opts.now().AddDate(0, 0, -int(opts.keepDays)) + for _, c := range codespaces { + if nameFilter != "" && c.Name != nameFilter { + continue + } + if opts.repoFilter != "" && !strings.EqualFold(c.RepositoryNWO, opts.repoFilter) { + continue + } + if opts.keepDays > 0 { + t, err := time.Parse(time.RFC3339, c.LastUsedAt) + if err != nil { + return fmt.Errorf("error parsing last_used_at timestamp %q: %w", c.LastUsedAt, err) + } + if t.After(lastUpdatedCutoffTime) { + continue + } + } + if !opts.skipConfirm { + confirmed, err := confirmDeletion(opts.prompter, c, opts.isInteractive) + if err != nil { + return fmt.Errorf("unable to confirm: %w", err) + } + if !confirmed { + continue + } + } + codespacesToDelete = append(codespacesToDelete, c) + } + + if len(codespacesToDelete) == 0 { + return errors.New("no codespaces to delete") + } + + g := errgroup.Group{} + for _, c := range codespacesToDelete { + codespaceName := c.Name + g.Go(func() error { + if err := a.apiClient.DeleteCodespace(ctx, user.Login, codespaceName); err != nil { + _, _ = a.logger.Errorf("error deleting codespace %q: %v\n", codespaceName, err) + return err + } + return nil + }) + } + + if err := g.Wait(); err != nil { + return errors.New("some codespaces failed to delete") + } + return nil +} + +func confirmDeletion(p prompter, codespace *api.Codespace, isInteractive bool) (bool, error) { + gs := codespace.Environment.GitStatus + hasUnsavedChanges := gs.HasUncommitedChanges || gs.HasUnpushedChanges + if !hasUnsavedChanges { + return true, nil + } + if !isInteractive { + return false, fmt.Errorf("codespace %s has unsaved changes (use --force to override)", codespace.Name) + } + return p.Confirm(fmt.Sprintf("Codespace %s has unsaved changes. OK to delete?", codespace.Name)) +} + +type surveyPrompter struct{} + +func (p *surveyPrompter) Confirm(message string) (bool, error) { + var confirmed struct { + Confirmed bool + } + q := []*survey.Question{ + { + Name: "confirmed", + Prompt: &survey.Confirm{ + Message: message, + }, + }, + } + if err := ask(q, &confirmed); err != nil { + return false, fmt.Errorf("failed to prompt: %w", err) + } + + return confirmed.Confirmed, nil +} diff --git a/cmd/ghcs/delete_test.go b/cmd/ghcs/delete_test.go new file mode 100644 index 000000000..af30bc542 --- /dev/null +++ b/cmd/ghcs/delete_test.go @@ -0,0 +1,257 @@ +package ghcs + +import ( + "bytes" + "context" + "errors" + "fmt" + "sort" + "strings" + "testing" + "time" + + "github.com/MakeNowJust/heredoc" + "github.com/cli/cli/v2/cmd/ghcs/output" + "github.com/cli/cli/v2/internal/api" +) + +func TestDelete(t *testing.T) { + user := &api.User{Login: "hubot"} + now, _ := time.Parse(time.RFC3339, "2021-09-22T00:00:00Z") + daysAgo := func(n int) string { + return now.Add(time.Hour * -time.Duration(24*n)).Format(time.RFC3339) + } + + tests := []struct { + name string + opts deleteOptions + codespaces []*api.Codespace + confirms map[string]bool + deleteErr error + wantErr bool + wantDeleted []string + wantStdout string + wantStderr string + }{ + { + name: "by name", + opts: deleteOptions{ + codespaceName: "hubot-robawt-abc", + }, + codespaces: []*api.Codespace{ + { + Name: "hubot-robawt-abc", + }, + }, + wantDeleted: []string{"hubot-robawt-abc"}, + }, + { + name: "by repo", + opts: deleteOptions{ + repoFilter: "monalisa/spoon-knife", + }, + codespaces: []*api.Codespace{ + { + Name: "monalisa-spoonknife-123", + RepositoryNWO: "monalisa/Spoon-Knife", + }, + { + Name: "hubot-robawt-abc", + RepositoryNWO: "hubot/ROBAWT", + }, + { + Name: "monalisa-spoonknife-c4f3", + RepositoryNWO: "monalisa/Spoon-Knife", + }, + }, + wantDeleted: []string{"monalisa-spoonknife-123", "monalisa-spoonknife-c4f3"}, + }, + { + name: "unused", + opts: deleteOptions{ + deleteAll: true, + keepDays: 3, + }, + codespaces: []*api.Codespace{ + { + Name: "monalisa-spoonknife-123", + LastUsedAt: daysAgo(1), + }, + { + Name: "hubot-robawt-abc", + LastUsedAt: daysAgo(4), + }, + { + Name: "monalisa-spoonknife-c4f3", + LastUsedAt: daysAgo(10), + }, + }, + wantDeleted: []string{"hubot-robawt-abc", "monalisa-spoonknife-c4f3"}, + }, + { + name: "deletion failed", + opts: deleteOptions{ + deleteAll: true, + }, + codespaces: []*api.Codespace{ + { + Name: "monalisa-spoonknife-123", + }, + { + Name: "hubot-robawt-abc", + }, + }, + deleteErr: errors.New("aborted by test"), + wantErr: true, + wantDeleted: []string{"hubot-robawt-abc", "monalisa-spoonknife-123"}, + wantStderr: heredoc.Doc(` + error deleting codespace "hubot-robawt-abc": aborted by test + error deleting codespace "monalisa-spoonknife-123": aborted by test + `), + }, + { + name: "with confirm", + opts: deleteOptions{ + isInteractive: true, + deleteAll: true, + skipConfirm: false, + }, + codespaces: []*api.Codespace{ + { + Name: "monalisa-spoonknife-123", + Environment: api.CodespaceEnvironment{ + GitStatus: api.CodespaceEnvironmentGitStatus{ + HasUnpushedChanges: true, + }, + }, + }, + { + Name: "hubot-robawt-abc", + Environment: api.CodespaceEnvironment{ + GitStatus: api.CodespaceEnvironmentGitStatus{ + HasUncommitedChanges: true, + }, + }, + }, + { + Name: "monalisa-spoonknife-c4f3", + Environment: api.CodespaceEnvironment{ + GitStatus: api.CodespaceEnvironmentGitStatus{ + HasUnpushedChanges: false, + HasUncommitedChanges: false, + }, + }, + }, + }, + confirms: map[string]bool{ + "Codespace monalisa-spoonknife-123 has unsaved changes. OK to delete?": false, + "Codespace hubot-robawt-abc has unsaved changes. OK to delete?": true, + }, + wantDeleted: []string{"hubot-robawt-abc", "monalisa-spoonknife-c4f3"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + apiMock := &apiClientMock{ + GetUserFunc: func(_ context.Context) (*api.User, error) { + return user, nil + }, + DeleteCodespaceFunc: func(_ context.Context, userLogin, name string) error { + if userLogin != user.Login { + return fmt.Errorf("unexpected user %q", userLogin) + } + if tt.deleteErr != nil { + return tt.deleteErr + } + return nil + }, + } + if tt.opts.codespaceName == "" { + apiMock.ListCodespacesFunc = func(_ context.Context, userLogin string) ([]*api.Codespace, error) { + if userLogin != user.Login { + return nil, fmt.Errorf("unexpected user %q", userLogin) + } + return tt.codespaces, nil + } + } else { + apiMock.GetCodespaceTokenFunc = func(_ context.Context, userLogin, name string) (string, error) { + if userLogin != user.Login { + return "", fmt.Errorf("unexpected user %q", userLogin) + } + return "CS_TOKEN", nil + } + apiMock.GetCodespaceFunc = func(_ context.Context, token, userLogin, name string) (*api.Codespace, error) { + if userLogin != user.Login { + return nil, fmt.Errorf("unexpected user %q", userLogin) + } + if token != "CS_TOKEN" { + return nil, fmt.Errorf("unexpected token %q", token) + } + return tt.codespaces[0], nil + } + } + opts := tt.opts + opts.now = func() time.Time { return now } + opts.prompter = &prompterMock{ + ConfirmFunc: func(msg string) (bool, error) { + res, found := tt.confirms[msg] + if !found { + return false, fmt.Errorf("unexpected prompt %q", msg) + } + return res, nil + }, + } + + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + app := &App{ + apiClient: apiMock, + logger: output.NewLogger(stdout, stderr, false), + } + err := app.Delete(context.Background(), opts) + if (err != nil) != tt.wantErr { + t.Errorf("delete() error = %v, wantErr %v", err, tt.wantErr) + } + if n := len(apiMock.GetUserCalls()); n != 1 { + t.Errorf("GetUser invoked %d times, expected %d", n, 1) + } + var gotDeleted []string + for _, delArgs := range apiMock.DeleteCodespaceCalls() { + gotDeleted = append(gotDeleted, delArgs.Name) + } + sort.Strings(gotDeleted) + if !sliceEquals(gotDeleted, tt.wantDeleted) { + t.Errorf("deleted %q, want %q", gotDeleted, tt.wantDeleted) + } + if out := stdout.String(); out != tt.wantStdout { + t.Errorf("stdout = %q, want %q", out, tt.wantStdout) + } + if out := sortLines(stderr.String()); out != tt.wantStderr { + t.Errorf("stderr = %q, want %q", out, tt.wantStderr) + } + }) + } +} + +func sliceEquals(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + +func sortLines(s string) string { + trailing := "" + if strings.HasSuffix(s, "\n") { + s = strings.TrimSuffix(s, "\n") + trailing = "\n" + } + lines := strings.Split(s, "\n") + sort.Strings(lines) + return strings.Join(lines, "\n") + trailing +} diff --git a/cmd/ghcs/list.go b/cmd/ghcs/list.go new file mode 100644 index 000000000..679325eff --- /dev/null +++ b/cmd/ghcs/list.go @@ -0,0 +1,63 @@ +package ghcs + +import ( + "context" + "fmt" + "os" + + "github.com/cli/cli/v2/cmd/ghcs/output" + "github.com/cli/cli/v2/internal/api" + "github.com/spf13/cobra" +) + +func newListCmd(app *App) *cobra.Command { + var asJSON bool + + listCmd := &cobra.Command{ + Use: "list", + Short: "List your codespaces", + Args: noArgsConstraint, + RunE: func(cmd *cobra.Command, args []string) error { + return app.List(cmd.Context(), asJSON) + }, + } + + listCmd.Flags().BoolVar(&asJSON, "json", false, "Output as JSON") + + return listCmd +} + +func (a *App) List(ctx context.Context, asJSON bool) error { + user, err := a.apiClient.GetUser(ctx) + if err != nil { + return fmt.Errorf("error getting user: %w", err) + } + + codespaces, err := a.apiClient.ListCodespaces(ctx, user.Login) + if err != nil { + return fmt.Errorf("error getting codespaces: %w", err) + } + + table := output.NewTable(os.Stdout, asJSON) + table.SetHeader([]string{"Name", "Repository", "Branch", "State", "Created At"}) + for _, codespace := range codespaces { + table.Append([]string{ + codespace.Name, + codespace.RepositoryNWO, + codespace.Branch + dirtyStar(codespace.Environment.GitStatus), + codespace.Environment.State, + codespace.CreatedAt, + }) + } + + table.Render() + return nil +} + +func dirtyStar(status api.CodespaceEnvironmentGitStatus) string { + if status.HasUncommitedChanges || status.HasUnpushedChanges { + return "*" + } + + return "" +} diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go new file mode 100644 index 000000000..24d017e58 --- /dev/null +++ b/cmd/ghcs/logs.go @@ -0,0 +1,112 @@ +package ghcs + +import ( + "context" + "fmt" + "net" + + "github.com/cli/cli/v2/internal/codespaces" + "github.com/cli/cli/v2/internal/liveshare" + "github.com/spf13/cobra" +) + +func newLogsCmd(app *App) *cobra.Command { + var ( + codespace string + follow bool + ) + + logsCmd := &cobra.Command{ + Use: "logs", + Short: "Access codespace logs", + Args: noArgsConstraint, + RunE: func(cmd *cobra.Command, args []string) error { + return app.Logs(cmd.Context(), codespace, follow) + }, + } + + logsCmd.Flags().StringVarP(&codespace, "codespace", "c", "", "Name of the codespace") + logsCmd.Flags().BoolVarP(&follow, "follow", "f", false, "Tail and follow the logs") + + return logsCmd +} + +func (a *App) Logs(ctx context.Context, codespaceName string, follow bool) (err error) { + // Ensure all child tasks (port forwarding, remote exec) terminate before return. + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + user, err := a.apiClient.GetUser(ctx) + if err != nil { + return fmt.Errorf("getting user: %w", err) + } + + authkeys := make(chan error, 1) + go func() { + authkeys <- checkAuthorizedKeys(ctx, a.apiClient, user.Login) + }() + + codespace, token, err := getOrChooseCodespace(ctx, a.apiClient, user, codespaceName) + if err != nil { + return fmt.Errorf("get or choose codespace: %w", err) + } + + session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, user.Login, token, codespace) + if err != nil { + return fmt.Errorf("connecting to Live Share: %w", err) + } + defer safeClose(session, &err) + + if err := <-authkeys; err != nil { + return err + } + + // Ensure local port is listening before client (getPostCreateOutput) connects. + listen, err := net.Listen("tcp", ":0") // arbitrary port + if err != nil { + return err + } + defer listen.Close() + localPort := listen.Addr().(*net.TCPAddr).Port + + a.logger.Println("Fetching SSH Details...") + remoteSSHServerPort, sshUser, err := session.StartSSHServer(ctx) + if err != nil { + return fmt.Errorf("error getting ssh server details: %w", err) + } + + cmdType := "cat" + if follow { + cmdType = "tail -f" + } + + dst := fmt.Sprintf("%s@localhost", sshUser) + cmd, err := codespaces.NewRemoteCommand( + ctx, localPort, dst, fmt.Sprintf("%s /workspaces/.codespaces/.persistedshare/creation.log", cmdType), + ) + if err != nil { + return fmt.Errorf("remote command: %w", err) + } + + tunnelClosed := make(chan error, 1) + go func() { + fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort) + tunnelClosed <- fwd.ForwardToListener(ctx, listen) // error is non-nil + }() + + cmdDone := make(chan error, 1) + go func() { + cmdDone <- cmd.Run() + }() + + select { + case err := <-tunnelClosed: + return fmt.Errorf("connection closed: %w", err) + case err := <-cmdDone: + if err != nil { + return fmt.Errorf("error retrieving logs: %w", err) + } + + return nil // success + } +} diff --git a/cmd/ghcs/main/main.go b/cmd/ghcs/main/main.go new file mode 100644 index 000000000..455a8c302 --- /dev/null +++ b/cmd/ghcs/main/main.go @@ -0,0 +1,54 @@ +package main + +import ( + "errors" + "fmt" + "io" + "net/http" + "os" + + "github.com/cli/cli/v2/cmd/ghcs" + "github.com/cli/cli/v2/cmd/ghcs/output" + "github.com/cli/cli/v2/internal/api" + "github.com/spf13/cobra" +) + +func main() { + token := os.Getenv("GITHUB_TOKEN") + rootCmd := ghcs.NewRootCmd(ghcs.NewApp( + output.NewLogger(os.Stdout, os.Stderr, false), + api.New(token, http.DefaultClient), + )) + + // Require GITHUB_TOKEN through a Cobra pre-run hook so that Cobra's help system for commands can still + // function without the token set. + oldPreRun := rootCmd.PersistentPreRunE + rootCmd.PersistentPreRunE = func(cmd *cobra.Command, args []string) error { + if token == "" { + return errTokenMissing + } + if oldPreRun != nil { + return oldPreRun(cmd, args) + } + return nil + } + + if cmd, err := rootCmd.ExecuteC(); err != nil { + explainError(os.Stderr, err, cmd) + os.Exit(1) + } +} + +var errTokenMissing = errors.New("GITHUB_TOKEN is missing") + +func explainError(w io.Writer, err error, cmd *cobra.Command) { + if errors.Is(err, errTokenMissing) { + fmt.Fprintln(w, "The GITHUB_TOKEN environment variable is required. Create a Personal Access Token at https://github.com/settings/tokens/new?scopes=repo") + fmt.Fprintln(w, "Make sure to enable SSO for your organizations after creating the token.") + return + } + if errors.Is(err, ghcs.ErrTooManyArgs) { + _ = cmd.Usage() + return + } +} diff --git a/cmd/ghcs/mock_api.go b/cmd/ghcs/mock_api.go new file mode 100644 index 000000000..f9733e9bc --- /dev/null +++ b/cmd/ghcs/mock_api.go @@ -0,0 +1,659 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package ghcs + +import ( + "context" + "sync" + + "github.com/cli/cli/v2/internal/api" +) + +// apiClientMock is a mock implementation of apiClient. +// +// func TestSomethingThatUsesapiClient(t *testing.T) { +// +// // make and configure a mocked apiClient +// mockedapiClient := &apiClientMock{ +// AuthorizedKeysFunc: func(ctx context.Context, user string) ([]byte, error) { +// panic("mock out the AuthorizedKeys method") +// }, +// CreateCodespaceFunc: func(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error) { +// panic("mock out the CreateCodespace method") +// }, +// DeleteCodespaceFunc: func(ctx context.Context, user string, name string) error { +// panic("mock out the DeleteCodespace method") +// }, +// GetCodespaceFunc: func(ctx context.Context, token string, user string, name string) (*api.Codespace, error) { +// panic("mock out the GetCodespace method") +// }, +// GetCodespaceRegionLocationFunc: func(ctx context.Context) (string, error) { +// panic("mock out the GetCodespaceRegionLocation method") +// }, +// GetCodespaceRepositoryContentsFunc: func(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error) { +// panic("mock out the GetCodespaceRepositoryContents method") +// }, +// GetCodespaceTokenFunc: func(ctx context.Context, user string, name string) (string, error) { +// panic("mock out the GetCodespaceToken method") +// }, +// GetCodespacesSKUsFunc: func(ctx context.Context, user *api.User, repository *api.Repository, branch string, location string) ([]*api.SKU, error) { +// panic("mock out the GetCodespacesSKUs method") +// }, +// GetRepositoryFunc: func(ctx context.Context, nwo string) (*api.Repository, error) { +// panic("mock out the GetRepository method") +// }, +// GetUserFunc: func(ctx context.Context) (*api.User, error) { +// panic("mock out the GetUser method") +// }, +// ListCodespacesFunc: func(ctx context.Context, user string) ([]*api.Codespace, error) { +// panic("mock out the ListCodespaces method") +// }, +// StartCodespaceFunc: func(ctx context.Context, token string, codespace *api.Codespace) error { +// panic("mock out the StartCodespace method") +// }, +// } +// +// // use mockedapiClient in code that requires apiClient +// // and then make assertions. +// +// } +type apiClientMock struct { + // AuthorizedKeysFunc mocks the AuthorizedKeys method. + AuthorizedKeysFunc func(ctx context.Context, user string) ([]byte, error) + + // CreateCodespaceFunc mocks the CreateCodespace method. + CreateCodespaceFunc func(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error) + + // DeleteCodespaceFunc mocks the DeleteCodespace method. + DeleteCodespaceFunc func(ctx context.Context, user string, name string) error + + // GetCodespaceFunc mocks the GetCodespace method. + GetCodespaceFunc func(ctx context.Context, token string, user string, name string) (*api.Codespace, error) + + // GetCodespaceRegionLocationFunc mocks the GetCodespaceRegionLocation method. + GetCodespaceRegionLocationFunc func(ctx context.Context) (string, error) + + // GetCodespaceRepositoryContentsFunc mocks the GetCodespaceRepositoryContents method. + GetCodespaceRepositoryContentsFunc func(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error) + + // GetCodespaceTokenFunc mocks the GetCodespaceToken method. + GetCodespaceTokenFunc func(ctx context.Context, user string, name string) (string, error) + + // GetCodespacesSKUsFunc mocks the GetCodespacesSKUs method. + GetCodespacesSKUsFunc func(ctx context.Context, user *api.User, repository *api.Repository, branch string, location string) ([]*api.SKU, error) + + // GetRepositoryFunc mocks the GetRepository method. + GetRepositoryFunc func(ctx context.Context, nwo string) (*api.Repository, error) + + // GetUserFunc mocks the GetUser method. + GetUserFunc func(ctx context.Context) (*api.User, error) + + // ListCodespacesFunc mocks the ListCodespaces method. + ListCodespacesFunc func(ctx context.Context, user string) ([]*api.Codespace, error) + + // StartCodespaceFunc mocks the StartCodespace method. + StartCodespaceFunc func(ctx context.Context, token string, codespace *api.Codespace) error + + // calls tracks calls to the methods. + calls struct { + // AuthorizedKeys holds details about calls to the AuthorizedKeys method. + AuthorizedKeys []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // User is the user argument value. + User string + } + // CreateCodespace holds details about calls to the CreateCodespace method. + CreateCodespace []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Params is the params argument value. + Params *api.CreateCodespaceParams + } + // DeleteCodespace holds details about calls to the DeleteCodespace method. + DeleteCodespace []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // User is the user argument value. + User string + // Name is the name argument value. + Name string + } + // GetCodespace holds details about calls to the GetCodespace method. + GetCodespace []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Token is the token argument value. + Token string + // User is the user argument value. + User string + // Name is the name argument value. + Name string + } + // GetCodespaceRegionLocation holds details about calls to the GetCodespaceRegionLocation method. + GetCodespaceRegionLocation []struct { + // Ctx is the ctx argument value. + Ctx context.Context + } + // GetCodespaceRepositoryContents holds details about calls to the GetCodespaceRepositoryContents method. + GetCodespaceRepositoryContents []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Codespace is the codespace argument value. + Codespace *api.Codespace + // Path is the path argument value. + Path string + } + // GetCodespaceToken holds details about calls to the GetCodespaceToken method. + GetCodespaceToken []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // User is the user argument value. + User string + // Name is the name argument value. + Name string + } + // GetCodespacesSKUs holds details about calls to the GetCodespacesSKUs method. + GetCodespacesSKUs []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // User is the user argument value. + User *api.User + // Repository is the repository argument value. + Repository *api.Repository + // Branch is the branch argument value. + Branch string + // Location is the location argument value. + Location string + } + // GetRepository holds details about calls to the GetRepository method. + GetRepository []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Nwo is the nwo argument value. + Nwo string + } + // GetUser holds details about calls to the GetUser method. + GetUser []struct { + // Ctx is the ctx argument value. + Ctx context.Context + } + // ListCodespaces holds details about calls to the ListCodespaces method. + ListCodespaces []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // User is the user argument value. + User string + } + // StartCodespace holds details about calls to the StartCodespace method. + StartCodespace []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Token is the token argument value. + Token string + // Codespace is the codespace argument value. + Codespace *api.Codespace + } + } + lockAuthorizedKeys sync.RWMutex + lockCreateCodespace sync.RWMutex + lockDeleteCodespace sync.RWMutex + lockGetCodespace sync.RWMutex + lockGetCodespaceRegionLocation sync.RWMutex + lockGetCodespaceRepositoryContents sync.RWMutex + lockGetCodespaceToken sync.RWMutex + lockGetCodespacesSKUs sync.RWMutex + lockGetRepository sync.RWMutex + lockGetUser sync.RWMutex + lockListCodespaces sync.RWMutex + lockStartCodespace sync.RWMutex +} + +// AuthorizedKeys calls AuthorizedKeysFunc. +func (mock *apiClientMock) AuthorizedKeys(ctx context.Context, user string) ([]byte, error) { + if mock.AuthorizedKeysFunc == nil { + panic("apiClientMock.AuthorizedKeysFunc: method is nil but apiClient.AuthorizedKeys was just called") + } + callInfo := struct { + Ctx context.Context + User string + }{ + Ctx: ctx, + User: user, + } + mock.lockAuthorizedKeys.Lock() + mock.calls.AuthorizedKeys = append(mock.calls.AuthorizedKeys, callInfo) + mock.lockAuthorizedKeys.Unlock() + return mock.AuthorizedKeysFunc(ctx, user) +} + +// AuthorizedKeysCalls gets all the calls that were made to AuthorizedKeys. +// Check the length with: +// len(mockedapiClient.AuthorizedKeysCalls()) +func (mock *apiClientMock) AuthorizedKeysCalls() []struct { + Ctx context.Context + User string +} { + var calls []struct { + Ctx context.Context + User string + } + mock.lockAuthorizedKeys.RLock() + calls = mock.calls.AuthorizedKeys + mock.lockAuthorizedKeys.RUnlock() + return calls +} + +// CreateCodespace calls CreateCodespaceFunc. +func (mock *apiClientMock) CreateCodespace(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error) { + if mock.CreateCodespaceFunc == nil { + panic("apiClientMock.CreateCodespaceFunc: method is nil but apiClient.CreateCodespace was just called") + } + callInfo := struct { + Ctx context.Context + Params *api.CreateCodespaceParams + }{ + Ctx: ctx, + Params: params, + } + mock.lockCreateCodespace.Lock() + mock.calls.CreateCodespace = append(mock.calls.CreateCodespace, callInfo) + mock.lockCreateCodespace.Unlock() + return mock.CreateCodespaceFunc(ctx, params) +} + +// CreateCodespaceCalls gets all the calls that were made to CreateCodespace. +// Check the length with: +// len(mockedapiClient.CreateCodespaceCalls()) +func (mock *apiClientMock) CreateCodespaceCalls() []struct { + Ctx context.Context + Params *api.CreateCodespaceParams +} { + var calls []struct { + Ctx context.Context + Params *api.CreateCodespaceParams + } + mock.lockCreateCodespace.RLock() + calls = mock.calls.CreateCodespace + mock.lockCreateCodespace.RUnlock() + return calls +} + +// DeleteCodespace calls DeleteCodespaceFunc. +func (mock *apiClientMock) DeleteCodespace(ctx context.Context, user string, name string) error { + if mock.DeleteCodespaceFunc == nil { + panic("apiClientMock.DeleteCodespaceFunc: method is nil but apiClient.DeleteCodespace was just called") + } + callInfo := struct { + Ctx context.Context + User string + Name string + }{ + Ctx: ctx, + User: user, + Name: name, + } + mock.lockDeleteCodespace.Lock() + mock.calls.DeleteCodespace = append(mock.calls.DeleteCodespace, callInfo) + mock.lockDeleteCodespace.Unlock() + return mock.DeleteCodespaceFunc(ctx, user, name) +} + +// DeleteCodespaceCalls gets all the calls that were made to DeleteCodespace. +// Check the length with: +// len(mockedapiClient.DeleteCodespaceCalls()) +func (mock *apiClientMock) DeleteCodespaceCalls() []struct { + Ctx context.Context + User string + Name string +} { + var calls []struct { + Ctx context.Context + User string + Name string + } + mock.lockDeleteCodespace.RLock() + calls = mock.calls.DeleteCodespace + mock.lockDeleteCodespace.RUnlock() + return calls +} + +// GetCodespace calls GetCodespaceFunc. +func (mock *apiClientMock) GetCodespace(ctx context.Context, token string, user string, name string) (*api.Codespace, error) { + if mock.GetCodespaceFunc == nil { + panic("apiClientMock.GetCodespaceFunc: method is nil but apiClient.GetCodespace was just called") + } + callInfo := struct { + Ctx context.Context + Token string + User string + Name string + }{ + Ctx: ctx, + Token: token, + User: user, + Name: name, + } + mock.lockGetCodespace.Lock() + mock.calls.GetCodespace = append(mock.calls.GetCodespace, callInfo) + mock.lockGetCodespace.Unlock() + return mock.GetCodespaceFunc(ctx, token, user, name) +} + +// GetCodespaceCalls gets all the calls that were made to GetCodespace. +// Check the length with: +// len(mockedapiClient.GetCodespaceCalls()) +func (mock *apiClientMock) GetCodespaceCalls() []struct { + Ctx context.Context + Token string + User string + Name string +} { + var calls []struct { + Ctx context.Context + Token string + User string + Name string + } + mock.lockGetCodespace.RLock() + calls = mock.calls.GetCodespace + mock.lockGetCodespace.RUnlock() + return calls +} + +// GetCodespaceRegionLocation calls GetCodespaceRegionLocationFunc. +func (mock *apiClientMock) GetCodespaceRegionLocation(ctx context.Context) (string, error) { + if mock.GetCodespaceRegionLocationFunc == nil { + panic("apiClientMock.GetCodespaceRegionLocationFunc: method is nil but apiClient.GetCodespaceRegionLocation was just called") + } + callInfo := struct { + Ctx context.Context + }{ + Ctx: ctx, + } + mock.lockGetCodespaceRegionLocation.Lock() + mock.calls.GetCodespaceRegionLocation = append(mock.calls.GetCodespaceRegionLocation, callInfo) + mock.lockGetCodespaceRegionLocation.Unlock() + return mock.GetCodespaceRegionLocationFunc(ctx) +} + +// GetCodespaceRegionLocationCalls gets all the calls that were made to GetCodespaceRegionLocation. +// Check the length with: +// len(mockedapiClient.GetCodespaceRegionLocationCalls()) +func (mock *apiClientMock) GetCodespaceRegionLocationCalls() []struct { + Ctx context.Context +} { + var calls []struct { + Ctx context.Context + } + mock.lockGetCodespaceRegionLocation.RLock() + calls = mock.calls.GetCodespaceRegionLocation + mock.lockGetCodespaceRegionLocation.RUnlock() + return calls +} + +// GetCodespaceRepositoryContents calls GetCodespaceRepositoryContentsFunc. +func (mock *apiClientMock) GetCodespaceRepositoryContents(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error) { + if mock.GetCodespaceRepositoryContentsFunc == nil { + panic("apiClientMock.GetCodespaceRepositoryContentsFunc: method is nil but apiClient.GetCodespaceRepositoryContents was just called") + } + callInfo := struct { + Ctx context.Context + Codespace *api.Codespace + Path string + }{ + Ctx: ctx, + Codespace: codespace, + Path: path, + } + mock.lockGetCodespaceRepositoryContents.Lock() + mock.calls.GetCodespaceRepositoryContents = append(mock.calls.GetCodespaceRepositoryContents, callInfo) + mock.lockGetCodespaceRepositoryContents.Unlock() + return mock.GetCodespaceRepositoryContentsFunc(ctx, codespace, path) +} + +// GetCodespaceRepositoryContentsCalls gets all the calls that were made to GetCodespaceRepositoryContents. +// Check the length with: +// len(mockedapiClient.GetCodespaceRepositoryContentsCalls()) +func (mock *apiClientMock) GetCodespaceRepositoryContentsCalls() []struct { + Ctx context.Context + Codespace *api.Codespace + Path string +} { + var calls []struct { + Ctx context.Context + Codespace *api.Codespace + Path string + } + mock.lockGetCodespaceRepositoryContents.RLock() + calls = mock.calls.GetCodespaceRepositoryContents + mock.lockGetCodespaceRepositoryContents.RUnlock() + return calls +} + +// GetCodespaceToken calls GetCodespaceTokenFunc. +func (mock *apiClientMock) GetCodespaceToken(ctx context.Context, user string, name string) (string, error) { + if mock.GetCodespaceTokenFunc == nil { + panic("apiClientMock.GetCodespaceTokenFunc: method is nil but apiClient.GetCodespaceToken was just called") + } + callInfo := struct { + Ctx context.Context + User string + Name string + }{ + Ctx: ctx, + User: user, + Name: name, + } + mock.lockGetCodespaceToken.Lock() + mock.calls.GetCodespaceToken = append(mock.calls.GetCodespaceToken, callInfo) + mock.lockGetCodespaceToken.Unlock() + return mock.GetCodespaceTokenFunc(ctx, user, name) +} + +// GetCodespaceTokenCalls gets all the calls that were made to GetCodespaceToken. +// Check the length with: +// len(mockedapiClient.GetCodespaceTokenCalls()) +func (mock *apiClientMock) GetCodespaceTokenCalls() []struct { + Ctx context.Context + User string + Name string +} { + var calls []struct { + Ctx context.Context + User string + Name string + } + mock.lockGetCodespaceToken.RLock() + calls = mock.calls.GetCodespaceToken + mock.lockGetCodespaceToken.RUnlock() + return calls +} + +// GetCodespacesSKUs calls GetCodespacesSKUsFunc. +func (mock *apiClientMock) GetCodespacesSKUs(ctx context.Context, user *api.User, repository *api.Repository, branch string, location string) ([]*api.SKU, error) { + if mock.GetCodespacesSKUsFunc == nil { + panic("apiClientMock.GetCodespacesSKUsFunc: method is nil but apiClient.GetCodespacesSKUs was just called") + } + callInfo := struct { + Ctx context.Context + User *api.User + Repository *api.Repository + Branch string + Location string + }{ + Ctx: ctx, + User: user, + Repository: repository, + Branch: branch, + Location: location, + } + mock.lockGetCodespacesSKUs.Lock() + mock.calls.GetCodespacesSKUs = append(mock.calls.GetCodespacesSKUs, callInfo) + mock.lockGetCodespacesSKUs.Unlock() + return mock.GetCodespacesSKUsFunc(ctx, user, repository, branch, location) +} + +// GetCodespacesSKUsCalls gets all the calls that were made to GetCodespacesSKUs. +// Check the length with: +// len(mockedapiClient.GetCodespacesSKUsCalls()) +func (mock *apiClientMock) GetCodespacesSKUsCalls() []struct { + Ctx context.Context + User *api.User + Repository *api.Repository + Branch string + Location string +} { + var calls []struct { + Ctx context.Context + User *api.User + Repository *api.Repository + Branch string + Location string + } + mock.lockGetCodespacesSKUs.RLock() + calls = mock.calls.GetCodespacesSKUs + mock.lockGetCodespacesSKUs.RUnlock() + return calls +} + +// GetRepository calls GetRepositoryFunc. +func (mock *apiClientMock) GetRepository(ctx context.Context, nwo string) (*api.Repository, error) { + if mock.GetRepositoryFunc == nil { + panic("apiClientMock.GetRepositoryFunc: method is nil but apiClient.GetRepository was just called") + } + callInfo := struct { + Ctx context.Context + Nwo string + }{ + Ctx: ctx, + Nwo: nwo, + } + mock.lockGetRepository.Lock() + mock.calls.GetRepository = append(mock.calls.GetRepository, callInfo) + mock.lockGetRepository.Unlock() + return mock.GetRepositoryFunc(ctx, nwo) +} + +// GetRepositoryCalls gets all the calls that were made to GetRepository. +// Check the length with: +// len(mockedapiClient.GetRepositoryCalls()) +func (mock *apiClientMock) GetRepositoryCalls() []struct { + Ctx context.Context + Nwo string +} { + var calls []struct { + Ctx context.Context + Nwo string + } + mock.lockGetRepository.RLock() + calls = mock.calls.GetRepository + mock.lockGetRepository.RUnlock() + return calls +} + +// GetUser calls GetUserFunc. +func (mock *apiClientMock) GetUser(ctx context.Context) (*api.User, error) { + if mock.GetUserFunc == nil { + panic("apiClientMock.GetUserFunc: method is nil but apiClient.GetUser was just called") + } + callInfo := struct { + Ctx context.Context + }{ + Ctx: ctx, + } + mock.lockGetUser.Lock() + mock.calls.GetUser = append(mock.calls.GetUser, callInfo) + mock.lockGetUser.Unlock() + return mock.GetUserFunc(ctx) +} + +// GetUserCalls gets all the calls that were made to GetUser. +// Check the length with: +// len(mockedapiClient.GetUserCalls()) +func (mock *apiClientMock) GetUserCalls() []struct { + Ctx context.Context +} { + var calls []struct { + Ctx context.Context + } + mock.lockGetUser.RLock() + calls = mock.calls.GetUser + mock.lockGetUser.RUnlock() + return calls +} + +// ListCodespaces calls ListCodespacesFunc. +func (mock *apiClientMock) ListCodespaces(ctx context.Context, user string) ([]*api.Codespace, error) { + if mock.ListCodespacesFunc == nil { + panic("apiClientMock.ListCodespacesFunc: method is nil but apiClient.ListCodespaces was just called") + } + callInfo := struct { + Ctx context.Context + User string + }{ + Ctx: ctx, + User: user, + } + mock.lockListCodespaces.Lock() + mock.calls.ListCodespaces = append(mock.calls.ListCodespaces, callInfo) + mock.lockListCodespaces.Unlock() + return mock.ListCodespacesFunc(ctx, user) +} + +// ListCodespacesCalls gets all the calls that were made to ListCodespaces. +// Check the length with: +// len(mockedapiClient.ListCodespacesCalls()) +func (mock *apiClientMock) ListCodespacesCalls() []struct { + Ctx context.Context + User string +} { + var calls []struct { + Ctx context.Context + User string + } + mock.lockListCodespaces.RLock() + calls = mock.calls.ListCodespaces + mock.lockListCodespaces.RUnlock() + return calls +} + +// StartCodespace calls StartCodespaceFunc. +func (mock *apiClientMock) StartCodespace(ctx context.Context, token string, codespace *api.Codespace) error { + if mock.StartCodespaceFunc == nil { + panic("apiClientMock.StartCodespaceFunc: method is nil but apiClient.StartCodespace was just called") + } + callInfo := struct { + Ctx context.Context + Token string + Codespace *api.Codespace + }{ + Ctx: ctx, + Token: token, + Codespace: codespace, + } + mock.lockStartCodespace.Lock() + mock.calls.StartCodespace = append(mock.calls.StartCodespace, callInfo) + mock.lockStartCodespace.Unlock() + return mock.StartCodespaceFunc(ctx, token, codespace) +} + +// StartCodespaceCalls gets all the calls that were made to StartCodespace. +// Check the length with: +// len(mockedapiClient.StartCodespaceCalls()) +func (mock *apiClientMock) StartCodespaceCalls() []struct { + Ctx context.Context + Token string + Codespace *api.Codespace +} { + var calls []struct { + Ctx context.Context + Token string + Codespace *api.Codespace + } + mock.lockStartCodespace.RLock() + calls = mock.calls.StartCodespace + mock.lockStartCodespace.RUnlock() + return calls +} diff --git a/cmd/ghcs/mock_prompter.go b/cmd/ghcs/mock_prompter.go new file mode 100644 index 000000000..56581b64d --- /dev/null +++ b/cmd/ghcs/mock_prompter.go @@ -0,0 +1,69 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package ghcs + +import ( + "sync" +) + +// prompterMock is a mock implementation of prompter. +// +// func TestSomethingThatUsesprompter(t *testing.T) { +// +// // make and configure a mocked prompter +// mockedprompter := &prompterMock{ +// ConfirmFunc: func(message string) (bool, error) { +// panic("mock out the Confirm method") +// }, +// } +// +// // use mockedprompter in code that requires prompter +// // and then make assertions. +// +// } +type prompterMock struct { + // ConfirmFunc mocks the Confirm method. + ConfirmFunc func(message string) (bool, error) + + // calls tracks calls to the methods. + calls struct { + // Confirm holds details about calls to the Confirm method. + Confirm []struct { + // Message is the message argument value. + Message string + } + } + lockConfirm sync.RWMutex +} + +// Confirm calls ConfirmFunc. +func (mock *prompterMock) Confirm(message string) (bool, error) { + if mock.ConfirmFunc == nil { + panic("prompterMock.ConfirmFunc: method is nil but prompter.Confirm was just called") + } + callInfo := struct { + Message string + }{ + Message: message, + } + mock.lockConfirm.Lock() + mock.calls.Confirm = append(mock.calls.Confirm, callInfo) + mock.lockConfirm.Unlock() + return mock.ConfirmFunc(message) +} + +// ConfirmCalls gets all the calls that were made to Confirm. +// Check the length with: +// len(mockedprompter.ConfirmCalls()) +func (mock *prompterMock) ConfirmCalls() []struct { + Message string +} { + var calls []struct { + Message string + } + mock.lockConfirm.RLock() + calls = mock.calls.Confirm + mock.lockConfirm.RUnlock() + return calls +} diff --git a/cmd/ghcs/output/format_json.go b/cmd/ghcs/output/format_json.go new file mode 100644 index 000000000..8488e8dfa --- /dev/null +++ b/cmd/ghcs/output/format_json.go @@ -0,0 +1,55 @@ +package output + +import ( + "encoding/json" + "io" + "strings" + "unicode" +) + +type jsonwriter struct { + w io.Writer + pretty bool + cols []string + data []interface{} +} + +func (j *jsonwriter) SetHeader(cols []string) { + j.cols = cols +} + +func (j *jsonwriter) Append(values []string) { + row := make(map[string]string) + for i, v := range values { + row[camelize(j.cols[i])] = v + } + j.data = append(j.data, row) +} + +func (j *jsonwriter) Render() { + enc := json.NewEncoder(j.w) + if j.pretty { + enc.SetIndent("", " ") + } + _ = enc.Encode(j.data) +} + +func camelize(s string) string { + var b strings.Builder + capitalizeNext := false + for i, r := range s { + if r == ' ' { + capitalizeNext = true + continue + } + if capitalizeNext { + b.WriteRune(unicode.ToUpper(r)) + capitalizeNext = false + } else if i == 0 { + b.WriteRune(unicode.ToLower(r)) + } else { + b.WriteRune(r) + } + } + return b.String() +} diff --git a/cmd/ghcs/output/format_table.go b/cmd/ghcs/output/format_table.go new file mode 100644 index 000000000..e0345672d --- /dev/null +++ b/cmd/ghcs/output/format_table.go @@ -0,0 +1,31 @@ +package output + +import ( + "io" + "os" + + "github.com/olekukonko/tablewriter" + "golang.org/x/term" +) + +type Table interface { + SetHeader([]string) + Append([]string) + Render() +} + +func NewTable(w io.Writer, asJSON bool) Table { + isTTY := isTTY(w) + if asJSON { + return &jsonwriter{w: w, pretty: isTTY} + } + if isTTY { + return tablewriter.NewWriter(w) + } + return &tabwriter{w: w} +} + +func isTTY(w io.Writer) bool { + f, ok := w.(*os.File) + return ok && term.IsTerminal(int(f.Fd())) +} diff --git a/cmd/ghcs/output/format_tsv.go b/cmd/ghcs/output/format_tsv.go new file mode 100644 index 000000000..3f1d226ca --- /dev/null +++ b/cmd/ghcs/output/format_tsv.go @@ -0,0 +1,25 @@ +package output + +import ( + "fmt" + "io" +) + +type tabwriter struct { + w io.Writer +} + +func (j *tabwriter) SetHeader([]string) {} + +func (j *tabwriter) Append(values []string) { + var sep string + for i, v := range values { + if i == 1 { + sep = "\t" + } + fmt.Fprintf(j.w, "%s%s", sep, v) + } + fmt.Fprint(j.w, "\n") +} + +func (j *tabwriter) Render() {} diff --git a/cmd/ghcs/output/logger.go b/cmd/ghcs/output/logger.go new file mode 100644 index 000000000..6ad7513f1 --- /dev/null +++ b/cmd/ghcs/output/logger.go @@ -0,0 +1,74 @@ +package output + +import ( + "fmt" + "io" + "sync" +) + +// NewLogger returns a Logger that will write to the given stdout/stderr writers. +// Disable the Logger to prevent it from writing to stdout in a TTY environment. +func NewLogger(stdout, stderr io.Writer, disabled bool) *Logger { + return &Logger{ + out: stdout, + errout: stderr, + enabled: !disabled && isTTY(stdout), + } +} + +// Logger writes to the given stdout/stderr writers. +// If not enabled, Print functions will noop but Error functions will continue +// to write to the stderr writer. +type Logger struct { + mu sync.Mutex // guards the writers + out io.Writer + errout io.Writer + enabled bool +} + +// Print writes the arguments to the stdout writer. +func (l *Logger) Print(v ...interface{}) (int, error) { + if !l.enabled { + return 0, nil + } + + l.mu.Lock() + defer l.mu.Unlock() + return fmt.Fprint(l.out, v...) +} + +// Println writes the arguments to the stdout writer with a newline at the end. +func (l *Logger) Println(v ...interface{}) (int, error) { + if !l.enabled { + return 0, nil + } + + l.mu.Lock() + defer l.mu.Unlock() + return fmt.Fprintln(l.out, v...) +} + +// Printf writes the formatted arguments to the stdout writer. +func (l *Logger) Printf(f string, v ...interface{}) (int, error) { + if !l.enabled { + return 0, nil + } + + l.mu.Lock() + defer l.mu.Unlock() + return fmt.Fprintf(l.out, f, v...) +} + +// Errorf writes the formatted arguments to the stderr writer. +func (l *Logger) Errorf(f string, v ...interface{}) (int, error) { + l.mu.Lock() + defer l.mu.Unlock() + return fmt.Fprintf(l.errout, f, v...) +} + +// Errorln writes the arguments to the stderr writer with a newline at the end. +func (l *Logger) Errorln(v ...interface{}) (int, error) { + l.mu.Lock() + defer l.mu.Unlock() + return fmt.Fprintln(l.errout, v...) +} diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go new file mode 100644 index 000000000..27467d150 --- /dev/null +++ b/cmd/ghcs/ports.go @@ -0,0 +1,330 @@ +package ghcs + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "net" + "os" + "strconv" + "strings" + + "github.com/cli/cli/v2/cmd/ghcs/output" + "github.com/cli/cli/v2/internal/api" + "github.com/cli/cli/v2/internal/codespaces" + "github.com/cli/cli/v2/internal/liveshare" + "github.com/muhammadmuzzammil1998/jsonc" + "github.com/spf13/cobra" + "golang.org/x/sync/errgroup" +) + +// newPortsCmd returns a Cobra "ports" command that displays a table of available ports, +// according to the specified flags. +func newPortsCmd(app *App) *cobra.Command { + var ( + codespace string + asJSON bool + ) + + portsCmd := &cobra.Command{ + Use: "ports", + Short: "List ports in a codespace", + Args: noArgsConstraint, + RunE: func(cmd *cobra.Command, args []string) error { + return app.ListPorts(cmd.Context(), codespace, asJSON) + }, + } + + portsCmd.PersistentFlags().StringVarP(&codespace, "codespace", "c", "", "Name of the codespace") + portsCmd.Flags().BoolVar(&asJSON, "json", false, "Output as JSON") + + portsCmd.AddCommand(newPortsPublicCmd(app)) + portsCmd.AddCommand(newPortsPrivateCmd(app)) + portsCmd.AddCommand(newPortsForwardCmd(app)) + + return portsCmd +} + +// ListPorts lists known ports in a codespace. +func (a *App) ListPorts(ctx context.Context, codespaceName string, asJSON bool) (err error) { + user, err := a.apiClient.GetUser(ctx) + if err != nil { + return fmt.Errorf("error getting user: %w", err) + } + + codespace, token, err := getOrChooseCodespace(ctx, a.apiClient, user, codespaceName) + if err != nil { + // TODO(josebalius): remove special handling of this error here and it other places + if err == errNoCodespaces { + return err + } + return fmt.Errorf("error choosing codespace: %w", err) + } + + devContainerCh := getDevContainer(ctx, a.apiClient, codespace) + + session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, user.Login, token, codespace) + if err != nil { + return fmt.Errorf("error connecting to Live Share: %w", err) + } + defer safeClose(session, &err) + + a.logger.Println("Loading ports...") + ports, err := session.GetSharedServers(ctx) + if err != nil { + return fmt.Errorf("error getting ports of shared servers: %w", err) + } + + devContainerResult := <-devContainerCh + if devContainerResult.err != nil { + // Warn about failure to read the devcontainer file. Not a ghcs command error. + _, _ = a.logger.Errorf("Failed to get port names: %v\n", devContainerResult.err.Error()) + } + + table := output.NewTable(os.Stdout, asJSON) + table.SetHeader([]string{"Label", "Port", "Public", "Browse URL"}) + for _, port := range ports { + sourcePort := strconv.Itoa(port.SourcePort) + var portName string + if devContainerResult.devContainer != nil { + if attributes, ok := devContainerResult.devContainer.PortAttributes[sourcePort]; ok { + portName = attributes.Label + } + } + + table.Append([]string{ + portName, + sourcePort, + strings.ToUpper(strconv.FormatBool(port.IsPublic)), + fmt.Sprintf("https://%s-%s.githubpreview.dev/", codespace.Name, sourcePort), + }) + } + table.Render() + + return nil +} + +type devContainerResult struct { + devContainer *devContainer + err error +} + +type devContainer struct { + PortAttributes map[string]portAttribute `json:"portsAttributes"` +} + +type portAttribute struct { + Label string `json:"label"` +} + +func getDevContainer(ctx context.Context, apiClient apiClient, codespace *api.Codespace) <-chan devContainerResult { + ch := make(chan devContainerResult, 1) + go func() { + contents, err := apiClient.GetCodespaceRepositoryContents(ctx, codespace, ".devcontainer/devcontainer.json") + if err != nil { + ch <- devContainerResult{nil, fmt.Errorf("error getting content: %w", err)} + return + } + + if contents == nil { + ch <- devContainerResult{nil, nil} + return + } + + convertedJSON := normalizeJSON(jsonc.ToJSON(contents)) + if !jsonc.Valid(convertedJSON) { + ch <- devContainerResult{nil, errors.New("failed to convert json to standard json")} + return + } + + var container devContainer + if err := json.Unmarshal(convertedJSON, &container); err != nil { + ch <- devContainerResult{nil, fmt.Errorf("error unmarshaling: %w", err)} + return + } + + ch <- devContainerResult{&container, nil} + }() + return ch +} + +// newPortsPublicCmd returns a Cobra "ports public" subcommand, which makes a given port public. +func newPortsPublicCmd(app *App) *cobra.Command { + return &cobra.Command{ + Use: "public ", + Short: "Mark port as public", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + codespace, err := cmd.Flags().GetString("codespace") + if err != nil { + // should only happen if flag is not defined + // or if the flag is not of string type + // since it's a persistent flag that we control it should never happen + return fmt.Errorf("get codespace flag: %w", err) + } + + return app.UpdatePortVisibility(cmd.Context(), codespace, args[0], true) + }, + } +} + +// newPortsPrivateCmd returns a Cobra "ports private" subcommand, which makes a given port private. +func newPortsPrivateCmd(app *App) *cobra.Command { + return &cobra.Command{ + Use: "private ", + Short: "Mark port as private", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + codespace, err := cmd.Flags().GetString("codespace") + if err != nil { + // should only happen if flag is not defined + // or if the flag is not of string type + // since it's a persistent flag that we control it should never happen + return fmt.Errorf("get codespace flag: %w", err) + } + + return app.UpdatePortVisibility(cmd.Context(), codespace, args[0], false) + }, + } +} + +func (a *App) UpdatePortVisibility(ctx context.Context, codespaceName, sourcePort string, public bool) (err error) { + user, err := a.apiClient.GetUser(ctx) + if err != nil { + return fmt.Errorf("error getting user: %w", err) + } + + codespace, token, err := getOrChooseCodespace(ctx, a.apiClient, user, codespaceName) + if err != nil { + if err == errNoCodespaces { + return err + } + return fmt.Errorf("error getting codespace: %w", err) + } + + session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, user.Login, token, codespace) + if err != nil { + return fmt.Errorf("error connecting to Live Share: %w", err) + } + defer safeClose(session, &err) + + port, err := strconv.Atoi(sourcePort) + if err != nil { + return fmt.Errorf("error reading port number: %w", err) + } + + if err := session.UpdateSharedVisibility(ctx, port, public); err != nil { + return fmt.Errorf("error update port to public: %w", err) + } + + state := "PUBLIC" + if !public { + state = "PRIVATE" + } + a.logger.Printf("Port %s is now %s.\n", sourcePort, state) + + return nil +} + +// NewPortsForwardCmd returns a Cobra "ports forward" subcommand, which forwards a set of +// port pairs from the codespace to localhost. +func newPortsForwardCmd(app *App) *cobra.Command { + return &cobra.Command{ + Use: "forward :...", + Short: "Forward ports", + Args: cobra.MinimumNArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + codespace, err := cmd.Flags().GetString("codespace") + if err != nil { + // should only happen if flag is not defined + // or if the flag is not of string type + // since it's a persistent flag that we control it should never happen + return fmt.Errorf("get codespace flag: %w", err) + } + + return app.ForwardPorts(cmd.Context(), codespace, args) + }, + } +} + +func (a *App) ForwardPorts(ctx context.Context, codespaceName string, ports []string) (err error) { + portPairs, err := getPortPairs(ports) + if err != nil { + return fmt.Errorf("get port pairs: %w", err) + } + + user, err := a.apiClient.GetUser(ctx) + if err != nil { + return fmt.Errorf("error getting user: %w", err) + } + + codespace, token, err := getOrChooseCodespace(ctx, a.apiClient, user, codespaceName) + if err != nil { + if err == errNoCodespaces { + return err + } + return fmt.Errorf("error getting codespace: %w", err) + } + + session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, user.Login, token, codespace) + if err != nil { + return fmt.Errorf("error connecting to Live Share: %w", err) + } + defer safeClose(session, &err) + + // Run forwarding of all ports concurrently, aborting all of + // them at the first failure, including cancellation of the context. + group, ctx := errgroup.WithContext(ctx) + for _, pair := range portPairs { + pair := pair + group.Go(func() error { + listen, err := net.Listen("tcp", fmt.Sprintf(":%d", pair.local)) + if err != nil { + return err + } + defer listen.Close() + a.logger.Printf("Forwarding ports: remote %d <=> local %d\n", pair.remote, pair.local) + name := fmt.Sprintf("share-%d", pair.remote) + fwd := liveshare.NewPortForwarder(session, name, pair.remote) + return fwd.ForwardToListener(ctx, listen) // error always non-nil + }) + } + return group.Wait() // first error +} + +type portPair struct { + remote, local int +} + +// getPortPairs parses a list of strings of form "%d:%d" into pairs of (remote, local) numbers. +func getPortPairs(ports []string) ([]portPair, error) { + pp := make([]portPair, 0, len(ports)) + + for _, portString := range ports { + parts := strings.Split(portString, ":") + if len(parts) < 2 { + return nil, fmt.Errorf("port pair: %q is not valid", portString) + } + + remote, err := strconv.Atoi(parts[0]) + if err != nil { + return pp, fmt.Errorf("convert remote port to int: %w", err) + } + + local, err := strconv.Atoi(parts[1]) + if err != nil { + return pp, fmt.Errorf("convert local port to int: %w", err) + } + + pp = append(pp, portPair{remote, local}) + } + + return pp, nil +} + +func normalizeJSON(j []byte) []byte { + // remove trailing commas + return bytes.ReplaceAll(j, []byte("},}"), []byte("}}")) +} diff --git a/cmd/ghcs/root.go b/cmd/ghcs/root.go new file mode 100644 index 000000000..c9fdd2876 --- /dev/null +++ b/cmd/ghcs/root.go @@ -0,0 +1,30 @@ +package ghcs + +import ( + "github.com/spf13/cobra" +) + +var version = "DEV" // Replaced in the release build process (by GoReleaser or Homebrew) by the git tag version number. + +func NewRootCmd(app *App) *cobra.Command { + root := &cobra.Command{ + Use: "ghcs", + SilenceUsage: true, // don't print usage message after each error (see #80) + SilenceErrors: false, // print errors automatically so that main need not + Long: `Unofficial CLI tool to manage GitHub Codespaces. + +Running commands requires the GITHUB_TOKEN environment variable to be set to a +token to access the GitHub API with.`, + Version: version, + } + + root.AddCommand(newCodeCmd(app)) + root.AddCommand(newCreateCmd(app)) + root.AddCommand(newDeleteCmd(app)) + root.AddCommand(newListCmd(app)) + root.AddCommand(newLogsCmd(app)) + root.AddCommand(newPortsCmd(app)) + root.AddCommand(newSSHCmd(app)) + + return root +} diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go new file mode 100644 index 000000000..5627952b6 --- /dev/null +++ b/cmd/ghcs/ssh.go @@ -0,0 +1,105 @@ +package ghcs + +import ( + "context" + "fmt" + "net" + + "github.com/cli/cli/v2/internal/codespaces" + "github.com/cli/cli/v2/internal/liveshare" + "github.com/spf13/cobra" +) + +func newSSHCmd(app *App) *cobra.Command { + var sshProfile, codespaceName string + var sshServerPort int + + sshCmd := &cobra.Command{ + Use: "ssh [flags] [--] [ssh-flags] [command]", + Short: "SSH into a codespace", + RunE: func(cmd *cobra.Command, args []string) error { + return app.SSH(cmd.Context(), args, sshProfile, codespaceName, sshServerPort) + }, + } + + sshCmd.Flags().StringVarP(&sshProfile, "profile", "", "", "Name of the SSH profile to use") + sshCmd.Flags().IntVarP(&sshServerPort, "server-port", "", 0, "SSH server port number (0 => pick unused)") + sshCmd.Flags().StringVarP(&codespaceName, "codespace", "c", "", "Name of the codespace") + + return sshCmd +} + +// SSH opens an ssh session or runs an ssh command in a codespace. +func (a *App) SSH(ctx context.Context, sshArgs []string, sshProfile, codespaceName string, localSSHServerPort int) (err error) { + // Ensure all child tasks (e.g. port forwarding) terminate before return. + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + user, err := a.apiClient.GetUser(ctx) + if err != nil { + return fmt.Errorf("error getting user: %w", err) + } + + authkeys := make(chan error, 1) + go func() { + authkeys <- checkAuthorizedKeys(ctx, a.apiClient, user.Login) + }() + + codespace, token, err := getOrChooseCodespace(ctx, a.apiClient, user, codespaceName) + if err != nil { + return fmt.Errorf("get or choose codespace: %w", err) + } + + session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, user.Login, token, codespace) + if err != nil { + return fmt.Errorf("error connecting to Live Share: %w", err) + } + defer safeClose(session, &err) + + if err := <-authkeys; err != nil { + return err + } + + a.logger.Println("Fetching SSH Details...") + remoteSSHServerPort, sshUser, err := session.StartSSHServer(ctx) + if err != nil { + return fmt.Errorf("error getting ssh server details: %w", err) + } + + usingCustomPort := localSSHServerPort != 0 // suppress log of command line in Shell + + // Ensure local port is listening before client (Shell) connects. + listen, err := net.Listen("tcp", fmt.Sprintf(":%d", localSSHServerPort)) + if err != nil { + return err + } + defer listen.Close() + localSSHServerPort = listen.Addr().(*net.TCPAddr).Port + + connectDestination := sshProfile + if connectDestination == "" { + connectDestination = fmt.Sprintf("%s@localhost", sshUser) + } + + a.logger.Println("Ready...") + tunnelClosed := make(chan error, 1) + go func() { + fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort) + tunnelClosed <- fwd.ForwardToListener(ctx, listen) // always non-nil + }() + + shellClosed := make(chan error, 1) + go func() { + shellClosed <- codespaces.Shell(ctx, a.logger, sshArgs, localSSHServerPort, connectDestination, usingCustomPort) + }() + + select { + case err := <-tunnelClosed: + return fmt.Errorf("tunnel closed: %w", err) + case err := <-shellClosed: + if err != nil { + return fmt.Errorf("shell closed: %w", err) + } + return nil // success + } +} diff --git a/go.mod b/go.mod index 1de8b09f9..8137a426a 100644 --- a/go.mod +++ b/go.mod @@ -12,9 +12,11 @@ require ( github.com/cli/safeexec v1.0.0 github.com/cpuguy83/go-md2man/v2 v2.0.0 github.com/creack/pty v1.1.13 + github.com/fatih/camelcase v1.0.0 github.com/gabriel-vasile/mimetype v1.1.2 github.com/google/go-cmp v0.5.5 github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 + github.com/gorilla/websocket v1.4.2 github.com/hashicorp/go-version v1.2.1 github.com/henvic/httpretty v0.0.6 github.com/itchyny/gojq v0.12.4 @@ -24,17 +26,24 @@ require ( github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d github.com/muesli/reflow v0.2.1-0.20210502190812-c80126ec2ad5 github.com/muesli/termenv v0.8.1 + github.com/muhammadmuzzammil1998/jsonc v0.0.0-20201229145248-615b0916ca38 + github.com/olekukonko/tablewriter v0.0.5 + github.com/opentracing/opentracing-go v1.1.0 github.com/shurcooL/githubv4 v0.0.0-20200928013246-d292edc3691b github.com/shurcooL/graphql v0.0.0-20181231061246-d48a9a75455f + github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 + github.com/sourcegraph/jsonrpc2 v0.1.0 github.com/spf13/cobra v1.2.1 github.com/spf13/pflag v1.0.5 github.com/stretchr/objx v0.1.1 // indirect github.com/stretchr/testify v1.7.0 - golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897 + golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 golang.org/x/sync v0.0.0-20210220032951-036812b2e83c - golang.org/x/sys v0.0.0-20210601080250-7ecdf8ef093b + golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 golang.org/x/term v0.0.0-20210503060354-a79de5458b56 gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b ) replace github.com/shurcooL/graphql => github.com/cli/shurcooL-graphql v0.0.0-20200707151639-0f7232a2bf7e + +replace golang.org/x/crypto => github.com/cli/crypto v0.0.0-20210929142629-6be313f59b03 diff --git a/go.sum b/go.sum index cf3f5b57b..a77ef812f 100644 --- a/go.sum +++ b/go.sum @@ -73,6 +73,8 @@ github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMn github.com/cli/browser v1.0.0/go.mod h1:IEWkHYbLjkhtjwwWlwTHW2lGxeS5gezEQBMLTwDHf5Q= github.com/cli/browser v1.1.0 h1:xOZBfkfY9L9vMBgqb1YwRirGu6QFaQ5dP/vXt5ENSOY= github.com/cli/browser v1.1.0/go.mod h1:HKMQAt9t12kov91Mn7RfZxyJQQgWgyS/3SZswlZ5iTI= +github.com/cli/crypto v0.0.0-20210929142629-6be313f59b03 h1:3f4uHLfWx4/WlnMPXGai03eoWAI+oGHJwr+5OXfxCr8= +github.com/cli/crypto v0.0.0-20210929142629-6be313f59b03/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= github.com/cli/oauth v0.8.0 h1:YTFgPXSTvvDUFti3tR4o6q7Oll2SnQ9ztLwCAn4/IOA= github.com/cli/oauth v0.8.0/go.mod h1:qd/FX8ZBD6n1sVNQO3aIdRxeu5LGw9WhKnYhIIoC2A4= github.com/cli/safeexec v1.0.0 h1:0VngyaIyqACHdcMNWfo6+KdUYnqEr2Sg+bSP1pdF+dI= @@ -103,6 +105,8 @@ github.com/envoyproxy/go-control-plane v0.9.7/go.mod h1:cwu0lG7PUMfa9snN8LXBig5y github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/fatih/camelcase v1.0.0 h1:hxNvNX/xYBp0ovncs8WyWZrOrpBNub/JfaMvbURyft8= +github.com/fatih/camelcase v1.0.0/go.mod h1:yN2Sb0lFhZJUdVvtELVWefmrXpuZESvPmqwoZc+/fpc= github.com/fatih/color v1.7.0 h1:DkWD4oS2D8LGGgTQ6IvwJJXSL5Vp2ffcQg58nFV38Ys= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= @@ -182,6 +186,9 @@ github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5m github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY= github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c= +github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= +github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/graph-gophers/graphql-go v0.0.0-20200622220639-c1d9693c95a6/go.mod h1:9CQHMSxwO4MprSdzoIEobiHpoLtHm77vfxsvsIN5Vuc= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/hashicorp/consul/api v1.1.0/go.mod h1:VmuI/Lkw1nC05EYQWNKwWGbkg+FbDBtguAZLlVdkD9Q= @@ -273,8 +280,11 @@ github.com/muesli/reflow v0.2.1-0.20210502190812-c80126ec2ad5 h1:T+Fc6qGlSfM+z0J github.com/muesli/reflow v0.2.1-0.20210502190812-c80126ec2ad5/go.mod h1:Xk+z4oIWdQqJzsxyjgl3P22oYZnHdZ8FFTHAQQt5BMQ= github.com/muesli/termenv v0.8.1 h1:9q230czSP3DHVpkaPDXGp0TOfAwyjyYwXlUCQxQSaBk= github.com/muesli/termenv v0.8.1/go.mod h1:kzt/D/4a88RoheZmwfqorY3A+tnsSMA9HJC/fQSFKo0= +github.com/muhammadmuzzammil1998/jsonc v0.0.0-20201229145248-615b0916ca38 h1:0FrBxrkJ0hVembTb/e4EU5Ml6vLcOusAqymmYISg5Uo= +github.com/muhammadmuzzammil1998/jsonc v0.0.0-20201229145248-615b0916ca38/go.mod h1:saF2fIVw4banK0H4+/EuqfFLpRnoy5S+ECwTOCcRcSU= github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= +github.com/opentracing/opentracing-go v1.1.0 h1:pWlfV3Bxv7k65HYwkikxat0+s3pV4bsqf19k25Ur8rU= github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= github.com/pelletier/go-toml v1.9.3/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= @@ -300,8 +310,12 @@ github.com/shurcooL/githubv4 v0.0.0-20200928013246-d292edc3691b h1:0/ecDXh/HTHRt github.com/shurcooL/githubv4 v0.0.0-20200928013246-d292edc3691b/go.mod h1:hAF0iLZy4td2EX+/8Tw+4nodhlMrwN3HupfaXj3zkGo= github.com/shurcooL/sanitized_anchor_name v1.0.0 h1:PdmoCO6wvbs+7yrJyMORt4/BmY5IYyJwS/kOiWx8mHo= github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= +github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 h1:JIAuq3EEf9cgbU6AtGPK4CTG3Zf6CKMNqf0MHTggAUA= +github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966/go.mod h1:sUM3LWHvSMaG192sy56D9F7CNvL7jUJVXoqM1QKLnog= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= +github.com/sourcegraph/jsonrpc2 v0.1.0 h1:ohJHjZ+PcaLxDUjqk2NC3tIGsVa5bXThe1ZheSXOjuk= +github.com/sourcegraph/jsonrpc2 v0.1.0/go.mod h1:ZafdZgk/axhT1cvZAPOhw+95nz2I/Ra5qMlU4gTRwIo= github.com/spf13/afero v1.6.0/go.mod h1:Ai8FlHk4v/PARR026UzYexafAt9roJ7LcLMAmO6Z93I= github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= github.com/spf13/cobra v1.2.1 h1:+KmjbUw1hriSNMF55oPrkZcb27aECyrj8V2ytv7kWDw= @@ -344,16 +358,6 @@ go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= go.uber.org/zap v1.17.0/go.mod h1:MXVU+bhUf/A7Xi2HNOnopQOrmycQ5Ih87HtOu4q5SSo= -golang.org/x/crypto v0.0.0-20181029021203-45a5f77698d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20190530122614-20be4c3c3ed5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897 h1:pLI5jrR7OSLijeIDcmRxNmw2api+jEfxLoykJVice/E= -golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -396,7 +400,6 @@ golang.org/x/net v0.0.0-20181201002055-351d144fa1fc/go.mod h1:mL1N/T3taQHkDXs73r golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190501004415-9ce7a6920f09/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= @@ -497,8 +500,9 @@ golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210601080250-7ecdf8ef093b h1:qh4f65QIVFjq9eBURLEYWqaEXmOyqdUyiBSgaXWccWk= golang.org/x/sys v0.0.0-20210601080250-7ecdf8ef093b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 h1:SrN+KX8Art/Sf4HNj6Zcz06G7VEz+7w9tdXTPOZ7+l4= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210503060354-a79de5458b56 h1:b8jxX3zqjpqb2LklXPzKSGJhzyxCOZSz8ncv8Nv+y7w= golang.org/x/term v0.0.0-20210503060354-a79de5458b56/go.mod h1:tfny5GFUkzUvx4ps4ajbZsCe5lw1metzhBm9T3x7oIY= diff --git a/internal/api/api.go b/internal/api/api.go new file mode 100644 index 000000000..cdf9b3f60 --- /dev/null +++ b/internal/api/api.go @@ -0,0 +1,640 @@ +package api + +// For descriptions of service interfaces, see: +// - https://online.visualstudio.com/api/swagger (for visualstudio.com) +// - https://docs.github.com/en/rest/reference/repos (for api.github.com) +// - https://github.com/github/github/blob/master/app/api/codespaces.rb (for vscs_internal) +// TODO(adonovan): replace the last link with a public doc URL when available. + +// TODO(adonovan): a possible reorganization would be to split this +// file into three internal packages, one per backend service, and to +// rename api.API to github.Client: +// +// - github.GetUser(github.Client) +// - github.GetRepository(Client) +// - github.ReadFile(Client, nwo, branch, path) // was GetCodespaceRepositoryContents +// - github.AuthorizedKeys(Client, user) +// - codespaces.Create(Client, user, repo, sku, branch, location) +// - codespaces.Delete(Client, user, token, name) +// - codespaces.Get(Client, token, owner, name) +// - codespaces.GetMachineTypes(Client, user, repo, branch, location) +// - codespaces.GetToken(Client, login, name) +// - codespaces.List(Client, user) +// - codespaces.Start(Client, token, codespace) +// - visualstudio.GetRegionLocation(http.Client) // no dependency on github +// +// This would make the meaning of each operation clearer. + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "net/http" + "strconv" + "strings" + "time" + + "github.com/opentracing/opentracing-go" +) + +const githubAPI = "https://api.github.com" + +type API struct { + token string + client httpClient + githubAPI string +} + +type httpClient interface { + Do(req *http.Request) (*http.Response, error) +} + +func New(token string, httpClient httpClient) *API { + return &API{ + token: token, + client: httpClient, + githubAPI: githubAPI, + } +} + +type User struct { + Login string `json:"login"` +} + +func (a *API) GetUser(ctx context.Context) (*User, error) { + req, err := http.NewRequest(http.MethodGet, a.githubAPI+"/user", nil) + if err != nil { + return nil, fmt.Errorf("error creating request: %w", err) + } + + a.setHeaders(req) + resp, err := a.do(ctx, req, "/user") + if err != nil { + return nil, fmt.Errorf("error making request: %w", err) + } + defer resp.Body.Close() + + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("error reading response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, jsonErrorResponse(b) + } + + var response User + if err := json.Unmarshal(b, &response); err != nil { + return nil, fmt.Errorf("error unmarshaling response: %w", err) + } + + return &response, nil +} + +func jsonErrorResponse(b []byte) error { + var response struct { + Message string `json:"message"` + } + if err := json.Unmarshal(b, &response); err != nil { + return fmt.Errorf("error unmarshaling error response: %w", err) + } + + return errors.New(response.Message) +} + +type Repository struct { + ID int `json:"id"` +} + +func (a *API) GetRepository(ctx context.Context, nwo string) (*Repository, error) { + req, err := http.NewRequest(http.MethodGet, a.githubAPI+"/repos/"+strings.ToLower(nwo), nil) + if err != nil { + return nil, fmt.Errorf("error creating request: %w", err) + } + + a.setHeaders(req) + resp, err := a.do(ctx, req, "/repos/*") + if err != nil { + return nil, fmt.Errorf("error making request: %w", err) + } + defer resp.Body.Close() + + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("error reading response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, jsonErrorResponse(b) + } + + var response Repository + if err := json.Unmarshal(b, &response); err != nil { + return nil, fmt.Errorf("error unmarshaling response: %w", err) + } + + return &response, nil +} + +type Codespace struct { + Name string `json:"name"` + GUID string `json:"guid"` + CreatedAt string `json:"created_at"` + LastUsedAt string `json:"last_used_at"` + Branch string `json:"branch"` + RepositoryName string `json:"repository_name"` + RepositoryNWO string `json:"repository_nwo"` + OwnerLogin string `json:"owner_login"` + Environment CodespaceEnvironment `json:"environment"` +} + +type CodespaceEnvironment struct { + State string `json:"state"` + Connection CodespaceEnvironmentConnection `json:"connection"` + GitStatus CodespaceEnvironmentGitStatus `json:"gitStatus"` +} + +type CodespaceEnvironmentGitStatus struct { + Ahead int `json:"ahead"` + Behind int `json:"behind"` + Branch string `json:"branch"` + Commit string `json:"commit"` + HasUnpushedChanges bool `json:"hasUnpushedChanges"` + HasUncommitedChanges bool `json:"hasUncommitedChanges"` +} + +const ( + CodespaceEnvironmentStateAvailable = "Available" +) + +type CodespaceEnvironmentConnection struct { + SessionID string `json:"sessionId"` + SessionToken string `json:"sessionToken"` + RelayEndpoint string `json:"relayEndpoint"` + RelaySAS string `json:"relaySas"` + HostPublicKeys []string `json:"hostPublicKeys"` +} + +func (a *API) ListCodespaces(ctx context.Context, user string) ([]*Codespace, error) { + req, err := http.NewRequest( + http.MethodGet, a.githubAPI+"/vscs_internal/user/"+user+"/codespaces", nil, + ) + if err != nil { + return nil, fmt.Errorf("error creating request: %w", err) + } + + a.setHeaders(req) + resp, err := a.do(ctx, req, "/vscs_internal/user/*/codespaces") + if err != nil { + return nil, fmt.Errorf("error making request: %w", err) + } + defer resp.Body.Close() + + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("error reading response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, jsonErrorResponse(b) + } + + var response struct { + Codespaces []*Codespace `json:"codespaces"` + } + if err := json.Unmarshal(b, &response); err != nil { + return nil, fmt.Errorf("error unmarshaling response: %w", err) + } + return response.Codespaces, nil +} + +type getCodespaceTokenRequest struct { + MintRepositoryToken bool `json:"mint_repository_token"` +} + +type getCodespaceTokenResponse struct { + RepositoryToken string `json:"repository_token"` +} + +// ErrNotProvisioned is returned by GetCodespacesToken to indicate that the +// creation of a codespace is not yet complete and that the caller should try again. +var ErrNotProvisioned = errors.New("codespace not provisioned") + +func (a *API) GetCodespaceToken(ctx context.Context, ownerLogin, codespaceName string) (string, error) { + reqBody, err := json.Marshal(getCodespaceTokenRequest{true}) + if err != nil { + return "", fmt.Errorf("error preparing request body: %w", err) + } + + req, err := http.NewRequest( + http.MethodPost, + a.githubAPI+"/vscs_internal/user/"+ownerLogin+"/codespaces/"+codespaceName+"/token", + bytes.NewBuffer(reqBody), + ) + if err != nil { + return "", fmt.Errorf("error creating request: %w", err) + } + + a.setHeaders(req) + resp, err := a.do(ctx, req, "/vscs_internal/user/*/codespaces/*/token") + if err != nil { + return "", fmt.Errorf("error making request: %w", err) + } + defer resp.Body.Close() + + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("error reading response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + if resp.StatusCode == http.StatusUnprocessableEntity { + return "", ErrNotProvisioned + } + + return "", jsonErrorResponse(b) + } + + var response getCodespaceTokenResponse + if err := json.Unmarshal(b, &response); err != nil { + return "", fmt.Errorf("error unmarshaling response: %w", err) + } + + return response.RepositoryToken, nil +} + +func (a *API) GetCodespace(ctx context.Context, token, owner, codespace string) (*Codespace, error) { + req, err := http.NewRequest( + http.MethodGet, + a.githubAPI+"/vscs_internal/user/"+owner+"/codespaces/"+codespace, + nil, + ) + if err != nil { + return nil, fmt.Errorf("error creating request: %w", err) + } + + // TODO: use a.setHeaders() + req.Header.Set("Authorization", "Bearer "+token) + resp, err := a.do(ctx, req, "/vscs_internal/user/*/codespaces/*") + if err != nil { + return nil, fmt.Errorf("error making request: %w", err) + } + defer resp.Body.Close() + + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("error reading response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, jsonErrorResponse(b) + } + + var response Codespace + if err := json.Unmarshal(b, &response); err != nil { + return nil, fmt.Errorf("error unmarshaling response: %w", err) + } + + return &response, nil +} + +func (a *API) StartCodespace(ctx context.Context, token string, codespace *Codespace) error { + req, err := http.NewRequest( + http.MethodPost, + a.githubAPI+"/vscs_internal/proxy/environments/"+codespace.GUID+"/start", + nil, + ) + if err != nil { + return fmt.Errorf("error creating request: %w", err) + } + + // TODO: use a.setHeaders() + req.Header.Set("Authorization", "Bearer "+token) + resp, err := a.do(ctx, req, "/vscs_internal/proxy/environments/*/start") + if err != nil { + return fmt.Errorf("error making request: %w", err) + } + defer resp.Body.Close() + + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("error reading response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + // Error response may be a numeric code or a JSON {"message": "..."}. + if bytes.HasPrefix(b, []byte("{")) { + return jsonErrorResponse(b) // probably JSON + } + if len(b) > 100 { + b = append(b[:97], "..."...) + } + if strings.TrimSpace(string(b)) == "7" { + // Non-HTTP 200 with error code 7 (EnvironmentNotShutdown) is benign. + // Ignore it. + } else { + return fmt.Errorf("failed to start codespace: %s", b) + } + } + + return nil +} + +type getCodespaceRegionLocationResponse struct { + Current string `json:"current"` +} + +func (a *API) GetCodespaceRegionLocation(ctx context.Context) (string, error) { + req, err := http.NewRequest(http.MethodGet, "https://online.visualstudio.com/api/v1/locations", nil) + if err != nil { + return "", fmt.Errorf("error creating request: %w", err) + } + + resp, err := a.do(ctx, req, req.URL.String()) + if err != nil { + return "", fmt.Errorf("error making request: %w", err) + } + defer resp.Body.Close() + + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("error reading response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return "", jsonErrorResponse(b) + } + + var response getCodespaceRegionLocationResponse + if err := json.Unmarshal(b, &response); err != nil { + return "", fmt.Errorf("error unmarshaling response: %w", err) + } + + return response.Current, nil +} + +type SKU struct { + Name string `json:"name"` + DisplayName string `json:"display_name"` +} + +func (a *API) GetCodespacesSKUs(ctx context.Context, user *User, repository *Repository, branch, location string) ([]*SKU, error) { + req, err := http.NewRequest(http.MethodGet, a.githubAPI+"/vscs_internal/user/"+user.Login+"/skus", nil) + if err != nil { + return nil, fmt.Errorf("error creating request: %w", err) + } + + q := req.URL.Query() + q.Add("location", location) + q.Add("ref", branch) + q.Add("repository_id", strconv.Itoa(repository.ID)) + req.URL.RawQuery = q.Encode() + + a.setHeaders(req) + resp, err := a.do(ctx, req, "/vscs_internal/user/*/skus") + if err != nil { + return nil, fmt.Errorf("error making request: %w", err) + } + defer resp.Body.Close() + + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("error reading response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, jsonErrorResponse(b) + } + + var response struct { + SKUs []*SKU `json:"skus"` + } + if err := json.Unmarshal(b, &response); err != nil { + return nil, fmt.Errorf("error unmarshaling response: %w", err) + } + + return response.SKUs, nil +} + +// CreateCodespaceParams are the required parameters for provisioning a Codespace. +type CreateCodespaceParams struct { + User string + RepositoryID int + Branch, Machine, Location string +} + +// CreateCodespace creates a codespace with the given parameters and returns a non-nil error if it +// fails to create. +func (a *API) CreateCodespace(ctx context.Context, params *CreateCodespaceParams) (*Codespace, error) { + codespace, err := a.startCreate( + ctx, params.User, params.RepositoryID, params.Machine, params.Branch, params.Location, + ) + if err != errProvisioningInProgress { + return codespace, err + } + + // errProvisioningInProgress indicates that codespace creation did not complete + // within the GitHub API RPC time limit (10s), so it continues asynchronously. + // We must poll the server to discover the outcome. + ctx, cancel := context.WithTimeout(ctx, 2*time.Minute) + defer cancel() + + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-ticker.C: + token, err := a.GetCodespaceToken(ctx, params.User, codespace.Name) + if err != nil { + if err == ErrNotProvisioned { + // Do nothing. We expect this to fail until the codespace is provisioned + continue + } + + return nil, fmt.Errorf("failed to get codespace token: %w", err) + } + + codespace, err = a.GetCodespace(ctx, token, params.User, codespace.Name) + if err != nil { + return nil, fmt.Errorf("failed to get codespace: %w", err) + } + + return codespace, nil + } + } +} + +type startCreateRequest struct { + RepositoryID int `json:"repository_id"` + Ref string `json:"ref"` + Location string `json:"location"` + SkuName string `json:"sku_name"` +} + +var errProvisioningInProgress = errors.New("provisioning in progress") + +// startCreate starts the creation of a codespace. +// It may return success or an error, or errProvisioningInProgress indicating that the operation +// did not complete before the GitHub API's time limit for RPCs (10s), in which case the caller +// must poll the server to learn the outcome. +func (a *API) startCreate(ctx context.Context, user string, repository int, sku, branch, location string) (*Codespace, error) { + requestBody, err := json.Marshal(startCreateRequest{repository, branch, location, sku}) + if err != nil { + return nil, fmt.Errorf("error marshaling request: %w", err) + } + + req, err := http.NewRequest(http.MethodPost, a.githubAPI+"/vscs_internal/user/"+user+"/codespaces", bytes.NewBuffer(requestBody)) + if err != nil { + return nil, fmt.Errorf("error creating request: %w", err) + } + + a.setHeaders(req) + resp, err := a.do(ctx, req, "/vscs_internal/user/*/codespaces") + if err != nil { + return nil, fmt.Errorf("error making request: %w", err) + } + defer resp.Body.Close() + + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("error reading response body: %w", err) + } + + switch { + case resp.StatusCode > http.StatusAccepted: + return nil, jsonErrorResponse(b) + case resp.StatusCode == http.StatusAccepted: + return nil, errProvisioningInProgress // RPC finished before result of creation known + } + + var response Codespace + if err := json.Unmarshal(b, &response); err != nil { + return nil, fmt.Errorf("error unmarshaling response: %w", err) + } + + return &response, nil +} + +func (a *API) DeleteCodespace(ctx context.Context, user string, codespaceName string) error { + token, err := a.GetCodespaceToken(ctx, user, codespaceName) + if err != nil { + return fmt.Errorf("error getting codespace token: %w", err) + } + + req, err := http.NewRequest(http.MethodDelete, a.githubAPI+"/vscs_internal/user/"+user+"/codespaces/"+codespaceName, nil) + if err != nil { + return fmt.Errorf("error creating request: %w", err) + } + + // TODO: use a.setHeaders() + req.Header.Set("Authorization", "Bearer "+token) + resp, err := a.do(ctx, req, "/vscs_internal/user/*/codespaces/*") + if err != nil { + return fmt.Errorf("error making request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode > http.StatusAccepted { + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("error reading response body: %w", err) + } + return jsonErrorResponse(b) + } + + return nil +} + +type getCodespaceRepositoryContentsResponse struct { + Content string `json:"content"` +} + +func (a *API) GetCodespaceRepositoryContents(ctx context.Context, codespace *Codespace, path string) ([]byte, error) { + req, err := http.NewRequest(http.MethodGet, a.githubAPI+"/repos/"+codespace.RepositoryNWO+"/contents/"+path, nil) + if err != nil { + return nil, fmt.Errorf("error creating request: %w", err) + } + + q := req.URL.Query() + q.Add("ref", codespace.Branch) + req.URL.RawQuery = q.Encode() + + a.setHeaders(req) + resp, err := a.do(ctx, req, "/repos/*/contents/*") + if err != nil { + return nil, fmt.Errorf("error making request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusNotFound { + return nil, nil + } + + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("error reading response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, jsonErrorResponse(b) + } + + var response getCodespaceRepositoryContentsResponse + if err := json.Unmarshal(b, &response); err != nil { + return nil, fmt.Errorf("error unmarshaling response: %w", err) + } + + decoded, err := base64.StdEncoding.DecodeString(response.Content) + if err != nil { + return nil, fmt.Errorf("error decoding content: %w", err) + } + + return decoded, nil +} + +// AuthorizedKeys returns the public keys (in ~/.ssh/authorized_keys +// format) registered by the specified GitHub user. +func (a *API) AuthorizedKeys(ctx context.Context, user string) ([]byte, error) { + url := fmt.Sprintf("https://github.com/%s.keys", user) + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return nil, err + } + resp, err := a.do(ctx, req, "/user.keys") + if err != nil { + return nil, err + } + defer resp.Body.Close() + + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("error reading response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("server returned %s", resp.Status) + } + return b, nil +} + +func (a *API) do(ctx context.Context, req *http.Request, spanName string) (*http.Response, error) { + // TODO(adonovan): use NewRequestWithContext(ctx) and drop ctx parameter. + span, ctx := opentracing.StartSpanFromContext(ctx, spanName) + defer span.Finish() + req = req.WithContext(ctx) + return a.client.Do(req) +} + +func (a *API) setHeaders(req *http.Request) { + if a.token != "" { + req.Header.Set("Authorization", "Bearer "+a.token) + } + req.Header.Set("Accept", "application/vnd.github.v3+json") +} diff --git a/internal/api/api_test.go b/internal/api/api_test.go new file mode 100644 index 000000000..6fb162030 --- /dev/null +++ b/internal/api/api_test.go @@ -0,0 +1,50 @@ +package api + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" +) + +func TestListCodespaces(t *testing.T) { + codespaces := []*Codespace{ + { + Name: "testcodespace", + CreatedAt: "2021-08-09T10:10:24+02:00", + LastUsedAt: "2021-08-09T13:10:24+02:00", + }, + } + svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response := struct { + Codespaces []*Codespace `json:"codespaces"` + }{ + Codespaces: codespaces, + } + data, _ := json.Marshal(response) + fmt.Fprint(w, string(data)) + })) + defer svr.Close() + + api := API{ + githubAPI: svr.URL, + client: &http.Client{}, + token: "faketoken", + } + ctx := context.TODO() + codespaces, err := api.ListCodespaces(ctx, "testuser") + if err != nil { + t.Fatal(err) + } + + if len(codespaces) != 1 { + t.Fatalf("expected 1 codespace, got %d", len(codespaces)) + } + + if codespaces[0].Name != "testcodespace" { + t.Fatalf("expected testcodespace, got %s", codespaces[0].Name) + } + +} diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go new file mode 100644 index 000000000..f127cb362 --- /dev/null +++ b/internal/codespaces/codespaces.go @@ -0,0 +1,77 @@ +package codespaces + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/cli/cli/v2/internal/api" + "github.com/cli/cli/v2/internal/liveshare" +) + +type logger interface { + Print(v ...interface{}) (int, error) + Println(v ...interface{}) (int, error) +} + +func connectionReady(codespace *api.Codespace) bool { + return codespace.Environment.Connection.SessionID != "" && + codespace.Environment.Connection.SessionToken != "" && + codespace.Environment.Connection.RelayEndpoint != "" && + codespace.Environment.Connection.RelaySAS != "" && + codespace.Environment.State == api.CodespaceEnvironmentStateAvailable +} + +type apiClient interface { + GetCodespace(ctx context.Context, token, user, name string) (*api.Codespace, error) + GetCodespaceToken(ctx context.Context, user, codespace string) (string, error) + StartCodespace(ctx context.Context, token string, codespace *api.Codespace) error +} + +// ConnectToLiveshare waits for a Codespace to become running, +// and connects to it using a Live Share session. +func ConnectToLiveshare(ctx context.Context, log logger, apiClient apiClient, userLogin, token string, codespace *api.Codespace) (*liveshare.Session, error) { + var startedCodespace bool + if codespace.Environment.State != api.CodespaceEnvironmentStateAvailable { + startedCodespace = true + log.Print("Starting your codespace...") + if err := apiClient.StartCodespace(ctx, token, codespace); err != nil { + return nil, fmt.Errorf("error starting codespace: %w", err) + } + } + + for retries := 0; !connectionReady(codespace); retries++ { + if retries > 1 { + if retries%2 == 0 { + log.Print(".") + } + + time.Sleep(1 * time.Second) + } + + if retries == 30 { + return nil, errors.New("timed out while waiting for the codespace to start") + } + + var err error + codespace, err = apiClient.GetCodespace(ctx, token, userLogin, codespace.Name) + if err != nil { + return nil, fmt.Errorf("error getting codespace: %w", err) + } + } + + if startedCodespace { + fmt.Print("\n") + } + + log.Println("Connecting to your codespace...") + + return liveshare.Connect(ctx, liveshare.Options{ + SessionID: codespace.Environment.Connection.SessionID, + SessionToken: codespace.Environment.Connection.SessionToken, + RelaySAS: codespace.Environment.Connection.RelaySAS, + RelayEndpoint: codespace.Environment.Connection.RelayEndpoint, + HostPublicKeys: codespace.Environment.Connection.HostPublicKeys, + }) +} diff --git a/internal/codespaces/ssh.go b/internal/codespaces/ssh.go new file mode 100644 index 000000000..36c8bf5b2 --- /dev/null +++ b/internal/codespaces/ssh.go @@ -0,0 +1,90 @@ +package codespaces + +import ( + "context" + "fmt" + "os" + "os/exec" + "strconv" + "strings" +) + +// Shell runs an interactive secure shell over an existing +// port-forwarding session. It runs until the shell is terminated +// (including by cancellation of the context). +func Shell(ctx context.Context, log logger, sshArgs []string, port int, destination string, usingCustomPort bool) error { + cmd, connArgs, err := newSSHCommand(ctx, port, destination, sshArgs) + if err != nil { + return fmt.Errorf("failed to create ssh command: %w", err) + } + + if usingCustomPort { + log.Println("Connection Details: ssh " + destination + " " + strings.Join(connArgs, " ")) + } + + return cmd.Run() +} + +// NewRemoteCommand returns an exec.Cmd that will securely run a shell +// command on the remote machine. +func NewRemoteCommand(ctx context.Context, tunnelPort int, destination string, sshArgs ...string) (*exec.Cmd, error) { + cmd, _, err := newSSHCommand(ctx, tunnelPort, destination, sshArgs) + return cmd, err +} + +// newSSHCommand populates an exec.Cmd to run a command (or if blank, +// an interactive shell) over ssh. +func newSSHCommand(ctx context.Context, port int, dst string, cmdArgs []string) (*exec.Cmd, []string, error) { + connArgs := []string{"-p", strconv.Itoa(port), "-o", "NoHostAuthenticationForLocalhost=yes"} + + // The ssh command syntax is: ssh [flags] user@host command [args...] + // There is no way to specify the user@host destination as a flag. + // Unfortunately, that means we need to know which user-provided words are + // SSH flags and which are command arguments so that we can place + // them before or after the destination, and that means we need to know all + // the flags and their arities. + cmdArgs, command, err := parseSSHArgs(cmdArgs) + if err != nil { + return nil, nil, err + } + + cmdArgs = append(cmdArgs, connArgs...) + cmdArgs = append(cmdArgs, "-C") // Compression + cmdArgs = append(cmdArgs, dst) // user@host + + if command != nil { + cmdArgs = append(cmdArgs, command...) + } + + cmd := exec.CommandContext(ctx, "ssh", cmdArgs...) + cmd.Stdout = os.Stdout + cmd.Stdin = os.Stdin + cmd.Stderr = os.Stderr + + return cmd, connArgs, nil +} + +// parseSSHArgs parses SSH arguments into two distinct slices of flags and command. +// It returns an error if a unary flag is provided without an argument. +func parseSSHArgs(args []string) (cmdArgs, command []string, err error) { + for i := 0; i < len(args); i++ { + arg := args[i] + + // if we've started parsing the command, set it to the rest of the args + if !strings.HasPrefix(arg, "-") { + command = args[i:] + break + } + + cmdArgs = append(cmdArgs, arg) + if len(arg) == 2 && strings.Contains("bcDeFIiLlmOopRSWw", arg[1:2]) { + if i++; i == len(args) { + return nil, nil, fmt.Errorf("ssh flag: %s requires an argument", arg) + } + + cmdArgs = append(cmdArgs, args[i]) + } + } + + return cmdArgs, command, nil +} diff --git a/internal/codespaces/ssh_test.go b/internal/codespaces/ssh_test.go new file mode 100644 index 000000000..c804f6000 --- /dev/null +++ b/internal/codespaces/ssh_test.go @@ -0,0 +1,105 @@ +package codespaces + +import ( + "fmt" + "testing" +) + +func TestParseSSHArgs(t *testing.T) { + type testCase struct { + Args []string + ParsedArgs []string + Command []string + Error string + } + + testCases := []testCase{ + {}, // empty test case + { + Args: []string{"-X", "-Y"}, + ParsedArgs: []string{"-X", "-Y"}, + Command: nil, + }, + { + Args: []string{"-X", "-Y", "-o", "someoption=test"}, + ParsedArgs: []string{"-X", "-Y", "-o", "someoption=test"}, + Command: nil, + }, + { + Args: []string{"-X", "-Y", "-o", "someoption=test", "somecommand"}, + ParsedArgs: []string{"-X", "-Y", "-o", "someoption=test"}, + Command: []string{"somecommand"}, + }, + { + Args: []string{"-X", "-Y", "-o", "someoption=test", "echo", "test"}, + ParsedArgs: []string{"-X", "-Y", "-o", "someoption=test"}, + Command: []string{"echo", "test"}, + }, + { + Args: []string{"somecommand"}, + ParsedArgs: []string{}, + Command: []string{"somecommand"}, + }, + { + Args: []string{"echo", "test"}, + ParsedArgs: []string{}, + Command: []string{"echo", "test"}, + }, + { + Args: []string{"-v", "echo", "hello", "world"}, + ParsedArgs: []string{"-v"}, + Command: []string{"echo", "hello", "world"}, + }, + { + Args: []string{"-L", "-l"}, + ParsedArgs: []string{"-L", "-l"}, + Command: nil, + }, + { + Args: []string{"-v", "echo", "-n", "test"}, + ParsedArgs: []string{"-v"}, + Command: []string{"echo", "-n", "test"}, + }, + { + Args: []string{"-v", "echo", "-b", "test"}, + ParsedArgs: []string{"-v"}, + Command: []string{"echo", "-b", "test"}, + }, + { + Args: []string{"-b"}, + ParsedArgs: nil, + Command: nil, + Error: "ssh flag: -b requires an argument", + }, + } + + for _, tcase := range testCases { + args, command, err := parseSSHArgs(tcase.Args) + if tcase.Error != "" { + if err == nil { + t.Errorf("expected error and got nil: %#v", tcase) + } + + if err.Error() != tcase.Error { + t.Errorf("error does not match expected error, got: '%s', expected: '%s'", err.Error(), tcase.Error) + } + + continue + } + + if err != nil { + t.Errorf("unexpected error: %v on test case: %#v", err, tcase) + continue + } + + argsStr, parsedArgsStr := fmt.Sprintf("%s", args), fmt.Sprintf("%s", tcase.ParsedArgs) + if argsStr != parsedArgsStr { + t.Errorf("args do not match parsed args. got: '%s', expected: '%s'", argsStr, parsedArgsStr) + } + + commandStr, parsedCommandStr := fmt.Sprintf("%s", command), fmt.Sprintf("%s", tcase.Command) + if commandStr != parsedCommandStr { + t.Errorf("command does not match parsed command. got: '%s', expected: '%s'", commandStr, parsedCommandStr) + } + } +} diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go new file mode 100644 index 000000000..7bd53b5e0 --- /dev/null +++ b/internal/codespaces/states.go @@ -0,0 +1,118 @@ +package codespaces + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net" + "strings" + "time" + + "github.com/cli/cli/v2/internal/api" + "github.com/cli/cli/v2/internal/liveshare" +) + +// PostCreateStateStatus is a string value representing the different statuses a state can have. +type PostCreateStateStatus string + +func (p PostCreateStateStatus) String() string { + return strings.Title(string(p)) +} + +const ( + PostCreateStateRunning PostCreateStateStatus = "running" + PostCreateStateSuccess PostCreateStateStatus = "succeeded" + PostCreateStateFailed PostCreateStateStatus = "failed" +) + +// PostCreateState is a combination of a state and status value that is captured +// during codespace creation. +type PostCreateState struct { + Name string `json:"name"` + Status PostCreateStateStatus `json:"status"` +} + +// PollPostCreateStates watches for state changes in a codespace, +// and calls the supplied poller for each batch of state changes. +// It runs until it encounters an error, including cancellation of the context. +func PollPostCreateStates(ctx context.Context, log logger, apiClient apiClient, user *api.User, codespace *api.Codespace, poller func([]PostCreateState)) (err error) { + token, err := apiClient.GetCodespaceToken(ctx, user.Login, codespace.Name) + if err != nil { + return fmt.Errorf("getting codespace token: %w", err) + } + + session, err := ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) + if err != nil { + return fmt.Errorf("connect to Live Share: %w", err) + } + defer func() { + if closeErr := session.Close(); err == nil { + err = closeErr + } + }() + + // Ensure local port is listening before client (getPostCreateOutput) connects. + listen, err := net.Listen("tcp", ":0") // arbitrary port + if err != nil { + return err + } + localPort := listen.Addr().(*net.TCPAddr).Port + + log.Println("Fetching SSH Details...") + remoteSSHServerPort, sshUser, err := session.StartSSHServer(ctx) + if err != nil { + return fmt.Errorf("error getting ssh server details: %w", err) + } + + tunnelClosed := make(chan error, 1) // buffered to avoid sender stuckness + go func() { + fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort) + tunnelClosed <- fwd.ForwardToListener(ctx, listen) // error is non-nil + }() + + t := time.NewTicker(1 * time.Second) + defer t.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + + case err := <-tunnelClosed: + return fmt.Errorf("connection failed: %w", err) + + case <-t.C: + states, err := getPostCreateOutput(ctx, localPort, codespace, sshUser) + if err != nil { + return fmt.Errorf("get post create output: %w", err) + } + + poller(states) + } + } +} + +func getPostCreateOutput(ctx context.Context, tunnelPort int, codespace *api.Codespace, user string) ([]PostCreateState, error) { + cmd, err := NewRemoteCommand( + ctx, tunnelPort, fmt.Sprintf("%s@localhost", user), + "cat /workspaces/.codespaces/shared/postCreateOutput.json", + ) + if err != nil { + return nil, fmt.Errorf("remote command: %w", err) + } + + stdout := new(bytes.Buffer) + cmd.Stdout = stdout + if err := cmd.Run(); err != nil { + return nil, fmt.Errorf("run command: %w", err) + } + var output struct { + Steps []PostCreateState `json:"steps"` + } + if err := json.Unmarshal(stdout.Bytes(), &output); err != nil { + return nil, fmt.Errorf("unmarshal output: %w", err) + } + + return output.Steps, nil +} diff --git a/internal/liveshare/client.go b/internal/liveshare/client.go new file mode 100644 index 000000000..913f19195 --- /dev/null +++ b/internal/liveshare/client.go @@ -0,0 +1,150 @@ +// Package liveshare is a Go client library for the Visual Studio Live Share +// service, which provides collaborative, distibuted editing and debugging. +// See https://docs.microsoft.com/en-us/visualstudio/liveshare for an overview. +// +// It provides the ability for a Go program to connect to a Live Share +// workspace (Connect), to expose a TCP port on a remote host +// (UpdateSharedVisibility), to start an SSH server listening on an +// exposed port (StartSSHServer), and to forward connections between +// the remote port and a local listening TCP port (ForwardToListener) +// or a local Go reader/writer (Forward). +package liveshare + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "net/url" + "strings" + + "github.com/opentracing/opentracing-go" + "golang.org/x/crypto/ssh" +) + +// An Options specifies Live Share connection parameters. +type Options struct { + SessionID string + SessionToken string // token for SSH session + RelaySAS string + RelayEndpoint string + HostPublicKeys []string + TLSConfig *tls.Config // (optional) +} + +// uri returns a websocket URL for the specified options. +func (opts *Options) uri(action string) (string, error) { + if opts.SessionID == "" { + return "", errors.New("SessionID is required") + } + if opts.RelaySAS == "" { + return "", errors.New("RelaySAS is required") + } + if opts.RelayEndpoint == "" { + return "", errors.New("RelayEndpoint is required") + } + + sas := url.QueryEscape(opts.RelaySAS) + uri := opts.RelayEndpoint + uri = strings.Replace(uri, "sb:", "wss:", -1) + uri = strings.Replace(uri, ".net/", ".net:443/$hc/", 1) + uri = uri + "?sb-hc-action=" + action + "&sb-hc-token=" + sas + return uri, nil +} + +// Connect connects to a Live Share workspace specified by the +// options, and returns a session representing the connection. +// The caller must call the session's Close method to end the session. +func Connect(ctx context.Context, opts Options) (*Session, error) { + uri, err := opts.uri("connect") + if err != nil { + return nil, err + } + + span, ctx := opentracing.StartSpanFromContext(ctx, "Connect") + defer span.Finish() + + sock := newSocket(uri, opts.TLSConfig) + if err := sock.connect(ctx); err != nil { + return nil, fmt.Errorf("error connecting websocket: %w", err) + } + + if opts.SessionToken == "" { + return nil, errors.New("SessionToken is required") + } + ssh := newSSHSession(opts.SessionToken, opts.HostPublicKeys, sock) + if err := ssh.connect(ctx); err != nil { + return nil, fmt.Errorf("error connecting to ssh session: %w", err) + } + + rpc := newRPCClient(ssh) + rpc.connect(ctx) + + args := joinWorkspaceArgs{ + ID: opts.SessionID, + ConnectionMode: "local", + JoiningUserSessionToken: opts.SessionToken, + ClientCapabilities: clientCapabilities{ + IsNonInteractive: false, + }, + } + var result joinWorkspaceResult + if err := rpc.do(ctx, "workspace.joinWorkspace", &args, &result); err != nil { + return nil, fmt.Errorf("error joining Live Share workspace: %w", err) + } + + return &Session{ssh: ssh, rpc: rpc}, nil +} + +type clientCapabilities struct { + IsNonInteractive bool `json:"isNonInteractive"` +} + +type joinWorkspaceArgs struct { + ID string `json:"id"` + ConnectionMode string `json:"connectionMode"` + JoiningUserSessionToken string `json:"joiningUserSessionToken"` + ClientCapabilities clientCapabilities `json:"clientCapabilities"` +} + +type joinWorkspaceResult struct { + SessionNumber int `json:"sessionNumber"` +} + +// A channelID is an identifier for an exposed port on a remote +// container that may be used to open an SSH channel to it. +type channelID struct { + name, condition string +} + +func (s *Session) openStreamingChannel(ctx context.Context, id channelID) (ssh.Channel, error) { + type getStreamArgs struct { + StreamName string `json:"streamName"` + Condition string `json:"condition"` + } + args := getStreamArgs{ + StreamName: id.name, + Condition: id.condition, + } + var streamID string + if err := s.rpc.do(ctx, "streamManager.getStream", args, &streamID); err != nil { + return nil, fmt.Errorf("error getting stream id: %w", err) + } + + span, ctx := opentracing.StartSpanFromContext(ctx, "Session.OpenChannel+SendRequest") + defer span.Finish() + _ = ctx // ctx is not currently used + + channel, reqs, err := s.ssh.conn.OpenChannel("session", nil) + if err != nil { + return nil, fmt.Errorf("error opening ssh channel for transport: %w", err) + } + go ssh.DiscardRequests(reqs) + + requestType := fmt.Sprintf("stream-transport-%s", streamID) + if _, err = channel.SendRequest(requestType, true, nil); err != nil { + return nil, fmt.Errorf("error sending channel request: %w", err) + } + + return channel, nil +} diff --git a/internal/liveshare/client_test.go b/internal/liveshare/client_test.go new file mode 100644 index 000000000..9db20ed1c --- /dev/null +++ b/internal/liveshare/client_test.go @@ -0,0 +1,72 @@ +package liveshare + +import ( + "context" + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "strings" + "testing" + + livesharetest "github.com/cli/cli/v2/internal/liveshare/test" + "github.com/sourcegraph/jsonrpc2" +) + +func TestConnect(t *testing.T) { + opts := Options{ + SessionID: "session-id", + SessionToken: "session-token", + RelaySAS: "relay-sas", + HostPublicKeys: []string{livesharetest.SSHPublicKey}, + } + joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) { + var joinWorkspaceReq joinWorkspaceArgs + if err := json.Unmarshal(*req.Params, &joinWorkspaceReq); err != nil { + return nil, fmt.Errorf("error unmarshaling req: %w", err) + } + if joinWorkspaceReq.ID != opts.SessionID { + return nil, errors.New("connection session id does not match") + } + if joinWorkspaceReq.ConnectionMode != "local" { + return nil, errors.New("connection mode is not local") + } + if joinWorkspaceReq.JoiningUserSessionToken != opts.SessionToken { + return nil, errors.New("connection user token does not match") + } + if joinWorkspaceReq.ClientCapabilities.IsNonInteractive != false { + return nil, errors.New("non interactive is not false") + } + return joinWorkspaceResult{1}, nil + } + + server, err := livesharetest.NewServer( + livesharetest.WithPassword(opts.SessionToken), + livesharetest.WithService("workspace.joinWorkspace", joinWorkspace), + livesharetest.WithRelaySAS(opts.RelaySAS), + ) + if err != nil { + t.Errorf("error creating Live Share server: %w", err) + } + defer server.Close() + opts.RelayEndpoint = "sb" + strings.TrimPrefix(server.URL(), "https") + + ctx := context.Background() + + opts.TLSConfig = &tls.Config{InsecureSkipVerify: true} + + done := make(chan error) + go func() { + _, err := Connect(ctx, opts) // ignore session + done <- err + }() + + select { + case err := <-server.Err(): + t.Errorf("error from server: %w", err) + case err := <-done: + if err != nil { + t.Errorf("error from client: %w", err) + } + } +} diff --git a/internal/liveshare/options_test.go b/internal/liveshare/options_test.go new file mode 100644 index 000000000..830c59104 --- /dev/null +++ b/internal/liveshare/options_test.go @@ -0,0 +1,56 @@ +package liveshare + +import ( + "context" + "testing" +) + +func TestBadOptions(t *testing.T) { + goodOptions := Options{ + SessionID: "sess-id", + SessionToken: "sess-token", + RelaySAS: "sas", + RelayEndpoint: "endpoint", + } + + opts := goodOptions + opts.SessionID = "" + checkBadOptions(t, opts) + + opts = goodOptions + opts.SessionToken = "" + checkBadOptions(t, opts) + + opts = goodOptions + opts.RelaySAS = "" + checkBadOptions(t, opts) + + opts = goodOptions + opts.RelayEndpoint = "" + checkBadOptions(t, opts) + + opts = Options{} + checkBadOptions(t, opts) +} + +func checkBadOptions(t *testing.T, opts Options) { + if _, err := Connect(context.Background(), opts); err == nil { + t.Errorf("Connect(%+v): no error", opts) + } +} + +func TestOptionsURI(t *testing.T) { + opts := Options{ + SessionID: "sess-id", + SessionToken: "sess-token", + RelaySAS: "sas", + RelayEndpoint: "sb://endpoint/.net/liveshare", + } + uri, err := opts.uri("connect") + if err != nil { + t.Fatal(err) + } + if uri != "wss://endpoint/.net:443/$hc/liveshare?sb-hc-action=connect&sb-hc-token=sas" { + t.Errorf("uri is not correct, got: '%v'", uri) + } +} diff --git a/internal/liveshare/port_forwarder.go b/internal/liveshare/port_forwarder.go new file mode 100644 index 000000000..fcc7ba767 --- /dev/null +++ b/internal/liveshare/port_forwarder.go @@ -0,0 +1,162 @@ +package liveshare + +import ( + "context" + "fmt" + "io" + "net" + + "github.com/opentracing/opentracing-go" +) + +// A PortForwarder forwards TCP traffic over a Live Share session from a port on a remote +// container to a local destination such as a network port or Go reader/writer. +type PortForwarder struct { + session *Session + name string + remotePort int +} + +// NewPortForwarder returns a new PortForwarder for the specified +// remote port and Live Share session. The name describes the purpose +// of the remote port or service. +func NewPortForwarder(session *Session, name string, remotePort int) *PortForwarder { + return &PortForwarder{ + session: session, + name: name, + remotePort: remotePort, + } +} + +// ForwardToListener forwards traffic between the container's remote +// port and a local port, which must already be listening for +// connections. (Accepting a listener rather than a port number avoids +// races against other processes opening ports, and against a client +// connecting to the socket prematurely.) +// +// ForwardToListener accepts and handles connections on the local port +// until it encounters the first error, which may include context +// cancellation. Its error result is always non-nil. The caller is +// responsible for closing the listening port. +func (fwd *PortForwarder) ForwardToListener(ctx context.Context, listen net.Listener) (err error) { + id, err := fwd.shareRemotePort(ctx) + if err != nil { + return err + } + + errc := make(chan error, 1) + sendError := func(err error) { + // Use non-blocking send, to avoid goroutines getting + // stuck in case of concurrent or sequential errors. + select { + case errc <- err: + default: + } + } + go func() { + for { + conn, err := listen.Accept() + if err != nil { + sendError(err) + return + } + + go func() { + if err := fwd.handleConnection(ctx, id, conn); err != nil { + sendError(err) + } + }() + } + }() + + return awaitError(ctx, errc) +} + +// Forward forwards traffic between the container's remote port and +// the specified read/write stream. On return, the stream is closed. +func (fwd *PortForwarder) Forward(ctx context.Context, conn io.ReadWriteCloser) error { + id, err := fwd.shareRemotePort(ctx) + if err != nil { + conn.Close() + return err + } + + // Create buffered channel so that send doesn't get stuck after context cancellation. + errc := make(chan error, 1) + go func() { + errc <- fwd.handleConnection(ctx, id, conn) + }() + return awaitError(ctx, errc) +} + +func (fwd *PortForwarder) shareRemotePort(ctx context.Context) (channelID, error) { + id, err := fwd.session.startSharing(ctx, fwd.name, fwd.remotePort) + if err != nil { + err = fmt.Errorf("failed to share remote port %d: %w", fwd.remotePort, err) + } + return id, err +} + +func awaitError(ctx context.Context, errc <-chan error) error { + select { + case err := <-errc: + return err + case <-ctx.Done(): + return ctx.Err() // canceled + } +} + +// handleConnection handles forwarding for a single accepted connection, then closes it. +func (fwd *PortForwarder) handleConnection(ctx context.Context, id channelID, conn io.ReadWriteCloser) (err error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "PortForwarder.handleConnection") + defer span.Finish() + + defer safeClose(conn, &err) + + channel, err := fwd.session.openStreamingChannel(ctx, id) + if err != nil { + return fmt.Errorf("error opening streaming channel for new connection: %w", err) + } + // Ideally we would call safeClose again, but (*ssh.channel).Close + // appears to have a bug that causes it return io.EOF spuriously + // if its peer closed first; see github.com/golang/go/issues/38115. + defer func() { + closeErr := channel.Close() + if err == nil && closeErr != io.EOF { + err = closeErr + } + }() + + // bi-directional copy of data. + errs := make(chan error, 2) + copyConn := func(w io.Writer, r io.Reader) { + _, err := io.Copy(w, r) + errs <- err + } + go copyConn(conn, channel) + go copyConn(channel, conn) + + // Wait until context is cancelled or both copies are done. + // Discard errors from io.Copy; they should not cause (e.g.) ForwardToListener to fail. + // TODO: how can we proxy errors from Copy so that each peer can distinguish an error from a short file? + for i := 0; ; { + select { + case <-ctx.Done(): + return ctx.Err() + case <-errs: + i++ + if i == 2 { + return nil + } + } + } +} + +// safeClose reports the error (to *err) from closing the stream only +// if no other error was previously reported. +func safeClose(closer io.Closer, err *error) { + closeErr := closer.Close() + if *err == nil { + *err = closeErr + } +} diff --git a/internal/liveshare/port_forwarder_test.go b/internal/liveshare/port_forwarder_test.go new file mode 100644 index 000000000..eda7262c6 --- /dev/null +++ b/internal/liveshare/port_forwarder_test.go @@ -0,0 +1,95 @@ +package liveshare + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net" + "testing" + "time" + + livesharetest "github.com/cli/cli/v2/internal/liveshare/test" + "github.com/sourcegraph/jsonrpc2" +) + +func TestNewPortForwarder(t *testing.T) { + testServer, session, err := makeMockSession() + if err != nil { + t.Errorf("create mock client: %w", err) + } + defer testServer.Close() + pf := NewPortForwarder(session, "ssh", 80) + if pf == nil { + t.Error("port forwarder is nil") + } +} + +func TestPortForwarderStart(t *testing.T) { + streamName, streamCondition := "stream-name", "stream-condition" + serverSharing := func(req *jsonrpc2.Request) (interface{}, error) { + return Port{StreamName: streamName, StreamCondition: streamCondition}, nil + } + getStream := func(req *jsonrpc2.Request) (interface{}, error) { + return "stream-id", nil + } + + stream := bytes.NewBufferString("stream-data") + testServer, session, err := makeMockSession( + livesharetest.WithService("serverSharing.startSharing", serverSharing), + livesharetest.WithService("streamManager.getStream", getStream), + livesharetest.WithStream("stream-id", stream), + ) + if err != nil { + t.Errorf("create mock session: %w", err) + } + defer testServer.Close() + + listen, err := net.Listen("tcp", ":8000") + if err != nil { + t.Fatal(err) + } + defer listen.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + done := make(chan error) + go func() { + const name, remote = "ssh", 8000 + done <- NewPortForwarder(session, name, remote).ForwardToListener(ctx, listen) + }() + + go func() { + var conn net.Conn + retries := 0 + for conn == nil && retries < 2 { + conn, err = net.DialTimeout("tcp", ":8000", 2*time.Second) + time.Sleep(1 * time.Second) + } + if conn == nil { + done <- errors.New("failed to connect to forwarded port") + } + b := make([]byte, len("stream-data")) + if _, err := conn.Read(b); err != nil && err != io.EOF { + done <- fmt.Errorf("reading stream: %w", err) + } + if string(b) != "stream-data" { + done <- fmt.Errorf("stream data is not expected value, got: %s", string(b)) + } + if _, err := conn.Write([]byte("new-data")); err != nil { + done <- fmt.Errorf("writing to stream: %w", err) + } + done <- nil + }() + + select { + case err := <-testServer.Err(): + t.Errorf("error from server: %w", err) + case err := <-done: + if err != nil { + t.Errorf("error from client: %w", err) + } + } +} diff --git a/internal/liveshare/rpc.go b/internal/liveshare/rpc.go new file mode 100644 index 000000000..bfd214c89 --- /dev/null +++ b/internal/liveshare/rpc.go @@ -0,0 +1,41 @@ +package liveshare + +import ( + "context" + "fmt" + "io" + + "github.com/opentracing/opentracing-go" + "github.com/sourcegraph/jsonrpc2" +) + +type rpcClient struct { + *jsonrpc2.Conn + conn io.ReadWriteCloser +} + +func newRPCClient(conn io.ReadWriteCloser) *rpcClient { + return &rpcClient{conn: conn} +} + +func (r *rpcClient) connect(ctx context.Context) { + stream := jsonrpc2.NewBufferedStream(r.conn, jsonrpc2.VSCodeObjectCodec{}) + r.Conn = jsonrpc2.NewConn(ctx, stream, nullHandler{}) +} + +func (r *rpcClient) do(ctx context.Context, method string, args, result interface{}) error { + span, ctx := opentracing.StartSpanFromContext(ctx, method) + defer span.Finish() + + waiter, err := r.Conn.DispatchCall(ctx, method, args) + if err != nil { + return fmt.Errorf("error dispatching %q call: %w", method, err) + } + + return waiter.Wait(ctx, result) +} + +type nullHandler struct{} + +func (nullHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { +} diff --git a/internal/liveshare/session.go b/internal/liveshare/session.go new file mode 100644 index 000000000..929e8605b --- /dev/null +++ b/internal/liveshare/session.go @@ -0,0 +1,99 @@ +package liveshare + +import ( + "context" + "fmt" + "strconv" +) + +// A Session represents the session between a connected Live Share client and server. +type Session struct { + ssh *sshSession + rpc *rpcClient +} + +// Close should be called by users to clean up RPC and SSH resources whenever the session +// is no longer active. +func (s *Session) Close() error { + // Closing the RPC conn closes the underlying stream (SSH) + // So we only need to close once + if err := s.rpc.Close(); err != nil { + s.ssh.Close() // close SSH and ignore error + return fmt.Errorf("error while closing Live Share session: %w", err) + } + + return nil +} + +// Port describes a port exposed by the container. +type Port struct { + SourcePort int `json:"sourcePort"` + DestinationPort int `json:"destinationPort"` + SessionName string `json:"sessionName"` + StreamName string `json:"streamName"` + StreamCondition string `json:"streamCondition"` + BrowseURL string `json:"browseUrl"` + IsPublic bool `json:"isPublic"` + IsTCPServerConnectionEstablished bool `json:"isTCPServerConnectionEstablished"` + HasTLSHandshakePassed bool `json:"hasTLSHandshakePassed"` +} + +// startSharing tells the Live Share host to start sharing the specified port from the container. +// The sessionName describes the purpose of the remote port or service. +// It returns an identifier that can be used to open an SSH channel to the remote port. +func (s *Session) startSharing(ctx context.Context, sessionName string, port int) (channelID, error) { + args := []interface{}{port, sessionName, fmt.Sprintf("http://localhost:%d", port)} + var response Port + if err := s.rpc.do(ctx, "serverSharing.startSharing", args, &response); err != nil { + return channelID{}, err + } + + return channelID{response.StreamName, response.StreamCondition}, nil +} + +// GetSharedServers returns a description of each container port +// shared by a prior call to StartSharing by some client. +func (s *Session) GetSharedServers(ctx context.Context) ([]*Port, error) { + var response []*Port + if err := s.rpc.do(ctx, "serverSharing.getSharedServers", []string{}, &response); err != nil { + return nil, err + } + + return response, nil +} + +// UpdateSharedVisibility controls port permissions and whether it can be accessed publicly +// via the Browse URL +func (s *Session) UpdateSharedVisibility(ctx context.Context, port int, public bool) error { + if err := s.rpc.do(ctx, "serverSharing.updateSharedServerVisibility", []interface{}{port, public}, nil); err != nil { + return err + } + + return nil +} + +// StartsSSHServer starts an SSH server in the container, installing sshd if necessary, +// and returns the port on which it listens and the user name clients should provide. +func (s *Session) StartSSHServer(ctx context.Context) (int, string, error) { + var response struct { + Result bool `json:"result"` + ServerPort string `json:"serverPort"` + User string `json:"user"` + Message string `json:"message"` + } + + if err := s.rpc.do(ctx, "ISshServerHostService.startRemoteServer", []string{}, &response); err != nil { + return 0, "", err + } + + if !response.Result { + return 0, "", fmt.Errorf("failed to start server: %s", response.Message) + } + + port, err := strconv.Atoi(response.ServerPort) + if err != nil { + return 0, "", fmt.Errorf("failed to parse port: %w", err) + } + + return port, response.User, nil +} diff --git a/internal/liveshare/session_test.go b/internal/liveshare/session_test.go new file mode 100644 index 000000000..1064e490b --- /dev/null +++ b/internal/liveshare/session_test.go @@ -0,0 +1,223 @@ +package liveshare + +import ( + "context" + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "strings" + "testing" + + livesharetest "github.com/cli/cli/v2/internal/liveshare/test" + "github.com/sourcegraph/jsonrpc2" +) + +func makeMockSession(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Session, error) { + joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) { + return joinWorkspaceResult{1}, nil + } + const sessionToken = "session-token" + opts = append( + opts, + livesharetest.WithPassword(sessionToken), + livesharetest.WithService("workspace.joinWorkspace", joinWorkspace), + ) + testServer, err := livesharetest.NewServer(opts...) + if err != nil { + return nil, nil, fmt.Errorf("error creating server: %w", err) + } + + session, err := Connect(context.Background(), Options{ + SessionID: "session-id", + SessionToken: sessionToken, + RelayEndpoint: "sb" + strings.TrimPrefix(testServer.URL(), "https"), + RelaySAS: "relay-sas", + HostPublicKeys: []string{livesharetest.SSHPublicKey}, + TLSConfig: &tls.Config{InsecureSkipVerify: true}, + }) + if err != nil { + return nil, nil, fmt.Errorf("error connecting to Live Share: %w", err) + } + return testServer, session, nil +} + +func TestServerStartSharing(t *testing.T) { + serverPort, serverProtocol := 2222, "sshd" + startSharing := func(req *jsonrpc2.Request) (interface{}, error) { + var args []interface{} + if err := json.Unmarshal(*req.Params, &args); err != nil { + return nil, fmt.Errorf("error unmarshaling request: %w", err) + } + if len(args) < 3 { + return nil, errors.New("not enough arguments to start sharing") + } + if port, ok := args[0].(float64); !ok { + return nil, errors.New("port argument is not an int") + } else if port != float64(serverPort) { + return nil, errors.New("port does not match serverPort") + } + if protocol, ok := args[1].(string); !ok { + return nil, errors.New("protocol argument is not a string") + } else if protocol != serverProtocol { + return nil, errors.New("protocol does not match serverProtocol") + } + if browseURL, ok := args[2].(string); !ok { + return nil, errors.New("browse url is not a string") + } else if browseURL != fmt.Sprintf("http://localhost:%d", serverPort) { + return nil, errors.New("browseURL does not match expected") + } + return Port{StreamName: "stream-name", StreamCondition: "stream-condition"}, nil + } + testServer, session, err := makeMockSession( + livesharetest.WithService("serverSharing.startSharing", startSharing), + ) + defer testServer.Close() //nolint:staticcheck // httptest.Server does not return errors on Close() + + if err != nil { + t.Errorf("error creating mock session: %w", err) + } + ctx := context.Background() + + done := make(chan error) + go func() { + streamID, err := session.startSharing(ctx, serverProtocol, serverPort) + if err != nil { + done <- fmt.Errorf("error sharing server: %w", err) + } + if streamID.name == "" || streamID.condition == "" { + done <- errors.New("stream name or condition is blank") + } + done <- nil + }() + + select { + case err := <-testServer.Err(): + t.Errorf("error from server: %w", err) + case err := <-done: + if err != nil { + t.Errorf("error from client: %w", err) + } + } +} + +func TestServerGetSharedServers(t *testing.T) { + sharedServer := Port{ + SourcePort: 2222, + StreamName: "stream-name", + StreamCondition: "stream-condition", + } + getSharedServers := func(req *jsonrpc2.Request) (interface{}, error) { + return []*Port{&sharedServer}, nil + } + testServer, session, err := makeMockSession( + livesharetest.WithService("serverSharing.getSharedServers", getSharedServers), + ) + if err != nil { + t.Errorf("error creating mock session: %w", err) + } + defer testServer.Close() + ctx := context.Background() + done := make(chan error) + go func() { + ports, err := session.GetSharedServers(ctx) + if err != nil { + done <- fmt.Errorf("error getting shared servers: %w", err) + } + if len(ports) < 1 { + done <- errors.New("not enough ports returned") + } + if ports[0].SourcePort != sharedServer.SourcePort { + done <- errors.New("source port does not match") + } + if ports[0].StreamName != sharedServer.StreamName { + done <- errors.New("stream name does not match") + } + if ports[0].StreamCondition != sharedServer.StreamCondition { + done <- errors.New("stream condiion does not match") + } + done <- nil + }() + + select { + case err := <-testServer.Err(): + t.Errorf("error from server: %w", err) + case err := <-done: + if err != nil { + t.Errorf("error from client: %w", err) + } + } +} + +func TestServerUpdateSharedVisibility(t *testing.T) { + updateSharedVisibility := func(rpcReq *jsonrpc2.Request) (interface{}, error) { + var req []interface{} + if err := json.Unmarshal(*rpcReq.Params, &req); err != nil { + return nil, fmt.Errorf("unmarshal req: %w", err) + } + if len(req) < 2 { + return nil, errors.New("request arguments is less than 2") + } + if port, ok := req[0].(float64); ok { + if port != 80.0 { + return nil, errors.New("port param is not expected value") + } + } else { + return nil, errors.New("port param is not a float64") + } + if public, ok := req[1].(bool); ok { + if public != true { + return nil, errors.New("pulic param is not expected value") + } + } else { + return nil, errors.New("public param is not a bool") + } + return nil, nil + } + testServer, session, err := makeMockSession( + livesharetest.WithService("serverSharing.updateSharedServerVisibility", updateSharedVisibility), + ) + if err != nil { + t.Errorf("creating mock session: %w", err) + } + defer testServer.Close() + ctx := context.Background() + done := make(chan error) + go func() { + done <- session.UpdateSharedVisibility(ctx, 80, true) + }() + select { + case err := <-testServer.Err(): + t.Errorf("error from server: %w", err) + case err := <-done: + if err != nil { + t.Errorf("error from client: %w", err) + } + } +} + +func TestInvalidHostKey(t *testing.T) { + joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) { + return joinWorkspaceResult{1}, nil + } + const sessionToken = "session-token" + opts := []livesharetest.ServerOption{ + livesharetest.WithPassword(sessionToken), + livesharetest.WithService("workspace.joinWorkspace", joinWorkspace), + } + testServer, err := livesharetest.NewServer(opts...) + if err != nil { + t.Errorf("error creating server: %w", err) + } + _, err = Connect(context.Background(), Options{ + SessionID: "session-id", + SessionToken: sessionToken, + RelayEndpoint: "sb" + strings.TrimPrefix(testServer.URL(), "https"), + RelaySAS: "relay-sas", + HostPublicKeys: []string{}, + TLSConfig: &tls.Config{InsecureSkipVerify: true}, + }) + if err == nil { + t.Error("expected invalid host key error, got: nil") + } +} diff --git a/internal/liveshare/socket.go b/internal/liveshare/socket.go new file mode 100644 index 000000000..f66436f65 --- /dev/null +++ b/internal/liveshare/socket.go @@ -0,0 +1,100 @@ +package liveshare + +import ( + "context" + "crypto/tls" + "io" + "net" + "net/http" + "time" + + "github.com/gorilla/websocket" +) + +type socket struct { + addr string + tlsConfig *tls.Config + + conn *websocket.Conn + reader io.Reader +} + +func newSocket(uri string, tlsConfig *tls.Config) *socket { + return &socket{addr: uri, tlsConfig: tlsConfig} +} + +func (s *socket) connect(ctx context.Context) error { + dialer := websocket.Dialer{ + Proxy: http.ProxyFromEnvironment, + HandshakeTimeout: 45 * time.Second, + TLSClientConfig: s.tlsConfig, + } + ws, _, err := dialer.Dial(s.addr, nil) + if err != nil { + return err + } + s.conn = ws + return nil +} + +func (s *socket) Read(b []byte) (int, error) { + if s.reader == nil { + _, reader, err := s.conn.NextReader() + if err != nil { + return 0, err + } + + s.reader = reader + } + + bytesRead, err := s.reader.Read(b) + if err != nil { + s.reader = nil + + if err == io.EOF { + err = nil + } + } + + return bytesRead, err +} + +func (s *socket) Write(b []byte) (int, error) { + nextWriter, err := s.conn.NextWriter(websocket.BinaryMessage) + if err != nil { + return 0, err + } + + bytesWritten, err := nextWriter.Write(b) + nextWriter.Close() + + return bytesWritten, err +} + +func (s *socket) Close() error { + return s.conn.Close() +} + +func (s *socket) LocalAddr() net.Addr { + return s.conn.LocalAddr() +} + +func (s *socket) RemoteAddr() net.Addr { + return s.conn.RemoteAddr() +} + +func (s *socket) SetDeadline(t time.Time) error { + if err := s.SetReadDeadline(t); err != nil { + return err + } + + return s.SetWriteDeadline(t) +} + +func (s *socket) SetReadDeadline(t time.Time) error { + return s.conn.SetReadDeadline(t) +} + +func (s *socket) SetWriteDeadline(t time.Time) error { + return s.conn.SetWriteDeadline(t) +} diff --git a/internal/liveshare/ssh.go b/internal/liveshare/ssh.go new file mode 100644 index 000000000..e7de9055a --- /dev/null +++ b/internal/liveshare/ssh.go @@ -0,0 +1,79 @@ +package liveshare + +import ( + "context" + "encoding/base64" + "errors" + "fmt" + "io" + "net" + "time" + + "golang.org/x/crypto/ssh" +) + +type sshSession struct { + *ssh.Session + token string + hostPublicKeys []string + socket net.Conn + conn ssh.Conn + reader io.Reader + writer io.Writer +} + +func newSSHSession(token string, hostPublicKeys []string, socket net.Conn) *sshSession { + return &sshSession{token: token, hostPublicKeys: hostPublicKeys, socket: socket} +} + +func (s *sshSession) connect(ctx context.Context) error { + clientConfig := ssh.ClientConfig{ + User: "", + Auth: []ssh.AuthMethod{ + ssh.Password(s.token), + }, + HostKeyAlgorithms: []string{"rsa-sha2-512", "rsa-sha2-256"}, + HostKeyCallback: func(hostname string, addr net.Addr, key ssh.PublicKey) error { + encodedKey := base64.StdEncoding.EncodeToString(key.Marshal()) + for _, hpk := range s.hostPublicKeys { + if encodedKey == hpk { + return nil // we found a match for expected public key, safely return + } + } + return errors.New("invalid host public key") + }, + Timeout: 10 * time.Second, + } + + sshClientConn, chans, reqs, err := ssh.NewClientConn(s.socket, "", &clientConfig) + if err != nil { + return fmt.Errorf("error creating ssh client connection: %w", err) + } + s.conn = sshClientConn + + sshClient := ssh.NewClient(sshClientConn, chans, reqs) + s.Session, err = sshClient.NewSession() + if err != nil { + return fmt.Errorf("error creating ssh client session: %w", err) + } + + s.reader, err = s.Session.StdoutPipe() + if err != nil { + return fmt.Errorf("error creating ssh session reader: %w", err) + } + + s.writer, err = s.Session.StdinPipe() + if err != nil { + return fmt.Errorf("error creating ssh session writer: %w", err) + } + + return nil +} + +func (s *sshSession) Read(p []byte) (n int, err error) { + return s.reader.Read(p) +} + +func (s *sshSession) Write(p []byte) (n int, err error) { + return s.writer.Write(p) +} diff --git a/internal/liveshare/test/server.go b/internal/liveshare/test/server.go new file mode 100644 index 000000000..3f038a15b --- /dev/null +++ b/internal/liveshare/test/server.go @@ -0,0 +1,334 @@ +package livesharetest + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + + "github.com/gorilla/websocket" + "github.com/sourcegraph/jsonrpc2" + "golang.org/x/crypto/ssh" +) + +const sshPrivateKey = `-----BEGIN RSA PRIVATE KEY----- +MIICXgIBAAKBgQC6VU6XsMaTot9ogsGcJ+juvJOmDvvCZmgJRTRwKkW0u2BLz4yV +rCzQcxaY4kaIuR80Y+1f0BLnZgh4pTREDR0T+p8hUsDSHim1ttKI8rK0hRtJ2qhY +lR4qt7P51rPA4KFA9z9gDjTwQLbDq21QMC4+n4d8CL3xRVGtlUAMM3Kl3wIDAQAB +AoGBAI8UemkYoSM06gBCh5D1RHQt8eKNltzL7g9QSNfoXeZOC7+q+/TiZPcbqLp0 +5lyOalu8b8Ym7J0rSE377Ypj13LyHMXS63e4wMiXv3qOl3GDhMLpypnJ8PwqR2b8 +IijL2jrpQfLu6IYqlteA+7e9aEexJa1RRwxYIyq6pG1IYpbhAkEA9nKgtj3Z6ZDC +46IdqYzuUM9ZQdcw4AFr407+lub7tbWe5pYmaq3cT725IwLw081OAmnWJYFDMa/n +IPl9YcZSPQJBAMGOMbPs/YPkQAsgNdIUlFtK3o41OrrwJuTRTvv0DsbqDV0LKOiC +t8oAQQvjisH6Ew5OOhFyIFXtvZfzQMJppksCQQDWFd+cUICTUEise/Duj9maY3Uz +J99ySGnTbZTlu8PfJuXhg3/d3ihrMPG6A1z3cPqaSBxaOj8H07mhQHn1zNU1AkEA +hkl+SGPrO793g4CUdq2ahIA8SpO5rIsDoQtq7jlUq0MlhGFCv5Y5pydn+bSjx5MV +933kocf5kUSBntPBIWElYwJAZTm5ghu0JtSE6t3km0iuj7NGAQSdb6mD8+O7C3CP +FU3vi+4HlBysaT6IZ/HG+/dBsr4gYp4LGuS7DbaLuYw/uw== +-----END RSA PRIVATE KEY-----` + +const SSHPublicKey = `AAAAB3NzaC1yc2EAAAADAQABAAAAgQC6VU6XsMaTot9ogsGcJ+juvJOmDvvCZmgJRTRwKkW0u2BLz4yVrCzQcxaY4kaIuR80Y+1f0BLnZgh4pTREDR0T+p8hUsDSHim1ttKI8rK0hRtJ2qhYlR4qt7P51rPA4KFA9z9gDjTwQLbDq21QMC4+n4d8CL3xRVGtlUAMM3Kl3w==` + +// Server represents a LiveShare relay host server. +type Server struct { + password string + services map[string]RPCHandleFunc + relaySAS string + streams map[string]io.ReadWriter + sshConfig *ssh.ServerConfig + httptestServer *httptest.Server + errCh chan error +} + +// NewServer creates a new Server. ServerOptions can be passed to configure +// the SSH password, backing service, secrets and more. +func NewServer(opts ...ServerOption) (*Server, error) { + server := new(Server) + + for _, o := range opts { + if err := o(server); err != nil { + return nil, err + } + } + + server.sshConfig = &ssh.ServerConfig{ + PasswordCallback: sshPasswordCallback(server.password), + } + privateKey, err := ssh.ParsePrivateKey([]byte(sshPrivateKey)) + if err != nil { + return nil, fmt.Errorf("error parsing key: %w", err) + } + server.sshConfig.AddHostKey(privateKey) + + server.errCh = make(chan error, 1) + server.httptestServer = httptest.NewTLSServer(http.HandlerFunc(makeConnection(server))) + return server, nil +} + +// ServerOption is used to configure the Server. +type ServerOption func(*Server) error + +// WithPassword configures the Server password for SSH. +func WithPassword(password string) ServerOption { + return func(s *Server) error { + s.password = password + return nil + } +} + +// WithService accepts a mock RPC service for the Server to invoke. +func WithService(serviceName string, handler RPCHandleFunc) ServerOption { + return func(s *Server) error { + if s.services == nil { + s.services = make(map[string]RPCHandleFunc) + } + + s.services[serviceName] = handler + return nil + } +} + +// WithRelaySAS configures the relay SAS configuration key. +func WithRelaySAS(sas string) ServerOption { + return func(s *Server) error { + s.relaySAS = sas + return nil + } +} + +// WithStream allows you to specify a mock data stream for the server. +func WithStream(name string, stream io.ReadWriter) ServerOption { + return func(s *Server) error { + if s.streams == nil { + s.streams = make(map[string]io.ReadWriter) + } + s.streams[name] = stream + return nil + } +} + +func sshPasswordCallback(serverPassword string) func(ssh.ConnMetadata, []byte) (*ssh.Permissions, error) { + return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { + if string(password) == serverPassword { + return nil, nil + } + return nil, errors.New("password rejected") + } +} + +// Close closes the underlying httptest Server. +func (s *Server) Close() { + s.httptestServer.Close() +} + +// URL returns the httptest Server url. +func (s *Server) URL() string { + return s.httptestServer.URL +} + +func (s *Server) Err() <-chan error { + return s.errCh +} + +var upgrader = websocket.Upgrader{} + +func makeConnection(server *Server) http.HandlerFunc { + return func(w http.ResponseWriter, req *http.Request) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + if server.relaySAS != "" { + // validate the sas key + sasParam := req.URL.Query().Get("sb-hc-token") + if sasParam != server.relaySAS { + sendError(server.errCh, errors.New("error validating sas")) + return + } + } + c, err := upgrader.Upgrade(w, req, nil) + if err != nil { + sendError(server.errCh, fmt.Errorf("error upgrading connection: %w", err)) + return + } + defer func() { + if err := c.Close(); err != nil { + sendError(server.errCh, err) + } + }() + + socketConn := newSocketConn(c) + _, chans, reqs, err := ssh.NewServerConn(socketConn, server.sshConfig) + if err != nil { + sendError(server.errCh, fmt.Errorf("error creating new ssh conn: %w", err)) + return + } + go ssh.DiscardRequests(reqs) + + if err := handleChannels(ctx, server, chans); err != nil { + sendError(server.errCh, err) + } + } +} + +// sendError does a non-blocking send of the error to the err channel. +func sendError(errc chan<- error, err error) { + select { + case errc <- err: + default: + // channel is blocked with a previous error, so we ignore + // this current error + } +} + +// awaitError waits for the context to finish and returns its error (if any). +// It also waits for an err to come through the err channel. +func awaitError(ctx context.Context, errc <-chan error) error { + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-errc: + return err + } +} + +// handleChannels services the sshChannels channel. For each SSH channel received +// it creates a go routine to service the channel's requests. It returns on the first +// error encountered. +func handleChannels(ctx context.Context, server *Server, sshChannels <-chan ssh.NewChannel) error { + errc := make(chan error, 1) + go func() { + for sshCh := range sshChannels { + ch, reqs, err := sshCh.Accept() + if err != nil { + sendError(errc, fmt.Errorf("failed to accept channel: %w", err)) + return + } + + go func() { + if err := handleRequests(ctx, server, ch, reqs); err != nil { + sendError(errc, fmt.Errorf("failed to handle requests: %w", err)) + } + }() + + handleChannel(server, ch) + } + }() + return awaitError(ctx, errc) +} + +// handleRequests services the SSH channel requests channel. It replies to requests and +// when stream transport requests are encountered, creates a go routine to create a +// bi-directional data stream between the channel and server stream. It returns on the first error +// encountered. +func handleRequests(ctx context.Context, server *Server, channel ssh.Channel, reqs <-chan *ssh.Request) error { + errc := make(chan error, 1) + go func() { + for req := range reqs { + if req.WantReply { + if err := req.Reply(true, nil); err != nil { + sendError(errc, fmt.Errorf("error replying to channel request: %w", err)) + return + } + } + + if strings.HasPrefix(req.Type, "stream-transport") { + go func() { + if err := forwardStream(ctx, server, req.Type, channel); err != nil { + sendError(errc, fmt.Errorf("failed to forward stream: %w", err)) + } + }() + } + } + }() + + return awaitError(ctx, errc) +} + +// concurrentStream is a concurrency safe io.ReadWriter. +type concurrentStream struct { + sync.RWMutex + stream io.ReadWriter +} + +func newConcurrentStream(rw io.ReadWriter) *concurrentStream { + return &concurrentStream{stream: rw} +} + +func (cs *concurrentStream) Read(b []byte) (int, error) { + cs.RLock() + defer cs.RUnlock() + return cs.stream.Read(b) +} + +func (cs *concurrentStream) Write(b []byte) (int, error) { + cs.Lock() + defer cs.Unlock() + return cs.stream.Write(b) +} + +// forwardStream does a bi-directional copy of the stream <-> with the SSH channel. The io.Copy +// runs until an error is encountered. +func forwardStream(ctx context.Context, server *Server, streamName string, channel ssh.Channel) (err error) { + simpleStreamName := strings.TrimPrefix(streamName, "stream-transport-") + stream, found := server.streams[simpleStreamName] + if !found { + return fmt.Errorf("stream '%s' not found", simpleStreamName) + } + defer func() { + if closeErr := channel.Close(); err == nil && closeErr != io.EOF { + err = closeErr + } + }() + + errc := make(chan error, 2) + copy := func(dst io.Writer, src io.Reader) { + if _, err := io.Copy(dst, src); err != nil { + errc <- err + } + } + + csStream := newConcurrentStream(stream) + go copy(csStream, channel) + go copy(channel, csStream) + + return awaitError(ctx, errc) +} + +func handleChannel(server *Server, channel ssh.Channel) { + stream := jsonrpc2.NewBufferedStream(channel, jsonrpc2.VSCodeObjectCodec{}) + jsonrpc2.NewConn(context.Background(), stream, newRPCHandler(server)) +} + +type RPCHandleFunc func(req *jsonrpc2.Request) (interface{}, error) + +type rpcHandler struct { + server *Server +} + +func newRPCHandler(server *Server) *rpcHandler { + return &rpcHandler{server} +} + +// Handle satisfies the jsonrpc2 pkg handler interface. It tries to find a mocked +// RPC service method and if found, it invokes the handler and replies to the request. +func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { + handler, found := r.server.services[req.Method] + if !found { + sendError(r.server.errCh, fmt.Errorf("RPC Method: '%s' not serviced", req.Method)) + return + } + + result, err := handler(req) + if err != nil { + sendError(r.server.errCh, fmt.Errorf("error handling: '%s': %w", req.Method, err)) + return + } + + if err := conn.Reply(ctx, req.ID, result); err != nil { + sendError(r.server.errCh, fmt.Errorf("error replying: %w", err)) + } +} diff --git a/internal/liveshare/test/socket.go b/internal/liveshare/test/socket.go new file mode 100644 index 000000000..00cd64a1b --- /dev/null +++ b/internal/liveshare/test/socket.go @@ -0,0 +1,77 @@ +package livesharetest + +import ( + "fmt" + "io" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +type socketConn struct { + *websocket.Conn + + reader io.Reader + writeMutex sync.Mutex + readMutex sync.Mutex +} + +func newSocketConn(conn *websocket.Conn) *socketConn { + return &socketConn{Conn: conn} +} + +func (s *socketConn) Read(b []byte) (int, error) { + s.readMutex.Lock() + defer s.readMutex.Unlock() + + if s.reader == nil { + msgType, r, err := s.Conn.NextReader() + if err != nil { + return 0, fmt.Errorf("error getting next reader: %w", err) + } + if msgType != websocket.BinaryMessage { + return 0, fmt.Errorf("invalid message type") + } + s.reader = r + } + + bytesRead, err := s.reader.Read(b) + if err != nil { + s.reader = nil + + if err == io.EOF { + err = nil + } + } + + return bytesRead, err +} + +func (s *socketConn) Write(b []byte) (int, error) { + s.writeMutex.Lock() + defer s.writeMutex.Unlock() + + w, err := s.Conn.NextWriter(websocket.BinaryMessage) + if err != nil { + return 0, fmt.Errorf("error getting next writer: %w", err) + } + + n, err := w.Write(b) + if err != nil { + return 0, fmt.Errorf("error writing: %w", err) + } + + if err := w.Close(); err != nil { + return 0, fmt.Errorf("error closing writer: %w", err) + } + + return n, nil +} + +func (s *socketConn) SetDeadline(deadline time.Time) error { + if err := s.Conn.SetReadDeadline(deadline); err != nil { + return err + } + return s.Conn.SetWriteDeadline(deadline) +} diff --git a/pkg/cmd/root/root.go b/pkg/cmd/root/root.go index f24c29211..bb36b9592 100644 --- a/pkg/cmd/root/root.go +++ b/pkg/cmd/root/root.go @@ -2,8 +2,12 @@ package root import ( "net/http" + "sync" "github.com/MakeNowJust/heredoc" + "github.com/cli/cli/v2/cmd/ghcs" + "github.com/cli/cli/v2/cmd/ghcs/output" + ghcsApi "github.com/cli/cli/v2/internal/api" actionsCmd "github.com/cli/cli/v2/pkg/cmd/actions" aliasCmd "github.com/cli/cli/v2/pkg/cmd/alias" apiCmd "github.com/cli/cli/v2/pkg/cmd/api" @@ -78,6 +82,7 @@ func NewCmdRoot(f *cmdutil.Factory, version, buildDate string) *cobra.Command { cmd.AddCommand(extensionCmd.NewCmdExtension(f)) cmd.AddCommand(secretCmd.NewCmdSecret(f)) cmd.AddCommand(sshKeyCmd.NewCmdSSHKey(f)) + cmd.AddCommand(newCodespaceCmd(f)) // the `api` command should not inherit any extra HTTP headers bareHTTPCmdFactory := *f @@ -121,3 +126,39 @@ func bareHTTPClient(f *cmdutil.Factory, version string) func() (*http.Client, er return factory.NewHTTPClient(f.IOStreams, cfg, version, false) } } + +func newCodespaceCmd(f *cmdutil.Factory) *cobra.Command { + cmd := ghcs.NewRootCmd(ghcs.NewApp( + output.NewLogger(f.IOStreams.Out, f.IOStreams.ErrOut, !f.IOStreams.IsStdoutTTY()), + ghcsApi.New("", &lazyLoadedHTTPClient{factory: f}), + )) + cmd.Use = "codespace" + cmd.Aliases = []string{"cs"} + cmd.Hidden = true + return cmd +} + +type lazyLoadedHTTPClient struct { + factory *cmdutil.Factory + + httpClientMu sync.RWMutex // guards httpClient + httpClient *http.Client +} + +func (l *lazyLoadedHTTPClient) Do(req *http.Request) (*http.Response, error) { + l.httpClientMu.RLock() + httpClient := l.httpClient + l.httpClientMu.RUnlock() + + if httpClient == nil { + l.httpClientMu.Lock() + defer l.httpClientMu.Unlock() + + var err error + l.httpClient, err = l.factory.HttpClient() + if err != nil { + return nil, err + } + } + return l.httpClient.Do(req) +}