Refactor finder.Find and replace parseCurrentBranch with parsePRRefs

I've been struggling horribly to reason through all of this code, and
after much mental gymnastics I identified the culprit as the overloaded
"branch" string returned by parseCurrentBranch.

This value was either the name of the branch that the PR we're looking for
is associated with, or that name prepended with the owner's name and a :
if we're on a branch, so:

PR branch: featureBranch
branch == "featureBranch"

If on Fork belonging to "ForkOwner"
branch == "ForkOwner:featureBranch"

Since this extra information was bundled up into this single string, it
complicated the responsibilities of parseCurrentBranch's "branch" return
value. Thus, I've teased out "branch" into the new PRRefs struct:

type PRRefs struct{
	BranchName string
	HeadRepo ghrepo.Interface
	BaseRepo ghrepo.Interface
}

This allows the new parsePRRefs function to move all the previous
"branch" string's information into structured data, and allows for a new
method on PRRefs, GetPRLabel(), to create the string that "branch"
previously held to pass into its downstream consumer, namely
findForBranch.

This also allowed for better test coverage, directly connecting the PRRefs
fields to the values contained in the git config. Overall, I am now
confident that this is doing what its supposed to do with respect to my
understanding of the various central and triangular git workflows we are
addressing.
This commit is contained in:
Tyler McGoffin 2025-01-22 15:39:43 -08:00
parent aef2642581
commit 41729b004d
3 changed files with 540 additions and 286 deletions

View file

