diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index ff3e13962..93016bbf8 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -92,13 +92,19 @@ func create(opts *createOptions) error { ctx, userResult.User, repository, machine, branch, locationResult.Location, ) if err != nil { + // This error is returned by the API when the initial creation fails with a retryable error. + // A retryable error means that GitHub will retry to re-create Codespace and clients should poll + // the API and attempt to fetch the Codespace for the next two minutes. if err == api.ErrCreateAsyncRetry { - createRetryCtx, cancelRetry := context.WithTimeout(ctx, 2*time.Minute) - defer cancelRetry() + log.Print("Switching to async provisioning...") + pollctx, cancel := context.WithTimeout(ctx, 2*time.Minute) + defer cancel() + + codespace, err = pollForCodespace(pollctx, apiClient, log, userResult.User, codespace) + log.Print("\n") - codespace, err = pollForProvisionedCodespace(createRetryCtx, codespace) if err != nil { - return fmt.Errorf("error creating codespace after retry: %w", err) + return fmt.Errorf("error creating codespace with async provisioning: %s: %w", codespace.Name, err) } } @@ -118,8 +124,40 @@ func create(opts *createOptions) error { return nil } -func pollForProvisionedCodespace(ctx context.Context, provisioningCodespace *api.Codespace) (*api.Codespace, error) { - return nil, nil +type apiClient interface { + GetCodespaceToken(context.Context, string, string) (string, error) + GetCodespace(context.Context, string, string, string) (*api.Codespace, error) +} + +// pollForCodespace polls the Codespaces API every second fetching the codespace. +// If it succeeds at fetching the codespace, we consider the codespace provisioned. +// Context should be cancelled to stop polling. +func pollForCodespace( + ctx context.Context, client apiClient, log *output.Logger, user *api.User, provisioningCodespace *api.Codespace, +) (*api.Codespace, error) { + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-ticker.C: + log.Print(".") + token, err := client.GetCodespaceToken(ctx, user.Login, provisioningCodespace.Name) + if err != nil { + // Do nothing. We expect this to fail until the codespace is provisioned + continue + } + + codespace, err := client.GetCodespace(ctx, token, user.Login, provisioningCodespace.Name) + if err != nil { + return nil, fmt.Errorf("failed to get codespace: %w", err) + } + + return codespace, nil + } + } } // showStatus polls the codespace for a list of post create states and their status. It will keep polling diff --git a/cmd/ghcs/create_test.go b/cmd/ghcs/create_test.go new file mode 100644 index 000000000..36769dc14 --- /dev/null +++ b/cmd/ghcs/create_test.go @@ -0,0 +1,93 @@ +package main + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + "github.com/github/ghcs/cmd/ghcs/output" + "github.com/github/ghcs/internal/api" +) + +type mockAPIClient struct { + getCodespaceToken func(context.Context, string, string) (string, error) + getCodespace func(context.Context, string, string, string) (*api.Codespace, error) +} + +func (m *mockAPIClient) GetCodespaceToken(ctx context.Context, userLogin, codespaceName string) (string, error) { + if m.getCodespaceToken == nil { + return "", errors.New("mock api client GetCodespaceToken not implemented") + } + + return m.getCodespaceToken(ctx, userLogin, codespaceName) +} + +func (m *mockAPIClient) GetCodespace(ctx context.Context, token, userLogin, codespaceName string) (*api.Codespace, error) { + if m.getCodespace == nil { + return nil, errors.New("mock api client GetCodespace not implemented") + } + + return m.getCodespace(ctx, token, userLogin, codespaceName) +} + +func TestPollForCodespace(t *testing.T) { + logger := output.NewLogger(nil, nil, false) + user := &api.User{Login: "test"} + tmpCodespace := &api.Codespace{Name: "tmp-codespace"} + codespaceToken := "codespace-token" + + ctxTimeout := 1 * time.Second + exceedTime := 2 * time.Second + exceedProvisioningTime := false + + api := &mockAPIClient{ + getCodespaceToken: func(ctx context.Context, userLogin, codespace string) (string, error) { + if exceedProvisioningTime { + ticker := time.NewTicker(exceedTime) + defer ticker.Stop() + <-ticker.C + } + if userLogin != user.Login { + return "", fmt.Errorf("user does not match, got: %s, expected: %s", userLogin, user.Login) + } + if codespace != tmpCodespace.Name { + return "", fmt.Errorf("codespace does not match, got: %s, expected: %s", codespace, tmpCodespace.Name) + } + return codespaceToken, nil + }, + getCodespace: func(ctx context.Context, token, userLogin, codespace string) (*api.Codespace, error) { + if token != codespaceToken { + return nil, fmt.Errorf("token does not match, got: %s, expected: %s", token, codespaceToken) + } + if userLogin != user.Login { + return nil, fmt.Errorf("user does not match, got: %s, expected: %s", userLogin, user.Login) + } + if codespace != tmpCodespace.Name { + return nil, fmt.Errorf("codespace does not match, got: %s, expected: %s", codespace, tmpCodespace.Name) + } + return tmpCodespace, nil + }, + } + + ctx, cancel := context.WithTimeout(context.Background(), ctxTimeout) + defer cancel() + + codespace, err := pollForCodespace(ctx, api, logger, user, tmpCodespace) + if err != nil { + t.Error(err) + } + if tmpCodespace.Name != codespace.Name { + t.Errorf("returned codespace does not match, got: %s, expected: %s", codespace.Name, tmpCodespace.Name) + } + + exceedProvisioningTime = true + ctx, cancel = context.WithTimeout(ctx, ctxTimeout) + defer cancel() + + _, err = pollForCodespace(ctx, api, logger, user, tmpCodespace) + if err == nil { + t.Error("expected context deadline exceeded error, got nil") + } +}