diff --git a/internal/codespaces/api/api.go b/internal/codespaces/api/api.go index 4e0973193..cb19b4390 100644 --- a/internal/codespaces/api/api.go +++ b/internal/codespaces/api/api.go @@ -268,19 +268,38 @@ func (c *Codespace) ExportData(fields []string) map[string]interface{} { return data } +type ListCodespacesOptions struct { + OrgName string + UserName string + RepoName string + Limit int +} + // ListCodespaces returns a list of codespaces for the user. Pass a negative limit to request all pages from // the API until all codespaces have been fetched. -func (a *API) ListCodespaces(ctx context.Context, limit int, orgName string, userName string) (codespaces []*Codespace, err error) { - perPage := 100 +func (a *API) ListCodespaces(ctx context.Context, opts ListCodespacesOptions) (codespaces []*Codespace, err error) { + var ( + perPage = 100 + limit = opts.Limit + ) + if limit > 0 && limit < 100 { perPage = limit } - var listURL string - var spanName string + var ( + listURL string + spanName string + ) - if orgName != "" { - if userName != "" { + if opts.RepoName != "" { + listURL = fmt.Sprintf("%s/repos/%s/codespaces?per_page=%d", a.githubAPI, opts.RepoName, perPage) + spanName = "/repos/*/codespaces" + } else if opts.OrgName != "" { + // the endpoints below can only be called by the organization admins + orgName := opts.OrgName + if opts.UserName != "" { + userName := opts.UserName listURL = fmt.Sprintf("%s/orgs/%s/members/%s/codespaces?per_page=%d", a.githubAPI, orgName, userName, perPage) spanName = "/orgs/*/members/*/codespaces" } else { diff --git a/internal/codespaces/api/api_test.go b/internal/codespaces/api/api_test.go index c48073795..482a6feef 100644 --- a/internal/codespaces/api/api_test.go +++ b/internal/codespaces/api/api_test.go @@ -140,7 +140,7 @@ func TestListCodespaces_limited(t *testing.T) { client: &http.Client{}, } ctx := context.TODO() - codespaces, err := api.ListCodespaces(ctx, 200, "", "") + codespaces, err := api.ListCodespaces(ctx, ListCodespacesOptions{Limit: 200}) if err != nil { t.Fatal(err) } @@ -165,7 +165,7 @@ func TestListCodespaces_unlimited(t *testing.T) { client: &http.Client{}, } ctx := context.TODO() - codespaces, err := api.ListCodespaces(ctx, -1, "", "") + codespaces, err := api.ListCodespaces(ctx, ListCodespacesOptions{}) if err != nil { t.Fatal(err) } diff --git a/pkg/cmd/codespace/common.go b/pkg/cmd/codespace/common.go index 816191238..759cec3ce 100644 --- a/pkg/cmd/codespace/common.go +++ b/pkg/cmd/codespace/common.go @@ -108,7 +108,7 @@ type apiClient interface { GetUser(ctx context.Context) (*api.User, error) GetCodespace(ctx context.Context, name string, includeConnection bool) (*api.Codespace, error) GetOrgMemberCodespace(ctx context.Context, orgName string, userName string, codespaceName string) (*api.Codespace, error) - ListCodespaces(ctx context.Context, limit int, orgName string, userName string) ([]*api.Codespace, error) + ListCodespaces(ctx context.Context, opts api.ListCodespacesOptions) ([]*api.Codespace, error) DeleteCodespace(ctx context.Context, name string, orgName string, userName string) error StartCodespace(ctx context.Context, name string) error StopCodespace(ctx context.Context, name string, orgName string, userName string) error @@ -126,7 +126,7 @@ type apiClient interface { var errNoCodespaces = errors.New("you have no codespaces") func chooseCodespace(ctx context.Context, apiClient apiClient) (*api.Codespace, error) { - codespaces, err := apiClient.ListCodespaces(ctx, -1, "", "") + codespaces, err := apiClient.ListCodespaces(ctx, api.ListCodespacesOptions{}) if err != nil { return nil, fmt.Errorf("error getting codespaces: %w", err) } diff --git a/pkg/cmd/codespace/delete.go b/pkg/cmd/codespace/delete.go index 06edcc327..7167f84c0 100644 --- a/pkg/cmd/codespace/delete.go +++ b/pkg/cmd/codespace/delete.go @@ -80,7 +80,7 @@ func (a *App) Delete(ctx context.Context, opts deleteOptions) (err error) { nameFilter := opts.codespaceName if nameFilter == "" { a.StartProgressIndicatorWithLabel("Fetching codespaces") - codespaces, err = a.apiClient.ListCodespaces(ctx, -1, opts.orgName, opts.userName) + codespaces, err = a.apiClient.ListCodespaces(ctx, api.ListCodespacesOptions{OrgName: opts.orgName, UserName: opts.userName}) a.StopProgressIndicator() if err != nil { return fmt.Errorf("error getting codespaces: %w", err) diff --git a/pkg/cmd/codespace/delete_test.go b/pkg/cmd/codespace/delete_test.go index b2700903b..3555a2a34 100644 --- a/pkg/cmd/codespace/delete_test.go +++ b/pkg/cmd/codespace/delete_test.go @@ -218,7 +218,7 @@ func TestDelete(t *testing.T) { }, } if tt.opts.codespaceName == "" { - apiMock.ListCodespacesFunc = func(_ context.Context, num int, orgName string, userName string) ([]*api.Codespace, error) { + apiMock.ListCodespacesFunc = func(_ context.Context, _ api.ListCodespacesOptions) ([]*api.Codespace, error) { return tt.codespaces, nil } } else { diff --git a/pkg/cmd/codespace/list.go b/pkg/cmd/codespace/list.go index 7cc60f0ca..b5aaf9f5c 100644 --- a/pkg/cmd/codespace/list.go +++ b/pkg/cmd/codespace/list.go @@ -14,6 +14,7 @@ import ( type listOptions struct { limit int + repo string orgName string userName string } @@ -42,6 +43,7 @@ func newListCmd(app *App) *cobra.Command { } listCmd.Flags().IntVarP(&opts.limit, "limit", "L", 30, "Maximum number of codespaces to list") + listCmd.Flags().StringVarP(&opts.repo, "repo", "r", "", "Repository name with owner: user/repo") listCmd.Flags().StringVarP(&opts.orgName, "org", "o", "", "The `login` handle of the organization to list codespaces for (admin-only)") listCmd.Flags().StringVarP(&opts.userName, "user", "u", "", "The `username` to list codespaces for (used with --org)") cmdutil.AddJSONFlags(listCmd, &exporter, api.CodespaceFields) @@ -50,8 +52,13 @@ func newListCmd(app *App) *cobra.Command { } func (a *App) List(ctx context.Context, opts *listOptions, exporter cmdutil.Exporter) error { + // if repo is provided, we don't accept orgName or userName + if opts.repo != "" && (opts.orgName != "" || opts.userName != "") { + return cmdutil.FlagErrorf("using `--org` or `--user` with `--repo` is not allowed") + } + a.StartProgressIndicatorWithLabel("Fetching codespaces") - codespaces, err := a.apiClient.ListCodespaces(ctx, opts.limit, opts.orgName, opts.userName) + codespaces, err := a.apiClient.ListCodespaces(ctx, api.ListCodespacesOptions{Limit: opts.limit, RepoName: opts.repo, OrgName: opts.orgName, UserName: opts.userName}) a.StopProgressIndicator() if err != nil { return fmt.Errorf("error getting codespaces: %w", err) diff --git a/pkg/cmd/codespace/list_test.go b/pkg/cmd/codespace/list_test.go index 001806b43..46448869f 100644 --- a/pkg/cmd/codespace/list_test.go +++ b/pkg/cmd/codespace/list_test.go @@ -15,16 +15,17 @@ func TestApp_List(t *testing.T) { apiClient apiClient } tests := []struct { - name string - fields fields - opts *listOptions + name string + fields fields + opts *listOptions + wantError error }{ { name: "list codespaces, no flags", fields: fields{ apiClient: &apiClientMock{ - ListCodespacesFunc: func(ctx context.Context, limit int, orgName string, userName string) ([]*api.Codespace, error) { - if orgName != "" { + ListCodespacesFunc: func(ctx context.Context, opts api.ListCodespacesOptions) ([]*api.Codespace, error) { + if opts.OrgName != "" { return nil, fmt.Errorf("should not be called with an orgName") } return []*api.Codespace{ @@ -41,12 +42,12 @@ func TestApp_List(t *testing.T) { name: "list codespaces, --org flag", fields: fields{ apiClient: &apiClientMock{ - ListCodespacesFunc: func(ctx context.Context, limit int, orgName string, userName string) ([]*api.Codespace, error) { - if orgName != "TestOrg" { - return nil, fmt.Errorf("Expected orgName to be TestOrg. Got %s", orgName) + ListCodespacesFunc: func(ctx context.Context, opts api.ListCodespacesOptions) ([]*api.Codespace, error) { + if opts.OrgName != "TestOrg" { + return nil, fmt.Errorf("Expected orgName to be TestOrg. Got %s", opts.OrgName) } - if userName != "" { - return nil, fmt.Errorf("Expected userName to be blank. Got %s", userName) + if opts.UserName != "" { + return nil, fmt.Errorf("Expected userName to be blank. Got %s", opts.UserName) } return []*api.Codespace{ { @@ -64,12 +65,12 @@ func TestApp_List(t *testing.T) { name: "list codespaces, --org and --user flag", fields: fields{ apiClient: &apiClientMock{ - ListCodespacesFunc: func(ctx context.Context, limit int, orgName string, userName string) ([]*api.Codespace, error) { - if orgName != "TestOrg" { - return nil, fmt.Errorf("Expected orgName to be TestOrg. Got %s", orgName) + ListCodespacesFunc: func(ctx context.Context, opts api.ListCodespacesOptions) ([]*api.Codespace, error) { + if opts.OrgName != "TestOrg" { + return nil, fmt.Errorf("Expected orgName to be TestOrg. Got %s", opts.OrgName) } - if userName != "jimmy" { - return nil, fmt.Errorf("Expected userName to be jimmy. Got %s", userName) + if opts.UserName != "jimmy" { + return nil, fmt.Errorf("Expected userName to be jimmy. Got %s", opts.UserName) } return []*api.Codespace{ { @@ -84,6 +85,44 @@ func TestApp_List(t *testing.T) { userName: "jimmy", }, }, + { + name: "list codespaces, --repo", + fields: fields{ + apiClient: &apiClientMock{ + ListCodespacesFunc: func(ctx context.Context, opts api.ListCodespacesOptions) ([]*api.Codespace, error) { + if opts.RepoName == "" { + return nil, fmt.Errorf("Expected repository to not be nil") + } + if opts.RepoName != "cli/cli" { + return nil, fmt.Errorf("Expected repository name to be cli/cli. Got %s", opts.RepoName) + } + if opts.OrgName != "" { + return nil, fmt.Errorf("Expected orgName to be blank. Got %s", opts.OrgName) + } + if opts.UserName != "" { + return nil, fmt.Errorf("Expected userName to be blank. Got %s", opts.UserName) + } + return []*api.Codespace{ + { + DisplayName: "CS1", + }, + }, nil + }, + }, + }, + opts: &listOptions{ + repo: "cli/cli", + }, + }, + { + name: "list codespaces,--repo, --org and --user flag", + opts: &listOptions{ + repo: "cli/cli", + orgName: "TestOrg", + userName: "jimmy", + }, + wantError: fmt.Errorf("using `--org` or `--user` with `--repo` is not allowed"), + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -95,8 +134,13 @@ func TestApp_List(t *testing.T) { var exporter cmdutil.Exporter err := a.List(context.Background(), tt.opts, exporter) - if err != nil { - t.Error(err) + if (err != nil) != (tt.wantError != nil) { + t.Errorf("error = %v, wantErr %v", err, tt.wantError) + return + } + + if err != nil && err.Error() != tt.wantError.Error() { + t.Errorf("error = %v, wantErr %v", err, tt.wantError) } }) } diff --git a/pkg/cmd/codespace/mock_api.go b/pkg/cmd/codespace/mock_api.go index 8d5c29d4e..fe3b61f00 100644 --- a/pkg/cmd/codespace/mock_api.go +++ b/pkg/cmd/codespace/mock_api.go @@ -52,7 +52,7 @@ import ( // GetUserFunc: func(ctx context.Context) (*api.User, error) { // panic("mock out the GetUser method") // }, -// ListCodespacesFunc: func(ctx context.Context, limit int, orgName string, userName string) ([]*api.Codespace, error) { +// ListCodespacesFunc: func(ctx context.Context, opts api.ListCodespacesOptions) ([]*api.Codespace, error) { // panic("mock out the ListCodespaces method") // }, // ListDevContainersFunc: func(ctx context.Context, repoID int, branch string, limit int) ([]api.DevContainerEntry, error) { @@ -108,7 +108,7 @@ type apiClientMock struct { GetUserFunc func(ctx context.Context) (*api.User, error) // ListCodespacesFunc mocks the ListCodespaces method. - ListCodespacesFunc func(ctx context.Context, limit int, orgName string, userName string) ([]*api.Codespace, error) + ListCodespacesFunc func(ctx context.Context, opts api.ListCodespacesOptions) ([]*api.Codespace, error) // ListDevContainersFunc mocks the ListDevContainers method. ListDevContainersFunc func(ctx context.Context, repoID int, branch string, limit int) ([]api.DevContainerEntry, error) @@ -227,12 +227,8 @@ type apiClientMock struct { ListCodespaces []struct { // Ctx is the ctx argument value. Ctx context.Context - // Limit is the limit argument value. - Limit int - // OrgName is the orgName argument value. - OrgName string - // UserName is the userName argument value. - UserName string + // Opts is the opts argument value. + Opts api.ListCodespacesOptions } // ListDevContainers holds details about calls to the ListDevContainers method. ListDevContainers []struct { @@ -739,41 +735,33 @@ func (mock *apiClientMock) GetUserCalls() []struct { } // ListCodespaces calls ListCodespacesFunc. -func (mock *apiClientMock) ListCodespaces(ctx context.Context, limit int, orgName string, userName string) ([]*api.Codespace, error) { +func (mock *apiClientMock) ListCodespaces(ctx context.Context, opts api.ListCodespacesOptions) ([]*api.Codespace, error) { if mock.ListCodespacesFunc == nil { panic("apiClientMock.ListCodespacesFunc: method is nil but apiClient.ListCodespaces was just called") } callInfo := struct { - Ctx context.Context - Limit int - OrgName string - UserName string + Ctx context.Context + Opts api.ListCodespacesOptions }{ - Ctx: ctx, - Limit: limit, - OrgName: orgName, - UserName: userName, + Ctx: ctx, + Opts: opts, } mock.lockListCodespaces.Lock() mock.calls.ListCodespaces = append(mock.calls.ListCodespaces, callInfo) mock.lockListCodespaces.Unlock() - return mock.ListCodespacesFunc(ctx, limit, orgName, userName) + return mock.ListCodespacesFunc(ctx, opts) } // ListCodespacesCalls gets all the calls that were made to ListCodespaces. // Check the length with: // len(mockedapiClient.ListCodespacesCalls()) func (mock *apiClientMock) ListCodespacesCalls() []struct { - Ctx context.Context - Limit int - OrgName string - UserName string + Ctx context.Context + Opts api.ListCodespacesOptions } { var calls []struct { - Ctx context.Context - Limit int - OrgName string - UserName string + Ctx context.Context + Opts api.ListCodespacesOptions } mock.lockListCodespaces.RLock() calls = mock.calls.ListCodespaces diff --git a/pkg/cmd/codespace/ssh.go b/pkg/cmd/codespace/ssh.go index ac1f7ecfe..292bae0d7 100644 --- a/pkg/cmd/codespace/ssh.go +++ b/pkg/cmd/codespace/ssh.go @@ -235,7 +235,7 @@ func (a *App) printOpenSSHConfig(ctx context.Context, opts sshOptions) (err erro var csList []*api.Codespace if opts.codespace == "" { a.StartProgressIndicatorWithLabel("Fetching codespaces") - csList, err = a.apiClient.ListCodespaces(ctx, -1, "", "") + csList, err = a.apiClient.ListCodespaces(ctx, api.ListCodespacesOptions{}) a.StopProgressIndicator() } else { var codespace *api.Codespace diff --git a/pkg/cmd/codespace/stop.go b/pkg/cmd/codespace/stop.go index e8a33f119..bc60e9dc1 100644 --- a/pkg/cmd/codespace/stop.go +++ b/pkg/cmd/codespace/stop.go @@ -43,7 +43,7 @@ func (a *App) StopCodespace(ctx context.Context, opts *stopOptions) error { if codespaceName == "" { a.StartProgressIndicatorWithLabel("Fetching codespaces") - codespaces, err := a.apiClient.ListCodespaces(ctx, -1, opts.orgName, ownerName) + codespaces, err := a.apiClient.ListCodespaces(ctx, api.ListCodespacesOptions{OrgName: opts.orgName, UserName: ownerName}) a.StopProgressIndicator() if err != nil { return fmt.Errorf("failed to list codespaces: %w", err)