diff --git a/api/queries_pr.go b/api/queries_pr.go index b5441a359..dfeb8608d 100644 --- a/api/queries_pr.go +++ b/api/queries_pr.go @@ -118,12 +118,6 @@ type PullRequestCommitCommit struct { } } -func (pr *PullRequest) StubCommit(oid string) { - pr.Commits.Nodes = append(pr.Commits.Nodes, PullRequestCommit{ - Commit: PullRequestCommitCommit{Oid: oid}, - }) -} - type PullRequestFile struct { Path string `json:"path"` Additions int `json:"additions"` diff --git a/pkg/cmd/pr/close/close.go b/pkg/cmd/pr/close/close.go index 041d48edc..2c9442ecb 100644 --- a/pkg/cmd/pr/close/close.go +++ b/pkg/cmd/pr/close/close.go @@ -122,7 +122,7 @@ func closeRun(opts *CloseOptions) error { } if pr.IsCrossRepository { - fmt.Fprintf(opts.IO.ErrOut, "%s Avoiding deleting the remote branch of a pull request from fork\n", cs.WarningIcon()) + fmt.Fprintf(opts.IO.ErrOut, "%s Skipped deleting the remote branch of a pull request from fork\n", cs.WarningIcon()) if !opts.DeleteLocalBranch { return nil } diff --git a/pkg/cmd/pr/close/close_test.go b/pkg/cmd/pr/close/close_test.go index 4024398dd..4a02dc13d 100644 --- a/pkg/cmd/pr/close/close_test.go +++ b/pkg/cmd/pr/close/close_test.go @@ -191,7 +191,7 @@ func TestPrClose_deleteBranch_crossRepo(t *testing.T) { assert.Equal(t, "", output.String()) assert.Equal(t, heredoc.Doc(` ✓ Closed pull request #96 (The title of the PR) - ! Avoiding deleting the remote branch of a pull request from fork + ! Skipped deleting the remote branch of a pull request from fork ✓ Deleted branch blueberries `), output.Stderr()) } diff --git a/pkg/cmd/pr/merge/merge_test.go b/pkg/cmd/pr/merge/merge_test.go index a91b9740d..4abb88aca 100644 --- a/pkg/cmd/pr/merge/merge_test.go +++ b/pkg/cmd/pr/merge/merge_test.go @@ -203,6 +203,12 @@ func baseRepo(owner, repo, branch string) ghrepo.Interface { }, "github.com") } +func stubCommit(pr *api.PullRequest, oid string) { + pr.Commits.Nodes = append(pr.Commits.Nodes, api.PullRequestCommit{ + Commit: api.PullRequestCommitCommit{Oid: oid}, + }) +} + func runCommand(rt http.RoundTripper, branch string, isTTY bool, cli string) (*test.CmdOut, error) { io, _, stdout, stderr := iostreams.Test() io.SetStdoutTTY(isTTY) @@ -456,7 +462,7 @@ func Test_nonDivergingPullRequest(t *testing.T) { Title: "Blueberries are a good fruit", State: "OPEN", } - pr.StubCommit("COMMITSHA1") + stubCommit(pr, "COMMITSHA1") shared.RunCommandFinder( "", @@ -497,7 +503,7 @@ func Test_divergingPullRequestWarning(t *testing.T) { Title: "Blueberries are a good fruit", State: "OPEN", } - pr.StubCommit("COMMITSHA1") + stubCommit(pr, "COMMITSHA1") shared.RunCommandFinder( "", diff --git a/pkg/cmd/pr/shared/finder.go b/pkg/cmd/pr/shared/finder.go index c17885ab2..52b2a436a 100644 --- a/pkg/cmd/pr/shared/finder.go +++ b/pkg/cmd/pr/shared/finder.go @@ -28,11 +28,12 @@ type progressIndicator interface { } type finder struct { - baseRepoFn func() (ghrepo.Interface, error) - branchFn func() (string, error) - remotesFn func() (context.Remotes, error) - httpClient func() (*http.Client, error) - progress progressIndicator + baseRepoFn func() (ghrepo.Interface, error) + branchFn func() (string, error) + remotesFn func() (context.Remotes, error) + httpClient func() (*http.Client, error) + branchConfig func(string) git.BranchConfig + progress progressIndicator repo ghrepo.Interface prNumber int @@ -47,11 +48,12 @@ func NewFinder(factory *cmdutil.Factory) PRFinder { } return &finder{ - baseRepoFn: factory.BaseRepo, - branchFn: factory.Branch, - remotesFn: factory.Remotes, - httpClient: factory.HttpClient, - progress: factory.IOStreams, + baseRepoFn: factory.BaseRepo, + branchFn: factory.Branch, + remotesFn: factory.Remotes, + httpClient: factory.HttpClient, + progress: factory.IOStreams, + branchConfig: git.ReadBranchConfig, } } @@ -79,7 +81,10 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err return nil, nil, errors.New("Find error: no fields specified") } - _ = f.parseURL(opts.Selector) + if repo, prNumber, err := f.parseURL(opts.Selector); err == nil { + f.prNumber = prNumber + f.repo = repo + } if f.repo == nil { repo, err := f.baseRepoFn() @@ -90,8 +95,12 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err } if opts.Selector == "" { - if err := f.parseCurrentBranch(); err != nil { + if branch, prNumber, err := f.parseCurrentBranch(); err != nil { return nil, nil, err + } else if prNumber > 0 { + f.prNumber = prNumber + } else { + f.branchName = branch } } else if f.prNumber == 0 { if prNumber, err := strconv.Atoi(strings.TrimPrefix(opts.Selector, "#")); err == nil { @@ -129,44 +138,44 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err var pullURLRE = regexp.MustCompile(`^/([^/]+)/([^/]+)/pull/(\d+)`) -func (f *finder) parseURL(prURL string) error { +func (f *finder) parseURL(prURL string) (ghrepo.Interface, int, error) { if prURL == "" { - return fmt.Errorf("invalid URL: %q", prURL) + return nil, 0, fmt.Errorf("invalid URL: %q", prURL) } u, err := url.Parse(prURL) if err != nil { - return err + return nil, 0, err } if u.Scheme != "https" && u.Scheme != "http" { - return fmt.Errorf("invalid scheme: %s", u.Scheme) + return nil, 0, fmt.Errorf("invalid scheme: %s", u.Scheme) } m := pullURLRE.FindStringSubmatch(u.Path) if m == nil { - return fmt.Errorf("not a pull request URL: %s", prURL) + return nil, 0, fmt.Errorf("not a pull request URL: %s", prURL) } - f.repo = ghrepo.NewWithHost(m[1], m[2], u.Hostname()) - f.prNumber, _ = strconv.Atoi(m[3]) - return nil + repo := ghrepo.NewWithHost(m[1], m[2], u.Hostname()) + prNumber, _ := strconv.Atoi(m[3]) + return repo, prNumber, nil } var prHeadRE = regexp.MustCompile(`^refs/pull/(\d+)/head$`) -func (f *finder) parseCurrentBranch() error { +func (f *finder) parseCurrentBranch() (string, int, error) { prHeadRef, err := f.branchFn() if err != nil { - return err + return "", 0, err } - branchConfig := git.ReadBranchConfig(prHeadRef) + branchConfig := f.branchConfig(prHeadRef) // the branch is configured to merge a special PR head ref if m := prHeadRE.FindStringSubmatch(branchConfig.MergeRef); m != nil { - f.prNumber, _ = strconv.Atoi(m[1]) - return nil + prNumber, _ := strconv.Atoi(m[1]) + return "", prNumber, nil } var branchOwner string @@ -193,8 +202,7 @@ func (f *finder) parseCurrentBranch() error { } } - f.branchName = prHeadRef - return nil + return prHeadRef, 0, nil } func findByNumber(httpClient *http.Client, repo ghrepo.Interface, number int, fields []string) (*api.PullRequest, error) { diff --git a/pkg/cmd/pr/shared/finder_test.go b/pkg/cmd/pr/shared/finder_test.go index 7149bd4f8..488e470ff 100644 --- a/pkg/cmd/pr/shared/finder_test.go +++ b/pkg/cmd/pr/shared/finder_test.go @@ -1,21 +1,26 @@ package shared import ( + "errors" "net/http" + "net/url" "testing" "github.com/cli/cli/context" + "github.com/cli/cli/git" "github.com/cli/cli/internal/ghrepo" "github.com/cli/cli/pkg/httpmock" ) func TestFind(t *testing.T) { type args struct { - baseRepoFn func() (ghrepo.Interface, error) - branchFn func() (string, error) - remotesFn func() (context.Remotes, error) - selector string - fields []string + baseRepoFn func() (ghrepo.Interface, error) + branchFn func() (string, error) + branchConfig func(string) git.BranchConfig + remotesFn func() (context.Remotes, error) + selector string + fields []string + baseBranch string } tests := []struct { name string @@ -44,6 +49,25 @@ func TestFind(t *testing.T) { wantPR: 13, wantRepo: "https://github.com/OWNER/REPO", }, + { + name: "baseRepo is error", + args: args{ + selector: "13", + fields: []string{"id", "number"}, + baseRepoFn: func() (ghrepo.Interface, error) { + return nil, errors.New("baseRepoErr") + }, + }, + wantErr: true, + }, + { + name: "blank fields is error", + args: args{ + selector: "13", + fields: []string{}, + }, + wantErr: true, + }, { name: "number only", args: args{ @@ -93,6 +117,233 @@ func TestFind(t *testing.T) { wantPR: 13, wantRepo: "https://example.org/OWNER/REPO", }, + { + name: "branch argument", + args: args{ + selector: "blueberries", + fields: []string{"id", "number"}, + baseRepoFn: func() (ghrepo.Interface, error) { + return ghrepo.FromFullName("OWNER/REPO") + }, + }, + httpStub: func(r *httpmock.Registry) { + r.Register( + httpmock.GraphQL(`query PullRequestForBranch\b`), + httpmock.StringResponse(`{"data":{"repository":{ + "pullRequests":{"nodes":[ + { + "number": 14, + "state": "CLOSED", + "baseRefName": "main", + "headRefName": "blueberries", + "isCrossRepository": false, + "headRepositoryOwner": {"login":"OWNER"} + }, + { + "number": 13, + "state": "OPEN", + "baseRefName": "main", + "headRefName": "blueberries", + "isCrossRepository": false, + "headRepositoryOwner": {"login":"OWNER"} + } + ]} + }}}`)) + }, + wantPR: 13, + wantRepo: "https://github.com/OWNER/REPO", + }, + { + name: "branch argument with base branch", + args: args{ + selector: "blueberries", + baseBranch: "main", + fields: []string{"id", "number"}, + baseRepoFn: func() (ghrepo.Interface, error) { + return ghrepo.FromFullName("OWNER/REPO") + }, + }, + httpStub: func(r *httpmock.Registry) { + r.Register( + httpmock.GraphQL(`query PullRequestForBranch\b`), + httpmock.StringResponse(`{"data":{"repository":{ + "pullRequests":{"nodes":[ + { + "number": 14, + "state": "OPEN", + "baseRefName": "dev", + "headRefName": "blueberries", + "isCrossRepository": false, + "headRepositoryOwner": {"login":"OWNER"} + }, + { + "number": 13, + "state": "OPEN", + "baseRefName": "main", + "headRefName": "blueberries", + "isCrossRepository": false, + "headRepositoryOwner": {"login":"OWNER"} + } + ]} + }}}`)) + }, + wantPR: 13, + wantRepo: "https://github.com/OWNER/REPO", + }, + { + name: "no argument reads current branch", + args: args{ + selector: "", + fields: []string{"id", "number"}, + baseRepoFn: func() (ghrepo.Interface, error) { + return ghrepo.FromFullName("OWNER/REPO") + }, + branchFn: func() (string, error) { + return "blueberries", nil + }, + branchConfig: func(branch string) (c git.BranchConfig) { + return + }, + }, + httpStub: func(r *httpmock.Registry) { + r.Register( + httpmock.GraphQL(`query PullRequestForBranch\b`), + httpmock.StringResponse(`{"data":{"repository":{ + "pullRequests":{"nodes":[ + { + "number": 13, + "state": "OPEN", + "baseRefName": "main", + "headRefName": "blueberries", + "isCrossRepository": false, + "headRepositoryOwner": {"login":"OWNER"} + } + ]} + }}}`)) + }, + wantPR: 13, + wantRepo: "https://github.com/OWNER/REPO", + }, + { + name: "current branch is error", + args: args{ + selector: "", + fields: []string{"id", "number"}, + baseRepoFn: func() (ghrepo.Interface, error) { + return ghrepo.FromFullName("OWNER/REPO") + }, + branchFn: func() (string, error) { + return "", errors.New("branchErr") + }, + }, + wantErr: true, + }, + { + name: "current branch with upstream configuration", + args: args{ + selector: "", + fields: []string{"id", "number"}, + baseRepoFn: func() (ghrepo.Interface, error) { + return ghrepo.FromFullName("OWNER/REPO") + }, + branchFn: func() (string, error) { + return "blueberries", nil + }, + branchConfig: func(branch string) (c git.BranchConfig) { + c.MergeRef = "refs/heads/blue-upstream-berries" + c.RemoteName = "origin" + return + }, + remotesFn: func() (context.Remotes, error) { + return context.Remotes{{ + Remote: &git.Remote{Name: "origin"}, + Repo: ghrepo.New("UPSTREAMOWNER", "REPO"), + }}, nil + }, + }, + httpStub: func(r *httpmock.Registry) { + r.Register( + httpmock.GraphQL(`query PullRequestForBranch\b`), + httpmock.StringResponse(`{"data":{"repository":{ + "pullRequests":{"nodes":[ + { + "number": 13, + "state": "OPEN", + "baseRefName": "main", + "headRefName": "blue-upstream-berries", + "isCrossRepository": true, + "headRepositoryOwner": {"login":"UPSTREAMOWNER"} + } + ]} + }}}`)) + }, + wantPR: 13, + wantRepo: "https://github.com/OWNER/REPO", + }, + { + name: "current branch with upstream configuration", + args: args{ + selector: "", + fields: []string{"id", "number"}, + baseRepoFn: func() (ghrepo.Interface, error) { + return ghrepo.FromFullName("OWNER/REPO") + }, + branchFn: func() (string, error) { + return "blueberries", nil + }, + branchConfig: func(branch string) (c git.BranchConfig) { + u, _ := url.Parse("https://github.com/UPSTREAMOWNER/REPO") + c.MergeRef = "refs/heads/blue-upstream-berries" + c.RemoteURL = u + return + }, + remotesFn: nil, + }, + httpStub: func(r *httpmock.Registry) { + r.Register( + httpmock.GraphQL(`query PullRequestForBranch\b`), + httpmock.StringResponse(`{"data":{"repository":{ + "pullRequests":{"nodes":[ + { + "number": 13, + "state": "OPEN", + "baseRefName": "main", + "headRefName": "blue-upstream-berries", + "isCrossRepository": true, + "headRepositoryOwner": {"login":"UPSTREAMOWNER"} + } + ]} + }}}`)) + }, + wantPR: 13, + wantRepo: "https://github.com/OWNER/REPO", + }, + { + name: "current branch made by pr checkout", + args: args{ + selector: "", + fields: []string{"id", "number"}, + baseRepoFn: func() (ghrepo.Interface, error) { + return ghrepo.FromFullName("OWNER/REPO") + }, + branchFn: func() (string, error) { + return "blueberries", nil + }, + branchConfig: func(branch string) (c git.BranchConfig) { + c.MergeRef = "refs/pull/13/head" + return + }, + }, + httpStub: func(r *httpmock.Registry) { + r.Register( + httpmock.GraphQL(`query PullRequestByNumber\b`), + httpmock.StringResponse(`{"data":{"repository":{ + "pullRequest":{"number":13} + }}}`)) + }, + wantPR: 13, + wantRepo: "https://github.com/OWNER/REPO", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -106,19 +357,30 @@ func TestFind(t *testing.T) { httpClient: func() (*http.Client, error) { return &http.Client{Transport: reg}, nil }, - baseRepoFn: tt.args.baseRepoFn, - branchFn: tt.args.branchFn, - remotesFn: tt.args.remotesFn, + baseRepoFn: tt.args.baseRepoFn, + branchFn: tt.args.branchFn, + branchConfig: tt.args.branchConfig, + remotesFn: tt.args.remotesFn, } pr, repo, err := f.Find(FindOptions{ - Selector: tt.args.selector, - Fields: tt.args.fields, + Selector: tt.args.selector, + Fields: tt.args.fields, + BaseBranch: tt.args.baseBranch, }) if (err != nil) != tt.wantErr { t.Errorf("Find() error = %v, wantErr %v", err, tt.wantErr) return } + if tt.wantErr { + if tt.wantPR > 0 { + t.Error("wantPR field is not checked in error case") + } + if tt.wantRepo != "" { + t.Error("wantRepo field is not checked in error case") + } + return + } if pr.Number != tt.wantPR { t.Errorf("want pr #%d, got #%d", tt.wantPR, pr.Number)