💅 cleanup and tests for PR finder

This commit is contained in:
Mislav Marohnić 2021-05-18 09:58:21 +02:00
parent bc3bb97c43
commit 51f7cbdfde
6 changed files with 317 additions and 47 deletions

View file

@ -118,12 +118,6 @@ type PullRequestCommitCommit struct {
}
}
func (pr *PullRequest) StubCommit(oid string) {
pr.Commits.Nodes = append(pr.Commits.Nodes, PullRequestCommit{
Commit: PullRequestCommitCommit{Oid: oid},
})
}
type PullRequestFile struct {
Path string `json:"path"`
Additions int `json:"additions"`

View file

@ -122,7 +122,7 @@ func closeRun(opts *CloseOptions) error {
}
if pr.IsCrossRepository {
fmt.Fprintf(opts.IO.ErrOut, "%s Avoiding deleting the remote branch of a pull request from fork\n", cs.WarningIcon())
fmt.Fprintf(opts.IO.ErrOut, "%s Skipped deleting the remote branch of a pull request from fork\n", cs.WarningIcon())
if !opts.DeleteLocalBranch {
return nil
}

View file

@ -191,7 +191,7 @@ func TestPrClose_deleteBranch_crossRepo(t *testing.T) {
assert.Equal(t, "", output.String())
assert.Equal(t, heredoc.Doc(`
Closed pull request #96 (The title of the PR)
! Avoiding deleting the remote branch of a pull request from fork
! Skipped deleting the remote branch of a pull request from fork
Deleted branch blueberries
`), output.Stderr())
}

View file

@ -203,6 +203,12 @@ func baseRepo(owner, repo, branch string) ghrepo.Interface {
}, "github.com")
}
func stubCommit(pr *api.PullRequest, oid string) {
pr.Commits.Nodes = append(pr.Commits.Nodes, api.PullRequestCommit{
Commit: api.PullRequestCommitCommit{Oid: oid},
})
}
func runCommand(rt http.RoundTripper, branch string, isTTY bool, cli string) (*test.CmdOut, error) {
io, _, stdout, stderr := iostreams.Test()
io.SetStdoutTTY(isTTY)
@ -456,7 +462,7 @@ func Test_nonDivergingPullRequest(t *testing.T) {
Title: "Blueberries are a good fruit",
State: "OPEN",
}
pr.StubCommit("COMMITSHA1")
stubCommit(pr, "COMMITSHA1")
shared.RunCommandFinder(
"",
@ -497,7 +503,7 @@ func Test_divergingPullRequestWarning(t *testing.T) {
Title: "Blueberries are a good fruit",
State: "OPEN",
}
pr.StubCommit("COMMITSHA1")
stubCommit(pr, "COMMITSHA1")
shared.RunCommandFinder(
"",

View file

@ -28,11 +28,12 @@ type progressIndicator interface {
}
type finder struct {
baseRepoFn func() (ghrepo.Interface, error)
branchFn func() (string, error)
remotesFn func() (context.Remotes, error)
httpClient func() (*http.Client, error)
progress progressIndicator
baseRepoFn func() (ghrepo.Interface, error)
branchFn func() (string, error)
remotesFn func() (context.Remotes, error)
httpClient func() (*http.Client, error)
branchConfig func(string) git.BranchConfig
progress progressIndicator
repo ghrepo.Interface
prNumber int
@ -47,11 +48,12 @@ func NewFinder(factory *cmdutil.Factory) PRFinder {
}
return &finder{
baseRepoFn: factory.BaseRepo,
branchFn: factory.Branch,
remotesFn: factory.Remotes,
httpClient: factory.HttpClient,
progress: factory.IOStreams,
baseRepoFn: factory.BaseRepo,
branchFn: factory.Branch,
remotesFn: factory.Remotes,
httpClient: factory.HttpClient,
progress: factory.IOStreams,
branchConfig: git.ReadBranchConfig,
}
}
@ -79,7 +81,10 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err
return nil, nil, errors.New("Find error: no fields specified")
}
_ = f.parseURL(opts.Selector)
if repo, prNumber, err := f.parseURL(opts.Selector); err == nil {
f.prNumber = prNumber
f.repo = repo
}
if f.repo == nil {
repo, err := f.baseRepoFn()
@ -90,8 +95,12 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err
}
if opts.Selector == "" {
if err := f.parseCurrentBranch(); err != nil {
if branch, prNumber, err := f.parseCurrentBranch(); err != nil {
return nil, nil, err
} else if prNumber > 0 {
f.prNumber = prNumber
} else {
f.branchName = branch
}
} else if f.prNumber == 0 {
if prNumber, err := strconv.Atoi(strings.TrimPrefix(opts.Selector, "#")); err == nil {
@ -129,44 +138,44 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err
var pullURLRE = regexp.MustCompile(`^/([^/]+)/([^/]+)/pull/(\d+)`)
func (f *finder) parseURL(prURL string) error {
func (f *finder) parseURL(prURL string) (ghrepo.Interface, int, error) {
if prURL == "" {
return fmt.Errorf("invalid URL: %q", prURL)
return nil, 0, fmt.Errorf("invalid URL: %q", prURL)
}
u, err := url.Parse(prURL)
if err != nil {
return err
return nil, 0, err
}
if u.Scheme != "https" && u.Scheme != "http" {
return fmt.Errorf("invalid scheme: %s", u.Scheme)
return nil, 0, fmt.Errorf("invalid scheme: %s", u.Scheme)
}
m := pullURLRE.FindStringSubmatch(u.Path)
if m == nil {
return fmt.Errorf("not a pull request URL: %s", prURL)
return nil, 0, fmt.Errorf("not a pull request URL: %s", prURL)
}
f.repo = ghrepo.NewWithHost(m[1], m[2], u.Hostname())
f.prNumber, _ = strconv.Atoi(m[3])
return nil
repo := ghrepo.NewWithHost(m[1], m[2], u.Hostname())
prNumber, _ := strconv.Atoi(m[3])
return repo, prNumber, nil
}
var prHeadRE = regexp.MustCompile(`^refs/pull/(\d+)/head$`)
func (f *finder) parseCurrentBranch() error {
func (f *finder) parseCurrentBranch() (string, int, error) {
prHeadRef, err := f.branchFn()
if err != nil {
return err
return "", 0, err
}
branchConfig := git.ReadBranchConfig(prHeadRef)
branchConfig := f.branchConfig(prHeadRef)
// the branch is configured to merge a special PR head ref
if m := prHeadRE.FindStringSubmatch(branchConfig.MergeRef); m != nil {
f.prNumber, _ = strconv.Atoi(m[1])
return nil
prNumber, _ := strconv.Atoi(m[1])
return "", prNumber, nil
}
var branchOwner string
@ -193,8 +202,7 @@ func (f *finder) parseCurrentBranch() error {
}
}
f.branchName = prHeadRef
return nil
return prHeadRef, 0, nil
}
func findByNumber(httpClient *http.Client, repo ghrepo.Interface, number int, fields []string) (*api.PullRequest, error) {

View file

@ -1,21 +1,26 @@
package shared
import (
"errors"
"net/http"
"net/url"
"testing"
"github.com/cli/cli/context"
"github.com/cli/cli/git"
"github.com/cli/cli/internal/ghrepo"
"github.com/cli/cli/pkg/httpmock"
)
func TestFind(t *testing.T) {
type args struct {
baseRepoFn func() (ghrepo.Interface, error)
branchFn func() (string, error)
remotesFn func() (context.Remotes, error)
selector string
fields []string
baseRepoFn func() (ghrepo.Interface, error)
branchFn func() (string, error)
branchConfig func(string) git.BranchConfig
remotesFn func() (context.Remotes, error)
selector string
fields []string
baseBranch string
}
tests := []struct {
name string
@ -44,6 +49,25 @@ func TestFind(t *testing.T) {
wantPR: 13,
wantRepo: "https://github.com/OWNER/REPO",
},
{
name: "baseRepo is error",
args: args{
selector: "13",
fields: []string{"id", "number"},
baseRepoFn: func() (ghrepo.Interface, error) {
return nil, errors.New("baseRepoErr")
},
},
wantErr: true,
},
{
name: "blank fields is error",
args: args{
selector: "13",
fields: []string{},
},
wantErr: true,
},
{
name: "number only",
args: args{
@ -93,6 +117,233 @@ func TestFind(t *testing.T) {
wantPR: 13,
wantRepo: "https://example.org/OWNER/REPO",
},
{
name: "branch argument",
args: args{
selector: "blueberries",
fields: []string{"id", "number"},
baseRepoFn: func() (ghrepo.Interface, error) {
return ghrepo.FromFullName("OWNER/REPO")
},
},
httpStub: func(r *httpmock.Registry) {
r.Register(
httpmock.GraphQL(`query PullRequestForBranch\b`),
httpmock.StringResponse(`{"data":{"repository":{
"pullRequests":{"nodes":[
{
"number": 14,
"state": "CLOSED",
"baseRefName": "main",
"headRefName": "blueberries",
"isCrossRepository": false,
"headRepositoryOwner": {"login":"OWNER"}
},
{
"number": 13,
"state": "OPEN",
"baseRefName": "main",
"headRefName": "blueberries",
"isCrossRepository": false,
"headRepositoryOwner": {"login":"OWNER"}
}
]}
}}}`))
},
wantPR: 13,
wantRepo: "https://github.com/OWNER/REPO",
},
{
name: "branch argument with base branch",
args: args{
selector: "blueberries",
baseBranch: "main",
fields: []string{"id", "number"},
baseRepoFn: func() (ghrepo.Interface, error) {
return ghrepo.FromFullName("OWNER/REPO")
},
},
httpStub: func(r *httpmock.Registry) {
r.Register(
httpmock.GraphQL(`query PullRequestForBranch\b`),
httpmock.StringResponse(`{"data":{"repository":{
"pullRequests":{"nodes":[
{
"number": 14,
"state": "OPEN",
"baseRefName": "dev",
"headRefName": "blueberries",
"isCrossRepository": false,
"headRepositoryOwner": {"login":"OWNER"}
},
{
"number": 13,
"state": "OPEN",
"baseRefName": "main",
"headRefName": "blueberries",
"isCrossRepository": false,
"headRepositoryOwner": {"login":"OWNER"}
}
]}
}}}`))
},
wantPR: 13,
wantRepo: "https://github.com/OWNER/REPO",
},
{
name: "no argument reads current branch",
args: args{
selector: "",
fields: []string{"id", "number"},
baseRepoFn: func() (ghrepo.Interface, error) {
return ghrepo.FromFullName("OWNER/REPO")
},
branchFn: func() (string, error) {
return "blueberries", nil
},
branchConfig: func(branch string) (c git.BranchConfig) {
return
},
},
httpStub: func(r *httpmock.Registry) {
r.Register(
httpmock.GraphQL(`query PullRequestForBranch\b`),
httpmock.StringResponse(`{"data":{"repository":{
"pullRequests":{"nodes":[
{
"number": 13,
"state": "OPEN",
"baseRefName": "main",
"headRefName": "blueberries",
"isCrossRepository": false,
"headRepositoryOwner": {"login":"OWNER"}
}
]}
}}}`))
},
wantPR: 13,
wantRepo: "https://github.com/OWNER/REPO",
},
{
name: "current branch is error",
args: args{
selector: "",
fields: []string{"id", "number"},
baseRepoFn: func() (ghrepo.Interface, error) {
return ghrepo.FromFullName("OWNER/REPO")
},
branchFn: func() (string, error) {
return "", errors.New("branchErr")
},
},
wantErr: true,
},
{
name: "current branch with upstream configuration",
args: args{
selector: "",
fields: []string{"id", "number"},
baseRepoFn: func() (ghrepo.Interface, error) {
return ghrepo.FromFullName("OWNER/REPO")
},
branchFn: func() (string, error) {
return "blueberries", nil
},
branchConfig: func(branch string) (c git.BranchConfig) {
c.MergeRef = "refs/heads/blue-upstream-berries"
c.RemoteName = "origin"
return
},
remotesFn: func() (context.Remotes, error) {
return context.Remotes{{
Remote: &git.Remote{Name: "origin"},
Repo: ghrepo.New("UPSTREAMOWNER", "REPO"),
}}, nil
},
},
httpStub: func(r *httpmock.Registry) {
r.Register(
httpmock.GraphQL(`query PullRequestForBranch\b`),
httpmock.StringResponse(`{"data":{"repository":{
"pullRequests":{"nodes":[
{
"number": 13,
"state": "OPEN",
"baseRefName": "main",
"headRefName": "blue-upstream-berries",
"isCrossRepository": true,
"headRepositoryOwner": {"login":"UPSTREAMOWNER"}
}
]}
}}}`))
},
wantPR: 13,
wantRepo: "https://github.com/OWNER/REPO",
},
{
name: "current branch with upstream configuration",
args: args{
selector: "",
fields: []string{"id", "number"},
baseRepoFn: func() (ghrepo.Interface, error) {
return ghrepo.FromFullName("OWNER/REPO")
},
branchFn: func() (string, error) {
return "blueberries", nil
},
branchConfig: func(branch string) (c git.BranchConfig) {
u, _ := url.Parse("https://github.com/UPSTREAMOWNER/REPO")
c.MergeRef = "refs/heads/blue-upstream-berries"
c.RemoteURL = u
return
},
remotesFn: nil,
},
httpStub: func(r *httpmock.Registry) {
r.Register(
httpmock.GraphQL(`query PullRequestForBranch\b`),
httpmock.StringResponse(`{"data":{"repository":{
"pullRequests":{"nodes":[
{
"number": 13,
"state": "OPEN",
"baseRefName": "main",
"headRefName": "blue-upstream-berries",
"isCrossRepository": true,
"headRepositoryOwner": {"login":"UPSTREAMOWNER"}
}
]}
}}}`))
},
wantPR: 13,
wantRepo: "https://github.com/OWNER/REPO",
},
{
name: "current branch made by pr checkout",
args: args{
selector: "",
fields: []string{"id", "number"},
baseRepoFn: func() (ghrepo.Interface, error) {
return ghrepo.FromFullName("OWNER/REPO")
},
branchFn: func() (string, error) {
return "blueberries", nil
},
branchConfig: func(branch string) (c git.BranchConfig) {
c.MergeRef = "refs/pull/13/head"
return
},
},
httpStub: func(r *httpmock.Registry) {
r.Register(
httpmock.GraphQL(`query PullRequestByNumber\b`),
httpmock.StringResponse(`{"data":{"repository":{
"pullRequest":{"number":13}
}}}`))
},
wantPR: 13,
wantRepo: "https://github.com/OWNER/REPO",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -106,19 +357,30 @@ func TestFind(t *testing.T) {
httpClient: func() (*http.Client, error) {
return &http.Client{Transport: reg}, nil
},
baseRepoFn: tt.args.baseRepoFn,
branchFn: tt.args.branchFn,
remotesFn: tt.args.remotesFn,
baseRepoFn: tt.args.baseRepoFn,
branchFn: tt.args.branchFn,
branchConfig: tt.args.branchConfig,
remotesFn: tt.args.remotesFn,
}
pr, repo, err := f.Find(FindOptions{
Selector: tt.args.selector,
Fields: tt.args.fields,
Selector: tt.args.selector,
Fields: tt.args.fields,
BaseBranch: tt.args.baseBranch,
})
if (err != nil) != tt.wantErr {
t.Errorf("Find() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
if tt.wantPR > 0 {
t.Error("wantPR field is not checked in error case")
}
if tt.wantRepo != "" {
t.Error("wantRepo field is not checked in error case")
}
return
}
if pr.Number != tt.wantPR {
t.Errorf("want pr #%d, got #%d", tt.wantPR, pr.Number)