diff --git a/acceptance/testdata/pr/pr-create-without-upstream-config.txtar b/acceptance/testdata/pr/pr-create-without-upstream-config.txtar new file mode 100644 index 000000000..00f3535a7 --- /dev/null +++ b/acceptance/testdata/pr/pr-create-without-upstream-config.txtar @@ -0,0 +1,27 @@ +# This test is the same as pr-create-basic, except that the git push doesn't include the -u argument +# This causes a git config read to fail during gh pr create, but it should not be fatal + +# Use gh as a credential helper +exec gh auth setup-git + +# Create a repository with a file so it has a default branch +exec gh repo create $ORG/$SCRIPT_NAME-$RANDOM_STRING --add-readme --private + +# Defer repo cleanup +defer gh repo delete --yes $ORG/$SCRIPT_NAME-$RANDOM_STRING + +# Clone the repo +exec gh repo clone $ORG/$SCRIPT_NAME-$RANDOM_STRING + +# Prepare a branch to PR +cd $SCRIPT_NAME-$RANDOM_STRING +exec git checkout -b feature-branch +exec git commit --allow-empty -m 'Empty Commit' +exec git push origin feature-branch + +# Create the PR +exec gh pr create --title 'Feature Title' --body 'Feature Body' + +# Check the PR is indeed created +exec gh pr view +stdout 'Feature Title' diff --git a/git/client.go b/git/client.go index 35b9cf16c..1a6d9ae7f 100644 --- a/git/client.go +++ b/git/client.go @@ -389,7 +389,6 @@ func (c *Client) ReadBranchConfig(ctx context.Context, branch string) (cfg Branc return } - cfg.LocalName = branch for _, line := range outputLines(out) { parts := strings.SplitN(line, " ", 2) if len(parts) < 2 { diff --git a/git/client_test.go b/git/client_test.go index 0ec4f7334..fff1397f3 100644 --- a/git/client_test.go +++ b/git/client_test.go @@ -737,7 +737,7 @@ func TestClientReadBranchConfig(t *testing.T) { name: "read branch config", cmdStdout: "branch.trunk.remote origin\nbranch.trunk.merge refs/heads/trunk\nbranch.trunk.gh-merge-base trunk", wantCmdArgs: `path/to/git config --get-regexp ^branch\.trunk\.(remote|merge|gh-merge-base)$`, - wantBranchConfig: BranchConfig{LocalName: "trunk", RemoteName: "origin", MergeRef: "refs/heads/trunk", MergeBase: "trunk"}, + wantBranchConfig: BranchConfig{RemoteName: "origin", MergeRef: "refs/heads/trunk", MergeBase: "trunk"}, }, } for _, tt := range tests { diff --git a/git/objects.go b/git/objects.go index f371429dd..c09683042 100644 --- a/git/objects.go +++ b/git/objects.go @@ -54,16 +54,6 @@ type Ref struct { Name string } -// TrackingRef represents a ref for a remote tracking branch. -type TrackingRef struct { - RemoteName string - BranchName string -} - -func (r TrackingRef) String() string { - return "refs/remotes/" + r.RemoteName + "/" + r.BranchName -} - type Commit struct { Sha string Title string @@ -71,8 +61,6 @@ type Commit struct { } type BranchConfig struct { - // LocalName of the branch. - LocalName string RemoteName string RemoteURL *url.URL // MergeBase is the optional base branch to target in a new PR if `--base` is not specified. diff --git a/pkg/cmd/pr/create/create.go b/pkg/cmd/pr/create/create.go index 3bfa768b7..db160b419 100644 --- a/pkg/cmd/pr/create/create.go +++ b/pkg/cmd/pr/create/create.go @@ -518,44 +518,80 @@ func initDefaultTitleBody(ctx CreateContext, state *shared.IssueMetadataState, u return nil } -func determineTrackingBranch(gitClient *git.Client, remotes ghContext.Remotes, headBranchConfig *git.BranchConfig) *git.TrackingRef { - refsForLookup := []string{"HEAD"} - var trackingRefs []git.TrackingRef +// trackingRef represents a ref for a remote tracking branch. +type trackingRef struct { + remoteName string + branchName string +} - if headBranchConfig.RemoteName != "" { - tr := git.TrackingRef{ - RemoteName: headBranchConfig.RemoteName, - BranchName: strings.TrimPrefix(headBranchConfig.MergeRef, "refs/heads/"), +func (r trackingRef) String() string { + return "refs/remotes/" + r.remoteName + "/" + r.branchName +} + +func mustParseTrackingRef(text string) trackingRef { + parts := strings.SplitN(string(text), "/", 4) + // The only place this is called is tryDetermineTrackingRef, where we are reconstructing + // the same tracking ref we passed in. If it doesn't match the expected format, this is a + // programmer error we want to know about, so it's ok to panic. + if len(parts) != 4 { + panic(fmt.Errorf("tracking ref should have four parts: %s", text)) + } + if parts[0] != "refs" || parts[1] != "remotes" { + panic(fmt.Errorf("tracking ref should start with refs/remotes/: %s", text)) + } + + return trackingRef{ + remoteName: parts[2], + branchName: parts[3], + } +} + +// tryDetermineTrackingRef is intended to try and find a remote branch on the same commit as the currently checked out +// HEAD, i.e. the local branch. If there are multiple branches that might match, the first remote is chosen, which in +// practice is determined by the sorting algorithm applied much earlier in the process, roughly "upstream", "github", "origin", +// and then everything else unstably sorted. +func tryDetermineTrackingRef(gitClient *git.Client, remotes ghContext.Remotes, localBranchName string, headBranchConfig git.BranchConfig) (trackingRef, bool) { + // To try and determine the tracking ref for a local branch, we first construct a collection of refs + // that might be tracking, given the current branch's config, and the list of known remotes. + refsForLookup := []string{"HEAD"} + if headBranchConfig.RemoteName != "" && headBranchConfig.MergeRef != "" { + tr := trackingRef{ + remoteName: headBranchConfig.RemoteName, + branchName: strings.TrimPrefix(headBranchConfig.MergeRef, "refs/heads/"), } - trackingRefs = append(trackingRefs, tr) refsForLookup = append(refsForLookup, tr.String()) } for _, remote := range remotes { - tr := git.TrackingRef{ - RemoteName: remote.Name, - BranchName: headBranchConfig.LocalName, + tr := trackingRef{ + remoteName: remote.Name, + branchName: localBranchName, } - trackingRefs = append(trackingRefs, tr) refsForLookup = append(refsForLookup, tr.String()) } + // Then we ask git for details about these refs, for example, refs/remotes/origin/trunk might return a hash + // for the remote tracking branch, trunk, for the remote, origin. If there is no ref, the git client returns + // no ref information. + // + // We also first check for the HEAD ref, so that we have the hash of the currently checked out commit. resolvedRefs, _ := gitClient.ShowRefs(context.Background(), refsForLookup) + + // If there is more than one resolved ref, that means that at least one ref was found in addition to the HEAD. if len(resolvedRefs) > 1 { + headRef := resolvedRefs[0] for _, r := range resolvedRefs[1:] { - if r.Hash != resolvedRefs[0].Hash { + // If the hash of the remote ref doesn't match the hash of HEAD then the remote branch is not in the same + // state, so it can't be used. + if r.Hash != headRef.Hash { continue } - for _, tr := range trackingRefs { - if tr.String() != r.Name { - continue - } - return &tr - } + // Otherwise we can parse the returned ref into a tracking ref and return that + return mustParseTrackingRef(r.Name), true } } - return nil + return trackingRef{}, false } func NewIssueState(ctx CreateContext, opts CreateOptions) (*shared.IssueMetadataState, error) { @@ -647,14 +683,14 @@ func NewCreateContext(opts *CreateOptions) (*CreateContext, error) { headBranchConfig := gitClient.ReadBranchConfig(context.Background(), headBranch) if isPushEnabled { // determine whether the head branch is already pushed to a remote - if pushedTo := determineTrackingBranch(gitClient, remotes, &headBranchConfig); pushedTo != nil { + if trackingRef, found := tryDetermineTrackingRef(gitClient, remotes, headBranch, headBranchConfig); found { isPushEnabled = false - if r, err := remotes.FindByName(pushedTo.RemoteName); err == nil { + if r, err := remotes.FindByName(trackingRef.remoteName); err == nil { headRepo = r headRemote = r - headBranchLabel = pushedTo.BranchName + headBranchLabel = trackingRef.branchName if !ghrepo.IsSame(baseRepo, headRepo) { - headBranchLabel = fmt.Sprintf("%s:%s", headRepo.RepoOwner(), pushedTo.BranchName) + headBranchLabel = fmt.Sprintf("%s:%s", headRepo.RepoOwner(), trackingRef.branchName) } } } diff --git a/pkg/cmd/pr/create/create_test.go b/pkg/cmd/pr/create/create_test.go index 27220d052..6c0ff11e2 100644 --- a/pkg/cmd/pr/create/create_test.go +++ b/pkg/cmd/pr/create/create_test.go @@ -1622,12 +1622,13 @@ func Test_createRun(t *testing.T) { } } -func Test_determineTrackingBranch(t *testing.T) { +func Test_tryDetermineTrackingRef(t *testing.T) { tests := []struct { - name string - cmdStubs func(*run.CommandStubber) - remotes context.Remotes - assert func(ref *git.TrackingRef, t *testing.T) + name string + cmdStubs func(*run.CommandStubber) + remotes context.Remotes + expectedTrackingRef trackingRef + expectedFound bool }{ { name: "empty", @@ -1635,54 +1636,53 @@ func Test_determineTrackingBranch(t *testing.T) { cs.Register(`git config --get-regexp.+branch\\\.feature\\\.`, 0, "") cs.Register(`git show-ref --verify -- HEAD`, 0, "abc HEAD") }, - assert: func(ref *git.TrackingRef, t *testing.T) { - assert.Nil(t, ref) - }, + expectedTrackingRef: trackingRef{}, + expectedFound: false, }, { name: "no match", cmdStubs: func(cs *run.CommandStubber) { cs.Register(`git config --get-regexp.+branch\\\.feature\\\.`, 0, "") - cs.Register("git show-ref --verify -- HEAD refs/remotes/origin/feature refs/remotes/upstream/feature", 0, "abc HEAD\nbca refs/remotes/origin/feature") + cs.Register("git show-ref --verify -- HEAD refs/remotes/upstream/feature refs/remotes/origin/feature", 0, "abc HEAD\nbca refs/remotes/upstream/feature") }, remotes: context.Remotes{ - &context.Remote{ - Remote: &git.Remote{Name: "origin"}, - Repo: ghrepo.New("hubot", "Spoon-Knife"), - }, &context.Remote{ Remote: &git.Remote{Name: "upstream"}, Repo: ghrepo.New("octocat", "Spoon-Knife"), }, + &context.Remote{ + Remote: &git.Remote{Name: "origin"}, + Repo: ghrepo.New("hubot", "Spoon-Knife"), + }, }, - assert: func(ref *git.TrackingRef, t *testing.T) { - assert.Nil(t, ref) - }, + expectedTrackingRef: trackingRef{}, + expectedFound: false, }, { name: "match", cmdStubs: func(cs *run.CommandStubber) { cs.Register(`git config --get-regexp.+branch\\\.feature\\\.`, 0, "") - cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature refs/remotes/upstream/feature$`, 0, heredoc.Doc(` + cs.Register(`git show-ref --verify -- HEAD refs/remotes/upstream/feature refs/remotes/origin/feature$`, 0, heredoc.Doc(` deadbeef HEAD - deadb00f refs/remotes/origin/feature - deadbeef refs/remotes/upstream/feature + deadb00f refs/remotes/upstream/feature + deadbeef refs/remotes/origin/feature `)) }, remotes: context.Remotes{ - &context.Remote{ - Remote: &git.Remote{Name: "origin"}, - Repo: ghrepo.New("hubot", "Spoon-Knife"), - }, &context.Remote{ Remote: &git.Remote{Name: "upstream"}, Repo: ghrepo.New("octocat", "Spoon-Knife"), }, + &context.Remote{ + Remote: &git.Remote{Name: "origin"}, + Repo: ghrepo.New("hubot", "Spoon-Knife"), + }, }, - assert: func(ref *git.TrackingRef, t *testing.T) { - assert.Equal(t, "upstream", ref.RemoteName) - assert.Equal(t, "feature", ref.BranchName) + expectedTrackingRef: trackingRef{ + remoteName: "origin", + branchName: "feature", }, + expectedFound: true, }, { name: "respect tracking config", @@ -1702,9 +1702,8 @@ func Test_determineTrackingBranch(t *testing.T) { Repo: ghrepo.New("hubot", "Spoon-Knife"), }, }, - assert: func(ref *git.TrackingRef, t *testing.T) { - assert.Nil(t, ref) - }, + expectedTrackingRef: trackingRef{}, + expectedFound: false, }, } for _, tt := range tests { @@ -1719,8 +1718,10 @@ func Test_determineTrackingBranch(t *testing.T) { GitPath: "some/path/git", } headBranchConfig := gitClient.ReadBranchConfig(ctx.Background(), "feature") - ref := determineTrackingBranch(gitClient, tt.remotes, &headBranchConfig) - tt.assert(ref, t) + ref, found := tryDetermineTrackingRef(gitClient, tt.remotes, "feature", headBranchConfig) + + assert.Equal(t, tt.expectedTrackingRef, ref) + assert.Equal(t, tt.expectedFound, found) }) } }