diff --git a/internal/codespaces/api/api.go b/internal/codespaces/api/api.go index 505e292ae..ab294ef4a 100644 --- a/internal/codespaces/api/api.go +++ b/internal/codespaces/api/api.go @@ -104,8 +104,9 @@ func (a *API) GetUser(ctx context.Context) (*User, error) { // Repository represents a GitHub repository. type Repository struct { - ID int `json:"id"` - FullName string `json:"full_name"` + ID int `json:"id"` + FullName string `json:"full_name"` + DefaultBranch string `json:"default_branch"` } // GetRepository returns the repository associated with the given owner and name. diff --git a/pkg/cmd/codespace/create.go b/pkg/cmd/codespace/create.go index e0a4de7a4..ac74658c3 100644 --- a/pkg/cmd/codespace/create.go +++ b/pkg/cmd/codespace/create.go @@ -41,34 +41,56 @@ func newCreateCmd(app *App) *cobra.Command { // Create creates a new Codespace func (a *App) Create(ctx context.Context, opts createOptions) error { locationCh := getLocation(ctx, a.apiClient) - userCh := getUser(ctx, a.apiClient) - repo, err := getRepoName(opts.repo) - if err != nil { - return fmt.Errorf("error getting repository name: %w", err) + userInputs := struct { + Repository string + Branch string + }{ + Repository: opts.repo, + Branch: opts.branch, } - branch, err := getBranchName(opts.branch) - if err != nil { - return fmt.Errorf("error getting branch name: %w", err) + + if userInputs.Repository == "" { + branchPrompt := "Branch (leave blank for default branch):" + if userInputs.Branch != "" { + branchPrompt = "Branch:" + } + questions := []*survey.Question{ + { + Name: "repository", + Prompt: &survey.Input{Message: "Repository:"}, + Validate: survey.Required, + }, + { + Name: "branch", + Prompt: &survey.Input{ + Message: branchPrompt, + Default: userInputs.Branch, + }, + }, + } + if err := ask(questions, &userInputs); err != nil { + return fmt.Errorf("failed to prompt: %w", err) + } } a.StartProgressIndicatorWithLabel("Fetching repository") - repository, err := a.apiClient.GetRepository(ctx, repo) + repository, err := a.apiClient.GetRepository(ctx, userInputs.Repository) a.StopProgressIndicator() if err != nil { return fmt.Errorf("error getting repository: %w", err) } + branch := userInputs.Branch + if branch == "" { + branch = repository.DefaultBranch + } + locationResult := <-locationCh if locationResult.Err != nil { return fmt.Errorf("error getting codespace region location: %w", locationResult.Err) } - userResult := <-userCh - if userResult.Err != nil { - return fmt.Errorf("error getting codespace user: %w", userResult.Err) - } - machine, err := getMachineName(ctx, a.apiClient, repository.ID, opts.machine, branch, locationResult.Location) if err != nil { return fmt.Errorf("error getting machine type: %w", err) @@ -90,7 +112,7 @@ func (a *App) Create(ctx context.Context, opts createOptions) error { } if opts.showStatus { - if err := a.showStatus(ctx, userResult.User, codespace); err != nil { + if err := a.showStatus(ctx, codespace); err != nil { return fmt.Errorf("show status: %w", err) } } @@ -102,7 +124,7 @@ func (a *App) Create(ctx context.Context, opts createOptions) error { // showStatus polls the codespace for a list of post create states and their status. It will keep polling // until all states have finished. Once all states have finished, we poll once more to check if any new // states have been introduced and stop polling otherwise. -func (a *App) showStatus(ctx context.Context, user *api.User, codespace *api.Codespace) error { +func (a *App) showStatus(ctx context.Context, codespace *api.Codespace) error { var ( lastState codespaces.PostCreateState breakNextState bool @@ -163,21 +185,6 @@ func (a *App) showStatus(ctx context.Context, user *api.User, codespace *api.Cod return nil } -type getUserResult struct { - User *api.User - Err error -} - -// getUser fetches the user record associated with the GITHUB_TOKEN -func getUser(ctx context.Context, apiClient apiClient) <-chan getUserResult { - ch := make(chan getUserResult, 1) - go func() { - user, err := apiClient.GetUser(ctx) - ch <- getUserResult{user, err} - }() - return ch -} - type locationResult struct { Location string Err error @@ -193,40 +200,6 @@ func getLocation(ctx context.Context, apiClient apiClient) <-chan locationResult return ch } -// getRepoName prompts the user for the name of the repository, or returns the repository if non-empty. -func getRepoName(repo string) (string, error) { - if repo != "" { - return repo, nil - } - - repoSurvey := []*survey.Question{ - { - Name: "repository", - Prompt: &survey.Input{Message: "Repository:"}, - Validate: survey.Required, - }, - } - err := ask(repoSurvey, &repo) - return repo, err -} - -// getBranchName prompts the user for the name of the branch, or returns the branch if non-empty. -func getBranchName(branch string) (string, error) { - if branch != "" { - return branch, nil - } - - branchSurvey := []*survey.Question{ - { - Name: "branch", - Prompt: &survey.Input{Message: "Branch:"}, - Validate: survey.Required, - }, - } - err := ask(branchSurvey, &branch) - return branch, err -} - // getMachineName prompts the user to select the machine type, or validates the machine if non-empty. func getMachineName(ctx context.Context, apiClient apiClient, repoID int, machine, branch, location string) (string, error) { machines, err := apiClient.GetCodespacesMachines(ctx, repoID, branch, location)