Eliminate API overfetching in pr commands

This completely rewrites the PR lookup mechanism so that the caller
must specify the GraphQL fields to query for each PR. Additionally, this
fixes some export problems with `pr view --json`.

Features:

- Each pr command now gets assigned a concept of a Finder. This makes it
  easier to stub the PR in tests without having to stub the underlying
  HTTP calls or git invocations.

- `pr view --web` is much faster since it only fetches the "url" field.

- `pr diff 123` now skips a whole API call where a whole PR was
  unnecessarily preloaded just to access its diff in a subsequent call.

- PullRequestGraphQL query builder is now used to construct queries.

- A bunch of individual commands are now freed of having to know about
  concepts such as BaseRepo, Branch, Config, or Remotes.
This commit is contained in:
Mislav Marohnić 2021-04-28 19:25:27 +02:00
parent d478a65254
commit 9bdc63c4ca
25 changed files with 769 additions and 1036 deletions

View file

@ -135,24 +135,6 @@ func CommentCreate(client *Client, repoHost string, params CommentCreateInput) (
return mutation.AddComment.CommentEdge.Node.URL, nil
}
func commentsFragment() string {
return `comments(last: 1) {
nodes {
author {
login
}
authorAssociation
body
createdAt
includesCreatedEdit
isMinimized
minimizedReason
` + reactionGroupsFragment() + `
}
totalCount
}`
}
func (c Comment) AuthorLogin() string {
return c.Author.Login
}

View file

@ -98,6 +98,10 @@ type IssuesDisabledError struct {
error
}
type Owner struct {
Login string `json:"login"`
}
type Author struct {
Login string `json:"login"`
}

View file

