Merge pull request #10177 from cli/cmbrose/pr-create-upstream-fix

Handle missing upstream configs for `gh pr create`
This commit is contained in:
Tyler McGoffin 2025-01-06 09:32:00 -08:00 committed by GitHub
commit 2358fcee83
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 120 additions and 69 deletions

View file

@ -0,0 +1,27 @@
# This test is the same as pr-create-basic, except that the git push doesn't include the -u argument
# This causes a git config read to fail during gh pr create, but it should not be fatal
# Use gh as a credential helper
exec gh auth setup-git
# Create a repository with a file so it has a default branch
exec gh repo create $ORG/$SCRIPT_NAME-$RANDOM_STRING --add-readme --private
# Defer repo cleanup
defer gh repo delete --yes $ORG/$SCRIPT_NAME-$RANDOM_STRING
# Clone the repo
exec gh repo clone $ORG/$SCRIPT_NAME-$RANDOM_STRING
# Prepare a branch to PR
cd $SCRIPT_NAME-$RANDOM_STRING
exec git checkout -b feature-branch
exec git commit --allow-empty -m 'Empty Commit'
exec git push origin feature-branch
# Create the PR
exec gh pr create --title 'Feature Title' --body 'Feature Body'
# Check the PR is indeed created
exec gh pr view
stdout 'Feature Title'

View file

