diff --git a/internal/codespaces/api/api.go b/internal/codespaces/api/api.go index 0bc8615fc..531b2b3f8 100644 --- a/internal/codespaces/api/api.go +++ b/internal/codespaces/api/api.go @@ -34,6 +34,7 @@ import ( "fmt" "io/ioutil" "net/http" + "strconv" "strings" "time" @@ -189,16 +190,43 @@ type CodespaceEnvironmentConnection struct { HostPublicKeys []string `json:"hostPublicKeys"` } +// codespacesListResponse is the response body for the `/user/codespaces` endpoint +type getCodespacesListResponse struct { + Codespaces []*Codespace `json:"codespaces"` + TotalCount int `json:"total_count"` +} + // ListCodespaces returns a list of codespaces for the user. -func (a *API) ListCodespaces(ctx context.Context) ([]*Codespace, error) { +// It consumes all pages returned by the API until all codespaces have been fetched. +func (a *API) ListCodespaces(ctx context.Context) (codespaces []*Codespace, err error) { + per_page := 100 + for page := 1; ; page++ { + response, err := a.fetchCodespaces(ctx, page, per_page) + if err != nil { + return nil, err + } + codespaces = append(codespaces, response.Codespaces...) + if page*per_page >= response.TotalCount { + break + } + } + + return codespaces, nil +} + +func (a *API) fetchCodespaces(ctx context.Context, page int, per_page int) (response *getCodespacesListResponse, err error) { req, err := http.NewRequest( http.MethodGet, a.githubAPI+"/user/codespaces", nil, ) if err != nil { return nil, fmt.Errorf("error creating request: %w", err) } - a.setHeaders(req) + q := req.URL.Query() + q.Add("page", strconv.Itoa(page)) + q.Add("per_page", strconv.Itoa(per_page)) + + req.URL.RawQuery = q.Encode() resp, err := a.do(ctx, req, "/user/codespaces") if err != nil { return nil, fmt.Errorf("error making request: %w", err) @@ -214,13 +242,10 @@ func (a *API) ListCodespaces(ctx context.Context) ([]*Codespace, error) { return nil, jsonErrorResponse(b) } - var response struct { - Codespaces []*Codespace `json:"codespaces"` - } if err := json.Unmarshal(b, &response); err != nil { return nil, fmt.Errorf("error unmarshaling response: %w", err) } - return response.Codespaces, nil + return response, nil } // GetCodespace returns the user codespace based on the provided name. diff --git a/internal/codespaces/api/api_test.go b/internal/codespaces/api/api_test.go index 8bbe1c8a9..a9d81e442 100644 --- a/internal/codespaces/api/api_test.go +++ b/internal/codespaces/api/api_test.go @@ -6,26 +6,61 @@ import ( "fmt" "net/http" "net/http/httptest" + "strconv" "testing" ) -func TestListCodespaces(t *testing.T) { - codespaces := []*Codespace{ - { - Name: "testcodespace", - CreatedAt: "2021-08-09T10:10:24+02:00", - LastUsedAt: "2021-08-09T13:10:24+02:00", - }, +func generateCodespaceList(start int, end int) []*Codespace { + codespacesList := []*Codespace{} + for i := start; i < end; i++ { + codespacesList = append(codespacesList, &Codespace{ + Name: fmt.Sprintf("codespace-%d", i), + }) } - svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + return codespacesList +} + +func createFakeListEndpointServer(t *testing.T, initalTotal int, finalTotal int) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/user/codespaces" { + t.Fatal("Incorrect path") + } + + page := 1 + if r.URL.Query().Get("page") != "" { + page, _ = strconv.Atoi(r.URL.Query().Get("page")) + } + + per_page := 0 + if r.URL.Query().Get("per_page") != "" { + per_page, _ = strconv.Atoi(r.URL.Query().Get("per_page")) + } + response := struct { Codespaces []*Codespace `json:"codespaces"` + TotalCount int `json:"total_count"` }{ - Codespaces: codespaces, + Codespaces: []*Codespace{}, + TotalCount: finalTotal, } + + if page == 1 { + response.Codespaces = generateCodespaceList(0, per_page) + response.TotalCount = initalTotal + } else if page == 2 { + response.Codespaces = generateCodespaceList(per_page, per_page*2) + response.TotalCount = finalTotal + } else { + t.Fatal("Should not check extra page") + } + data, _ := json.Marshal(response) fmt.Fprint(w, string(data)) })) +} + +func TestListCodespaces(t *testing.T) { + svr := createFakeListEndpointServer(t, 200, 200) defer svr.Close() api := API{ @@ -38,13 +73,53 @@ func TestListCodespaces(t *testing.T) { if err != nil { t.Fatal(err) } - - if len(codespaces) != 1 { - t.Fatalf("expected 1 codespace, got %d", len(codespaces)) + if len(codespaces) != 200 { + t.Fatalf("expected 100 codespace, got %d", len(codespaces)) } - if codespaces[0].Name != "testcodespace" { - t.Fatalf("expected testcodespace, got %s", codespaces[0].Name) + if codespaces[0].Name != "codespace-0" { + t.Fatalf("expected codespace-0, got %s", codespaces[0].Name) } + if codespaces[199].Name != "codespace-199" { + t.Fatalf("expected codespace-199, got %s", codespaces[0].Name) + } +} + +func TestMidIterationDeletion(t *testing.T) { + svr := createFakeListEndpointServer(t, 200, 199) + defer svr.Close() + + api := API{ + githubAPI: svr.URL, + client: &http.Client{}, + token: "faketoken", + } + ctx := context.TODO() + codespaces, err := api.ListCodespaces(ctx) + if err != nil { + t.Fatal(err) + } + if len(codespaces) != 200 { + t.Fatalf("expected 200 codespace, got %d", len(codespaces)) + } +} + +func TestMidIterationAddition(t *testing.T) { + svr := createFakeListEndpointServer(t, 199, 200) + defer svr.Close() + + api := API{ + githubAPI: svr.URL, + client: &http.Client{}, + token: "faketoken", + } + ctx := context.TODO() + codespaces, err := api.ListCodespaces(ctx) + if err != nil { + t.Fatal(err) + } + if len(codespaces) != 200 { + t.Fatalf("expected 200 codespace, got %d", len(codespaces)) + } }