issue #2329: create shared PRLister
This commit is contained in:
parent
d7cabf18f7
commit
91b3b99b76
4 changed files with 242 additions and 38 deletions
|
|
@ -13,7 +13,6 @@ import (
|
|||
"github.com/cli/cli/v2/internal/gh"
|
||||
"github.com/cli/cli/v2/internal/ghrepo"
|
||||
"github.com/cli/cli/v2/internal/text"
|
||||
"github.com/cli/cli/v2/pkg/cmd/pr/list"
|
||||
"github.com/cli/cli/v2/pkg/cmd/pr/shared"
|
||||
"github.com/cli/cli/v2/pkg/cmdutil"
|
||||
"github.com/cli/cli/v2/pkg/iostreams"
|
||||
|
|
@ -30,6 +29,7 @@ type CheckoutOptions struct {
|
|||
|
||||
Finder shared.PRFinder
|
||||
Prompter shared.Prompter
|
||||
Lister shared.PRLister
|
||||
|
||||
Interactive bool
|
||||
BaseRepo func() (ghrepo.Interface, error)
|
||||
|
|
@ -67,6 +67,7 @@ func NewCmdCheckout(f *cmdutil.Factory, runF func(*CheckoutOptions) error) *cobr
|
|||
Args: cobra.MaximumNArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
opts.Finder = shared.NewFinder(f)
|
||||
opts.Lister = shared.NewLister(f)
|
||||
|
||||
if len(args) > 0 {
|
||||
opts.SelectorArg = args[0]
|
||||
|
|
@ -97,12 +98,7 @@ func checkoutRun(opts *CheckoutOptions) error {
|
|||
return err
|
||||
}
|
||||
|
||||
client, err := opts.HttpClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pr, err := resolvePR(client, baseRepo, opts.Prompter, opts.SelectorArg, opts.Interactive, opts.Finder, opts.IO)
|
||||
pr, err := resolvePR(baseRepo, opts.Prompter, opts.SelectorArg, opts.Interactive, opts.Finder, opts.Lister, opts.IO)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -291,7 +287,7 @@ func executeCmds(client *git.Client, credentialPattern git.CredentialPattern, cm
|
|||
return nil
|
||||
}
|
||||
|
||||
func resolvePR(httpClient *http.Client, baseRepo ghrepo.Interface, prompter shared.Prompter, pullRequestSelector string, isInteractive bool, pullRequestFinder shared.PRFinder, io *iostreams.IOStreams) (*api.PullRequest, error) {
|
||||
func resolvePR(baseRepo ghrepo.Interface, prompter shared.Prompter, pullRequestSelector string, isInteractive bool, pullRequestFinder shared.PRFinder, prLister shared.PRLister, io *iostreams.IOStreams) (*api.PullRequest, error) {
|
||||
// When non-interactive
|
||||
if pullRequestSelector != "" {
|
||||
pr, _, err := pullRequestFinder.Find(shared.FindOptions{
|
||||
|
|
@ -315,20 +311,21 @@ func resolvePR(httpClient *http.Client, baseRepo ghrepo.Interface, prompter shar
|
|||
}
|
||||
// When interactive
|
||||
io.StartProgressIndicator()
|
||||
listResult, err := list.ListPullRequests(httpClient, baseRepo, shared.FilterOptions{Entity: "pr", State: "open", Fields: []string{
|
||||
"number",
|
||||
"title",
|
||||
"state",
|
||||
"url",
|
||||
"isDraft",
|
||||
"createdAt",
|
||||
listResult, err := prLister.List(shared.ListOptions{
|
||||
State: "open",
|
||||
Fields: []string{
|
||||
"number",
|
||||
"title",
|
||||
"state",
|
||||
"isDraft",
|
||||
|
||||
"headRefName",
|
||||
"headRepository",
|
||||
"headRepositoryOwner",
|
||||
"isCrossRepository",
|
||||
"maintainerCanModify",
|
||||
}}, 10)
|
||||
"headRefName",
|
||||
"headRepository",
|
||||
"headRepositoryOwner",
|
||||
"isCrossRepository",
|
||||
"maintainerCanModify",
|
||||
},
|
||||
LimitResults: 10})
|
||||
io.StopProgressIndicator()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
|||
|
|
@ -28,6 +28,10 @@ import (
|
|||
// repo: either "baseOwner/baseRepo" or "baseOwner/baseRepo:defaultBranch"
|
||||
// prHead: "headOwner/headRepo:headBranch"
|
||||
func stubPR(repo, prHead string) (ghrepo.Interface, *api.PullRequest) {
|
||||
return _stubPR(repo, prHead, 123, "PR title", "OPEN", false)
|
||||
}
|
||||
|
||||
func _stubPR(repo, prHead string, number int, title string, state string, isDraft bool) (ghrepo.Interface, *api.PullRequest) {
|
||||
defaultBranch := ""
|
||||
if idx := strings.IndexRune(repo, ':'); idx >= 0 {
|
||||
defaultBranch = repo[idx+1:]
|
||||
|
|
@ -53,12 +57,16 @@ func stubPR(repo, prHead string) (ghrepo.Interface, *api.PullRequest) {
|
|||
}
|
||||
|
||||
return baseRepo, &api.PullRequest{
|
||||
Number: 123,
|
||||
Number: number,
|
||||
HeadRefName: headRefName,
|
||||
HeadRepositoryOwner: api.Owner{Login: headRepo.RepoOwner()},
|
||||
HeadRepository: &api.PRRepository{Name: headRepo.RepoName()},
|
||||
IsCrossRepository: !ghrepo.IsSame(baseRepo, headRepo),
|
||||
MaintainerCanModify: false,
|
||||
|
||||
Title: title,
|
||||
State: state,
|
||||
IsDraft: isDraft,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -224,10 +232,17 @@ func Test_checkoutRun(t *testing.T) {
|
|||
opts: &CheckoutOptions{
|
||||
SelectorArg: "",
|
||||
Interactive: true,
|
||||
Finder: func() shared.PRFinder {
|
||||
baseRepo, pr := stubPR("OWNER/REPO:master", "OWNER/REPO:feature")
|
||||
finder := shared.NewMockFinder("123", pr, baseRepo)
|
||||
return finder
|
||||
Lister: func() shared.PRLister {
|
||||
_, pr1 := _stubPR("OWNER/REPO:master", "OWNER/REPO:feature", 32, "New feature", "OPEN", false)
|
||||
_, pr2 := _stubPR("OWNER/REPO:master", "OWNER/REPO:bug-fix", 29, "Fixed bad bug", "OPEN", false)
|
||||
_, pr3 := _stubPR("OWNER/REPO:master", "OWNER/REPO:docs", 28, "Improve documentation", "OPEN", true)
|
||||
lister := shared.NewMockLister(&api.PullRequestAndTotalCount{
|
||||
TotalCount: 3,
|
||||
PullRequests: []api.PullRequest{
|
||||
*pr1, *pr2, *pr3,
|
||||
}, SearchCapped: false}, nil)
|
||||
lister.ExpectFields([]string{"number", "title", "state", "isDraft", "headRefName", "headRepository", "headRepositoryOwner", "isCrossRepository", "maintainerCanModify"})
|
||||
return lister
|
||||
}(),
|
||||
BaseRepo: func() (ghrepo.Interface, error) {
|
||||
return ghrepo.New("OWNER", "REPO"), nil
|
||||
|
|
@ -236,14 +251,11 @@ func Test_checkoutRun(t *testing.T) {
|
|||
return config.NewBlankConfig(), nil
|
||||
},
|
||||
},
|
||||
httpStubs: func(reg *httpmock.Registry) {
|
||||
reg.Register(httpmock.GraphQL(`query PullRequestList\b`), httpmock.FileResponse("./fixtures/prList.json"))
|
||||
},
|
||||
promptStubs: func(pm *prompter.MockPrompter) {
|
||||
pm.RegisterSelect("Select a pull request",
|
||||
[]string{"32\tDRAFT New feature [feature]", "29\tOPEN Fixed bad bug [bug-fix]", "28\tOPEN Improve documentation [docs]"},
|
||||
[]string{"32\tOPEN New feature [feature]", "29\tOPEN Fixed bad bug [bug-fix]", "28\tDRAFT Improve documentation [docs]"},
|
||||
func(_, _ string, opts []string) (int, error) {
|
||||
return prompter.IndexFor(opts, "32\tDRAFT New feature [feature]")
|
||||
return prompter.IndexFor(opts, "32\tOPEN New feature [feature]")
|
||||
})
|
||||
},
|
||||
runStubs: func(cs *run.CommandStubber) {
|
||||
|
|
@ -263,11 +275,10 @@ func Test_checkoutRun(t *testing.T) {
|
|||
BaseRepo: func() (ghrepo.Interface, error) {
|
||||
return ghrepo.New("OWNER", "REPO"), nil
|
||||
},
|
||||
},
|
||||
httpStubs: func(reg *httpmock.Registry) {
|
||||
reg.Register(httpmock.GraphQL(`query PullRequestList\b`),
|
||||
httpmock.StringResponse(`{"data":{"repository":{"pullRequests":{ "totalCount": 0,"nodes":[]}}}}`))
|
||||
|
||||
Lister: shared.NewMockLister(&api.PullRequestAndTotalCount{
|
||||
TotalCount: 0,
|
||||
PullRequests: []api.PullRequest{},
|
||||
}, nil),
|
||||
},
|
||||
remotes: map[string]string{
|
||||
"origin": "OWNER/REPO",
|
||||
|
|
|
|||
|
|
@ -548,7 +548,7 @@ func (m *mockFinder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface,
|
|||
if m.expectSelector != opts.Selector {
|
||||
return nil, nil, fmt.Errorf("mockFinder: expected selector %q, got %q", m.expectSelector, opts.Selector)
|
||||
}
|
||||
if len(m.expectFields) > 0 && !isEqualSet(m.expectFields, opts.Fields) {
|
||||
if len(m.expectFields) > 0 && !IsEqualSet(m.expectFields, opts.Fields) {
|
||||
return nil, nil, fmt.Errorf("mockFinder: expected fields %v, got %v", m.expectFields, opts.Fields)
|
||||
}
|
||||
if m.called {
|
||||
|
|
@ -568,7 +568,7 @@ func (m *mockFinder) ExpectFields(fields []string) {
|
|||
m.expectFields = fields
|
||||
}
|
||||
|
||||
func isEqualSet(a, b []string) bool {
|
||||
func IsEqualSet(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
|
|
|
|||
196
pkg/cmd/pr/shared/lister.go
Normal file
196
pkg/cmd/pr/shared/lister.go
Normal file
|
|
@ -0,0 +1,196 @@
|
|||
package shared
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/cli/cli/v2/internal/ghrepo"
|
||||
"github.com/cli/cli/v2/pkg/cmdutil"
|
||||
|
||||
api "github.com/cli/cli/v2/api"
|
||||
)
|
||||
|
||||
type PRLister interface {
|
||||
List(opt ListOptions) (*api.PullRequestAndTotalCount, error)
|
||||
}
|
||||
|
||||
type ListOptions struct {
|
||||
LimitResults int
|
||||
|
||||
State string
|
||||
BaseBranch string
|
||||
HeadBranch string
|
||||
|
||||
Fields []string
|
||||
}
|
||||
|
||||
type lister struct {
|
||||
baseRepoFn func() (ghrepo.Interface, error)
|
||||
httpClient func() (*http.Client, error)
|
||||
}
|
||||
|
||||
func NewLister(factory *cmdutil.Factory) PRLister {
|
||||
return &lister{
|
||||
baseRepoFn: factory.BaseRepo,
|
||||
httpClient: factory.HttpClient,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *lister) List(opts ListOptions) (*api.PullRequestAndTotalCount, error) {
|
||||
repo, err := l.baseRepoFn()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client, err := l.httpClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return listPullRequests(client, repo, opts)
|
||||
}
|
||||
|
||||
func listPullRequests(httpClient *http.Client, repo ghrepo.Interface, filters ListOptions) (*api.PullRequestAndTotalCount, error) {
|
||||
type response struct {
|
||||
Repository struct {
|
||||
PullRequests struct {
|
||||
Nodes []api.PullRequest
|
||||
PageInfo struct {
|
||||
HasNextPage bool
|
||||
EndCursor string
|
||||
}
|
||||
TotalCount int
|
||||
}
|
||||
}
|
||||
}
|
||||
limit := filters.LimitResults
|
||||
fragment := fmt.Sprintf("fragment pr on PullRequest{%s}", api.PullRequestGraphQL(filters.Fields))
|
||||
query := fragment + `
|
||||
query PullRequestList(
|
||||
$owner: String!,
|
||||
$repo: String!,
|
||||
$limit: Int!,
|
||||
$endCursor: String,
|
||||
$baseBranch: String,
|
||||
$headBranch: String,
|
||||
$state: [PullRequestState!] = OPEN
|
||||
) {
|
||||
repository(owner: $owner, name: $repo) {
|
||||
pullRequests(
|
||||
states: $state,
|
||||
baseRefName: $baseBranch,
|
||||
headRefName: $headBranch,
|
||||
first: $limit,
|
||||
after: $endCursor,
|
||||
orderBy: {field: CREATED_AT, direction: DESC}
|
||||
) {
|
||||
totalCount
|
||||
nodes {
|
||||
...pr
|
||||
}
|
||||
pageInfo {
|
||||
hasNextPage
|
||||
endCursor
|
||||
}
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
pageLimit := min(limit, 100)
|
||||
variables := map[string]interface{}{
|
||||
"owner": repo.RepoOwner(),
|
||||
"repo": repo.RepoName(),
|
||||
}
|
||||
|
||||
switch filters.State {
|
||||
case "open":
|
||||
variables["state"] = []string{"OPEN"}
|
||||
case "closed":
|
||||
variables["state"] = []string{"CLOSED", "MERGED"}
|
||||
case "merged":
|
||||
variables["state"] = []string{"MERGED"}
|
||||
case "all":
|
||||
variables["state"] = []string{"OPEN", "CLOSED", "MERGED"}
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid state: %s", filters.State)
|
||||
}
|
||||
|
||||
if filters.BaseBranch != "" {
|
||||
variables["baseBranch"] = filters.BaseBranch
|
||||
}
|
||||
if filters.HeadBranch != "" {
|
||||
variables["headBranch"] = filters.HeadBranch
|
||||
}
|
||||
|
||||
res := api.PullRequestAndTotalCount{}
|
||||
var check = make(map[int]struct{})
|
||||
client := api.NewClientFromHTTP(httpClient)
|
||||
|
||||
loop:
|
||||
for {
|
||||
variables["limit"] = pageLimit
|
||||
var data response
|
||||
err := client.GraphQL(repo.RepoHost(), query, variables, &data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
prData := data.Repository.PullRequests
|
||||
res.TotalCount = prData.TotalCount
|
||||
|
||||
for _, pr := range prData.Nodes {
|
||||
if _, exists := check[pr.Number]; exists && pr.Number > 0 {
|
||||
continue
|
||||
}
|
||||
check[pr.Number] = struct{}{}
|
||||
|
||||
res.PullRequests = append(res.PullRequests, pr)
|
||||
if len(res.PullRequests) == limit {
|
||||
break loop
|
||||
}
|
||||
}
|
||||
|
||||
if prData.PageInfo.HasNextPage {
|
||||
variables["endCursor"] = prData.PageInfo.EndCursor
|
||||
pageLimit = min(pageLimit, limit-len(res.PullRequests))
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return &res, nil
|
||||
}
|
||||
|
||||
type mockLister struct {
|
||||
called bool
|
||||
expectFields []string
|
||||
|
||||
result *api.PullRequestAndTotalCount
|
||||
err error
|
||||
}
|
||||
|
||||
func NewMockLister(result *api.PullRequestAndTotalCount, err error) *mockLister {
|
||||
return &mockLister{
|
||||
result: result,
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockLister) List(opt ListOptions) (*api.PullRequestAndTotalCount, error) {
|
||||
m.called = true
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
|
||||
if m.expectFields != nil {
|
||||
|
||||
if !IsEqualSet(m.expectFields, opt.Fields) {
|
||||
return nil, fmt.Errorf("unexpected fields: %v", opt.Fields)
|
||||
}
|
||||
}
|
||||
|
||||
return m.result, m.err
|
||||
}
|
||||
|
||||
func (m *mockLister) ExpectFields(fields []string) {
|
||||
m.expectFields = fields
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue