Poll codespace on ErrCreateAsyncRetry error

- Introduce tests for the poller
- Attempt to fetch codespace for 2 mins
This commit is contained in:
Jose Garcia 2021-09-21 12:37:11 -04:00
parent 0b68aaab7e
commit 323462ca5c
2 changed files with 137 additions and 6 deletions

View file

@ -92,13 +92,19 @@ func create(opts *createOptions) error {
ctx, userResult.User, repository, machine, branch, locationResult.Location,
)
if err != nil {
// This error is returned by the API when the initial creation fails with a retryable error.
// A retryable error means that GitHub will retry to re-create Codespace and clients should poll
// the API and attempt to fetch the Codespace for the next two minutes.
if err == api.ErrCreateAsyncRetry {
createRetryCtx, cancelRetry := context.WithTimeout(ctx, 2*time.Minute)
defer cancelRetry()
log.Print("Switching to async provisioning...")
pollctx, cancel := context.WithTimeout(ctx, 2*time.Minute)
defer cancel()
codespace, err = pollForCodespace(pollctx, apiClient, log, userResult.User, codespace)
log.Print("\n")
codespace, err = pollForProvisionedCodespace(createRetryCtx, codespace)
if err != nil {
return fmt.Errorf("error creating codespace after retry: %w", err)
return fmt.Errorf("error creating codespace with async provisioning: %s: %w", codespace.Name, err)
}
}
@ -118,8 +124,40 @@ func create(opts *createOptions) error {
return nil
}
func pollForProvisionedCodespace(ctx context.Context, provisioningCodespace *api.Codespace) (*api.Codespace, error) {
return nil, nil
type apiClient interface {
GetCodespaceToken(context.Context, string, string) (string, error)
GetCodespace(context.Context, string, string, string) (*api.Codespace, error)
}
// pollForCodespace polls the Codespaces API every second fetching the codespace.
// If it succeeds at fetching the codespace, we consider the codespace provisioned.
// Context should be cancelled to stop polling.
func pollForCodespace(
ctx context.Context, client apiClient, log *output.Logger, user *api.User, provisioningCodespace *api.Codespace,
) (*api.Codespace, error) {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-ticker.C:
log.Print(".")
token, err := client.GetCodespaceToken(ctx, user.Login, provisioningCodespace.Name)
if err != nil {
// Do nothing. We expect this to fail until the codespace is provisioned
continue
}
codespace, err := client.GetCodespace(ctx, token, user.Login, provisioningCodespace.Name)
if err != nil {
return nil, fmt.Errorf("failed to get codespace: %w", err)
}
return codespace, nil
}
}
}
// showStatus polls the codespace for a list of post create states and their status. It will keep polling

93
cmd/ghcs/create_test.go Normal file
View file

@ -0,0 +1,93 @@
package main
import (
"context"
"errors"
"fmt"
"testing"
"time"
"github.com/github/ghcs/cmd/ghcs/output"
"github.com/github/ghcs/internal/api"
)
type mockAPIClient struct {
getCodespaceToken func(context.Context, string, string) (string, error)
getCodespace func(context.Context, string, string, string) (*api.Codespace, error)
}
func (m *mockAPIClient) GetCodespaceToken(ctx context.Context, userLogin, codespaceName string) (string, error) {
if m.getCodespaceToken == nil {
return "", errors.New("mock api client GetCodespaceToken not implemented")
}
return m.getCodespaceToken(ctx, userLogin, codespaceName)
}
func (m *mockAPIClient) GetCodespace(ctx context.Context, token, userLogin, codespaceName string) (*api.Codespace, error) {
if m.getCodespace == nil {
return nil, errors.New("mock api client GetCodespace not implemented")
}
return m.getCodespace(ctx, token, userLogin, codespaceName)
}
func TestPollForCodespace(t *testing.T) {
logger := output.NewLogger(nil, nil, false)
user := &api.User{Login: "test"}
tmpCodespace := &api.Codespace{Name: "tmp-codespace"}
codespaceToken := "codespace-token"
ctxTimeout := 1 * time.Second
exceedTime := 2 * time.Second
exceedProvisioningTime := false
api := &mockAPIClient{
getCodespaceToken: func(ctx context.Context, userLogin, codespace string) (string, error) {
if exceedProvisioningTime {
ticker := time.NewTicker(exceedTime)
defer ticker.Stop()
<-ticker.C
}
if userLogin != user.Login {
return "", fmt.Errorf("user does not match, got: %s, expected: %s", userLogin, user.Login)
}
if codespace != tmpCodespace.Name {
return "", fmt.Errorf("codespace does not match, got: %s, expected: %s", codespace, tmpCodespace.Name)
}
return codespaceToken, nil
},
getCodespace: func(ctx context.Context, token, userLogin, codespace string) (*api.Codespace, error) {
if token != codespaceToken {
return nil, fmt.Errorf("token does not match, got: %s, expected: %s", token, codespaceToken)
}
if userLogin != user.Login {
return nil, fmt.Errorf("user does not match, got: %s, expected: %s", userLogin, user.Login)
}
if codespace != tmpCodespace.Name {
return nil, fmt.Errorf("codespace does not match, got: %s, expected: %s", codespace, tmpCodespace.Name)
}
return tmpCodespace, nil
},
}
ctx, cancel := context.WithTimeout(context.Background(), ctxTimeout)
defer cancel()
codespace, err := pollForCodespace(ctx, api, logger, user, tmpCodespace)
if err != nil {
t.Error(err)
}
if tmpCodespace.Name != codespace.Name {
t.Errorf("returned codespace does not match, got: %s, expected: %s", codespace.Name, tmpCodespace.Name)
}
exceedProvisioningTime = true
ctx, cancel = context.WithTimeout(ctx, ctxTimeout)
defer cancel()
_, err = pollForCodespace(ctx, api, logger, user, tmpCodespace)
if err == nil {
t.Error("expected context deadline exceeded error, got nil")
}
}