From 9bdc63c4ca6de084d5a9db3b068aa6314e1e81e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Wed, 28 Apr 2021 19:25:27 +0200 Subject: [PATCH] Eliminate API overfetching in `pr` commands This completely rewrites the PR lookup mechanism so that the caller must specify the GraphQL fields to query for each PR. Additionally, this fixes some export problems with `pr view --json`. Features: - Each pr command now gets assigned a concept of a Finder. This makes it easier to stub the PR in tests without having to stub the underlying HTTP calls or git invocations. - `pr view --web` is much faster since it only fetches the "url" field. - `pr diff 123` now skips a whole API call where a whole PR was unnecessarily preloaded just to access its diff in a subsequent call. - PullRequestGraphQL query builder is now used to construct queries. - A bunch of individual commands are now freed of having to know about concepts such as BaseRepo, Branch, Config, or Remotes. --- api/queries_comments.go | 18 - api/queries_issue.go | 4 + api/queries_pr.go | 339 ++----------- api/queries_pr_test.go | 29 -- api/reaction_groups.go | 9 - pkg/cmd/pr/checkout/checkout.go | 32 +- pkg/cmd/pr/checks/checks.go | 34 +- pkg/cmd/pr/checks/checks_test.go | 23 +- pkg/cmd/pr/close/close.go | 25 +- pkg/cmd/pr/comment/comment.go | 32 +- pkg/cmd/pr/comment/comment_test.go | 33 +- pkg/cmd/pr/create/create.go | 13 +- pkg/cmd/pr/diff/diff.go | 26 +- pkg/cmd/pr/edit/edit.go | 25 +- pkg/cmd/pr/edit/edit_test.go | 2 - pkg/cmd/pr/merge/merge.go | 30 +- pkg/cmd/pr/merge/merge_test.go | 479 ++++++++---------- pkg/cmd/pr/ready/ready.go | 20 +- pkg/cmd/pr/reopen/reopen.go | 25 +- pkg/cmd/pr/review/review.go | 27 +- pkg/cmd/pr/shared/finder.go | 328 ++++++++++++ .../shared/{lookup_test.go => finder_test.go} | 30 +- pkg/cmd/pr/shared/lookup.go | 121 ----- pkg/cmd/pr/view/view.go | 96 +--- pkg/cmd/run/shared/shared.go | 5 +- 25 files changed, 769 insertions(+), 1036 deletions(-) create mode 100644 pkg/cmd/pr/shared/finder.go rename pkg/cmd/pr/shared/{lookup_test.go => finder_test.go} (78%) delete mode 100644 pkg/cmd/pr/shared/lookup.go diff --git a/api/queries_comments.go b/api/queries_comments.go index db6ad25e7..f02322c22 100644 --- a/api/queries_comments.go +++ b/api/queries_comments.go @@ -135,24 +135,6 @@ func CommentCreate(client *Client, repoHost string, params CommentCreateInput) ( return mutation.AddComment.CommentEdge.Node.URL, nil } -func commentsFragment() string { - return `comments(last: 1) { - nodes { - author { - login - } - authorAssociation - body - createdAt - includesCreatedEdit - isMinimized - minimizedReason - ` + reactionGroupsFragment() + ` - } - totalCount - }` -} - func (c Comment) AuthorLogin() string { return c.Author.Login } diff --git a/api/queries_issue.go b/api/queries_issue.go index 7804a7613..35e4e418f 100644 --- a/api/queries_issue.go +++ b/api/queries_issue.go @@ -98,6 +98,10 @@ type IssuesDisabledError struct { error } +type Owner struct { + Login string `json:"login"` +} + type Author struct { Login string `json:"login"` } diff --git a/api/queries_pr.go b/api/queries_pr.go index e3a575562..38f9f5ee4 100644 --- a/api/queries_pr.go +++ b/api/queries_pr.go @@ -6,7 +6,6 @@ import ( "fmt" "io" "net/http" - "sort" "strings" "time" @@ -34,6 +33,7 @@ type PullRequest struct { Number int Title string State string + Closed bool URL string BaseRefName string HeadRefName string @@ -57,10 +57,8 @@ type PullRequest struct { Author Author MergedBy *Author - HeadRepositoryOwner struct { - Login string `json:"login"` - } - HeadRepository struct { + HeadRepositoryOwner Owner + HeadRepository struct { Name string } IsCrossRepository bool @@ -77,27 +75,7 @@ type PullRequest struct { Commits struct { TotalCount int - Nodes []struct { - Commit struct { - Oid string - StatusCheckRollup struct { - Contexts struct { - Nodes []struct { - TypeName string `json:"__typename"` - Name string `json:"name"` - Context string `json:"context,omitempty"` - State string `json:"state,omitempty"` - Status string `json:"status"` - Conclusion string `json:"conclusion"` - StartedAt time.Time `json:"startedAt"` - CompletedAt time.Time `json:"completedAt"` - DetailsURL string `json:"detailsUrl"` - TargetURL string `json:"targetUrl,omitempty"` - } - } - } - } - } + Nodes []PullRequestCommit } Assignees Assignees Labels Labels @@ -113,6 +91,37 @@ type Commit struct { OID string `json:"oid"` } +type PullRequestCommit struct { + Commit PullRequestCommitCommit +} + +// PullRequestCommitCommit is like "Commit" but with StatusCheckRollup +type PullRequestCommitCommit struct { + Oid string + StatusCheckRollup struct { + Contexts struct { + Nodes []struct { + TypeName string `json:"__typename"` + Name string `json:"name"` + Context string `json:"context,omitempty"` + State string `json:"state,omitempty"` + Status string `json:"status"` + Conclusion string `json:"conclusion"` + StartedAt time.Time `json:"startedAt"` + CompletedAt time.Time `json:"completedAt"` + DetailsURL string `json:"detailsUrl"` + TargetURL string `json:"targetUrl,omitempty"` + } + } + } +} + +func (pr *PullRequest) StubCommit(oid string) { + pr.Commits.Nodes = append(pr.Commits.Nodes, PullRequestCommit{ + Commit: PullRequestCommitCommit{Oid: oid}, + }) +} + type PullRequestFile struct { Path string `json:"path"` Additions int `json:"additions"` @@ -138,14 +147,6 @@ func (r ReviewRequests) Logins() []string { return logins } -type NotFoundError struct { - error -} - -func (err *NotFoundError) Unwrap() error { - return err.error -} - func (pr PullRequest) HeadLabel() string { if pr.IsCrossRepository { return fmt.Sprintf("%s:%s", pr.HeadRepositoryOwner.Login, pr.HeadRefName) @@ -247,7 +248,7 @@ func (c Client) PullRequestDiff(baseRepo ghrepo.Interface, prNumber int) (io.Rea } if resp.StatusCode == 404 { - return nil, &NotFoundError{errors.New("pull request not found")} + return nil, errors.New("pull request not found") } else if resp.StatusCode != 200 { return nil, HandleHTTPError(resp) } @@ -560,274 +561,6 @@ func pullRequestFragment(httpClient *http.Client, hostname string) (string, erro return fragments, nil } -func prCommitsFragment(httpClient *http.Client, hostname string) (string, error) { - cachedClient := NewCachedClient(httpClient, time.Hour*24) - if prFeatures, err := determinePullRequestFeatures(cachedClient, hostname); err != nil { - return "", err - } else if !prFeatures.HasStatusCheckRollup { - return "", nil - } - - return ` - commits(last: 1) { - totalCount - nodes { - commit { - oid - statusCheckRollup { - contexts(last: 100) { - nodes { - ...on StatusContext { - context - state - targetUrl - } - ...on CheckRun { - name - status - conclusion - startedAt - completedAt - detailsUrl - } - } - } - } - } - } - } - `, nil -} - -func PullRequestByNumber(client *Client, repo ghrepo.Interface, number int) (*PullRequest, error) { - type response struct { - Repository struct { - PullRequest PullRequest - } - } - - statusesFragment, err := prCommitsFragment(client.http, repo.RepoHost()) - if err != nil { - return nil, err - } - - query := ` - query PullRequestByNumber($owner: String!, $repo: String!, $pr_number: Int!) { - repository(owner: $owner, name: $repo) { - pullRequest(number: $pr_number) { - id - url - number - title - state - closed - body - mergeable - additions - deletions - author { - login - } - ` + statusesFragment + ` - baseRefName - headRefName - headRepositoryOwner { - login - } - headRepository { - name - } - isCrossRepository - isDraft - maintainerCanModify - reviewRequests(first: 100) { - nodes { - requestedReviewer { - __typename - ...on User { - login - } - ...on Team { - name - } - } - } - totalCount - } - assignees(first: 100) { - nodes { - login - } - totalCount - } - labels(first: 100) { - nodes { - name - } - totalCount - } - projectCards(first: 100) { - nodes { - project { - name - } - column { - name - } - } - totalCount - } - milestone{ - title - } - ` + commentsFragment() + ` - ` + reactionGroupsFragment() + ` - } - } - }` - - variables := map[string]interface{}{ - "owner": repo.RepoOwner(), - "repo": repo.RepoName(), - "pr_number": number, - } - - var resp response - err = client.GraphQL(repo.RepoHost(), query, variables, &resp) - if err != nil { - return nil, err - } - - return &resp.Repository.PullRequest, nil -} - -func PullRequestForBranch(client *Client, repo ghrepo.Interface, baseBranch, headBranch string, stateFilters []string) (*PullRequest, error) { - type response struct { - Repository struct { - PullRequests struct { - Nodes []PullRequest - } - } - } - - statusesFragment, err := prCommitsFragment(client.http, repo.RepoHost()) - if err != nil { - return nil, err - } - - query := ` - query PullRequestForBranch($owner: String!, $repo: String!, $headRefName: String!, $states: [PullRequestState!]) { - repository(owner: $owner, name: $repo) { - pullRequests(headRefName: $headRefName, states: $states, first: 30, orderBy: { field: CREATED_AT, direction: DESC }) { - nodes { - id - number - title - state - body - mergeable - additions - deletions - author { - login - } - ` + statusesFragment + ` - url - baseRefName - headRefName - headRepositoryOwner { - login - } - headRepository { - name - } - isCrossRepository - isDraft - maintainerCanModify - reviewRequests(first: 100) { - nodes { - requestedReviewer { - __typename - ...on User { - login - } - ...on Team { - name - } - } - } - totalCount - } - assignees(first: 100) { - nodes { - login - } - totalCount - } - labels(first: 100) { - nodes { - name - } - totalCount - } - projectCards(first: 100) { - nodes { - project { - name - } - column { - name - } - } - totalCount - } - milestone{ - title - } - ` + commentsFragment() + ` - ` + reactionGroupsFragment() + ` - } - } - } - }` - - branchWithoutOwner := headBranch - if idx := strings.Index(headBranch, ":"); idx >= 0 { - branchWithoutOwner = headBranch[idx+1:] - } - - variables := map[string]interface{}{ - "owner": repo.RepoOwner(), - "repo": repo.RepoName(), - "headRefName": branchWithoutOwner, - "states": stateFilters, - } - - var resp response - err = client.GraphQL(repo.RepoHost(), query, variables, &resp) - if err != nil { - return nil, err - } - - prs := resp.Repository.PullRequests.Nodes - sortPullRequestsByState(prs) - - for _, pr := range prs { - if pr.HeadLabel() == headBranch && (baseBranch == "" || pr.BaseRefName == baseBranch) { - return &pr, nil - } - } - - return nil, &NotFoundError{fmt.Errorf("no pull requests found for branch %q", headBranch)} -} - -// sortPullRequestsByState sorts a PullRequest slice by open-first -func sortPullRequestsByState(prs []PullRequest) { - sort.SliceStable(prs, func(a, b int) bool { - return prs[a].State == "OPEN" - }) -} - // CreatePullRequest creates a pull request in a GitHub repository func CreatePullRequest(client *Client, repo *Repository, params map[string]interface{}) (*PullRequest, error) { query := ` diff --git a/api/queries_pr_test.go b/api/queries_pr_test.go index 5441be950..886f16dd0 100644 --- a/api/queries_pr_test.go +++ b/api/queries_pr_test.go @@ -158,32 +158,3 @@ func Test_determinePullRequestFeatures(t *testing.T) { }) } } - -func Test_sortPullRequestsByState(t *testing.T) { - prs := []PullRequest{ - { - BaseRefName: "test1", - State: "MERGED", - }, - { - BaseRefName: "test2", - State: "CLOSED", - }, - { - BaseRefName: "test3", - State: "OPEN", - }, - } - - sortPullRequestsByState(prs) - - if prs[0].BaseRefName != "test3" { - t.Errorf("prs[0]: got %s, want %q", prs[0].BaseRefName, "test3") - } - if prs[1].BaseRefName != "test1" { - t.Errorf("prs[1]: got %s, want %q", prs[1].BaseRefName, "test1") - } - if prs[2].BaseRefName != "test2" { - t.Errorf("prs[2]: got %s, want %q", prs[2].BaseRefName, "test2") - } -} diff --git a/api/reaction_groups.go b/api/reaction_groups.go index 769edc6aa..08ae53040 100644 --- a/api/reaction_groups.go +++ b/api/reaction_groups.go @@ -57,12 +57,3 @@ var reactionEmoji = map[string]string{ "ROCKET": "\U0001f680", "EYES": "\U0001f440", } - -func reactionGroupsFragment() string { - return `reactionGroups { - content - users { - totalCount - } - }` -} diff --git a/pkg/cmd/pr/checkout/checkout.go b/pkg/cmd/pr/checkout/checkout.go index f7f73bb28..03d04a1a9 100644 --- a/pkg/cmd/pr/checkout/checkout.go +++ b/pkg/cmd/pr/checkout/checkout.go @@ -24,10 +24,11 @@ type CheckoutOptions struct { HttpClient func() (*http.Client, error) Config func() (config.Config, error) IO *iostreams.IOStreams - BaseRepo func() (ghrepo.Interface, error) Remotes func() (context.Remotes, error) Branch func() (string, error) + Finder shared.PRFinder + SelectorArg string RecurseSubmodules bool Force bool @@ -48,8 +49,7 @@ func NewCmdCheckout(f *cmdutil.Factory, runF func(*CheckoutOptions) error) *cobr Short: "Check out a pull request in git", Args: cmdutil.ExactArgs(1, "argument required"), RunE: func(cmd *cobra.Command, args []string) error { - // support `-R, --repo` override - opts.BaseRepo = f.BaseRepo + opts.Finder = shared.NewFinder(f) if len(args) > 0 { opts.SelectorArg = args[0] @@ -70,18 +70,10 @@ func NewCmdCheckout(f *cmdutil.Factory, runF func(*CheckoutOptions) error) *cobr } func checkoutRun(opts *CheckoutOptions) error { - remotes, err := opts.Remotes() - if err != nil { - return err + findOptions := shared.FindOptions{ + Selector: opts.SelectorArg, } - - httpClient, err := opts.HttpClient() - if err != nil { - return err - } - apiClient := api.NewClientFromHTTP(httpClient) - - pr, baseRepo, err := shared.PRFromArgs(apiClient, opts.BaseRepo, opts.Branch, opts.Remotes, opts.SelectorArg) + pr, baseRepo, err := opts.Finder.Find(findOptions) if err != nil { return err } @@ -90,8 +82,12 @@ func checkoutRun(opts *CheckoutOptions) error { if err != nil { return err } - protocol, _ := cfg.Get(baseRepo.RepoHost(), "git_protocol") + + remotes, err := opts.Remotes() + if err != nil { + return err + } baseRemote, _ := remotes.FindByRepo(baseRepo.RepoOwner(), baseRepo.RepoName()) baseURLOrName := ghrepo.FormatRemoteURL(baseRepo, protocol) if baseRemote != nil { @@ -112,6 +108,12 @@ func checkoutRun(opts *CheckoutOptions) error { if headRemote != nil { cmdQueue = append(cmdQueue, cmdsForExistingRemote(headRemote, pr, opts)...) } else { + httpClient, err := opts.HttpClient() + if err != nil { + return err + } + apiClient := api.NewClientFromHTTP(httpClient) + defaultBranch, err := api.RepoDefaultBranch(apiClient, baseRepo) if err != nil { return err diff --git a/pkg/cmd/pr/checks/checks.go b/pkg/cmd/pr/checks/checks.go index 1a11f3d44..b9091379f 100644 --- a/pkg/cmd/pr/checks/checks.go +++ b/pkg/cmd/pr/checks/checks.go @@ -3,13 +3,10 @@ package checks import ( "errors" "fmt" - "net/http" "sort" "time" "github.com/MakeNowJust/heredoc" - "github.com/cli/cli/api" - "github.com/cli/cli/context" "github.com/cli/cli/internal/ghrepo" "github.com/cli/cli/pkg/cmd/pr/shared" "github.com/cli/cli/pkg/cmdutil" @@ -23,26 +20,19 @@ type browser interface { } type ChecksOptions struct { - HttpClient func() (*http.Client, error) - IO *iostreams.IOStreams - Browser browser - BaseRepo func() (ghrepo.Interface, error) - Branch func() (string, error) - Remotes func() (context.Remotes, error) + IO *iostreams.IOStreams + Browser browser - WebMode bool + Finder shared.PRFinder SelectorArg string + WebMode bool } func NewCmdChecks(f *cmdutil.Factory, runF func(*ChecksOptions) error) *cobra.Command { opts := &ChecksOptions{ - IO: f.IOStreams, - HttpClient: f.HttpClient, - Branch: f.Branch, - Remotes: f.Remotes, - BaseRepo: f.BaseRepo, - Browser: f.Browser, + IO: f.IOStreams, + Browser: f.Browser, } cmd := &cobra.Command{ @@ -56,8 +46,7 @@ func NewCmdChecks(f *cmdutil.Factory, runF func(*ChecksOptions) error) *cobra.Co `), Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - // support `-R, --repo` override - opts.BaseRepo = f.BaseRepo + opts.Finder = shared.NewFinder(f) if repoOverride, _ := cmd.Flags().GetString("repo"); repoOverride != "" && len(args) == 0 { return &cmdutil.FlagError{Err: errors.New("argument required when using the --repo flag")} @@ -81,13 +70,10 @@ func NewCmdChecks(f *cmdutil.Factory, runF func(*ChecksOptions) error) *cobra.Co } func checksRun(opts *ChecksOptions) error { - httpClient, err := opts.HttpClient() - if err != nil { - return err + findOptions := shared.FindOptions{ + Selector: opts.SelectorArg, } - apiClient := api.NewClientFromHTTP(httpClient) - - pr, baseRepo, err := shared.PRFromArgs(apiClient, opts.BaseRepo, opts.Branch, opts.Remotes, opts.SelectorArg) + pr, baseRepo, err := opts.Finder.Find(findOptions) if err != nil { return err } diff --git a/pkg/cmd/pr/checks/checks_test.go b/pkg/cmd/pr/checks/checks_test.go index fdcce3f9e..ca743d856 100644 --- a/pkg/cmd/pr/checks/checks_test.go +++ b/pkg/cmd/pr/checks/checks_test.go @@ -2,10 +2,8 @@ package checks import ( "bytes" - "net/http" "testing" - "github.com/cli/cli/internal/ghrepo" "github.com/cli/cli/internal/run" "github.com/cli/cli/pkg/cmdutil" "github.com/cli/cli/pkg/httpmock" @@ -174,10 +172,7 @@ func Test_checksRun(t *testing.T) { io.SetStdoutTTY(!tt.nontty) opts := &ChecksOptions{ - IO: io, - BaseRepo: func() (ghrepo.Interface, error) { - return ghrepo.New("OWNER", "REPO"), nil - }, + IO: io, SelectorArg: "123", } @@ -190,10 +185,6 @@ func Test_checksRun(t *testing.T) { reg.Register(httpmock.GraphQL(`query PullRequestByNumber\b`), httpmock.FileResponse(tt.fixture)) } - opts.HttpClient = func() (*http.Client, error) { - return &http.Client{Transport: reg}, nil - } - err := checksRun(opts) if tt.wantErr != "" { assert.EqualError(t, err, tt.wantErr) @@ -246,15 +237,9 @@ func TestChecksRun_web(t *testing.T) { defer teardown(t) err := checksRun(&ChecksOptions{ - IO: io, - Browser: browser, - WebMode: true, - HttpClient: func() (*http.Client, error) { - return &http.Client{Transport: reg}, nil - }, - BaseRepo: func() (ghrepo.Interface, error) { - return ghrepo.New("OWNER", "REPO"), nil - }, + IO: io, + Browser: browser, + WebMode: true, SelectorArg: "123", }) assert.NoError(t, err) diff --git a/pkg/cmd/pr/close/close.go b/pkg/cmd/pr/close/close.go index 850bd7631..a1e6e3dff 100644 --- a/pkg/cmd/pr/close/close.go +++ b/pkg/cmd/pr/close/close.go @@ -6,8 +6,6 @@ import ( "github.com/cli/cli/api" "github.com/cli/cli/git" - "github.com/cli/cli/internal/config" - "github.com/cli/cli/internal/ghrepo" "github.com/cli/cli/pkg/cmd/pr/shared" "github.com/cli/cli/pkg/cmdutil" "github.com/cli/cli/pkg/iostreams" @@ -16,11 +14,11 @@ import ( type CloseOptions struct { HttpClient func() (*http.Client, error) - Config func() (config.Config, error) IO *iostreams.IOStreams - BaseRepo func() (ghrepo.Interface, error) Branch func() (string, error) + Finder shared.PRFinder + SelectorArg string DeleteBranch bool DeleteLocalBranch bool @@ -30,7 +28,6 @@ func NewCmdClose(f *cmdutil.Factory, runF func(*CloseOptions) error) *cobra.Comm opts := &CloseOptions{ IO: f.IOStreams, HttpClient: f.HttpClient, - Config: f.Config, Branch: f.Branch, } @@ -39,8 +36,7 @@ func NewCmdClose(f *cmdutil.Factory, runF func(*CloseOptions) error) *cobra.Comm Short: "Close a pull request", Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - // support `-R, --repo` override - opts.BaseRepo = f.BaseRepo + opts.Finder = shared.NewFinder(f) if len(args) > 0 { opts.SelectorArg = args[0] @@ -62,13 +58,10 @@ func NewCmdClose(f *cmdutil.Factory, runF func(*CloseOptions) error) *cobra.Comm func closeRun(opts *CloseOptions) error { cs := opts.IO.ColorScheme() - httpClient, err := opts.HttpClient() - if err != nil { - return err + findOptions := shared.FindOptions{ + Selector: opts.SelectorArg, } - apiClient := api.NewClientFromHTTP(httpClient) - - pr, baseRepo, err := shared.PRFromArgs(apiClient, opts.BaseRepo, nil, nil, opts.SelectorArg) + pr, baseRepo, err := opts.Finder.Find(findOptions) if err != nil { return err } @@ -81,6 +74,12 @@ func closeRun(opts *CloseOptions) error { return nil } + httpClient, err := opts.HttpClient() + if err != nil { + return err + } + apiClient := api.NewClientFromHTTP(httpClient) + err = api.PullRequestClose(apiClient, baseRepo, pr) if err != nil { return fmt.Errorf("API call failed: %w", err) diff --git a/pkg/cmd/pr/comment/comment.go b/pkg/cmd/pr/comment/comment.go index a9ec5e9d3..85845259c 100644 --- a/pkg/cmd/pr/comment/comment.go +++ b/pkg/cmd/pr/comment/comment.go @@ -2,11 +2,8 @@ package comment import ( "errors" - "net/http" "github.com/MakeNowJust/heredoc" - "github.com/cli/cli/api" - "github.com/cli/cli/context" "github.com/cli/cli/internal/ghrepo" "github.com/cli/cli/pkg/cmd/pr/shared" "github.com/cli/cli/pkg/cmdutil" @@ -48,7 +45,13 @@ func NewCmdComment(f *cmdutil.Factory, runF func(*shared.CommentableOptions) err if len(args) > 0 { selector = args[0] } - opts.RetrieveCommentable = retrievePR(f.HttpClient, f.BaseRepo, f.Branch, f.Remotes, selector) + finder := shared.NewFinder(f) + opts.RetrieveCommentable = func() (shared.Commentable, ghrepo.Interface, error) { + return finder.Find(shared.FindOptions{ + Selector: selector, + Fields: []string{"id", "url"}, + }) + } return shared.CommentablePreRun(cmd, opts) }, RunE: func(cmd *cobra.Command, args []string) error { @@ -74,24 +77,3 @@ func NewCmdComment(f *cmdutil.Factory, runF func(*shared.CommentableOptions) err return cmd } - -func retrievePR(httpClient func() (*http.Client, error), - baseRepo func() (ghrepo.Interface, error), - branch func() (string, error), - remotes func() (context.Remotes, error), - selector string) func() (shared.Commentable, ghrepo.Interface, error) { - return func() (shared.Commentable, ghrepo.Interface, error) { - httpClient, err := httpClient() - if err != nil { - return nil, nil, err - } - apiClient := api.NewClientFromHTTP(httpClient) - - pr, repo, err := shared.PRFromArgs(apiClient, baseRepo, branch, remotes, selector) - if err != nil { - return nil, nil, err - } - - return pr, repo, nil - } -} diff --git a/pkg/cmd/pr/comment/comment_test.go b/pkg/cmd/pr/comment/comment_test.go index 429af7cda..859a57069 100644 --- a/pkg/cmd/pr/comment/comment_test.go +++ b/pkg/cmd/pr/comment/comment_test.go @@ -8,7 +8,7 @@ import ( "path/filepath" "testing" - "github.com/cli/cli/context" + "github.com/cli/cli/api" "github.com/cli/cli/internal/ghrepo" "github.com/cli/cli/pkg/cmd/pr/shared" "github.com/cli/cli/pkg/cmdutil" @@ -224,7 +224,6 @@ func Test_commentRun(t *testing.T) { ConfirmSubmitSurvey: func() (bool, error) { return true, nil }, }, httpStubs: func(t *testing.T, reg *httpmock.Registry) { - mockPullRequestFromNumber(t, reg) mockCommentCreate(t, reg) }, stdout: "https://github.com/OWNER/REPO/pull/123#issuecomment-456\n", @@ -238,9 +237,6 @@ func Test_commentRun(t *testing.T) { OpenInBrowser: func(string) error { return nil }, }, - httpStubs: func(t *testing.T, reg *httpmock.Registry) { - mockPullRequestFromNumber(t, reg) - }, stderr: "Opening github.com/OWNER/REPO/pull/123 in your browser.\n", }, { @@ -253,7 +249,6 @@ func Test_commentRun(t *testing.T) { EditSurvey: func() (string, error) { return "comment body", nil }, }, httpStubs: func(t *testing.T, reg *httpmock.Registry) { - mockPullRequestFromNumber(t, reg) mockCommentCreate(t, reg) }, stdout: "https://github.com/OWNER/REPO/pull/123#issuecomment-456\n", @@ -266,7 +261,6 @@ func Test_commentRun(t *testing.T) { Body: "comment body", }, httpStubs: func(t *testing.T, reg *httpmock.Registry) { - mockPullRequestFromNumber(t, reg) mockCommentCreate(t, reg) }, stdout: "https://github.com/OWNER/REPO/pull/123#issuecomment-456\n", @@ -280,16 +274,20 @@ func Test_commentRun(t *testing.T) { reg := &httpmock.Registry{} defer reg.Verify(t) - tt.httpStubs(t, reg) + if tt.httpStubs != nil { + tt.httpStubs(t, reg) + } httpClient := func() (*http.Client, error) { return &http.Client{Transport: reg}, nil } - baseRepo := func() (ghrepo.Interface, error) { return ghrepo.New("OWNER", "REPO"), nil } - branch := func() (string, error) { return "", nil } - remotes := func() (context.Remotes, error) { return nil, nil } tt.input.IO = io tt.input.HttpClient = httpClient - tt.input.RetrieveCommentable = retrievePR(httpClient, baseRepo, branch, remotes, "123") + tt.input.RetrieveCommentable = func() (shared.Commentable, ghrepo.Interface, error) { + return &api.PullRequest{ + Number: 123, + URL: "https://github.com/OWNER/REPO/pull/123", + }, ghrepo.New("OWNER", "REPO"), nil + } t.Run(tt.name, func(t *testing.T) { err := shared.CommentableRun(tt.input) @@ -300,17 +298,6 @@ func Test_commentRun(t *testing.T) { } } -func mockPullRequestFromNumber(_ *testing.T, reg *httpmock.Registry) { - reg.Register( - httpmock.GraphQL(`query PullRequestByNumber\b`), - httpmock.StringResponse(` - { "data": { "repository": { "pullRequest": { - "number": 123, - "url": "https://github.com/OWNER/REPO/pull/123" - } } } }`), - ) -} - func mockCommentCreate(t *testing.T, reg *httpmock.Registry) { reg.Register( httpmock.GraphQL(`mutation CommentCreate\b`), diff --git a/pkg/cmd/pr/create/create.go b/pkg/cmd/pr/create/create.go index fc2e70e26..c368e4e43 100644 --- a/pkg/cmd/pr/create/create.go +++ b/pkg/cmd/pr/create/create.go @@ -36,6 +36,7 @@ type CreateOptions struct { Remotes func() (context.Remotes, error) Branch func() (string, error) Browser browser + Finder shared.PRFinder TitleProvided bool BodyProvided bool @@ -117,6 +118,8 @@ func NewCmdCreate(f *cmdutil.Factory, runF func(*CreateOptions) error) *cobra.Co `), Args: cmdutil.NoArgsQuoteReminder, RunE: func(cmd *cobra.Command, args []string) error { + opts.Finder = shared.NewFinder(f) + opts.TitleProvided = cmd.Flags().Changed("title") opts.RepoOverride, _ = cmd.Flags().GetString("repo") noMaintainerEdit, _ := cmd.Flags().GetBool("no-maintainer-edit") @@ -220,9 +223,13 @@ func createRun(opts *CreateOptions) (err error) { state.Body = opts.Body } - existingPR, err := api.PullRequestForBranch( - client, ctx.BaseRepo, ctx.BaseBranch, ctx.HeadBranchLabel, []string{"OPEN"}) - var notFound *api.NotFoundError + existingPR, _, err := opts.Finder.Find(shared.FindOptions{ + Selector: ctx.HeadBranchLabel, + BaseBranch: ctx.BaseBranch, + States: []string{"OPEN"}, + Fields: []string{"url"}, + }) + var notFound *shared.NotFoundError if err != nil && !errors.As(err, ¬Found) { return fmt.Errorf("error checking for existing pull request: %w", err) } diff --git a/pkg/cmd/pr/diff/diff.go b/pkg/cmd/pr/diff/diff.go index 00a41c657..fa040a4aa 100644 --- a/pkg/cmd/pr/diff/diff.go +++ b/pkg/cmd/pr/diff/diff.go @@ -11,8 +11,6 @@ import ( "github.com/MakeNowJust/heredoc" "github.com/cli/cli/api" - "github.com/cli/cli/context" - "github.com/cli/cli/internal/ghrepo" "github.com/cli/cli/pkg/cmd/pr/shared" "github.com/cli/cli/pkg/cmdutil" "github.com/cli/cli/pkg/iostreams" @@ -22,9 +20,8 @@ import ( type DiffOptions struct { HttpClient func() (*http.Client, error) IO *iostreams.IOStreams - BaseRepo func() (ghrepo.Interface, error) - Remotes func() (context.Remotes, error) - Branch func() (string, error) + + Finder shared.PRFinder SelectorArg string UseColor string @@ -34,8 +31,6 @@ func NewCmdDiff(f *cmdutil.Factory, runF func(*DiffOptions) error) *cobra.Comman opts := &DiffOptions{ IO: f.IOStreams, HttpClient: f.HttpClient, - Remotes: f.Remotes, - Branch: f.Branch, } cmd := &cobra.Command{ @@ -49,8 +44,7 @@ func NewCmdDiff(f *cmdutil.Factory, runF func(*DiffOptions) error) *cobra.Comman `), Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - // support `-R, --repo` override - opts.BaseRepo = f.BaseRepo + opts.Finder = shared.NewFinder(f) if repoOverride, _ := cmd.Flags().GetString("repo"); repoOverride != "" && len(args) == 0 { return &cmdutil.FlagError{Err: errors.New("argument required when using the --repo flag")} @@ -81,17 +75,21 @@ func NewCmdDiff(f *cmdutil.Factory, runF func(*DiffOptions) error) *cobra.Comman } func diffRun(opts *DiffOptions) error { + findOptions := shared.FindOptions{ + Selector: opts.SelectorArg, + Fields: []string{"number"}, + } + pr, baseRepo, err := opts.Finder.Find(findOptions) + if err != nil { + return err + } + httpClient, err := opts.HttpClient() if err != nil { return err } apiClient := api.NewClientFromHTTP(httpClient) - pr, baseRepo, err := shared.PRFromArgs(apiClient, opts.BaseRepo, opts.Branch, opts.Remotes, opts.SelectorArg) - if err != nil { - return err - } - diff, err := apiClient.PullRequestDiff(baseRepo, pr.Number) if err != nil { return fmt.Errorf("could not find pull request diff: %w", err) diff --git a/pkg/cmd/pr/edit/edit.go b/pkg/cmd/pr/edit/edit.go index dbf0321f9..888dc6ec2 100644 --- a/pkg/cmd/pr/edit/edit.go +++ b/pkg/cmd/pr/edit/edit.go @@ -7,7 +7,6 @@ import ( "github.com/MakeNowJust/heredoc" "github.com/cli/cli/api" - "github.com/cli/cli/context" "github.com/cli/cli/internal/config" "github.com/cli/cli/internal/ghrepo" shared "github.com/cli/cli/pkg/cmd/pr/shared" @@ -20,10 +19,8 @@ import ( type EditOptions struct { HttpClient func() (*http.Client, error) IO *iostreams.IOStreams - BaseRepo func() (ghrepo.Interface, error) - Remotes func() (context.Remotes, error) - Branch func() (string, error) + Finder shared.PRFinder Surveyor Surveyor Fetcher EditableOptionsFetcher EditorRetriever EditorRetriever @@ -38,8 +35,6 @@ func NewCmdEdit(f *cmdutil.Factory, runF func(*EditOptions) error) *cobra.Comman opts := &EditOptions{ IO: f.IOStreams, HttpClient: f.HttpClient, - Remotes: f.Remotes, - Branch: f.Branch, Surveyor: surveyor{}, Fetcher: fetcher{}, EditorRetriever: editorRetriever{config: f.Config}, @@ -66,8 +61,7 @@ func NewCmdEdit(f *cmdutil.Factory, runF func(*EditOptions) error) *cobra.Comman `), Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - // support `-R, --repo` override - opts.BaseRepo = f.BaseRepo + opts.Finder = shared.NewFinder(f) if len(args) > 0 { opts.SelectorArg = args[0] @@ -155,13 +149,10 @@ func NewCmdEdit(f *cmdutil.Factory, runF func(*EditOptions) error) *cobra.Comman } func editRun(opts *EditOptions) error { - httpClient, err := opts.HttpClient() - if err != nil { - return err + findOptions := shared.FindOptions{ + Selector: opts.SelectorArg, } - apiClient := api.NewClientFromHTTP(httpClient) - - pr, repo, err := shared.PRFromArgs(apiClient, opts.BaseRepo, opts.Branch, opts.Remotes, opts.SelectorArg) + pr, repo, err := opts.Finder.Find(findOptions) if err != nil { return err } @@ -184,6 +175,12 @@ func editRun(opts *EditOptions) error { } } + httpClient, err := opts.HttpClient() + if err != nil { + return err + } + apiClient := api.NewClientFromHTTP(httpClient) + opts.IO.StartProgressIndicator() err = opts.Fetcher.EditableOptionsFetch(apiClient, repo, &editable) opts.IO.StopProgressIndicator() diff --git a/pkg/cmd/pr/edit/edit_test.go b/pkg/cmd/pr/edit/edit_test.go index 586036910..c918f4a4c 100644 --- a/pkg/cmd/pr/edit/edit_test.go +++ b/pkg/cmd/pr/edit/edit_test.go @@ -450,11 +450,9 @@ func Test_editRun(t *testing.T) { tt.httpStubs(t, reg) httpClient := func() (*http.Client, error) { return &http.Client{Transport: reg}, nil } - baseRepo := func() (ghrepo.Interface, error) { return ghrepo.New("OWNER", "REPO"), nil } tt.input.IO = io tt.input.HttpClient = httpClient - tt.input.BaseRepo = baseRepo t.Run(tt.name, func(t *testing.T) { err := editRun(tt.input) diff --git a/pkg/cmd/pr/merge/merge.go b/pkg/cmd/pr/merge/merge.go index 31a91d6be..298d59b7f 100644 --- a/pkg/cmd/pr/merge/merge.go +++ b/pkg/cmd/pr/merge/merge.go @@ -8,10 +8,8 @@ import ( "github.com/AlecAivazis/survey/v2" "github.com/MakeNowJust/heredoc" "github.com/cli/cli/api" - "github.com/cli/cli/context" "github.com/cli/cli/git" "github.com/cli/cli/internal/config" - "github.com/cli/cli/internal/ghrepo" "github.com/cli/cli/pkg/cmd/pr/shared" "github.com/cli/cli/pkg/cmdutil" "github.com/cli/cli/pkg/iostreams" @@ -26,12 +24,11 @@ type editor interface { type MergeOptions struct { HttpClient func() (*http.Client, error) - Config func() (config.Config, error) IO *iostreams.IOStreams - BaseRepo func() (ghrepo.Interface, error) - Remotes func() (context.Remotes, error) Branch func() (string, error) + Finder shared.PRFinder + SelectorArg string DeleteBranch bool MergeMethod PullRequestMergeMethod @@ -52,8 +49,6 @@ func NewCmdMerge(f *cmdutil.Factory, runF func(*MergeOptions) error) *cobra.Comm opts := &MergeOptions{ IO: f.IOStreams, HttpClient: f.HttpClient, - Config: f.Config, - Remotes: f.Remotes, Branch: f.Branch, } @@ -76,8 +71,7 @@ func NewCmdMerge(f *cmdutil.Factory, runF func(*MergeOptions) error) *cobra.Comm `), Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - // support `-R, --repo` override - opts.BaseRepo = f.BaseRepo + opts.Finder = shared.NewFinder(f) if repoOverride, _ := cmd.Flags().GetString("repo"); repoOverride != "" && len(args) == 0 { return &cmdutil.FlagError{Err: errors.New("argument required when using the --repo flag")} @@ -136,7 +130,7 @@ func NewCmdMerge(f *cmdutil.Factory, runF func(*MergeOptions) error) *cobra.Comm opts.Editor = &userEditor{ io: opts.IO, - config: opts.Config, + config: f.Config, } if runF != nil { @@ -160,19 +154,23 @@ func NewCmdMerge(f *cmdutil.Factory, runF func(*MergeOptions) error) *cobra.Comm func mergeRun(opts *MergeOptions) error { cs := opts.IO.ColorScheme() - httpClient, err := opts.HttpClient() - if err != nil { - return err + findOptions := shared.FindOptions{ + Selector: opts.SelectorArg, + Fields: []string{"id", "number", "state", "title", "commits", "mergeable", "headRepositoryOwner", "headRefName"}, } - apiClient := api.NewClientFromHTTP(httpClient) - - pr, baseRepo, err := shared.PRFromArgs(apiClient, opts.BaseRepo, opts.Branch, opts.Remotes, opts.SelectorArg) + pr, baseRepo, err := opts.Finder.Find(findOptions) if err != nil { return err } isTerminal := opts.IO.IsStdoutTTY() + httpClient, err := opts.HttpClient() + if err != nil { + return err + } + apiClient := api.NewClientFromHTTP(httpClient) + if opts.AutoMergeDisable { err := disableAutoMerge(httpClient, baseRepo, pr.ID) if err != nil { diff --git a/pkg/cmd/pr/merge/merge_test.go b/pkg/cmd/pr/merge/merge_test.go index 16ad0ed6d..df562e194 100644 --- a/pkg/cmd/pr/merge/merge_test.go +++ b/pkg/cmd/pr/merge/merge_test.go @@ -13,11 +13,9 @@ import ( "github.com/MakeNowJust/heredoc" "github.com/cli/cli/api" - "github.com/cli/cli/context" - "github.com/cli/cli/git" - "github.com/cli/cli/internal/config" "github.com/cli/cli/internal/ghrepo" "github.com/cli/cli/internal/run" + "github.com/cli/cli/pkg/cmd/pr/shared" "github.com/cli/cli/pkg/cmdutil" "github.com/cli/cli/pkg/httpmock" "github.com/cli/cli/pkg/iostreams" @@ -197,6 +195,14 @@ func Test_NewCmdMerge(t *testing.T) { } } +func baseRepo(owner, repo, branch string) ghrepo.Interface { + return api.InitRepoHostname(&api.Repository{ + Name: repo, + Owner: api.RepositoryOwner{Login: owner}, + DefaultBranchRef: api.BranchRef{Name: branch}, + }, "github.com") +} + func runCommand(rt http.RoundTripper, branch string, isTTY bool, cli string) (*test.CmdOut, error) { io, _, stdout, stderr := iostreams.Test() io.SetStdoutTTY(isTTY) @@ -208,24 +214,6 @@ func runCommand(rt http.RoundTripper, branch string, isTTY bool, cli string) (*t HttpClient: func() (*http.Client, error) { return &http.Client{Transport: rt}, nil }, - Config: func() (config.Config, error) { - return config.NewBlankConfig(), nil - }, - BaseRepo: func() (ghrepo.Interface, error) { - return api.InitRepoHostname(&api.Repository{ - Name: "REPO", - Owner: api.RepositoryOwner{Login: "OWNER"}, - DefaultBranchRef: api.BranchRef{Name: "master"}, - }, "github.com"), nil - }, - Remotes: func() (context.Remotes, error) { - return context.Remotes{ - { - Remote: &git.Remote{Name: "origin"}, - Repo: ghrepo.New("OWNER", "REPO"), - }, - }, nil - }, Branch: func() (string, error) { return branch, nil }, @@ -259,17 +247,18 @@ func initFakeHTTP() *httpmock.Registry { func TestPrMerge(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - http.Register( - httpmock.GraphQL(`query PullRequestByNumber\b`), - httpmock.StringResponse(` - { "data": { "repository": { "pullRequest": { - "id": "THE-ID", - "number": 1, - "title": "The title of the PR", - "state": "OPEN", - "headRefName": "blueberries", - "headRepositoryOwner": {"login": "OWNER"} - } } } }`)) + + shared.RunCommandFinder( + "1", + &api.PullRequest{ + ID: "THE-ID", + Number: 1, + State: "OPEN", + Title: "The title of the PR", + }, + baseRepo("OWNER", "REPO", "master"), + ) + http.Register( httpmock.GraphQL(`mutation PullRequestMerge\b`), httpmock.GraphQLMutation(`{}`, func(input map[string]interface{}) { @@ -296,17 +285,18 @@ func TestPrMerge(t *testing.T) { func TestPrMerge_nontty(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - http.Register( - httpmock.GraphQL(`query PullRequestByNumber\b`), - httpmock.StringResponse(` - { "data": { "repository": { "pullRequest": { - "id": "THE-ID", - "number": 1, - "title": "The title of the PR", - "state": "OPEN", - "headRefName": "blueberries", - "headRepositoryOwner": {"login": "OWNER"} - } } } }`)) + + shared.RunCommandFinder( + "1", + &api.PullRequest{ + ID: "THE-ID", + Number: 1, + State: "OPEN", + Title: "The title of the PR", + }, + baseRepo("OWNER", "REPO", "master"), + ) + http.Register( httpmock.GraphQL(`mutation PullRequestMerge\b`), httpmock.GraphQLMutation(`{}`, func(input map[string]interface{}) { @@ -330,17 +320,18 @@ func TestPrMerge_nontty(t *testing.T) { func TestPrMerge_withRepoFlag(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - http.Register( - httpmock.GraphQL(`query PullRequestByNumber\b`), - httpmock.StringResponse(` - { "data": { "repository": { "pullRequest": { - "id": "THE-ID", - "number": 1, - "title": "The title of the PR", - "state": "OPEN", - "headRefName": "blueberries", - "headRepositoryOwner": {"login": "OWNER"} - } } } }`)) + + shared.RunCommandFinder( + "1", + &api.PullRequest{ + ID: "THE-ID", + Number: 1, + State: "OPEN", + Title: "The title of the PR", + }, + baseRepo("OWNER", "REPO", "master"), + ) + http.Register( httpmock.GraphQL(`mutation PullRequestMerge\b`), httpmock.GraphQLMutation(`{}`, func(input map[string]interface{}) { @@ -367,10 +358,19 @@ func TestPrMerge_withRepoFlag(t *testing.T) { func TestPrMerge_deleteBranch(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - http.Register( - httpmock.GraphQL(`query PullRequestForBranch\b`), - // FIXME: references fixture from another package - httpmock.FileResponse("../view/fixtures/prViewPreviewWithMetadataByBranch.json")) + + shared.RunCommandFinder( + "", + &api.PullRequest{ + ID: "PR_10", + Number: 10, + State: "OPEN", + Title: "Blueberries are a good fruit", + HeadRefName: "blueberries", + }, + baseRepo("OWNER", "REPO", "master"), + ) + http.Register( httpmock.GraphQL(`mutation PullRequestMerge\b`), httpmock.GraphQLMutation(`{}`, func(input map[string]interface{}) { @@ -385,8 +385,6 @@ func TestPrMerge_deleteBranch(t *testing.T) { cs, cmdTeardown := run.Stub() defer cmdTeardown(t) - cs.Register(`git .+ show .+ HEAD`, 1, "") - cs.Register(`git config --get-regexp.+branch\\\.blueberries\\\.`, 0, "") cs.Register(`git checkout master`, 0, "") cs.Register(`git rev-parse --verify refs/heads/blueberries`, 0, "") cs.Register(`git branch -D blueberries`, 0, "") @@ -406,10 +404,19 @@ func TestPrMerge_deleteBranch(t *testing.T) { func TestPrMerge_deleteNonCurrentBranch(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - http.Register( - httpmock.GraphQL(`query PullRequestForBranch\b`), - // FIXME: references fixture from another package - httpmock.FileResponse("../view/fixtures/prViewPreviewWithMetadataByBranch.json")) + + shared.RunCommandFinder( + "blueberries", + &api.PullRequest{ + ID: "PR_10", + Number: 10, + State: "OPEN", + Title: "Blueberries are a good fruit", + HeadRefName: "blueberries", + }, + baseRepo("OWNER", "REPO", "master"), + ) + http.Register( httpmock.GraphQL(`mutation PullRequestMerge\b`), httpmock.GraphQLMutation(`{}`, func(input map[string]interface{}) { @@ -439,59 +446,24 @@ func TestPrMerge_deleteNonCurrentBranch(t *testing.T) { `), output.Stderr()) } -func TestPrMerge_noPrNumberGiven(t *testing.T) { - http := initFakeHTTP() - defer http.Verify(t) - http.Register( - httpmock.GraphQL(`query PullRequestForBranch\b`), - // FIXME: references fixture from another package - httpmock.FileResponse("../view/fixtures/prViewPreviewWithMetadataByBranch.json")) - http.Register( - httpmock.GraphQL(`mutation PullRequestMerge\b`), - httpmock.GraphQLMutation(`{}`, func(input map[string]interface{}) { - assert.Equal(t, "PR_10", input["pullRequestId"].(string)) - assert.Equal(t, "MERGE", input["mergeMethod"].(string)) - assert.NotContains(t, input, "commitHeadline") - })) - - cs, cmdTeardown := run.Stub() - defer cmdTeardown(t) - - cs.Register(`git .+ show .+ HEAD`, 1, "") - cs.Register(`git config --get-regexp.+branch\\\.blueberries\\\.`, 0, "") - - output, err := runCommand(http, "blueberries", true, "pr merge --merge") - if err != nil { - t.Fatalf("error running command `pr merge`: %v", err) - } - - assert.Equal(t, "", output.String()) - assert.Equal(t, heredoc.Doc(` - ✓ Merged pull request #10 (Blueberries are a good fruit) - `), output.Stderr()) -} - func Test_nonDivergingPullRequest(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - http.Register( - httpmock.GraphQL(`query PullRequestForBranch\b`), - httpmock.StringResponse(` - { "data": { "repository": { "pullRequests": { "nodes": [{ - "headRefName": "blueberries", - "headRepositoryOwner": {"login": "OWNER"}, - "id": "PR_10", - "title": "Blueberries are a good fruit", - "number": 10, - "commits": { - "nodes": [{ - "commit": { - "oid": "COMMITSHA1" - } - }], - "totalCount": 1 - } - }] } } } }`)) + + pr := &api.PullRequest{ + ID: "PR_10", + Number: 10, + Title: "Blueberries are a good fruit", + State: "OPEN", + } + pr.StubCommit("COMMITSHA1") + + shared.RunCommandFinder( + "", + pr, + baseRepo("OWNER", "REPO", "master"), + ) + http.Register( httpmock.GraphQL(`mutation PullRequestMerge\b`), httpmock.GraphQLMutation(`{}`, func(input map[string]interface{}) { @@ -504,7 +476,6 @@ func Test_nonDivergingPullRequest(t *testing.T) { defer cmdTeardown(t) cs.Register(`git .+ show .+ HEAD`, 0, "COMMITSHA1,title") - cs.Register(`git config --get-regexp.+branch\\\.blueberries\\\.`, 0, "") output, err := runCommand(http, "blueberries", true, "pr merge --merge") if err != nil { @@ -519,24 +490,21 @@ func Test_nonDivergingPullRequest(t *testing.T) { func Test_divergingPullRequestWarning(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - http.Register( - httpmock.GraphQL(`query PullRequestForBranch\b`), - httpmock.StringResponse(` - { "data": { "repository": { "pullRequests": { "nodes": [{ - "headRefName": "blueberries", - "headRepositoryOwner": {"login": "OWNER"}, - "id": "PR_10", - "title": "Blueberries are a good fruit", - "number": 10, - "commits": { - "nodes": [{ - "commit": { - "oid": "COMMITSHA1" - } - }], - "totalCount": 1 - } - }] } } } }`)) + + pr := &api.PullRequest{ + ID: "PR_10", + Number: 10, + Title: "Blueberries are a good fruit", + State: "OPEN", + } + pr.StubCommit("COMMITSHA1") + + shared.RunCommandFinder( + "", + pr, + baseRepo("OWNER", "REPO", "master"), + ) + http.Register( httpmock.GraphQL(`mutation PullRequestMerge\b`), httpmock.GraphQLMutation(`{}`, func(input map[string]interface{}) { @@ -549,7 +517,6 @@ func Test_divergingPullRequestWarning(t *testing.T) { defer cmdTeardown(t) cs.Register(`git .+ show .+ HEAD`, 0, "COMMITSHA2,title") - cs.Register(`git config --get-regexp.+branch\\\.blueberries\\\.`, 0, "") output, err := runCommand(http, "blueberries", true, "pr merge --merge") if err != nil { @@ -565,20 +532,18 @@ func Test_divergingPullRequestWarning(t *testing.T) { func Test_pullRequestWithoutCommits(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - http.Register( - httpmock.GraphQL(`query PullRequestForBranch\b`), - httpmock.StringResponse(` - { "data": { "repository": { "pullRequests": { "nodes": [{ - "headRefName": "blueberries", - "headRepositoryOwner": {"login": "OWNER"}, - "id": "PR_10", - "title": "Blueberries are a good fruit", - "number": 10, - "commits": { - "nodes": [], - "totalCount": 0 - } - }] } } } }`)) + + shared.RunCommandFinder( + "", + &api.PullRequest{ + ID: "PR_10", + Number: 10, + Title: "Blueberries are a good fruit", + State: "OPEN", + }, + baseRepo("OWNER", "REPO", "master"), + ) + http.Register( httpmock.GraphQL(`mutation PullRequestMerge\b`), httpmock.GraphQLMutation(`{}`, func(input map[string]interface{}) { @@ -587,11 +552,9 @@ func Test_pullRequestWithoutCommits(t *testing.T) { assert.NotContains(t, input, "commitHeadline") })) - cs, cmdTeardown := run.Stub() + _, cmdTeardown := run.Stub() defer cmdTeardown(t) - cs.Register(`git config --get-regexp.+branch\\\.blueberries\\\.`, 0, "") - output, err := runCommand(http, "blueberries", true, "pr merge --merge") if err != nil { t.Fatalf("error running command `pr merge`: %v", err) @@ -605,17 +568,18 @@ func Test_pullRequestWithoutCommits(t *testing.T) { func TestPrMerge_rebase(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - http.Register( - httpmock.GraphQL(`query PullRequestByNumber\b`), - httpmock.StringResponse(` - { "data": { "repository": { "pullRequest": { - "id": "THE-ID", - "number": 2, - "title": "The title of the PR", - "state": "OPEN", - "headRefName": "blueberries", - "headRepositoryOwner": {"login": "OWNER"} - } } } }`)) + + shared.RunCommandFinder( + "2", + &api.PullRequest{ + ID: "THE-ID", + Number: 2, + Title: "The title of the PR", + State: "OPEN", + }, + baseRepo("OWNER", "REPO", "master"), + ) + http.Register( httpmock.GraphQL(`mutation PullRequestMerge\b`), httpmock.GraphQLMutation(`{}`, func(input map[string]interface{}) { @@ -642,17 +606,18 @@ func TestPrMerge_rebase(t *testing.T) { func TestPrMerge_squash(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - http.Register( - httpmock.GraphQL(`query PullRequestByNumber\b`), - httpmock.StringResponse(` - { "data": { "repository": { "pullRequest": { - "id": "THE-ID", - "number": 3, - "title": "The title of the PR", - "state": "OPEN", - "headRefName": "blueberries", - "headRepositoryOwner": {"login": "OWNER"} - } } } }`)) + + shared.RunCommandFinder( + "3", + &api.PullRequest{ + ID: "THE-ID", + Number: 3, + Title: "The title of the PR", + State: "OPEN", + }, + baseRepo("OWNER", "REPO", "master"), + ) + http.Register( httpmock.GraphQL(`mutation PullRequestMerge\b`), httpmock.GraphQLMutation(`{}`, func(input map[string]interface{}) { @@ -678,22 +643,18 @@ func TestPrMerge_squash(t *testing.T) { func TestPrMerge_alreadyMerged(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - http.Register( - httpmock.GraphQL(`query PullRequestByNumber\b`), - httpmock.StringResponse(` - { "data": { "repository": { - "pullRequest": { - "number": 4, - "title": "The title of the PR", - "state": "MERGED", - "baseRefName": "master", - "headRefName": "blueberries", - "headRepositoryOwner": { - "login": "OWNER" - }, - "isCrossRepository": false - } - } } }`)) + + shared.RunCommandFinder( + "4", + &api.PullRequest{ + ID: "THE-ID", + Number: 4, + State: "MERGED", + HeadRefName: "blueberries", + BaseRefName: "master", + }, + baseRepo("OWNER", "REPO", "master"), + ) cs, cmdTeardown := run.Stub() defer cmdTeardown(t) @@ -718,12 +679,17 @@ func TestPrMerge_alreadyMerged(t *testing.T) { func TestPrMerge_alreadyMerged_nonInteractive(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - http.Register( - httpmock.GraphQL(`query PullRequestByNumber\b`), - httpmock.StringResponse(` - { "data": { "repository": { - "pullRequest": { "number": 4, "title": "The title of the PR", "state": "MERGED"} - } } }`)) + + shared.RunCommandFinder( + "4", + &api.PullRequest{ + ID: "THE-ID", + Number: 4, + State: "MERGED", + HeadRepositoryOwner: api.Owner{Login: "monalisa"}, + }, + baseRepo("OWNER", "REPO", "master"), + ) _, cmdTeardown := run.Stub() defer cmdTeardown(t) @@ -740,15 +706,18 @@ func TestPrMerge_alreadyMerged_nonInteractive(t *testing.T) { func TestPRMerge_interactive(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - http.Register( - httpmock.GraphQL(`query PullRequestForBranch\b`), - httpmock.StringResponse(` - { "data": { "repository": { "pullRequests": { "nodes": [{ - "headRefName": "blueberries", - "headRepositoryOwner": {"login": "OWNER"}, - "id": "THE-ID", - "number": 3 - }] } } } }`)) + + shared.RunCommandFinder( + "", + &api.PullRequest{ + ID: "THE-ID", + Number: 3, + Title: "It was the best of times", + HeadRefName: "blueberries", + }, + baseRepo("OWNER", "REPO", "master"), + ) + http.Register( httpmock.GraphQL(`query RepositoryInfo\b`), httpmock.StringResponse(` @@ -765,11 +734,9 @@ func TestPRMerge_interactive(t *testing.T) { assert.NotContains(t, input, "commitHeadline") })) - cs, cmdTeardown := run.Stub() + _, cmdTeardown := run.Stub() defer cmdTeardown(t) - cs.Register(`git config --get-regexp.+branch\\\.blueberries\\\.`, 0, "") - as, surveyTeardown := prompt.InitAskStubber() defer surveyTeardown() @@ -789,16 +756,18 @@ func TestPRMerge_interactive(t *testing.T) { func TestPRMerge_interactiveWithDeleteBranch(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - http.Register( - httpmock.GraphQL(`query PullRequestForBranch\b`), - httpmock.StringResponse(` - { "data": { "repository": { "pullRequests": { "nodes": [{ - "headRefName": "blueberries", - "headRepositoryOwner": {"login": "OWNER"}, - "id": "THE-ID", - "title": "It was the best of times", - "number": 3 - }] } } } }`)) + + shared.RunCommandFinder( + "", + &api.PullRequest{ + ID: "THE-ID", + Number: 3, + Title: "It was the best of times", + HeadRefName: "blueberries", + }, + baseRepo("OWNER", "REPO", "master"), + ) + http.Register( httpmock.GraphQL(`query RepositoryInfo\b`), httpmock.StringResponse(` @@ -821,7 +790,6 @@ func TestPRMerge_interactiveWithDeleteBranch(t *testing.T) { cs, cmdTeardown := run.Stub() defer cmdTeardown(t) - cs.Register(`git config --get-regexp.+branch\\\.blueberries\\\.`, 0, "") cs.Register(`git checkout master`, 0, "") cs.Register(`git rev-parse --verify refs/heads/blueberries`, 0, "") cs.Register(`git branch -D blueberries`, 0, "") @@ -851,16 +819,6 @@ func TestPRMerge_interactiveSquashEditCommitMsg(t *testing.T) { tr := initFakeHTTP() defer tr.Verify(t) - - tr.Register( - httpmock.GraphQL(`query PullRequestByNumber\b`), - httpmock.StringResponse(` - { "data": { "repository": { "pullRequest": { - "headRepositoryOwner": {"login": "OWNER"}, - "id": "THE-ID", - "number": 3, - "title": "title" - } } } }`)) tr.Register( httpmock.GraphQL(`query RepositoryInfo\b`), httpmock.StringResponse(` @@ -902,25 +860,28 @@ func TestPRMerge_interactiveSquashEditCommitMsg(t *testing.T) { }, SelectorArg: "https://github.com/OWNER/REPO/pull/123", InteractiveMode: true, + Finder: shared.NewMockFinder( + "https://github.com/OWNER/REPO/pull/123", + &api.PullRequest{ID: "THE-ID", Number: 123, Title: "title"}, + ghrepo.New("OWNER", "REPO"), + ), }) assert.NoError(t, err) assert.Equal(t, "", stdout.String()) - assert.Equal(t, "✓ Squashed and merged pull request #3 (title)\n", stderr.String()) + assert.Equal(t, "✓ Squashed and merged pull request #123 (title)\n", stderr.String()) } func TestPRMerge_interactiveCancelled(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - http.Register( - httpmock.GraphQL(`query PullRequestForBranch\b`), - httpmock.StringResponse(` - { "data": { "repository": { "pullRequests": { "nodes": [{ - "headRefName": "blueberries", - "headRepositoryOwner": {"login": "OWNER"}, - "id": "THE-ID", - "number": 3 - }] } } } }`)) + + shared.RunCommandFinder( + "", + &api.PullRequest{ID: "THE-ID", Number: 123}, + ghrepo.New("OWNER", "REPO"), + ) + http.Register( httpmock.GraphQL(`query RepositoryInfo\b`), httpmock.StringResponse(` @@ -930,11 +891,9 @@ func TestPRMerge_interactiveCancelled(t *testing.T) { "squashMergeAllowed": true } } }`)) - cs, cmdTeardown := run.Stub() + _, cmdTeardown := run.Stub() defer cmdTeardown(t) - cs.Register(`git config --get-regexp.+branch\\\.blueberries\\\.`, 0, "") - as, surveyTeardown := prompt.InitAskStubber() defer surveyTeardown() @@ -971,18 +930,6 @@ func TestMergeRun_autoMerge(t *testing.T) { tr := initFakeHTTP() defer tr.Verify(t) - - tr.Register( - httpmock.GraphQL(`query PullRequestByNumber\b`), - httpmock.StringResponse(` - { "data": { "repository": { "pullRequest": { - "id": "THE-ID", - "number": 123, - "title": "The title of the PR", - "state": "OPEN", - "headRefName": "blueberries", - "headRepositoryOwner": {"login": "OWNER"} - } } } }`)) tr.Register( httpmock.GraphQL(`mutation PullRequestAutoMerge\b`), httpmock.GraphQLMutation(`{}`, func(input map[string]interface{}) { @@ -1001,6 +948,11 @@ func TestMergeRun_autoMerge(t *testing.T) { SelectorArg: "https://github.com/OWNER/REPO/pull/123", AutoMergeEnable: true, MergeMethod: PullRequestMergeMethodSquash, + Finder: shared.NewMockFinder( + "https://github.com/OWNER/REPO/pull/123", + &api.PullRequest{ID: "THE-ID", Number: 123}, + ghrepo.New("OWNER", "REPO"), + ), }) assert.NoError(t, err) @@ -1015,21 +967,11 @@ func TestMergeRun_disableAutoMerge(t *testing.T) { tr := initFakeHTTP() defer tr.Verify(t) - - tr.Register( - httpmock.GraphQL(`query PullRequestByNumber\b`), - httpmock.StringResponse(` - { "data": { "repository": { "pullRequest": { - "id": "THE-ID", - "number": 123, - "title": "The title of the PR", - "state": "OPEN", - "headRefName": "blueberries", - "headRepositoryOwner": {"login": "OWNER"} - } } } }`)) tr.Register( httpmock.GraphQL(`mutation PullRequestAutoMergeDisable\b`), - httpmock.StringResponse(`{}`)) + httpmock.GraphQLQuery(`{}`, func(s string, m map[string]interface{}) { + assert.Equal(t, map[string]interface{}{"prID": "THE-ID"}, m) + })) _, cmdTeardown := run.Stub() defer cmdTeardown(t) @@ -1041,6 +983,11 @@ func TestMergeRun_disableAutoMerge(t *testing.T) { }, SelectorArg: "https://github.com/OWNER/REPO/pull/123", AutoMergeDisable: true, + Finder: shared.NewMockFinder( + "https://github.com/OWNER/REPO/pull/123", + &api.PullRequest{ID: "THE-ID", Number: 123}, + ghrepo.New("OWNER", "REPO"), + ), }) assert.NoError(t, err) diff --git a/pkg/cmd/pr/ready/ready.go b/pkg/cmd/pr/ready/ready.go index 78b20532a..a00563b4b 100644 --- a/pkg/cmd/pr/ready/ready.go +++ b/pkg/cmd/pr/ready/ready.go @@ -24,6 +24,8 @@ type ReadyOptions struct { Remotes func() (context.Remotes, error) Branch func() (string, error) + Finder shared.PRFinder + SelectorArg string } @@ -47,8 +49,7 @@ func NewCmdReady(f *cmdutil.Factory, runF func(*ReadyOptions) error) *cobra.Comm `), Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - // support `-R, --repo` override - opts.BaseRepo = f.BaseRepo + opts.Finder = shared.NewFinder(f) if repoOverride, _ := cmd.Flags().GetString("repo"); repoOverride != "" && len(args) == 0 { return &cmdutil.FlagError{Err: errors.New("argument required when using the --repo flag")} @@ -71,13 +72,10 @@ func NewCmdReady(f *cmdutil.Factory, runF func(*ReadyOptions) error) *cobra.Comm func readyRun(opts *ReadyOptions) error { cs := opts.IO.ColorScheme() - httpClient, err := opts.HttpClient() - if err != nil { - return err + findOptions := shared.FindOptions{ + Selector: opts.SelectorArg, } - apiClient := api.NewClientFromHTTP(httpClient) - - pr, baseRepo, err := shared.PRFromArgs(apiClient, opts.BaseRepo, opts.Branch, opts.Remotes, opts.SelectorArg) + pr, baseRepo, err := opts.Finder.Find(findOptions) if err != nil { return err } @@ -90,6 +88,12 @@ func readyRun(opts *ReadyOptions) error { return nil } + httpClient, err := opts.HttpClient() + if err != nil { + return err + } + apiClient := api.NewClientFromHTTP(httpClient) + err = api.PullRequestReady(apiClient, baseRepo, pr) if err != nil { return fmt.Errorf("API call failed: %w", err) diff --git a/pkg/cmd/pr/reopen/reopen.go b/pkg/cmd/pr/reopen/reopen.go index f22b8bfd2..72fb13659 100644 --- a/pkg/cmd/pr/reopen/reopen.go +++ b/pkg/cmd/pr/reopen/reopen.go @@ -5,8 +5,6 @@ import ( "net/http" "github.com/cli/cli/api" - "github.com/cli/cli/internal/config" - "github.com/cli/cli/internal/ghrepo" "github.com/cli/cli/pkg/cmd/pr/shared" "github.com/cli/cli/pkg/cmdutil" "github.com/cli/cli/pkg/iostreams" @@ -15,9 +13,9 @@ import ( type ReopenOptions struct { HttpClient func() (*http.Client, error) - Config func() (config.Config, error) IO *iostreams.IOStreams - BaseRepo func() (ghrepo.Interface, error) + + Finder shared.PRFinder SelectorArg string } @@ -26,7 +24,6 @@ func NewCmdReopen(f *cmdutil.Factory, runF func(*ReopenOptions) error) *cobra.Co opts := &ReopenOptions{ IO: f.IOStreams, HttpClient: f.HttpClient, - Config: f.Config, } cmd := &cobra.Command{ @@ -34,8 +31,7 @@ func NewCmdReopen(f *cmdutil.Factory, runF func(*ReopenOptions) error) *cobra.Co Short: "Reopen a pull request", Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - // support `-R, --repo` override - opts.BaseRepo = f.BaseRepo + opts.Finder = shared.NewFinder(f) if len(args) > 0 { opts.SelectorArg = args[0] @@ -54,13 +50,10 @@ func NewCmdReopen(f *cmdutil.Factory, runF func(*ReopenOptions) error) *cobra.Co func reopenRun(opts *ReopenOptions) error { cs := opts.IO.ColorScheme() - httpClient, err := opts.HttpClient() - if err != nil { - return err + findOptions := shared.FindOptions{ + Selector: opts.SelectorArg, } - apiClient := api.NewClientFromHTTP(httpClient) - - pr, baseRepo, err := shared.PRFromArgs(apiClient, opts.BaseRepo, nil, nil, opts.SelectorArg) + pr, baseRepo, err := opts.Finder.Find(findOptions) if err != nil { return err } @@ -75,6 +68,12 @@ func reopenRun(opts *ReopenOptions) error { return nil } + httpClient, err := opts.HttpClient() + if err != nil { + return err + } + apiClient := api.NewClientFromHTTP(httpClient) + err = api.PullRequestReopen(apiClient, baseRepo, pr) if err != nil { return fmt.Errorf("API call failed: %w", err) diff --git a/pkg/cmd/pr/review/review.go b/pkg/cmd/pr/review/review.go index 1ff213bcb..606af85dc 100644 --- a/pkg/cmd/pr/review/review.go +++ b/pkg/cmd/pr/review/review.go @@ -8,9 +8,7 @@ import ( "github.com/AlecAivazis/survey/v2" "github.com/MakeNowJust/heredoc" "github.com/cli/cli/api" - "github.com/cli/cli/context" "github.com/cli/cli/internal/config" - "github.com/cli/cli/internal/ghrepo" "github.com/cli/cli/pkg/cmd/pr/shared" "github.com/cli/cli/pkg/cmdutil" "github.com/cli/cli/pkg/iostreams" @@ -24,9 +22,8 @@ type ReviewOptions struct { HttpClient func() (*http.Client, error) Config func() (config.Config, error) IO *iostreams.IOStreams - BaseRepo func() (ghrepo.Interface, error) - Remotes func() (context.Remotes, error) - Branch func() (string, error) + + Finder shared.PRFinder SelectorArg string InteractiveMode bool @@ -39,8 +36,6 @@ func NewCmdReview(f *cmdutil.Factory, runF func(*ReviewOptions) error) *cobra.Co IO: f.IOStreams, HttpClient: f.HttpClient, Config: f.Config, - Remotes: f.Remotes, - Branch: f.Branch, } var ( @@ -74,8 +69,7 @@ func NewCmdReview(f *cmdutil.Factory, runF func(*ReviewOptions) error) *cobra.Co `), Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - // support `-R, --repo` override - opts.BaseRepo = f.BaseRepo + opts.Finder = shared.NewFinder(f) if repoOverride, _ := cmd.Flags().GetString("repo"); repoOverride != "" && len(args) == 0 { return &cmdutil.FlagError{Err: errors.New("argument required when using the --repo flag")} @@ -151,13 +145,10 @@ func NewCmdReview(f *cmdutil.Factory, runF func(*ReviewOptions) error) *cobra.Co } func reviewRun(opts *ReviewOptions) error { - httpClient, err := opts.HttpClient() - if err != nil { - return err + findOptions := shared.FindOptions{ + Selector: opts.SelectorArg, } - apiClient := api.NewClientFromHTTP(httpClient) - - pr, baseRepo, err := shared.PRFromArgs(apiClient, opts.BaseRepo, opts.Branch, opts.Remotes, opts.SelectorArg) + pr, baseRepo, err := opts.Finder.Find(findOptions) if err != nil { return err } @@ -183,6 +174,12 @@ func reviewRun(opts *ReviewOptions) error { } } + httpClient, err := opts.HttpClient() + if err != nil { + return err + } + apiClient := api.NewClientFromHTTP(httpClient) + err = api.AddReview(apiClient, baseRepo, pr, reviewData) if err != nil { return fmt.Errorf("failed to create review: %w", err) diff --git a/pkg/cmd/pr/shared/finder.go b/pkg/cmd/pr/shared/finder.go new file mode 100644 index 000000000..cd7e3c314 --- /dev/null +++ b/pkg/cmd/pr/shared/finder.go @@ -0,0 +1,328 @@ +package shared + +import ( + "errors" + "fmt" + "net/http" + "net/url" + "regexp" + "sort" + "strconv" + "strings" + + "github.com/cli/cli/api" + "github.com/cli/cli/context" + "github.com/cli/cli/git" + "github.com/cli/cli/internal/ghrepo" + "github.com/cli/cli/pkg/cmdutil" + "github.com/cli/cli/pkg/set" +) + +type PRFinder interface { + Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, error) +} + +type progressIndicator interface { + StartProgressIndicator() + StopProgressIndicator() +} + +type finder struct { + baseRepoFn func() (ghrepo.Interface, error) + branchFn func() (string, error) + remotesFn func() (context.Remotes, error) + httpClient func() (*http.Client, error) + progress progressIndicator + + repo ghrepo.Interface + prNumber int + branchName string +} + +func NewFinder(factory *cmdutil.Factory) PRFinder { + if runCommandFinder != nil { + f := runCommandFinder + runCommandFinder = &mockFinder{err: errors.New("you must use a RunCommandFinder to stub PR lookups")} + return f + } + + return &finder{ + baseRepoFn: factory.BaseRepo, + branchFn: factory.Branch, + remotesFn: factory.Remotes, + httpClient: factory.HttpClient, + progress: factory.IOStreams, + } +} + +var runCommandFinder PRFinder + +// RunCommandFinder is the NewMockFinder substitute to be used ONLY in runCommand-style tests. +func RunCommandFinder(selector string, pr *api.PullRequest, repo ghrepo.Interface) { + runCommandFinder = NewMockFinder(selector, pr, repo) +} + +type FindOptions struct { + // Selector can be a number with optional `#` prefix, a branch name with optional `:` prefix, or + // a PR URL. + Selector string + // Fields lists the GraphQL fields to fetch for the PullRequest. + Fields []string + // BaseBranch is the name of the base branch to scope the PR-for-branch lookup to. + BaseBranch string + // States lists the possible PR states to scope the PR-for-branch lookup to. + States []string +} + +func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, error) { + if len(opts.Fields) == 0 { + return nil, nil, errors.New("Find error: no fields specified") + } + + _ = f.parseURL(opts.Selector) + + if f.repo == nil { + repo, err := f.baseRepoFn() + if err != nil { + return nil, nil, fmt.Errorf("could not determine base repo: %w", err) + } + f.repo = repo + } + + if opts.Selector == "" { + if err := f.parseCurrentBranch(); err != nil { + return nil, nil, err + } + } else if f.prNumber == 0 { + if prNumber, err := strconv.Atoi(strings.TrimPrefix(opts.Selector, "#")); err == nil { + f.prNumber = prNumber + } else { + f.branchName = opts.Selector + } + } + + httpClient, err := f.httpClient() + if err != nil { + return nil, nil, err + } + + if f.progress != nil { + f.progress.StartProgressIndicator() + defer f.progress.StopProgressIndicator() + } + + if f.prNumber > 0 { + if len(opts.Fields) == 1 && opts.Fields[0] == "number" { + // avoid hitting the API if we already have all the information + return &api.PullRequest{Number: f.prNumber}, f.repo, nil + } + pr, err := findByNumber(httpClient, f.repo, f.prNumber, opts.Fields) + return pr, f.repo, err + } + + pr, err := findForBranch(httpClient, f.repo, opts.BaseBranch, f.branchName, opts.States, opts.Fields) + + // TODO: preload view: api.ReviewsForPullRequest, api.CommentsForPullRequest + // TODO: preload checks: get all checks + return pr, f.repo, err +} + +var pullURLRE = regexp.MustCompile(`^/([^/]+)/([^/]+)/pull/(\d+)`) + +func (f *finder) parseURL(prURL string) error { + if prURL == "" { + return fmt.Errorf("invalid URL: %q", prURL) + } + + u, err := url.Parse(prURL) + if err != nil { + return err + } + + if u.Scheme != "https" && u.Scheme != "http" { + return fmt.Errorf("invalid scheme: %s", u.Scheme) + } + + m := pullURLRE.FindStringSubmatch(u.Path) + if m == nil { + return fmt.Errorf("not a pull request URL: %s", prURL) + } + + f.repo = ghrepo.NewWithHost(m[1], m[2], u.Hostname()) + f.prNumber, _ = strconv.Atoi(m[3]) + return nil +} + +var prHeadRE = regexp.MustCompile(`^refs/pull/(\d+)/head$`) + +func (f *finder) parseCurrentBranch() error { + prHeadRef, err := f.branchFn() + if err != nil { + return err + } + + branchConfig := git.ReadBranchConfig(prHeadRef) + + // the branch is configured to merge a special PR head ref + if m := prHeadRE.FindStringSubmatch(branchConfig.MergeRef); m != nil { + f.prNumber, _ = strconv.Atoi(m[1]) + return nil + } + + var branchOwner string + if branchConfig.RemoteURL != nil { + // the branch merges from a remote specified by URL + if r, err := ghrepo.FromURL(branchConfig.RemoteURL); err == nil { + branchOwner = r.RepoOwner() + } + } else if branchConfig.RemoteName != "" { + // the branch merges from a remote specified by name + rem, _ := f.remotesFn() + if r, err := rem.FindByName(branchConfig.RemoteName); err == nil { + branchOwner = r.RepoOwner() + } + } + + if branchOwner != "" { + if strings.HasPrefix(branchConfig.MergeRef, "refs/heads/") { + prHeadRef = strings.TrimPrefix(branchConfig.MergeRef, "refs/heads/") + } + // prepend `OWNER:` if this branch is pushed to a fork + if !strings.EqualFold(branchOwner, f.repo.RepoOwner()) { + prHeadRef = fmt.Sprintf("%s:%s", branchOwner, prHeadRef) + } + } + + f.branchName = prHeadRef + return nil +} + +func findByNumber(httpClient *http.Client, repo ghrepo.Interface, number int, fields []string) (*api.PullRequest, error) { + type response struct { + Repository struct { + PullRequest api.PullRequest + } + } + + query := fmt.Sprintf(` + query PullRequestByNumber($owner: String!, $repo: String!, $pr_number: Int!) { + repository(owner: $owner, name: $repo) { + pullRequest(number: $pr_number) {%s} + } + }`, api.PullRequestGraphQL(fields)) + + variables := map[string]interface{}{ + "owner": repo.RepoOwner(), + "repo": repo.RepoName(), + "pr_number": number, + } + + var resp response + client := api.NewClientFromHTTP(httpClient) + err := client.GraphQL(repo.RepoHost(), query, variables, &resp) + if err != nil { + return nil, err + } + + return &resp.Repository.PullRequest, nil +} + +func findForBranch(httpClient *http.Client, repo ghrepo.Interface, baseBranch, headBranch string, stateFilters, fields []string) (*api.PullRequest, error) { + type response struct { + Repository struct { + PullRequests struct { + Nodes []api.PullRequest + } + } + } + + fieldSet := set.NewStringSet() + fieldSet.AddValues(fields) + // these fields are required for filtering below + fieldSet.AddValues([]string{"state", "baseRefName", "headRefName", "isCrossRepository", "headRepositoryOwner"}) + + query := fmt.Sprintf(` + query PullRequestForBranch($owner: String!, $repo: String!, $headRefName: String!, $states: [PullRequestState!]) { + repository(owner: $owner, name: $repo) { + pullRequests(headRefName: $headRefName, states: $states, first: 30, orderBy: { field: CREATED_AT, direction: DESC }) { + nodes {%s} + } + } + }`, api.PullRequestGraphQL(fieldSet.ToSlice())) + + branchWithoutOwner := headBranch + if idx := strings.Index(headBranch, ":"); idx >= 0 { + branchWithoutOwner = headBranch[idx+1:] + } + + variables := map[string]interface{}{ + "owner": repo.RepoOwner(), + "repo": repo.RepoName(), + "headRefName": branchWithoutOwner, + "states": stateFilters, + } + + var resp response + client := api.NewClientFromHTTP(httpClient) + err := client.GraphQL(repo.RepoHost(), query, variables, &resp) + if err != nil { + return nil, err + } + + prs := resp.Repository.PullRequests.Nodes + sort.SliceStable(prs, func(a, b int) bool { + return prs[a].State == "OPEN" && prs[b].State != "OPEN" + }) + + for _, pr := range prs { + if pr.HeadLabel() == headBranch && (baseBranch == "" || pr.BaseRefName == baseBranch) { + return &pr, nil + } + } + + return nil, &NotFoundError{fmt.Errorf("no pull requests found for branch %q", headBranch)} +} + +type NotFoundError struct { + error +} + +func (err *NotFoundError) Unwrap() error { + return err.error +} + +func NewMockFinder(selector string, pr *api.PullRequest, repo ghrepo.Interface) PRFinder { + return &mockFinder{ + expectSelector: selector, + pr: pr, + repo: repo, + } +} + +type mockFinder struct { + called bool + expectSelector string + pr *api.PullRequest + repo ghrepo.Interface + err error +} + +func (m *mockFinder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, error) { + if m.err != nil { + return nil, nil, m.err + } + if m.expectSelector != opts.Selector { + return nil, nil, fmt.Errorf("mockFinder: expected selector %q, got %q", m.expectSelector, opts.Selector) + } + if m.called { + return nil, nil, errors.New("mockFinder used more than once") + } + m.called = true + + if m.pr.HeadRepositoryOwner.Login == "" { + // pose as same-repo PR by default + m.pr.HeadRepositoryOwner.Login = m.repo.RepoOwner() + } + + return m.pr, m.repo, nil +} diff --git a/pkg/cmd/pr/shared/lookup_test.go b/pkg/cmd/pr/shared/finder_test.go similarity index 78% rename from pkg/cmd/pr/shared/lookup_test.go rename to pkg/cmd/pr/shared/finder_test.go index 4d843d7ae..f3600962e 100644 --- a/pkg/cmd/pr/shared/lookup_test.go +++ b/pkg/cmd/pr/shared/finder_test.go @@ -4,13 +4,12 @@ import ( "net/http" "testing" - "github.com/cli/cli/api" "github.com/cli/cli/context" "github.com/cli/cli/internal/ghrepo" "github.com/cli/cli/pkg/httpmock" ) -func TestPRFromArgs(t *testing.T) { +func TestFind(t *testing.T) { type args struct { baseRepoFn func() (ghrepo.Interface, error) branchFn func() (string, error) @@ -68,12 +67,6 @@ func TestPRFromArgs(t *testing.T) { baseRepoFn: nil, }, httpStub: func(r *httpmock.Registry) { - r.Register( - httpmock.GraphQL(`query PullRequest_fields\b`), - httpmock.StringResponse(`{"data":{}}`)) - r.Register( - httpmock.GraphQL(`query PullRequest_fields2\b`), - httpmock.StringResponse(`{"data":{}}`)) r.Register( httpmock.GraphQL(`query PullRequestByNumber\b`), httpmock.StringResponse(`{"data":{"repository":{ @@ -87,17 +80,30 @@ func TestPRFromArgs(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { reg := &httpmock.Registry{} + defer reg.Verify(t) if tt.httpStub != nil { tt.httpStub(reg) } - httpClient := &http.Client{Transport: reg} - pr, repo, err := PRFromArgs(api.NewClientFromHTTP(httpClient), tt.args.baseRepoFn, tt.args.branchFn, tt.args.remotesFn, tt.args.selector) + + f := finder{ + httpClient: func() (*http.Client, error) { + return &http.Client{Transport: reg}, nil + }, + baseRepoFn: tt.args.baseRepoFn, + branchFn: tt.args.branchFn, + remotesFn: tt.args.remotesFn, + } + + pr, repo, err := f.Find(FindOptions{ + Selector: tt.args.selector, + }) if (err != nil) != tt.wantErr { - t.Errorf("IssueFromArg() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("Find() error = %v, wantErr %v", err, tt.wantErr) return } + if pr.Number != tt.wantPR { - t.Errorf("want issue #%d, got #%d", tt.wantPR, pr.Number) + t.Errorf("want pr #%d, got #%d", tt.wantPR, pr.Number) } repoURL := ghrepo.GenerateRepoURL(repo, "") if repoURL != tt.wantRepo { diff --git a/pkg/cmd/pr/shared/lookup.go b/pkg/cmd/pr/shared/lookup.go deleted file mode 100644 index 06e9221c0..000000000 --- a/pkg/cmd/pr/shared/lookup.go +++ /dev/null @@ -1,121 +0,0 @@ -package shared - -import ( - "fmt" - "net/url" - "regexp" - "strconv" - "strings" - - "github.com/cli/cli/api" - "github.com/cli/cli/context" - "github.com/cli/cli/git" - "github.com/cli/cli/internal/ghrepo" -) - -// PRFromArgs looks up the pull request from either the number/branch/URL argument or one belonging to the current branch -// -// NOTE: this API isn't great, but is here as a compatibility layer between old-style and new-style commands -func PRFromArgs(apiClient *api.Client, baseRepoFn func() (ghrepo.Interface, error), branchFn func() (string, error), remotesFn func() (context.Remotes, error), arg string) (*api.PullRequest, ghrepo.Interface, error) { - if arg != "" { - // First check to see if the prString is a url, return repo from url if found. This - // is run first because we don't need to run determineBaseRepo for this path - pr, r, err := prFromURL(apiClient, arg) - if pr != nil || err != nil { - return pr, r, err - } - } - - repo, err := baseRepoFn() - if err != nil { - return nil, nil, fmt.Errorf("could not determine base repo: %w", err) - } - - // If there are no args see if we can guess the PR from the current branch - if arg == "" { - pr, err := prForCurrentBranch(apiClient, repo, branchFn, remotesFn) - return pr, repo, err - } else { - // Next see if the prString is a number and use that to look up the url - pr, err := prFromNumberString(apiClient, repo, arg) - if pr != nil || err != nil { - return pr, repo, err - } - - // Last see if it is a branch name - pr, err = api.PullRequestForBranch(apiClient, repo, "", arg, nil) - return pr, repo, err - } -} - -func prFromNumberString(apiClient *api.Client, repo ghrepo.Interface, s string) (*api.PullRequest, error) { - if prNumber, err := strconv.Atoi(strings.TrimPrefix(s, "#")); err == nil { - return api.PullRequestByNumber(apiClient, repo, prNumber) - } - - return nil, nil -} - -var pullURLRE = regexp.MustCompile(`^/([^/]+)/([^/]+)/pull/(\d+)`) - -func prFromURL(apiClient *api.Client, s string) (*api.PullRequest, ghrepo.Interface, error) { - u, err := url.Parse(s) - if err != nil { - return nil, nil, nil - } - - if u.Scheme != "https" && u.Scheme != "http" { - return nil, nil, nil - } - - m := pullURLRE.FindStringSubmatch(u.Path) - if m == nil { - return nil, nil, nil - } - - repo := ghrepo.NewWithHost(m[1], m[2], u.Hostname()) - prNumberString := m[3] - pr, err := prFromNumberString(apiClient, repo, prNumberString) - return pr, repo, err -} - -func prForCurrentBranch(apiClient *api.Client, repo ghrepo.Interface, branchFn func() (string, error), remotesFn func() (context.Remotes, error)) (*api.PullRequest, error) { - prHeadRef, err := branchFn() - if err != nil { - return nil, err - } - - branchConfig := git.ReadBranchConfig(prHeadRef) - - // the branch is configured to merge a special PR head ref - prHeadRE := regexp.MustCompile(`^refs/pull/(\d+)/head$`) - if m := prHeadRE.FindStringSubmatch(branchConfig.MergeRef); m != nil { - return prFromNumberString(apiClient, repo, m[1]) - } - - var branchOwner string - if branchConfig.RemoteURL != nil { - // the branch merges from a remote specified by URL - if r, err := ghrepo.FromURL(branchConfig.RemoteURL); err == nil { - branchOwner = r.RepoOwner() - } - } else if branchConfig.RemoteName != "" { - // the branch merges from a remote specified by name - rem, _ := remotesFn() - if r, err := rem.FindByName(branchConfig.RemoteName); err == nil { - branchOwner = r.RepoOwner() - } - } - - if branchOwner != "" { - if strings.HasPrefix(branchConfig.MergeRef, "refs/heads/") { - prHeadRef = strings.TrimPrefix(branchConfig.MergeRef, "refs/heads/") - } - // prepend `OWNER:` if this branch is pushed to a fork - if !strings.EqualFold(branchOwner, repo.RepoOwner()) { - prHeadRef = fmt.Sprintf("%s:%s", branchOwner, prHeadRef) - } - } - - return api.PullRequestForBranch(apiClient, repo, "", prHeadRef, nil) -} diff --git a/pkg/cmd/pr/view/view.go b/pkg/cmd/pr/view/view.go index 4e6300297..b8e32189d 100644 --- a/pkg/cmd/pr/view/view.go +++ b/pkg/cmd/pr/view/view.go @@ -3,17 +3,12 @@ package view import ( "errors" "fmt" - "net/http" "sort" "strconv" "strings" - "sync" "github.com/MakeNowJust/heredoc" "github.com/cli/cli/api" - "github.com/cli/cli/context" - "github.com/cli/cli/internal/config" - "github.com/cli/cli/internal/ghrepo" "github.com/cli/cli/pkg/cmd/pr/shared" "github.com/cli/cli/pkg/cmdutil" "github.com/cli/cli/pkg/iostreams" @@ -27,14 +22,10 @@ type browser interface { } type ViewOptions struct { - HttpClient func() (*http.Client, error) - Config func() (config.Config, error) - IO *iostreams.IOStreams - Browser browser - BaseRepo func() (ghrepo.Interface, error) - Remotes func() (context.Remotes, error) - Branch func() (string, error) + IO *iostreams.IOStreams + Browser browser + Finder shared.PRFinder Exporter cmdutil.Exporter SelectorArg string @@ -44,12 +35,8 @@ type ViewOptions struct { func NewCmdView(f *cmdutil.Factory, runF func(*ViewOptions) error) *cobra.Command { opts := &ViewOptions{ - IO: f.IOStreams, - HttpClient: f.HttpClient, - Config: f.Config, - Remotes: f.Remotes, - Branch: f.Branch, - Browser: f.Browser, + IO: f.IOStreams, + Browser: f.Browser, } cmd := &cobra.Command{ @@ -65,8 +52,7 @@ func NewCmdView(f *cmdutil.Factory, runF func(*ViewOptions) error) *cobra.Comman `), Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - // support `-R, --repo` override - opts.BaseRepo = f.BaseRepo + opts.Finder = shared.NewFinder(f) if repoOverride, _ := cmd.Flags().GetString("repo"); repoOverride != "" && len(args) == 0 { return &cmdutil.FlagError{Err: errors.New("argument required when using the --repo flag")} @@ -90,10 +76,26 @@ func NewCmdView(f *cmdutil.Factory, runF func(*ViewOptions) error) *cobra.Comman return cmd } +var defaultFields = []string{ + "url", "number", "title", "state", "body", "author", + "isDraft", "maintainerCanModify", "mergeable", "additions", "deletions", + "baseRefName", "headRefName", "headRepositoryOwner", "headRepository", "isCrossRepository", + "reviewRequests", "reviews", "assignees", "labels", "projectCards", "milestone", + "comments", // TODO: fetch only 1 last comment unless `opts.Comments` was set + "reactionGroups", +} + func viewRun(opts *ViewOptions) error { - opts.IO.StartProgressIndicator() - pr, err := retrievePullRequest(opts) - opts.IO.StopProgressIndicator() + findOptions := shared.FindOptions{ + Selector: opts.SelectorArg, + Fields: defaultFields, + } + if opts.BrowserMode { + findOptions.Fields = []string{"url"} + } else if opts.Exporter != nil { + findOptions.Fields = opts.Exporter.Fields() + } + pr, _, err := opts.Finder.Find(findOptions) if err != nil { return err } @@ -413,51 +415,3 @@ func prStateWithDraft(pr *api.PullRequest) string { return pr.State } - -func retrievePullRequest(opts *ViewOptions) (*api.PullRequest, error) { - httpClient, err := opts.HttpClient() - if err != nil { - return nil, err - } - - apiClient := api.NewClientFromHTTP(httpClient) - - pr, repo, err := shared.PRFromArgs(apiClient, opts.BaseRepo, opts.Branch, opts.Remotes, opts.SelectorArg) - if err != nil { - return nil, err - } - - if opts.BrowserMode { - return pr, nil - } - - var errp, errc error - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - var reviews *api.PullRequestReviews - reviews, errp = api.ReviewsForPullRequest(apiClient, repo, pr) - pr.Reviews = *reviews - }() - - if opts.Comments { - wg.Add(1) - go func() { - defer wg.Done() - var comments *api.Comments - comments, errc = api.CommentsForPullRequest(apiClient, repo, pr) - pr.Comments = *comments - }() - } - - wg.Wait() - - if errp != nil { - err = errp - } - if errc != nil { - err = errc - } - return pr, err -} diff --git a/pkg/cmd/run/shared/shared.go b/pkg/cmd/run/shared/shared.go index afd926114..4fcf73250 100644 --- a/pkg/cmd/run/shared/shared.go +++ b/pkg/cmd/run/shared/shared.go @@ -134,11 +134,10 @@ func GetAnnotations(client *api.Client, repo ghrepo.Interface, job Job) ([]Annot err := client.REST(repo.RepoHost(), "GET", path, nil, &result) if err != nil { - var notFound *api.NotFoundError - if !errors.As(err, ¬Found) { + var httpError api.HTTPError + if errors.As(err, &httpError) && httpError.StatusCode == 404 { return []Annotation{}, nil } - return nil, err }