diff --git a/internal/codespaces/api/api.go b/internal/codespaces/api/api.go index ed86a2c6c..6b31f0f69 100644 --- a/internal/codespaces/api/api.go +++ b/internal/codespaces/api/api.go @@ -605,6 +605,64 @@ func (a *API) DeleteCodespace(ctx context.Context, codespaceName string) error { return nil } +// ListDevContainers returns a list of valid devcontainer.json files for the repo. Pass a negative limit to request all pages from +// the API until all devcontainer.json files have been fetched. +func (a *API) ListDevContainers(ctx context.Context, repoID int, branch string, limit int) (devcontainers []string, err error) { + perPage := 100 + if limit > 0 && limit < 100 { + perPage = limit + } + + listURL := fmt.Sprintf("%s/repositories/%d/codespaces/devcontainers?per_page=%d", a.githubAPI, repoID, perPage) + if branch != "" { + listURL += "&ref=" + branch + } + for { + req, err := http.NewRequest(http.MethodGet, listURL, nil) + if err != nil { + return nil, fmt.Errorf("error creating request: %w", err) + } + a.setHeaders(req) + + resp, err := a.do(ctx, req, fmt.Sprintf("/repositories/%d/codespaces/devcontainers", repoID)) + if err != nil { + return nil, fmt.Errorf("error making request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, api.HandleHTTPError(resp) + } + + var response struct { + Devcontainers []string `json:"devcontainers"` + } + dec := json.NewDecoder(resp.Body) + if err := dec.Decode(&response); err != nil { + return nil, fmt.Errorf("error unmarshaling response: %w", err) + } + + nextURL := findNextPage(resp.Header.Get("Link")) + devcontainers = append(devcontainers, response.Devcontainers...) + + if nextURL == "" || (limit > 0 && len(devcontainers) >= limit) { + break + } + + if newPerPage := limit - len(devcontainers); limit > 0 && newPerPage < 100 { + u, _ := url.Parse(nextURL) + q := u.Query() + q.Set("per_page", strconv.Itoa(newPerPage)) + u.RawQuery = q.Encode() + listURL = u.String() + } else { + listURL = nextURL + } + } + + return devcontainers, nil +} + type getCodespaceRepositoryContentsResponse struct { Content string `json:"content"` } diff --git a/pkg/cmd/codespace/common.go b/pkg/cmd/codespace/common.go index 1107ae6a5..20d9cfe20 100644 --- a/pkg/cmd/codespace/common.go +++ b/pkg/cmd/codespace/common.go @@ -67,6 +67,7 @@ type apiClient interface { GetCodespaceRegionLocation(ctx context.Context) (string, error) GetCodespacesMachines(ctx context.Context, repoID int, branch, location string) ([]*api.Machine, error) GetCodespaceRepositoryContents(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error) + ListDevContainers(ctx context.Context, repoID int, branch string, limit int) (devcontainers []string, err error) } var errNoCodespaces = errors.New("you have no codespaces") diff --git a/pkg/cmd/codespace/create.go b/pkg/cmd/codespace/create.go index 8f884f8da..7964fb1c0 100644 --- a/pkg/cmd/codespace/create.go +++ b/pkg/cmd/codespace/create.go @@ -91,6 +91,37 @@ func (a *App) Create(ctx context.Context, opts createOptions) error { branch = repository.DefaultBranch } + devContainerPath := opts.devContainerPath + + // now that we have repo+branch, we can list available devcontainer.json files (if any) + if len(opts.devContainerPath) < 1 { + a.StartProgressIndicatorWithLabel("Fetching devcontainer.json files") + devContainerPaths, err := a.apiClient.ListDevContainers(ctx, repository.ID, branch, 100) + if err != nil { + return fmt.Errorf("error getting devcontainer.json paths: %w", err) + } + a.StopProgressIndicator() + + if len(devContainerPaths) > 0 { + devContainerPathQuestion := &survey.Question{ + Name: "devContainerPath", + Prompt: &survey.Select{ + Message: "Devcontainer.json file:", + Options: append([]string{"none"}, devContainerPaths...), + }, + } + + if err := ask([]*survey.Question{devContainerPathQuestion}, &devContainerPath); err != nil { + return fmt.Errorf("failed to prompt: %w", err) + } + } + + if devContainerPath == "none" { + // special arg allows users to opt out of devcontainer.json selection + devContainerPath = "" + } + } + locationResult := <-locationCh if locationResult.Err != nil { return fmt.Errorf("error getting codespace region location: %w", locationResult.Err) @@ -111,7 +142,7 @@ func (a *App) Create(ctx context.Context, opts createOptions) error { Machine: machine, Location: locationResult.Location, IdleTimeoutMinutes: int(opts.idleTimeout.Minutes()), - DevContainerPath: opts.devContainerPath, + DevContainerPath: devContainerPath, }) a.StopProgressIndicator() if err != nil { diff --git a/pkg/cmd/codespace/create_test.go b/pkg/cmd/codespace/create_test.go index c5c1ea075..0a7db6294 100644 --- a/pkg/cmd/codespace/create_test.go +++ b/pkg/cmd/codespace/create_test.go @@ -19,6 +19,7 @@ func TestApp_Create(t *testing.T) { fields fields opts createOptions wantErr bool + wantErrMsg string wantStdout string wantStderr string }{ @@ -70,17 +71,103 @@ func TestApp_Create(t *testing.T) { }, wantStdout: "monalisa-dotfiles-abcd1234\n", }, + { + name: "create codespace with default branch with default devcontainer if no path provided", + fields: fields{ + apiClient: &apiClientMock{ + GetCodespaceRegionLocationFunc: func(ctx context.Context) (string, error) { + return "EUROPE", nil + }, + GetRepositoryFunc: func(ctx context.Context, nwo string) (*api.Repository, error) { + return &api.Repository{ + ID: 1234, + FullName: nwo, + DefaultBranch: "main", + }, nil + }, + ListDevContainersFunc: func(ctx context.Context, repoID int, branch string, limit int) ([]string, error) { + return []string{}, nil + }, + GetCodespacesMachinesFunc: func(ctx context.Context, repoID int, branch, location string) ([]*api.Machine, error) { + return []*api.Machine{ + { + Name: "GIGA", + DisplayName: "Gigabits of a machine", + }, + }, nil + }, + CreateCodespaceFunc: func(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error) { + if params.Branch != "main" { + return nil, fmt.Errorf("got branch %q, want %q", params.Branch, "main") + } + if params.IdleTimeoutMinutes != 30 { + return nil, fmt.Errorf("idle timeout minutes was %v", params.IdleTimeoutMinutes) + } + if params.DevContainerPath != "" { + return nil, fmt.Errorf("got dev container path %q, want %q", params.DevContainerPath, ".devcontainer/foobar/devcontainer.json") + } + return &api.Codespace{ + Name: "monalisa-dotfiles-abcd1234", + }, nil + }, + }, + }, + opts: createOptions{ + repo: "monalisa/dotfiles", + branch: "", + machine: "GIGA", + showStatus: false, + idleTimeout: 30 * time.Minute, + }, + wantStdout: "monalisa-dotfiles-abcd1234\n", + }, + { + name: "returns error when getting devcontainer paths fails", + fields: fields{ + apiClient: &apiClientMock{ + GetCodespaceRegionLocationFunc: func(ctx context.Context) (string, error) { + return "EUROPE", nil + }, + GetRepositoryFunc: func(ctx context.Context, nwo string) (*api.Repository, error) { + return &api.Repository{ + ID: 1234, + FullName: nwo, + DefaultBranch: "main", + }, nil + }, + ListDevContainersFunc: func(ctx context.Context, repoID int, branch string, limit int) ([]string, error) { + return nil, fmt.Errorf("some error") + }, + }, + }, + opts: createOptions{ + repo: "monalisa/dotfiles", + branch: "", + machine: "GIGA", + showStatus: false, + idleTimeout: 30 * time.Minute, + }, + wantErr: true, + wantErrMsg: "error getting devcontainer.json paths: some error", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { io, _, stdout, stderr := iostreams.Test() + io.SetStdinTTY(true) + io.SetStdoutTTY(true) + a := &App{ io: io, apiClient: tt.fields.apiClient, } - if err := a.Create(context.Background(), tt.opts); (err != nil) != tt.wantErr { + err := a.Create(context.Background(), tt.opts) + if (err != nil) != tt.wantErr { t.Errorf("App.Create() error = %v, wantErr %v", err, tt.wantErr) } + if tt.wantErrMsg != "" && err.Error() != tt.wantErrMsg { + t.Errorf("err message = %v, wantErrMsg %v", err.Error(), tt.wantErrMsg) + } if got := stdout.String(); got != tt.wantStdout { t.Errorf("stdout = %v, want %v", got, tt.wantStdout) } diff --git a/pkg/cmd/codespace/mock_api.go b/pkg/cmd/codespace/mock_api.go index 8d40934da..c27d5b48d 100644 --- a/pkg/cmd/codespace/mock_api.go +++ b/pkg/cmd/codespace/mock_api.go @@ -46,6 +46,9 @@ import ( // ListCodespacesFunc: func(ctx context.Context, limit int) ([]*api.Codespace, error) { // panic("mock out the ListCodespaces method") // }, +// ListDevContainersFunc: func(ctx context.Context, repoID int, branch string, limit int) ([]string, error) { +// panic("mock out the ListDevContainers method") +// }, // StartCodespaceFunc: func(ctx context.Context, name string) error { // panic("mock out the StartCodespace method") // }, @@ -89,6 +92,9 @@ type apiClientMock struct { // ListCodespacesFunc mocks the ListCodespaces method. ListCodespacesFunc func(ctx context.Context, limit int) ([]*api.Codespace, error) + // ListDevContainersFunc mocks the ListDevContainers method. + ListDevContainersFunc func(ctx context.Context, repoID int, branch string, limit int) ([]string, error) + // StartCodespaceFunc mocks the StartCodespace method. StartCodespaceFunc func(ctx context.Context, name string) error @@ -171,6 +177,17 @@ type apiClientMock struct { // Limit is the limit argument value. Limit int } + // ListDevContainers holds details about calls to the ListDevContainers method. + ListDevContainers []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // RepoID is the repoID argument value. + RepoID int + // Branch is the branch argument value. + Branch string + // Limit is the limit argument value. + Limit int + } // StartCodespace holds details about calls to the StartCodespace method. StartCodespace []struct { // Ctx is the ctx argument value. @@ -196,6 +213,7 @@ type apiClientMock struct { lockGetRepository sync.RWMutex lockGetUser sync.RWMutex lockListCodespaces sync.RWMutex + lockListDevContainers sync.RWMutex lockStartCodespace sync.RWMutex lockStopCodespace sync.RWMutex } @@ -558,6 +576,49 @@ func (mock *apiClientMock) ListCodespacesCalls() []struct { return calls } +// ListDevContainers calls ListDevContainersFunc. +func (mock *apiClientMock) ListDevContainers(ctx context.Context, repoID int, branch string, limit int) ([]string, error) { + if mock.ListDevContainersFunc == nil { + panic("apiClientMock.ListDevContainersFunc: method is nil but apiClient.ListDevContainers was just called") + } + callInfo := struct { + Ctx context.Context + RepoID int + Branch string + Limit int + }{ + Ctx: ctx, + RepoID: repoID, + Branch: branch, + Limit: limit, + } + mock.lockListDevContainers.Lock() + mock.calls.ListDevContainers = append(mock.calls.ListDevContainers, callInfo) + mock.lockListDevContainers.Unlock() + return mock.ListDevContainersFunc(ctx, repoID, branch, limit) +} + +// ListDevContainersCalls gets all the calls that were made to ListDevContainers. +// Check the length with: +// len(mockedapiClient.ListDevContainersCalls()) +func (mock *apiClientMock) ListDevContainersCalls() []struct { + Ctx context.Context + RepoID int + Branch string + Limit int +} { + var calls []struct { + Ctx context.Context + RepoID int + Branch string + Limit int + } + mock.lockListDevContainers.RLock() + calls = mock.calls.ListDevContainers + mock.lockListDevContainers.RUnlock() + return calls +} + // StartCodespace calls StartCodespaceFunc. func (mock *apiClientMock) StartCodespace(ctx context.Context, name string) error { if mock.StartCodespaceFunc == nil {