Merge pull request #5816 from cli/jshorty/codespace-pre-flight-api

Notify user when codespace usage is covered by organization
This commit is contained in:
Jake Shorty 2022-06-20 11:33:24 -06:00 committed by GitHub
commit faef89144d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 257 additions and 47 deletions

View file

@ -94,6 +94,7 @@ func New(serverURL, apiURL, vscsURL string, httpClient httpClient) *API {
// User represents a GitHub user.
type User struct {
Login string `json:"login"`
Type string `json:"type"`
}
// GetUser returns the user associated with the given token.
@ -556,6 +557,50 @@ func (a *API) GetCodespaceRepoSuggestions(ctx context.Context, partialSearch str
return repoNames, nil
}
// GetCodespaceBillableOwner returns the billable owner and expected default values for
// codespaces created by the user for a given repository.
func (a *API) GetCodespaceBillableOwner(ctx context.Context, nwo string) (*User, error) {
req, err := http.NewRequest(http.MethodGet, a.githubAPI+"/repos/"+nwo+"/codespaces/new", nil)
if err != nil {
return nil, fmt.Errorf("error creating request: %w", err)
}
a.setHeaders(req)
resp, err := a.do(ctx, req, "/repos/*/codespaces/new")
if err != nil {
return nil, fmt.Errorf("error making request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
return nil, nil
} else if resp.StatusCode == http.StatusForbidden {
return nil, fmt.Errorf("you cannot create codespaces with that repository")
} else if resp.StatusCode != http.StatusOK {
return nil, api.HandleHTTPError(resp)
}
b, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("error reading response body: %w", err)
}
var response struct {
BillableOwner User `json:"billable_owner"`
Defaults struct {
DevcontainerPath string `json:"devcontainer_path"`
Location string `json:"location"`
}
}
if err := json.Unmarshal(b, &response); err != nil {
return nil, fmt.Errorf("error unmarshaling response: %w", err)
}
// While this response contains further helpful information ahead of codespace creation,
// we're only referencing the billable owner today.
return &response.BillableOwner, nil
}
// CreateCodespaceParams are the required parameters for provisioning a Codespace.
type CreateCodespaceParams struct {
RepositoryID int

View file

@ -120,6 +120,7 @@ type apiClient interface {
GetCodespaceRepositoryContents(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error)
ListDevContainers(ctx context.Context, repoID int, branch string, limit int) (devcontainers []api.DevContainerEntry, err error)
GetCodespaceRepoSuggestions(ctx context.Context, partialSearch string, params api.RepoSearchParameters) ([]string, error)
GetCodespaceBillableOwner(ctx context.Context, nwo string) (*api.User, error)
}
var errNoCodespaces = errors.New("you have no codespaces")

View file

@ -111,12 +111,9 @@ func (a *App) Create(ctx context.Context, opts createOptions) error {
Location: opts.location,
}
if userInputs.Repository == "" {
branchPrompt := "Branch (leave blank for default branch):"
if userInputs.Branch != "" {
branchPrompt = "Branch:"
}
questions := []*survey.Question{
promptForRepoAndBranch := userInputs.Repository == ""
if promptForRepoAndBranch {
repoQuestions := []*survey.Question{
{
Name: "repository",
Prompt: &survey.Input{
@ -128,15 +125,8 @@ func (a *App) Create(ctx context.Context, opts createOptions) error {
},
Validate: survey.Required,
},
{
Name: "branch",
Prompt: &survey.Input{
Message: branchPrompt,
Default: userInputs.Branch,
},
},
}
if err := ask(questions, &userInputs); err != nil {
if err := ask(repoQuestions, &userInputs); err != nil {
return fmt.Errorf("failed to prompt: %w", err)
}
}
@ -152,6 +142,37 @@ func (a *App) Create(ctx context.Context, opts createOptions) error {
return fmt.Errorf("error getting repository: %w", err)
}
a.StartProgressIndicatorWithLabel("Validating repository for codespaces")
billableOwner, err := a.apiClient.GetCodespaceBillableOwner(ctx, userInputs.Repository)
a.StopProgressIndicator()
if err != nil {
return fmt.Errorf("error checking codespace ownership: %w", err)
} else if billableOwner != nil && billableOwner.Type == "Organization" {
cs := a.io.ColorScheme()
fmt.Fprintln(a.io.Out, cs.Blue("✓ Codespaces usage for this repository is paid for by "+billableOwner.Login))
}
if promptForRepoAndBranch {
branchPrompt := "Branch (leave blank for default branch):"
if userInputs.Branch != "" {
branchPrompt = "Branch:"
}
branchQuestions := []*survey.Question{
{
Name: "branch",
Prompt: &survey.Input{
Message: branchPrompt,
Default: userInputs.Branch,
},
},
}
if err := ask(branchQuestions, &userInputs); err != nil {
return fmt.Errorf("failed to prompt: %w", err)
}
}
branch := userInputs.Branch
if branch == "" {
branch = repository.DefaultBranch

View file

@ -36,6 +36,12 @@ func TestApp_Create(t *testing.T) {
DefaultBranch: "main",
}, nil
},
GetCodespaceBillableOwnerFunc: func(ctx context.Context, nwo string) (*api.User, error) {
return &api.User{
Login: "monalisa",
Type: "User",
}, nil
},
ListDevContainersFunc: func(ctx context.Context, repoID int, branch string, limit int) ([]api.DevContainerEntry, error) {
return []api.DevContainerEntry{{Path: ".devcontainer/devcontainer.json"}}, nil
},
@ -80,9 +86,6 @@ func TestApp_Create(t *testing.T) {
name: "create codespace with default branch shows idle timeout notice if present",
fields: fields{
apiClient: &apiClientMock{
GetCodespaceRegionLocationFunc: func(ctx context.Context) (string, error) {
return "EUROPE", nil
},
GetRepositoryFunc: func(ctx context.Context, nwo string) (*api.Repository, error) {
return &api.Repository{
ID: 1234,
@ -90,6 +93,12 @@ func TestApp_Create(t *testing.T) {
DefaultBranch: "main",
}, nil
},
GetCodespaceBillableOwnerFunc: func(ctx context.Context, nwo string) (*api.User, error) {
return &api.User{
Login: "monalisa",
Type: "User",
}, nil
},
GetCodespacesMachinesFunc: func(ctx context.Context, repoID int, branch, location string) ([]*api.Machine, error) {
return []*api.Machine{
{
@ -131,9 +140,6 @@ func TestApp_Create(t *testing.T) {
name: "create codespace with default branch with default devcontainer if no path provided and no devcontainer files exist in the repo",
fields: fields{
apiClient: &apiClientMock{
GetCodespaceRegionLocationFunc: func(ctx context.Context) (string, error) {
return "EUROPE", nil
},
GetRepositoryFunc: func(ctx context.Context, nwo string) (*api.Repository, error) {
return &api.Repository{
ID: 1234,
@ -141,6 +147,12 @@ func TestApp_Create(t *testing.T) {
DefaultBranch: "main",
}, nil
},
GetCodespaceBillableOwnerFunc: func(ctx context.Context, nwo string) (*api.User, error) {
return &api.User{
Login: "monalisa",
Type: "User",
}, nil
},
ListDevContainersFunc: func(ctx context.Context, repoID int, branch string, limit int) ([]api.DevContainerEntry, error) {
return []api.DevContainerEntry{}, nil
},
@ -187,9 +199,6 @@ func TestApp_Create(t *testing.T) {
name: "returns error when getting devcontainer paths fails",
fields: fields{
apiClient: &apiClientMock{
GetCodespaceRegionLocationFunc: func(ctx context.Context) (string, error) {
return "EUROPE", nil
},
GetRepositoryFunc: func(ctx context.Context, nwo string) (*api.Repository, error) {
return &api.Repository{
ID: 1234,
@ -197,6 +206,12 @@ func TestApp_Create(t *testing.T) {
DefaultBranch: "main",
}, nil
},
GetCodespaceBillableOwnerFunc: func(ctx context.Context, nwo string) (*api.User, error) {
return &api.User{
Login: "monalisa",
Type: "User",
}, nil
},
ListDevContainersFunc: func(ctx context.Context, repoID int, branch string, limit int) ([]api.DevContainerEntry, error) {
return nil, fmt.Errorf("some error")
},
@ -215,9 +230,6 @@ func TestApp_Create(t *testing.T) {
name: "create codespace with default branch does not show idle timeout notice if not conntected to terminal",
fields: fields{
apiClient: &apiClientMock{
GetCodespaceRegionLocationFunc: func(ctx context.Context) (string, error) {
return "EUROPE", nil
},
GetRepositoryFunc: func(ctx context.Context, nwo string) (*api.Repository, error) {
return &api.Repository{
ID: 1234,
@ -225,6 +237,12 @@ func TestApp_Create(t *testing.T) {
DefaultBranch: "main",
}, nil
},
GetCodespaceBillableOwnerFunc: func(ctx context.Context, nwo string) (*api.User, error) {
return &api.User{
Login: "monalisa",
Type: "User",
}, nil
},
ListDevContainersFunc: func(ctx context.Context, repoID int, branch string, limit int) ([]api.DevContainerEntry, error) {
return []api.DevContainerEntry{}, nil
},
@ -268,6 +286,12 @@ func TestApp_Create(t *testing.T) {
name: "create codespace that requires accepting additional permissions",
fields: fields{
apiClient: &apiClientMock{
GetCodespaceBillableOwnerFunc: func(ctx context.Context, nwo string) (*api.User, error) {
return &api.User{
Login: "monalisa",
Type: "User",
}, nil
},
GetRepositoryFunc: func(ctx context.Context, nwo string) (*api.Repository, error) {
return &api.Repository{
ID: 1234,
@ -315,6 +339,119 @@ Open this URL in your browser to review and authorize additional permissions: ex
Alternatively, you can run "create" with the "--default-permissions" option to continue without authorizing additional permissions.
`,
},
{
name: "returns error when user can't create codepaces for a repository",
fields: fields{
apiClient: &apiClientMock{
GetRepositoryFunc: func(ctx context.Context, nwo string) (*api.Repository, error) {
return &api.Repository{
ID: 1234,
FullName: nwo,
DefaultBranch: "main",
}, nil
},
GetCodespaceBillableOwnerFunc: func(ctx context.Context, nwo string) (*api.User, error) {
return nil, fmt.Errorf("some error")
},
},
},
opts: createOptions{
repo: "megacorp/private",
branch: "",
machine: "GIGA",
showStatus: false,
idleTimeout: 30 * time.Minute,
},
wantErr: fmt.Errorf("error checking codespace ownership: some error"),
},
{
name: "mentions billable owner when org covers codepaces for a repository",
fields: fields{
apiClient: &apiClientMock{
GetRepositoryFunc: func(ctx context.Context, nwo string) (*api.Repository, error) {
return &api.Repository{
ID: 1234,
FullName: nwo,
DefaultBranch: "main",
}, nil
},
GetCodespaceBillableOwnerFunc: func(ctx context.Context, nwo string) (*api.User, error) {
return &api.User{
Type: "Organization",
Login: "megacorp",
}, nil
},
ListDevContainersFunc: func(ctx context.Context, repoID int, branch string, limit int) ([]api.DevContainerEntry, error) {
return []api.DevContainerEntry{{Path: ".devcontainer/devcontainer.json"}}, nil
},
GetCodespacesMachinesFunc: func(ctx context.Context, repoID int, branch, location string) ([]*api.Machine, error) {
return []*api.Machine{
{
Name: "GIGA",
DisplayName: "Gigabits of a machine",
},
}, nil
},
CreateCodespaceFunc: func(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error) {
return &api.Codespace{
Name: "megacorp-private-abcd1234",
}, nil
},
},
},
opts: createOptions{
repo: "megacorp/private",
branch: "",
machine: "GIGA",
showStatus: false,
idleTimeout: 30 * time.Minute,
},
wantStdout: "✓ Codespaces usage for this repository is paid for by megacorp\nmegacorp-private-abcd1234\n",
},
{
name: "doesn't mention billable owner when it's the individual",
fields: fields{
apiClient: &apiClientMock{
GetRepositoryFunc: func(ctx context.Context, nwo string) (*api.Repository, error) {
return &api.Repository{
ID: 1234,
FullName: nwo,
DefaultBranch: "main",
}, nil
},
GetCodespaceBillableOwnerFunc: func(ctx context.Context, nwo string) (*api.User, error) {
return &api.User{
Type: "User",
Login: "monalisa",
}, nil
},
ListDevContainersFunc: func(ctx context.Context, repoID int, branch string, limit int) ([]api.DevContainerEntry, error) {
return []api.DevContainerEntry{{Path: ".devcontainer/devcontainer.json"}}, nil
},
GetCodespacesMachinesFunc: func(ctx context.Context, repoID int, branch, location string) ([]*api.Machine, error) {
return []*api.Machine{
{
Name: "GIGA",
DisplayName: "Gigabits of a machine",
},
}, nil
},
CreateCodespaceFunc: func(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error) {
return &api.Codespace{
Name: "megacorp-private-abcd1234",
}, nil
},
},
},
opts: createOptions{
repo: "megacorp/private",
branch: "",
machine: "GIGA",
showStatus: false,
idleTimeout: 30 * time.Minute,
},
wantStdout: "megacorp-private-abcd1234\n",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

View file

@ -31,8 +31,8 @@ import (
// GetCodespaceFunc: func(ctx context.Context, name string, includeConnection bool) (*api.Codespace, error) {
// panic("mock out the GetCodespace method")
// },
// GetCodespaceRegionLocationFunc: func(ctx context.Context) (string, error) {
// panic("mock out the GetCodespaceRegionLocation method")
// GetCodespaceBillableOwnerFunc: func(ctx context.Context, nwo string) (*api.User, error) {
// panic("mock out the GetCodespaceBillableOwner method")
// },
// GetCodespaceRepoSuggestionsFunc: func(ctx context.Context, partialSearch string, params api.RepoSearchParameters) ([]string, error) {
// panic("mock out the GetCodespaceRepoSuggestions method")
@ -83,8 +83,8 @@ type apiClientMock struct {
// GetCodespaceFunc mocks the GetCodespace method.
GetCodespaceFunc func(ctx context.Context, name string, includeConnection bool) (*api.Codespace, error)
// GetCodespaceRegionLocationFunc mocks the GetCodespaceRegionLocation method.
GetCodespaceRegionLocationFunc func(ctx context.Context) (string, error)
// GetCodespaceBillableOwnerFunc mocks the GetCodespaceBillableOwner method.
GetCodespaceBillableOwnerFunc func(ctx context.Context, nwo string) (*api.User, error)
// GetCodespaceRepoSuggestionsFunc mocks the GetCodespaceRepoSuggestions method.
GetCodespaceRepoSuggestionsFunc func(ctx context.Context, partialSearch string, params api.RepoSearchParameters) ([]string, error)
@ -154,10 +154,12 @@ type apiClientMock struct {
// IncludeConnection is the includeConnection argument value.
IncludeConnection bool
}
// GetCodespaceRegionLocation holds details about calls to the GetCodespaceRegionLocation method.
GetCodespaceRegionLocation []struct {
// GetCodespaceBillableOwner holds details about calls to the GetCodespaceBillableOwner method.
GetCodespaceBillableOwner []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// Nwo is the nwo argument value.
Nwo string
}
// GetCodespaceRepoSuggestions holds details about calls to the GetCodespaceRepoSuggestions method.
GetCodespaceRepoSuggestions []struct {
@ -238,7 +240,7 @@ type apiClientMock struct {
lockDeleteCodespace sync.RWMutex
lockEditCodespace sync.RWMutex
lockGetCodespace sync.RWMutex
lockGetCodespaceRegionLocation sync.RWMutex
lockGetCodespaceBillableOwner sync.RWMutex
lockGetCodespaceRepoSuggestions sync.RWMutex
lockGetCodespaceRepositoryContents sync.RWMutex
lockGetCodespacesMachines sync.RWMutex
@ -433,34 +435,38 @@ func (mock *apiClientMock) GetCodespaceCalls() []struct {
return calls
}
// GetCodespaceRegionLocation calls GetCodespaceRegionLocationFunc.
func (mock *apiClientMock) GetCodespaceRegionLocation(ctx context.Context) (string, error) {
if mock.GetCodespaceRegionLocationFunc == nil {
panic("apiClientMock.GetCodespaceRegionLocationFunc: method is nil but apiClient.GetCodespaceRegionLocation was just called")
// GetCodespaceBillableOwner calls GetCodespaceBillableOwnerFunc.
func (mock *apiClientMock) GetCodespaceBillableOwner(ctx context.Context, nwo string) (*api.User, error) {
if mock.GetCodespaceBillableOwnerFunc == nil {
panic("apiClientMock.GetCodespaceBillableOwnerFunc: method is nil but apiClient.GetCodespaceBillableOwner was just called")
}
callInfo := struct {
Ctx context.Context
Nwo string
}{
Ctx: ctx,
Nwo: nwo,
}
mock.lockGetCodespaceRegionLocation.Lock()
mock.calls.GetCodespaceRegionLocation = append(mock.calls.GetCodespaceRegionLocation, callInfo)
mock.lockGetCodespaceRegionLocation.Unlock()
return mock.GetCodespaceRegionLocationFunc(ctx)
mock.lockGetCodespaceBillableOwner.Lock()
mock.calls.GetCodespaceBillableOwner = append(mock.calls.GetCodespaceBillableOwner, callInfo)
mock.lockGetCodespaceBillableOwner.Unlock()
return mock.GetCodespaceBillableOwnerFunc(ctx, nwo)
}
// GetCodespaceRegionLocationCalls gets all the calls that were made to GetCodespaceRegionLocation.
// GetCodespaceBillableOwnerCalls gets all the calls that were made to GetCodespaceBillableOwner.
// Check the length with:
// len(mockedapiClient.GetCodespaceRegionLocationCalls())
func (mock *apiClientMock) GetCodespaceRegionLocationCalls() []struct {
// len(mockedapiClient.GetCodespaceBillableOwnerCalls())
func (mock *apiClientMock) GetCodespaceBillableOwnerCalls() []struct {
Ctx context.Context
Nwo string
} {
var calls []struct {
Ctx context.Context
Nwo string
}
mock.lockGetCodespaceRegionLocation.RLock()
calls = mock.calls.GetCodespaceRegionLocation
mock.lockGetCodespaceRegionLocation.RUnlock()
mock.lockGetCodespaceBillableOwner.RLock()
calls = mock.calls.GetCodespaceBillableOwner
mock.lockGetCodespaceBillableOwner.RUnlock()
return calls
}