From 8b5073d6172f9cfe4ad4c8ba157ee35bff3f6ed6 Mon Sep 17 00:00:00 2001 From: William Martin Date: Mon, 6 Jan 2025 16:58:35 +0100 Subject: [PATCH] Move trackingRef into pr create package --- git/objects.go | 10 ------ pkg/cmd/pr/create/create.go | 54 +++++++++++++++++++------------- pkg/cmd/pr/create/create_test.go | 14 ++++----- 3 files changed, 39 insertions(+), 39 deletions(-) diff --git a/git/objects.go b/git/objects.go index ae058a01c..c09683042 100644 --- a/git/objects.go +++ b/git/objects.go @@ -54,16 +54,6 @@ type Ref struct { Name string } -// TrackingRef represents a ref for a remote tracking branch. -type TrackingRef struct { - RemoteName string - BranchName string -} - -func (r TrackingRef) String() string { - return "refs/remotes/" + r.RemoteName + "/" + r.BranchName -} - type Commit struct { Sha string Title string diff --git a/pkg/cmd/pr/create/create.go b/pkg/cmd/pr/create/create.go index 372daf86c..46a4e284e 100644 --- a/pkg/cmd/pr/create/create.go +++ b/pkg/cmd/pr/create/create.go @@ -518,26 +518,47 @@ func initDefaultTitleBody(ctx CreateContext, state *shared.IssueMetadataState, u return nil } +// trackingRef represents a ref for a remote tracking branch. +type trackingRef struct { + remoteName string + branchName string +} + +func (r trackingRef) String() string { + return "refs/remotes/" + r.remoteName + "/" + r.branchName +} + +func mustParseTrackingRef(text string) trackingRef { + parts := strings.SplitN(string(text), "/", 4) + if len(parts) != 4 { + panic(fmt.Errorf("invalid tracking ref: %s", text)) + } + return trackingRef{ + remoteName: parts[2], + branchName: parts[3], + } +} + // tryDetermineTrackingRef is intended to try and find a remote branch on the same commit as the currently checked out // HEAD, i.e. the local branch. If there are multiple branches that might match, the first remote is chosen, which in // practice is determined by the sorting algorithm applied much earlier in the process, roughly "upstream", "github", "origin", // and then everything else unstably sorted. -func tryDetermineTrackingRef(gitClient *git.Client, remotes ghContext.Remotes, localBranchName string, headBranchConfig git.BranchConfig) (git.TrackingRef, bool) { +func tryDetermineTrackingRef(gitClient *git.Client, remotes ghContext.Remotes, localBranchName string, headBranchConfig git.BranchConfig) (trackingRef, bool) { // To try and determine the tracking ref for a local branch, we first construct a collection of refs // that might be tracking, given the current branch's config, and the list of known remotes. refsForLookup := []string{"HEAD"} if headBranchConfig.RemoteName != "" && headBranchConfig.MergeRef != "" { - tr := git.TrackingRef{ - RemoteName: headBranchConfig.RemoteName, - BranchName: strings.TrimPrefix(headBranchConfig.MergeRef, "refs/heads/"), + tr := trackingRef{ + remoteName: headBranchConfig.RemoteName, + branchName: strings.TrimPrefix(headBranchConfig.MergeRef, "refs/heads/"), } refsForLookup = append(refsForLookup, tr.String()) } for _, remote := range remotes { - tr := git.TrackingRef{ - RemoteName: remote.Name, - BranchName: localBranchName, + tr := trackingRef{ + remoteName: remote.Name, + branchName: localBranchName, } refsForLookup = append(refsForLookup, tr.String()) } @@ -563,18 +584,7 @@ func tryDetermineTrackingRef(gitClient *git.Client, remotes ghContext.Remotes, l } } - return git.TrackingRef{}, false -} - -func mustParseTrackingRef(text string) git.TrackingRef { - parts := strings.SplitN(string(text), "/", 4) - if len(parts) != 4 { - panic(fmt.Errorf("invalid tracking ref: %s", text)) - } - return git.TrackingRef{ - RemoteName: parts[2], - BranchName: parts[3], - } + return trackingRef{}, false } func NewIssueState(ctx CreateContext, opts CreateOptions) (*shared.IssueMetadataState, error) { @@ -668,12 +678,12 @@ func NewCreateContext(opts *CreateOptions) (*CreateContext, error) { // determine whether the head branch is already pushed to a remote if trackingRef, found := tryDetermineTrackingRef(gitClient, remotes, headBranch, headBranchConfig); found { isPushEnabled = false - if r, err := remotes.FindByName(trackingRef.RemoteName); err == nil { + if r, err := remotes.FindByName(trackingRef.remoteName); err == nil { headRepo = r headRemote = r - headBranchLabel = trackingRef.BranchName + headBranchLabel = trackingRef.branchName if !ghrepo.IsSame(baseRepo, headRepo) { - headBranchLabel = fmt.Sprintf("%s:%s", headRepo.RepoOwner(), trackingRef.BranchName) + headBranchLabel = fmt.Sprintf("%s:%s", headRepo.RepoOwner(), trackingRef.branchName) } } } diff --git a/pkg/cmd/pr/create/create_test.go b/pkg/cmd/pr/create/create_test.go index c5127bcc6..6c0ff11e2 100644 --- a/pkg/cmd/pr/create/create_test.go +++ b/pkg/cmd/pr/create/create_test.go @@ -1627,7 +1627,7 @@ func Test_tryDetermineTrackingRef(t *testing.T) { name string cmdStubs func(*run.CommandStubber) remotes context.Remotes - expectedTrackingRef git.TrackingRef + expectedTrackingRef trackingRef expectedFound bool }{ { @@ -1636,7 +1636,7 @@ func Test_tryDetermineTrackingRef(t *testing.T) { cs.Register(`git config --get-regexp.+branch\\\.feature\\\.`, 0, "") cs.Register(`git show-ref --verify -- HEAD`, 0, "abc HEAD") }, - expectedTrackingRef: git.TrackingRef{}, + expectedTrackingRef: trackingRef{}, expectedFound: false, }, { @@ -1655,7 +1655,7 @@ func Test_tryDetermineTrackingRef(t *testing.T) { Repo: ghrepo.New("hubot", "Spoon-Knife"), }, }, - expectedTrackingRef: git.TrackingRef{}, + expectedTrackingRef: trackingRef{}, expectedFound: false, }, { @@ -1678,9 +1678,9 @@ func Test_tryDetermineTrackingRef(t *testing.T) { Repo: ghrepo.New("hubot", "Spoon-Knife"), }, }, - expectedTrackingRef: git.TrackingRef{ - RemoteName: "origin", - BranchName: "feature", + expectedTrackingRef: trackingRef{ + remoteName: "origin", + branchName: "feature", }, expectedFound: true, }, @@ -1702,7 +1702,7 @@ func Test_tryDetermineTrackingRef(t *testing.T) { Repo: ghrepo.New("hubot", "Spoon-Knife"), }, }, - expectedTrackingRef: git.TrackingRef{}, + expectedTrackingRef: trackingRef{}, expectedFound: false, }, }