From 3ae4e5da20ebd9f9f46342906c0a930b75310d3b Mon Sep 17 00:00:00 2001 From: William Martin Date: Mon, 6 Jan 2025 15:14:43 +0100 Subject: [PATCH] Document and rework pr create tracking branch lookup --- git/objects.go | 12 ++++++++++++ pkg/cmd/pr/create/create.go | 30 +++++++++++++++++++----------- 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/git/objects.go b/git/objects.go index ae058a01c..924a3059e 100644 --- a/git/objects.go +++ b/git/objects.go @@ -1,6 +1,7 @@ package git import ( + "fmt" "net/url" "strings" ) @@ -64,6 +65,17 @@ func (r TrackingRef) String() string { return "refs/remotes/" + r.RemoteName + "/" + r.BranchName } +func ParseTrackingRef(text string) (TrackingRef, error) { + parts := strings.SplitN(string(text), "/", 4) + if len(parts) != 4 { + return TrackingRef{}, fmt.Errorf("invalid tracking ref: %s", text) + } + return TrackingRef{ + RemoteName: parts[2], + BranchName: parts[3], + }, nil +} + type Commit struct { Sha string Title string diff --git a/pkg/cmd/pr/create/create.go b/pkg/cmd/pr/create/create.go index 6a1e13849..df62dc889 100644 --- a/pkg/cmd/pr/create/create.go +++ b/pkg/cmd/pr/create/create.go @@ -519,15 +519,14 @@ func initDefaultTitleBody(ctx CreateContext, state *shared.IssueMetadataState, u } func determineTrackingBranch(gitClient *git.Client, remotes ghContext.Remotes, localBranchName string, headBranchConfig *git.BranchConfig) *git.TrackingRef { + // To try and determine the tracking ref for a local branch, we first construct a collection of refs + // that might be tracking, given the current branch's config, and the list of known remotes. refsForLookup := []string{"HEAD"} - var trackingRefs []git.TrackingRef - - if headBranchConfig.RemoteName != "" { + if headBranchConfig.RemoteName != "" && headBranchConfig.MergeRef != "" { tr := git.TrackingRef{ RemoteName: headBranchConfig.RemoteName, BranchName: strings.TrimPrefix(headBranchConfig.MergeRef, "refs/heads/"), } - trackingRefs = append(trackingRefs, tr) refsForLookup = append(refsForLookup, tr.String()) } @@ -536,22 +535,31 @@ func determineTrackingBranch(gitClient *git.Client, remotes ghContext.Remotes, l RemoteName: remote.Name, BranchName: localBranchName, } - trackingRefs = append(trackingRefs, tr) refsForLookup = append(refsForLookup, tr.String()) } + // Then we ask git for details about these refs, for example, refs/remotes/origin/trunk might return a hash + // for the remote tracking branch, trunk, for the remote, origin. If there is no ref, the git client returns + // no ref information. + // + // We also first check for the HEAD ref, so that we have the hash of the currently checked out commit. resolvedRefs, _ := gitClient.ShowRefs(context.Background(), refsForLookup) + + // If there is more than one resolved ref, that means that at least one ref was found in addition to the HEAD. if len(resolvedRefs) > 1 { + headRef := resolvedRefs[0] for _, r := range resolvedRefs[1:] { - if r.Hash != resolvedRefs[0].Hash { + // If the hash of the remote ref doesn't match the hash of HEAD then the remote branch is not in the same + // state, so it can't be used. + if r.Hash != headRef.Hash { continue } - for _, tr := range trackingRefs { - if tr.String() != r.Name { - continue - } - return &tr + // Otherwise we can parse the returned ref into a tracking ref and return that + trackingRef, err := git.ParseTrackingRef(r.Name) + if err != nil { + return nil } + return &trackingRef } }