💅 refactor and add tests for Secrets pagination

This commit is contained in:
Mislav Marohnić 2021-05-31 19:00:44 +02:00
parent cb60538709
commit 260f720c07
2 changed files with 87 additions and 16 deletions

View file

@ -1,13 +1,16 @@
package list
import (
"encoding/json"
"fmt"
"net/http"
"regexp"
"strings"
"time"
"github.com/cli/cli/api"
"github.com/cli/cli/internal/config"
"github.com/cli/cli/internal/ghinstance"
"github.com/cli/cli/internal/ghrepo"
"github.com/cli/cli/pkg/cmd/secret/shared"
"github.com/cli/cli/pkg/cmdutil"
@ -55,11 +58,10 @@ func NewCmdList(f *cmdutil.Factory, runF func(*ListOptions) error) *cobra.Comman
}
func listRun(opts *ListOptions) error {
c, err := opts.HttpClient()
client, err := opts.HttpClient()
if err != nil {
return fmt.Errorf("could not create http client: %w", err)
}
client := api.NewClientFromHTTP(c)
orgName := opts.OrgName
@ -145,7 +147,7 @@ func fmtVisibility(s Secret) string {
return ""
}
func getOrgSecrets(client *api.Client, host, orgName string) ([]*Secret, error) {
func getOrgSecrets(client httpClient, host, orgName string) ([]*Secret, error) {
secrets, err := getSecrets(client, host, fmt.Sprintf("orgs/%s/actions/secrets", orgName))
if err != nil {
return nil, err
@ -160,7 +162,7 @@ func getOrgSecrets(client *api.Client, host, orgName string) ([]*Secret, error)
continue
}
var result responseData
if err := client.REST(host, "GET", secret.SelectedReposURL, nil, &result); err != nil {
if _, err := apiGet(client, secret.SelectedReposURL, &result); err != nil {
return nil, fmt.Errorf("failed determining selected repositories for %s: %w", secret.Name, err)
}
secret.NumSelectedRepos = result.TotalCount
@ -169,7 +171,7 @@ func getOrgSecrets(client *api.Client, host, orgName string) ([]*Secret, error)
return secrets, nil
}
func getRepoSecrets(client *api.Client, repo ghrepo.Interface) ([]*Secret, error) {
func getRepoSecrets(client httpClient, repo ghrepo.Interface) ([]*Secret, error) {
return getSecrets(client, repo.RepoHost(), fmt.Sprintf("repos/%s/actions/secrets",
ghrepo.FullName(repo)))
}
@ -178,25 +180,63 @@ type secretsPayload struct {
Secrets []*Secret
}
func getSecrets(client *api.Client, host, path string) ([]*Secret, error) {
results := secretsPayload{}
type httpClient interface {
Do(*http.Request) (*http.Response, error)
}
perPage := 100
page := 1
func getSecrets(client httpClient, host, path string) ([]*Secret, error) {
var results []*Secret
url := fmt.Sprintf("%s%s?per_page=100", ghinstance.RESTPrefix(host), path)
for {
result := secretsPayload{}
err := client.REST(host, "GET", fmt.Sprintf("%s?per_page=%d&page=%d", path, perPage, page), nil, &result)
var payload secretsPayload
nextURL, err := apiGet(client, url, &payload)
if err != nil {
return nil, err
}
results.Secrets = append(results.Secrets, result.Secrets...)
if len(result.Secrets) == 0 || len(result.Secrets) < 100 {
results = append(results, payload.Secrets...)
if nextURL == "" {
break
}
page++
url = nextURL
}
return results.Secrets, nil
return results, nil
}
func apiGet(client httpClient, url string, data interface{}) (string, error) {
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json; charset=utf-8")
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode > 299 {
return "", api.HandleHTTPError(resp)
}
dec := json.NewDecoder(resp.Body)
if err := dec.Decode(data); err != nil {
return "", err
}
return findNextPage(resp.Header.Get("Link")), nil
}
var linkRE = regexp.MustCompile(`<([^>]+)>;\s*rel="([^"]+)"`)
func findNextPage(link string) string {
for _, m := range linkRE.FindAllStringSubmatch(link, -1) {
if len(m) >= 2 && m[2] == "next" {
return m[1]
}
}
return ""
}

View file

@ -3,7 +3,9 @@ package list
import (
"bytes"
"fmt"
"io/ioutil"
"net/http"
"strings"
"testing"
"time"
@ -200,3 +202,32 @@ func Test_listRun(t *testing.T) {
})
}
}
func Test_getSecrets_pagination(t *testing.T) {
var requests []*http.Request
var client testClient = func(req *http.Request) (*http.Response, error) {
header := make(map[string][]string)
if len(requests) == 0 {
header["Link"] = []string{`<http://example.com/page/0>; rel="previous", <http://example.com/page/2>; rel="next"`}
}
requests = append(requests, req)
return &http.Response{
Request: req,
Body: ioutil.NopCloser(strings.NewReader(`{"secrets":[{},{}]}`)),
Header: header,
}, nil
}
secrets, err := getSecrets(client, "github.com", "path/to")
assert.NoError(t, err)
assert.Equal(t, 2, len(requests))
assert.Equal(t, 4, len(secrets))
assert.Equal(t, "https://api.github.com/path/to?per_page=100", requests[0].URL.String())
assert.Equal(t, "http://example.com/page/2", requests[1].URL.String())
}
type testClient func(*http.Request) (*http.Response, error)
func (c testClient) Do(req *http.Request) (*http.Response, error) {
return c(req)
}