From 6014b31d0335eced8c45e15cff4975d3261d0152 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Fri, 1 Jul 2022 13:08:44 -0400 Subject: [PATCH] Fix case where codespace created while in provisioning state causes panic --- internal/codespaces/api/api.go | 14 +++- internal/codespaces/api/api_test.go | 110 ++++++++++++++++++++-------- 2 files changed, 90 insertions(+), 34 deletions(-) diff --git a/internal/codespaces/api/api.go b/internal/codespaces/api/api.go index cb19b4390..8c2bb80e2 100644 --- a/internal/codespaces/api/api.go +++ b/internal/codespaces/api/api.go @@ -706,7 +706,7 @@ type CreateCodespaceParams struct { // fails to create. func (a *API) CreateCodespace(ctx context.Context, params *CreateCodespaceParams) (*Codespace, error) { codespace, err := a.startCreate(ctx, params) - if err != errProvisioningInProgress { + if !errors.Is(err, errProvisioningInProgress) { return codespace, err } @@ -802,7 +802,17 @@ func (a *API) startCreate(ctx context.Context, params *CreateCodespaceParams) (* defer resp.Body.Close() if resp.StatusCode == http.StatusAccepted { - return nil, errProvisioningInProgress // RPC finished before result of creation known + b, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("error reading response body: %w", err) + } + + var response Codespace + if err := json.Unmarshal(b, &response); err != nil { + return nil, fmt.Errorf("error unmarshaling response: %w", err) + } + + return &response, errProvisioningInProgress // RPC finished before result of creation known } else if resp.StatusCode == http.StatusUnauthorized { var ( ue AcceptPermissionsRequiredError diff --git a/internal/codespaces/api/api_test.go b/internal/codespaces/api/api_test.go index 482a6feef..d8d99afb6 100644 --- a/internal/codespaces/api/api_test.go +++ b/internal/codespaces/api/api_test.go @@ -66,47 +66,93 @@ func createFakeListEndpointServer(t *testing.T, initalTotal int, finalTotal int) })) } -func createFakeCreateEndpointServer(t *testing.T, initalTotal int, finalTotal int) *httptest.Server { +func createFakeCreateEndpointServer(t *testing.T, wantStatus int) *httptest.Server { + t.Helper() return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != "/user/codespaces" { - t.Fatal("Incorrect path") + // create endpoint + if r.URL.Path == "/user/codespaces" { + body := r.Body + if body == nil { + t.Fatal("No body") + } + defer body.Close() + + var params startCreateRequest + err := json.NewDecoder(body).Decode(¶ms) + if err != nil { + t.Fatal("error:", err) + } + + if params.RepositoryID != 1 { + t.Fatal("Expected RepositoryID to be 1. Got: ", params.RepositoryID) + } + + if params.IdleTimeoutMinutes != 10 { + t.Fatal("Expected IdleTimeoutMinutes to be 10. Got: ", params.IdleTimeoutMinutes) + } + + if *params.RetentionPeriodMinutes != 0 { + t.Fatal("Expected RetentionPeriodMinutes to be 0. Got: ", *params.RetentionPeriodMinutes) + } + + response := Codespace{ + Name: "codespace-1", + } + + if wantStatus == 0 { + wantStatus = http.StatusCreated + } + + data, _ := json.Marshal(response) + w.WriteHeader(wantStatus) + fmt.Fprint(w, string(data)) + return } - body := r.Body - if body == nil { - t.Fatal("No body") - } - defer body.Close() - - var params startCreateRequest - err := json.NewDecoder(body).Decode(¶ms) - if err != nil { - t.Fatal("error:", err) + // get endpoint hit for testing pending status + if r.URL.Path == "/user/codespaces/codespace-1" { + response := Codespace{ + Name: "codespace-1", + State: CodespaceStateAvailable, + } + data, _ := json.Marshal(response) + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, string(data)) + return } - if params.RepositoryID != 1 { - t.Fatal("Expected RepositoryID to be 1. Got: ", params.RepositoryID) - } - - if params.IdleTimeoutMinutes != 10 { - t.Fatal("Expected IdleTimeoutMinutes to be 10. Got: ", params.IdleTimeoutMinutes) - } - - if *params.RetentionPeriodMinutes != 0 { - t.Fatal("Expected RetentionPeriodMinutes to be 0. Got: ", *params.RetentionPeriodMinutes) - } - - response := Codespace{ - Name: "codespace-1", - } - - data, _ := json.Marshal(response) - fmt.Fprint(w, string(data)) + t.Fatal("Incorrect path") })) } func TestCreateCodespaces(t *testing.T) { - svr := createFakeCreateEndpointServer(t, 200, 200) + svr := createFakeCreateEndpointServer(t, http.StatusCreated) + defer svr.Close() + + api := API{ + githubAPI: svr.URL, + client: &http.Client{}, + } + + ctx := context.TODO() + retentionPeriod := 0 + params := &CreateCodespaceParams{ + RepositoryID: 1, + IdleTimeoutMinutes: 10, + RetentionPeriodMinutes: &retentionPeriod, + } + codespace, err := api.CreateCodespace(ctx, params) + if err != nil { + t.Fatal(err) + } + + if codespace.Name != "codespace-1" { + t.Fatalf("expected codespace-1, got %s", codespace.Name) + } +} + +func TestCreateCodespaces_Pending(t *testing.T) { + svr := createFakeCreateEndpointServer(t, http.StatusAccepted) defer svr.Close() api := API{