Fix checkout when URL arg is from fork and cwd is upstream
This commit is contained in:
parent
69fff52026
commit
11b9496e17
4 changed files with 306 additions and 208 deletions
|
|
@ -27,13 +27,8 @@ type CheckoutOptions struct {
|
|||
Remotes func() (cliContext.Remotes, error)
|
||||
Branch func() (string, error)
|
||||
|
||||
Finder shared.PRFinder
|
||||
Prompter shared.Prompt
|
||||
Lister shared.PRLister
|
||||
PRResolver PRResolver
|
||||
|
||||
Interactive bool
|
||||
BaseRepo func() (ghrepo.Interface, error)
|
||||
SelectorArg string
|
||||
RecurseSubmodules bool
|
||||
Force bool
|
||||
Detach bool
|
||||
|
|
@ -48,8 +43,6 @@ func NewCmdCheckout(f *cmdutil.Factory, runF func(*CheckoutOptions) error) *cobr
|
|||
Config: f.Config,
|
||||
Remotes: f.Remotes,
|
||||
Branch: f.Branch,
|
||||
Prompter: f.Prompter,
|
||||
BaseRepo: f.BaseRepo,
|
||||
}
|
||||
|
||||
cmd := &cobra.Command{
|
||||
|
|
@ -66,15 +59,30 @@ func NewCmdCheckout(f *cmdutil.Factory, runF func(*CheckoutOptions) error) *cobr
|
|||
`),
|
||||
Args: cobra.MaximumNArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
opts.Finder = shared.NewFinder(f)
|
||||
opts.Lister = shared.NewLister(f)
|
||||
|
||||
if len(args) > 0 {
|
||||
opts.SelectorArg = args[0]
|
||||
} else if !opts.IO.CanPrompt() {
|
||||
return cmdutil.FlagErrorf("pull request number, URL, or branch required when not running interactively")
|
||||
opts.PRResolver = &specificPRResolver{
|
||||
prFinder: shared.NewFinder(f),
|
||||
selector: args[0],
|
||||
}
|
||||
} else if opts.IO.CanPrompt() {
|
||||
baseRepo, err := f.BaseRepo()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
httpClient, err := f.HttpClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
opts.PRResolver = &promptingPRResolver{
|
||||
io: opts.IO,
|
||||
prompter: f.Prompter,
|
||||
prLister: shared.NewLister(httpClient),
|
||||
baseRepo: baseRepo,
|
||||
}
|
||||
} else {
|
||||
opts.Interactive = true
|
||||
return cmdutil.FlagErrorf("pull request number, URL, or branch required when not running interactively")
|
||||
}
|
||||
|
||||
if runF != nil {
|
||||
|
|
@ -93,12 +101,7 @@ func NewCmdCheckout(f *cmdutil.Factory, runF func(*CheckoutOptions) error) *cobr
|
|||
}
|
||||
|
||||
func checkoutRun(opts *CheckoutOptions) error {
|
||||
baseRepo, err := opts.BaseRepo()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pr, err := resolvePR(baseRepo, opts.Prompter, opts.SelectorArg, opts.Interactive, opts.Finder, opts.Lister, opts.IO)
|
||||
pr, baseRepo, err := opts.PRResolver.Resolve()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -286,32 +289,47 @@ func executeCmds(client *git.Client, credentialPattern git.CredentialPattern, cm
|
|||
return nil
|
||||
}
|
||||
|
||||
func resolvePR(baseRepo ghrepo.Interface, prompter shared.Prompt, pullRequestSelector string, isInteractive bool, pullRequestFinder shared.PRFinder, prLister shared.PRLister, io *iostreams.IOStreams) (*api.PullRequest, error) {
|
||||
// When non-interactive
|
||||
if pullRequestSelector != "" {
|
||||
pr, _, err := pullRequestFinder.Find(shared.FindOptions{
|
||||
Selector: pullRequestSelector,
|
||||
Fields: []string{
|
||||
"number",
|
||||
"headRefName",
|
||||
"headRepository",
|
||||
"headRepositoryOwner",
|
||||
"isCrossRepository",
|
||||
"maintainerCanModify",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return pr, nil
|
||||
type PRResolver interface {
|
||||
Resolve() (*api.PullRequest, ghrepo.Interface, error)
|
||||
}
|
||||
|
||||
type specificPRResolver struct {
|
||||
prFinder shared.PRFinder
|
||||
selector string
|
||||
}
|
||||
|
||||
func (r *specificPRResolver) Resolve() (*api.PullRequest, ghrepo.Interface, error) {
|
||||
pr, baseRepo, err := r.prFinder.Find(shared.FindOptions{
|
||||
Selector: r.selector,
|
||||
Fields: []string{
|
||||
"number",
|
||||
"headRefName",
|
||||
"headRepository",
|
||||
"headRepositoryOwner",
|
||||
"isCrossRepository",
|
||||
"maintainerCanModify",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if !isInteractive {
|
||||
return nil, cmdutil.FlagErrorf("pull request number, URL, or branch required when not running interactively")
|
||||
}
|
||||
// When interactive
|
||||
io.StartProgressIndicator()
|
||||
listResult, err := prLister.List(shared.ListOptions{
|
||||
State: "open",
|
||||
return pr, baseRepo, nil
|
||||
}
|
||||
|
||||
type promptingPRResolver struct {
|
||||
io *iostreams.IOStreams
|
||||
prompter shared.Prompt
|
||||
|
||||
prLister shared.PRLister
|
||||
|
||||
baseRepo ghrepo.Interface
|
||||
}
|
||||
|
||||
func (r *promptingPRResolver) Resolve() (*api.PullRequest, ghrepo.Interface, error) {
|
||||
r.io.StartProgressIndicator()
|
||||
listResult, err := r.prLister.List(shared.ListOptions{
|
||||
BaseRepo: r.baseRepo,
|
||||
State: "open",
|
||||
Fields: []string{
|
||||
"number",
|
||||
"title",
|
||||
|
|
@ -325,21 +343,16 @@ func resolvePR(baseRepo ghrepo.Interface, prompter shared.Prompt, pullRequestSel
|
|||
"maintainerCanModify",
|
||||
},
|
||||
LimitResults: 10})
|
||||
io.StopProgressIndicator()
|
||||
r.io.StopProgressIndicator()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
if len(listResult.PullRequests) == 0 {
|
||||
return nil, shared.ListNoResults(ghrepo.FullName(baseRepo), "pull request", false)
|
||||
return nil, nil, shared.ListNoResults(ghrepo.FullName(r.baseRepo), "pull request", false)
|
||||
}
|
||||
|
||||
pr, err := promptForPR(prompter, *listResult)
|
||||
return pr, err
|
||||
}
|
||||
|
||||
func promptForPR(prompter shared.Prompt, jobs api.PullRequestAndTotalCount) (*api.PullRequest, error) {
|
||||
candidates := []string{}
|
||||
for _, pr := range jobs.PullRequests {
|
||||
for _, pr := range listResult.PullRequests {
|
||||
candidates = append(candidates, fmt.Sprintf("%d\t%s %s [%s]",
|
||||
pr.Number,
|
||||
shared.PrStateWithDraft(&pr),
|
||||
|
|
@ -348,14 +361,10 @@ func promptForPR(prompter shared.Prompt, jobs api.PullRequestAndTotalCount) (*ap
|
|||
))
|
||||
}
|
||||
|
||||
selected, err := prompter.Select("Select a pull request", "", candidates)
|
||||
selected, err := r.prompter.Select("Select a pull request", "", candidates)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if selected >= 0 {
|
||||
return &jobs.PullRequests[selected], nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
return &listResult.PullRequests[selected], r.baseRepo, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,8 +23,87 @@ import (
|
|||
"github.com/cli/cli/v2/test"
|
||||
"github.com/google/shlex"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewCmdCheckout(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args string
|
||||
wantsOpts CheckoutOptions
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "recurse submodules",
|
||||
args: "--recurse-submodules 123",
|
||||
wantsOpts: CheckoutOptions{
|
||||
RecurseSubmodules: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "force",
|
||||
args: "--force 123",
|
||||
wantsOpts: CheckoutOptions{
|
||||
Force: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "detach",
|
||||
args: "--detach 123",
|
||||
wantsOpts: CheckoutOptions{
|
||||
Detach: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "branch",
|
||||
args: "--branch test-branch 123",
|
||||
wantsOpts: CheckoutOptions{
|
||||
BranchName: "test-branch",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "when there is no selector and no TTY, returns an error",
|
||||
args: "",
|
||||
wantErr: cmdutil.FlagErrorf("pull request number, URL, or branch required when not running interactively"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ios, _, _, _ := iostreams.Test()
|
||||
f := &cmdutil.Factory{
|
||||
IOStreams: ios,
|
||||
}
|
||||
|
||||
ios.SetStdinTTY(false)
|
||||
|
||||
argv, err := shlex.Split(tt.args)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var spiedOpts *CheckoutOptions
|
||||
cmd := NewCmdCheckout(f, func(opts *CheckoutOptions) error {
|
||||
spiedOpts = opts
|
||||
return nil
|
||||
})
|
||||
cmd.SetArgs(argv)
|
||||
cmd.SetIn(&bytes.Buffer{})
|
||||
cmd.SetOut(&bytes.Buffer{})
|
||||
cmd.SetErr(&bytes.Buffer{})
|
||||
|
||||
_, err = cmd.ExecuteC()
|
||||
if tt.wantErr != nil {
|
||||
require.Equal(t, tt.wantErr, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.wantsOpts.RecurseSubmodules, spiedOpts.RecurseSubmodules)
|
||||
require.Equal(t, tt.wantsOpts.Force, spiedOpts.Force)
|
||||
require.Equal(t, tt.wantsOpts.Detach, spiedOpts.Detach)
|
||||
require.Equal(t, tt.wantsOpts.BranchName, spiedOpts.BranchName)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// repo: either "baseOwner/baseRepo" or "baseOwner/baseRepo:defaultBranch"
|
||||
// prHead: "headOwner/headRepo:headBranch"
|
||||
func stubPR(repo, prHead string) (ghrepo.Interface, *api.PullRequest) {
|
||||
|
|
@ -70,6 +149,20 @@ func _stubPR(repo, prHead string, number int, title string, state string, isDraf
|
|||
}
|
||||
}
|
||||
|
||||
type stubPRResolver struct {
|
||||
pr *api.PullRequest
|
||||
baseRepo ghrepo.Interface
|
||||
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *stubPRResolver) Resolve() (*api.PullRequest, ghrepo.Interface, error) {
|
||||
if s.err != nil {
|
||||
return nil, nil, s.err
|
||||
}
|
||||
return s.pr, s.baseRepo, nil
|
||||
}
|
||||
|
||||
func Test_checkoutRun(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
|
@ -88,16 +181,13 @@ func Test_checkoutRun(t *testing.T) {
|
|||
{
|
||||
name: "checkout with ssh remote URL",
|
||||
opts: &CheckoutOptions{
|
||||
SelectorArg: "123",
|
||||
Finder: func() shared.PRFinder {
|
||||
PRResolver: func() PRResolver {
|
||||
baseRepo, pr := stubPR("OWNER/REPO:master", "OWNER/REPO:feature")
|
||||
finder := shared.NewMockFinder("123", pr, baseRepo)
|
||||
return finder
|
||||
return &stubPRResolver{
|
||||
pr: pr,
|
||||
baseRepo: baseRepo,
|
||||
}
|
||||
}(),
|
||||
BaseRepo: func() (ghrepo.Interface, error) {
|
||||
baseRepo, _ := stubPR("OWNER/REPO:master", "OWNER/REPO:feature")
|
||||
return baseRepo, nil
|
||||
},
|
||||
Config: func() (gh.Config, error) {
|
||||
return config.NewBlankConfig(), nil
|
||||
},
|
||||
|
|
@ -117,18 +207,15 @@ func Test_checkoutRun(t *testing.T) {
|
|||
{
|
||||
name: "fork repo was deleted",
|
||||
opts: &CheckoutOptions{
|
||||
SelectorArg: "123",
|
||||
Finder: func() shared.PRFinder {
|
||||
baseRepo, pr := stubPR("OWNER/REPO:master", "hubot/REPO:feature")
|
||||
PRResolver: func() PRResolver {
|
||||
baseRepo, pr := stubPR("OWNER/REPO:master", "OWNER/REPO:feature")
|
||||
pr.MaintainerCanModify = true
|
||||
pr.HeadRepository = nil
|
||||
finder := shared.NewMockFinder("123", pr, baseRepo)
|
||||
return finder
|
||||
return &stubPRResolver{
|
||||
pr: pr,
|
||||
baseRepo: baseRepo,
|
||||
}
|
||||
}(),
|
||||
BaseRepo: func() (ghrepo.Interface, error) {
|
||||
baseRepo, _ := stubPR("OWNER/REPO:master", "OWNER/REPO:feature")
|
||||
return baseRepo, nil
|
||||
},
|
||||
Config: func() (gh.Config, error) {
|
||||
return config.NewBlankConfig(), nil
|
||||
},
|
||||
|
|
@ -151,17 +238,14 @@ func Test_checkoutRun(t *testing.T) {
|
|||
{
|
||||
name: "with local branch rename and existing git remote",
|
||||
opts: &CheckoutOptions{
|
||||
SelectorArg: "123",
|
||||
BranchName: "foobar",
|
||||
Finder: func() shared.PRFinder {
|
||||
BranchName: "foobar",
|
||||
PRResolver: func() PRResolver {
|
||||
baseRepo, pr := stubPR("OWNER/REPO:master", "OWNER/REPO:feature")
|
||||
finder := shared.NewMockFinder("123", pr, baseRepo)
|
||||
return finder
|
||||
return &stubPRResolver{
|
||||
pr: pr,
|
||||
baseRepo: baseRepo,
|
||||
}
|
||||
}(),
|
||||
BaseRepo: func() (ghrepo.Interface, error) {
|
||||
baseRepo, _ := stubPR("OWNER/REPO:master", "OWNER/REPO:feature")
|
||||
return baseRepo, nil
|
||||
},
|
||||
Config: func() (gh.Config, error) {
|
||||
return config.NewBlankConfig(), nil
|
||||
},
|
||||
|
|
@ -181,18 +265,15 @@ func Test_checkoutRun(t *testing.T) {
|
|||
{
|
||||
name: "with local branch name, no existing git remote",
|
||||
opts: &CheckoutOptions{
|
||||
SelectorArg: "123",
|
||||
BranchName: "foobar",
|
||||
Finder: func() shared.PRFinder {
|
||||
BranchName: "foobar",
|
||||
PRResolver: func() PRResolver {
|
||||
baseRepo, pr := stubPR("OWNER/REPO:master", "hubot/REPO:feature")
|
||||
pr.MaintainerCanModify = true
|
||||
finder := shared.NewMockFinder("123", pr, baseRepo)
|
||||
return finder
|
||||
return &stubPRResolver{
|
||||
pr: pr,
|
||||
baseRepo: baseRepo,
|
||||
}
|
||||
}(),
|
||||
BaseRepo: func() (ghrepo.Interface, error) {
|
||||
baseRepo, _ := stubPR("OWNER/REPO:master", "hubot/REPO:feature")
|
||||
return baseRepo, nil
|
||||
},
|
||||
Config: func() (gh.Config, error) {
|
||||
return config.NewBlankConfig(), nil
|
||||
},
|
||||
|
|
@ -213,78 +294,14 @@ func Test_checkoutRun(t *testing.T) {
|
|||
},
|
||||
},
|
||||
{
|
||||
name: "with no selected PR args and non tty, return error",
|
||||
name: "when the PR resolver errors, then that error is bubbled up",
|
||||
opts: &CheckoutOptions{
|
||||
SelectorArg: "",
|
||||
Interactive: false,
|
||||
BaseRepo: func() (ghrepo.Interface, error) {
|
||||
return ghrepo.New("OWNER", "REPO"), nil
|
||||
PRResolver: &stubPRResolver{
|
||||
err: errors.New("expected test error"),
|
||||
},
|
||||
},
|
||||
remotes: map[string]string{
|
||||
"origin": "OWNER/REPO",
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "pull request number, URL, or branch required when not running interactively",
|
||||
},
|
||||
{
|
||||
name: "with no selected PR args and stdin tty, prompts for choice",
|
||||
opts: &CheckoutOptions{
|
||||
SelectorArg: "",
|
||||
Interactive: true,
|
||||
Lister: func() shared.PRLister {
|
||||
_, pr1 := _stubPR("OWNER/REPO:master", "OWNER/REPO:feature", 32, "New feature", "OPEN", false)
|
||||
_, pr2 := _stubPR("OWNER/REPO:master", "OWNER/REPO:bug-fix", 29, "Fixed bad bug", "OPEN", false)
|
||||
_, pr3 := _stubPR("OWNER/REPO:master", "OWNER/REPO:docs", 28, "Improve documentation", "OPEN", true)
|
||||
lister := shared.NewMockLister(&api.PullRequestAndTotalCount{
|
||||
TotalCount: 3,
|
||||
PullRequests: []api.PullRequest{
|
||||
*pr1, *pr2, *pr3,
|
||||
}, SearchCapped: false}, nil)
|
||||
lister.ExpectFields([]string{"number", "title", "state", "isDraft", "headRefName", "headRepository", "headRepositoryOwner", "isCrossRepository", "maintainerCanModify"})
|
||||
return lister
|
||||
}(),
|
||||
BaseRepo: func() (ghrepo.Interface, error) {
|
||||
return ghrepo.New("OWNER", "REPO"), nil
|
||||
},
|
||||
Config: func() (gh.Config, error) {
|
||||
return config.NewBlankConfig(), nil
|
||||
},
|
||||
},
|
||||
promptStubs: func(pm *prompter.MockPrompter) {
|
||||
pm.RegisterSelect("Select a pull request",
|
||||
[]string{"32\tOPEN New feature [feature]", "29\tOPEN Fixed bad bug [bug-fix]", "28\tDRAFT Improve documentation [docs]"},
|
||||
func(_, _ string, opts []string) (int, error) {
|
||||
return prompter.IndexFor(opts, "32\tOPEN New feature [feature]")
|
||||
})
|
||||
},
|
||||
runStubs: func(cs *run.CommandStubber) {
|
||||
cs.Register(`git show-ref --verify -- refs/heads/feature`, 1, "")
|
||||
cs.Register(`git fetch origin \+refs/heads/feature:refs/remotes/origin/feature --no-tags`, 0, "")
|
||||
cs.Register(`git checkout -b feature --track origin/feature`, 0, "")
|
||||
},
|
||||
remotes: map[string]string{
|
||||
"origin": "OWNER/REPO",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with no select PR args and no open PR, return error",
|
||||
opts: &CheckoutOptions{
|
||||
SelectorArg: "",
|
||||
Interactive: true,
|
||||
BaseRepo: func() (ghrepo.Interface, error) {
|
||||
return ghrepo.New("OWNER", "REPO"), nil
|
||||
},
|
||||
Lister: shared.NewMockLister(&api.PullRequestAndTotalCount{
|
||||
TotalCount: 0,
|
||||
PullRequests: []api.PullRequest{},
|
||||
}, nil),
|
||||
},
|
||||
remotes: map[string]string{
|
||||
"origin": "OWNER/REPO",
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "no open pull requests in OWNER/REPO",
|
||||
errMsg: "expected test error",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
|
|
@ -309,12 +326,6 @@ func Test_checkoutRun(t *testing.T) {
|
|||
tt.runStubs(cmdStubs)
|
||||
}
|
||||
|
||||
pm := prompter.NewMockPrompter(t)
|
||||
tt.opts.Prompter = pm
|
||||
if tt.promptStubs != nil {
|
||||
tt.promptStubs(pm)
|
||||
}
|
||||
|
||||
opts.Remotes = func() (context.Remotes, error) {
|
||||
if len(tt.remotes) == 0 {
|
||||
return nil, errors.New("no remotes")
|
||||
|
|
@ -351,6 +362,102 @@ func Test_checkoutRun(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestSpecificPRResolver(t *testing.T) {
|
||||
t.Run("when the PR Finder returns results, those are returned", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
baseRepo, pr := stubPR("OWNER/REPO:master", "OWNER/REPO:feature")
|
||||
mockFinder := shared.NewMockFinder("123", pr, baseRepo)
|
||||
mockFinder.ExpectFields([]string{"number", "headRefName", "headRepository", "headRepositoryOwner", "isCrossRepository", "maintainerCanModify"})
|
||||
|
||||
resolver := &specificPRResolver{
|
||||
prFinder: mockFinder,
|
||||
selector: "123",
|
||||
}
|
||||
|
||||
resolvedPR, resolvedBaseRepo, err := resolver.Resolve()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, pr, resolvedPR)
|
||||
require.True(t, ghrepo.IsSame(baseRepo, resolvedBaseRepo), "expected repos to be the same")
|
||||
})
|
||||
|
||||
t.Run("when the PR Finder errors, that error is returned", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockFinder := shared.NewMockFinder("123", nil, nil)
|
||||
|
||||
resolver := &specificPRResolver{
|
||||
prFinder: mockFinder,
|
||||
selector: "123",
|
||||
}
|
||||
|
||||
_, _, err := resolver.Resolve()
|
||||
var notFoundErr *shared.NotFoundError
|
||||
require.ErrorAs(t, err, ¬FoundErr)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPromptingPRResolver(t *testing.T) {
|
||||
t.Run("when the PR Lister has results, then we prompt for a choice", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ios, _, _, _ := iostreams.Test()
|
||||
|
||||
baseRepo, pr1 := _stubPR("OWNER/REPO:master", "OWNER/REPO:feature", 32, "New feature", "OPEN", false)
|
||||
_, pr2 := _stubPR("OWNER/REPO:master", "OWNER/REPO:bug-fix", 29, "Fixed bad bug", "OPEN", false)
|
||||
_, pr3 := _stubPR("OWNER/REPO:master", "OWNER/REPO:docs", 28, "Improve documentation", "OPEN", true)
|
||||
lister := shared.NewMockLister(&api.PullRequestAndTotalCount{
|
||||
TotalCount: 3,
|
||||
PullRequests: []api.PullRequest{
|
||||
*pr1, *pr2, *pr3,
|
||||
}, SearchCapped: false}, nil)
|
||||
lister.ExpectFields([]string{"number", "title", "state", "isDraft", "headRefName", "headRepository", "headRepositoryOwner", "isCrossRepository", "maintainerCanModify"})
|
||||
|
||||
pm := prompter.NewMockPrompter(t)
|
||||
pm.RegisterSelect("Select a pull request",
|
||||
[]string{"32\tOPEN New feature [feature]", "29\tOPEN Fixed bad bug [bug-fix]", "28\tDRAFT Improve documentation [docs]"},
|
||||
func(_, _ string, opts []string) (int, error) {
|
||||
return prompter.IndexFor(opts, "32\tOPEN New feature [feature]")
|
||||
})
|
||||
|
||||
resolver := &promptingPRResolver{
|
||||
io: ios,
|
||||
prompter: pm,
|
||||
|
||||
prLister: lister,
|
||||
|
||||
baseRepo: baseRepo,
|
||||
}
|
||||
|
||||
resolvedPR, resolvedBaseRepo, err := resolver.Resolve()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, pr1, resolvedPR)
|
||||
require.True(t, ghrepo.IsSame(baseRepo, resolvedBaseRepo), "expected repos to be the same")
|
||||
})
|
||||
|
||||
t.Run("when the PR lister has no results, then we return an error", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ios, _, _, _ := iostreams.Test()
|
||||
|
||||
lister := shared.NewMockLister(&api.PullRequestAndTotalCount{
|
||||
TotalCount: 0,
|
||||
PullRequests: []api.PullRequest{},
|
||||
}, nil)
|
||||
|
||||
resolver := &promptingPRResolver{
|
||||
io: ios,
|
||||
prLister: lister,
|
||||
baseRepo: ghrepo.New("OWNER", "REPO"),
|
||||
}
|
||||
|
||||
_, _, err := resolver.Resolve()
|
||||
var noResultsErr cmdutil.NoResultsError
|
||||
require.ErrorAs(t, err, &noResultsErr)
|
||||
require.Equal(t, "no open pull requests in OWNER/REPO", noResultsErr.Error())
|
||||
})
|
||||
}
|
||||
|
||||
/** LEGACY TESTS **/
|
||||
|
||||
func runCommand(rt http.RoundTripper, remotes context.Remotes, branch string, cli string, baseRepo ghrepo.Interface) (*test.CmdOut, error) {
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ import (
|
|||
"github.com/cli/cli/v2/api"
|
||||
"github.com/cli/cli/v2/internal/ghrepo"
|
||||
prShared "github.com/cli/cli/v2/pkg/cmd/pr/shared"
|
||||
"github.com/cli/cli/v2/pkg/cmdutil"
|
||||
)
|
||||
|
||||
func shouldUseSearch(filters prShared.FilterOptions) bool {
|
||||
|
|
@ -19,10 +18,8 @@ func listPullRequests(httpClient *http.Client, repo ghrepo.Interface, filters pr
|
|||
return searchPullRequests(httpClient, repo, filters, limit)
|
||||
}
|
||||
|
||||
return prShared.NewLister(&cmdutil.Factory{
|
||||
HttpClient: func() (*http.Client, error) { return httpClient, nil },
|
||||
BaseRepo: func() (ghrepo.Interface, error) { return repo, nil },
|
||||
}).List(prShared.ListOptions{
|
||||
return prShared.NewLister(httpClient).List(prShared.ListOptions{
|
||||
BaseRepo: repo,
|
||||
LimitResults: limit,
|
||||
State: filters.State,
|
||||
BaseBranch: filters.BaseBranch,
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ import (
|
|||
"net/http"
|
||||
|
||||
"github.com/cli/cli/v2/internal/ghrepo"
|
||||
"github.com/cli/cli/v2/pkg/cmdutil"
|
||||
|
||||
api "github.com/cli/cli/v2/api"
|
||||
)
|
||||
|
|
@ -15,6 +14,8 @@ type PRLister interface {
|
|||
}
|
||||
|
||||
type ListOptions struct {
|
||||
BaseRepo ghrepo.Interface
|
||||
|
||||
LimitResults int
|
||||
|
||||
State string
|
||||
|
|
@ -25,32 +26,16 @@ type ListOptions struct {
|
|||
}
|
||||
|
||||
type lister struct {
|
||||
baseRepoFn func() (ghrepo.Interface, error)
|
||||
httpClient func() (*http.Client, error)
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewLister(factory *cmdutil.Factory) PRLister {
|
||||
func NewLister(httpClient *http.Client) PRLister {
|
||||
return &lister{
|
||||
baseRepoFn: factory.BaseRepo,
|
||||
httpClient: factory.HttpClient,
|
||||
httpClient: httpClient,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *lister) List(opts ListOptions) (*api.PullRequestAndTotalCount, error) {
|
||||
repo, err := l.baseRepoFn()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client, err := l.httpClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return listPullRequests(client, repo, opts)
|
||||
}
|
||||
|
||||
func listPullRequests(httpClient *http.Client, repo ghrepo.Interface, filters ListOptions) (*api.PullRequestAndTotalCount, error) {
|
||||
type response struct {
|
||||
Repository struct {
|
||||
PullRequests struct {
|
||||
|
|
@ -63,8 +48,8 @@ func listPullRequests(httpClient *http.Client, repo ghrepo.Interface, filters Li
|
|||
}
|
||||
}
|
||||
}
|
||||
limit := filters.LimitResults
|
||||
fragment := fmt.Sprintf("fragment pr on PullRequest{%s}", api.PullRequestGraphQL(filters.Fields))
|
||||
limit := opts.LimitResults
|
||||
fragment := fmt.Sprintf("fragment pr on PullRequest{%s}", api.PullRequestGraphQL(opts.Fields))
|
||||
query := fragment + `
|
||||
query PullRequestList(
|
||||
$owner: String!,
|
||||
|
|
@ -98,11 +83,11 @@ func listPullRequests(httpClient *http.Client, repo ghrepo.Interface, filters Li
|
|||
|
||||
pageLimit := min(limit, 100)
|
||||
variables := map[string]interface{}{
|
||||
"owner": repo.RepoOwner(),
|
||||
"repo": repo.RepoName(),
|
||||
"owner": opts.BaseRepo.RepoOwner(),
|
||||
"repo": opts.BaseRepo.RepoName(),
|
||||
}
|
||||
|
||||
switch filters.State {
|
||||
switch opts.State {
|
||||
case "open":
|
||||
variables["state"] = []string{"OPEN"}
|
||||
case "closed":
|
||||
|
|
@ -112,25 +97,25 @@ func listPullRequests(httpClient *http.Client, repo ghrepo.Interface, filters Li
|
|||
case "all":
|
||||
variables["state"] = []string{"OPEN", "CLOSED", "MERGED"}
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid state: %s", filters.State)
|
||||
return nil, fmt.Errorf("invalid state: %s", opts.State)
|
||||
}
|
||||
|
||||
if filters.BaseBranch != "" {
|
||||
variables["baseBranch"] = filters.BaseBranch
|
||||
if opts.BaseBranch != "" {
|
||||
variables["baseBranch"] = opts.BaseBranch
|
||||
}
|
||||
if filters.HeadBranch != "" {
|
||||
variables["headBranch"] = filters.HeadBranch
|
||||
if opts.HeadBranch != "" {
|
||||
variables["headBranch"] = opts.HeadBranch
|
||||
}
|
||||
|
||||
res := api.PullRequestAndTotalCount{}
|
||||
var check = make(map[int]struct{})
|
||||
client := api.NewClientFromHTTP(httpClient)
|
||||
client := api.NewClientFromHTTP(l.httpClient)
|
||||
|
||||
loop:
|
||||
for {
|
||||
variables["limit"] = pageLimit
|
||||
var data response
|
||||
err := client.GraphQL(repo.RepoHost(), query, variables, &data)
|
||||
err := client.GraphQL(opts.BaseRepo.RepoHost(), query, variables, &data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue