diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index 91c8bcb8d..eb6d1bea6 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -83,7 +83,7 @@ func create(opts *createOptions) error { log.Println("Creating your codespace...") - codespace, err := codespaces.Provision(ctx, log, apiClient, &codespaces.ProvisionParams{ + codespace, err := apiClient.ProvisionCodespace(ctx, log, &api.ProvisionCodespaceParams{ User: userResult.User, Repository: repository, Branch: branch, diff --git a/internal/api/api.go b/internal/api/api.go index 6b19b0703..c3ad0aadc 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -36,6 +36,7 @@ import ( "net/http" "strconv" "strings" + "time" "github.com/opentracing/opentracing-go" ) @@ -402,6 +403,81 @@ func (a *API) GetCodespacesSKUs(ctx context.Context, user *User, repository *Rep return response.SKUs, nil } +// ProvisionCodespaceParams are the required parameters for provisioning a Codespace. +type ProvisionCodespaceParams struct { + User *User + Repository *Repository + Branch, Machine, Location string +} + +type logger interface { + Print(v ...interface{}) (int, error) + Println(v ...interface{}) (int, error) +} + +// ProvisionCodespace creates a codespace with the given parameters and handles polling in the case +// of initial creation failures. +func (a *API) ProvisionCodespace(ctx context.Context, log logger, params *ProvisionCodespaceParams) (*Codespace, error) { + codespace, err := a.createCodespace( + ctx, params.User, params.Repository, params.Machine, params.Branch, params.Location, + ) + if err != nil { + // This error is returned by the API when the initial creation fails with a retryable error. + // A retryable error means that GitHub will retry to re-create Codespace and clients should poll + // the API and attempt to fetch the Codespace for the next two minutes. + if err == errProvisioningInProgress { + pollTimeout := 2 * time.Minute + pollInterval := 1 * time.Second + log.Print(".") + codespace, err = pollForCodespace(ctx, a, log, pollTimeout, pollInterval, params.User.Login, codespace.Name) + log.Print("\n") + + if err != nil { + return nil, fmt.Errorf("error creating codespace with async provisioning: %s: %w", codespace.Name, err) + } + } + + return nil, err + } + + return codespace, nil +} + +type apiClient interface { + GetCodespaceToken(ctx context.Context, userLogin, codespaceName string) (string, error) + GetCodespace(ctx context.Context, token, userLogin, codespaceName string) (*Codespace, error) +} + +// pollForCodespace polls the Codespaces GET endpoint on a given interval for a specified duration. +// If it succeeds at fetching the codespace, we consider the codespace provisioned. +func pollForCodespace(ctx context.Context, client apiClient, log logger, duration, interval time.Duration, user, name string) (*Codespace, error) { + ctx, cancel := context.WithTimeout(ctx, duration) + defer cancel() + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-ticker.C: + log.Print(".") + token, err := client.GetCodespaceToken(ctx, user, name) + if err != nil { + if err == ErrNotProvisioned { + // Do nothing. We expect this to fail until the codespace is provisioned + continue + } + + return nil, fmt.Errorf("failed to get codespace token: %w", err) + } + + return client.GetCodespace(ctx, token, user, name) + } + } +} + type createCodespaceRequest struct { RepositoryID int `json:"repository_id"` Ref string `json:"ref"` @@ -409,9 +485,9 @@ type createCodespaceRequest struct { SkuName string `json:"sku_name"` } -var ErrProvisioningInProgress = errors.New("provisioning in progress") +var errProvisioningInProgress = errors.New("provisioning in progress") -func (a *API) CreateCodespace(ctx context.Context, user *User, repository *Repository, sku, branch, location string) (*Codespace, error) { +func (a *API) createCodespace(ctx context.Context, user *User, repository *Repository, sku, branch, location string) (*Codespace, error) { requestBody, err := json.Marshal(createCodespaceRequest{repository.ID, branch, location, sku}) if err != nil { return nil, fmt.Errorf("error marshaling request: %w", err) @@ -442,7 +518,7 @@ func (a *API) CreateCodespace(ctx context.Context, user *User, repository *Repos // being retried. For clients this means that they must implement a polling strategy // to check for the codespace existence for the next two minutes. We return an error // here so callers can detect and handle this condition. - return nil, ErrProvisioningInProgress + return nil, errProvisioningInProgress } var response Codespace diff --git a/internal/codespaces/codespaces_test.go b/internal/api/api_test.go similarity index 77% rename from internal/codespaces/codespaces_test.go rename to internal/api/api_test.go index 53aba0557..eb5226a59 100644 --- a/internal/codespaces/codespaces_test.go +++ b/internal/api/api_test.go @@ -1,4 +1,4 @@ -package codespaces +package api import ( "context" @@ -8,21 +8,11 @@ import ( "time" "github.com/github/ghcs/cmd/ghcs/output" - "github.com/github/ghcs/internal/api" ) type mockAPIClient struct { - createCodespace func(context.Context, *api.User, *api.Repository, string, string, string) (*api.Codespace, error) getCodespaceToken func(context.Context, string, string) (string, error) - getCodespace func(context.Context, string, string, string) (*api.Codespace, error) -} - -func (m *mockAPIClient) CreateCodespace(ctx context.Context, user *api.User, repo *api.Repository, machine, branch, location string) (*api.Codespace, error) { - if m.createCodespace == nil { - return nil, errors.New("mock api client CreateCodespace not implemented") - } - - return m.createCodespace(ctx, user, repo, machine, branch, location) + getCodespace func(context.Context, string, string, string) (*Codespace, error) } func (m *mockAPIClient) GetCodespaceToken(ctx context.Context, userLogin, codespaceName string) (string, error) { @@ -33,7 +23,7 @@ func (m *mockAPIClient) GetCodespaceToken(ctx context.Context, userLogin, codesp return m.getCodespaceToken(ctx, userLogin, codespaceName) } -func (m *mockAPIClient) GetCodespace(ctx context.Context, token, userLogin, codespaceName string) (*api.Codespace, error) { +func (m *mockAPIClient) GetCodespace(ctx context.Context, token, userLogin, codespaceName string) (*Codespace, error) { if m.getCodespace == nil { return nil, errors.New("mock api client GetCodespace not implemented") } @@ -43,8 +33,8 @@ func (m *mockAPIClient) GetCodespace(ctx context.Context, token, userLogin, code func TestPollForCodespace(t *testing.T) { logger := output.NewLogger(nil, nil, false) - user := &api.User{Login: "test"} - tmpCodespace := &api.Codespace{Name: "tmp-codespace"} + user := &User{Login: "test"} + tmpCodespace := &Codespace{Name: "tmp-codespace"} codespaceToken := "codespace-token" ctx := context.Background() @@ -61,7 +51,7 @@ func TestPollForCodespace(t *testing.T) { } return codespaceToken, nil }, - getCodespace: func(ctx context.Context, token, userLogin, codespace string) (*api.Codespace, error) { + getCodespace: func(ctx context.Context, token, userLogin, codespace string) (*Codespace, error) { if token != codespaceToken { return nil, fmt.Errorf("token does not match, got: %s, expected: %s", token, codespaceToken) } diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index 8a0e21b3d..2933c9d8d 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -75,74 +75,3 @@ func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, use return lsclient.JoinWorkspace(ctx) } - -type apiClient interface { - CreateCodespace(ctx context.Context, user *api.User, repo *api.Repository, machine, branch, location string) (*api.Codespace, error) - GetCodespaceToken(ctx context.Context, userLogin, codespaceName string) (string, error) - GetCodespace(ctx context.Context, token, userLogin, codespaceName string) (*api.Codespace, error) -} - -// ProvisionParams are the required parameters for provisioning a Codespace. -type ProvisionParams struct { - User *api.User - Repository *api.Repository - Branch, Machine, Location string -} - -// Provision creates a codespace with the given parameters and handles polling in the case -// of initial creation failures. -func Provision(ctx context.Context, log logger, client apiClient, params *ProvisionParams) (*api.Codespace, error) { - codespace, err := client.CreateCodespace( - ctx, params.User, params.Repository, params.Machine, params.Branch, params.Location, - ) - if err != nil { - // This error is returned by the API when the initial creation fails with a retryable error. - // A retryable error means that GitHub will retry to re-create Codespace and clients should poll - // the API and attempt to fetch the Codespace for the next two minutes. - if err == api.ErrProvisioningInProgress { - pollTimeout := 2 * time.Minute - pollInterval := 1 * time.Second - log.Print(".") - codespace, err = pollForCodespace(ctx, client, log, pollTimeout, pollInterval, params.User.Login, codespace.Name) - log.Print("\n") - - if err != nil { - return nil, fmt.Errorf("error creating codespace with async provisioning: %s: %w", codespace.Name, err) - } - } - - return nil, err - } - - return codespace, nil -} - -// pollForCodespace polls the Codespaces GET endpoint on a given interval for a specified duration. -// If it succeeds at fetching the codespace, we consider the codespace provisioned. -func pollForCodespace(ctx context.Context, client apiClient, log logger, duration, interval time.Duration, user, name string) (*api.Codespace, error) { - ctx, cancel := context.WithTimeout(ctx, duration) - defer cancel() - - ticker := time.NewTicker(interval) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-ticker.C: - log.Print(".") - token, err := client.GetCodespaceToken(ctx, user, name) - if err != nil { - if err == api.ErrNotProvisioned { - // Do nothing. We expect this to fail until the codespace is provisioned - continue - } - - return nil, fmt.Errorf("failed to get codespace token: %w", err) - } - - return client.GetCodespace(ctx, token, user, name) - } - } -}