From 91b3b99b767dd0367a49d87668885d83c43515bd Mon Sep 17 00:00:00 2001 From: nilvng Date: Sun, 15 Dec 2024 14:01:08 +1100 Subject: [PATCH] issue #2329: create shared PRLister --- pkg/cmd/pr/checkout/checkout.go | 39 +++--- pkg/cmd/pr/checkout/checkout_test.go | 41 ++++-- pkg/cmd/pr/shared/finder.go | 4 +- pkg/cmd/pr/shared/lister.go | 196 +++++++++++++++++++++++++++ 4 files changed, 242 insertions(+), 38 deletions(-) create mode 100644 pkg/cmd/pr/shared/lister.go diff --git a/pkg/cmd/pr/checkout/checkout.go b/pkg/cmd/pr/checkout/checkout.go index faf006c7e..b265ea25e 100644 --- a/pkg/cmd/pr/checkout/checkout.go +++ b/pkg/cmd/pr/checkout/checkout.go @@ -13,7 +13,6 @@ import ( "github.com/cli/cli/v2/internal/gh" "github.com/cli/cli/v2/internal/ghrepo" "github.com/cli/cli/v2/internal/text" - "github.com/cli/cli/v2/pkg/cmd/pr/list" "github.com/cli/cli/v2/pkg/cmd/pr/shared" "github.com/cli/cli/v2/pkg/cmdutil" "github.com/cli/cli/v2/pkg/iostreams" @@ -30,6 +29,7 @@ type CheckoutOptions struct { Finder shared.PRFinder Prompter shared.Prompter + Lister shared.PRLister Interactive bool BaseRepo func() (ghrepo.Interface, error) @@ -67,6 +67,7 @@ func NewCmdCheckout(f *cmdutil.Factory, runF func(*CheckoutOptions) error) *cobr Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { opts.Finder = shared.NewFinder(f) + opts.Lister = shared.NewLister(f) if len(args) > 0 { opts.SelectorArg = args[0] @@ -97,12 +98,7 @@ func checkoutRun(opts *CheckoutOptions) error { return err } - client, err := opts.HttpClient() - if err != nil { - return err - } - - pr, err := resolvePR(client, baseRepo, opts.Prompter, opts.SelectorArg, opts.Interactive, opts.Finder, opts.IO) + pr, err := resolvePR(baseRepo, opts.Prompter, opts.SelectorArg, opts.Interactive, opts.Finder, opts.Lister, opts.IO) if err != nil { return err } @@ -291,7 +287,7 @@ func executeCmds(client *git.Client, credentialPattern git.CredentialPattern, cm return nil } -func resolvePR(httpClient *http.Client, baseRepo ghrepo.Interface, prompter shared.Prompter, pullRequestSelector string, isInteractive bool, pullRequestFinder shared.PRFinder, io *iostreams.IOStreams) (*api.PullRequest, error) { +func resolvePR(baseRepo ghrepo.Interface, prompter shared.Prompter, pullRequestSelector string, isInteractive bool, pullRequestFinder shared.PRFinder, prLister shared.PRLister, io *iostreams.IOStreams) (*api.PullRequest, error) { // When non-interactive if pullRequestSelector != "" { pr, _, err := pullRequestFinder.Find(shared.FindOptions{ @@ -315,20 +311,21 @@ func resolvePR(httpClient *http.Client, baseRepo ghrepo.Interface, prompter shar } // When interactive io.StartProgressIndicator() - listResult, err := list.ListPullRequests(httpClient, baseRepo, shared.FilterOptions{Entity: "pr", State: "open", Fields: []string{ - "number", - "title", - "state", - "url", - "isDraft", - "createdAt", + listResult, err := prLister.List(shared.ListOptions{ + State: "open", + Fields: []string{ + "number", + "title", + "state", + "isDraft", - "headRefName", - "headRepository", - "headRepositoryOwner", - "isCrossRepository", - "maintainerCanModify", - }}, 10) + "headRefName", + "headRepository", + "headRepositoryOwner", + "isCrossRepository", + "maintainerCanModify", + }, + LimitResults: 10}) io.StopProgressIndicator() if err != nil { return nil, err diff --git a/pkg/cmd/pr/checkout/checkout_test.go b/pkg/cmd/pr/checkout/checkout_test.go index a099a848d..bd8bc3984 100644 --- a/pkg/cmd/pr/checkout/checkout_test.go +++ b/pkg/cmd/pr/checkout/checkout_test.go @@ -28,6 +28,10 @@ import ( // repo: either "baseOwner/baseRepo" or "baseOwner/baseRepo:defaultBranch" // prHead: "headOwner/headRepo:headBranch" func stubPR(repo, prHead string) (ghrepo.Interface, *api.PullRequest) { + return _stubPR(repo, prHead, 123, "PR title", "OPEN", false) +} + +func _stubPR(repo, prHead string, number int, title string, state string, isDraft bool) (ghrepo.Interface, *api.PullRequest) { defaultBranch := "" if idx := strings.IndexRune(repo, ':'); idx >= 0 { defaultBranch = repo[idx+1:] @@ -53,12 +57,16 @@ func stubPR(repo, prHead string) (ghrepo.Interface, *api.PullRequest) { } return baseRepo, &api.PullRequest{ - Number: 123, + Number: number, HeadRefName: headRefName, HeadRepositoryOwner: api.Owner{Login: headRepo.RepoOwner()}, HeadRepository: &api.PRRepository{Name: headRepo.RepoName()}, IsCrossRepository: !ghrepo.IsSame(baseRepo, headRepo), MaintainerCanModify: false, + + Title: title, + State: state, + IsDraft: isDraft, } } @@ -224,10 +232,17 @@ func Test_checkoutRun(t *testing.T) { opts: &CheckoutOptions{ SelectorArg: "", Interactive: true, - Finder: func() shared.PRFinder { - baseRepo, pr := stubPR("OWNER/REPO:master", "OWNER/REPO:feature") - finder := shared.NewMockFinder("123", pr, baseRepo) - return finder + Lister: func() shared.PRLister { + _, pr1 := _stubPR("OWNER/REPO:master", "OWNER/REPO:feature", 32, "New feature", "OPEN", false) + _, pr2 := _stubPR("OWNER/REPO:master", "OWNER/REPO:bug-fix", 29, "Fixed bad bug", "OPEN", false) + _, pr3 := _stubPR("OWNER/REPO:master", "OWNER/REPO:docs", 28, "Improve documentation", "OPEN", true) + lister := shared.NewMockLister(&api.PullRequestAndTotalCount{ + TotalCount: 3, + PullRequests: []api.PullRequest{ + *pr1, *pr2, *pr3, + }, SearchCapped: false}, nil) + lister.ExpectFields([]string{"number", "title", "state", "isDraft", "headRefName", "headRepository", "headRepositoryOwner", "isCrossRepository", "maintainerCanModify"}) + return lister }(), BaseRepo: func() (ghrepo.Interface, error) { return ghrepo.New("OWNER", "REPO"), nil @@ -236,14 +251,11 @@ func Test_checkoutRun(t *testing.T) { return config.NewBlankConfig(), nil }, }, - httpStubs: func(reg *httpmock.Registry) { - reg.Register(httpmock.GraphQL(`query PullRequestList\b`), httpmock.FileResponse("./fixtures/prList.json")) - }, promptStubs: func(pm *prompter.MockPrompter) { pm.RegisterSelect("Select a pull request", - []string{"32\tDRAFT New feature [feature]", "29\tOPEN Fixed bad bug [bug-fix]", "28\tOPEN Improve documentation [docs]"}, + []string{"32\tOPEN New feature [feature]", "29\tOPEN Fixed bad bug [bug-fix]", "28\tDRAFT Improve documentation [docs]"}, func(_, _ string, opts []string) (int, error) { - return prompter.IndexFor(opts, "32\tDRAFT New feature [feature]") + return prompter.IndexFor(opts, "32\tOPEN New feature [feature]") }) }, runStubs: func(cs *run.CommandStubber) { @@ -263,11 +275,10 @@ func Test_checkoutRun(t *testing.T) { BaseRepo: func() (ghrepo.Interface, error) { return ghrepo.New("OWNER", "REPO"), nil }, - }, - httpStubs: func(reg *httpmock.Registry) { - reg.Register(httpmock.GraphQL(`query PullRequestList\b`), - httpmock.StringResponse(`{"data":{"repository":{"pullRequests":{ "totalCount": 0,"nodes":[]}}}}`)) - + Lister: shared.NewMockLister(&api.PullRequestAndTotalCount{ + TotalCount: 0, + PullRequests: []api.PullRequest{}, + }, nil), }, remotes: map[string]string{ "origin": "OWNER/REPO", diff --git a/pkg/cmd/pr/shared/finder.go b/pkg/cmd/pr/shared/finder.go index e4a70eb22..b1fec267e 100644 --- a/pkg/cmd/pr/shared/finder.go +++ b/pkg/cmd/pr/shared/finder.go @@ -548,7 +548,7 @@ func (m *mockFinder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, if m.expectSelector != opts.Selector { return nil, nil, fmt.Errorf("mockFinder: expected selector %q, got %q", m.expectSelector, opts.Selector) } - if len(m.expectFields) > 0 && !isEqualSet(m.expectFields, opts.Fields) { + if len(m.expectFields) > 0 && !IsEqualSet(m.expectFields, opts.Fields) { return nil, nil, fmt.Errorf("mockFinder: expected fields %v, got %v", m.expectFields, opts.Fields) } if m.called { @@ -568,7 +568,7 @@ func (m *mockFinder) ExpectFields(fields []string) { m.expectFields = fields } -func isEqualSet(a, b []string) bool { +func IsEqualSet(a, b []string) bool { if len(a) != len(b) { return false } diff --git a/pkg/cmd/pr/shared/lister.go b/pkg/cmd/pr/shared/lister.go new file mode 100644 index 000000000..004f726d5 --- /dev/null +++ b/pkg/cmd/pr/shared/lister.go @@ -0,0 +1,196 @@ +package shared + +import ( + "fmt" + "net/http" + + "github.com/cli/cli/v2/internal/ghrepo" + "github.com/cli/cli/v2/pkg/cmdutil" + + api "github.com/cli/cli/v2/api" +) + +type PRLister interface { + List(opt ListOptions) (*api.PullRequestAndTotalCount, error) +} + +type ListOptions struct { + LimitResults int + + State string + BaseBranch string + HeadBranch string + + Fields []string +} + +type lister struct { + baseRepoFn func() (ghrepo.Interface, error) + httpClient func() (*http.Client, error) +} + +func NewLister(factory *cmdutil.Factory) PRLister { + return &lister{ + baseRepoFn: factory.BaseRepo, + httpClient: factory.HttpClient, + } +} + +func (l *lister) List(opts ListOptions) (*api.PullRequestAndTotalCount, error) { + repo, err := l.baseRepoFn() + if err != nil { + return nil, err + } + + client, err := l.httpClient() + if err != nil { + return nil, err + } + + return listPullRequests(client, repo, opts) +} + +func listPullRequests(httpClient *http.Client, repo ghrepo.Interface, filters ListOptions) (*api.PullRequestAndTotalCount, error) { + type response struct { + Repository struct { + PullRequests struct { + Nodes []api.PullRequest + PageInfo struct { + HasNextPage bool + EndCursor string + } + TotalCount int + } + } + } + limit := filters.LimitResults + fragment := fmt.Sprintf("fragment pr on PullRequest{%s}", api.PullRequestGraphQL(filters.Fields)) + query := fragment + ` + query PullRequestList( + $owner: String!, + $repo: String!, + $limit: Int!, + $endCursor: String, + $baseBranch: String, + $headBranch: String, + $state: [PullRequestState!] = OPEN + ) { + repository(owner: $owner, name: $repo) { + pullRequests( + states: $state, + baseRefName: $baseBranch, + headRefName: $headBranch, + first: $limit, + after: $endCursor, + orderBy: {field: CREATED_AT, direction: DESC} + ) { + totalCount + nodes { + ...pr + } + pageInfo { + hasNextPage + endCursor + } + } + } + }` + + pageLimit := min(limit, 100) + variables := map[string]interface{}{ + "owner": repo.RepoOwner(), + "repo": repo.RepoName(), + } + + switch filters.State { + case "open": + variables["state"] = []string{"OPEN"} + case "closed": + variables["state"] = []string{"CLOSED", "MERGED"} + case "merged": + variables["state"] = []string{"MERGED"} + case "all": + variables["state"] = []string{"OPEN", "CLOSED", "MERGED"} + default: + return nil, fmt.Errorf("invalid state: %s", filters.State) + } + + if filters.BaseBranch != "" { + variables["baseBranch"] = filters.BaseBranch + } + if filters.HeadBranch != "" { + variables["headBranch"] = filters.HeadBranch + } + + res := api.PullRequestAndTotalCount{} + var check = make(map[int]struct{}) + client := api.NewClientFromHTTP(httpClient) + +loop: + for { + variables["limit"] = pageLimit + var data response + err := client.GraphQL(repo.RepoHost(), query, variables, &data) + if err != nil { + return nil, err + } + prData := data.Repository.PullRequests + res.TotalCount = prData.TotalCount + + for _, pr := range prData.Nodes { + if _, exists := check[pr.Number]; exists && pr.Number > 0 { + continue + } + check[pr.Number] = struct{}{} + + res.PullRequests = append(res.PullRequests, pr) + if len(res.PullRequests) == limit { + break loop + } + } + + if prData.PageInfo.HasNextPage { + variables["endCursor"] = prData.PageInfo.EndCursor + pageLimit = min(pageLimit, limit-len(res.PullRequests)) + } else { + break + } + } + + return &res, nil +} + +type mockLister struct { + called bool + expectFields []string + + result *api.PullRequestAndTotalCount + err error +} + +func NewMockLister(result *api.PullRequestAndTotalCount, err error) *mockLister { + return &mockLister{ + result: result, + err: err, + } +} + +func (m *mockLister) List(opt ListOptions) (*api.PullRequestAndTotalCount, error) { + m.called = true + if m.err != nil { + return nil, m.err + } + + if m.expectFields != nil { + + if !IsEqualSet(m.expectFields, opt.Fields) { + return nil, fmt.Errorf("unexpected fields: %v", opt.Fields) + } + } + + return m.result, m.err +} + +func (m *mockLister) ExpectFields(fields []string) { + m.expectFields = fields +}