@ -64,8 +64,10 @@ type BranchConfig struct {
RemoteName string
RemoteURL *url.URL
// MergeBase is the optional base branch to target in a new PR if `--base` is not specified.
MergeBase string
MergeRef string
MergeBase string
MergeRef string
// These are used to handle triangular workflows. They can be defined by either
// a remote.pushDefault or a branch.<name>.pushremote value set on the git config.
RemotePushDefault string
PushRemoteURL *url.URL
PushRemoteName string

View file

@ -41,9 +41,9 @@ type finder struct {
branchConfig func(string) (git.BranchConfig, error)
progress progressIndicator
repo ghrepo.Interface
prNumber int
branchName string
baseRefRepo ghrepo.Interface
prNumber int
branchName string
}
func NewFinder(factory *cmdutil.Factory) PRFinder {
@ -89,6 +89,22 @@ type FindOptions struct {
States []string
}
type PRRefs struct {
BranchName string
HeadRepo ghrepo.Interface
BaseRepo ghrepo.Interface
}
// GetPRLabel returns the string that the GitHub API uses to identify the PR. This is
// either just the branch name or, if the PR is originating from a fork, the fork owner
// and the branch name, like <owner>:<branch>.
func (s *PRRefs) GetPRLabel() string {
if s.HeadRepo == s.BaseRepo {
return s.BranchName
}
return fmt.Sprintf("%s:%s", s.HeadRepo.RepoOwner(), s.BranchName)
}
func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, error) {
if len(opts.Fields) == 0 {
return nil, nil, errors.New("Find error: no fields specified")
@ -96,26 +112,18 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err
if repo, prNumber, err := f.parseURL(opts.Selector); err == nil {
f.prNumber = prNumber
f.repo = repo
f.baseRefRepo = repo
}
if f.repo == nil {
if f.baseRefRepo == nil {
repo, err := f.baseRepoFn()
if err != nil {
return nil, nil, err
}
f.repo = repo
f.baseRefRepo = repo
}
if opts.Selector == "" {
if branch, prNumber, err := f.parseCurrentBranch(); err != nil {
return nil, nil, err
} else if prNumber > 0 {
f.prNumber = prNumber
} else {
f.branchName = branch
}
} else if f.prNumber == 0 {
if f.prNumber == 0 && opts.Selector != "" {
// If opts.Selector is a valid number then assume it is the
// PR number unless opts.BaseBranch is specified. This is a
// special case for PR create command which will always want
@ -127,8 +135,28 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err
} else {
f.branchName = opts.Selector
}
} else {
currentBranchName, err := f.branchFn()
if err != nil {
return nil, nil, err
}
f.branchName = currentBranchName
}
// Get the branch config for the current branchName
branchConfig, err := f.branchConfig(f.branchName)
if err != nil {
return nil, nil, err
}
// Determine if the branch is configured to merge to a special PR ref
prHeadRE := regexp.MustCompile(`^refs/pull/(\d+)/head$`)
if m := prHeadRE.FindStringSubmatch(branchConfig.MergeRef); m != nil {
prNumber, _ := strconv.Atoi(m[1])
f.prNumber = prNumber
}
// Set up HTTP client
httpClient, err := f.httpClient()
if err != nil {
return nil, nil, err
@ -147,7 +175,7 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err
if fields.Contains("isInMergeQueue") || fields.Contains("isMergeQueueEnabled") {
cachedClient := api.NewCachedHTTPClient(httpClient, time.Hour*24)
detector := fd.NewDetector(cachedClient, f.repo.RepoHost())
detector := fd.NewDetector(cachedClient, f.baseRefRepo.RepoHost())
prFeatures, err := detector.PullRequestFeatures()
if err != nil {
return nil, nil, err
@ -168,36 +196,54 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err
if f.prNumber > 0 {
if numberFieldOnly {
// avoid hitting the API if we already have all the information
return &api.PullRequest{Number: f.prNumber}, f.repo, nil
return &api.PullRequest{Number: f.prNumber}, f.baseRefRepo, nil
}
pr, err = findByNumber(httpClient, f.baseRefRepo, f.prNumber, fields.ToSlice())
if err != nil {
return pr, f.baseRefRepo, err
}
pr, err = findByNumber(httpClient, f.repo, f.prNumber, fields.ToSlice())
} else {
pr, err = findForBranch(httpClient, f.repo, opts.BaseBranch, f.branchName, opts.States, fields.ToSlice())
}
if err != nil {
return pr, f.repo, err
rems, err := f.remotesFn()
if err != nil {
return nil, nil, err
}
pushDefault, err := f.pushDefault()
if err != nil {
return nil, nil, err
}
prRefs, err := parsePRRefs(f.branchName, branchConfig, pushDefault, f.baseRefRepo, rems)
if err != nil {
return nil, nil, err
}
pr, err = findForBranch(httpClient, f.baseRefRepo, opts.BaseBranch, prRefs.GetPRLabel(), opts.States, fields.ToSlice())
if err != nil {
return pr, f.baseRefRepo, err
}
}
g, _ := errgroup.WithContext(context.Background())
if fields.Contains("reviews") {
g.Go(func() error {
return preloadPrReviews(httpClient, f.repo, pr)
return preloadPrReviews(httpClient, f.baseRefRepo, pr)
})
}
if fields.Contains("comments") {
g.Go(func() error {
return preloadPrComments(httpClient, f.repo, pr)
return preloadPrComments(httpClient, f.baseRefRepo, pr)
})
}
if fields.Contains("statusCheckRollup") {
g.Go(func() error {
return preloadPrChecks(httpClient, f.repo, pr)
return preloadPrChecks(httpClient, f.baseRefRepo, pr)
})
}
if getProjectItems {
g.Go(func() error {
apiClient := api.NewClientFromHTTP(httpClient)
err := api.ProjectsV2ItemsForPullRequest(apiClient, f.repo, pr)
err := api.ProjectsV2ItemsForPullRequest(apiClient, f.baseRefRepo, pr)
if err != nil && !api.ProjectsV2IgnorableError(err) {
return err
}
@ -205,7 +251,7 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err
})
}
return pr, f.repo, g.Wait()
return pr, f.baseRefRepo, g.Wait()
}
var pullURLRE = regexp.MustCompile(`^/([^/]+)/([^/]+)/pull/(\d+)`)
@ -234,61 +280,53 @@ func (f *finder) parseURL(prURL string) (ghrepo.Interface, int, error) {
return repo, prNumber, nil
}
var prHeadRE = regexp.MustCompile(`^refs/pull/(\d+)/head$`)
func (f *finder) parseCurrentBranch() (string, int, error) {
prHeadRef, err := f.branchFn()
if err != nil {
return "", 0, err
func parsePRRefs(currentBranchName string, branchConfig git.BranchConfig, pushDefault string, baseRefRepo ghrepo.Interface, rems remotes.Remotes) (PRRefs, error) {
prRefs := PRRefs{
BaseRepo: baseRefRepo,
}
branchConfig, err := f.branchConfig(prHeadRef)
if err != nil {
return "", 0, err
}
// the branch is configured to merge a special PR head ref
if m := prHeadRE.FindStringSubmatch(branchConfig.MergeRef); m != nil {
prNumber, _ := strconv.Atoi(m[1])
return "", prNumber, nil
}
var gitRemoteRepo ghrepo.Interface
if branchConfig.PushRemoteURL != nil {
// the branch merges from a remote specified by URL
if r, err := ghrepo.FromURL(branchConfig.RemoteURL); err == nil {
gitRemoteRepo = r
}
} else if branchConfig.PushRemoteName != "" {
rem, _ := f.remotesFn()
if r, err := rem.FindByName(branchConfig.PushRemoteName); err == nil {
gitRemoteRepo = r
// If @{push} resolves, then we have all the information we need to determine the head repo
// and branch name. It is of the form <remote>/<branch>.
if branchConfig.Push != "" {
for _, r := range rems {
// Find the remote who's name matches the push <remote> prefix
if strings.HasPrefix(branchConfig.Push, r.Name+"/") {
prRefs.BranchName = strings.TrimPrefix(branchConfig.Push, r.Name+"/")
prRefs.HeadRepo = r.Repo
return prRefs, nil
}
}
}
if gitRemoteRepo != nil {
if branchConfig.Push != "" {
prHeadRef = strings.TrimPrefix(branchConfig.Push, branchConfig.PushRemoteName+"/")
} else if pushDefault, _ := f.pushDefault(); (pushDefault == "upstream" || pushDefault == "tracking") &&
strings.HasPrefix(branchConfig.MergeRef, "refs/heads/") {
prHeadRef = strings.TrimPrefix(branchConfig.MergeRef, "refs/heads/")
// To get the HeadRepo, we look to the git config. The PushRemote{Name | URL} comes from
// one of the following, in order of precedence:
// 1. branch.<name>.pushRemote
// 2. remote.pushDefault
// 3. branch.<name>.remote
if branchConfig.PushRemoteName != "" {
if r, err := rems.FindByName(branchConfig.PushRemoteName); err == nil {
prRefs.HeadRepo = r.Repo
}
// prepend `OWNER:` if this branch is pushed to a fork
// This is determined by:
// - The repo having a different owner
// - The repo having the same owner but a different name (private org fork)
// I suspect that the implementation of the second case may be broken in the face
// of a repo rename, where the remote hasn't been updated locally. This is a
// frequent issue in commands that use SmartBaseRepoFunc. It's not any worse than not
// supporting this case at all though.
sameOwner := strings.EqualFold(gitRemoteRepo.RepoOwner(), f.repo.RepoOwner())
sameOwnerDifferentRepoName := sameOwner && !strings.EqualFold(gitRemoteRepo.RepoName(), f.repo.RepoName())
if !sameOwner || sameOwnerDifferentRepoName {
prHeadRef = fmt.Sprintf("%s:%s", gitRemoteRepo.RepoOwner(), prHeadRef)
} else if branchConfig.PushRemoteURL != nil {
if r, err := ghrepo.FromURL(branchConfig.PushRemoteURL); err == nil {
prRefs.HeadRepo = r
}
}
return prHeadRef, 0, nil
// We assume the PR's branch name is the same as whatever f.BranchFn() returned earlier.
// unless the user has specified push.default = upstream or tracking, then we use the
// branch name from the merge ref.
prRefs.BranchName = currentBranchName
if pushDefault == "upstream" || pushDefault == "tracking" {
prRefs.BranchName = strings.TrimPrefix(branchConfig.MergeRef, "refs/heads/")
}
// The PR merges from a branch in the same repo as the base branch (usually the default branch)
if prRefs.HeadRepo == nil {
prRefs.HeadRepo = baseRefRepo
}
return prRefs, nil
}
func findByNumber(httpClient *http.Client, repo ghrepo.Interface, number int, fields []string) (*api.PullRequest, error) {
@ -321,7 +359,7 @@ func findByNumber(httpClient *http.Client, repo ghrepo.Interface, number int, fi
return &resp.Repository.PullRequest, nil
}
func findForBranch(httpClient *http.Client, repo ghrepo.Interface, baseBranch, headBranch string, stateFilters, fields []string) (*api.PullRequest, error) {
func findForBranch(httpClient *http.Client, repo ghrepo.Interface, baseBranch, headBranchWithOwnerIfFork string, stateFilters, fields []string) (*api.PullRequest, error) {
type response struct {
Repository struct {
PullRequests struct {
@ -348,9 +386,9 @@ func findForBranch(httpClient *http.Client, repo ghrepo.Interface, baseBranch, h
}
}`, api.PullRequestGraphQL(fieldSet.ToSlice()))
branchWithoutOwner := headBranch
if idx := strings.Index(headBranch, ":"); idx >= 0 {
branchWithoutOwner = headBranch[idx+1:]
branchWithoutOwner := headBranchWithOwnerIfFork
if idx := strings.Index(headBranchWithOwnerIfFork, ":"); idx >= 0 {
branchWithoutOwner = headBranchWithOwnerIfFork[idx+1:]
}
variables := map[string]interface{}{
@ -373,18 +411,17 @@ func findForBranch(httpClient *http.Client, repo ghrepo.Interface, baseBranch, h
})
for _, pr := range prs {
headBranchMatches := pr.HeadLabel() == headBranch
headBranchMatches := pr.HeadLabel() == headBranchWithOwnerIfFork
baseBranchEmptyOrMatches := baseBranch == "" || pr.BaseRefName == baseBranch
// When the head is the default branch, it doesn't really make sense to show merged or closed PRs.
// https://github.com/cli/cli/issues/4263
isNotClosedOrMergedWhenHeadIsDefault := pr.State == "OPEN" || resp.Repository.DefaultBranchRef.Name != headBranch
isNotClosedOrMergedWhenHeadIsDefault := pr.State == "OPEN" || resp.Repository.DefaultBranchRef.Name != headBranchWithOwnerIfFork
if headBranchMatches && baseBranchEmptyOrMatches && isNotClosedOrMergedWhenHeadIsDefault {
return &pr, nil
}
}
return nil, &NotFoundError{fmt.Errorf("no pull requests found for branch %q", headBranch)}
return nil, &NotFoundError{fmt.Errorf("no pull requests found for branch %q", headBranchWithOwnerIfFork)}
}
func preloadPrReviews(httpClient *http.Client, repo ghrepo.Interface, pr *api.PullRequest) error {

View file

@ -18,20 +18,45 @@ type args struct {
baseRepoFn func() (ghrepo.Interface, error)
branchFn func() (string, error)
branchConfig func(string) (git.BranchConfig, error)
remotesFn func() (context.Remotes, error)
pushDefault func() (string, error)
selector string
fields []string
baseBranch string
}
func TestFind(t *testing.T) {
type args struct {
baseRepoFn func() (ghrepo.Interface, error)
branchFn func() (string, error)
branchConfig func(string) (git.BranchConfig, error)
pushDefault func() (string, error)
remotesFn func() (context.Remotes, error)
selector string
fields []string
baseBranch string
// TODO: Abstract these out meaningfully for reuse in parsePRRefs tests
originOwnerUrl, err := url.Parse("https://github.com/ORIGINOWNER/REPO.git")
if err != nil {
t.Fatal(err)
}
remoteOrigin := context.Remote{
Remote: &git.Remote{
Name: "origin",
FetchURL: originOwnerUrl,
},
Repo: ghrepo.New("ORIGINOWNER", "REPO"),
}
remoteOther := context.Remote{
Remote: &git.Remote{
Name: "other",
FetchURL: originOwnerUrl,
},
Repo: ghrepo.New("ORIGINOWNER", "OTHER-REPO"),
}
upstreamOwnerUrl, err := url.Parse("https://github.com/UPSTREAMOWNER/REPO.git")
if err != nil {
t.Fatal(err)
}
remoteUpstream := context.Remote{
Remote: &git.Remote{
Name: "upstream",
FetchURL: upstreamOwnerUrl,
},
Repo: ghrepo.New("UPSTREAMOWNER", "REPO"),
}
tests := []struct {
name string
args args
@ -43,11 +68,13 @@ func TestFind(t *testing.T) {
{
name: "number argument",
args: args{
selector: "13",
fields: []string{"id", "number"},
baseRepoFn: func() (ghrepo.Interface, error) {
return ghrepo.FromFullName("OWNER/REPO")
selector: "13",
fields: []string{"id", "number"},
baseRepoFn: stubBaseRepoFn(remoteOrigin.Repo, nil),
branchFn: func() (string, error) {
return "blueberries", nil
},
branchConfig: stubBranchConfig(git.BranchConfig{}, nil),
},
httpStub: func(r *httpmock.Registry) {
r.Register(
@ -57,7 +84,7 @@ func TestFind(t *testing.T) {
}}}`))
},
wantPR: 13,
wantRepo: "https://github.com/OWNER/REPO",
wantRepo: "https://github.com/ORIGINOWNER/REPO",
},
{
name: "number argument with base branch",
@ -65,9 +92,14 @@ func TestFind(t *testing.T) {
selector: "13",
baseBranch: "main",
fields: []string{"id", "number"},
baseRepoFn: func() (ghrepo.Interface, error) {
return ghrepo.FromFullName("OWNER/REPO")
baseRepoFn: stubBaseRepoFn(remoteOrigin.Repo, nil),
branchFn: func() (string, error) {
return "blueberries", nil
},
branchConfig: stubBranchConfig(git.BranchConfig{
PushRemoteName: remoteOrigin.Remote.Name,
}, nil),
pushDefault: stubPushDefault("simple", nil),
},
httpStub: func(r *httpmock.Registry) {
r.Register(
@ -80,22 +112,25 @@ func TestFind(t *testing.T) {
"baseRefName": "main",
"headRefName": "13",
"isCrossRepository": false,
"headRepositoryOwner": {"login":"OWNER"}
"headRepositoryOwner": {"login":"ORIGINOWNER"}
}
]}
}}}`))
},
wantPR: 123,
wantRepo: "https://github.com/OWNER/REPO",
wantRepo: "https://github.com/ORIGINOWNER/REPO",
},
{
name: "baseRepo is error",
args: args{
selector: "13",
fields: []string{"id", "number"},
baseRepoFn: func() (ghrepo.Interface, error) {
return nil, errors.New("baseRepoErr")
selector: "13",
fields: []string{"id", "number"},
baseRepoFn: stubBaseRepoFn(nil, errors.New("baseRepoErr")),
branchFn: func() (string, error) {
return "blueberries", nil
},
branchConfig: stubBranchConfig(git.BranchConfig{}, nil),
pushDefault: stubPushDefault("simple", nil),
},
wantErr: true,
},
@ -110,24 +145,30 @@ func TestFind(t *testing.T) {
{
name: "number only",
args: args{
selector: "13",
fields: []string{"number"},
baseRepoFn: func() (ghrepo.Interface, error) {
return ghrepo.FromFullName("OWNER/REPO")
selector: "13",
fields: []string{"number"},
baseRepoFn: stubBaseRepoFn(remoteOrigin.Repo, nil),
branchFn: func() (string, error) {
return "blueberries", nil
},
branchConfig: stubBranchConfig(git.BranchConfig{}, nil),
pushDefault: stubPushDefault("simple", nil),
},
httpStub: nil,
wantPR: 13,
wantRepo: "https://github.com/OWNER/REPO",
wantRepo: "https://github.com/ORIGINOWNER/REPO",
},
{
name: "number with hash argument",
args: args{
selector: "#13",
fields: []string{"id", "number"},
baseRepoFn: func() (ghrepo.Interface, error) {
return ghrepo.FromFullName("OWNER/REPO")
selector: "#13",
fields: []string{"id", "number"},
baseRepoFn: stubBaseRepoFn(remoteOrigin.Repo, nil),
branchFn: func() (string, error) {
return "blueberries", nil
},
branchConfig: stubBranchConfig(git.BranchConfig{}, nil),
pushDefault: stubPushDefault("simple", nil),
},
httpStub: func(r *httpmock.Registry) {
r.Register(
@ -137,14 +178,19 @@ func TestFind(t *testing.T) {
}}}`))
},
wantPR: 13,
wantRepo: "https://github.com/OWNER/REPO",
wantRepo: "https://github.com/ORIGINOWNER/REPO",
},
{
name: "URL argument",
name: "PR URL argument",
args: args{
selector: "https://example.org/OWNER/REPO/pull/13/files",
fields: []string{"id", "number"},
baseRepoFn: nil,
branchFn: func() (string, error) {
return "blueberries", nil
},
branchConfig: stubBranchConfig(git.BranchConfig{}, nil),
pushDefault: stubPushDefault("simple", nil),
},
httpStub: func(r *httpmock.Registry) {
r.Register(
@ -157,13 +203,16 @@ func TestFind(t *testing.T) {
wantRepo: "https://example.org/OWNER/REPO",
},
{
name: "branch argument",
name: "when provided branch argument with an open and closed PR for that branch name, it returns the open PR",
args: args{
selector: "blueberries",
fields: []string{"id", "number"},
baseRepoFn: func() (ghrepo.Interface, error) {
return ghrepo.FromFullName("OWNER/REPO")
selector: "blueberries",
fields: []string{"id", "number"},
baseRepoFn: stubBaseRepoFn(remoteOrigin.Repo, nil),
branchFn: func() (string, error) {
return "blueberries", nil
},
branchConfig: stubBranchConfig(git.BranchConfig{}, nil),
pushDefault: stubPushDefault("simple", nil),
},
httpStub: func(r *httpmock.Registry) {
r.Register(
@ -176,7 +225,7 @@ func TestFind(t *testing.T) {
"baseRefName": "main",
"headRefName": "blueberries",
"isCrossRepository": false,
"headRepositoryOwner": {"login":"OWNER"}
"headRepositoryOwner": {"login":"ORIGINOWNER"}
},
{
"number": 13,
@ -184,13 +233,13 @@ func TestFind(t *testing.T) {
"baseRefName": "main",
"headRefName": "blueberries",
"isCrossRepository": false,
"headRepositoryOwner": {"login":"OWNER"}
"headRepositoryOwner": {"login":"ORIGINOWNER"}
}
]}
}}}`))
},
wantPR: 13,
wantRepo: "https://github.com/OWNER/REPO",
wantRepo: "https://github.com/ORIGINOWNER/REPO",
},
{
name: "branch argument with base branch",
@ -201,6 +250,11 @@ func TestFind(t *testing.T) {
baseRepoFn: func() (ghrepo.Interface, error) {
return ghrepo.FromFullName("OWNER/REPO")
},
branchFn: func() (string, error) {
return "blueberries", nil
},
branchConfig: stubBranchConfig(git.BranchConfig{}, nil),
pushDefault: stubPushDefault("simple", nil),
},
httpStub: func(r *httpmock.Registry) {
r.Register(
@ -240,17 +294,8 @@ func TestFind(t *testing.T) {
branchFn: func() (string, error) {
return "blueberries", nil
},
branchConfig: stubBranchConfig(git.BranchConfig{
MergeRef: "refs/heads/blueberries",
RemoteName: "origin",
Push: "origin/blueberries",
}, nil),
remotesFn: func() (context.Remotes, error) {
return context.Remotes{{
Remote: &git.Remote{Name: "origin"},
Repo: ghrepo.New("OWNER", "REPO"),
}}, nil
},
branchConfig: stubBranchConfig(git.BranchConfig{}, nil),
pushDefault: stubPushDefault("simple", nil),
},
httpStub: func(r *httpmock.Registry) {
r.Register(
@ -283,6 +328,7 @@ func TestFind(t *testing.T) {
return "blueberries", nil
},
branchConfig: stubBranchConfig(git.BranchConfig{}, nil),
pushDefault: stubPushDefault("simple", nil),
},
httpStub: func(r *httpmock.Registry) {
r.Register(
@ -320,114 +366,19 @@ func TestFind(t *testing.T) {
wantErr: true,
},
{
name: "current branch with upstream configuration",
name: "when the current branch is configured to push to and pull from 'upstream' and push.default = upstream but the repo push/pulls from 'origin', it finds the PR associated with the upstream repo and returns origin as the base repo",
args: args{
selector: "",
fields: []string{"id", "number"},
baseRepoFn: func() (ghrepo.Interface, error) {
return ghrepo.FromFullName("OWNER/REPO")
},
branchFn: func() (string, error) {
return "blueberries", nil
},
pushDefault: func() (string, error) { return "upstream", nil },
branchConfig: stubBranchConfig(git.BranchConfig{
MergeRef: "refs/heads/blue-upstream-berries",
RemoteName: "origin",
PushRemoteName: "origin",
Push: "origin/blue-upstream-berries",
}, nil),
remotesFn: func() (context.Remotes, error) {
return context.Remotes{{
Remote: &git.Remote{Name: "origin"},
Repo: ghrepo.New("UPSTREAMOWNER", "REPO"),
}}, nil
},
},
httpStub: func(r *httpmock.Registry) {
r.Register(
httpmock.GraphQL(`query PullRequestForBranch\b`),
httpmock.StringResponse(`{"data":{"repository":{
"pullRequests":{"nodes":[
{
"number": 13,
"state": "OPEN",
"baseRefName": "main",
"headRefName": "blue-upstream-berries",
"isCrossRepository": true,
"headRepositoryOwner": {"login":"UPSTREAMOWNER"}
}
]}
}}}`))
},
wantPR: 13,
wantRepo: "https://github.com/OWNER/REPO",
},
{
name: "current branch with upstream RemoteURL configuration",
args: args{
selector: "",
fields: []string{"id", "number"},
baseRepoFn: func() (ghrepo.Interface, error) {
return ghrepo.FromFullName("OWNER/REPO")
},
branchFn: func() (string, error) {
return "blueberries", nil
},
branchConfig: func(branch string) (git.BranchConfig, error) {
u, _ := url.Parse("https://github.com/UPSTREAMOWNER/REPO")
return stubBranchConfig(git.BranchConfig{
MergeRef: "refs/heads/blue-upstream-berries",
RemoteURL: u,
PushRemoteURL: u,
}, nil)(branch)
},
pushDefault: func() (string, error) { return "upstream", nil },
remotesFn: nil,
},
httpStub: func(r *httpmock.Registry) {
r.Register(
httpmock.GraphQL(`query PullRequestForBranch\b`),
httpmock.StringResponse(`{"data":{"repository":{
"pullRequests":{"nodes":[
{
"number": 13,
"state": "OPEN",
"baseRefName": "main",
"headRefName": "blue-upstream-berries",
"isCrossRepository": true,
"headRepositoryOwner": {"login":"UPSTREAMOWNER"}
}
]}
}}}`))
},
wantPR: 13,
wantRepo: "https://github.com/OWNER/REPO",
},
{
name: "current branch with tracking (deprecated synonym of upstream) configuration",
args: args{
selector: "",
fields: []string{"id", "number"},
baseRepoFn: func() (ghrepo.Interface, error) {
return ghrepo.FromFullName("OWNER/REPO")
},
selector: "",
fields: []string{"id", "number"},
baseRepoFn: stubBaseRepoFn(remoteOrigin.Repo, nil),
branchFn: func() (string, error) {
return "blueberries", nil
},
branchConfig: stubBranchConfig(git.BranchConfig{
MergeRef: "refs/heads/blue-upstream-berries",
RemoteName: "origin",
PushRemoteName: "origin",
Push: "origin/blue-upstream-berries",
PushRemoteName: "upstream",
}, nil),
pushDefault: func() (string, error) { return "tracking", nil },
remotesFn: func() (context.Remotes, error) {
return context.Remotes{{
Remote: &git.Remote{Name: "origin"},
Repo: ghrepo.New("UPSTREAMOWNER", "REPO"),
}}, nil
},
pushDefault: stubPushDefault("upstream", nil),
},
httpStub: func(r *httpmock.Registry) {
r.Register(
@ -446,34 +397,55 @@ func TestFind(t *testing.T) {
}}}`))
},
wantPR: 13,
wantRepo: "https://github.com/OWNER/REPO",
wantRepo: "https://github.com/ORIGINOWNER/REPO",
},
{
name: "the current branch is configured to push to and pull from a URL (upstream, in this example) that is different from what the repo is configured to push to and pull from (origin, in this example) and push.default = upstream, it finds the PR associated with the upstream repo and returns origin as the base repo",
args: args{
selector: "",
fields: []string{"id", "number"},
baseRepoFn: stubBaseRepoFn(remoteOrigin.Repo, nil),
branchFn: func() (string, error) {
return "blueberries", nil
},
branchConfig: stubBranchConfig(git.BranchConfig{
MergeRef: "refs/heads/blue-upstream-berries",
PushRemoteURL: remoteUpstream.Remote.FetchURL,
}, nil),
pushDefault: stubPushDefault("upstream", nil),
},
httpStub: func(r *httpmock.Registry) {
r.Register(
httpmock.GraphQL(`query PullRequestForBranch\b`),
httpmock.StringResponse(`{"data":{"repository":{
"pullRequests":{"nodes":[
{
"number": 13,
"state": "OPEN",
"baseRefName": "main",
"headRefName": "blue-upstream-berries",
"isCrossRepository": true,
"headRepositoryOwner": {"login":"UPSTREAMOWNER"}
}
]}
}}}`))
},
wantPR: 13,
wantRepo: "https://github.com/ORIGINOWNER/REPO",
},
{
name: "current branch with upstream and fork in same org",
args: args{
selector: "",
fields: []string{"id", "number"},
baseRepoFn: func() (ghrepo.Interface, error) {
return ghrepo.FromFullName("OWNER/REPO")
},
selector: "",
fields: []string{"id", "number"},
baseRepoFn: stubBaseRepoFn(remoteOrigin.Repo, nil),
branchFn: func() (string, error) {
return "blueberries", nil
},
branchConfig: stubBranchConfig(git.BranchConfig{
RemoteName: "origin",
MergeRef: "refs/heads/main",
PushRemoteName: "origin",
Push: "origin/blueberries",
Push: "other/blueberries",
}, nil),
remotesFn: func() (context.Remotes, error) {
return context.Remotes{{
Remote: &git.Remote{Name: "origin"},
Repo: ghrepo.New("OWNER", "REPO-FORK"),
}, {
Remote: &git.Remote{Name: "upstream"},
Repo: ghrepo.New("OWNER", "REPO"),
}}, nil
},
pushDefault: stubPushDefault("simple", nil),
},
httpStub: func(r *httpmock.Registry) {
r.Register(
@ -486,13 +458,13 @@ func TestFind(t *testing.T) {
"baseRefName": "main",
"headRefName": "blueberries",
"isCrossRepository": true,
"headRepositoryOwner": {"login":"OWNER"}
"headRepositoryOwner": {"login":"ORIGINOWNER"}
}
]}
}}}`))
},
wantPR: 13,
wantRepo: "https://github.com/OWNER/REPO",
wantRepo: "https://github.com/ORIGINOWNER/REPO",
},
{
name: "current branch made by pr checkout",
@ -533,6 +505,7 @@ func TestFind(t *testing.T) {
branchConfig: stubBranchConfig(git.BranchConfig{
MergeRef: "refs/pull/13/head",
}, nil),
pushDefault: stubPushDefault("simple", nil),
},
httpStub: func(r *httpmock.Registry) {
r.Register(
@ -597,7 +570,11 @@ func TestFind(t *testing.T) {
branchFn: tt.args.branchFn,
branchConfig: tt.args.branchConfig,
pushDefault: tt.args.pushDefault,
remotesFn: tt.args.remotesFn,
remotesFn: stubRemotes(context.Remotes{
&remoteOrigin,
&remoteOther,
&remoteUpstream,
}, nil),
}
pr, repo, err := f.Find(FindOptions{
@ -630,42 +607,262 @@ func TestFind(t *testing.T) {
}
}
func Test_parseCurrentBranch(t *testing.T) {
func Test_parsePRRefs(t *testing.T) {
originOwnerUrl, err := url.Parse("https://github.com/ORIGINOWNER/REPO.git")
if err != nil {
t.Fatal(err)
}
remoteOrigin := context.Remote{
Remote: &git.Remote{
Name: "origin",
FetchURL: originOwnerUrl,
},
Repo: ghrepo.New("ORIGINOWNER", "REPO"),
}
remoteOther := context.Remote{
Remote: &git.Remote{
Name: "other",
FetchURL: originOwnerUrl,
},
Repo: ghrepo.New("ORIGINOWNER", "REPO"),
}
upstreamOwnerUrl, err := url.Parse("https://github.com/UPSTREAMOWNER/REPO.git")
if err != nil {
t.Fatal(err)
}
remoteUpstream := context.Remote{
Remote: &git.Remote{
Name: "upstream",
FetchURL: upstreamOwnerUrl,
},
Repo: ghrepo.New("UPSTREAMOWNER", "REPO"),
}
tests := []struct {
name string
args args
wantSelector string
wantPR int
wantError error
name string
branchConfig git.BranchConfig
pushDefault string
currentBranchName string
baseRefRepo ghrepo.Interface
rems context.Remotes
wantPRRefs PRRefs
wantErr error
}{
{
name: "failed branch config",
args: args{
branchConfig: stubBranchConfig(git.BranchConfig{}, errors.New("branchConfigErr")),
branchFn: func() (string, error) {
return "blueberries", nil
},
name: "When the branch is called 'blueberries' with an empty branch config, it returns the correct PRRefs",
branchConfig: git.BranchConfig{},
currentBranchName: "blueberries",
baseRefRepo: remoteOrigin.Repo,
wantPRRefs: PRRefs{
BranchName: "blueberries",
HeadRepo: remoteOrigin.Repo,
BaseRepo: remoteOrigin.Repo,
},
wantSelector: "",
wantPR: 0,
wantError: errors.New("branchConfigErr"),
wantErr: nil,
},
{
name: "When the branch is called 'otherBranch' with an empty branch config, it returns the correct PRRefs",
branchConfig: git.BranchConfig{},
currentBranchName: "otherBranch",
baseRefRepo: remoteOrigin.Repo,
wantPRRefs: PRRefs{
BranchName: "otherBranch",
HeadRepo: remoteOrigin.Repo,
BaseRepo: remoteOrigin.Repo,
},
wantErr: nil,
},
{
name: "When the branch name doesn't match the branch name in BranchConfig.Push, it returns the BranchConfig.Push branch name",
branchConfig: git.BranchConfig{
Push: "origin/pushBranch",
},
currentBranchName: "blueberries",
baseRefRepo: remoteOrigin.Repo,
rems: context.Remotes{
&remoteOrigin,
},
wantPRRefs: PRRefs{
BranchName: "pushBranch",
HeadRepo: remoteOrigin.Repo,
BaseRepo: remoteOrigin.Repo,
},
wantErr: nil,
},
{
name: "When the branch name doesn't match a different branch name in BranchConfig.Push, it returns the BranchConfig.Push branch name",
branchConfig: git.BranchConfig{
Push: "origin/differentPushBranch",
},
currentBranchName: "blueberries",
baseRefRepo: remoteOrigin.Repo,
rems: context.Remotes{
&remoteOrigin,
},
wantPRRefs: PRRefs{
BranchName: "differentPushBranch",
HeadRepo: remoteOrigin.Repo,
BaseRepo: remoteOrigin.Repo,
},
wantErr: nil,
},
{
name: "When the branch name doesn't match a different branch name in BranchConfig.Push and the remote isn't 'origin', it returns the BranchConfig.Push branch name",
branchConfig: git.BranchConfig{
Push: "other/pushBranch",
},
currentBranchName: "blueberries",
baseRefRepo: remoteOrigin.Repo,
rems: context.Remotes{
&remoteOther,
},
wantPRRefs: PRRefs{
BranchName: "pushBranch",
HeadRepo: remoteOther.Repo,
BaseRepo: remoteOrigin.Repo,
},
wantErr: nil,
},
{
name: "When the push remote is the same as the baseRepo, it returns the baseRepo as the PRRefs HeadRepo",
branchConfig: git.BranchConfig{
PushRemoteName: remoteOrigin.Remote.Name,
},
currentBranchName: "blueberries",
baseRefRepo: remoteOrigin.Repo,
rems: context.Remotes{
&remoteOrigin,
&remoteUpstream,
},
wantPRRefs: PRRefs{
BranchName: "blueberries",
HeadRepo: remoteOrigin.Repo,
BaseRepo: remoteOrigin.Repo,
},
wantErr: nil,
},
{
name: "When the push remote is different from the baseRepo, it returns the push remote repo as the PRRefs HeadRepo",
branchConfig: git.BranchConfig{
PushRemoteName: remoteOrigin.Remote.Name,
},
currentBranchName: "blueberries",
baseRefRepo: remoteUpstream.Repo,
rems: context.Remotes{
&remoteOrigin,
&remoteUpstream,
},
wantPRRefs: PRRefs{
BranchName: "blueberries",
HeadRepo: remoteOrigin.Repo,
BaseRepo: remoteUpstream.Repo,
},
wantErr: nil,
},
{
name: "When the push remote defined by a URL and the baseRepo is different from the push remote, it returns the push remote repo as the PRRefs HeadRepo",
branchConfig: git.BranchConfig{
PushRemoteURL: remoteOrigin.Remote.FetchURL,
},
currentBranchName: "blueberries",
baseRefRepo: remoteUpstream.Repo,
rems: context.Remotes{
&remoteOrigin,
&remoteUpstream,
},
wantPRRefs: PRRefs{
BranchName: "blueberries",
HeadRepo: remoteOrigin.Repo,
BaseRepo: remoteUpstream.Repo,
},
wantErr: nil,
},
{
name: "When the push remote and merge ref are configured to a different repo and push.default = upstream, it should return the branch name from the other repo",
branchConfig: git.BranchConfig{
PushRemoteName: remoteUpstream.Remote.Name,
MergeRef: "refs/heads/blue-upstream-berries",
},
pushDefault: "upstream",
currentBranchName: "blueberries",
baseRefRepo: remoteOrigin.Repo,
rems: context.Remotes{
&remoteOrigin,
&remoteUpstream,
},
wantPRRefs: PRRefs{
BranchName: "blue-upstream-berries",
HeadRepo: remoteUpstream.Repo,
BaseRepo: remoteOrigin.Repo,
},
wantErr: nil,
},
{
name: "When the push remote and merge ref are configured to a different repo and push.default = tracking, it should return the branch name from the other repo",
branchConfig: git.BranchConfig{
PushRemoteName: remoteUpstream.Remote.Name,
MergeRef: "refs/heads/blue-upstream-berries",
},
pushDefault: "tracking",
currentBranchName: "blueberries",
baseRefRepo: remoteOrigin.Repo,
rems: context.Remotes{
&remoteOrigin,
&remoteUpstream,
},
wantPRRefs: PRRefs{
BranchName: "blue-upstream-berries",
HeadRepo: remoteUpstream.Repo,
BaseRepo: remoteOrigin.Repo,
},
wantErr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
f := finder{
httpClient: func() (*http.Client, error) {
return &http.Client{}, nil
},
baseRepoFn: tt.args.baseRepoFn,
branchFn: tt.args.branchFn,
branchConfig: tt.args.branchConfig,
remotesFn: tt.args.remotesFn,
prRefs, err := parsePRRefs(tt.currentBranchName, tt.branchConfig, tt.pushDefault, tt.baseRefRepo, tt.rems)
if tt.wantErr != nil {
require.Error(t, err)
assert.Equal(t, tt.wantErr, err)
} else {
require.NoError(t, err)
}
selector, pr, err := f.parseCurrentBranch()
assert.Equal(t, tt.wantSelector, selector)
assert.Equal(t, tt.wantPR, pr)
assert.Equal(t, tt.wantError, err)
assert.Equal(t, tt.wantPRRefs, prRefs)
})
}
}
func TestPRRefs_GetPRLabel(t *testing.T) {
originRepo := ghrepo.New("ORIGINOWNER", "REPO")
upstreamRepo := ghrepo.New("UPSTREAMOWNER", "REPO")
tests := []struct {
name string
prRefs PRRefs
want string
}{
{
name: "When the HeadRepo and BaseRepo match, it returns the branch name",
prRefs: PRRefs{
BranchName: "blueberries",
HeadRepo: originRepo,
BaseRepo: originRepo,
},
want: "blueberries",
},
{
name: "When the HeadRepo and BaseRepo do not match, it returns the prepended HeadRepo owner to the branch name",
prRefs: PRRefs{
BranchName: "blueberries",
HeadRepo: originRepo,
BaseRepo: upstreamRepo,
},
want: "ORIGINOWNER:blueberries",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.want, tt.prRefs.GetPRLabel())
})
}
}
@ -675,3 +872,21 @@ func stubBranchConfig(branchConfig git.BranchConfig, err error) func(string) (gi
return branchConfig, err
}
}
func stubRemotes(remotes context.Remotes, err error) func() (context.Remotes, error) {
return func() (context.Remotes, error) {
return remotes, err
}
}
func stubBaseRepoFn(baseRepo ghrepo.Interface, err error) func() (ghrepo.Interface, error) {
return func() (ghrepo.Interface, error) {
return baseRepo, err
}
}
func stubPushDefault(pushDefault string, err error) func() (string, error) {
return func() (string, error) {
return pushDefault, err
}
}