diff --git a/pkg/cmd/pr/create/create.go b/pkg/cmd/pr/create/create.go index dc89542e5..04cf5d1aa 100644 --- a/pkg/cmd/pr/create/create.go +++ b/pkg/cmd/pr/create/create.go @@ -649,7 +649,7 @@ func NewCreateContext(opts *CreateOptions) (*CreateContext, error) { // our local branch state. // If it is, we can use it as the head repo for the PR // and avoid prompting the user. - if prRefs.HeadRepo != nil && prRefs.BranchName != "" { + if prRefs.HasHead() { // Check if the head branch is up-to-date with the local branch headRemote, err := remotes.FindByRepo(prRefs.HeadRepo.RepoOwner(), prRefs.HeadRepo.RepoName()) if headRemote != nil && err == nil { @@ -662,7 +662,7 @@ func NewCreateContext(opts *CreateOptions) (*CreateContext, error) { headRef := resolvedRefs[0] for _, r := range resolvedRefs[1:] { // If the head ref is the same as the remote head ref, - // then the remote head is current. + // then the remote head is current, and we can use it. if r.Hash == headRef.Hash { promptForHeadRepo = false break diff --git a/pkg/cmd/pr/shared/finder.go b/pkg/cmd/pr/shared/finder.go index dc9cb8fb9..d02377759 100644 --- a/pkg/cmd/pr/shared/finder.go +++ b/pkg/cmd/pr/shared/finder.go @@ -109,6 +109,10 @@ type PullRequestRefs struct { BaseRepo ghrepo.Interface } +func (s *PullRequestRefs) HasHead() bool { + return s.HeadRepo != nil && s.BranchName != "" +} + // GetPRHeadLabel returns the string that the GitHub API uses to identify the PR. This is // either just the branch name or, if the PR is originating from a fork, the fork owner // and the branch name, like :. diff --git a/pkg/cmd/pr/shared/finder_test.go b/pkg/cmd/pr/shared/finder_test.go index 36551ab42..b10c55e55 100644 --- a/pkg/cmd/pr/shared/finder_test.go +++ b/pkg/cmd/pr/shared/finder_test.go @@ -972,6 +972,52 @@ func TestPRRefs_GetPRHeadLabel(t *testing.T) { } } +func TestPullRequestRefs_HasHead(t *testing.T) { + tests := []struct { + name string + prRefs PullRequestRefs + want bool + }{ + { + name: "HeadRepo is nil and BranchName is empty, return false", + prRefs: PullRequestRefs{ + HeadRepo: nil, + BranchName: "", + }, + want: false, + }, + { + name: "HeadRepo is not nil and BranchName is empty, return false", + prRefs: PullRequestRefs{ + HeadRepo: ghrepo.New("ORIGINOWNER", "REPO"), + BranchName: "", + }, + want: false, + }, + { + name: "HeadRepo is nil and BranchName is not empty, return false", + prRefs: PullRequestRefs{ + HeadRepo: nil, + BranchName: "feature-branch", + }, + want: false, + }, + { + name: "HeadRepo is not nil and BranchName is not empty, return true", + prRefs: PullRequestRefs{ + HeadRepo: ghrepo.New("ORIGINOWNER", "REPO"), + BranchName: "feature-branch", + }, + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, tt.prRefs.HasHead()) + }) + } +} + func stubBranchConfig(branchConfig git.BranchConfig, err error) func(string) (git.BranchConfig, error) { return func(branch string) (git.BranchConfig, error) { return branchConfig, err