@ -389,7 +389,6 @@ func (c *Client) ReadBranchConfig(ctx context.Context, branch string) (cfg Branc
return
}
cfg.LocalName = branch
for _, line := range outputLines(out) {
parts := strings.SplitN(line, " ", 2)
if len(parts) < 2 {

View file

@ -737,7 +737,7 @@ func TestClientReadBranchConfig(t *testing.T) {
name: "read branch config",
cmdStdout: "branch.trunk.remote origin\nbranch.trunk.merge refs/heads/trunk\nbranch.trunk.gh-merge-base trunk",
wantCmdArgs: `path/to/git config --get-regexp ^branch\.trunk\.(remote|merge|gh-merge-base)$`,
wantBranchConfig: BranchConfig{LocalName: "trunk", RemoteName: "origin", MergeRef: "refs/heads/trunk", MergeBase: "trunk"},
wantBranchConfig: BranchConfig{RemoteName: "origin", MergeRef: "refs/heads/trunk", MergeBase: "trunk"},
},
}
for _, tt := range tests {

View file

@ -54,16 +54,6 @@ type Ref struct {
Name string
}
// TrackingRef represents a ref for a remote tracking branch.
type TrackingRef struct {
RemoteName string
BranchName string
}
func (r TrackingRef) String() string {
return "refs/remotes/" + r.RemoteName + "/" + r.BranchName
}
type Commit struct {
Sha string
Title string
@ -71,8 +61,6 @@ type Commit struct {
}
type BranchConfig struct {
// LocalName of the branch.
LocalName string
RemoteName string
RemoteURL *url.URL
// MergeBase is the optional base branch to target in a new PR if `--base` is not specified.

View file

@ -518,44 +518,80 @@ func initDefaultTitleBody(ctx CreateContext, state *shared.IssueMetadataState, u
return nil
}
func determineTrackingBranch(gitClient *git.Client, remotes ghContext.Remotes, headBranchConfig *git.BranchConfig) *git.TrackingRef {
refsForLookup := []string{"HEAD"}
var trackingRefs []git.TrackingRef
// trackingRef represents a ref for a remote tracking branch.
type trackingRef struct {
remoteName string
branchName string
}
if headBranchConfig.RemoteName != "" {
tr := git.TrackingRef{
RemoteName: headBranchConfig.RemoteName,
BranchName: strings.TrimPrefix(headBranchConfig.MergeRef, "refs/heads/"),
func (r trackingRef) String() string {
return "refs/remotes/" + r.remoteName + "/" + r.branchName
}
func mustParseTrackingRef(text string) trackingRef {
parts := strings.SplitN(string(text), "/", 4)
// The only place this is called is tryDetermineTrackingRef, where we are reconstructing
// the same tracking ref we passed in. If it doesn't match the expected format, this is a
// programmer error we want to know about, so it's ok to panic.
if len(parts) != 4 {
panic(fmt.Errorf("tracking ref should have four parts: %s", text))
}
if parts[0] != "refs" || parts[1] != "remotes" {
panic(fmt.Errorf("tracking ref should start with refs/remotes/: %s", text))
}
return trackingRef{
remoteName: parts[2],
branchName: parts[3],
}
}
// tryDetermineTrackingRef is intended to try and find a remote branch on the same commit as the currently checked out
// HEAD, i.e. the local branch. If there are multiple branches that might match, the first remote is chosen, which in
// practice is determined by the sorting algorithm applied much earlier in the process, roughly "upstream", "github", "origin",
// and then everything else unstably sorted.
func tryDetermineTrackingRef(gitClient *git.Client, remotes ghContext.Remotes, localBranchName string, headBranchConfig git.BranchConfig) (trackingRef, bool) {
// To try and determine the tracking ref for a local branch, we first construct a collection of refs
// that might be tracking, given the current branch's config, and the list of known remotes.
refsForLookup := []string{"HEAD"}
if headBranchConfig.RemoteName != "" && headBranchConfig.MergeRef != "" {
tr := trackingRef{
remoteName: headBranchConfig.RemoteName,
branchName: strings.TrimPrefix(headBranchConfig.MergeRef, "refs/heads/"),
}
trackingRefs = append(trackingRefs, tr)
refsForLookup = append(refsForLookup, tr.String())
}
for _, remote := range remotes {
tr := git.TrackingRef{
RemoteName: remote.Name,
BranchName: headBranchConfig.LocalName,
tr := trackingRef{
remoteName: remote.Name,
branchName: localBranchName,
}
trackingRefs = append(trackingRefs, tr)
refsForLookup = append(refsForLookup, tr.String())
}
// Then we ask git for details about these refs, for example, refs/remotes/origin/trunk might return a hash
// for the remote tracking branch, trunk, for the remote, origin. If there is no ref, the git client returns
// no ref information.
//
// We also first check for the HEAD ref, so that we have the hash of the currently checked out commit.
resolvedRefs, _ := gitClient.ShowRefs(context.Background(), refsForLookup)
// If there is more than one resolved ref, that means that at least one ref was found in addition to the HEAD.
if len(resolvedRefs) > 1 {
headRef := resolvedRefs[0]
for _, r := range resolvedRefs[1:] {
if r.Hash != resolvedRefs[0].Hash {
// If the hash of the remote ref doesn't match the hash of HEAD then the remote branch is not in the same
// state, so it can't be used.
if r.Hash != headRef.Hash {
continue
}
for _, tr := range trackingRefs {
if tr.String() != r.Name {
continue
}
return &tr
}
// Otherwise we can parse the returned ref into a tracking ref and return that
return mustParseTrackingRef(r.Name), true
}
}
return nil
return trackingRef{}, false
}
func NewIssueState(ctx CreateContext, opts CreateOptions) (*shared.IssueMetadataState, error) {
@ -647,14 +683,14 @@ func NewCreateContext(opts *CreateOptions) (*CreateContext, error) {
headBranchConfig := gitClient.ReadBranchConfig(context.Background(), headBranch)
if isPushEnabled {
// determine whether the head branch is already pushed to a remote
if pushedTo := determineTrackingBranch(gitClient, remotes, &headBranchConfig); pushedTo != nil {
if trackingRef, found := tryDetermineTrackingRef(gitClient, remotes, headBranch, headBranchConfig); found {
isPushEnabled = false
if r, err := remotes.FindByName(pushedTo.RemoteName); err == nil {
if r, err := remotes.FindByName(trackingRef.remoteName); err == nil {
headRepo = r
headRemote = r
headBranchLabel = pushedTo.BranchName
headBranchLabel = trackingRef.branchName
if !ghrepo.IsSame(baseRepo, headRepo) {
headBranchLabel = fmt.Sprintf("%s:%s", headRepo.RepoOwner(), pushedTo.BranchName)
headBranchLabel = fmt.Sprintf("%s:%s", headRepo.RepoOwner(), trackingRef.branchName)
}
}
}

View file

@ -1622,12 +1622,13 @@ func Test_createRun(t *testing.T) {
}
}
func Test_determineTrackingBranch(t *testing.T) {
func Test_tryDetermineTrackingRef(t *testing.T) {
tests := []struct {
name string
cmdStubs func(*run.CommandStubber)
remotes context.Remotes
assert func(ref *git.TrackingRef, t *testing.T)
name string
cmdStubs func(*run.CommandStubber)
remotes context.Remotes
expectedTrackingRef trackingRef
expectedFound bool
}{
{
name: "empty",
@ -1635,54 +1636,53 @@ func Test_determineTrackingBranch(t *testing.T) {
cs.Register(`git config --get-regexp.+branch\\\.feature\\\.`, 0, "")
cs.Register(`git show-ref --verify -- HEAD`, 0, "abc HEAD")
},
assert: func(ref *git.TrackingRef, t *testing.T) {
assert.Nil(t, ref)
},
expectedTrackingRef: trackingRef{},
expectedFound: false,
},
{
name: "no match",
cmdStubs: func(cs *run.CommandStubber) {
cs.Register(`git config --get-regexp.+branch\\\.feature\\\.`, 0, "")
cs.Register("git show-ref --verify -- HEAD refs/remotes/origin/feature refs/remotes/upstream/feature", 0, "abc HEAD\nbca refs/remotes/origin/feature")
cs.Register("git show-ref --verify -- HEAD refs/remotes/upstream/feature refs/remotes/origin/feature", 0, "abc HEAD\nbca refs/remotes/upstream/feature")
},
remotes: context.Remotes{
&context.Remote{
Remote: &git.Remote{Name: "origin"},
Repo: ghrepo.New("hubot", "Spoon-Knife"),
},
&context.Remote{
Remote: &git.Remote{Name: "upstream"},
Repo: ghrepo.New("octocat", "Spoon-Knife"),
},
&context.Remote{
Remote: &git.Remote{Name: "origin"},
Repo: ghrepo.New("hubot", "Spoon-Knife"),
},
},
assert: func(ref *git.TrackingRef, t *testing.T) {
assert.Nil(t, ref)
},
expectedTrackingRef: trackingRef{},
expectedFound: false,
},
{
name: "match",
cmdStubs: func(cs *run.CommandStubber) {
cs.Register(`git config --get-regexp.+branch\\\.feature\\\.`, 0, "")
cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature refs/remotes/upstream/feature$`, 0, heredoc.Doc(`
cs.Register(`git show-ref --verify -- HEAD refs/remotes/upstream/feature refs/remotes/origin/feature$`, 0, heredoc.Doc(`
deadbeef HEAD
deadb00f refs/remotes/origin/feature
deadbeef refs/remotes/upstream/feature
deadb00f refs/remotes/upstream/feature
deadbeef refs/remotes/origin/feature
`))
},
remotes: context.Remotes{
&context.Remote{
Remote: &git.Remote{Name: "origin"},
Repo: ghrepo.New("hubot", "Spoon-Knife"),
},
&context.Remote{
Remote: &git.Remote{Name: "upstream"},
Repo: ghrepo.New("octocat", "Spoon-Knife"),
},
&context.Remote{
Remote: &git.Remote{Name: "origin"},
Repo: ghrepo.New("hubot", "Spoon-Knife"),
},
},
assert: func(ref *git.TrackingRef, t *testing.T) {
assert.Equal(t, "upstream", ref.RemoteName)
assert.Equal(t, "feature", ref.BranchName)
expectedTrackingRef: trackingRef{
remoteName: "origin",
branchName: "feature",
},
expectedFound: true,
},
{
name: "respect tracking config",
@ -1702,9 +1702,8 @@ func Test_determineTrackingBranch(t *testing.T) {
Repo: ghrepo.New("hubot", "Spoon-Knife"),
},
},
assert: func(ref *git.TrackingRef, t *testing.T) {
assert.Nil(t, ref)
},
expectedTrackingRef: trackingRef{},
expectedFound: false,
},
}
for _, tt := range tests {
@ -1719,8 +1718,10 @@ func Test_determineTrackingBranch(t *testing.T) {
GitPath: "some/path/git",
}
headBranchConfig := gitClient.ReadBranchConfig(ctx.Background(), "feature")
ref := determineTrackingBranch(gitClient, tt.remotes, &headBranchConfig)
tt.assert(ref, t)
ref, found := tryDetermineTrackingRef(gitClient, tt.remotes, "feature", headBranchConfig)
assert.Equal(t, tt.expectedTrackingRef, ref)
assert.Equal(t, tt.expectedFound, found)
})
}
}