refactor(pr create): add PullRequestRefs HasHead

This commit is contained in:
Kynan Ware 2025-03-10 13:28:46 -06:00
parent 54da786bec
commit e999976b3d
3 changed files with 52 additions and 2 deletions

View file

@ -649,7 +649,7 @@ func NewCreateContext(opts *CreateOptions) (*CreateContext, error) {
// our local branch state.
// If it is, we can use it as the head repo for the PR
// and avoid prompting the user.
if prRefs.HeadRepo != nil && prRefs.BranchName != "" {
if prRefs.HasHead() {
// Check if the head branch is up-to-date with the local branch
headRemote, err := remotes.FindByRepo(prRefs.HeadRepo.RepoOwner(), prRefs.HeadRepo.RepoName())
if headRemote != nil && err == nil {
@ -662,7 +662,7 @@ func NewCreateContext(opts *CreateOptions) (*CreateContext, error) {
headRef := resolvedRefs[0]
for _, r := range resolvedRefs[1:] {
// If the head ref is the same as the remote head ref,
// then the remote head is current.
// then the remote head is current, and we can use it.
if r.Hash == headRef.Hash {
promptForHeadRepo = false
break

View file

@ -109,6 +109,10 @@ type PullRequestRefs struct {
BaseRepo ghrepo.Interface
}
func (s *PullRequestRefs) HasHead() bool {
return s.HeadRepo != nil && s.BranchName != ""
}
// GetPRHeadLabel returns the string that the GitHub API uses to identify the PR. This is
// either just the branch name or, if the PR is originating from a fork, the fork owner
// and the branch name, like <owner>:<branch>.

View file

@ -972,6 +972,52 @@ func TestPRRefs_GetPRHeadLabel(t *testing.T) {
}
}
func TestPullRequestRefs_HasHead(t *testing.T) {
tests := []struct {
name string
prRefs PullRequestRefs
want bool
}{
{
name: "HeadRepo is nil and BranchName is empty, return false",
prRefs: PullRequestRefs{
HeadRepo: nil,
BranchName: "",
},
want: false,
},
{
name: "HeadRepo is not nil and BranchName is empty, return false",
prRefs: PullRequestRefs{
HeadRepo: ghrepo.New("ORIGINOWNER", "REPO"),
BranchName: "",
},
want: false,
},
{
name: "HeadRepo is nil and BranchName is not empty, return false",
prRefs: PullRequestRefs{
HeadRepo: nil,
BranchName: "feature-branch",
},
want: false,
},
{
name: "HeadRepo is not nil and BranchName is not empty, return true",
prRefs: PullRequestRefs{
HeadRepo: ghrepo.New("ORIGINOWNER", "REPO"),
BranchName: "feature-branch",
},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.want, tt.prRefs.HasHead())
})
}
}
func stubBranchConfig(branchConfig git.BranchConfig, err error) func(string) (git.BranchConfig, error) {
return func(branch string) (git.BranchConfig, error) {
return branchConfig, err