Merge pull request #3663 from cli/cross-repo-pr-checkout

Fix `pr checkout` for cross-repository pull requests
This commit is contained in:
Mislav Marohnić 2021-05-19 13:35:57 +02:00 committed by GitHub
commit dee89a1b6a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 40 additions and 7 deletions

View file

@ -72,7 +72,7 @@ func NewCmdCheckout(f *cmdutil.Factory, runF func(*CheckoutOptions) error) *cobr
func checkoutRun(opts *CheckoutOptions) error {
findOptions := shared.FindOptions{
Selector: opts.SelectorArg,
Fields: []string{"number", "headRefName", "headRepository", "headRepositoryOwner"},
Fields: []string{"number", "headRefName", "headRepository", "headRepositoryOwner", "isCrossRepository"},
}
pr, baseRepo, err := opts.Finder.Find(findOptions)
if err != nil {

View file

@ -110,7 +110,8 @@ func TestPRCheckout_sameRepo(t *testing.T) {
defer http.Verify(t)
baseRepo, pr := stubPR("OWNER/REPO", "OWNER/REPO:feature")
shared.RunCommandFinder("123", pr, baseRepo)
finder := shared.RunCommandFinder("123", pr, baseRepo)
finder.ExpectFields([]string{"number", "headRefName", "headRepository", "headRepositoryOwner", "isCrossRepository"})
cs, cmdTeardown := run.Stub()
defer cmdTeardown(t)
@ -164,7 +165,8 @@ func TestPRCheckout_differentRepo_remoteExists(t *testing.T) {
defer http.Verify(t)
baseRepo, pr := stubPR("OWNER/REPO", "hubot/REPO:feature")
shared.RunCommandFinder("123", pr, baseRepo)
finder := shared.RunCommandFinder("123", pr, baseRepo)
finder.ExpectFields([]string{"number", "headRefName", "headRepository", "headRepositoryOwner", "isCrossRepository"})
cs, cmdTeardown := run.Stub()
defer cmdTeardown(t)
@ -186,7 +188,8 @@ func TestPRCheckout_differentRepo(t *testing.T) {
defer http.Verify(t)
baseRepo, pr := stubPR("OWNER/REPO:master", "hubot/REPO:feature")
shared.RunCommandFinder("123", pr, baseRepo)
finder := shared.RunCommandFinder("123", pr, baseRepo)
finder.ExpectFields([]string{"number", "headRefName", "headRepository", "headRepositoryOwner", "isCrossRepository"})
cs, cmdTeardown := run.Stub()
defer cmdTeardown(t)

View file

@ -65,8 +65,10 @@ func NewFinder(factory *cmdutil.Factory) PRFinder {
var runCommandFinder PRFinder
// RunCommandFinder is the NewMockFinder substitute to be used ONLY in runCommand-style tests.
func RunCommandFinder(selector string, pr *api.PullRequest, repo ghrepo.Interface) {
runCommandFinder = NewMockFinder(selector, pr, repo)
func RunCommandFinder(selector string, pr *api.PullRequest, repo ghrepo.Interface) *mockFinder {
finder := NewMockFinder(selector, pr, repo)
runCommandFinder = finder
return finder
}
type FindOptions struct {
@ -460,7 +462,7 @@ func (err *NotFoundError) Unwrap() error {
return err.error
}
func NewMockFinder(selector string, pr *api.PullRequest, repo ghrepo.Interface) PRFinder {
func NewMockFinder(selector string, pr *api.PullRequest, repo ghrepo.Interface) *mockFinder {
var err error
if pr == nil {
err = &NotFoundError{errors.New("no pull requests found")}
@ -476,6 +478,7 @@ func NewMockFinder(selector string, pr *api.PullRequest, repo ghrepo.Interface)
type mockFinder struct {
called bool
expectSelector string
expectFields []string
pr *api.PullRequest
repo ghrepo.Interface
err error
@ -488,6 +491,9 @@ func (m *mockFinder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface,
if m.expectSelector != opts.Selector {
return nil, nil, fmt.Errorf("mockFinder: expected selector %q, got %q", m.expectSelector, opts.Selector)
}
if len(m.expectFields) > 0 && !isEqualSet(m.expectFields, opts.Fields) {
return nil, nil, fmt.Errorf("mockFinder: expected fields %v, got %v", m.expectFields, opts.Fields)
}
if m.called {
return nil, nil, errors.New("mockFinder used more than once")
}
@ -500,3 +506,27 @@ func (m *mockFinder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface,
return m.pr, m.repo, nil
}
func (m *mockFinder) ExpectFields(fields []string) {
m.expectFields = fields
}
func isEqualSet(a, b []string) bool {
if len(a) != len(b) {
return false
}
aCopy := make([]string, len(a))
copy(aCopy, a)
bCopy := make([]string, len(b))
copy(bCopy, b)
sort.Strings(aCopy)
sort.Strings(bCopy)
for i := range aCopy {
if aCopy[i] != bCopy[i] {
return false
}
}
return true
}