@ -6,7 +6,6 @@ import (
"fmt"
"io"
"net/http"
"sort"
"strings"
"time"
@ -34,6 +33,7 @@ type PullRequest struct {
Number int
Title string
State string
Closed bool
URL string
BaseRefName string
HeadRefName string
@ -57,10 +57,8 @@ type PullRequest struct {
Author Author
MergedBy *Author
HeadRepositoryOwner struct {
Login string `json:"login"`
}
HeadRepository struct {
HeadRepositoryOwner Owner
HeadRepository struct {
Name string
}
IsCrossRepository bool
@ -77,27 +75,7 @@ type PullRequest struct {
Commits struct {
TotalCount int
Nodes []struct {
Commit struct {
Oid string
StatusCheckRollup struct {
Contexts struct {
Nodes []struct {
TypeName string `json:"__typename"`
Name string `json:"name"`
Context string `json:"context,omitempty"`
State string `json:"state,omitempty"`
Status string `json:"status"`
Conclusion string `json:"conclusion"`
StartedAt time.Time `json:"startedAt"`
CompletedAt time.Time `json:"completedAt"`
DetailsURL string `json:"detailsUrl"`
TargetURL string `json:"targetUrl,omitempty"`
}
}
}
}
}
Nodes []PullRequestCommit
}
Assignees Assignees
Labels Labels
@ -113,6 +91,37 @@ type Commit struct {
OID string `json:"oid"`
}
type PullRequestCommit struct {
Commit PullRequestCommitCommit
}
// PullRequestCommitCommit is like "Commit" but with StatusCheckRollup
type PullRequestCommitCommit struct {
Oid string
StatusCheckRollup struct {
Contexts struct {
Nodes []struct {
TypeName string `json:"__typename"`
Name string `json:"name"`
Context string `json:"context,omitempty"`
State string `json:"state,omitempty"`
Status string `json:"status"`
Conclusion string `json:"conclusion"`
StartedAt time.Time `json:"startedAt"`
CompletedAt time.Time `json:"completedAt"`
DetailsURL string `json:"detailsUrl"`
TargetURL string `json:"targetUrl,omitempty"`
}
}
}
}
func (pr *PullRequest) StubCommit(oid string) {
pr.Commits.Nodes = append(pr.Commits.Nodes, PullRequestCommit{
Commit: PullRequestCommitCommit{Oid: oid},
})
}
type PullRequestFile struct {
Path string `json:"path"`
Additions int `json:"additions"`
@ -138,14 +147,6 @@ func (r ReviewRequests) Logins() []string {
return logins
}
type NotFoundError struct {
error
}
func (err *NotFoundError) Unwrap() error {
return err.error
}
func (pr PullRequest) HeadLabel() string {
if pr.IsCrossRepository {
return fmt.Sprintf("%s:%s", pr.HeadRepositoryOwner.Login, pr.HeadRefName)
@ -247,7 +248,7 @@ func (c Client) PullRequestDiff(baseRepo ghrepo.Interface, prNumber int) (io.Rea
}
if resp.StatusCode == 404 {
return nil, &NotFoundError{errors.New("pull request not found")}
return nil, errors.New("pull request not found")
} else if resp.StatusCode != 200 {
return nil, HandleHTTPError(resp)
}
@ -560,274 +561,6 @@ func pullRequestFragment(httpClient *http.Client, hostname string) (string, erro
return fragments, nil
}
func prCommitsFragment(httpClient *http.Client, hostname string) (string, error) {
cachedClient := NewCachedClient(httpClient, time.Hour*24)
if prFeatures, err := determinePullRequestFeatures(cachedClient, hostname); err != nil {
return "", err
} else if !prFeatures.HasStatusCheckRollup {
return "", nil
}
return `
commits(last: 1) {
totalCount
nodes {
commit {
oid
statusCheckRollup {
contexts(last: 100) {
nodes {
...on StatusContext {
context
state
targetUrl
}
...on CheckRun {
name
status
conclusion
startedAt
completedAt
detailsUrl
}
}
}
}
}
}
}
`, nil
}
func PullRequestByNumber(client *Client, repo ghrepo.Interface, number int) (*PullRequest, error) {
type response struct {
Repository struct {
PullRequest PullRequest
}
}
statusesFragment, err := prCommitsFragment(client.http, repo.RepoHost())
if err != nil {
return nil, err
}
query := `
query PullRequestByNumber($owner: String!, $repo: String!, $pr_number: Int!) {
repository(owner: $owner, name: $repo) {
pullRequest(number: $pr_number) {
id
url
number
title
state
closed
body
mergeable
additions
deletions
author {
login
}
` + statusesFragment + `
baseRefName
headRefName
headRepositoryOwner {
login
}
headRepository {
name
}
isCrossRepository
isDraft
maintainerCanModify
reviewRequests(first: 100) {
nodes {
requestedReviewer {
__typename
...on User {
login
}
...on Team {
name
}
}
}
totalCount
}
assignees(first: 100) {
nodes {
login
}
totalCount
}
labels(first: 100) {
nodes {
name
}
totalCount
}
projectCards(first: 100) {
nodes {
project {
name
}
column {
name
}
}
totalCount
}
milestone{
title
}
` + commentsFragment() + `
` + reactionGroupsFragment() + `
}
}
}`
variables := map[string]interface{}{
"owner": repo.RepoOwner(),
"repo": repo.RepoName(),
"pr_number": number,
}
var resp response
err = client.GraphQL(repo.RepoHost(), query, variables, &resp)
if err != nil {
return nil, err
}
return &resp.Repository.PullRequest, nil
}
func PullRequestForBranch(client *Client, repo ghrepo.Interface, baseBranch, headBranch string, stateFilters []string) (*PullRequest, error) {
type response struct {
Repository struct {
PullRequests struct {
Nodes []PullRequest
}
}
}
statusesFragment, err := prCommitsFragment(client.http, repo.RepoHost())
if err != nil {
return nil, err
}
query := `
query PullRequestForBranch($owner: String!, $repo: String!, $headRefName: String!, $states: [PullRequestState!]) {
repository(owner: $owner, name: $repo) {
pullRequests(headRefName: $headRefName, states: $states, first: 30, orderBy: { field: CREATED_AT, direction: DESC }) {
nodes {
id
number
title
state
body
mergeable
additions
deletions
author {
login
}
` + statusesFragment + `
url
baseRefName
headRefName
headRepositoryOwner {
login
}
headRepository {
name
}
isCrossRepository
isDraft
maintainerCanModify
reviewRequests(first: 100) {
nodes {
requestedReviewer {
__typename
...on User {
login
}
...on Team {
name
}
}
}
totalCount
}
assignees(first: 100) {
nodes {
login
}
totalCount
}
labels(first: 100) {
nodes {
name
}
totalCount
}
projectCards(first: 100) {
nodes {
project {
name
}
column {
name
}
}
totalCount
}
milestone{
title
}
` + commentsFragment() + `
` + reactionGroupsFragment() + `
}
}
}
}`
branchWithoutOwner := headBranch
if idx := strings.Index(headBranch, ":"); idx >= 0 {
branchWithoutOwner = headBranch[idx+1:]
}
variables := map[string]interface{}{
"owner": repo.RepoOwner(),
"repo": repo.RepoName(),
"headRefName": branchWithoutOwner,
"states": stateFilters,
}
var resp response
err = client.GraphQL(repo.RepoHost(), query, variables, &resp)
if err != nil {
return nil, err
}
prs := resp.Repository.PullRequests.Nodes
sortPullRequestsByState(prs)
for _, pr := range prs {
if pr.HeadLabel() == headBranch && (baseBranch == "" || pr.BaseRefName == baseBranch) {
return &pr, nil
}
}
return nil, &NotFoundError{fmt.Errorf("no pull requests found for branch %q", headBranch)}
}
// sortPullRequestsByState sorts a PullRequest slice by open-first
func sortPullRequestsByState(prs []PullRequest) {
sort.SliceStable(prs, func(a, b int) bool {
return prs[a].State == "OPEN"
})
}
// CreatePullRequest creates a pull request in a GitHub repository
func CreatePullRequest(client *Client, repo *Repository, params map[string]interface{}) (*PullRequest, error) {
query := `

View file

@ -158,32 +158,3 @@ func Test_determinePullRequestFeatures(t *testing.T) {
})
}
}
func Test_sortPullRequestsByState(t *testing.T) {
prs := []PullRequest{
{
BaseRefName: "test1",
State: "MERGED",
},
{
BaseRefName: "test2",
State: "CLOSED",
},
{
BaseRefName: "test3",
State: "OPEN",
},
}
sortPullRequestsByState(prs)
if prs[0].BaseRefName != "test3" {
t.Errorf("prs[0]: got %s, want %q", prs[0].BaseRefName, "test3")
}
if prs[1].BaseRefName != "test1" {
t.Errorf("prs[1]: got %s, want %q", prs[1].BaseRefName, "test1")
}
if prs[2].BaseRefName != "test2" {
t.Errorf("prs[2]: got %s, want %q", prs[2].BaseRefName, "test2")
}
}

View file

@ -57,12 +57,3 @@ var reactionEmoji = map[string]string{
"ROCKET": "\U0001f680",
"EYES": "\U0001f440",
}
func reactionGroupsFragment() string {
return `reactionGroups {
content
users {
totalCount
}
}`
}

View file

@ -24,10 +24,11 @@ type CheckoutOptions struct {
HttpClient func() (*http.Client, error)
Config func() (config.Config, error)
IO *iostreams.IOStreams
BaseRepo func() (ghrepo.Interface, error)
Remotes func() (context.Remotes, error)
Branch func() (string, error)
Finder shared.PRFinder
SelectorArg string
RecurseSubmodules bool
Force bool
@ -48,8 +49,7 @@ func NewCmdCheckout(f *cmdutil.Factory, runF func(*CheckoutOptions) error) *cobr
Short: "Check out a pull request in git",
Args: cmdutil.ExactArgs(1, "argument required"),
RunE: func(cmd *cobra.Command, args []string) error {
// support `-R, --repo` override
opts.BaseRepo = f.BaseRepo
opts.Finder = shared.NewFinder(f)
if len(args) > 0 {
opts.SelectorArg = args[0]
@ -70,18 +70,10 @@ func NewCmdCheckout(f *cmdutil.Factory, runF func(*CheckoutOptions) error) *cobr
}
func checkoutRun(opts *CheckoutOptions) error {
remotes, err := opts.Remotes()
if err != nil {
return err
findOptions := shared.FindOptions{
Selector: opts.SelectorArg,
}
httpClient, err := opts.HttpClient()
if err != nil {
return err
}
apiClient := api.NewClientFromHTTP(httpClient)
pr, baseRepo, err := shared.PRFromArgs(apiClient, opts.BaseRepo, opts.Branch, opts.Remotes, opts.SelectorArg)
pr, baseRepo, err := opts.Finder.Find(findOptions)
if err != nil {
return err
}
@ -90,8 +82,12 @@ func checkoutRun(opts *CheckoutOptions) error {
if err != nil {
return err
}
protocol, _ := cfg.Get(baseRepo.RepoHost(), "git_protocol")
remotes, err := opts.Remotes()
if err != nil {
return err
}
baseRemote, _ := remotes.FindByRepo(baseRepo.RepoOwner(), baseRepo.RepoName())
baseURLOrName := ghrepo.FormatRemoteURL(baseRepo, protocol)
if baseRemote != nil {
@ -112,6 +108,12 @@ func checkoutRun(opts *CheckoutOptions) error {
if headRemote != nil {
cmdQueue = append(cmdQueue, cmdsForExistingRemote(headRemote, pr, opts)...)
} else {
httpClient, err := opts.HttpClient()
if err != nil {
return err
}
apiClient := api.NewClientFromHTTP(httpClient)
defaultBranch, err := api.RepoDefaultBranch(apiClient, baseRepo)
if err != nil {
return err

View file

@ -3,13 +3,10 @@ package checks
import (
"errors"
"fmt"
"net/http"
"sort"
"time"
"github.com/MakeNowJust/heredoc"
"github.com/cli/cli/api"
"github.com/cli/cli/context"
"github.com/cli/cli/internal/ghrepo"
"github.com/cli/cli/pkg/cmd/pr/shared"
"github.com/cli/cli/pkg/cmdutil"
@ -23,26 +20,19 @@ type browser interface {
}
type ChecksOptions struct {
HttpClient func() (*http.Client, error)
IO *iostreams.IOStreams
Browser browser
BaseRepo func() (ghrepo.Interface, error)
Branch func() (string, error)
Remotes func() (context.Remotes, error)
IO *iostreams.IOStreams
Browser browser
WebMode bool
Finder shared.PRFinder
SelectorArg string
WebMode bool
}
func NewCmdChecks(f *cmdutil.Factory, runF func(*ChecksOptions) error) *cobra.Command {
opts := &ChecksOptions{
IO: f.IOStreams,
HttpClient: f.HttpClient,
Branch: f.Branch,
Remotes: f.Remotes,
BaseRepo: f.BaseRepo,
Browser: f.Browser,
IO: f.IOStreams,
Browser: f.Browser,
}
cmd := &cobra.Command{
@ -56,8 +46,7 @@ func NewCmdChecks(f *cmdutil.Factory, runF func(*ChecksOptions) error) *cobra.Co
`),
Args: cobra.MaximumNArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
// support `-R, --repo` override
opts.BaseRepo = f.BaseRepo
opts.Finder = shared.NewFinder(f)
if repoOverride, _ := cmd.Flags().GetString("repo"); repoOverride != "" && len(args) == 0 {
return &cmdutil.FlagError{Err: errors.New("argument required when using the --repo flag")}
@ -81,13 +70,10 @@ func NewCmdChecks(f *cmdutil.Factory, runF func(*ChecksOptions) error) *cobra.Co
}
func checksRun(opts *ChecksOptions) error {
httpClient, err := opts.HttpClient()
if err != nil {
return err
findOptions := shared.FindOptions{
Selector: opts.SelectorArg,
}
apiClient := api.NewClientFromHTTP(httpClient)
pr, baseRepo, err := shared.PRFromArgs(apiClient, opts.BaseRepo, opts.Branch, opts.Remotes, opts.SelectorArg)
pr, baseRepo, err := opts.Finder.Find(findOptions)
if err != nil {
return err
}

View file

@ -2,10 +2,8 @@ package checks
import (
"bytes"
"net/http"
"testing"
"github.com/cli/cli/internal/ghrepo"
"github.com/cli/cli/internal/run"
"github.com/cli/cli/pkg/cmdutil"
"github.com/cli/cli/pkg/httpmock"
@ -174,10 +172,7 @@ func Test_checksRun(t *testing.T) {
io.SetStdoutTTY(!tt.nontty)
opts := &ChecksOptions{
IO: io,
BaseRepo: func() (ghrepo.Interface, error) {
return ghrepo.New("OWNER", "REPO"), nil
},
IO: io,
SelectorArg: "123",
}
@ -190,10 +185,6 @@ func Test_checksRun(t *testing.T) {
reg.Register(httpmock.GraphQL(`query PullRequestByNumber\b`), httpmock.FileResponse(tt.fixture))
}
opts.HttpClient = func() (*http.Client, error) {
return &http.Client{Transport: reg}, nil
}
err := checksRun(opts)
if tt.wantErr != "" {
assert.EqualError(t, err, tt.wantErr)
@ -246,15 +237,9 @@ func TestChecksRun_web(t *testing.T) {
defer teardown(t)
err := checksRun(&ChecksOptions{
IO: io,
Browser: browser,
WebMode: true,
HttpClient: func() (*http.Client, error) {
return &http.Client{Transport: reg}, nil
},
BaseRepo: func() (ghrepo.Interface, error) {
return ghrepo.New("OWNER", "REPO"), nil
},
IO: io,
Browser: browser,
WebMode: true,
SelectorArg: "123",
})
assert.NoError(t, err)

View file

@ -6,8 +6,6 @@ import (
"github.com/cli/cli/api"
"github.com/cli/cli/git"
"github.com/cli/cli/internal/config"
"github.com/cli/cli/internal/ghrepo"
"github.com/cli/cli/pkg/cmd/pr/shared"
"github.com/cli/cli/pkg/cmdutil"
"github.com/cli/cli/pkg/iostreams"
@ -16,11 +14,11 @@ import (
type CloseOptions struct {
HttpClient func() (*http.Client, error)
Config func() (config.Config, error)
IO *iostreams.IOStreams
BaseRepo func() (ghrepo.Interface, error)
Branch func() (string, error)
Finder shared.PRFinder
SelectorArg string
DeleteBranch bool
DeleteLocalBranch bool
@ -30,7 +28,6 @@ func NewCmdClose(f *cmdutil.Factory, runF func(*CloseOptions) error) *cobra.Comm
opts := &CloseOptions{
IO: f.IOStreams,
HttpClient: f.HttpClient,
Config: f.Config,
Branch: f.Branch,
}
@ -39,8 +36,7 @@ func NewCmdClose(f *cmdutil.Factory, runF func(*CloseOptions) error) *cobra.Comm
Short: "Close a pull request",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
// support `-R, --repo` override
opts.BaseRepo = f.BaseRepo
opts.Finder = shared.NewFinder(f)
if len(args) > 0 {
opts.SelectorArg = args[0]
@ -62,13 +58,10 @@ func NewCmdClose(f *cmdutil.Factory, runF func(*CloseOptions) error) *cobra.Comm
func closeRun(opts *CloseOptions) error {
cs := opts.IO.ColorScheme()
httpClient, err := opts.HttpClient()
if err != nil {
return err
findOptions := shared.FindOptions{
Selector: opts.SelectorArg,
}
apiClient := api.NewClientFromHTTP(httpClient)
pr, baseRepo, err := shared.PRFromArgs(apiClient, opts.BaseRepo, nil, nil, opts.SelectorArg)
pr, baseRepo, err := opts.Finder.Find(findOptions)
if err != nil {
return err
}
@ -81,6 +74,12 @@ func closeRun(opts *CloseOptions) error {
return nil
}
httpClient, err := opts.HttpClient()
if err != nil {
return err
}
apiClient := api.NewClientFromHTTP(httpClient)
err = api.PullRequestClose(apiClient, baseRepo, pr)
if err != nil {
return fmt.Errorf("API call failed: %w", err)

View file

@ -2,11 +2,8 @@ package comment
import (
"errors"
"net/http"
"github.com/MakeNowJust/heredoc"
"github.com/cli/cli/api"
"github.com/cli/cli/context"
"github.com/cli/cli/internal/ghrepo"
"github.com/cli/cli/pkg/cmd/pr/shared"
"github.com/cli/cli/pkg/cmdutil"
@ -48,7 +45,13 @@ func NewCmdComment(f *cmdutil.Factory, runF func(*shared.CommentableOptions) err
if len(args) > 0 {
selector = args[0]
}
opts.RetrieveCommentable = retrievePR(f.HttpClient, f.BaseRepo, f.Branch, f.Remotes, selector)
finder := shared.NewFinder(f)
opts.RetrieveCommentable = func() (shared.Commentable, ghrepo.Interface, error) {
return finder.Find(shared.FindOptions{
Selector: selector,
Fields: []string{"id", "url"},
})
}
return shared.CommentablePreRun(cmd, opts)
},
RunE: func(cmd *cobra.Command, args []string) error {
@ -74,24 +77,3 @@ func NewCmdComment(f *cmdutil.Factory, runF func(*shared.CommentableOptions) err
return cmd
}
func retrievePR(httpClient func() (*http.Client, error),
baseRepo func() (ghrepo.Interface, error),
branch func() (string, error),
remotes func() (context.Remotes, error),
selector string) func() (shared.Commentable, ghrepo.Interface, error) {
return func() (shared.Commentable, ghrepo.Interface, error) {
httpClient, err := httpClient()
if err != nil {
return nil, nil, err
}
apiClient := api.NewClientFromHTTP(httpClient)
pr, repo, err := shared.PRFromArgs(apiClient, baseRepo, branch, remotes, selector)
if err != nil {
return nil, nil, err
}
return pr, repo, nil
}
}

View file

@ -8,7 +8,7 @@ import (
"path/filepath"
"testing"
"github.com/cli/cli/context"
"github.com/cli/cli/api"
"github.com/cli/cli/internal/ghrepo"
"github.com/cli/cli/pkg/cmd/pr/shared"
"github.com/cli/cli/pkg/cmdutil"
@ -224,7 +224,6 @@ func Test_commentRun(t *testing.T) {
ConfirmSubmitSurvey: func() (bool, error) { return true, nil },
},
httpStubs: func(t *testing.T, reg *httpmock.Registry) {
mockPullRequestFromNumber(t, reg)
mockCommentCreate(t, reg)
},
stdout: "https://github.com/OWNER/REPO/pull/123#issuecomment-456\n",
@ -238,9 +237,6 @@ func Test_commentRun(t *testing.T) {
OpenInBrowser: func(string) error { return nil },
},
httpStubs: func(t *testing.T, reg *httpmock.Registry) {
mockPullRequestFromNumber(t, reg)
},
stderr: "Opening github.com/OWNER/REPO/pull/123 in your browser.\n",
},
{
@ -253,7 +249,6 @@ func Test_commentRun(t *testing.T) {
EditSurvey: func() (string, error) { return "comment body", nil },
},
httpStubs: func(t *testing.T, reg *httpmock.Registry) {
mockPullRequestFromNumber(t, reg)
mockCommentCreate(t, reg)
},
stdout: "https://github.com/OWNER/REPO/pull/123#issuecomment-456\n",
@ -266,7 +261,6 @@ func Test_commentRun(t *testing.T) {
Body: "comment body",
},
httpStubs: func(t *testing.T, reg *httpmock.Registry) {
mockPullRequestFromNumber(t, reg)
mockCommentCreate(t, reg)
},
stdout: "https://github.com/OWNER/REPO/pull/123#issuecomment-456\n",
@ -280,16 +274,20 @@ func Test_commentRun(t *testing.T) {
reg := &httpmock.Registry{}
defer reg.Verify(t)
tt.httpStubs(t, reg)
if tt.httpStubs != nil {
tt.httpStubs(t, reg)
}
httpClient := func() (*http.Client, error) { return &http.Client{Transport: reg}, nil }
baseRepo := func() (ghrepo.Interface, error) { return ghrepo.New("OWNER", "REPO"), nil }
branch := func() (string, error) { return "", nil }
remotes := func() (context.Remotes, error) { return nil, nil }
tt.input.IO = io
tt.input.HttpClient = httpClient
tt.input.RetrieveCommentable = retrievePR(httpClient, baseRepo, branch, remotes, "123")
tt.input.RetrieveCommentable = func() (shared.Commentable, ghrepo.Interface, error) {
return &api.PullRequest{
Number: 123,
URL: "https://github.com/OWNER/REPO/pull/123",
}, ghrepo.New("OWNER", "REPO"), nil
}
t.Run(tt.name, func(t *testing.T) {
err := shared.CommentableRun(tt.input)
@ -300,17 +298,6 @@ func Test_commentRun(t *testing.T) {
}
}
func mockPullRequestFromNumber(_ *testing.T, reg *httpmock.Registry) {
reg.Register(
httpmock.GraphQL(`query PullRequestByNumber\b`),
httpmock.StringResponse(`
{ "data": { "repository": { "pullRequest": {
"number": 123,
"url": "https://github.com/OWNER/REPO/pull/123"
} } } }`),
)
}
func mockCommentCreate(t *testing.T, reg *httpmock.Registry) {
reg.Register(
httpmock.GraphQL(`mutation CommentCreate\b`),

View file

@ -36,6 +36,7 @@ type CreateOptions struct {
Remotes func() (context.Remotes, error)
Branch func() (string, error)
Browser browser
Finder shared.PRFinder
TitleProvided bool
BodyProvided bool
@ -117,6 +118,8 @@ func NewCmdCreate(f *cmdutil.Factory, runF func(*CreateOptions) error) *cobra.Co
`),
Args: cmdutil.NoArgsQuoteReminder,
RunE: func(cmd *cobra.Command, args []string) error {
opts.Finder = shared.NewFinder(f)
opts.TitleProvided = cmd.Flags().Changed("title")
opts.RepoOverride, _ = cmd.Flags().GetString("repo")
noMaintainerEdit, _ := cmd.Flags().GetBool("no-maintainer-edit")
@ -220,9 +223,13 @@ func createRun(opts *CreateOptions) (err error) {
state.Body = opts.Body
}
existingPR, err := api.PullRequestForBranch(
client, ctx.BaseRepo, ctx.BaseBranch, ctx.HeadBranchLabel, []string{"OPEN"})
var notFound *api.NotFoundError
existingPR, _, err := opts.Finder.Find(shared.FindOptions{
Selector: ctx.HeadBranchLabel,
BaseBranch: ctx.BaseBranch,
States: []string{"OPEN"},
Fields: []string{"url"},
})
var notFound *shared.NotFoundError
if err != nil && !errors.As(err, &notFound) {
return fmt.Errorf("error checking for existing pull request: %w", err)
}

View file

@ -11,8 +11,6 @@ import (
"github.com/MakeNowJust/heredoc"
"github.com/cli/cli/api"
"github.com/cli/cli/context"
"github.com/cli/cli/internal/ghrepo"
"github.com/cli/cli/pkg/cmd/pr/shared"
"github.com/cli/cli/pkg/cmdutil"
"github.com/cli/cli/pkg/iostreams"
@ -22,9 +20,8 @@ import (
type DiffOptions struct {
HttpClient func() (*http.Client, error)
IO *iostreams.IOStreams
BaseRepo func() (ghrepo.Interface, error)
Remotes func() (context.Remotes, error)
Branch func() (string, error)
Finder shared.PRFinder
SelectorArg string
UseColor string
@ -34,8 +31,6 @@ func NewCmdDiff(f *cmdutil.Factory, runF func(*DiffOptions) error) *cobra.Comman
opts := &DiffOptions{
IO: f.IOStreams,
HttpClient: f.HttpClient,
Remotes: f.Remotes,
Branch: f.Branch,
}
cmd := &cobra.Command{
@ -49,8 +44,7 @@ func NewCmdDiff(f *cmdutil.Factory, runF func(*DiffOptions) error) *cobra.Comman
`),
Args: cobra.MaximumNArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
// support `-R, --repo` override
opts.BaseRepo = f.BaseRepo
opts.Finder = shared.NewFinder(f)
if repoOverride, _ := cmd.Flags().GetString("repo"); repoOverride != "" && len(args) == 0 {
return &cmdutil.FlagError{Err: errors.New("argument required when using the --repo flag")}
@ -81,17 +75,21 @@ func NewCmdDiff(f *cmdutil.Factory, runF func(*DiffOptions) error) *cobra.Comman
}
func diffRun(opts *DiffOptions) error {
findOptions := shared.FindOptions{
Selector: opts.SelectorArg,
Fields: []string{"number"},
}
pr, baseRepo, err := opts.Finder.Find(findOptions)
if err != nil {
return err
}
httpClient, err := opts.HttpClient()
if err != nil {
return err
}
apiClient := api.NewClientFromHTTP(httpClient)
pr, baseRepo, err := shared.PRFromArgs(apiClient, opts.BaseRepo, opts.Branch, opts.Remotes, opts.SelectorArg)
if err != nil {
return err
}
diff, err := apiClient.PullRequestDiff(baseRepo, pr.Number)
if err != nil {
return fmt.Errorf("could not find pull request diff: %w", err)

View file

@ -7,7 +7,6 @@ import (
"github.com/MakeNowJust/heredoc"
"github.com/cli/cli/api"
"github.com/cli/cli/context"
"github.com/cli/cli/internal/config"
"github.com/cli/cli/internal/ghrepo"
shared "github.com/cli/cli/pkg/cmd/pr/shared"
@ -20,10 +19,8 @@ import (
type EditOptions struct {
HttpClient func() (*http.Client, error)
IO *iostreams.IOStreams
BaseRepo func() (ghrepo.Interface, error)
Remotes func() (context.Remotes, error)
Branch func() (string, error)
Finder shared.PRFinder
Surveyor Surveyor
Fetcher EditableOptionsFetcher
EditorRetriever EditorRetriever
@ -38,8 +35,6 @@ func NewCmdEdit(f *cmdutil.Factory, runF func(*EditOptions) error) *cobra.Comman
opts := &EditOptions{
IO: f.IOStreams,
HttpClient: f.HttpClient,
Remotes: f.Remotes,
Branch: f.Branch,
Surveyor: surveyor{},
Fetcher: fetcher{},
EditorRetriever: editorRetriever{config: f.Config},
@ -66,8 +61,7 @@ func NewCmdEdit(f *cmdutil.Factory, runF func(*EditOptions) error) *cobra.Comman
`),
Args: cobra.MaximumNArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
// support `-R, --repo` override
opts.BaseRepo = f.BaseRepo
opts.Finder = shared.NewFinder(f)
if len(args) > 0 {
opts.SelectorArg = args[0]
@ -155,13 +149,10 @@ func NewCmdEdit(f *cmdutil.Factory, runF func(*EditOptions) error) *cobra.Comman
}
func editRun(opts *EditOptions) error {
httpClient, err := opts.HttpClient()
if err != nil {
return err
findOptions := shared.FindOptions{
Selector: opts.SelectorArg,
}
apiClient := api.NewClientFromHTTP(httpClient)
pr, repo, err := shared.PRFromArgs(apiClient, opts.BaseRepo, opts.Branch, opts.Remotes, opts.SelectorArg)
pr, repo, err := opts.Finder.Find(findOptions)
if err != nil {
return err
}
@ -184,6 +175,12 @@ func editRun(opts *EditOptions) error {
}
}
httpClient, err := opts.HttpClient()
if err != nil {
return err
}
apiClient := api.NewClientFromHTTP(httpClient)
opts.IO.StartProgressIndicator()
err = opts.Fetcher.EditableOptionsFetch(apiClient, repo, &editable)
opts.IO.StopProgressIndicator()

View file

@ -450,11 +450,9 @@ func Test_editRun(t *testing.T) {
tt.httpStubs(t, reg)
httpClient := func() (*http.Client, error) { return &http.Client{Transport: reg}, nil }
baseRepo := func() (ghrepo.Interface, error) { return ghrepo.New("OWNER", "REPO"), nil }
tt.input.IO = io
tt.input.HttpClient = httpClient
tt.input.BaseRepo = baseRepo
t.Run(tt.name, func(t *testing.T) {
err := editRun(tt.input)

View file

@ -8,10 +8,8 @@ import (
"github.com/AlecAivazis/survey/v2"
"github.com/MakeNowJust/heredoc"
"github.com/cli/cli/api"
"github.com/cli/cli/context"
"github.com/cli/cli/git"
"github.com/cli/cli/internal/config"
"github.com/cli/cli/internal/ghrepo"
"github.com/cli/cli/pkg/cmd/pr/shared"
"github.com/cli/cli/pkg/cmdutil"
"github.com/cli/cli/pkg/iostreams"
@ -26,12 +24,11 @@ type editor interface {
type MergeOptions struct {
HttpClient func() (*http.Client, error)
Config func() (config.Config, error)
IO *iostreams.IOStreams
BaseRepo func() (ghrepo.Interface, error)
Remotes func() (context.Remotes, error)
Branch func() (string, error)
Finder shared.PRFinder
SelectorArg string
DeleteBranch bool
MergeMethod PullRequestMergeMethod
@ -52,8 +49,6 @@ func NewCmdMerge(f *cmdutil.Factory, runF func(*MergeOptions) error) *cobra.Comm
opts := &MergeOptions{
IO: f.IOStreams,
HttpClient: f.HttpClient,
Config: f.Config,
Remotes: f.Remotes,
Branch: f.Branch,
}
@ -76,8 +71,7 @@ func NewCmdMerge(f *cmdutil.Factory, runF func(*MergeOptions) error) *cobra.Comm
`),
Args: cobra.MaximumNArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
// support `-R, --repo` override
opts.BaseRepo = f.BaseRepo
opts.Finder = shared.NewFinder(f)
if repoOverride, _ := cmd.Flags().GetString("repo"); repoOverride != "" && len(args) == 0 {
return &cmdutil.FlagError{Err: errors.New("argument required when using the --repo flag")}
@ -136,7 +130,7 @@ func NewCmdMerge(f *cmdutil.Factory, runF func(*MergeOptions) error) *cobra.Comm
opts.Editor = &userEditor{
io: opts.IO,
config: opts.Config,
config: f.Config,
}
if runF != nil {
@ -160,19 +154,23 @@ func NewCmdMerge(f *cmdutil.Factory, runF func(*MergeOptions) error) *cobra.Comm
func mergeRun(opts *MergeOptions) error {
cs := opts.IO.ColorScheme()
httpClient, err := opts.HttpClient()
if err != nil {
return err
findOptions := shared.FindOptions{
Selector: opts.SelectorArg,
Fields: []string{"id", "number", "state", "title", "commits", "mergeable", "headRepositoryOwner", "headRefName"},
}
apiClient := api.NewClientFromHTTP(httpClient)
pr, baseRepo, err := shared.PRFromArgs(apiClient, opts.BaseRepo, opts.Branch, opts.Remotes, opts.SelectorArg)
pr, baseRepo, err := opts.Finder.Find(findOptions)
if err != nil {
return err
}
isTerminal := opts.IO.IsStdoutTTY()
httpClient, err := opts.HttpClient()
if err != nil {
return err
}
apiClient := api.NewClientFromHTTP(httpClient)
if opts.AutoMergeDisable {
err := disableAutoMerge(httpClient, baseRepo, pr.ID)
if err != nil {

View file

@ -13,11 +13,9 @@ import (
"github.com/MakeNowJust/heredoc"
"github.com/cli/cli/api"
"github.com/cli/cli/context"
"github.com/cli/cli/git"
"github.com/cli/cli/internal/config"
"github.com/cli/cli/internal/ghrepo"
"github.com/cli/cli/internal/run"
"github.com/cli/cli/pkg/cmd/pr/shared"
"github.com/cli/cli/pkg/cmdutil"
"github.com/cli/cli/pkg/httpmock"
"github.com/cli/cli/pkg/iostreams"
@ -197,6 +195,14 @@ func Test_NewCmdMerge(t *testing.T) {
}
}
func baseRepo(owner, repo, branch string) ghrepo.Interface {
return api.InitRepoHostname(&api.Repository{
Name: repo,
Owner: api.RepositoryOwner{Login: owner},
DefaultBranchRef: api.BranchRef{Name: branch},
}, "github.com")
}
func runCommand(rt http.RoundTripper, branch string, isTTY bool, cli string) (*test.CmdOut, error) {
io, _, stdout, stderr := iostreams.Test()
io.SetStdoutTTY(isTTY)
@ -208,24 +214,6 @@ func runCommand(rt http.RoundTripper, branch string, isTTY bool, cli string) (*t
HttpClient: func() (*http.Client, error) {
return &http.Client{Transport: rt}, nil
},
Config: func() (config.Config, error) {
return config.NewBlankConfig(), nil
},
BaseRepo: func() (ghrepo.Interface, error) {
return api.InitRepoHostname(&api.Repository{
Name: "REPO",
Owner: api.RepositoryOwner{Login: "OWNER"},
DefaultBranchRef: api.BranchRef{Name: "master"},
}, "github.com"), nil
},
Remotes: func() (context.Remotes, error) {
return context.Remotes{
{
Remote: &git.Remote{Name: "origin"},
Repo: ghrepo.New("OWNER", "REPO"),
},
}, nil
},
Branch: func() (string, error) {
return branch, nil
},
@ -259,17 +247,18 @@ func initFakeHTTP() *httpmock.Registry {
func TestPrMerge(t *testing.T) {
http := initFakeHTTP()
defer http.Verify(t)
http.Register(
httpmock.GraphQL(`query PullRequestByNumber\b`),
httpmock.StringResponse(`
{ "data": { "repository": { "pullRequest": {
"id": "THE-ID",
"number": 1,
"title": "The title of the PR",
"state": "OPEN",
"headRefName": "blueberries",
"headRepositoryOwner": {"login": "OWNER"}
} } } }`))
shared.RunCommandFinder(
"1",
&api.PullRequest{
ID: "THE-ID",
Number: 1,
State: "OPEN",
Title: "The title of the PR",
},
baseRepo("OWNER", "REPO", "master"),
)
http.Register(
httpmock.GraphQL(`mutation PullRequestMerge\b`),
httpmock.GraphQLMutation(`{}`, func(input map[string]interface{}) {
@ -296,17 +285,18 @@ func TestPrMerge(t *testing.T) {
func TestPrMerge_nontty(t *testing.T) {
http := initFakeHTTP()
defer http.Verify(t)
http.Register(
httpmock.GraphQL(`query PullRequestByNumber\b`),
httpmock.StringResponse(`
{ "data": { "repository": { "pullRequest": {
"id": "THE-ID",
"number": 1,
"title": "The title of the PR",
"state": "OPEN",
"headRefName": "blueberries",
"headRepositoryOwner": {"login": "OWNER"}
} } } }`))
shared.RunCommandFinder(
"1",
&api.PullRequest{
ID: "THE-ID",
Number: 1,
State: "OPEN",
Title: "The title of the PR",
},
baseRepo("OWNER", "REPO", "master"),
)
http.Register(
httpmock.GraphQL(`mutation PullRequestMerge\b`),
httpmock.GraphQLMutation(`{}`, func(input map[string]interface{}) {
@ -330,17 +320,18 @@ func TestPrMerge_nontty(t *testing.T) {
func TestPrMerge_withRepoFlag(t *testing.T) {
http := initFakeHTTP()
defer http.Verify(t)
http.Register(
httpmock.GraphQL(`query PullRequestByNumber\b`),
httpmock.StringResponse(`
{ "data": { "repository": { "pullRequest": {
"id": "THE-ID",
"number": 1,
"title": "The title of the PR",
"state": "OPEN",
"headRefName": "blueberries",
"headRepositoryOwner": {"login": "OWNER"}
} } } }`))
shared.RunCommandFinder(
"1",
&api.PullRequest{
ID: "THE-ID",
Number: 1,
State: "OPEN",
Title: "The title of the PR",
},
baseRepo("OWNER", "REPO", "master"),
)
http.Register(
httpmock.GraphQL(`mutation PullRequestMerge\b`),
httpmock.GraphQLMutation(`{}`, func(input map[string]interface{}) {
@ -367,10 +358,19 @@ func TestPrMerge_withRepoFlag(t *testing.T) {
func TestPrMerge_deleteBranch(t *testing.T) {
http := initFakeHTTP()
defer http.Verify(t)
http.Register(
httpmock.GraphQL(`query PullRequestForBranch\b`),
// FIXME: references fixture from another package
httpmock.FileResponse("../view/fixtures/prViewPreviewWithMetadataByBranch.json"))
shared.RunCommandFinder(
"",
&api.PullRequest{
ID: "PR_10",
Number: 10,
State: "OPEN",
Title: "Blueberries are a good fruit",
HeadRefName: "blueberries",
},
baseRepo("OWNER", "REPO", "master"),
)
http.Register(
httpmock.GraphQL(`mutation PullRequestMerge\b`),
httpmock.GraphQLMutation(`{}`, func(input map[string]interface{}) {
@ -385,8 +385,6 @@ func TestPrMerge_deleteBranch(t *testing.T) {
cs, cmdTeardown := run.Stub()
defer cmdTeardown(t)
cs.Register(`git .+ show .+ HEAD`, 1, "")
cs.Register(`git config --get-regexp.+branch\\\.blueberries\\\.`, 0, "")
cs.Register(`git checkout master`, 0, "")
cs.Register(`git rev-parse --verify refs/heads/blueberries`, 0, "")
cs.Register(`git branch -D blueberries`, 0, "")
@ -406,10 +404,19 @@ func TestPrMerge_deleteBranch(t *testing.T) {
func TestPrMerge_deleteNonCurrentBranch(t *testing.T) {
http := initFakeHTTP()
defer http.Verify(t)
http.Register(
httpmock.GraphQL(`query PullRequestForBranch\b`),
// FIXME: references fixture from another package
httpmock.FileResponse("../view/fixtures/prViewPreviewWithMetadataByBranch.json"))
shared.RunCommandFinder(
"blueberries",
&api.PullRequest{
ID: "PR_10",
Number: 10,
State: "OPEN",
Title: "Blueberries are a good fruit",
HeadRefName: "blueberries",
},
baseRepo("OWNER", "REPO", "master"),
)
http.Register(
httpmock.GraphQL(`mutation PullRequestMerge\b`),
httpmock.GraphQLMutation(`{}`, func(input map[string]interface{}) {
@ -439,59 +446,24 @@ func TestPrMerge_deleteNonCurrentBranch(t *testing.T) {
`), output.Stderr())
}
func TestPrMerge_noPrNumberGiven(t *testing.T) {
http := initFakeHTTP()
defer http.Verify(t)
http.Register(
httpmock.GraphQL(`query PullRequestForBranch\b`),
// FIXME: references fixture from another package
httpmock.FileResponse("../view/fixtures/prViewPreviewWithMetadataByBranch.json"))
http.Register(
httpmock.GraphQL(`mutation PullRequestMerge\b`),
httpmock.GraphQLMutation(`{}`, func(input map[string]interface{}) {
assert.Equal(t, "PR_10", input["pullRequestId"].(string))
assert.Equal(t, "MERGE", input["mergeMethod"].(string))
assert.NotContains(t, input, "commitHeadline")
}))
cs, cmdTeardown := run.Stub()
defer cmdTeardown(t)
cs.Register(`git .+ show .+ HEAD`, 1, "")
cs.Register(`git config --get-regexp.+branch\\\.blueberries\\\.`, 0, "")
output, err := runCommand(http, "blueberries", true, "pr merge --merge")
if err != nil {
t.Fatalf("error running command `pr merge`: %v", err)
}
assert.Equal(t, "", output.String())
assert.Equal(t, heredoc.Doc(`
Merged pull request #10 (Blueberries are a good fruit)
`), output.Stderr())
}
func Test_nonDivergingPullRequest(t *testing.T) {
http := initFakeHTTP()
defer http.Verify(t)
http.Register(
httpmock.GraphQL(`query PullRequestForBranch\b`),
httpmock.StringResponse(`
{ "data": { "repository": { "pullRequests": { "nodes": [{
"headRefName": "blueberries",
"headRepositoryOwner": {"login": "OWNER"},
"id": "PR_10",
"title": "Blueberries are a good fruit",
"number": 10,
"commits": {
"nodes": [{
"commit": {
"oid": "COMMITSHA1"
}
}],
"totalCount": 1
}
}] } } } }`))
pr := &api.PullRequest{
ID: "PR_10",
Number: 10,
Title: "Blueberries are a good fruit",
State: "OPEN",
}
pr.StubCommit("COMMITSHA1")
shared.RunCommandFinder(
"",
pr,
baseRepo("OWNER", "REPO", "master"),
)
http.Register(
httpmock.GraphQL(`mutation PullRequestMerge\b`),
httpmock.GraphQLMutation(`{}`, func(input map[string]interface{}) {
@ -504,7 +476,6 @@ func Test_nonDivergingPullRequest(t *testing.T) {
defer cmdTeardown(t)
cs.Register(`git .+ show .+ HEAD`, 0, "COMMITSHA1,title")
cs.Register(`git config --get-regexp.+branch\\\.blueberries\\\.`, 0, "")
output, err := runCommand(http, "blueberries", true, "pr merge --merge")
if err != nil {
@ -519,24 +490,21 @@ func Test_nonDivergingPullRequest(t *testing.T) {
func Test_divergingPullRequestWarning(t *testing.T) {
http := initFakeHTTP()
defer http.Verify(t)
http.Register(
httpmock.GraphQL(`query PullRequestForBranch\b`),
httpmock.StringResponse(`
{ "data": { "repository": { "pullRequests": { "nodes": [{
"headRefName": "blueberries",
"headRepositoryOwner": {"login": "OWNER"},
"id": "PR_10",
"title": "Blueberries are a good fruit",
"number": 10,
"commits": {
"nodes": [{
"commit": {
"oid": "COMMITSHA1"
}
}],
"totalCount": 1
}
}] } } } }`))
pr := &api.PullRequest{
ID: "PR_10",
Number: 10,
Title: "Blueberries are a good fruit",
State: "OPEN",
}
pr.StubCommit("COMMITSHA1")
shared.RunCommandFinder(
"",
pr,
baseRepo("OWNER", "REPO", "master"),
)
http.Register(
httpmock.GraphQL(`mutation PullRequestMerge\b`),
httpmock.GraphQLMutation(`{}`, func(input map[string]interface{}) {
@ -549,7 +517,6 @@ func Test_divergingPullRequestWarning(t *testing.T) {
defer cmdTeardown(t)
cs.Register(`git .+ show .+ HEAD`, 0, "COMMITSHA2,title")
cs.Register(`git config --get-regexp.+branch\\\.blueberries\\\.`, 0, "")
output, err := runCommand(http, "blueberries", true, "pr merge --merge")
if err != nil {
@ -565,20 +532,18 @@ func Test_divergingPullRequestWarning(t *testing.T) {
func Test_pullRequestWithoutCommits(t *testing.T) {
http := initFakeHTTP()
defer http.Verify(t)
http.Register(
httpmock.GraphQL(`query PullRequestForBranch\b`),
httpmock.StringResponse(`
{ "data": { "repository": { "pullRequests": { "nodes": [{
"headRefName": "blueberries",
"headRepositoryOwner": {"login": "OWNER"},
"id": "PR_10",
"title": "Blueberries are a good fruit",
"number": 10,
"commits": {
"nodes": [],
"totalCount": 0
}
}] } } } }`))
shared.RunCommandFinder(
"",
&api.PullRequest{
ID: "PR_10",
Number: 10,
Title: "Blueberries are a good fruit",
State: "OPEN",
},
baseRepo("OWNER", "REPO", "master"),
)
http.Register(
httpmock.GraphQL(`mutation PullRequestMerge\b`),
httpmock.GraphQLMutation(`{}`, func(input map[string]interface{}) {
@ -587,11 +552,9 @@ func Test_pullRequestWithoutCommits(t *testing.T) {
assert.NotContains(t, input, "commitHeadline")
}))
cs, cmdTeardown := run.Stub()
_, cmdTeardown := run.Stub()
defer cmdTeardown(t)
cs.Register(`git config --get-regexp.+branch\\\.blueberries\\\.`, 0, "")
output, err := runCommand(http, "blueberries", true, "pr merge --merge")
if err != nil {
t.Fatalf("error running command `pr merge`: %v", err)
@ -605,17 +568,18 @@ func Test_pullRequestWithoutCommits(t *testing.T) {
func TestPrMerge_rebase(t *testing.T) {
http := initFakeHTTP()
defer http.Verify(t)
http.Register(
httpmock.GraphQL(`query PullRequestByNumber\b`),
httpmock.StringResponse(`
{ "data": { "repository": { "pullRequest": {
"id": "THE-ID",
"number": 2,
"title": "The title of the PR",
"state": "OPEN",
"headRefName": "blueberries",
"headRepositoryOwner": {"login": "OWNER"}
} } } }`))
shared.RunCommandFinder(
"2",
&api.PullRequest{
ID: "THE-ID",
Number: 2,
Title: "The title of the PR",
State: "OPEN",
},
baseRepo("OWNER", "REPO", "master"),
)
http.Register(
httpmock.GraphQL(`mutation PullRequestMerge\b`),
httpmock.GraphQLMutation(`{}`, func(input map[string]interface{}) {
@ -642,17 +606,18 @@ func TestPrMerge_rebase(t *testing.T) {
func TestPrMerge_squash(t *testing.T) {
http := initFakeHTTP()
defer http.Verify(t)
http.Register(
httpmock.GraphQL(`query PullRequestByNumber\b`),
httpmock.StringResponse(`
{ "data": { "repository": { "pullRequest": {
"id": "THE-ID",
"number": 3,
"title": "The title of the PR",
"state": "OPEN",
"headRefName": "blueberries",
"headRepositoryOwner": {"login": "OWNER"}
} } } }`))
shared.RunCommandFinder(
"3",
&api.PullRequest{
ID: "THE-ID",
Number: 3,
Title: "The title of the PR",
State: "OPEN",
},
baseRepo("OWNER", "REPO", "master"),
)
http.Register(
httpmock.GraphQL(`mutation PullRequestMerge\b`),
httpmock.GraphQLMutation(`{}`, func(input map[string]interface{}) {
@ -678,22 +643,18 @@ func TestPrMerge_squash(t *testing.T) {
func TestPrMerge_alreadyMerged(t *testing.T) {
http := initFakeHTTP()
defer http.Verify(t)
http.Register(
httpmock.GraphQL(`query PullRequestByNumber\b`),
httpmock.StringResponse(`
{ "data": { "repository": {
"pullRequest": {
"number": 4,
"title": "The title of the PR",
"state": "MERGED",
"baseRefName": "master",
"headRefName": "blueberries",
"headRepositoryOwner": {
"login": "OWNER"
},
"isCrossRepository": false
}
} } }`))
shared.RunCommandFinder(
"4",
&api.PullRequest{
ID: "THE-ID",
Number: 4,
State: "MERGED",
HeadRefName: "blueberries",
BaseRefName: "master",
},
baseRepo("OWNER", "REPO", "master"),
)
cs, cmdTeardown := run.Stub()
defer cmdTeardown(t)
@ -718,12 +679,17 @@ func TestPrMerge_alreadyMerged(t *testing.T) {
func TestPrMerge_alreadyMerged_nonInteractive(t *testing.T) {
http := initFakeHTTP()
defer http.Verify(t)
http.Register(
httpmock.GraphQL(`query PullRequestByNumber\b`),
httpmock.StringResponse(`
{ "data": { "repository": {
"pullRequest": { "number": 4, "title": "The title of the PR", "state": "MERGED"}
} } }`))
shared.RunCommandFinder(
"4",
&api.PullRequest{
ID: "THE-ID",
Number: 4,
State: "MERGED",
HeadRepositoryOwner: api.Owner{Login: "monalisa"},
},
baseRepo("OWNER", "REPO", "master"),
)
_, cmdTeardown := run.Stub()
defer cmdTeardown(t)
@ -740,15 +706,18 @@ func TestPrMerge_alreadyMerged_nonInteractive(t *testing.T) {
func TestPRMerge_interactive(t *testing.T) {
http := initFakeHTTP()
defer http.Verify(t)
http.Register(
httpmock.GraphQL(`query PullRequestForBranch\b`),
httpmock.StringResponse(`
{ "data": { "repository": { "pullRequests": { "nodes": [{
"headRefName": "blueberries",
"headRepositoryOwner": {"login": "OWNER"},
"id": "THE-ID",
"number": 3
}] } } } }`))
shared.RunCommandFinder(
"",
&api.PullRequest{
ID: "THE-ID",
Number: 3,
Title: "It was the best of times",
HeadRefName: "blueberries",
},
baseRepo("OWNER", "REPO", "master"),
)
http.Register(
httpmock.GraphQL(`query RepositoryInfo\b`),
httpmock.StringResponse(`
@ -765,11 +734,9 @@ func TestPRMerge_interactive(t *testing.T) {
assert.NotContains(t, input, "commitHeadline")
}))
cs, cmdTeardown := run.Stub()
_, cmdTeardown := run.Stub()
defer cmdTeardown(t)
cs.Register(`git config --get-regexp.+branch\\\.blueberries\\\.`, 0, "")
as, surveyTeardown := prompt.InitAskStubber()
defer surveyTeardown()
@ -789,16 +756,18 @@ func TestPRMerge_interactive(t *testing.T) {
func TestPRMerge_interactiveWithDeleteBranch(t *testing.T) {
http := initFakeHTTP()
defer http.Verify(t)
http.Register(
httpmock.GraphQL(`query PullRequestForBranch\b`),
httpmock.StringResponse(`
{ "data": { "repository": { "pullRequests": { "nodes": [{
"headRefName": "blueberries",
"headRepositoryOwner": {"login": "OWNER"},
"id": "THE-ID",
"title": "It was the best of times",
"number": 3
}] } } } }`))
shared.RunCommandFinder(
"",
&api.PullRequest{
ID: "THE-ID",
Number: 3,
Title: "It was the best of times",
HeadRefName: "blueberries",
},
baseRepo("OWNER", "REPO", "master"),
)
http.Register(
httpmock.GraphQL(`query RepositoryInfo\b`),
httpmock.StringResponse(`
@ -821,7 +790,6 @@ func TestPRMerge_interactiveWithDeleteBranch(t *testing.T) {
cs, cmdTeardown := run.Stub()
defer cmdTeardown(t)
cs.Register(`git config --get-regexp.+branch\\\.blueberries\\\.`, 0, "")
cs.Register(`git checkout master`, 0, "")
cs.Register(`git rev-parse --verify refs/heads/blueberries`, 0, "")
cs.Register(`git branch -D blueberries`, 0, "")
@ -851,16 +819,6 @@ func TestPRMerge_interactiveSquashEditCommitMsg(t *testing.T) {
tr := initFakeHTTP()
defer tr.Verify(t)
tr.Register(
httpmock.GraphQL(`query PullRequestByNumber\b`),
httpmock.StringResponse(`
{ "data": { "repository": { "pullRequest": {
"headRepositoryOwner": {"login": "OWNER"},
"id": "THE-ID",
"number": 3,
"title": "title"
} } } }`))
tr.Register(
httpmock.GraphQL(`query RepositoryInfo\b`),
httpmock.StringResponse(`
@ -902,25 +860,28 @@ func TestPRMerge_interactiveSquashEditCommitMsg(t *testing.T) {
},
SelectorArg: "https://github.com/OWNER/REPO/pull/123",
InteractiveMode: true,
Finder: shared.NewMockFinder(
"https://github.com/OWNER/REPO/pull/123",
&api.PullRequest{ID: "THE-ID", Number: 123, Title: "title"},
ghrepo.New("OWNER", "REPO"),
),
})
assert.NoError(t, err)
assert.Equal(t, "", stdout.String())
assert.Equal(t, "✓ Squashed and merged pull request #3 (title)\n", stderr.String())
assert.Equal(t, "✓ Squashed and merged pull request #123 (title)\n", stderr.String())
}
func TestPRMerge_interactiveCancelled(t *testing.T) {
http := initFakeHTTP()
defer http.Verify(t)
http.Register(
httpmock.GraphQL(`query PullRequestForBranch\b`),
httpmock.StringResponse(`
{ "data": { "repository": { "pullRequests": { "nodes": [{
"headRefName": "blueberries",
"headRepositoryOwner": {"login": "OWNER"},
"id": "THE-ID",
"number": 3
}] } } } }`))
shared.RunCommandFinder(
"",
&api.PullRequest{ID: "THE-ID", Number: 123},
ghrepo.New("OWNER", "REPO"),
)
http.Register(
httpmock.GraphQL(`query RepositoryInfo\b`),
httpmock.StringResponse(`
@ -930,11 +891,9 @@ func TestPRMerge_interactiveCancelled(t *testing.T) {
"squashMergeAllowed": true
} } }`))
cs, cmdTeardown := run.Stub()
_, cmdTeardown := run.Stub()
defer cmdTeardown(t)
cs.Register(`git config --get-regexp.+branch\\\.blueberries\\\.`, 0, "")
as, surveyTeardown := prompt.InitAskStubber()
defer surveyTeardown()
@ -971,18 +930,6 @@ func TestMergeRun_autoMerge(t *testing.T) {
tr := initFakeHTTP()
defer tr.Verify(t)
tr.Register(
httpmock.GraphQL(`query PullRequestByNumber\b`),
httpmock.StringResponse(`
{ "data": { "repository": { "pullRequest": {
"id": "THE-ID",
"number": 123,
"title": "The title of the PR",
"state": "OPEN",
"headRefName": "blueberries",
"headRepositoryOwner": {"login": "OWNER"}
} } } }`))
tr.Register(
httpmock.GraphQL(`mutation PullRequestAutoMerge\b`),
httpmock.GraphQLMutation(`{}`, func(input map[string]interface{}) {
@ -1001,6 +948,11 @@ func TestMergeRun_autoMerge(t *testing.T) {
SelectorArg: "https://github.com/OWNER/REPO/pull/123",
AutoMergeEnable: true,
MergeMethod: PullRequestMergeMethodSquash,
Finder: shared.NewMockFinder(
"https://github.com/OWNER/REPO/pull/123",
&api.PullRequest{ID: "THE-ID", Number: 123},
ghrepo.New("OWNER", "REPO"),
),
})
assert.NoError(t, err)
@ -1015,21 +967,11 @@ func TestMergeRun_disableAutoMerge(t *testing.T) {
tr := initFakeHTTP()
defer tr.Verify(t)
tr.Register(
httpmock.GraphQL(`query PullRequestByNumber\b`),
httpmock.StringResponse(`
{ "data": { "repository": { "pullRequest": {
"id": "THE-ID",
"number": 123,
"title": "The title of the PR",
"state": "OPEN",
"headRefName": "blueberries",
"headRepositoryOwner": {"login": "OWNER"}
} } } }`))
tr.Register(
httpmock.GraphQL(`mutation PullRequestAutoMergeDisable\b`),
httpmock.StringResponse(`{}`))
httpmock.GraphQLQuery(`{}`, func(s string, m map[string]interface{}) {
assert.Equal(t, map[string]interface{}{"prID": "THE-ID"}, m)
}))
_, cmdTeardown := run.Stub()
defer cmdTeardown(t)
@ -1041,6 +983,11 @@ func TestMergeRun_disableAutoMerge(t *testing.T) {
},
SelectorArg: "https://github.com/OWNER/REPO/pull/123",
AutoMergeDisable: true,
Finder: shared.NewMockFinder(
"https://github.com/OWNER/REPO/pull/123",
&api.PullRequest{ID: "THE-ID", Number: 123},
ghrepo.New("OWNER", "REPO"),
),
})
assert.NoError(t, err)

View file

@ -24,6 +24,8 @@ type ReadyOptions struct {
Remotes func() (context.Remotes, error)
Branch func() (string, error)
Finder shared.PRFinder
SelectorArg string
}
@ -47,8 +49,7 @@ func NewCmdReady(f *cmdutil.Factory, runF func(*ReadyOptions) error) *cobra.Comm
`),
Args: cobra.MaximumNArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
// support `-R, --repo` override
opts.BaseRepo = f.BaseRepo
opts.Finder = shared.NewFinder(f)
if repoOverride, _ := cmd.Flags().GetString("repo"); repoOverride != "" && len(args) == 0 {
return &cmdutil.FlagError{Err: errors.New("argument required when using the --repo flag")}
@ -71,13 +72,10 @@ func NewCmdReady(f *cmdutil.Factory, runF func(*ReadyOptions) error) *cobra.Comm
func readyRun(opts *ReadyOptions) error {
cs := opts.IO.ColorScheme()
httpClient, err := opts.HttpClient()
if err != nil {
return err
findOptions := shared.FindOptions{
Selector: opts.SelectorArg,
}
apiClient := api.NewClientFromHTTP(httpClient)
pr, baseRepo, err := shared.PRFromArgs(apiClient, opts.BaseRepo, opts.Branch, opts.Remotes, opts.SelectorArg)
pr, baseRepo, err := opts.Finder.Find(findOptions)
if err != nil {
return err
}
@ -90,6 +88,12 @@ func readyRun(opts *ReadyOptions) error {
return nil
}
httpClient, err := opts.HttpClient()
if err != nil {
return err
}
apiClient := api.NewClientFromHTTP(httpClient)
err = api.PullRequestReady(apiClient, baseRepo, pr)
if err != nil {
return fmt.Errorf("API call failed: %w", err)

View file

@ -5,8 +5,6 @@ import (
"net/http"
"github.com/cli/cli/api"
"github.com/cli/cli/internal/config"
"github.com/cli/cli/internal/ghrepo"
"github.com/cli/cli/pkg/cmd/pr/shared"
"github.com/cli/cli/pkg/cmdutil"
"github.com/cli/cli/pkg/iostreams"
@ -15,9 +13,9 @@ import (
type ReopenOptions struct {
HttpClient func() (*http.Client, error)
Config func() (config.Config, error)
IO *iostreams.IOStreams
BaseRepo func() (ghrepo.Interface, error)
Finder shared.PRFinder
SelectorArg string
}
@ -26,7 +24,6 @@ func NewCmdReopen(f *cmdutil.Factory, runF func(*ReopenOptions) error) *cobra.Co
opts := &ReopenOptions{
IO: f.IOStreams,
HttpClient: f.HttpClient,
Config: f.Config,
}
cmd := &cobra.Command{
@ -34,8 +31,7 @@ func NewCmdReopen(f *cmdutil.Factory, runF func(*ReopenOptions) error) *cobra.Co
Short: "Reopen a pull request",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
// support `-R, --repo` override
opts.BaseRepo = f.BaseRepo
opts.Finder = shared.NewFinder(f)
if len(args) > 0 {
opts.SelectorArg = args[0]
@ -54,13 +50,10 @@ func NewCmdReopen(f *cmdutil.Factory, runF func(*ReopenOptions) error) *cobra.Co
func reopenRun(opts *ReopenOptions) error {
cs := opts.IO.ColorScheme()
httpClient, err := opts.HttpClient()
if err != nil {
return err
findOptions := shared.FindOptions{
Selector: opts.SelectorArg,
}
apiClient := api.NewClientFromHTTP(httpClient)
pr, baseRepo, err := shared.PRFromArgs(apiClient, opts.BaseRepo, nil, nil, opts.SelectorArg)
pr, baseRepo, err := opts.Finder.Find(findOptions)
if err != nil {
return err
}
@ -75,6 +68,12 @@ func reopenRun(opts *ReopenOptions) error {
return nil
}
httpClient, err := opts.HttpClient()
if err != nil {
return err
}
apiClient := api.NewClientFromHTTP(httpClient)
err = api.PullRequestReopen(apiClient, baseRepo, pr)
if err != nil {
return fmt.Errorf("API call failed: %w", err)

View file

@ -8,9 +8,7 @@ import (
"github.com/AlecAivazis/survey/v2"
"github.com/MakeNowJust/heredoc"
"github.com/cli/cli/api"
"github.com/cli/cli/context"
"github.com/cli/cli/internal/config"
"github.com/cli/cli/internal/ghrepo"
"github.com/cli/cli/pkg/cmd/pr/shared"
"github.com/cli/cli/pkg/cmdutil"
"github.com/cli/cli/pkg/iostreams"
@ -24,9 +22,8 @@ type ReviewOptions struct {
HttpClient func() (*http.Client, error)
Config func() (config.Config, error)
IO *iostreams.IOStreams
BaseRepo func() (ghrepo.Interface, error)
Remotes func() (context.Remotes, error)
Branch func() (string, error)
Finder shared.PRFinder
SelectorArg string
InteractiveMode bool
@ -39,8 +36,6 @@ func NewCmdReview(f *cmdutil.Factory, runF func(*ReviewOptions) error) *cobra.Co
IO: f.IOStreams,
HttpClient: f.HttpClient,
Config: f.Config,
Remotes: f.Remotes,
Branch: f.Branch,
}
var (
@ -74,8 +69,7 @@ func NewCmdReview(f *cmdutil.Factory, runF func(*ReviewOptions) error) *cobra.Co
`),
Args: cobra.MaximumNArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
// support `-R, --repo` override
opts.BaseRepo = f.BaseRepo
opts.Finder = shared.NewFinder(f)
if repoOverride, _ := cmd.Flags().GetString("repo"); repoOverride != "" && len(args) == 0 {
return &cmdutil.FlagError{Err: errors.New("argument required when using the --repo flag")}
@ -151,13 +145,10 @@ func NewCmdReview(f *cmdutil.Factory, runF func(*ReviewOptions) error) *cobra.Co
}
func reviewRun(opts *ReviewOptions) error {
httpClient, err := opts.HttpClient()
if err != nil {
return err
findOptions := shared.FindOptions{
Selector: opts.SelectorArg,
}
apiClient := api.NewClientFromHTTP(httpClient)
pr, baseRepo, err := shared.PRFromArgs(apiClient, opts.BaseRepo, opts.Branch, opts.Remotes, opts.SelectorArg)
pr, baseRepo, err := opts.Finder.Find(findOptions)
if err != nil {
return err
}
@ -183,6 +174,12 @@ func reviewRun(opts *ReviewOptions) error {
}
}
httpClient, err := opts.HttpClient()
if err != nil {
return err
}
apiClient := api.NewClientFromHTTP(httpClient)
err = api.AddReview(apiClient, baseRepo, pr, reviewData)
if err != nil {
return fmt.Errorf("failed to create review: %w", err)

328
pkg/cmd/pr/shared/finder.go Normal file
View file

@ -0,0 +1,328 @@
package shared
import (
"errors"
"fmt"
"net/http"
"net/url"
"regexp"
"sort"
"strconv"
"strings"
"github.com/cli/cli/api"
"github.com/cli/cli/context"
"github.com/cli/cli/git"
"github.com/cli/cli/internal/ghrepo"
"github.com/cli/cli/pkg/cmdutil"
"github.com/cli/cli/pkg/set"
)
type PRFinder interface {
Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, error)
}
type progressIndicator interface {
StartProgressIndicator()
StopProgressIndicator()
}
type finder struct {
baseRepoFn func() (ghrepo.Interface, error)
branchFn func() (string, error)
remotesFn func() (context.Remotes, error)
httpClient func() (*http.Client, error)
progress progressIndicator
repo ghrepo.Interface
prNumber int
branchName string
}
func NewFinder(factory *cmdutil.Factory) PRFinder {
if runCommandFinder != nil {
f := runCommandFinder
runCommandFinder = &mockFinder{err: errors.New("you must use a RunCommandFinder to stub PR lookups")}
return f
}
return &finder{
baseRepoFn: factory.BaseRepo,
branchFn: factory.Branch,
remotesFn: factory.Remotes,
httpClient: factory.HttpClient,
progress: factory.IOStreams,
}
}
var runCommandFinder PRFinder
// RunCommandFinder is the NewMockFinder substitute to be used ONLY in runCommand-style tests.
func RunCommandFinder(selector string, pr *api.PullRequest, repo ghrepo.Interface) {
runCommandFinder = NewMockFinder(selector, pr, repo)
}
type FindOptions struct {
// Selector can be a number with optional `#` prefix, a branch name with optional `<owner>:` prefix, or
// a PR URL.
Selector string
// Fields lists the GraphQL fields to fetch for the PullRequest.
Fields []string
// BaseBranch is the name of the base branch to scope the PR-for-branch lookup to.
BaseBranch string
// States lists the possible PR states to scope the PR-for-branch lookup to.
States []string
}
func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, error) {
if len(opts.Fields) == 0 {
return nil, nil, errors.New("Find error: no fields specified")
}
_ = f.parseURL(opts.Selector)
if f.repo == nil {
repo, err := f.baseRepoFn()
if err != nil {
return nil, nil, fmt.Errorf("could not determine base repo: %w", err)
}
f.repo = repo
}
if opts.Selector == "" {
if err := f.parseCurrentBranch(); err != nil {
return nil, nil, err
}
} else if f.prNumber == 0 {
if prNumber, err := strconv.Atoi(strings.TrimPrefix(opts.Selector, "#")); err == nil {
f.prNumber = prNumber
} else {
f.branchName = opts.Selector
}
}
httpClient, err := f.httpClient()
if err != nil {
return nil, nil, err
}
if f.progress != nil {
f.progress.StartProgressIndicator()
defer f.progress.StopProgressIndicator()
}
if f.prNumber > 0 {
if len(opts.Fields) == 1 && opts.Fields[0] == "number" {
// avoid hitting the API if we already have all the information
return &api.PullRequest{Number: f.prNumber}, f.repo, nil
}
pr, err := findByNumber(httpClient, f.repo, f.prNumber, opts.Fields)
return pr, f.repo, err
}
pr, err := findForBranch(httpClient, f.repo, opts.BaseBranch, f.branchName, opts.States, opts.Fields)
// TODO: preload view: api.ReviewsForPullRequest, api.CommentsForPullRequest
// TODO: preload checks: get all checks
return pr, f.repo, err
}
var pullURLRE = regexp.MustCompile(`^/([^/]+)/([^/]+)/pull/(\d+)`)
func (f *finder) parseURL(prURL string) error {
if prURL == "" {
return fmt.Errorf("invalid URL: %q", prURL)
}
u, err := url.Parse(prURL)
if err != nil {
return err
}
if u.Scheme != "https" && u.Scheme != "http" {
return fmt.Errorf("invalid scheme: %s", u.Scheme)
}
m := pullURLRE.FindStringSubmatch(u.Path)
if m == nil {
return fmt.Errorf("not a pull request URL: %s", prURL)
}
f.repo = ghrepo.NewWithHost(m[1], m[2], u.Hostname())
f.prNumber, _ = strconv.Atoi(m[3])
return nil
}
var prHeadRE = regexp.MustCompile(`^refs/pull/(\d+)/head$`)
func (f *finder) parseCurrentBranch() error {
prHeadRef, err := f.branchFn()
if err != nil {
return err
}
branchConfig := git.ReadBranchConfig(prHeadRef)
// the branch is configured to merge a special PR head ref
if m := prHeadRE.FindStringSubmatch(branchConfig.MergeRef); m != nil {
f.prNumber, _ = strconv.Atoi(m[1])
return nil
}
var branchOwner string
if branchConfig.RemoteURL != nil {
// the branch merges from a remote specified by URL
if r, err := ghrepo.FromURL(branchConfig.RemoteURL); err == nil {
branchOwner = r.RepoOwner()
}
} else if branchConfig.RemoteName != "" {
// the branch merges from a remote specified by name
rem, _ := f.remotesFn()
if r, err := rem.FindByName(branchConfig.RemoteName); err == nil {
branchOwner = r.RepoOwner()
}
}
if branchOwner != "" {
if strings.HasPrefix(branchConfig.MergeRef, "refs/heads/") {
prHeadRef = strings.TrimPrefix(branchConfig.MergeRef, "refs/heads/")
}
// prepend `OWNER:` if this branch is pushed to a fork
if !strings.EqualFold(branchOwner, f.repo.RepoOwner()) {
prHeadRef = fmt.Sprintf("%s:%s", branchOwner, prHeadRef)
}
}
f.branchName = prHeadRef
return nil
}
func findByNumber(httpClient *http.Client, repo ghrepo.Interface, number int, fields []string) (*api.PullRequest, error) {
type response struct {
Repository struct {
PullRequest api.PullRequest
}
}
query := fmt.Sprintf(`
query PullRequestByNumber($owner: String!, $repo: String!, $pr_number: Int!) {
repository(owner: $owner, name: $repo) {
pullRequest(number: $pr_number) {%s}
}
}`, api.PullRequestGraphQL(fields))
variables := map[string]interface{}{
"owner": repo.RepoOwner(),
"repo": repo.RepoName(),
"pr_number": number,
}
var resp response
client := api.NewClientFromHTTP(httpClient)
err := client.GraphQL(repo.RepoHost(), query, variables, &resp)
if err != nil {
return nil, err
}
return &resp.Repository.PullRequest, nil
}
func findForBranch(httpClient *http.Client, repo ghrepo.Interface, baseBranch, headBranch string, stateFilters, fields []string) (*api.PullRequest, error) {
type response struct {
Repository struct {
PullRequests struct {
Nodes []api.PullRequest
}
}
}
fieldSet := set.NewStringSet()
fieldSet.AddValues(fields)
// these fields are required for filtering below
fieldSet.AddValues([]string{"state", "baseRefName", "headRefName", "isCrossRepository", "headRepositoryOwner"})
query := fmt.Sprintf(`
query PullRequestForBranch($owner: String!, $repo: String!, $headRefName: String!, $states: [PullRequestState!]) {
repository(owner: $owner, name: $repo) {
pullRequests(headRefName: $headRefName, states: $states, first: 30, orderBy: { field: CREATED_AT, direction: DESC }) {
nodes {%s}
}
}
}`, api.PullRequestGraphQL(fieldSet.ToSlice()))
branchWithoutOwner := headBranch
if idx := strings.Index(headBranch, ":"); idx >= 0 {
branchWithoutOwner = headBranch[idx+1:]
}
variables := map[string]interface{}{
"owner": repo.RepoOwner(),
"repo": repo.RepoName(),
"headRefName": branchWithoutOwner,
"states": stateFilters,
}
var resp response
client := api.NewClientFromHTTP(httpClient)
err := client.GraphQL(repo.RepoHost(), query, variables, &resp)
if err != nil {
return nil, err
}
prs := resp.Repository.PullRequests.Nodes
sort.SliceStable(prs, func(a, b int) bool {
return prs[a].State == "OPEN" && prs[b].State != "OPEN"
})
for _, pr := range prs {
if pr.HeadLabel() == headBranch && (baseBranch == "" || pr.BaseRefName == baseBranch) {
return &pr, nil
}
}
return nil, &NotFoundError{fmt.Errorf("no pull requests found for branch %q", headBranch)}
}
type NotFoundError struct {
error
}
func (err *NotFoundError) Unwrap() error {
return err.error
}
func NewMockFinder(selector string, pr *api.PullRequest, repo ghrepo.Interface) PRFinder {
return &mockFinder{
expectSelector: selector,
pr: pr,
repo: repo,
}
}
type mockFinder struct {
called bool
expectSelector string
pr *api.PullRequest
repo ghrepo.Interface
err error
}
func (m *mockFinder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, error) {
if m.err != nil {
return nil, nil, m.err
}
if m.expectSelector != opts.Selector {
return nil, nil, fmt.Errorf("mockFinder: expected selector %q, got %q", m.expectSelector, opts.Selector)
}
if m.called {
return nil, nil, errors.New("mockFinder used more than once")
}
m.called = true
if m.pr.HeadRepositoryOwner.Login == "" {
// pose as same-repo PR by default
m.pr.HeadRepositoryOwner.Login = m.repo.RepoOwner()
}
return m.pr, m.repo, nil
}

View file

@ -4,13 +4,12 @@ import (
"net/http"
"testing"
"github.com/cli/cli/api"
"github.com/cli/cli/context"
"github.com/cli/cli/internal/ghrepo"
"github.com/cli/cli/pkg/httpmock"
)
func TestPRFromArgs(t *testing.T) {
func TestFind(t *testing.T) {
type args struct {
baseRepoFn func() (ghrepo.Interface, error)
branchFn func() (string, error)
@ -68,12 +67,6 @@ func TestPRFromArgs(t *testing.T) {
baseRepoFn: nil,
},
httpStub: func(r *httpmock.Registry) {
r.Register(
httpmock.GraphQL(`query PullRequest_fields\b`),
httpmock.StringResponse(`{"data":{}}`))
r.Register(
httpmock.GraphQL(`query PullRequest_fields2\b`),
httpmock.StringResponse(`{"data":{}}`))
r.Register(
httpmock.GraphQL(`query PullRequestByNumber\b`),
httpmock.StringResponse(`{"data":{"repository":{
@ -87,17 +80,30 @@ func TestPRFromArgs(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reg := &httpmock.Registry{}
defer reg.Verify(t)
if tt.httpStub != nil {
tt.httpStub(reg)
}
httpClient := &http.Client{Transport: reg}
pr, repo, err := PRFromArgs(api.NewClientFromHTTP(httpClient), tt.args.baseRepoFn, tt.args.branchFn, tt.args.remotesFn, tt.args.selector)
f := finder{
httpClient: func() (*http.Client, error) {
return &http.Client{Transport: reg}, nil
},
baseRepoFn: tt.args.baseRepoFn,
branchFn: tt.args.branchFn,
remotesFn: tt.args.remotesFn,
}
pr, repo, err := f.Find(FindOptions{
Selector: tt.args.selector,
})
if (err != nil) != tt.wantErr {
t.Errorf("IssueFromArg() error = %v, wantErr %v", err, tt.wantErr)
t.Errorf("Find() error = %v, wantErr %v", err, tt.wantErr)
return
}
if pr.Number != tt.wantPR {
t.Errorf("want issue #%d, got #%d", tt.wantPR, pr.Number)
t.Errorf("want pr #%d, got #%d", tt.wantPR, pr.Number)
}
repoURL := ghrepo.GenerateRepoURL(repo, "")
if repoURL != tt.wantRepo {

View file

@ -1,121 +0,0 @@
package shared
import (
"fmt"
"net/url"
"regexp"
"strconv"
"strings"
"github.com/cli/cli/api"
"github.com/cli/cli/context"
"github.com/cli/cli/git"
"github.com/cli/cli/internal/ghrepo"
)
// PRFromArgs looks up the pull request from either the number/branch/URL argument or one belonging to the current branch
//
// NOTE: this API isn't great, but is here as a compatibility layer between old-style and new-style commands
func PRFromArgs(apiClient *api.Client, baseRepoFn func() (ghrepo.Interface, error), branchFn func() (string, error), remotesFn func() (context.Remotes, error), arg string) (*api.PullRequest, ghrepo.Interface, error) {
if arg != "" {
// First check to see if the prString is a url, return repo from url if found. This
// is run first because we don't need to run determineBaseRepo for this path
pr, r, err := prFromURL(apiClient, arg)
if pr != nil || err != nil {
return pr, r, err
}
}
repo, err := baseRepoFn()
if err != nil {
return nil, nil, fmt.Errorf("could not determine base repo: %w", err)
}
// If there are no args see if we can guess the PR from the current branch
if arg == "" {
pr, err := prForCurrentBranch(apiClient, repo, branchFn, remotesFn)
return pr, repo, err
} else {
// Next see if the prString is a number and use that to look up the url
pr, err := prFromNumberString(apiClient, repo, arg)
if pr != nil || err != nil {
return pr, repo, err
}
// Last see if it is a branch name
pr, err = api.PullRequestForBranch(apiClient, repo, "", arg, nil)
return pr, repo, err
}
}
func prFromNumberString(apiClient *api.Client, repo ghrepo.Interface, s string) (*api.PullRequest, error) {
if prNumber, err := strconv.Atoi(strings.TrimPrefix(s, "#")); err == nil {
return api.PullRequestByNumber(apiClient, repo, prNumber)
}
return nil, nil
}
var pullURLRE = regexp.MustCompile(`^/([^/]+)/([^/]+)/pull/(\d+)`)
func prFromURL(apiClient *api.Client, s string) (*api.PullRequest, ghrepo.Interface, error) {
u, err := url.Parse(s)
if err != nil {
return nil, nil, nil
}
if u.Scheme != "https" && u.Scheme != "http" {
return nil, nil, nil
}
m := pullURLRE.FindStringSubmatch(u.Path)
if m == nil {
return nil, nil, nil
}
repo := ghrepo.NewWithHost(m[1], m[2], u.Hostname())
prNumberString := m[3]
pr, err := prFromNumberString(apiClient, repo, prNumberString)
return pr, repo, err
}
func prForCurrentBranch(apiClient *api.Client, repo ghrepo.Interface, branchFn func() (string, error), remotesFn func() (context.Remotes, error)) (*api.PullRequest, error) {
prHeadRef, err := branchFn()
if err != nil {
return nil, err
}
branchConfig := git.ReadBranchConfig(prHeadRef)
// the branch is configured to merge a special PR head ref
prHeadRE := regexp.MustCompile(`^refs/pull/(\d+)/head$`)
if m := prHeadRE.FindStringSubmatch(branchConfig.MergeRef); m != nil {
return prFromNumberString(apiClient, repo, m[1])
}
var branchOwner string
if branchConfig.RemoteURL != nil {
// the branch merges from a remote specified by URL
if r, err := ghrepo.FromURL(branchConfig.RemoteURL); err == nil {
branchOwner = r.RepoOwner()
}
} else if branchConfig.RemoteName != "" {
// the branch merges from a remote specified by name
rem, _ := remotesFn()
if r, err := rem.FindByName(branchConfig.RemoteName); err == nil {
branchOwner = r.RepoOwner()
}
}
if branchOwner != "" {
if strings.HasPrefix(branchConfig.MergeRef, "refs/heads/") {
prHeadRef = strings.TrimPrefix(branchConfig.MergeRef, "refs/heads/")
}
// prepend `OWNER:` if this branch is pushed to a fork
if !strings.EqualFold(branchOwner, repo.RepoOwner()) {
prHeadRef = fmt.Sprintf("%s:%s", branchOwner, prHeadRef)
}
}
return api.PullRequestForBranch(apiClient, repo, "", prHeadRef, nil)
}

View file

@ -3,17 +3,12 @@ package view
import (
"errors"
"fmt"
"net/http"
"sort"
"strconv"
"strings"
"sync"
"github.com/MakeNowJust/heredoc"
"github.com/cli/cli/api"
"github.com/cli/cli/context"
"github.com/cli/cli/internal/config"
"github.com/cli/cli/internal/ghrepo"
"github.com/cli/cli/pkg/cmd/pr/shared"
"github.com/cli/cli/pkg/cmdutil"
"github.com/cli/cli/pkg/iostreams"
@ -27,14 +22,10 @@ type browser interface {
}
type ViewOptions struct {
HttpClient func() (*http.Client, error)
Config func() (config.Config, error)
IO *iostreams.IOStreams
Browser browser
BaseRepo func() (ghrepo.Interface, error)
Remotes func() (context.Remotes, error)
Branch func() (string, error)
IO *iostreams.IOStreams
Browser browser
Finder shared.PRFinder
Exporter cmdutil.Exporter
SelectorArg string
@ -44,12 +35,8 @@ type ViewOptions struct {
func NewCmdView(f *cmdutil.Factory, runF func(*ViewOptions) error) *cobra.Command {
opts := &ViewOptions{
IO: f.IOStreams,
HttpClient: f.HttpClient,
Config: f.Config,
Remotes: f.Remotes,
Branch: f.Branch,
Browser: f.Browser,
IO: f.IOStreams,
Browser: f.Browser,
}
cmd := &cobra.Command{
@ -65,8 +52,7 @@ func NewCmdView(f *cmdutil.Factory, runF func(*ViewOptions) error) *cobra.Comman
`),
Args: cobra.MaximumNArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
// support `-R, --repo` override
opts.BaseRepo = f.BaseRepo
opts.Finder = shared.NewFinder(f)
if repoOverride, _ := cmd.Flags().GetString("repo"); repoOverride != "" && len(args) == 0 {
return &cmdutil.FlagError{Err: errors.New("argument required when using the --repo flag")}
@ -90,10 +76,26 @@ func NewCmdView(f *cmdutil.Factory, runF func(*ViewOptions) error) *cobra.Comman
return cmd
}
var defaultFields = []string{
"url", "number", "title", "state", "body", "author",
"isDraft", "maintainerCanModify", "mergeable", "additions", "deletions",
"baseRefName", "headRefName", "headRepositoryOwner", "headRepository", "isCrossRepository",
"reviewRequests", "reviews", "assignees", "labels", "projectCards", "milestone",
"comments", // TODO: fetch only 1 last comment unless `opts.Comments` was set
"reactionGroups",
}
func viewRun(opts *ViewOptions) error {
opts.IO.StartProgressIndicator()
pr, err := retrievePullRequest(opts)
opts.IO.StopProgressIndicator()
findOptions := shared.FindOptions{
Selector: opts.SelectorArg,
Fields: defaultFields,
}
if opts.BrowserMode {
findOptions.Fields = []string{"url"}
} else if opts.Exporter != nil {
findOptions.Fields = opts.Exporter.Fields()
}
pr, _, err := opts.Finder.Find(findOptions)
if err != nil {
return err
}
@ -413,51 +415,3 @@ func prStateWithDraft(pr *api.PullRequest) string {
return pr.State
}
func retrievePullRequest(opts *ViewOptions) (*api.PullRequest, error) {
httpClient, err := opts.HttpClient()
if err != nil {
return nil, err
}
apiClient := api.NewClientFromHTTP(httpClient)
pr, repo, err := shared.PRFromArgs(apiClient, opts.BaseRepo, opts.Branch, opts.Remotes, opts.SelectorArg)
if err != nil {
return nil, err
}
if opts.BrowserMode {
return pr, nil
}
var errp, errc error
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
var reviews *api.PullRequestReviews
reviews, errp = api.ReviewsForPullRequest(apiClient, repo, pr)
pr.Reviews = *reviews
}()
if opts.Comments {
wg.Add(1)
go func() {
defer wg.Done()
var comments *api.Comments
comments, errc = api.CommentsForPullRequest(apiClient, repo, pr)
pr.Comments = *comments
}()
}
wg.Wait()
if errp != nil {
err = errp
}
if errc != nil {
err = errc
}
return pr, err
}

View file

@ -134,11 +134,10 @@ func GetAnnotations(client *api.Client, repo ghrepo.Interface, job Job) ([]Annot
err := client.REST(repo.RepoHost(), "GET", path, nil, &result)
if err != nil {
var notFound *api.NotFoundError
if !errors.As(err, &notFound) {
var httpError api.HTTPError
if errors.As(err, &httpError) && httpError.StatusCode == 404 {
return []Annotation{}, nil
}
return nil, err
}