From 260f720c0738c403132f10477a1565cd57477c5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Mon, 31 May 2021 19:00:44 +0200 Subject: [PATCH] :nail_care: refactor and add tests for Secrets pagination --- pkg/cmd/secret/list/list.go | 72 +++++++++++++++++++++++++------- pkg/cmd/secret/list/list_test.go | 31 ++++++++++++++ 2 files changed, 87 insertions(+), 16 deletions(-) diff --git a/pkg/cmd/secret/list/list.go b/pkg/cmd/secret/list/list.go index e512295ac..97c3ac27a 100644 --- a/pkg/cmd/secret/list/list.go +++ b/pkg/cmd/secret/list/list.go @@ -1,13 +1,16 @@ package list import ( + "encoding/json" "fmt" "net/http" + "regexp" "strings" "time" "github.com/cli/cli/api" "github.com/cli/cli/internal/config" + "github.com/cli/cli/internal/ghinstance" "github.com/cli/cli/internal/ghrepo" "github.com/cli/cli/pkg/cmd/secret/shared" "github.com/cli/cli/pkg/cmdutil" @@ -55,11 +58,10 @@ func NewCmdList(f *cmdutil.Factory, runF func(*ListOptions) error) *cobra.Comman } func listRun(opts *ListOptions) error { - c, err := opts.HttpClient() + client, err := opts.HttpClient() if err != nil { return fmt.Errorf("could not create http client: %w", err) } - client := api.NewClientFromHTTP(c) orgName := opts.OrgName @@ -145,7 +147,7 @@ func fmtVisibility(s Secret) string { return "" } -func getOrgSecrets(client *api.Client, host, orgName string) ([]*Secret, error) { +func getOrgSecrets(client httpClient, host, orgName string) ([]*Secret, error) { secrets, err := getSecrets(client, host, fmt.Sprintf("orgs/%s/actions/secrets", orgName)) if err != nil { return nil, err @@ -160,7 +162,7 @@ func getOrgSecrets(client *api.Client, host, orgName string) ([]*Secret, error) continue } var result responseData - if err := client.REST(host, "GET", secret.SelectedReposURL, nil, &result); err != nil { + if _, err := apiGet(client, secret.SelectedReposURL, &result); err != nil { return nil, fmt.Errorf("failed determining selected repositories for %s: %w", secret.Name, err) } secret.NumSelectedRepos = result.TotalCount @@ -169,7 +171,7 @@ func getOrgSecrets(client *api.Client, host, orgName string) ([]*Secret, error) return secrets, nil } -func getRepoSecrets(client *api.Client, repo ghrepo.Interface) ([]*Secret, error) { +func getRepoSecrets(client httpClient, repo ghrepo.Interface) ([]*Secret, error) { return getSecrets(client, repo.RepoHost(), fmt.Sprintf("repos/%s/actions/secrets", ghrepo.FullName(repo))) } @@ -178,25 +180,63 @@ type secretsPayload struct { Secrets []*Secret } -func getSecrets(client *api.Client, host, path string) ([]*Secret, error) { - results := secretsPayload{} +type httpClient interface { + Do(*http.Request) (*http.Response, error) +} - perPage := 100 - page := 1 +func getSecrets(client httpClient, host, path string) ([]*Secret, error) { + var results []*Secret + url := fmt.Sprintf("%s%s?per_page=100", ghinstance.RESTPrefix(host), path) for { - result := secretsPayload{} - err := client.REST(host, "GET", fmt.Sprintf("%s?per_page=%d&page=%d", path, perPage, page), nil, &result) + var payload secretsPayload + nextURL, err := apiGet(client, url, &payload) if err != nil { return nil, err } - results.Secrets = append(results.Secrets, result.Secrets...) - if len(result.Secrets) == 0 || len(result.Secrets) < 100 { + results = append(results, payload.Secrets...) + + if nextURL == "" { break } - - page++ + url = nextURL } - return results.Secrets, nil + return results, nil +} + +func apiGet(client httpClient, url string, data interface{}) (string, error) { + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return "", err + } + req.Header.Set("Content-Type", "application/json; charset=utf-8") + + resp, err := client.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + if resp.StatusCode > 299 { + return "", api.HandleHTTPError(resp) + } + + dec := json.NewDecoder(resp.Body) + if err := dec.Decode(data); err != nil { + return "", err + } + + return findNextPage(resp.Header.Get("Link")), nil +} + +var linkRE = regexp.MustCompile(`<([^>]+)>;\s*rel="([^"]+)"`) + +func findNextPage(link string) string { + for _, m := range linkRE.FindAllStringSubmatch(link, -1) { + if len(m) >= 2 && m[2] == "next" { + return m[1] + } + } + return "" } diff --git a/pkg/cmd/secret/list/list_test.go b/pkg/cmd/secret/list/list_test.go index 8620ca012..e3975b1e6 100644 --- a/pkg/cmd/secret/list/list_test.go +++ b/pkg/cmd/secret/list/list_test.go @@ -3,7 +3,9 @@ package list import ( "bytes" "fmt" + "io/ioutil" "net/http" + "strings" "testing" "time" @@ -200,3 +202,32 @@ func Test_listRun(t *testing.T) { }) } } + +func Test_getSecrets_pagination(t *testing.T) { + var requests []*http.Request + var client testClient = func(req *http.Request) (*http.Response, error) { + header := make(map[string][]string) + if len(requests) == 0 { + header["Link"] = []string{`; rel="previous", ; rel="next"`} + } + requests = append(requests, req) + return &http.Response{ + Request: req, + Body: ioutil.NopCloser(strings.NewReader(`{"secrets":[{},{}]}`)), + Header: header, + }, nil + } + + secrets, err := getSecrets(client, "github.com", "path/to") + assert.NoError(t, err) + assert.Equal(t, 2, len(requests)) + assert.Equal(t, 4, len(secrets)) + assert.Equal(t, "https://api.github.com/path/to?per_page=100", requests[0].URL.String()) + assert.Equal(t, "http://example.com/page/2", requests[1].URL.String()) +} + +type testClient func(*http.Request) (*http.Response, error) + +func (c testClient) Do(req *http.Request) (*http.Response, error) { + return c(req) +}