diff --git a/pkg/cmd/pr/checkout/checkout.go b/pkg/cmd/pr/checkout/checkout.go index 84b16709f..0647aebb2 100644 --- a/pkg/cmd/pr/checkout/checkout.go +++ b/pkg/cmd/pr/checkout/checkout.go @@ -72,7 +72,7 @@ func NewCmdCheckout(f *cmdutil.Factory, runF func(*CheckoutOptions) error) *cobr func checkoutRun(opts *CheckoutOptions) error { findOptions := shared.FindOptions{ Selector: opts.SelectorArg, - Fields: []string{"number", "headRefName", "headRepository", "headRepositoryOwner"}, + Fields: []string{"number", "headRefName", "headRepository", "headRepositoryOwner", "isCrossRepository"}, } pr, baseRepo, err := opts.Finder.Find(findOptions) if err != nil { diff --git a/pkg/cmd/pr/checkout/checkout_test.go b/pkg/cmd/pr/checkout/checkout_test.go index 932db4778..48150ff84 100644 --- a/pkg/cmd/pr/checkout/checkout_test.go +++ b/pkg/cmd/pr/checkout/checkout_test.go @@ -110,7 +110,8 @@ func TestPRCheckout_sameRepo(t *testing.T) { defer http.Verify(t) baseRepo, pr := stubPR("OWNER/REPO", "OWNER/REPO:feature") - shared.RunCommandFinder("123", pr, baseRepo) + finder := shared.RunCommandFinder("123", pr, baseRepo) + finder.ExpectFields([]string{"number", "headRefName", "headRepository", "headRepositoryOwner", "isCrossRepository"}) cs, cmdTeardown := run.Stub() defer cmdTeardown(t) @@ -164,7 +165,8 @@ func TestPRCheckout_differentRepo_remoteExists(t *testing.T) { defer http.Verify(t) baseRepo, pr := stubPR("OWNER/REPO", "hubot/REPO:feature") - shared.RunCommandFinder("123", pr, baseRepo) + finder := shared.RunCommandFinder("123", pr, baseRepo) + finder.ExpectFields([]string{"number", "headRefName", "headRepository", "headRepositoryOwner", "isCrossRepository"}) cs, cmdTeardown := run.Stub() defer cmdTeardown(t) @@ -186,7 +188,8 @@ func TestPRCheckout_differentRepo(t *testing.T) { defer http.Verify(t) baseRepo, pr := stubPR("OWNER/REPO:master", "hubot/REPO:feature") - shared.RunCommandFinder("123", pr, baseRepo) + finder := shared.RunCommandFinder("123", pr, baseRepo) + finder.ExpectFields([]string{"number", "headRefName", "headRepository", "headRepositoryOwner", "isCrossRepository"}) cs, cmdTeardown := run.Stub() defer cmdTeardown(t) diff --git a/pkg/cmd/pr/shared/finder.go b/pkg/cmd/pr/shared/finder.go index 1ad253be2..7b90e146c 100644 --- a/pkg/cmd/pr/shared/finder.go +++ b/pkg/cmd/pr/shared/finder.go @@ -65,8 +65,10 @@ func NewFinder(factory *cmdutil.Factory) PRFinder { var runCommandFinder PRFinder // RunCommandFinder is the NewMockFinder substitute to be used ONLY in runCommand-style tests. -func RunCommandFinder(selector string, pr *api.PullRequest, repo ghrepo.Interface) { - runCommandFinder = NewMockFinder(selector, pr, repo) +func RunCommandFinder(selector string, pr *api.PullRequest, repo ghrepo.Interface) *mockFinder { + finder := NewMockFinder(selector, pr, repo) + runCommandFinder = finder + return finder } type FindOptions struct { @@ -460,7 +462,7 @@ func (err *NotFoundError) Unwrap() error { return err.error } -func NewMockFinder(selector string, pr *api.PullRequest, repo ghrepo.Interface) PRFinder { +func NewMockFinder(selector string, pr *api.PullRequest, repo ghrepo.Interface) *mockFinder { var err error if pr == nil { err = &NotFoundError{errors.New("no pull requests found")} @@ -476,6 +478,7 @@ func NewMockFinder(selector string, pr *api.PullRequest, repo ghrepo.Interface) type mockFinder struct { called bool expectSelector string + expectFields []string pr *api.PullRequest repo ghrepo.Interface err error @@ -488,6 +491,9 @@ func (m *mockFinder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, if m.expectSelector != opts.Selector { return nil, nil, fmt.Errorf("mockFinder: expected selector %q, got %q", m.expectSelector, opts.Selector) } + if len(m.expectFields) > 0 && !isEqualSet(m.expectFields, opts.Fields) { + return nil, nil, fmt.Errorf("mockFinder: expected fields %v, got %v", m.expectFields, opts.Fields) + } if m.called { return nil, nil, errors.New("mockFinder used more than once") } @@ -500,3 +506,27 @@ func (m *mockFinder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, return m.pr, m.repo, nil } + +func (m *mockFinder) ExpectFields(fields []string) { + m.expectFields = fields +} + +func isEqualSet(a, b []string) bool { + if len(a) != len(b) { + return false + } + + aCopy := make([]string, len(a)) + copy(aCopy, a) + bCopy := make([]string, len(b)) + copy(bCopy, b) + sort.Strings(aCopy) + sort.Strings(bCopy) + + for i := range aCopy { + if aCopy[i] != bCopy[i] { + return false + } + } + return true +}