Merge pull request #8179 from dmgardiner25/poll-codespace-permissions

Poll permission acceptance endpoint on codespace creation
This commit is contained in:
David Gardiner 2023-10-16 11:42:59 -07:00 committed by GitHub
commit bc0f63b043
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 175 additions and 30 deletions

View file

@ -639,6 +639,46 @@ func (a *API) GetCodespacesMachines(ctx context.Context, repoID int, branch, loc
return response.Machines, nil
}
// GetCodespacesPermissionsCheck returns a bool indicating whether the user has accepted permissions for the given repo and devcontainer path.
func (a *API) GetCodespacesPermissionsCheck(ctx context.Context, repoID int, branch string, location string, devcontainerPath string) (bool, error) {
reqURL := fmt.Sprintf("%s/repositories/%d/codespaces/permissions_check", a.githubAPI, repoID)
req, err := http.NewRequest(http.MethodGet, reqURL, nil)
if err != nil {
return false, fmt.Errorf("error creating request: %w", err)
}
q := req.URL.Query()
q.Add("location", location)
q.Add("ref", branch)
q.Add("devcontainer_path", devcontainerPath)
req.URL.RawQuery = q.Encode()
a.setHeaders(req)
resp, err := a.do(ctx, req, "/repositories/*/codespaces/permissions_check")
if err != nil {
return false, fmt.Errorf("error making request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return false, api.HandleHTTPError(resp)
}
b, err := io.ReadAll(resp.Body)
if err != nil {
return false, fmt.Errorf("error reading response body: %w", err)
}
var response struct {
Accepted bool `json:"accepted"`
}
if err := json.Unmarshal(b, &response); err != nil {
return false, fmt.Errorf("error unmarshalling response: %w", err)
}
return response.Accepted, nil
}
// RepoSearchParameters are the optional parameters for searching for repositories.
type RepoSearchParameters struct {
// The maximum number of repos to return. At most 100 repos are returned even if this value is greater than 100.

View file

@ -76,7 +76,8 @@ type apiClient interface {
CreateCodespace(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error)
EditCodespace(ctx context.Context, codespaceName string, params *api.EditCodespaceParams) (*api.Codespace, error)
GetRepository(ctx context.Context, nwo string) (*api.Repository, error)
GetCodespacesMachines(ctx context.Context, repoID int, branch, location string, devcontainerPath string) ([]*api.Machine, error)
GetCodespacesMachines(ctx context.Context, repoID int, branch string, location string, devcontainerPath string) ([]*api.Machine, error)
GetCodespacesPermissionsCheck(ctx context.Context, repoID int, branch string, location string, devcontainerPath string) (bool, error)
GetCodespaceRepositoryContents(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error)
ListDevContainers(ctx context.Context, repoID int, branch string, limit int) (devcontainers []api.DevContainerEntry, err error)
GetCodespaceRepoSuggestions(ctx context.Context, partialSearch string, params api.RepoSearchParameters) ([]string, error)

View file

@ -19,6 +19,11 @@ const (
DEVCONTAINER_PROMPT_DEFAULT = "Default Codespaces configuration"
)
const (
permissionsPollingInterval = 5 * time.Second
permissionsPollingTimeout = 1 * time.Minute
)
var (
DEFAULT_DEVCONTAINER_DEFINITIONS = []string{".devcontainer.json", ".devcontainer/devcontainer.json"}
)
@ -307,7 +312,7 @@ func (a *App) Create(ctx context.Context, opts createOptions) error {
return fmt.Errorf("error creating codespace: %w", err)
}
codespace, err = a.handleAdditionalPermissions(ctx, createParams, aerr.AllowPermissionsURL)
codespace, err = a.handleAdditionalPermissions(ctx, createParams, aerr.AllowPermissionsURL, userInputs.Location)
if err != nil {
// this error could be a cmdutil.SilentError (in the case that the user opened the browser) so we don't want to wrap it
return err
@ -331,7 +336,7 @@ func (a *App) Create(ctx context.Context, opts createOptions) error {
return nil
}
func (a *App) handleAdditionalPermissions(ctx context.Context, createParams *api.CreateCodespaceParams, allowPermissionsURL string) (*api.Codespace, error) {
func (a *App) handleAdditionalPermissions(ctx context.Context, createParams *api.CreateCodespaceParams, allowPermissionsURL string, location string) (*api.Codespace, error) {
var (
isInteractive = a.io.CanPrompt()
cs = a.io.ColorScheme()
@ -372,13 +377,44 @@ func (a *App) handleAdditionalPermissions(ctx context.Context, createParams *api
// if the user chose to continue in the browser, open the URL
if answers.Accept == choices[0] {
fmt.Fprintln(a.io.ErrOut, "Please re-run the create request after accepting permissions in the browser.")
if err := a.browser.Browse(allowPermissionsURL); err != nil {
return nil, fmt.Errorf("error opening browser: %w", err)
}
// browser opened successfully but we do not know if they accepted the permissions
// so we must exit and wait for the user to attempt the create again
return nil, cmdutil.SilentError
}
// Poll until the user has accepted the permissions or timeout
err := a.RunWithProgress("Waiting for permissions to be accepted in the browser", func() (err error) {
ctx, cancel := context.WithTimeout(ctx, permissionsPollingTimeout)
defer cancel()
done := make(chan error, 1)
go func() {
for {
accepted, err := a.apiClient.GetCodespacesPermissionsCheck(ctx, createParams.RepositoryID, createParams.Branch, location, createParams.DevContainerPath)
if err != nil {
done <- err
return
}
if accepted {
done <- nil
return
}
// Wait before polling again
time.Sleep(permissionsPollingInterval)
}
}()
select {
case err := <-done:
return err
case <-ctx.Done():
return fmt.Errorf("timed out waiting for permissions to be accepted in the browser")
}
})
if err != nil {
return nil, fmt.Errorf("error polling for permissions: %w", err)
}
// if the user chose to create the codespace without the permissions,
@ -386,7 +422,7 @@ func (a *App) handleAdditionalPermissions(ctx context.Context, createParams *api
createParams.PermissionsOptOut = true
var codespace *api.Codespace
err := a.RunWithProgress("Creating codespace", func() (err error) {
err = a.RunWithProgress("Creating codespace", func() (err error) {
codespace, err = a.apiClient.CreateCodespace(ctx, createParams)
return
})

View file

@ -41,8 +41,8 @@ import (
// GetCodespacesMachinesFunc: func(ctx context.Context, repoID int, branch string, location string, devcontainerPath string) ([]*codespacesAPI.Machine, error) {
// panic("mock out the GetCodespacesMachines method")
// },
// HTTPClientFunc: func() (*http.Client, error) {
// panic("mock out the HTTPClient method")
// GetCodespacesPermissionsCheckFunc: func(ctx context.Context, repoID int, branch string, location string, devcontainerPath string) (bool, error) {
// panic("mock out the GetCodespacesPermissionsCheck method")
// },
// GetOrgMemberCodespaceFunc: func(ctx context.Context, orgName string, userName string, codespaceName string) (*codespacesAPI.Codespace, error) {
// panic("mock out the GetOrgMemberCodespace method")
@ -53,6 +53,9 @@ import (
// GetUserFunc: func(ctx context.Context) (*codespacesAPI.User, error) {
// panic("mock out the GetUser method")
// },
// HTTPClientFunc: func() (*http.Client, error) {
// panic("mock out the HTTPClient method")
// },
// ListCodespacesFunc: func(ctx context.Context, opts codespacesAPI.ListCodespacesOptions) ([]*codespacesAPI.Codespace, error) {
// panic("mock out the ListCodespaces method")
// },
@ -99,8 +102,8 @@ type apiClientMock struct {
// GetCodespacesMachinesFunc mocks the GetCodespacesMachines method.
GetCodespacesMachinesFunc func(ctx context.Context, repoID int, branch string, location string, devcontainerPath string) ([]*codespacesAPI.Machine, error)
// HTTPClientFunc mocks the HTTPClient method.
HTTPClientFunc func() (*http.Client, error)
// GetCodespacesPermissionsCheckFunc mocks the GetCodespacesPermissionsCheck method.
GetCodespacesPermissionsCheckFunc func(ctx context.Context, repoID int, branch string, location string, devcontainerPath string) (bool, error)
// GetOrgMemberCodespaceFunc mocks the GetOrgMemberCodespace method.
GetOrgMemberCodespaceFunc func(ctx context.Context, orgName string, userName string, codespaceName string) (*codespacesAPI.Codespace, error)
@ -111,6 +114,9 @@ type apiClientMock struct {
// GetUserFunc mocks the GetUser method.
GetUserFunc func(ctx context.Context) (*codespacesAPI.User, error)
// HTTPClientFunc mocks the HTTPClient method.
HTTPClientFunc func() (*http.Client, error)
// ListCodespacesFunc mocks the ListCodespaces method.
ListCodespacesFunc func(ctx context.Context, opts codespacesAPI.ListCodespacesOptions) ([]*codespacesAPI.Codespace, error)
@ -202,8 +208,18 @@ type apiClientMock struct {
// DevcontainerPath is the devcontainerPath argument value.
DevcontainerPath string
}
// HTTPClient holds details about calls to the HTTPClient method.
HTTPClient []struct {
// GetCodespacesPermissionsCheck holds details about calls to the GetCodespacesPermissionsCheck method.
GetCodespacesPermissionsCheck []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// RepoID is the repoID argument value.
RepoID int
// Branch is the branch argument value.
Branch string
// Location is the location argument value.
Location string
// DevcontainerPath is the devcontainerPath argument value.
DevcontainerPath string
}
// GetOrgMemberCodespace holds details about calls to the GetOrgMemberCodespace method.
GetOrgMemberCodespace []struct {
@ -228,6 +244,9 @@ type apiClientMock struct {
// Ctx is the ctx argument value.
Ctx context.Context
}
// HTTPClient holds details about calls to the HTTPClient method.
HTTPClient []struct {
}
// ListCodespaces holds details about calls to the ListCodespaces method.
ListCodespaces []struct {
// Ctx is the ctx argument value.
@ -276,10 +295,11 @@ type apiClientMock struct {
lockGetCodespaceRepoSuggestions sync.RWMutex
lockGetCodespaceRepositoryContents sync.RWMutex
lockGetCodespacesMachines sync.RWMutex
lockHTTPClient sync.RWMutex
lockGetCodespacesPermissionsCheck sync.RWMutex
lockGetOrgMemberCodespace sync.RWMutex
lockGetRepository sync.RWMutex
lockGetUser sync.RWMutex
lockHTTPClient sync.RWMutex
lockListCodespaces sync.RWMutex
lockListDevContainers sync.RWMutex
lockServerURL sync.RWMutex
@ -611,30 +631,51 @@ func (mock *apiClientMock) GetCodespacesMachinesCalls() []struct {
return calls
}
// HTTPClient calls HTTPClientFunc.
func (mock *apiClientMock) HTTPClient() (*http.Client, error) {
if mock.HTTPClientFunc == nil {
panic("apiClientMock.HTTPClientFunc: method is nil but apiClient.HTTPClient was just called")
// GetCodespacesPermissionsCheck calls GetCodespacesPermissionsCheckFunc.
func (mock *apiClientMock) GetCodespacesPermissionsCheck(ctx context.Context, repoID int, branch string, location string, devcontainerPath string) (bool, error) {
if mock.GetCodespacesPermissionsCheckFunc == nil {
panic("apiClientMock.GetCodespacesPermissionsCheckFunc: method is nil but apiClient.GetCodespacesPermissionsCheck was just called")
}
callInfo := struct {
}{}
mock.lockHTTPClient.Lock()
mock.calls.HTTPClient = append(mock.calls.HTTPClient, callInfo)
mock.lockHTTPClient.Unlock()
return mock.HTTPClientFunc()
Ctx context.Context
RepoID int
Branch string
Location string
DevcontainerPath string
}{
Ctx: ctx,
RepoID: repoID,
Branch: branch,
Location: location,
DevcontainerPath: devcontainerPath,
}
mock.lockGetCodespacesPermissionsCheck.Lock()
mock.calls.GetCodespacesPermissionsCheck = append(mock.calls.GetCodespacesPermissionsCheck, callInfo)
mock.lockGetCodespacesPermissionsCheck.Unlock()
return mock.GetCodespacesPermissionsCheckFunc(ctx, repoID, branch, location, devcontainerPath)
}
// HTTPClientCalls gets all the calls that were made to HTTPClient.
// GetCodespacesPermissionsCheckCalls gets all the calls that were made to GetCodespacesPermissionsCheck.
// Check the length with:
//
// len(mockedapiClient.HTTPClientCalls())
func (mock *apiClientMock) HTTPClientCalls() []struct {
// len(mockedapiClient.GetCodespacesPermissionsCheckCalls())
func (mock *apiClientMock) GetCodespacesPermissionsCheckCalls() []struct {
Ctx context.Context
RepoID int
Branch string
Location string
DevcontainerPath string
} {
var calls []struct {
Ctx context.Context
RepoID int
Branch string
Location string
DevcontainerPath string
}
mock.lockHTTPClient.RLock()
calls = mock.calls.HTTPClient
mock.lockHTTPClient.RUnlock()
mock.lockGetCodespacesPermissionsCheck.RLock()
calls = mock.calls.GetCodespacesPermissionsCheck
mock.lockGetCodespacesPermissionsCheck.RUnlock()
return calls
}
@ -750,6 +791,33 @@ func (mock *apiClientMock) GetUserCalls() []struct {
return calls
}
// HTTPClient calls HTTPClientFunc.
func (mock *apiClientMock) HTTPClient() (*http.Client, error) {
if mock.HTTPClientFunc == nil {
panic("apiClientMock.HTTPClientFunc: method is nil but apiClient.HTTPClient was just called")
}
callInfo := struct {
}{}
mock.lockHTTPClient.Lock()
mock.calls.HTTPClient = append(mock.calls.HTTPClient, callInfo)
mock.lockHTTPClient.Unlock()
return mock.HTTPClientFunc()
}
// HTTPClientCalls gets all the calls that were made to HTTPClient.
// Check the length with:
//
// len(mockedapiClient.HTTPClientCalls())
func (mock *apiClientMock) HTTPClientCalls() []struct {
} {
var calls []struct {
}
mock.lockHTTPClient.RLock()
calls = mock.calls.HTTPClient
mock.lockHTTPClient.RUnlock()
return calls
}
// ListCodespaces calls ListCodespacesFunc.
func (mock *apiClientMock) ListCodespaces(ctx context.Context, opts codespacesAPI.ListCodespacesOptions) ([]*codespacesAPI.Codespace, error) {
if mock.ListCodespacesFunc == nil {