diff --git a/internal/codespaces/api/api.go b/internal/codespaces/api/api.go index c11e635f6..8eec820a8 100644 --- a/internal/codespaces/api/api.go +++ b/internal/codespaces/api/api.go @@ -94,6 +94,7 @@ func New(serverURL, apiURL, vscsURL string, httpClient httpClient) *API { // User represents a GitHub user. type User struct { Login string `json:"login"` + Type string `json:"type"` } // GetUser returns the user associated with the given token. @@ -556,6 +557,50 @@ func (a *API) GetCodespaceRepoSuggestions(ctx context.Context, partialSearch str return repoNames, nil } +// GetCodespaceBillableOwner returns the billable owner and expected default values for +// codespaces created by the user for a given repository. +func (a *API) GetCodespaceBillableOwner(ctx context.Context, nwo string) (*User, error) { + req, err := http.NewRequest(http.MethodGet, a.githubAPI+"/repos/"+nwo+"/codespaces/new", nil) + if err != nil { + return nil, fmt.Errorf("error creating request: %w", err) + } + + a.setHeaders(req) + resp, err := a.do(ctx, req, "/repos/*/codespaces/new") + if err != nil { + return nil, fmt.Errorf("error making request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusNotFound { + return nil, nil + } else if resp.StatusCode == http.StatusForbidden { + return nil, fmt.Errorf("you cannot create codespaces with that repository") + } else if resp.StatusCode != http.StatusOK { + return nil, api.HandleHTTPError(resp) + } + + b, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("error reading response body: %w", err) + } + + var response struct { + BillableOwner User `json:"billable_owner"` + Defaults struct { + DevcontainerPath string `json:"devcontainer_path"` + Location string `json:"location"` + } + } + if err := json.Unmarshal(b, &response); err != nil { + return nil, fmt.Errorf("error unmarshaling response: %w", err) + } + + // While this response contains further helpful information ahead of codespace creation, + // we're only referencing the billable owner today. + return &response.BillableOwner, nil +} + // CreateCodespaceParams are the required parameters for provisioning a Codespace. type CreateCodespaceParams struct { RepositoryID int diff --git a/pkg/cmd/codespace/common.go b/pkg/cmd/codespace/common.go index 6720aa0bb..6d272bcbb 100644 --- a/pkg/cmd/codespace/common.go +++ b/pkg/cmd/codespace/common.go @@ -120,6 +120,7 @@ type apiClient interface { GetCodespaceRepositoryContents(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error) ListDevContainers(ctx context.Context, repoID int, branch string, limit int) (devcontainers []api.DevContainerEntry, err error) GetCodespaceRepoSuggestions(ctx context.Context, partialSearch string, params api.RepoSearchParameters) ([]string, error) + GetCodespaceBillableOwner(ctx context.Context, nwo string) (*api.User, error) } var errNoCodespaces = errors.New("you have no codespaces") diff --git a/pkg/cmd/codespace/create.go b/pkg/cmd/codespace/create.go index 49d950e9f..764fe4314 100644 --- a/pkg/cmd/codespace/create.go +++ b/pkg/cmd/codespace/create.go @@ -111,12 +111,9 @@ func (a *App) Create(ctx context.Context, opts createOptions) error { Location: opts.location, } - if userInputs.Repository == "" { - branchPrompt := "Branch (leave blank for default branch):" - if userInputs.Branch != "" { - branchPrompt = "Branch:" - } - questions := []*survey.Question{ + promptForRepoAndBranch := userInputs.Repository == "" + if promptForRepoAndBranch { + repoQuestions := []*survey.Question{ { Name: "repository", Prompt: &survey.Input{ @@ -128,15 +125,8 @@ func (a *App) Create(ctx context.Context, opts createOptions) error { }, Validate: survey.Required, }, - { - Name: "branch", - Prompt: &survey.Input{ - Message: branchPrompt, - Default: userInputs.Branch, - }, - }, } - if err := ask(questions, &userInputs); err != nil { + if err := ask(repoQuestions, &userInputs); err != nil { return fmt.Errorf("failed to prompt: %w", err) } } @@ -152,6 +142,37 @@ func (a *App) Create(ctx context.Context, opts createOptions) error { return fmt.Errorf("error getting repository: %w", err) } + a.StartProgressIndicatorWithLabel("Validating repository for codespaces") + billableOwner, err := a.apiClient.GetCodespaceBillableOwner(ctx, userInputs.Repository) + a.StopProgressIndicator() + + if err != nil { + return fmt.Errorf("error checking codespace ownership: %w", err) + } else if billableOwner != nil && billableOwner.Type == "Organization" { + cs := a.io.ColorScheme() + fmt.Fprintln(a.io.Out, cs.Blue("✓ Codespaces usage for this repository is paid for by "+billableOwner.Login)) + } + + if promptForRepoAndBranch { + branchPrompt := "Branch (leave blank for default branch):" + if userInputs.Branch != "" { + branchPrompt = "Branch:" + } + branchQuestions := []*survey.Question{ + { + Name: "branch", + Prompt: &survey.Input{ + Message: branchPrompt, + Default: userInputs.Branch, + }, + }, + } + + if err := ask(branchQuestions, &userInputs); err != nil { + return fmt.Errorf("failed to prompt: %w", err) + } + } + branch := userInputs.Branch if branch == "" { branch = repository.DefaultBranch diff --git a/pkg/cmd/codespace/create_test.go b/pkg/cmd/codespace/create_test.go index 910c73495..7a14a6056 100644 --- a/pkg/cmd/codespace/create_test.go +++ b/pkg/cmd/codespace/create_test.go @@ -36,6 +36,12 @@ func TestApp_Create(t *testing.T) { DefaultBranch: "main", }, nil }, + GetCodespaceBillableOwnerFunc: func(ctx context.Context, nwo string) (*api.User, error) { + return &api.User{ + Login: "monalisa", + Type: "User", + }, nil + }, ListDevContainersFunc: func(ctx context.Context, repoID int, branch string, limit int) ([]api.DevContainerEntry, error) { return []api.DevContainerEntry{{Path: ".devcontainer/devcontainer.json"}}, nil }, @@ -80,9 +86,6 @@ func TestApp_Create(t *testing.T) { name: "create codespace with default branch shows idle timeout notice if present", fields: fields{ apiClient: &apiClientMock{ - GetCodespaceRegionLocationFunc: func(ctx context.Context) (string, error) { - return "EUROPE", nil - }, GetRepositoryFunc: func(ctx context.Context, nwo string) (*api.Repository, error) { return &api.Repository{ ID: 1234, @@ -90,6 +93,12 @@ func TestApp_Create(t *testing.T) { DefaultBranch: "main", }, nil }, + GetCodespaceBillableOwnerFunc: func(ctx context.Context, nwo string) (*api.User, error) { + return &api.User{ + Login: "monalisa", + Type: "User", + }, nil + }, GetCodespacesMachinesFunc: func(ctx context.Context, repoID int, branch, location string) ([]*api.Machine, error) { return []*api.Machine{ { @@ -131,9 +140,6 @@ func TestApp_Create(t *testing.T) { name: "create codespace with default branch with default devcontainer if no path provided and no devcontainer files exist in the repo", fields: fields{ apiClient: &apiClientMock{ - GetCodespaceRegionLocationFunc: func(ctx context.Context) (string, error) { - return "EUROPE", nil - }, GetRepositoryFunc: func(ctx context.Context, nwo string) (*api.Repository, error) { return &api.Repository{ ID: 1234, @@ -141,6 +147,12 @@ func TestApp_Create(t *testing.T) { DefaultBranch: "main", }, nil }, + GetCodespaceBillableOwnerFunc: func(ctx context.Context, nwo string) (*api.User, error) { + return &api.User{ + Login: "monalisa", + Type: "User", + }, nil + }, ListDevContainersFunc: func(ctx context.Context, repoID int, branch string, limit int) ([]api.DevContainerEntry, error) { return []api.DevContainerEntry{}, nil }, @@ -187,9 +199,6 @@ func TestApp_Create(t *testing.T) { name: "returns error when getting devcontainer paths fails", fields: fields{ apiClient: &apiClientMock{ - GetCodespaceRegionLocationFunc: func(ctx context.Context) (string, error) { - return "EUROPE", nil - }, GetRepositoryFunc: func(ctx context.Context, nwo string) (*api.Repository, error) { return &api.Repository{ ID: 1234, @@ -197,6 +206,12 @@ func TestApp_Create(t *testing.T) { DefaultBranch: "main", }, nil }, + GetCodespaceBillableOwnerFunc: func(ctx context.Context, nwo string) (*api.User, error) { + return &api.User{ + Login: "monalisa", + Type: "User", + }, nil + }, ListDevContainersFunc: func(ctx context.Context, repoID int, branch string, limit int) ([]api.DevContainerEntry, error) { return nil, fmt.Errorf("some error") }, @@ -215,9 +230,6 @@ func TestApp_Create(t *testing.T) { name: "create codespace with default branch does not show idle timeout notice if not conntected to terminal", fields: fields{ apiClient: &apiClientMock{ - GetCodespaceRegionLocationFunc: func(ctx context.Context) (string, error) { - return "EUROPE", nil - }, GetRepositoryFunc: func(ctx context.Context, nwo string) (*api.Repository, error) { return &api.Repository{ ID: 1234, @@ -225,6 +237,12 @@ func TestApp_Create(t *testing.T) { DefaultBranch: "main", }, nil }, + GetCodespaceBillableOwnerFunc: func(ctx context.Context, nwo string) (*api.User, error) { + return &api.User{ + Login: "monalisa", + Type: "User", + }, nil + }, ListDevContainersFunc: func(ctx context.Context, repoID int, branch string, limit int) ([]api.DevContainerEntry, error) { return []api.DevContainerEntry{}, nil }, @@ -268,6 +286,12 @@ func TestApp_Create(t *testing.T) { name: "create codespace that requires accepting additional permissions", fields: fields{ apiClient: &apiClientMock{ + GetCodespaceBillableOwnerFunc: func(ctx context.Context, nwo string) (*api.User, error) { + return &api.User{ + Login: "monalisa", + Type: "User", + }, nil + }, GetRepositoryFunc: func(ctx context.Context, nwo string) (*api.Repository, error) { return &api.Repository{ ID: 1234, @@ -315,6 +339,119 @@ Open this URL in your browser to review and authorize additional permissions: ex Alternatively, you can run "create" with the "--default-permissions" option to continue without authorizing additional permissions. `, }, + { + name: "returns error when user can't create codepaces for a repository", + fields: fields{ + apiClient: &apiClientMock{ + GetRepositoryFunc: func(ctx context.Context, nwo string) (*api.Repository, error) { + return &api.Repository{ + ID: 1234, + FullName: nwo, + DefaultBranch: "main", + }, nil + }, + GetCodespaceBillableOwnerFunc: func(ctx context.Context, nwo string) (*api.User, error) { + return nil, fmt.Errorf("some error") + }, + }, + }, + opts: createOptions{ + repo: "megacorp/private", + branch: "", + machine: "GIGA", + showStatus: false, + idleTimeout: 30 * time.Minute, + }, + wantErr: fmt.Errorf("error checking codespace ownership: some error"), + }, + { + name: "mentions billable owner when org covers codepaces for a repository", + fields: fields{ + apiClient: &apiClientMock{ + GetRepositoryFunc: func(ctx context.Context, nwo string) (*api.Repository, error) { + return &api.Repository{ + ID: 1234, + FullName: nwo, + DefaultBranch: "main", + }, nil + }, + GetCodespaceBillableOwnerFunc: func(ctx context.Context, nwo string) (*api.User, error) { + return &api.User{ + Type: "Organization", + Login: "megacorp", + }, nil + }, + ListDevContainersFunc: func(ctx context.Context, repoID int, branch string, limit int) ([]api.DevContainerEntry, error) { + return []api.DevContainerEntry{{Path: ".devcontainer/devcontainer.json"}}, nil + }, + GetCodespacesMachinesFunc: func(ctx context.Context, repoID int, branch, location string) ([]*api.Machine, error) { + return []*api.Machine{ + { + Name: "GIGA", + DisplayName: "Gigabits of a machine", + }, + }, nil + }, + CreateCodespaceFunc: func(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error) { + return &api.Codespace{ + Name: "megacorp-private-abcd1234", + }, nil + }, + }, + }, + opts: createOptions{ + repo: "megacorp/private", + branch: "", + machine: "GIGA", + showStatus: false, + idleTimeout: 30 * time.Minute, + }, + wantStdout: "✓ Codespaces usage for this repository is paid for by megacorp\nmegacorp-private-abcd1234\n", + }, + { + name: "doesn't mention billable owner when it's the individual", + fields: fields{ + apiClient: &apiClientMock{ + GetRepositoryFunc: func(ctx context.Context, nwo string) (*api.Repository, error) { + return &api.Repository{ + ID: 1234, + FullName: nwo, + DefaultBranch: "main", + }, nil + }, + GetCodespaceBillableOwnerFunc: func(ctx context.Context, nwo string) (*api.User, error) { + return &api.User{ + Type: "User", + Login: "monalisa", + }, nil + }, + ListDevContainersFunc: func(ctx context.Context, repoID int, branch string, limit int) ([]api.DevContainerEntry, error) { + return []api.DevContainerEntry{{Path: ".devcontainer/devcontainer.json"}}, nil + }, + GetCodespacesMachinesFunc: func(ctx context.Context, repoID int, branch, location string) ([]*api.Machine, error) { + return []*api.Machine{ + { + Name: "GIGA", + DisplayName: "Gigabits of a machine", + }, + }, nil + }, + CreateCodespaceFunc: func(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error) { + return &api.Codespace{ + Name: "megacorp-private-abcd1234", + }, nil + }, + }, + }, + opts: createOptions{ + repo: "megacorp/private", + branch: "", + machine: "GIGA", + showStatus: false, + idleTimeout: 30 * time.Minute, + }, + wantStdout: "megacorp-private-abcd1234\n", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/pkg/cmd/codespace/mock_api.go b/pkg/cmd/codespace/mock_api.go index 897b39f1d..3999692f2 100644 --- a/pkg/cmd/codespace/mock_api.go +++ b/pkg/cmd/codespace/mock_api.go @@ -31,8 +31,8 @@ import ( // GetCodespaceFunc: func(ctx context.Context, name string, includeConnection bool) (*api.Codespace, error) { // panic("mock out the GetCodespace method") // }, -// GetCodespaceRegionLocationFunc: func(ctx context.Context) (string, error) { -// panic("mock out the GetCodespaceRegionLocation method") +// GetCodespaceBillableOwnerFunc: func(ctx context.Context, nwo string) (*api.User, error) { +// panic("mock out the GetCodespaceBillableOwner method") // }, // GetCodespaceRepoSuggestionsFunc: func(ctx context.Context, partialSearch string, params api.RepoSearchParameters) ([]string, error) { // panic("mock out the GetCodespaceRepoSuggestions method") @@ -83,8 +83,8 @@ type apiClientMock struct { // GetCodespaceFunc mocks the GetCodespace method. GetCodespaceFunc func(ctx context.Context, name string, includeConnection bool) (*api.Codespace, error) - // GetCodespaceRegionLocationFunc mocks the GetCodespaceRegionLocation method. - GetCodespaceRegionLocationFunc func(ctx context.Context) (string, error) + // GetCodespaceBillableOwnerFunc mocks the GetCodespaceBillableOwner method. + GetCodespaceBillableOwnerFunc func(ctx context.Context, nwo string) (*api.User, error) // GetCodespaceRepoSuggestionsFunc mocks the GetCodespaceRepoSuggestions method. GetCodespaceRepoSuggestionsFunc func(ctx context.Context, partialSearch string, params api.RepoSearchParameters) ([]string, error) @@ -154,10 +154,12 @@ type apiClientMock struct { // IncludeConnection is the includeConnection argument value. IncludeConnection bool } - // GetCodespaceRegionLocation holds details about calls to the GetCodespaceRegionLocation method. - GetCodespaceRegionLocation []struct { + // GetCodespaceBillableOwner holds details about calls to the GetCodespaceBillableOwner method. + GetCodespaceBillableOwner []struct { // Ctx is the ctx argument value. Ctx context.Context + // Nwo is the nwo argument value. + Nwo string } // GetCodespaceRepoSuggestions holds details about calls to the GetCodespaceRepoSuggestions method. GetCodespaceRepoSuggestions []struct { @@ -238,7 +240,7 @@ type apiClientMock struct { lockDeleteCodespace sync.RWMutex lockEditCodespace sync.RWMutex lockGetCodespace sync.RWMutex - lockGetCodespaceRegionLocation sync.RWMutex + lockGetCodespaceBillableOwner sync.RWMutex lockGetCodespaceRepoSuggestions sync.RWMutex lockGetCodespaceRepositoryContents sync.RWMutex lockGetCodespacesMachines sync.RWMutex @@ -433,34 +435,38 @@ func (mock *apiClientMock) GetCodespaceCalls() []struct { return calls } -// GetCodespaceRegionLocation calls GetCodespaceRegionLocationFunc. -func (mock *apiClientMock) GetCodespaceRegionLocation(ctx context.Context) (string, error) { - if mock.GetCodespaceRegionLocationFunc == nil { - panic("apiClientMock.GetCodespaceRegionLocationFunc: method is nil but apiClient.GetCodespaceRegionLocation was just called") +// GetCodespaceBillableOwner calls GetCodespaceBillableOwnerFunc. +func (mock *apiClientMock) GetCodespaceBillableOwner(ctx context.Context, nwo string) (*api.User, error) { + if mock.GetCodespaceBillableOwnerFunc == nil { + panic("apiClientMock.GetCodespaceBillableOwnerFunc: method is nil but apiClient.GetCodespaceBillableOwner was just called") } callInfo := struct { Ctx context.Context + Nwo string }{ Ctx: ctx, + Nwo: nwo, } - mock.lockGetCodespaceRegionLocation.Lock() - mock.calls.GetCodespaceRegionLocation = append(mock.calls.GetCodespaceRegionLocation, callInfo) - mock.lockGetCodespaceRegionLocation.Unlock() - return mock.GetCodespaceRegionLocationFunc(ctx) + mock.lockGetCodespaceBillableOwner.Lock() + mock.calls.GetCodespaceBillableOwner = append(mock.calls.GetCodespaceBillableOwner, callInfo) + mock.lockGetCodespaceBillableOwner.Unlock() + return mock.GetCodespaceBillableOwnerFunc(ctx, nwo) } -// GetCodespaceRegionLocationCalls gets all the calls that were made to GetCodespaceRegionLocation. +// GetCodespaceBillableOwnerCalls gets all the calls that were made to GetCodespaceBillableOwner. // Check the length with: -// len(mockedapiClient.GetCodespaceRegionLocationCalls()) -func (mock *apiClientMock) GetCodespaceRegionLocationCalls() []struct { +// len(mockedapiClient.GetCodespaceBillableOwnerCalls()) +func (mock *apiClientMock) GetCodespaceBillableOwnerCalls() []struct { Ctx context.Context + Nwo string } { var calls []struct { Ctx context.Context + Nwo string } - mock.lockGetCodespaceRegionLocation.RLock() - calls = mock.calls.GetCodespaceRegionLocation - mock.lockGetCodespaceRegionLocation.RUnlock() + mock.lockGetCodespaceBillableOwner.RLock() + calls = mock.calls.GetCodespaceBillableOwner + mock.lockGetCodespaceBillableOwner.RUnlock() return calls }