diff --git a/internal/codespaces/api/api.go b/internal/codespaces/api/api.go index 302e0ee0f..950a8ea52 100644 --- a/internal/codespaces/api/api.go +++ b/internal/codespaces/api/api.go @@ -639,6 +639,46 @@ func (a *API) GetCodespacesMachines(ctx context.Context, repoID int, branch, loc return response.Machines, nil } +// GetCodespacesPermissionsCheck returns a bool indicating whether the user has accepted permissions for the given repo and devcontainer path. +func (a *API) GetCodespacesPermissionsCheck(ctx context.Context, repoID int, branch string, location string, devcontainerPath string) (bool, error) { + reqURL := fmt.Sprintf("%s/repositories/%d/codespaces/permissions_check", a.githubAPI, repoID) + req, err := http.NewRequest(http.MethodGet, reqURL, nil) + if err != nil { + return false, fmt.Errorf("error creating request: %w", err) + } + + q := req.URL.Query() + q.Add("location", location) + q.Add("ref", branch) + q.Add("devcontainer_path", devcontainerPath) + req.URL.RawQuery = q.Encode() + + a.setHeaders(req) + resp, err := a.do(ctx, req, "/repositories/*/codespaces/permissions_check") + if err != nil { + return false, fmt.Errorf("error making request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return false, api.HandleHTTPError(resp) + } + + b, err := io.ReadAll(resp.Body) + if err != nil { + return false, fmt.Errorf("error reading response body: %w", err) + } + + var response struct { + Accepted bool `json:"accepted"` + } + if err := json.Unmarshal(b, &response); err != nil { + return false, fmt.Errorf("error unmarshalling response: %w", err) + } + + return response.Accepted, nil +} + // RepoSearchParameters are the optional parameters for searching for repositories. type RepoSearchParameters struct { // The maximum number of repos to return. At most 100 repos are returned even if this value is greater than 100. diff --git a/pkg/cmd/codespace/common.go b/pkg/cmd/codespace/common.go index 3ad3463e6..78dae09e4 100644 --- a/pkg/cmd/codespace/common.go +++ b/pkg/cmd/codespace/common.go @@ -76,7 +76,8 @@ type apiClient interface { CreateCodespace(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error) EditCodespace(ctx context.Context, codespaceName string, params *api.EditCodespaceParams) (*api.Codespace, error) GetRepository(ctx context.Context, nwo string) (*api.Repository, error) - GetCodespacesMachines(ctx context.Context, repoID int, branch, location string, devcontainerPath string) ([]*api.Machine, error) + GetCodespacesMachines(ctx context.Context, repoID int, branch string, location string, devcontainerPath string) ([]*api.Machine, error) + GetCodespacesPermissionsCheck(ctx context.Context, repoID int, branch string, location string, devcontainerPath string) (bool, error) GetCodespaceRepositoryContents(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error) ListDevContainers(ctx context.Context, repoID int, branch string, limit int) (devcontainers []api.DevContainerEntry, err error) GetCodespaceRepoSuggestions(ctx context.Context, partialSearch string, params api.RepoSearchParameters) ([]string, error) diff --git a/pkg/cmd/codespace/create.go b/pkg/cmd/codespace/create.go index a2782aef7..485f8e57c 100644 --- a/pkg/cmd/codespace/create.go +++ b/pkg/cmd/codespace/create.go @@ -19,6 +19,11 @@ const ( DEVCONTAINER_PROMPT_DEFAULT = "Default Codespaces configuration" ) +const ( + permissionsPollingInterval = 5 * time.Second + permissionsPollingTimeout = 1 * time.Minute +) + var ( DEFAULT_DEVCONTAINER_DEFINITIONS = []string{".devcontainer.json", ".devcontainer/devcontainer.json"} ) @@ -307,7 +312,7 @@ func (a *App) Create(ctx context.Context, opts createOptions) error { return fmt.Errorf("error creating codespace: %w", err) } - codespace, err = a.handleAdditionalPermissions(ctx, createParams, aerr.AllowPermissionsURL) + codespace, err = a.handleAdditionalPermissions(ctx, createParams, aerr.AllowPermissionsURL, userInputs.Location) if err != nil { // this error could be a cmdutil.SilentError (in the case that the user opened the browser) so we don't want to wrap it return err @@ -331,7 +336,7 @@ func (a *App) Create(ctx context.Context, opts createOptions) error { return nil } -func (a *App) handleAdditionalPermissions(ctx context.Context, createParams *api.CreateCodespaceParams, allowPermissionsURL string) (*api.Codespace, error) { +func (a *App) handleAdditionalPermissions(ctx context.Context, createParams *api.CreateCodespaceParams, allowPermissionsURL string, location string) (*api.Codespace, error) { var ( isInteractive = a.io.CanPrompt() cs = a.io.ColorScheme() @@ -372,13 +377,44 @@ func (a *App) handleAdditionalPermissions(ctx context.Context, createParams *api // if the user chose to continue in the browser, open the URL if answers.Accept == choices[0] { - fmt.Fprintln(a.io.ErrOut, "Please re-run the create request after accepting permissions in the browser.") if err := a.browser.Browse(allowPermissionsURL); err != nil { return nil, fmt.Errorf("error opening browser: %w", err) } - // browser opened successfully but we do not know if they accepted the permissions - // so we must exit and wait for the user to attempt the create again - return nil, cmdutil.SilentError + } + + // Poll until the user has accepted the permissions or timeout + err := a.RunWithProgress("Waiting for permissions to be accepted in the browser", func() (err error) { + ctx, cancel := context.WithTimeout(ctx, permissionsPollingTimeout) + defer cancel() + + done := make(chan error, 1) + go func() { + for { + accepted, err := a.apiClient.GetCodespacesPermissionsCheck(ctx, createParams.RepositoryID, createParams.Branch, location, createParams.DevContainerPath) + if err != nil { + done <- err + return + } + + if accepted { + done <- nil + return + } + + // Wait before polling again + time.Sleep(permissionsPollingInterval) + } + }() + + select { + case err := <-done: + return err + case <-ctx.Done(): + return fmt.Errorf("timed out waiting for permissions to be accepted in the browser") + } + }) + if err != nil { + return nil, fmt.Errorf("error polling for permissions: %w", err) } // if the user chose to create the codespace without the permissions, @@ -386,7 +422,7 @@ func (a *App) handleAdditionalPermissions(ctx context.Context, createParams *api createParams.PermissionsOptOut = true var codespace *api.Codespace - err := a.RunWithProgress("Creating codespace", func() (err error) { + err = a.RunWithProgress("Creating codespace", func() (err error) { codespace, err = a.apiClient.CreateCodespace(ctx, createParams) return }) diff --git a/pkg/cmd/codespace/mock_api.go b/pkg/cmd/codespace/mock_api.go index aad15c025..d6cf37eeb 100644 --- a/pkg/cmd/codespace/mock_api.go +++ b/pkg/cmd/codespace/mock_api.go @@ -41,8 +41,8 @@ import ( // GetCodespacesMachinesFunc: func(ctx context.Context, repoID int, branch string, location string, devcontainerPath string) ([]*codespacesAPI.Machine, error) { // panic("mock out the GetCodespacesMachines method") // }, -// HTTPClientFunc: func() (*http.Client, error) { -// panic("mock out the HTTPClient method") +// GetCodespacesPermissionsCheckFunc: func(ctx context.Context, repoID int, branch string, location string, devcontainerPath string) (bool, error) { +// panic("mock out the GetCodespacesPermissionsCheck method") // }, // GetOrgMemberCodespaceFunc: func(ctx context.Context, orgName string, userName string, codespaceName string) (*codespacesAPI.Codespace, error) { // panic("mock out the GetOrgMemberCodespace method") @@ -53,6 +53,9 @@ import ( // GetUserFunc: func(ctx context.Context) (*codespacesAPI.User, error) { // panic("mock out the GetUser method") // }, +// HTTPClientFunc: func() (*http.Client, error) { +// panic("mock out the HTTPClient method") +// }, // ListCodespacesFunc: func(ctx context.Context, opts codespacesAPI.ListCodespacesOptions) ([]*codespacesAPI.Codespace, error) { // panic("mock out the ListCodespaces method") // }, @@ -99,8 +102,8 @@ type apiClientMock struct { // GetCodespacesMachinesFunc mocks the GetCodespacesMachines method. GetCodespacesMachinesFunc func(ctx context.Context, repoID int, branch string, location string, devcontainerPath string) ([]*codespacesAPI.Machine, error) - // HTTPClientFunc mocks the HTTPClient method. - HTTPClientFunc func() (*http.Client, error) + // GetCodespacesPermissionsCheckFunc mocks the GetCodespacesPermissionsCheck method. + GetCodespacesPermissionsCheckFunc func(ctx context.Context, repoID int, branch string, location string, devcontainerPath string) (bool, error) // GetOrgMemberCodespaceFunc mocks the GetOrgMemberCodespace method. GetOrgMemberCodespaceFunc func(ctx context.Context, orgName string, userName string, codespaceName string) (*codespacesAPI.Codespace, error) @@ -111,6 +114,9 @@ type apiClientMock struct { // GetUserFunc mocks the GetUser method. GetUserFunc func(ctx context.Context) (*codespacesAPI.User, error) + // HTTPClientFunc mocks the HTTPClient method. + HTTPClientFunc func() (*http.Client, error) + // ListCodespacesFunc mocks the ListCodespaces method. ListCodespacesFunc func(ctx context.Context, opts codespacesAPI.ListCodespacesOptions) ([]*codespacesAPI.Codespace, error) @@ -202,8 +208,18 @@ type apiClientMock struct { // DevcontainerPath is the devcontainerPath argument value. DevcontainerPath string } - // HTTPClient holds details about calls to the HTTPClient method. - HTTPClient []struct { + // GetCodespacesPermissionsCheck holds details about calls to the GetCodespacesPermissionsCheck method. + GetCodespacesPermissionsCheck []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // RepoID is the repoID argument value. + RepoID int + // Branch is the branch argument value. + Branch string + // Location is the location argument value. + Location string + // DevcontainerPath is the devcontainerPath argument value. + DevcontainerPath string } // GetOrgMemberCodespace holds details about calls to the GetOrgMemberCodespace method. GetOrgMemberCodespace []struct { @@ -228,6 +244,9 @@ type apiClientMock struct { // Ctx is the ctx argument value. Ctx context.Context } + // HTTPClient holds details about calls to the HTTPClient method. + HTTPClient []struct { + } // ListCodespaces holds details about calls to the ListCodespaces method. ListCodespaces []struct { // Ctx is the ctx argument value. @@ -276,10 +295,11 @@ type apiClientMock struct { lockGetCodespaceRepoSuggestions sync.RWMutex lockGetCodespaceRepositoryContents sync.RWMutex lockGetCodespacesMachines sync.RWMutex - lockHTTPClient sync.RWMutex + lockGetCodespacesPermissionsCheck sync.RWMutex lockGetOrgMemberCodespace sync.RWMutex lockGetRepository sync.RWMutex lockGetUser sync.RWMutex + lockHTTPClient sync.RWMutex lockListCodespaces sync.RWMutex lockListDevContainers sync.RWMutex lockServerURL sync.RWMutex @@ -611,30 +631,51 @@ func (mock *apiClientMock) GetCodespacesMachinesCalls() []struct { return calls } -// HTTPClient calls HTTPClientFunc. -func (mock *apiClientMock) HTTPClient() (*http.Client, error) { - if mock.HTTPClientFunc == nil { - panic("apiClientMock.HTTPClientFunc: method is nil but apiClient.HTTPClient was just called") +// GetCodespacesPermissionsCheck calls GetCodespacesPermissionsCheckFunc. +func (mock *apiClientMock) GetCodespacesPermissionsCheck(ctx context.Context, repoID int, branch string, location string, devcontainerPath string) (bool, error) { + if mock.GetCodespacesPermissionsCheckFunc == nil { + panic("apiClientMock.GetCodespacesPermissionsCheckFunc: method is nil but apiClient.GetCodespacesPermissionsCheck was just called") } callInfo := struct { - }{} - mock.lockHTTPClient.Lock() - mock.calls.HTTPClient = append(mock.calls.HTTPClient, callInfo) - mock.lockHTTPClient.Unlock() - return mock.HTTPClientFunc() + Ctx context.Context + RepoID int + Branch string + Location string + DevcontainerPath string + }{ + Ctx: ctx, + RepoID: repoID, + Branch: branch, + Location: location, + DevcontainerPath: devcontainerPath, + } + mock.lockGetCodespacesPermissionsCheck.Lock() + mock.calls.GetCodespacesPermissionsCheck = append(mock.calls.GetCodespacesPermissionsCheck, callInfo) + mock.lockGetCodespacesPermissionsCheck.Unlock() + return mock.GetCodespacesPermissionsCheckFunc(ctx, repoID, branch, location, devcontainerPath) } -// HTTPClientCalls gets all the calls that were made to HTTPClient. +// GetCodespacesPermissionsCheckCalls gets all the calls that were made to GetCodespacesPermissionsCheck. // Check the length with: // -// len(mockedapiClient.HTTPClientCalls()) -func (mock *apiClientMock) HTTPClientCalls() []struct { +// len(mockedapiClient.GetCodespacesPermissionsCheckCalls()) +func (mock *apiClientMock) GetCodespacesPermissionsCheckCalls() []struct { + Ctx context.Context + RepoID int + Branch string + Location string + DevcontainerPath string } { var calls []struct { + Ctx context.Context + RepoID int + Branch string + Location string + DevcontainerPath string } - mock.lockHTTPClient.RLock() - calls = mock.calls.HTTPClient - mock.lockHTTPClient.RUnlock() + mock.lockGetCodespacesPermissionsCheck.RLock() + calls = mock.calls.GetCodespacesPermissionsCheck + mock.lockGetCodespacesPermissionsCheck.RUnlock() return calls } @@ -750,6 +791,33 @@ func (mock *apiClientMock) GetUserCalls() []struct { return calls } +// HTTPClient calls HTTPClientFunc. +func (mock *apiClientMock) HTTPClient() (*http.Client, error) { + if mock.HTTPClientFunc == nil { + panic("apiClientMock.HTTPClientFunc: method is nil but apiClient.HTTPClient was just called") + } + callInfo := struct { + }{} + mock.lockHTTPClient.Lock() + mock.calls.HTTPClient = append(mock.calls.HTTPClient, callInfo) + mock.lockHTTPClient.Unlock() + return mock.HTTPClientFunc() +} + +// HTTPClientCalls gets all the calls that were made to HTTPClient. +// Check the length with: +// +// len(mockedapiClient.HTTPClientCalls()) +func (mock *apiClientMock) HTTPClientCalls() []struct { +} { + var calls []struct { + } + mock.lockHTTPClient.RLock() + calls = mock.calls.HTTPClient + mock.lockHTTPClient.RUnlock() + return calls +} + // ListCodespaces calls ListCodespacesFunc. func (mock *apiClientMock) ListCodespaces(ctx context.Context, opts codespacesAPI.ListCodespacesOptions) ([]*codespacesAPI.Codespace, error) { if mock.ListCodespacesFunc == nil {