diff --git a/pkg/cmd/pr/shared/finder.go b/pkg/cmd/pr/shared/finder.go index 06a66e93d..4aa58af96 100644 --- a/pkg/cmd/pr/shared/finder.go +++ b/pkg/cmd/pr/shared/finder.go @@ -99,7 +99,7 @@ type PRRefs struct { // either just the branch name or, if the PR is originating from a fork, the fork owner // and the branch name, like :. func (s *PRRefs) GetPRLabel() string { - if s.HeadRepo == s.BaseRepo { + if ghrepo.IsSame(s.HeadRepo, s.BaseRepo) { return s.BranchName } return fmt.Sprintf("%s:%s", s.HeadRepo.RepoOwner(), s.BranchName) @@ -296,6 +296,12 @@ func parsePRRefs(currentBranchName string, branchConfig git.BranchConfig, pushDe return prRefs, nil } } + + remoteNames := make([]string, len(rems)) + for i, r := range rems { + remoteNames[i] = r.Name + } + return PRRefs{}, fmt.Errorf("no remote for %q found in %q", branchConfig.Push, strings.Join(remoteNames, ", ")) } // To get the HeadRepo, we look to the git config. The PushRemote{Name | URL} comes from diff --git a/pkg/cmd/pr/shared/finder_test.go b/pkg/cmd/pr/shared/finder_test.go index 7a0140088..fe01e6b6e 100644 --- a/pkg/cmd/pr/shared/finder_test.go +++ b/pkg/cmd/pr/shared/finder_test.go @@ -2,6 +2,7 @@ package shared import ( "errors" + "fmt" "net/http" "net/url" "testing" @@ -70,7 +71,7 @@ func TestFind(t *testing.T) { args: args{ selector: "13", fields: []string{"id", "number"}, - baseRepoFn: stubBaseRepoFn(remoteOrigin.Repo, nil), + baseRepoFn: stubBaseRepoFn(ghrepo.New("ORIGINOWNER", "REPO"), nil), branchFn: func() (string, error) { return "blueberries", nil }, @@ -92,7 +93,7 @@ func TestFind(t *testing.T) { selector: "13", baseBranch: "main", fields: []string{"id", "number"}, - baseRepoFn: stubBaseRepoFn(remoteOrigin.Repo, nil), + baseRepoFn: stubBaseRepoFn(ghrepo.New("ORIGINOWNER", "REPO"), nil), branchFn: func() (string, error) { return "blueberries", nil }, @@ -147,7 +148,7 @@ func TestFind(t *testing.T) { args: args{ selector: "13", fields: []string{"number"}, - baseRepoFn: stubBaseRepoFn(remoteOrigin.Repo, nil), + baseRepoFn: stubBaseRepoFn(ghrepo.New("ORIGINOWNER", "REPO"), nil), branchFn: func() (string, error) { return "blueberries", nil }, @@ -163,7 +164,7 @@ func TestFind(t *testing.T) { args: args{ selector: "#13", fields: []string{"id", "number"}, - baseRepoFn: stubBaseRepoFn(remoteOrigin.Repo, nil), + baseRepoFn: stubBaseRepoFn(ghrepo.New("ORIGINOWNER", "REPO"), nil), branchFn: func() (string, error) { return "blueberries", nil }, @@ -207,7 +208,7 @@ func TestFind(t *testing.T) { args: args{ selector: "blueberries", fields: []string{"id", "number"}, - baseRepoFn: stubBaseRepoFn(remoteOrigin.Repo, nil), + baseRepoFn: stubBaseRepoFn(ghrepo.New("ORIGINOWNER", "REPO"), nil), branchFn: func() (string, error) { return "blueberries", nil }, @@ -370,7 +371,7 @@ func TestFind(t *testing.T) { args: args{ selector: "", fields: []string{"id", "number"}, - baseRepoFn: stubBaseRepoFn(remoteOrigin.Repo, nil), + baseRepoFn: stubBaseRepoFn(ghrepo.New("ORIGINOWNER", "REPO"), nil), branchFn: func() (string, error) { return "blueberries", nil }, @@ -404,7 +405,7 @@ func TestFind(t *testing.T) { args: args{ selector: "", fields: []string{"id", "number"}, - baseRepoFn: stubBaseRepoFn(remoteOrigin.Repo, nil), + baseRepoFn: stubBaseRepoFn(ghrepo.New("ORIGINOWNER", "REPO"), nil), branchFn: func() (string, error) { return "blueberries", nil }, @@ -438,7 +439,7 @@ func TestFind(t *testing.T) { args: args{ selector: "", fields: []string{"id", "number"}, - baseRepoFn: stubBaseRepoFn(remoteOrigin.Repo, nil), + baseRepoFn: stubBaseRepoFn(ghrepo.New("ORIGINOWNER", "REPO"), nil), branchFn: func() (string, error) { return "blueberries", nil }, @@ -691,21 +692,18 @@ func Test_parsePRRefs(t *testing.T) { wantErr: nil, }, { - name: "When the branch name doesn't match a different branch name in BranchConfig.Push, it returns the BranchConfig.Push branch name", + name: "When the push revision doesn't match a remote, it returns an error", branchConfig: git.BranchConfig{ Push: "origin/differentPushBranch", }, currentBranchName: "blueberries", baseRefRepo: remoteOrigin.Repo, rems: context.Remotes{ - &remoteOrigin, + &remoteUpstream, + &remoteOther, }, - wantPRRefs: PRRefs{ - BranchName: "differentPushBranch", - HeadRepo: remoteOrigin.Repo, - BaseRepo: remoteOrigin.Repo, - }, - wantErr: nil, + wantPRRefs: PRRefs{}, + wantErr: fmt.Errorf("no remote for %q found in %q", "origin/differentPushBranch", "upstream, other"), }, { name: "When the branch name doesn't match a different branch name in BranchConfig.Push and the remote isn't 'origin', it returns the BranchConfig.Push branch name", @@ -823,12 +821,11 @@ func Test_parsePRRefs(t *testing.T) { t.Run(tt.name, func(t *testing.T) { prRefs, err := parsePRRefs(tt.currentBranchName, tt.branchConfig, tt.pushDefault, tt.baseRefRepo, tt.rems) if tt.wantErr != nil { - require.Error(t, err) - assert.Equal(t, tt.wantErr, err) + require.Equal(t, tt.wantErr, err) } else { require.NoError(t, err) } - assert.Equal(t, tt.wantPRRefs, prRefs) + require.Equal(t, tt.wantPRRefs, prRefs) }) } }