diff --git a/internal/api/api.go b/internal/api/api.go index eac2c3a88..50c4a03de 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -429,24 +429,10 @@ func (a *API) CreateCodespace(ctx context.Context, log logger, params *CreateCod // errProvisioningInProgress indicates that codespace creation did not complete // within the GitHub API RPC time limit (10s), so it continues asynchronously. // We must poll the server to discover the outcome. - pollTimeout := 2 * time.Minute - pollInterval := 1 * time.Second - - return pollForCodespace(ctx, a, log, pollTimeout, pollInterval, params.User, codespace.Name) -} - -type apiClient interface { - GetCodespaceToken(ctx context.Context, userLogin, codespaceName string) (string, error) - GetCodespace(ctx context.Context, token, userLogin, codespaceName string) (*Codespace, error) -} - -// pollForCodespace polls the Codespaces GET endpoint on a given interval for a specified duration. -// If it succeeds at fetching the codespace, we consider the codespace provisioned. -func pollForCodespace(ctx context.Context, client apiClient, log logger, duration, interval time.Duration, user, name string) (*Codespace, error) { - ctx, cancel := context.WithTimeout(ctx, duration) + ctx, cancel := context.WithTimeout(ctx, 2*time.Minute) defer cancel() - ticker := time.NewTicker(interval) + ticker := time.NewTicker(1 * time.Second) defer ticker.Stop() for { @@ -455,7 +441,7 @@ func pollForCodespace(ctx context.Context, client apiClient, log logger, duratio return nil, ctx.Err() case <-ticker.C: log.Print(".") - token, err := client.GetCodespaceToken(ctx, user, name) + token, err := a.GetCodespaceToken(ctx, params.User, codespace.Name) if err != nil { if err == ErrNotProvisioned { // Do nothing. We expect this to fail until the codespace is provisioned @@ -465,7 +451,7 @@ func pollForCodespace(ctx context.Context, client apiClient, log logger, duratio return nil, fmt.Errorf("failed to get codespace token: %w", err) } - codespace, err := client.GetCodespace(ctx, token, user, name) + codespace, err = a.GetCodespace(ctx, token, params.User, codespace.Name) if err != nil { return nil, fmt.Errorf("failed to get codespace: %w", err) } diff --git a/internal/api/api_test.go b/internal/api/api_test.go deleted file mode 100644 index eb5226a59..000000000 --- a/internal/api/api_test.go +++ /dev/null @@ -1,82 +0,0 @@ -package api - -import ( - "context" - "errors" - "fmt" - "testing" - "time" - - "github.com/github/ghcs/cmd/ghcs/output" -) - -type mockAPIClient struct { - getCodespaceToken func(context.Context, string, string) (string, error) - getCodespace func(context.Context, string, string, string) (*Codespace, error) -} - -func (m *mockAPIClient) GetCodespaceToken(ctx context.Context, userLogin, codespaceName string) (string, error) { - if m.getCodespaceToken == nil { - return "", errors.New("mock api client GetCodespaceToken not implemented") - } - - return m.getCodespaceToken(ctx, userLogin, codespaceName) -} - -func (m *mockAPIClient) GetCodespace(ctx context.Context, token, userLogin, codespaceName string) (*Codespace, error) { - if m.getCodespace == nil { - return nil, errors.New("mock api client GetCodespace not implemented") - } - - return m.getCodespace(ctx, token, userLogin, codespaceName) -} - -func TestPollForCodespace(t *testing.T) { - logger := output.NewLogger(nil, nil, false) - user := &User{Login: "test"} - tmpCodespace := &Codespace{Name: "tmp-codespace"} - codespaceToken := "codespace-token" - ctx := context.Background() - - pollInterval := 50 * time.Millisecond - pollTimeout := 100 * time.Millisecond - - api := &mockAPIClient{ - getCodespaceToken: func(ctx context.Context, userLogin, codespace string) (string, error) { - if userLogin != user.Login { - return "", fmt.Errorf("user does not match, got: %s, expected: %s", userLogin, user.Login) - } - if codespace != tmpCodespace.Name { - return "", fmt.Errorf("codespace does not match, got: %s, expected: %s", codespace, tmpCodespace.Name) - } - return codespaceToken, nil - }, - getCodespace: func(ctx context.Context, token, userLogin, codespace string) (*Codespace, error) { - if token != codespaceToken { - return nil, fmt.Errorf("token does not match, got: %s, expected: %s", token, codespaceToken) - } - if userLogin != user.Login { - return nil, fmt.Errorf("user does not match, got: %s, expected: %s", userLogin, user.Login) - } - if codespace != tmpCodespace.Name { - return nil, fmt.Errorf("codespace does not match, got: %s, expected: %s", codespace, tmpCodespace.Name) - } - return tmpCodespace, nil - }, - } - - codespace, err := pollForCodespace(ctx, api, logger, pollTimeout, pollInterval, user.Login, tmpCodespace.Name) - if err != nil { - t.Error(err) - } - if tmpCodespace.Name != codespace.Name { - t.Errorf("returned codespace does not match, got: %s, expected: %s", codespace.Name, tmpCodespace.Name) - } - - // swap the durations to trigger a timeout - pollTimeout, pollInterval = pollInterval, pollTimeout - _, err = pollForCodespace(ctx, api, logger, pollTimeout, pollInterval, user.Login, tmpCodespace.Name) - if err != context.DeadlineExceeded { - t.Error("expected context deadline exceeded error, got nil") - } -}