diff --git a/pkg/cmd/pr/create/create.go b/pkg/cmd/pr/create/create.go index b05c971fe..00993e0fa 100644 --- a/pkg/cmd/pr/create/create.go +++ b/pkg/cmd/pr/create/create.go @@ -518,9 +518,9 @@ func initDefaultTitleBody(ctx CreateContext, state *shared.IssueMetadataState, u return nil } -// determineTrackingBranch is intended to try and find a remote branch on the same commit as the currently checked out +// tryDetermineTrackingRef is intended to try and find a remote branch on the same commit as the currently checked out // HEAD, i.e. the local branch. -func determineTrackingBranch(gitClient *git.Client, remotes ghContext.Remotes, localBranchName string, headBranchConfig git.BranchConfig) *git.TrackingRef { +func tryDetermineTrackingRef(gitClient *git.Client, remotes ghContext.Remotes, localBranchName string, headBranchConfig git.BranchConfig) (git.TrackingRef, bool) { // To try and determine the tracking ref for a local branch, we first construct a collection of refs // that might be tracking, given the current branch's config, and the list of known remotes. refsForLookup := []string{"HEAD"} @@ -557,12 +557,11 @@ func determineTrackingBranch(gitClient *git.Client, remotes ghContext.Remotes, l continue } // Otherwise we can parse the returned ref into a tracking ref and return that - trackingRef := mustParseTrackingRef(r.Name) - return &trackingRef + return mustParseTrackingRef(r.Name), true } } - return nil + return git.TrackingRef{}, false } func mustParseTrackingRef(text string) git.TrackingRef { @@ -665,14 +664,14 @@ func NewCreateContext(opts *CreateOptions) (*CreateContext, error) { headBranchConfig := gitClient.ReadBranchConfig(context.Background(), headBranch) if isPushEnabled { // determine whether the head branch is already pushed to a remote - if pushedTo := determineTrackingBranch(gitClient, remotes, headBranch, headBranchConfig); pushedTo != nil { + if trackingRef, found := tryDetermineTrackingRef(gitClient, remotes, headBranch, headBranchConfig); found { isPushEnabled = false - if r, err := remotes.FindByName(pushedTo.RemoteName); err == nil { + if r, err := remotes.FindByName(trackingRef.RemoteName); err == nil { headRepo = r headRemote = r - headBranchLabel = pushedTo.BranchName + headBranchLabel = trackingRef.BranchName if !ghrepo.IsSame(baseRepo, headRepo) { - headBranchLabel = fmt.Sprintf("%s:%s", headRepo.RepoOwner(), pushedTo.BranchName) + headBranchLabel = fmt.Sprintf("%s:%s", headRepo.RepoOwner(), trackingRef.BranchName) } } } diff --git a/pkg/cmd/pr/create/create_test.go b/pkg/cmd/pr/create/create_test.go index 8d5a1d1b3..acb35f557 100644 --- a/pkg/cmd/pr/create/create_test.go +++ b/pkg/cmd/pr/create/create_test.go @@ -1622,12 +1622,12 @@ func Test_createRun(t *testing.T) { } } -func Test_determineTrackingBranch(t *testing.T) { +func Test_tryDetermineTrackingRef(t *testing.T) { tests := []struct { name string cmdStubs func(*run.CommandStubber) remotes context.Remotes - assert func(ref *git.TrackingRef, t *testing.T) + assert func(ref git.TrackingRef, found bool, t *testing.T) }{ { name: "empty", @@ -1635,8 +1635,9 @@ func Test_determineTrackingBranch(t *testing.T) { cs.Register(`git config --get-regexp.+branch\\\.feature\\\.`, 0, "") cs.Register(`git show-ref --verify -- HEAD`, 0, "abc HEAD") }, - assert: func(ref *git.TrackingRef, t *testing.T) { - assert.Nil(t, ref) + assert: func(ref git.TrackingRef, found bool, t *testing.T) { + assert.Zero(t, ref) + assert.False(t, found) }, }, { @@ -1655,8 +1656,9 @@ func Test_determineTrackingBranch(t *testing.T) { Repo: ghrepo.New("octocat", "Spoon-Knife"), }, }, - assert: func(ref *git.TrackingRef, t *testing.T) { - assert.Nil(t, ref) + assert: func(ref git.TrackingRef, found bool, t *testing.T) { + assert.Zero(t, ref) + assert.False(t, found) }, }, { @@ -1679,9 +1681,12 @@ func Test_determineTrackingBranch(t *testing.T) { Repo: ghrepo.New("octocat", "Spoon-Knife"), }, }, - assert: func(ref *git.TrackingRef, t *testing.T) { - assert.Equal(t, "upstream", ref.RemoteName) - assert.Equal(t, "feature", ref.BranchName) + assert: func(ref git.TrackingRef, found bool, t *testing.T) { + assert.Equal(t, git.TrackingRef{ + RemoteName: "upstream", + BranchName: "feature", + }, ref) + assert.True(t, found) }, }, { @@ -1702,8 +1707,9 @@ func Test_determineTrackingBranch(t *testing.T) { Repo: ghrepo.New("hubot", "Spoon-Knife"), }, }, - assert: func(ref *git.TrackingRef, t *testing.T) { - assert.Nil(t, ref) + assert: func(ref git.TrackingRef, found bool, t *testing.T) { + assert.Zero(t, ref) + assert.False(t, found) }, }, } @@ -1719,8 +1725,9 @@ func Test_determineTrackingBranch(t *testing.T) { GitPath: "some/path/git", } headBranchConfig := gitClient.ReadBranchConfig(ctx.Background(), "feature") - ref := determineTrackingBranch(gitClient, tt.remotes, "feature", headBranchConfig) - tt.assert(ref, t) + ref, found := tryDetermineTrackingRef(gitClient, tt.remotes, "feature", headBranchConfig) + + tt.assert(ref, found, t) }) } }