From 7a1052ca339a2475fddd08c22cca1eacd8707251 Mon Sep 17 00:00:00 2001 From: Caleb Brose <5447118+cmbrose@users.noreply.github.com> Date: Fri, 3 Jan 2025 20:35:48 +0000 Subject: [PATCH 01/12] Set LocalBranch even if the git config fails --- .../pr-create-without-upstream-config.txtar | 27 +++++++++++++++++++ git/client.go | 3 ++- 2 files changed, 29 insertions(+), 1 deletion(-) create mode 100644 acceptance/testdata/pr/pr-create-without-upstream-config.txtar diff --git a/acceptance/testdata/pr/pr-create-without-upstream-config.txtar b/acceptance/testdata/pr/pr-create-without-upstream-config.txtar new file mode 100644 index 000000000..00f3535a7 --- /dev/null +++ b/acceptance/testdata/pr/pr-create-without-upstream-config.txtar @@ -0,0 +1,27 @@ +# This test is the same as pr-create-basic, except that the git push doesn't include the -u argument +# This causes a git config read to fail during gh pr create, but it should not be fatal + +# Use gh as a credential helper +exec gh auth setup-git + +# Create a repository with a file so it has a default branch +exec gh repo create $ORG/$SCRIPT_NAME-$RANDOM_STRING --add-readme --private + +# Defer repo cleanup +defer gh repo delete --yes $ORG/$SCRIPT_NAME-$RANDOM_STRING + +# Clone the repo +exec gh repo clone $ORG/$SCRIPT_NAME-$RANDOM_STRING + +# Prepare a branch to PR +cd $SCRIPT_NAME-$RANDOM_STRING +exec git checkout -b feature-branch +exec git commit --allow-empty -m 'Empty Commit' +exec git push origin feature-branch + +# Create the PR +exec gh pr create --title 'Feature Title' --body 'Feature Body' + +# Check the PR is indeed created +exec gh pr view +stdout 'Feature Title' diff --git a/git/client.go b/git/client.go index 35b9cf16c..774ce5e76 100644 --- a/git/client.go +++ b/git/client.go @@ -378,6 +378,8 @@ func (c *Client) lookupCommit(ctx context.Context, sha, format string) ([]byte, // ReadBranchConfig parses the `branch.BRANCH.(remote|merge|gh-merge-base)` part of git config. func (c *Client) ReadBranchConfig(ctx context.Context, branch string) (cfg BranchConfig) { + cfg.LocalName = branch + prefix := regexp.QuoteMeta(fmt.Sprintf("branch.%s.", branch)) args := []string{"config", "--get-regexp", fmt.Sprintf("^%s(remote|merge|%s)$", prefix, MergeBaseConfig)} cmd, err := c.Command(ctx, args...) @@ -389,7 +391,6 @@ func (c *Client) ReadBranchConfig(ctx context.Context, branch string) (cfg Branc return } - cfg.LocalName = branch for _, line := range outputLines(out) { parts := strings.SplitN(line, " ", 2) if len(parts) < 2 { From 9d490547b8d20f168e9c665172a3a12204e2f497 Mon Sep 17 00:00:00 2001 From: Caleb Brose <5447118+cmbrose@users.noreply.github.com> Date: Fri, 3 Jan 2025 20:39:12 +0000 Subject: [PATCH 02/12] Alternative: remove LocalBranch from BranchConfig --- git/client.go | 2 -- git/client_test.go | 2 +- git/objects.go | 2 -- pkg/cmd/pr/create/create.go | 6 +++--- 4 files changed, 4 insertions(+), 8 deletions(-) diff --git a/git/client.go b/git/client.go index 774ce5e76..1a6d9ae7f 100644 --- a/git/client.go +++ b/git/client.go @@ -378,8 +378,6 @@ func (c *Client) lookupCommit(ctx context.Context, sha, format string) ([]byte, // ReadBranchConfig parses the `branch.BRANCH.(remote|merge|gh-merge-base)` part of git config. func (c *Client) ReadBranchConfig(ctx context.Context, branch string) (cfg BranchConfig) { - cfg.LocalName = branch - prefix := regexp.QuoteMeta(fmt.Sprintf("branch.%s.", branch)) args := []string{"config", "--get-regexp", fmt.Sprintf("^%s(remote|merge|%s)$", prefix, MergeBaseConfig)} cmd, err := c.Command(ctx, args...) diff --git a/git/client_test.go b/git/client_test.go index 0ec4f7334..fff1397f3 100644 --- a/git/client_test.go +++ b/git/client_test.go @@ -737,7 +737,7 @@ func TestClientReadBranchConfig(t *testing.T) { name: "read branch config", cmdStdout: "branch.trunk.remote origin\nbranch.trunk.merge refs/heads/trunk\nbranch.trunk.gh-merge-base trunk", wantCmdArgs: `path/to/git config --get-regexp ^branch\.trunk\.(remote|merge|gh-merge-base)$`, - wantBranchConfig: BranchConfig{LocalName: "trunk", RemoteName: "origin", MergeRef: "refs/heads/trunk", MergeBase: "trunk"}, + wantBranchConfig: BranchConfig{RemoteName: "origin", MergeRef: "refs/heads/trunk", MergeBase: "trunk"}, }, } for _, tt := range tests { diff --git a/git/objects.go b/git/objects.go index f371429dd..ae058a01c 100644 --- a/git/objects.go +++ b/git/objects.go @@ -71,8 +71,6 @@ type Commit struct { } type BranchConfig struct { - // LocalName of the branch. - LocalName string RemoteName string RemoteURL *url.URL // MergeBase is the optional base branch to target in a new PR if `--base` is not specified. diff --git a/pkg/cmd/pr/create/create.go b/pkg/cmd/pr/create/create.go index 3bfa768b7..6a1e13849 100644 --- a/pkg/cmd/pr/create/create.go +++ b/pkg/cmd/pr/create/create.go @@ -518,7 +518,7 @@ func initDefaultTitleBody(ctx CreateContext, state *shared.IssueMetadataState, u return nil } -func determineTrackingBranch(gitClient *git.Client, remotes ghContext.Remotes, headBranchConfig *git.BranchConfig) *git.TrackingRef { +func determineTrackingBranch(gitClient *git.Client, remotes ghContext.Remotes, localBranchName string, headBranchConfig *git.BranchConfig) *git.TrackingRef { refsForLookup := []string{"HEAD"} var trackingRefs []git.TrackingRef @@ -534,7 +534,7 @@ func determineTrackingBranch(gitClient *git.Client, remotes ghContext.Remotes, h for _, remote := range remotes { tr := git.TrackingRef{ RemoteName: remote.Name, - BranchName: headBranchConfig.LocalName, + BranchName: localBranchName, } trackingRefs = append(trackingRefs, tr) refsForLookup = append(refsForLookup, tr.String()) @@ -647,7 +647,7 @@ func NewCreateContext(opts *CreateOptions) (*CreateContext, error) { headBranchConfig := gitClient.ReadBranchConfig(context.Background(), headBranch) if isPushEnabled { // determine whether the head branch is already pushed to a remote - if pushedTo := determineTrackingBranch(gitClient, remotes, &headBranchConfig); pushedTo != nil { + if pushedTo := determineTrackingBranch(gitClient, remotes, headBranch, &headBranchConfig); pushedTo != nil { isPushEnabled = false if r, err := remotes.FindByName(pushedTo.RemoteName); err == nil { headRepo = r From 67749480d566fd2cc268ea68155d132197a5e258 Mon Sep 17 00:00:00 2001 From: Caleb Brose <5447118+cmbrose@users.noreply.github.com> Date: Fri, 3 Jan 2025 20:45:20 +0000 Subject: [PATCH 03/12] Fix test --- pkg/cmd/pr/create/create_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/cmd/pr/create/create_test.go b/pkg/cmd/pr/create/create_test.go index 27220d052..9d6584e6c 100644 --- a/pkg/cmd/pr/create/create_test.go +++ b/pkg/cmd/pr/create/create_test.go @@ -1719,7 +1719,7 @@ func Test_determineTrackingBranch(t *testing.T) { GitPath: "some/path/git", } headBranchConfig := gitClient.ReadBranchConfig(ctx.Background(), "feature") - ref := determineTrackingBranch(gitClient, tt.remotes, &headBranchConfig) + ref := determineTrackingBranch(gitClient, tt.remotes, "feature", &headBranchConfig) tt.assert(ref, t) }) } From 3ae4e5da20ebd9f9f46342906c0a930b75310d3b Mon Sep 17 00:00:00 2001 From: William Martin Date: Mon, 6 Jan 2025 15:14:43 +0100 Subject: [PATCH 04/12] Document and rework pr create tracking branch lookup --- git/objects.go | 12 ++++++++++++ pkg/cmd/pr/create/create.go | 30 +++++++++++++++++++----------- 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/git/objects.go b/git/objects.go index ae058a01c..924a3059e 100644 --- a/git/objects.go +++ b/git/objects.go @@ -1,6 +1,7 @@ package git import ( + "fmt" "net/url" "strings" ) @@ -64,6 +65,17 @@ func (r TrackingRef) String() string { return "refs/remotes/" + r.RemoteName + "/" + r.BranchName } +func ParseTrackingRef(text string) (TrackingRef, error) { + parts := strings.SplitN(string(text), "/", 4) + if len(parts) != 4 { + return TrackingRef{}, fmt.Errorf("invalid tracking ref: %s", text) + } + return TrackingRef{ + RemoteName: parts[2], + BranchName: parts[3], + }, nil +} + type Commit struct { Sha string Title string diff --git a/pkg/cmd/pr/create/create.go b/pkg/cmd/pr/create/create.go index 6a1e13849..df62dc889 100644 --- a/pkg/cmd/pr/create/create.go +++ b/pkg/cmd/pr/create/create.go @@ -519,15 +519,14 @@ func initDefaultTitleBody(ctx CreateContext, state *shared.IssueMetadataState, u } func determineTrackingBranch(gitClient *git.Client, remotes ghContext.Remotes, localBranchName string, headBranchConfig *git.BranchConfig) *git.TrackingRef { + // To try and determine the tracking ref for a local branch, we first construct a collection of refs + // that might be tracking, given the current branch's config, and the list of known remotes. refsForLookup := []string{"HEAD"} - var trackingRefs []git.TrackingRef - - if headBranchConfig.RemoteName != "" { + if headBranchConfig.RemoteName != "" && headBranchConfig.MergeRef != "" { tr := git.TrackingRef{ RemoteName: headBranchConfig.RemoteName, BranchName: strings.TrimPrefix(headBranchConfig.MergeRef, "refs/heads/"), } - trackingRefs = append(trackingRefs, tr) refsForLookup = append(refsForLookup, tr.String()) } @@ -536,22 +535,31 @@ func determineTrackingBranch(gitClient *git.Client, remotes ghContext.Remotes, l RemoteName: remote.Name, BranchName: localBranchName, } - trackingRefs = append(trackingRefs, tr) refsForLookup = append(refsForLookup, tr.String()) } + // Then we ask git for details about these refs, for example, refs/remotes/origin/trunk might return a hash + // for the remote tracking branch, trunk, for the remote, origin. If there is no ref, the git client returns + // no ref information. + // + // We also first check for the HEAD ref, so that we have the hash of the currently checked out commit. resolvedRefs, _ := gitClient.ShowRefs(context.Background(), refsForLookup) + + // If there is more than one resolved ref, that means that at least one ref was found in addition to the HEAD. if len(resolvedRefs) > 1 { + headRef := resolvedRefs[0] for _, r := range resolvedRefs[1:] { - if r.Hash != resolvedRefs[0].Hash { + // If the hash of the remote ref doesn't match the hash of HEAD then the remote branch is not in the same + // state, so it can't be used. + if r.Hash != headRef.Hash { continue } - for _, tr := range trackingRefs { - if tr.String() != r.Name { - continue - } - return &tr + // Otherwise we can parse the returned ref into a tracking ref and return that + trackingRef, err := git.ParseTrackingRef(r.Name) + if err != nil { + return nil } + return &trackingRef } } From dc077dc09ba68843e973b715c448e376135e6c67 Mon Sep 17 00:00:00 2001 From: William Martin Date: Mon, 6 Jan 2025 15:44:34 +0100 Subject: [PATCH 05/12] Panic if tracking ref can't be reconstructed --- git/objects.go | 12 ------------ pkg/cmd/pr/create/create.go | 16 ++++++++++++---- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/git/objects.go b/git/objects.go index 924a3059e..ae058a01c 100644 --- a/git/objects.go +++ b/git/objects.go @@ -1,7 +1,6 @@ package git import ( - "fmt" "net/url" "strings" ) @@ -65,17 +64,6 @@ func (r TrackingRef) String() string { return "refs/remotes/" + r.RemoteName + "/" + r.BranchName } -func ParseTrackingRef(text string) (TrackingRef, error) { - parts := strings.SplitN(string(text), "/", 4) - if len(parts) != 4 { - return TrackingRef{}, fmt.Errorf("invalid tracking ref: %s", text) - } - return TrackingRef{ - RemoteName: parts[2], - BranchName: parts[3], - }, nil -} - type Commit struct { Sha string Title string diff --git a/pkg/cmd/pr/create/create.go b/pkg/cmd/pr/create/create.go index df62dc889..d7032714a 100644 --- a/pkg/cmd/pr/create/create.go +++ b/pkg/cmd/pr/create/create.go @@ -555,10 +555,7 @@ func determineTrackingBranch(gitClient *git.Client, remotes ghContext.Remotes, l continue } // Otherwise we can parse the returned ref into a tracking ref and return that - trackingRef, err := git.ParseTrackingRef(r.Name) - if err != nil { - return nil - } + trackingRef := mustParseTrackingRef(r.Name) return &trackingRef } } @@ -566,6 +563,17 @@ func determineTrackingBranch(gitClient *git.Client, remotes ghContext.Remotes, l return nil } +func mustParseTrackingRef(text string) git.TrackingRef { + parts := strings.SplitN(string(text), "/", 4) + if len(parts) != 4 { + panic(fmt.Errorf("invalid tracking ref: %s", text)) + } + return git.TrackingRef{ + RemoteName: parts[2], + BranchName: parts[3], + } +} + func NewIssueState(ctx CreateContext, opts CreateOptions) (*shared.IssueMetadataState, error) { var milestoneTitles []string if opts.Milestone != "" { From 05764b8114e59411cf41723124d83fbb400d9395 Mon Sep 17 00:00:00 2001 From: William Martin Date: Mon, 6 Jan 2025 15:47:10 +0100 Subject: [PATCH 06/12] Don't use pointer for determineTrackingBranch branchConfig --- pkg/cmd/pr/create/create.go | 4 ++-- pkg/cmd/pr/create/create_test.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pkg/cmd/pr/create/create.go b/pkg/cmd/pr/create/create.go index d7032714a..40913db64 100644 --- a/pkg/cmd/pr/create/create.go +++ b/pkg/cmd/pr/create/create.go @@ -518,7 +518,7 @@ func initDefaultTitleBody(ctx CreateContext, state *shared.IssueMetadataState, u return nil } -func determineTrackingBranch(gitClient *git.Client, remotes ghContext.Remotes, localBranchName string, headBranchConfig *git.BranchConfig) *git.TrackingRef { +func determineTrackingBranch(gitClient *git.Client, remotes ghContext.Remotes, localBranchName string, headBranchConfig git.BranchConfig) *git.TrackingRef { // To try and determine the tracking ref for a local branch, we first construct a collection of refs // that might be tracking, given the current branch's config, and the list of known remotes. refsForLookup := []string{"HEAD"} @@ -663,7 +663,7 @@ func NewCreateContext(opts *CreateOptions) (*CreateContext, error) { headBranchConfig := gitClient.ReadBranchConfig(context.Background(), headBranch) if isPushEnabled { // determine whether the head branch is already pushed to a remote - if pushedTo := determineTrackingBranch(gitClient, remotes, headBranch, &headBranchConfig); pushedTo != nil { + if pushedTo := determineTrackingBranch(gitClient, remotes, headBranch, headBranchConfig); pushedTo != nil { isPushEnabled = false if r, err := remotes.FindByName(pushedTo.RemoteName); err == nil { headRepo = r diff --git a/pkg/cmd/pr/create/create_test.go b/pkg/cmd/pr/create/create_test.go index 9d6584e6c..8d5a1d1b3 100644 --- a/pkg/cmd/pr/create/create_test.go +++ b/pkg/cmd/pr/create/create_test.go @@ -1719,7 +1719,7 @@ func Test_determineTrackingBranch(t *testing.T) { GitPath: "some/path/git", } headBranchConfig := gitClient.ReadBranchConfig(ctx.Background(), "feature") - ref := determineTrackingBranch(gitClient, tt.remotes, "feature", &headBranchConfig) + ref := determineTrackingBranch(gitClient, tt.remotes, "feature", headBranchConfig) tt.assert(ref, t) }) } From 27bd4b2aec09f1242177821ad72a01b01163c998 Mon Sep 17 00:00:00 2001 From: William Martin Date: Mon, 6 Jan 2025 15:49:53 +0100 Subject: [PATCH 07/12] Doc determineTrackingBranch --- pkg/cmd/pr/create/create.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pkg/cmd/pr/create/create.go b/pkg/cmd/pr/create/create.go index 40913db64..b05c971fe 100644 --- a/pkg/cmd/pr/create/create.go +++ b/pkg/cmd/pr/create/create.go @@ -518,6 +518,8 @@ func initDefaultTitleBody(ctx CreateContext, state *shared.IssueMetadataState, u return nil } +// determineTrackingBranch is intended to try and find a remote branch on the same commit as the currently checked out +// HEAD, i.e. the local branch. func determineTrackingBranch(gitClient *git.Client, remotes ghContext.Remotes, localBranchName string, headBranchConfig git.BranchConfig) *git.TrackingRef { // To try and determine the tracking ref for a local branch, we first construct a collection of refs // that might be tracking, given the current branch's config, and the list of known remotes. From b8c167970b70205cba2fd22e2bb2da378229f0d7 Mon Sep 17 00:00:00 2001 From: William Martin Date: Mon, 6 Jan 2025 15:55:35 +0100 Subject: [PATCH 08/12] Avoid pointer return from determineTrackingBranch --- pkg/cmd/pr/create/create.go | 17 ++++++++-------- pkg/cmd/pr/create/create_test.go | 33 +++++++++++++++++++------------- 2 files changed, 28 insertions(+), 22 deletions(-) diff --git a/pkg/cmd/pr/create/create.go b/pkg/cmd/pr/create/create.go index b05c971fe..00993e0fa 100644 --- a/pkg/cmd/pr/create/create.go +++ b/pkg/cmd/pr/create/create.go @@ -518,9 +518,9 @@ func initDefaultTitleBody(ctx CreateContext, state *shared.IssueMetadataState, u return nil } -// determineTrackingBranch is intended to try and find a remote branch on the same commit as the currently checked out +// tryDetermineTrackingRef is intended to try and find a remote branch on the same commit as the currently checked out // HEAD, i.e. the local branch. -func determineTrackingBranch(gitClient *git.Client, remotes ghContext.Remotes, localBranchName string, headBranchConfig git.BranchConfig) *git.TrackingRef { +func tryDetermineTrackingRef(gitClient *git.Client, remotes ghContext.Remotes, localBranchName string, headBranchConfig git.BranchConfig) (git.TrackingRef, bool) { // To try and determine the tracking ref for a local branch, we first construct a collection of refs // that might be tracking, given the current branch's config, and the list of known remotes. refsForLookup := []string{"HEAD"} @@ -557,12 +557,11 @@ func determineTrackingBranch(gitClient *git.Client, remotes ghContext.Remotes, l continue } // Otherwise we can parse the returned ref into a tracking ref and return that - trackingRef := mustParseTrackingRef(r.Name) - return &trackingRef + return mustParseTrackingRef(r.Name), true } } - return nil + return git.TrackingRef{}, false } func mustParseTrackingRef(text string) git.TrackingRef { @@ -665,14 +664,14 @@ func NewCreateContext(opts *CreateOptions) (*CreateContext, error) { headBranchConfig := gitClient.ReadBranchConfig(context.Background(), headBranch) if isPushEnabled { // determine whether the head branch is already pushed to a remote - if pushedTo := determineTrackingBranch(gitClient, remotes, headBranch, headBranchConfig); pushedTo != nil { + if trackingRef, found := tryDetermineTrackingRef(gitClient, remotes, headBranch, headBranchConfig); found { isPushEnabled = false - if r, err := remotes.FindByName(pushedTo.RemoteName); err == nil { + if r, err := remotes.FindByName(trackingRef.RemoteName); err == nil { headRepo = r headRemote = r - headBranchLabel = pushedTo.BranchName + headBranchLabel = trackingRef.BranchName if !ghrepo.IsSame(baseRepo, headRepo) { - headBranchLabel = fmt.Sprintf("%s:%s", headRepo.RepoOwner(), pushedTo.BranchName) + headBranchLabel = fmt.Sprintf("%s:%s", headRepo.RepoOwner(), trackingRef.BranchName) } } } diff --git a/pkg/cmd/pr/create/create_test.go b/pkg/cmd/pr/create/create_test.go index 8d5a1d1b3..acb35f557 100644 --- a/pkg/cmd/pr/create/create_test.go +++ b/pkg/cmd/pr/create/create_test.go @@ -1622,12 +1622,12 @@ func Test_createRun(t *testing.T) { } } -func Test_determineTrackingBranch(t *testing.T) { +func Test_tryDetermineTrackingRef(t *testing.T) { tests := []struct { name string cmdStubs func(*run.CommandStubber) remotes context.Remotes - assert func(ref *git.TrackingRef, t *testing.T) + assert func(ref git.TrackingRef, found bool, t *testing.T) }{ { name: "empty", @@ -1635,8 +1635,9 @@ func Test_determineTrackingBranch(t *testing.T) { cs.Register(`git config --get-regexp.+branch\\\.feature\\\.`, 0, "") cs.Register(`git show-ref --verify -- HEAD`, 0, "abc HEAD") }, - assert: func(ref *git.TrackingRef, t *testing.T) { - assert.Nil(t, ref) + assert: func(ref git.TrackingRef, found bool, t *testing.T) { + assert.Zero(t, ref) + assert.False(t, found) }, }, { @@ -1655,8 +1656,9 @@ func Test_determineTrackingBranch(t *testing.T) { Repo: ghrepo.New("octocat", "Spoon-Knife"), }, }, - assert: func(ref *git.TrackingRef, t *testing.T) { - assert.Nil(t, ref) + assert: func(ref git.TrackingRef, found bool, t *testing.T) { + assert.Zero(t, ref) + assert.False(t, found) }, }, { @@ -1679,9 +1681,12 @@ func Test_determineTrackingBranch(t *testing.T) { Repo: ghrepo.New("octocat", "Spoon-Knife"), }, }, - assert: func(ref *git.TrackingRef, t *testing.T) { - assert.Equal(t, "upstream", ref.RemoteName) - assert.Equal(t, "feature", ref.BranchName) + assert: func(ref git.TrackingRef, found bool, t *testing.T) { + assert.Equal(t, git.TrackingRef{ + RemoteName: "upstream", + BranchName: "feature", + }, ref) + assert.True(t, found) }, }, { @@ -1702,8 +1707,9 @@ func Test_determineTrackingBranch(t *testing.T) { Repo: ghrepo.New("hubot", "Spoon-Knife"), }, }, - assert: func(ref *git.TrackingRef, t *testing.T) { - assert.Nil(t, ref) + assert: func(ref git.TrackingRef, found bool, t *testing.T) { + assert.Zero(t, ref) + assert.False(t, found) }, }, } @@ -1719,8 +1725,9 @@ func Test_determineTrackingBranch(t *testing.T) { GitPath: "some/path/git", } headBranchConfig := gitClient.ReadBranchConfig(ctx.Background(), "feature") - ref := determineTrackingBranch(gitClient, tt.remotes, "feature", headBranchConfig) - tt.assert(ref, t) + ref, found := tryDetermineTrackingRef(gitClient, tt.remotes, "feature", headBranchConfig) + + tt.assert(ref, found, t) }) } } From 57ba5e56082e442e3e4e856ba8ca4aec703c8499 Mon Sep 17 00:00:00 2001 From: William Martin Date: Mon, 6 Jan 2025 15:57:14 +0100 Subject: [PATCH 09/12] Rework tryDetermineTrackingRef tests --- pkg/cmd/pr/create/create_test.go | 40 ++++++++++++++------------------ 1 file changed, 17 insertions(+), 23 deletions(-) diff --git a/pkg/cmd/pr/create/create_test.go b/pkg/cmd/pr/create/create_test.go index acb35f557..4785a9d99 100644 --- a/pkg/cmd/pr/create/create_test.go +++ b/pkg/cmd/pr/create/create_test.go @@ -1624,10 +1624,11 @@ func Test_createRun(t *testing.T) { func Test_tryDetermineTrackingRef(t *testing.T) { tests := []struct { - name string - cmdStubs func(*run.CommandStubber) - remotes context.Remotes - assert func(ref git.TrackingRef, found bool, t *testing.T) + name string + cmdStubs func(*run.CommandStubber) + remotes context.Remotes + expectedTrackingRef git.TrackingRef + expectedFound bool }{ { name: "empty", @@ -1635,10 +1636,8 @@ func Test_tryDetermineTrackingRef(t *testing.T) { cs.Register(`git config --get-regexp.+branch\\\.feature\\\.`, 0, "") cs.Register(`git show-ref --verify -- HEAD`, 0, "abc HEAD") }, - assert: func(ref git.TrackingRef, found bool, t *testing.T) { - assert.Zero(t, ref) - assert.False(t, found) - }, + expectedTrackingRef: git.TrackingRef{}, + expectedFound: false, }, { name: "no match", @@ -1656,10 +1655,8 @@ func Test_tryDetermineTrackingRef(t *testing.T) { Repo: ghrepo.New("octocat", "Spoon-Knife"), }, }, - assert: func(ref git.TrackingRef, found bool, t *testing.T) { - assert.Zero(t, ref) - assert.False(t, found) - }, + expectedTrackingRef: git.TrackingRef{}, + expectedFound: false, }, { name: "match", @@ -1681,13 +1678,11 @@ func Test_tryDetermineTrackingRef(t *testing.T) { Repo: ghrepo.New("octocat", "Spoon-Knife"), }, }, - assert: func(ref git.TrackingRef, found bool, t *testing.T) { - assert.Equal(t, git.TrackingRef{ - RemoteName: "upstream", - BranchName: "feature", - }, ref) - assert.True(t, found) + expectedTrackingRef: git.TrackingRef{ + RemoteName: "upstream", + BranchName: "feature", }, + expectedFound: true, }, { name: "respect tracking config", @@ -1707,10 +1702,8 @@ func Test_tryDetermineTrackingRef(t *testing.T) { Repo: ghrepo.New("hubot", "Spoon-Knife"), }, }, - assert: func(ref git.TrackingRef, found bool, t *testing.T) { - assert.Zero(t, ref) - assert.False(t, found) - }, + expectedTrackingRef: git.TrackingRef{}, + expectedFound: false, }, } for _, tt := range tests { @@ -1727,7 +1720,8 @@ func Test_tryDetermineTrackingRef(t *testing.T) { headBranchConfig := gitClient.ReadBranchConfig(ctx.Background(), "feature") ref, found := tryDetermineTrackingRef(gitClient, tt.remotes, "feature", headBranchConfig) - tt.assert(ref, found, t) + assert.Equal(t, tt.expectedTrackingRef, ref) + assert.Equal(t, tt.expectedFound, found) }) } } From 62ecb1c84dbb37ea908fa25c7f5e11a27faaf51c Mon Sep 17 00:00:00 2001 From: William Martin Date: Mon, 6 Jan 2025 16:10:12 +0100 Subject: [PATCH 10/12] Make tryDetermineTrackingRef tests more respective of reality Though it doesn't really matter, in practice upstream is always going to come before origin. --- pkg/cmd/pr/create/create.go | 4 +++- pkg/cmd/pr/create/create_test.go | 26 +++++++++++++------------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/pkg/cmd/pr/create/create.go b/pkg/cmd/pr/create/create.go index 00993e0fa..372daf86c 100644 --- a/pkg/cmd/pr/create/create.go +++ b/pkg/cmd/pr/create/create.go @@ -519,7 +519,9 @@ func initDefaultTitleBody(ctx CreateContext, state *shared.IssueMetadataState, u } // tryDetermineTrackingRef is intended to try and find a remote branch on the same commit as the currently checked out -// HEAD, i.e. the local branch. +// HEAD, i.e. the local branch. If there are multiple branches that might match, the first remote is chosen, which in +// practice is determined by the sorting algorithm applied much earlier in the process, roughly "upstream", "github", "origin", +// and then everything else unstably sorted. func tryDetermineTrackingRef(gitClient *git.Client, remotes ghContext.Remotes, localBranchName string, headBranchConfig git.BranchConfig) (git.TrackingRef, bool) { // To try and determine the tracking ref for a local branch, we first construct a collection of refs // that might be tracking, given the current branch's config, and the list of known remotes. diff --git a/pkg/cmd/pr/create/create_test.go b/pkg/cmd/pr/create/create_test.go index 4785a9d99..c5127bcc6 100644 --- a/pkg/cmd/pr/create/create_test.go +++ b/pkg/cmd/pr/create/create_test.go @@ -1643,17 +1643,17 @@ func Test_tryDetermineTrackingRef(t *testing.T) { name: "no match", cmdStubs: func(cs *run.CommandStubber) { cs.Register(`git config --get-regexp.+branch\\\.feature\\\.`, 0, "") - cs.Register("git show-ref --verify -- HEAD refs/remotes/origin/feature refs/remotes/upstream/feature", 0, "abc HEAD\nbca refs/remotes/origin/feature") + cs.Register("git show-ref --verify -- HEAD refs/remotes/upstream/feature refs/remotes/origin/feature", 0, "abc HEAD\nbca refs/remotes/upstream/feature") }, remotes: context.Remotes{ - &context.Remote{ - Remote: &git.Remote{Name: "origin"}, - Repo: ghrepo.New("hubot", "Spoon-Knife"), - }, &context.Remote{ Remote: &git.Remote{Name: "upstream"}, Repo: ghrepo.New("octocat", "Spoon-Knife"), }, + &context.Remote{ + Remote: &git.Remote{Name: "origin"}, + Repo: ghrepo.New("hubot", "Spoon-Knife"), + }, }, expectedTrackingRef: git.TrackingRef{}, expectedFound: false, @@ -1662,24 +1662,24 @@ func Test_tryDetermineTrackingRef(t *testing.T) { name: "match", cmdStubs: func(cs *run.CommandStubber) { cs.Register(`git config --get-regexp.+branch\\\.feature\\\.`, 0, "") - cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature refs/remotes/upstream/feature$`, 0, heredoc.Doc(` + cs.Register(`git show-ref --verify -- HEAD refs/remotes/upstream/feature refs/remotes/origin/feature$`, 0, heredoc.Doc(` deadbeef HEAD - deadb00f refs/remotes/origin/feature - deadbeef refs/remotes/upstream/feature + deadb00f refs/remotes/upstream/feature + deadbeef refs/remotes/origin/feature `)) }, remotes: context.Remotes{ - &context.Remote{ - Remote: &git.Remote{Name: "origin"}, - Repo: ghrepo.New("hubot", "Spoon-Knife"), - }, &context.Remote{ Remote: &git.Remote{Name: "upstream"}, Repo: ghrepo.New("octocat", "Spoon-Knife"), }, + &context.Remote{ + Remote: &git.Remote{Name: "origin"}, + Repo: ghrepo.New("hubot", "Spoon-Knife"), + }, }, expectedTrackingRef: git.TrackingRef{ - RemoteName: "upstream", + RemoteName: "origin", BranchName: "feature", }, expectedFound: true, From 8b5073d6172f9cfe4ad4c8ba157ee35bff3f6ed6 Mon Sep 17 00:00:00 2001 From: William Martin Date: Mon, 6 Jan 2025 16:58:35 +0100 Subject: [PATCH 11/12] Move trackingRef into pr create package --- git/objects.go | 10 ------ pkg/cmd/pr/create/create.go | 54 +++++++++++++++++++------------- pkg/cmd/pr/create/create_test.go | 14 ++++----- 3 files changed, 39 insertions(+), 39 deletions(-) diff --git a/git/objects.go b/git/objects.go index ae058a01c..c09683042 100644 --- a/git/objects.go +++ b/git/objects.go @@ -54,16 +54,6 @@ type Ref struct { Name string } -// TrackingRef represents a ref for a remote tracking branch. -type TrackingRef struct { - RemoteName string - BranchName string -} - -func (r TrackingRef) String() string { - return "refs/remotes/" + r.RemoteName + "/" + r.BranchName -} - type Commit struct { Sha string Title string diff --git a/pkg/cmd/pr/create/create.go b/pkg/cmd/pr/create/create.go index 372daf86c..46a4e284e 100644 --- a/pkg/cmd/pr/create/create.go +++ b/pkg/cmd/pr/create/create.go @@ -518,26 +518,47 @@ func initDefaultTitleBody(ctx CreateContext, state *shared.IssueMetadataState, u return nil } +// trackingRef represents a ref for a remote tracking branch. +type trackingRef struct { + remoteName string + branchName string +} + +func (r trackingRef) String() string { + return "refs/remotes/" + r.remoteName + "/" + r.branchName +} + +func mustParseTrackingRef(text string) trackingRef { + parts := strings.SplitN(string(text), "/", 4) + if len(parts) != 4 { + panic(fmt.Errorf("invalid tracking ref: %s", text)) + } + return trackingRef{ + remoteName: parts[2], + branchName: parts[3], + } +} + // tryDetermineTrackingRef is intended to try and find a remote branch on the same commit as the currently checked out // HEAD, i.e. the local branch. If there are multiple branches that might match, the first remote is chosen, which in // practice is determined by the sorting algorithm applied much earlier in the process, roughly "upstream", "github", "origin", // and then everything else unstably sorted. -func tryDetermineTrackingRef(gitClient *git.Client, remotes ghContext.Remotes, localBranchName string, headBranchConfig git.BranchConfig) (git.TrackingRef, bool) { +func tryDetermineTrackingRef(gitClient *git.Client, remotes ghContext.Remotes, localBranchName string, headBranchConfig git.BranchConfig) (trackingRef, bool) { // To try and determine the tracking ref for a local branch, we first construct a collection of refs // that might be tracking, given the current branch's config, and the list of known remotes. refsForLookup := []string{"HEAD"} if headBranchConfig.RemoteName != "" && headBranchConfig.MergeRef != "" { - tr := git.TrackingRef{ - RemoteName: headBranchConfig.RemoteName, - BranchName: strings.TrimPrefix(headBranchConfig.MergeRef, "refs/heads/"), + tr := trackingRef{ + remoteName: headBranchConfig.RemoteName, + branchName: strings.TrimPrefix(headBranchConfig.MergeRef, "refs/heads/"), } refsForLookup = append(refsForLookup, tr.String()) } for _, remote := range remotes { - tr := git.TrackingRef{ - RemoteName: remote.Name, - BranchName: localBranchName, + tr := trackingRef{ + remoteName: remote.Name, + branchName: localBranchName, } refsForLookup = append(refsForLookup, tr.String()) } @@ -563,18 +584,7 @@ func tryDetermineTrackingRef(gitClient *git.Client, remotes ghContext.Remotes, l } } - return git.TrackingRef{}, false -} - -func mustParseTrackingRef(text string) git.TrackingRef { - parts := strings.SplitN(string(text), "/", 4) - if len(parts) != 4 { - panic(fmt.Errorf("invalid tracking ref: %s", text)) - } - return git.TrackingRef{ - RemoteName: parts[2], - BranchName: parts[3], - } + return trackingRef{}, false } func NewIssueState(ctx CreateContext, opts CreateOptions) (*shared.IssueMetadataState, error) { @@ -668,12 +678,12 @@ func NewCreateContext(opts *CreateOptions) (*CreateContext, error) { // determine whether the head branch is already pushed to a remote if trackingRef, found := tryDetermineTrackingRef(gitClient, remotes, headBranch, headBranchConfig); found { isPushEnabled = false - if r, err := remotes.FindByName(trackingRef.RemoteName); err == nil { + if r, err := remotes.FindByName(trackingRef.remoteName); err == nil { headRepo = r headRemote = r - headBranchLabel = trackingRef.BranchName + headBranchLabel = trackingRef.branchName if !ghrepo.IsSame(baseRepo, headRepo) { - headBranchLabel = fmt.Sprintf("%s:%s", headRepo.RepoOwner(), trackingRef.BranchName) + headBranchLabel = fmt.Sprintf("%s:%s", headRepo.RepoOwner(), trackingRef.branchName) } } } diff --git a/pkg/cmd/pr/create/create_test.go b/pkg/cmd/pr/create/create_test.go index c5127bcc6..6c0ff11e2 100644 --- a/pkg/cmd/pr/create/create_test.go +++ b/pkg/cmd/pr/create/create_test.go @@ -1627,7 +1627,7 @@ func Test_tryDetermineTrackingRef(t *testing.T) { name string cmdStubs func(*run.CommandStubber) remotes context.Remotes - expectedTrackingRef git.TrackingRef + expectedTrackingRef trackingRef expectedFound bool }{ { @@ -1636,7 +1636,7 @@ func Test_tryDetermineTrackingRef(t *testing.T) { cs.Register(`git config --get-regexp.+branch\\\.feature\\\.`, 0, "") cs.Register(`git show-ref --verify -- HEAD`, 0, "abc HEAD") }, - expectedTrackingRef: git.TrackingRef{}, + expectedTrackingRef: trackingRef{}, expectedFound: false, }, { @@ -1655,7 +1655,7 @@ func Test_tryDetermineTrackingRef(t *testing.T) { Repo: ghrepo.New("hubot", "Spoon-Knife"), }, }, - expectedTrackingRef: git.TrackingRef{}, + expectedTrackingRef: trackingRef{}, expectedFound: false, }, { @@ -1678,9 +1678,9 @@ func Test_tryDetermineTrackingRef(t *testing.T) { Repo: ghrepo.New("hubot", "Spoon-Knife"), }, }, - expectedTrackingRef: git.TrackingRef{ - RemoteName: "origin", - BranchName: "feature", + expectedTrackingRef: trackingRef{ + remoteName: "origin", + branchName: "feature", }, expectedFound: true, }, @@ -1702,7 +1702,7 @@ func Test_tryDetermineTrackingRef(t *testing.T) { Repo: ghrepo.New("hubot", "Spoon-Knife"), }, }, - expectedTrackingRef: git.TrackingRef{}, + expectedTrackingRef: trackingRef{}, expectedFound: false, }, } From c3b41e87b89aad1f8ea3fcbe4c93c08c3a50b3f9 Mon Sep 17 00:00:00 2001 From: William Martin Date: Mon, 6 Jan 2025 17:00:16 +0100 Subject: [PATCH 12/12] Panic mustParseTrackingRef if format is incorrect --- pkg/cmd/pr/create/create.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pkg/cmd/pr/create/create.go b/pkg/cmd/pr/create/create.go index 46a4e284e..db160b419 100644 --- a/pkg/cmd/pr/create/create.go +++ b/pkg/cmd/pr/create/create.go @@ -530,9 +530,16 @@ func (r trackingRef) String() string { func mustParseTrackingRef(text string) trackingRef { parts := strings.SplitN(string(text), "/", 4) + // The only place this is called is tryDetermineTrackingRef, where we are reconstructing + // the same tracking ref we passed in. If it doesn't match the expected format, this is a + // programmer error we want to know about, so it's ok to panic. if len(parts) != 4 { - panic(fmt.Errorf("invalid tracking ref: %s", text)) + panic(fmt.Errorf("tracking ref should have four parts: %s", text)) } + if parts[0] != "refs" || parts[1] != "remotes" { + panic(fmt.Errorf("tracking ref should start with refs/remotes/: %s", text)) + } + return trackingRef{ remoteName: parts[2], branchName: parts[3],