diff --git a/pkg/cmd/pr/list/http.go b/pkg/cmd/pr/list/http.go index 47cbd9d61..5cf6a410e 100644 --- a/pkg/cmd/pr/list/http.go +++ b/pkg/cmd/pr/list/http.go @@ -10,8 +10,12 @@ import ( "github.com/cli/cli/v2/pkg/githubsearch" ) +func shouldUseSearch(filters prShared.FilterOptions) bool { + return filters.Draft != "" || filters.Author != "" || filters.Assignee != "" || filters.Search != "" || len(filters.Labels) > 0 +} + func listPullRequests(httpClient *http.Client, repo ghrepo.Interface, filters prShared.FilterOptions, limit int) (*api.PullRequestAndTotalCount, error) { - if filters.Author != "" || filters.Assignee != "" || filters.Search != "" || len(filters.Labels) > 0 { + if shouldUseSearch(filters) { return searchPullRequests(httpClient, repo, filters, limit) } @@ -177,6 +181,10 @@ func searchPullRequests(httpClient *http.Client, repo ghrepo.Interface, filters q.SetBaseBranch(filters.BaseBranch) } + if filters.Draft != "" { + q.SetDraft(filters.Draft) + } + pageLimit := min(limit, 100) variables := map[string]interface{}{ "q": q.String(), diff --git a/pkg/cmd/pr/list/list.go b/pkg/cmd/pr/list/list.go index dd62da726..a2cec0e43 100644 --- a/pkg/cmd/pr/list/list.go +++ b/pkg/cmd/pr/list/list.go @@ -37,6 +37,7 @@ type ListOptions struct { Author string Assignee string Search string + Draft string } func NewCmdList(f *cmdutil.Factory, runF func(*ListOptions) error) *cobra.Command { @@ -46,6 +47,8 @@ func NewCmdList(f *cmdutil.Factory, runF func(*ListOptions) error) *cobra.Comman Browser: f.Browser, } + var draft bool + cmd := &cobra.Command{ Use: "list", Short: "List and filter pull requests in this repository", @@ -74,6 +77,10 @@ func NewCmdList(f *cmdutil.Factory, runF func(*ListOptions) error) *cobra.Comman return &cmdutil.FlagError{Err: fmt.Errorf("invalid value for --limit: %v", opts.LimitResults)} } + if cmd.Flags().Changed("draft") { + opts.Draft = strconv.FormatBool(draft) + } + if runF != nil { return runF(opts) } @@ -92,6 +99,8 @@ func NewCmdList(f *cmdutil.Factory, runF func(*ListOptions) error) *cobra.Comman cmd.Flags().StringVarP(&opts.Author, "author", "A", "", "Filter by author") cmd.Flags().StringVarP(&opts.Assignee, "assignee", "a", "", "Filter by assignee") cmd.Flags().StringVarP(&opts.Search, "search", "S", "", "Search pull requests with `query`") + cmd.Flags().BoolVarP(&draft, "draft", "d", false, "Filter by draft state") + cmdutil.AddJSONFlags(cmd, &opts.Exporter, api.PullRequestFields) return cmd @@ -132,12 +141,12 @@ func listRun(opts *ListOptions) error { Labels: opts.Labels, BaseBranch: opts.BaseBranch, Search: opts.Search, + Draft: opts.Draft, Fields: defaultFields, } if opts.Exporter != nil { filters.Fields = opts.Exporter.Fields() } - if opts.WebMode { prListURL := ghrepo.GenerateRepoURL(baseRepo, "pulls") openURL, err := shared.ListURLWithQuery(prListURL, filters) diff --git a/pkg/cmd/pr/list/list_test.go b/pkg/cmd/pr/list/list_test.go index a912e4cb6..ff55089f0 100644 --- a/pkg/cmd/pr/list/list_test.go +++ b/pkg/cmd/pr/list/list_test.go @@ -176,39 +176,89 @@ func TestPRList_filteringAssignee(t *testing.T) { } } -func TestPRList_filteringAssigneeLabels(t *testing.T) { - http := initFakeHTTP() - defer http.Verify(t) +func TestPRList_filteringDraft(t *testing.T) { + tests := []struct { + name string + cli string + expectedQuery string + }{ + { + name: "draft", + cli: "--draft", + expectedQuery: `repo:OWNER/REPO is:pr is:open draft:true`, + }, + { + name: "non-draft", + cli: "--draft=false", + expectedQuery: `repo:OWNER/REPO is:pr is:open draft:false`, + }, + } - _, err := runCommand(http, true, `-l one,two -a hubot`) - if err == nil && err.Error() != "multiple labels with --assignee are not supported" { - t.Fatal(err) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + http := initFakeHTTP() + defer http.Verify(t) + + http.Register( + httpmock.GraphQL(`query PullRequestSearch\b`), + httpmock.GraphQLQuery(`{}`, func(_ string, params map[string]interface{}) { + assert.Equal(t, test.expectedQuery, params["q"].(string)) + })) + + _, err := runCommand(http, true, test.cli) + if err != nil { + t.Fatal(err) + } + }) } } func TestPRList_withInvalidLimitFlag(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - _, err := runCommand(http, true, `--limit=0`) - if err == nil && err.Error() != "invalid limit: 0" { - t.Errorf("error running command `issue list`: %v", err) - } + assert.EqualError(t, err, "invalid value for --limit: 0") } func TestPRList_web(t *testing.T) { - http := initFakeHTTP() - defer http.Verify(t) - - _, cmdTeardown := run.Stub() - defer cmdTeardown(t) - - output, err := runCommand(http, true, "--web -a peter -l bug -l docs -L 10 -s merged -B trunk") - if err != nil { - t.Errorf("error running command `pr list` with `--web` flag: %v", err) + tests := []struct { + name string + cli string + expectedBrowserURL string + }{ + { + name: "filters", + cli: "-a peter -l bug -l docs -L 10 -s merged -B trunk", + expectedBrowserURL: "https://github.com/OWNER/REPO/pulls?q=is%3Apr+is%3Amerged+assignee%3Apeter+label%3Abug+label%3Adocs+base%3Atrunk", + }, + { + name: "draft", + cli: "--draft=true", + expectedBrowserURL: "https://github.com/OWNER/REPO/pulls?q=is%3Apr+is%3Aopen+draft%3Atrue", + }, + { + name: "non-draft", + cli: "--draft=0", + expectedBrowserURL: "https://github.com/OWNER/REPO/pulls?q=is%3Apr+is%3Aopen+draft%3Afalse", + }, } - assert.Equal(t, "", output.String()) - assert.Equal(t, "Opening github.com/OWNER/REPO/pulls in your browser.\n", output.Stderr()) - assert.Equal(t, "https://github.com/OWNER/REPO/pulls?q=is%3Apr+is%3Amerged+assignee%3Apeter+label%3Abug+label%3Adocs+base%3Atrunk", output.BrowsedURL) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + http := initFakeHTTP() + defer http.Verify(t) + + _, cmdTeardown := run.Stub() + defer cmdTeardown(t) + + output, err := runCommand(http, true, "--web "+test.cli) + if err != nil { + t.Errorf("error running command `pr list` with `--web` flag: %v", err) + } + + assert.Equal(t, "", output.String()) + assert.Equal(t, "Opening github.com/OWNER/REPO/pulls in your browser.\n", output.Stderr()) + assert.Equal(t, test.expectedBrowserURL, output.BrowsedURL) + }) + } } diff --git a/pkg/cmd/pr/shared/params.go b/pkg/cmd/pr/shared/params.go index 92096e9d8..2333bcfa3 100644 --- a/pkg/cmd/pr/shared/params.go +++ b/pkg/cmd/pr/shared/params.go @@ -157,8 +157,8 @@ type FilterOptions struct { Mention string Milestone string Search string - - Fields []string + Draft string + Fields []string } func (opts *FilterOptions) IsDefault() bool { @@ -241,7 +241,9 @@ func SearchQueryBuild(options FilterOptions) string { if options.Search != "" { q.AddQuery(options.Search) } - + if options.Draft != "" { + q.SetDraft(options.Draft) + } return q.String() } diff --git a/pkg/cmd/pr/shared/params_test.go b/pkg/cmd/pr/shared/params_test.go index fa41ac307..8f3e793e5 100644 --- a/pkg/cmd/pr/shared/params_test.go +++ b/pkg/cmd/pr/shared/params_test.go @@ -16,6 +16,7 @@ func Test_listURLWithQuery(t *testing.T) { listURL string options FilterOptions } + tests := []struct { name string args args @@ -34,6 +35,32 @@ func Test_listURLWithQuery(t *testing.T) { want: "https://example.com/path?a=b&q=is%3Aissue+is%3Aopen", wantErr: false, }, + { + name: "draft", + args: args{ + listURL: "https://example.com/path", + options: FilterOptions{ + Entity: "pr", + State: "open", + Draft: "true", + }, + }, + want: "https://example.com/path?q=is%3Apr+is%3Aopen+draft%3Atrue", + wantErr: false, + }, + { + name: "non-draft", + args: args{ + listURL: "https://example.com/path", + options: FilterOptions{ + Entity: "pr", + State: "open", + Draft: "false", + }, + }, + want: "https://example.com/path?q=is%3Apr+is%3Aopen+draft%3Afalse", + wantErr: false, + }, { name: "all", args: args{ diff --git a/pkg/githubsearch/query.go b/pkg/githubsearch/query.go index f25707710..a8f3005a9 100644 --- a/pkg/githubsearch/query.go +++ b/pkg/githubsearch/query.go @@ -54,6 +54,7 @@ type Query struct { forkState string visibility string isArchived *bool + draft string } func (q *Query) InRepository(nameWithOwner string) { @@ -139,6 +140,10 @@ func (q *Query) SetArchived(isArchived bool) { q.isArchived = &isArchived } +func (q *Query) SetDraft(draft string) { + q.draft = draft +} + func (q *Query) String() string { var qs string @@ -198,6 +203,9 @@ func (q *Query) String() string { if q.headBranch != "" { qs += fmt.Sprintf("head:%s ", quote(q.headBranch)) } + if q.draft != "" { + qs += fmt.Sprintf("draft:%v ", q.draft) + } if q.sort != "" { qs += fmt.Sprintf("sort:%s ", q.sort)