issue #2329: create shared PRLister

This commit is contained in:
nilvng 2024-12-15 14:01:08 +11:00
parent d7cabf18f7
commit 91b3b99b76
4 changed files with 242 additions and 38 deletions

View file

@ -13,7 +13,6 @@ import (
"github.com/cli/cli/v2/internal/gh"
"github.com/cli/cli/v2/internal/ghrepo"
"github.com/cli/cli/v2/internal/text"
"github.com/cli/cli/v2/pkg/cmd/pr/list"
"github.com/cli/cli/v2/pkg/cmd/pr/shared"
"github.com/cli/cli/v2/pkg/cmdutil"
"github.com/cli/cli/v2/pkg/iostreams"
@ -30,6 +29,7 @@ type CheckoutOptions struct {
Finder shared.PRFinder
Prompter shared.Prompter
Lister shared.PRLister
Interactive bool
BaseRepo func() (ghrepo.Interface, error)
@ -67,6 +67,7 @@ func NewCmdCheckout(f *cmdutil.Factory, runF func(*CheckoutOptions) error) *cobr
Args: cobra.MaximumNArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
opts.Finder = shared.NewFinder(f)
opts.Lister = shared.NewLister(f)
if len(args) > 0 {
opts.SelectorArg = args[0]
@ -97,12 +98,7 @@ func checkoutRun(opts *CheckoutOptions) error {
return err
}
client, err := opts.HttpClient()
if err != nil {
return err
}
pr, err := resolvePR(client, baseRepo, opts.Prompter, opts.SelectorArg, opts.Interactive, opts.Finder, opts.IO)
pr, err := resolvePR(baseRepo, opts.Prompter, opts.SelectorArg, opts.Interactive, opts.Finder, opts.Lister, opts.IO)
if err != nil {
return err
}
@ -291,7 +287,7 @@ func executeCmds(client *git.Client, credentialPattern git.CredentialPattern, cm
return nil
}
func resolvePR(httpClient *http.Client, baseRepo ghrepo.Interface, prompter shared.Prompter, pullRequestSelector string, isInteractive bool, pullRequestFinder shared.PRFinder, io *iostreams.IOStreams) (*api.PullRequest, error) {
func resolvePR(baseRepo ghrepo.Interface, prompter shared.Prompter, pullRequestSelector string, isInteractive bool, pullRequestFinder shared.PRFinder, prLister shared.PRLister, io *iostreams.IOStreams) (*api.PullRequest, error) {
// When non-interactive
if pullRequestSelector != "" {
pr, _, err := pullRequestFinder.Find(shared.FindOptions{
@ -315,20 +311,21 @@ func resolvePR(httpClient *http.Client, baseRepo ghrepo.Interface, prompter shar
}
// When interactive
io.StartProgressIndicator()
listResult, err := list.ListPullRequests(httpClient, baseRepo, shared.FilterOptions{Entity: "pr", State: "open", Fields: []string{
"number",
"title",
"state",
"url",
"isDraft",
"createdAt",
listResult, err := prLister.List(shared.ListOptions{
State: "open",
Fields: []string{
"number",
"title",
"state",
"isDraft",
"headRefName",
"headRepository",
"headRepositoryOwner",
"isCrossRepository",
"maintainerCanModify",
}}, 10)
"headRefName",
"headRepository",
"headRepositoryOwner",
"isCrossRepository",
"maintainerCanModify",
},
LimitResults: 10})
io.StopProgressIndicator()
if err != nil {
return nil, err

View file

@ -28,6 +28,10 @@ import (
// repo: either "baseOwner/baseRepo" or "baseOwner/baseRepo:defaultBranch"
// prHead: "headOwner/headRepo:headBranch"
func stubPR(repo, prHead string) (ghrepo.Interface, *api.PullRequest) {
return _stubPR(repo, prHead, 123, "PR title", "OPEN", false)
}
func _stubPR(repo, prHead string, number int, title string, state string, isDraft bool) (ghrepo.Interface, *api.PullRequest) {
defaultBranch := ""
if idx := strings.IndexRune(repo, ':'); idx >= 0 {
defaultBranch = repo[idx+1:]
@ -53,12 +57,16 @@ func stubPR(repo, prHead string) (ghrepo.Interface, *api.PullRequest) {
}
return baseRepo, &api.PullRequest{
Number: 123,
Number: number,
HeadRefName: headRefName,
HeadRepositoryOwner: api.Owner{Login: headRepo.RepoOwner()},
HeadRepository: &api.PRRepository{Name: headRepo.RepoName()},
IsCrossRepository: !ghrepo.IsSame(baseRepo, headRepo),
MaintainerCanModify: false,
Title: title,
State: state,
IsDraft: isDraft,
}
}
@ -224,10 +232,17 @@ func Test_checkoutRun(t *testing.T) {
opts: &CheckoutOptions{
SelectorArg: "",
Interactive: true,
Finder: func() shared.PRFinder {
baseRepo, pr := stubPR("OWNER/REPO:master", "OWNER/REPO:feature")
finder := shared.NewMockFinder("123", pr, baseRepo)
return finder
Lister: func() shared.PRLister {
_, pr1 := _stubPR("OWNER/REPO:master", "OWNER/REPO:feature", 32, "New feature", "OPEN", false)
_, pr2 := _stubPR("OWNER/REPO:master", "OWNER/REPO:bug-fix", 29, "Fixed bad bug", "OPEN", false)
_, pr3 := _stubPR("OWNER/REPO:master", "OWNER/REPO:docs", 28, "Improve documentation", "OPEN", true)
lister := shared.NewMockLister(&api.PullRequestAndTotalCount{
TotalCount: 3,
PullRequests: []api.PullRequest{
*pr1, *pr2, *pr3,
}, SearchCapped: false}, nil)
lister.ExpectFields([]string{"number", "title", "state", "isDraft", "headRefName", "headRepository", "headRepositoryOwner", "isCrossRepository", "maintainerCanModify"})
return lister
}(),
BaseRepo: func() (ghrepo.Interface, error) {
return ghrepo.New("OWNER", "REPO"), nil
@ -236,14 +251,11 @@ func Test_checkoutRun(t *testing.T) {
return config.NewBlankConfig(), nil
},
},
httpStubs: func(reg *httpmock.Registry) {
reg.Register(httpmock.GraphQL(`query PullRequestList\b`), httpmock.FileResponse("./fixtures/prList.json"))
},
promptStubs: func(pm *prompter.MockPrompter) {
pm.RegisterSelect("Select a pull request",
[]string{"32\tDRAFT New feature [feature]", "29\tOPEN Fixed bad bug [bug-fix]", "28\tOPEN Improve documentation [docs]"},
[]string{"32\tOPEN New feature [feature]", "29\tOPEN Fixed bad bug [bug-fix]", "28\tDRAFT Improve documentation [docs]"},
func(_, _ string, opts []string) (int, error) {
return prompter.IndexFor(opts, "32\tDRAFT New feature [feature]")
return prompter.IndexFor(opts, "32\tOPEN New feature [feature]")
})
},
runStubs: func(cs *run.CommandStubber) {
@ -263,11 +275,10 @@ func Test_checkoutRun(t *testing.T) {
BaseRepo: func() (ghrepo.Interface, error) {
return ghrepo.New("OWNER", "REPO"), nil
},
},
httpStubs: func(reg *httpmock.Registry) {
reg.Register(httpmock.GraphQL(`query PullRequestList\b`),
httpmock.StringResponse(`{"data":{"repository":{"pullRequests":{ "totalCount": 0,"nodes":[]}}}}`))
Lister: shared.NewMockLister(&api.PullRequestAndTotalCount{
TotalCount: 0,
PullRequests: []api.PullRequest{},
}, nil),
},
remotes: map[string]string{
"origin": "OWNER/REPO",

View file

@ -548,7 +548,7 @@ func (m *mockFinder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface,
if m.expectSelector != opts.Selector {
return nil, nil, fmt.Errorf("mockFinder: expected selector %q, got %q", m.expectSelector, opts.Selector)
}
if len(m.expectFields) > 0 && !isEqualSet(m.expectFields, opts.Fields) {
if len(m.expectFields) > 0 && !IsEqualSet(m.expectFields, opts.Fields) {
return nil, nil, fmt.Errorf("mockFinder: expected fields %v, got %v", m.expectFields, opts.Fields)
}
if m.called {
@ -568,7 +568,7 @@ func (m *mockFinder) ExpectFields(fields []string) {
m.expectFields = fields
}
func isEqualSet(a, b []string) bool {
func IsEqualSet(a, b []string) bool {
if len(a) != len(b) {
return false
}

196
pkg/cmd/pr/shared/lister.go Normal file
View file

@ -0,0 +1,196 @@
package shared
import (
"fmt"
"net/http"
"github.com/cli/cli/v2/internal/ghrepo"
"github.com/cli/cli/v2/pkg/cmdutil"
api "github.com/cli/cli/v2/api"
)
type PRLister interface {
List(opt ListOptions) (*api.PullRequestAndTotalCount, error)
}
type ListOptions struct {
LimitResults int
State string
BaseBranch string
HeadBranch string
Fields []string
}
type lister struct {
baseRepoFn func() (ghrepo.Interface, error)
httpClient func() (*http.Client, error)
}
func NewLister(factory *cmdutil.Factory) PRLister {
return &lister{
baseRepoFn: factory.BaseRepo,
httpClient: factory.HttpClient,
}
}
func (l *lister) List(opts ListOptions) (*api.PullRequestAndTotalCount, error) {
repo, err := l.baseRepoFn()
if err != nil {
return nil, err
}
client, err := l.httpClient()
if err != nil {
return nil, err
}
return listPullRequests(client, repo, opts)
}
func listPullRequests(httpClient *http.Client, repo ghrepo.Interface, filters ListOptions) (*api.PullRequestAndTotalCount, error) {
type response struct {
Repository struct {
PullRequests struct {
Nodes []api.PullRequest
PageInfo struct {
HasNextPage bool
EndCursor string
}
TotalCount int
}
}
}
limit := filters.LimitResults
fragment := fmt.Sprintf("fragment pr on PullRequest{%s}", api.PullRequestGraphQL(filters.Fields))
query := fragment + `
query PullRequestList(
$owner: String!,
$repo: String!,
$limit: Int!,
$endCursor: String,
$baseBranch: String,
$headBranch: String,
$state: [PullRequestState!] = OPEN
) {
repository(owner: $owner, name: $repo) {
pullRequests(
states: $state,
baseRefName: $baseBranch,
headRefName: $headBranch,
first: $limit,
after: $endCursor,
orderBy: {field: CREATED_AT, direction: DESC}
) {
totalCount
nodes {
...pr
}
pageInfo {
hasNextPage
endCursor
}
}
}
}`
pageLimit := min(limit, 100)
variables := map[string]interface{}{
"owner": repo.RepoOwner(),
"repo": repo.RepoName(),
}
switch filters.State {
case "open":
variables["state"] = []string{"OPEN"}
case "closed":
variables["state"] = []string{"CLOSED", "MERGED"}
case "merged":
variables["state"] = []string{"MERGED"}
case "all":
variables["state"] = []string{"OPEN", "CLOSED", "MERGED"}
default:
return nil, fmt.Errorf("invalid state: %s", filters.State)
}
if filters.BaseBranch != "" {
variables["baseBranch"] = filters.BaseBranch
}
if filters.HeadBranch != "" {
variables["headBranch"] = filters.HeadBranch
}
res := api.PullRequestAndTotalCount{}
var check = make(map[int]struct{})
client := api.NewClientFromHTTP(httpClient)
loop:
for {
variables["limit"] = pageLimit
var data response
err := client.GraphQL(repo.RepoHost(), query, variables, &data)
if err != nil {
return nil, err
}
prData := data.Repository.PullRequests
res.TotalCount = prData.TotalCount
for _, pr := range prData.Nodes {
if _, exists := check[pr.Number]; exists && pr.Number > 0 {
continue
}
check[pr.Number] = struct{}{}
res.PullRequests = append(res.PullRequests, pr)
if len(res.PullRequests) == limit {
break loop
}
}
if prData.PageInfo.HasNextPage {
variables["endCursor"] = prData.PageInfo.EndCursor
pageLimit = min(pageLimit, limit-len(res.PullRequests))
} else {
break
}
}
return &res, nil
}
type mockLister struct {
called bool
expectFields []string
result *api.PullRequestAndTotalCount
err error
}
func NewMockLister(result *api.PullRequestAndTotalCount, err error) *mockLister {
return &mockLister{
result: result,
err: err,
}
}
func (m *mockLister) List(opt ListOptions) (*api.PullRequestAndTotalCount, error) {
m.called = true
if m.err != nil {
return nil, m.err
}
if m.expectFields != nil {
if !IsEqualSet(m.expectFields, opt.Fields) {
return nil, fmt.Errorf("unexpected fields: %v", opt.Fields)
}
}
return m.result, m.err
}
func (m *mockLister) ExpectFields(fields []string) {
m.expectFields = fields
}