Fix checkout when URL arg is from fork and cwd is upstream

This commit is contained in:
William Martin 2025-02-27 15:07:09 +01:00
parent 69fff52026
commit 11b9496e17
4 changed files with 306 additions and 208 deletions

View file

@ -27,13 +27,8 @@ type CheckoutOptions struct {
Remotes func() (cliContext.Remotes, error)
Branch func() (string, error)
Finder shared.PRFinder
Prompter shared.Prompt
Lister shared.PRLister
PRResolver PRResolver
Interactive bool
BaseRepo func() (ghrepo.Interface, error)
SelectorArg string
RecurseSubmodules bool
Force bool
Detach bool
@ -48,8 +43,6 @@ func NewCmdCheckout(f *cmdutil.Factory, runF func(*CheckoutOptions) error) *cobr
Config: f.Config,
Remotes: f.Remotes,
Branch: f.Branch,
Prompter: f.Prompter,
BaseRepo: f.BaseRepo,
}
cmd := &cobra.Command{
@ -66,15 +59,30 @@ func NewCmdCheckout(f *cmdutil.Factory, runF func(*CheckoutOptions) error) *cobr
`),
Args: cobra.MaximumNArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
opts.Finder = shared.NewFinder(f)
opts.Lister = shared.NewLister(f)
if len(args) > 0 {
opts.SelectorArg = args[0]
} else if !opts.IO.CanPrompt() {
return cmdutil.FlagErrorf("pull request number, URL, or branch required when not running interactively")
opts.PRResolver = &specificPRResolver{
prFinder: shared.NewFinder(f),
selector: args[0],
}
} else if opts.IO.CanPrompt() {
baseRepo, err := f.BaseRepo()
if err != nil {
return err
}
httpClient, err := f.HttpClient()
if err != nil {
return err
}
opts.PRResolver = &promptingPRResolver{
io: opts.IO,
prompter: f.Prompter,
prLister: shared.NewLister(httpClient),
baseRepo: baseRepo,
}
} else {
opts.Interactive = true
return cmdutil.FlagErrorf("pull request number, URL, or branch required when not running interactively")
}
if runF != nil {
@ -93,12 +101,7 @@ func NewCmdCheckout(f *cmdutil.Factory, runF func(*CheckoutOptions) error) *cobr
}
func checkoutRun(opts *CheckoutOptions) error {
baseRepo, err := opts.BaseRepo()
if err != nil {
return err
}
pr, err := resolvePR(baseRepo, opts.Prompter, opts.SelectorArg, opts.Interactive, opts.Finder, opts.Lister, opts.IO)
pr, baseRepo, err := opts.PRResolver.Resolve()
if err != nil {
return err
}
@ -286,32 +289,47 @@ func executeCmds(client *git.Client, credentialPattern git.CredentialPattern, cm
return nil
}
func resolvePR(baseRepo ghrepo.Interface, prompter shared.Prompt, pullRequestSelector string, isInteractive bool, pullRequestFinder shared.PRFinder, prLister shared.PRLister, io *iostreams.IOStreams) (*api.PullRequest, error) {
// When non-interactive
if pullRequestSelector != "" {
pr, _, err := pullRequestFinder.Find(shared.FindOptions{
Selector: pullRequestSelector,
Fields: []string{
"number",
"headRefName",
"headRepository",
"headRepositoryOwner",
"isCrossRepository",
"maintainerCanModify",
},
})
if err != nil {
return nil, err
}
return pr, nil
type PRResolver interface {
Resolve() (*api.PullRequest, ghrepo.Interface, error)
}
type specificPRResolver struct {
prFinder shared.PRFinder
selector string
}
func (r *specificPRResolver) Resolve() (*api.PullRequest, ghrepo.Interface, error) {
pr, baseRepo, err := r.prFinder.Find(shared.FindOptions{
Selector: r.selector,
Fields: []string{
"number",
"headRefName",
"headRepository",
"headRepositoryOwner",
"isCrossRepository",
"maintainerCanModify",
},
})
if err != nil {
return nil, nil, err
}
if !isInteractive {
return nil, cmdutil.FlagErrorf("pull request number, URL, or branch required when not running interactively")
}
// When interactive
io.StartProgressIndicator()
listResult, err := prLister.List(shared.ListOptions{
State: "open",
return pr, baseRepo, nil
}
type promptingPRResolver struct {
io *iostreams.IOStreams
prompter shared.Prompt
prLister shared.PRLister
baseRepo ghrepo.Interface
}
func (r *promptingPRResolver) Resolve() (*api.PullRequest, ghrepo.Interface, error) {
r.io.StartProgressIndicator()
listResult, err := r.prLister.List(shared.ListOptions{
BaseRepo: r.baseRepo,
State: "open",
Fields: []string{
"number",
"title",
@ -325,21 +343,16 @@ func resolvePR(baseRepo ghrepo.Interface, prompter shared.Prompt, pullRequestSel
"maintainerCanModify",
},
LimitResults: 10})
io.StopProgressIndicator()
r.io.StopProgressIndicator()
if err != nil {
return nil, err
return nil, nil, err
}
if len(listResult.PullRequests) == 0 {
return nil, shared.ListNoResults(ghrepo.FullName(baseRepo), "pull request", false)
return nil, nil, shared.ListNoResults(ghrepo.FullName(r.baseRepo), "pull request", false)
}
pr, err := promptForPR(prompter, *listResult)
return pr, err
}
func promptForPR(prompter shared.Prompt, jobs api.PullRequestAndTotalCount) (*api.PullRequest, error) {
candidates := []string{}
for _, pr := range jobs.PullRequests {
for _, pr := range listResult.PullRequests {
candidates = append(candidates, fmt.Sprintf("%d\t%s %s [%s]",
pr.Number,
shared.PrStateWithDraft(&pr),
@ -348,14 +361,10 @@ func promptForPR(prompter shared.Prompt, jobs api.PullRequestAndTotalCount) (*ap
))
}
selected, err := prompter.Select("Select a pull request", "", candidates)
selected, err := r.prompter.Select("Select a pull request", "", candidates)
if err != nil {
return nil, err
return nil, nil, err
}
if selected >= 0 {
return &jobs.PullRequests[selected], nil
}
return nil, nil
return &listResult.PullRequests[selected], r.baseRepo, nil
}

View file

@ -23,8 +23,87 @@ import (
"github.com/cli/cli/v2/test"
"github.com/google/shlex"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewCmdCheckout(t *testing.T) {
tests := []struct {
name string
args string
wantsOpts CheckoutOptions
wantErr error
}{
{
name: "recurse submodules",
args: "--recurse-submodules 123",
wantsOpts: CheckoutOptions{
RecurseSubmodules: true,
},
},
{
name: "force",
args: "--force 123",
wantsOpts: CheckoutOptions{
Force: true,
},
},
{
name: "detach",
args: "--detach 123",
wantsOpts: CheckoutOptions{
Detach: true,
},
},
{
name: "branch",
args: "--branch test-branch 123",
wantsOpts: CheckoutOptions{
BranchName: "test-branch",
},
},
{
name: "when there is no selector and no TTY, returns an error",
args: "",
wantErr: cmdutil.FlagErrorf("pull request number, URL, or branch required when not running interactively"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ios, _, _, _ := iostreams.Test()
f := &cmdutil.Factory{
IOStreams: ios,
}
ios.SetStdinTTY(false)
argv, err := shlex.Split(tt.args)
assert.NoError(t, err)
var spiedOpts *CheckoutOptions
cmd := NewCmdCheckout(f, func(opts *CheckoutOptions) error {
spiedOpts = opts
return nil
})
cmd.SetArgs(argv)
cmd.SetIn(&bytes.Buffer{})
cmd.SetOut(&bytes.Buffer{})
cmd.SetErr(&bytes.Buffer{})
_, err = cmd.ExecuteC()
if tt.wantErr != nil {
require.Equal(t, tt.wantErr, err)
return
}
require.NoError(t, err)
require.Equal(t, tt.wantsOpts.RecurseSubmodules, spiedOpts.RecurseSubmodules)
require.Equal(t, tt.wantsOpts.Force, spiedOpts.Force)
require.Equal(t, tt.wantsOpts.Detach, spiedOpts.Detach)
require.Equal(t, tt.wantsOpts.BranchName, spiedOpts.BranchName)
})
}
}
// repo: either "baseOwner/baseRepo" or "baseOwner/baseRepo:defaultBranch"
// prHead: "headOwner/headRepo:headBranch"
func stubPR(repo, prHead string) (ghrepo.Interface, *api.PullRequest) {
@ -70,6 +149,20 @@ func _stubPR(repo, prHead string, number int, title string, state string, isDraf
}
}
type stubPRResolver struct {
pr *api.PullRequest
baseRepo ghrepo.Interface
err error
}
func (s *stubPRResolver) Resolve() (*api.PullRequest, ghrepo.Interface, error) {
if s.err != nil {
return nil, nil, s.err
}
return s.pr, s.baseRepo, nil
}
func Test_checkoutRun(t *testing.T) {
tests := []struct {
name string
@ -88,16 +181,13 @@ func Test_checkoutRun(t *testing.T) {
{
name: "checkout with ssh remote URL",
opts: &CheckoutOptions{
SelectorArg: "123",
Finder: func() shared.PRFinder {
PRResolver: func() PRResolver {
baseRepo, pr := stubPR("OWNER/REPO:master", "OWNER/REPO:feature")
finder := shared.NewMockFinder("123", pr, baseRepo)
return finder
return &stubPRResolver{
pr: pr,
baseRepo: baseRepo,
}
}(),
BaseRepo: func() (ghrepo.Interface, error) {
baseRepo, _ := stubPR("OWNER/REPO:master", "OWNER/REPO:feature")
return baseRepo, nil
},
Config: func() (gh.Config, error) {
return config.NewBlankConfig(), nil
},
@ -117,18 +207,15 @@ func Test_checkoutRun(t *testing.T) {
{
name: "fork repo was deleted",
opts: &CheckoutOptions{
SelectorArg: "123",
Finder: func() shared.PRFinder {
baseRepo, pr := stubPR("OWNER/REPO:master", "hubot/REPO:feature")
PRResolver: func() PRResolver {
baseRepo, pr := stubPR("OWNER/REPO:master", "OWNER/REPO:feature")
pr.MaintainerCanModify = true
pr.HeadRepository = nil
finder := shared.NewMockFinder("123", pr, baseRepo)
return finder
return &stubPRResolver{
pr: pr,
baseRepo: baseRepo,
}
}(),
BaseRepo: func() (ghrepo.Interface, error) {
baseRepo, _ := stubPR("OWNER/REPO:master", "OWNER/REPO:feature")
return baseRepo, nil
},
Config: func() (gh.Config, error) {
return config.NewBlankConfig(), nil
},
@ -151,17 +238,14 @@ func Test_checkoutRun(t *testing.T) {
{
name: "with local branch rename and existing git remote",
opts: &CheckoutOptions{
SelectorArg: "123",
BranchName: "foobar",
Finder: func() shared.PRFinder {
BranchName: "foobar",
PRResolver: func() PRResolver {
baseRepo, pr := stubPR("OWNER/REPO:master", "OWNER/REPO:feature")
finder := shared.NewMockFinder("123", pr, baseRepo)
return finder
return &stubPRResolver{
pr: pr,
baseRepo: baseRepo,
}
}(),
BaseRepo: func() (ghrepo.Interface, error) {
baseRepo, _ := stubPR("OWNER/REPO:master", "OWNER/REPO:feature")
return baseRepo, nil
},
Config: func() (gh.Config, error) {
return config.NewBlankConfig(), nil
},
@ -181,18 +265,15 @@ func Test_checkoutRun(t *testing.T) {
{
name: "with local branch name, no existing git remote",
opts: &CheckoutOptions{
SelectorArg: "123",
BranchName: "foobar",
Finder: func() shared.PRFinder {
BranchName: "foobar",
PRResolver: func() PRResolver {
baseRepo, pr := stubPR("OWNER/REPO:master", "hubot/REPO:feature")
pr.MaintainerCanModify = true
finder := shared.NewMockFinder("123", pr, baseRepo)
return finder
return &stubPRResolver{
pr: pr,
baseRepo: baseRepo,
}
}(),
BaseRepo: func() (ghrepo.Interface, error) {
baseRepo, _ := stubPR("OWNER/REPO:master", "hubot/REPO:feature")
return baseRepo, nil
},
Config: func() (gh.Config, error) {
return config.NewBlankConfig(), nil
},
@ -213,78 +294,14 @@ func Test_checkoutRun(t *testing.T) {
},
},
{
name: "with no selected PR args and non tty, return error",
name: "when the PR resolver errors, then that error is bubbled up",
opts: &CheckoutOptions{
SelectorArg: "",
Interactive: false,
BaseRepo: func() (ghrepo.Interface, error) {
return ghrepo.New("OWNER", "REPO"), nil
PRResolver: &stubPRResolver{
err: errors.New("expected test error"),
},
},
remotes: map[string]string{
"origin": "OWNER/REPO",
},
wantErr: true,
errMsg: "pull request number, URL, or branch required when not running interactively",
},
{
name: "with no selected PR args and stdin tty, prompts for choice",
opts: &CheckoutOptions{
SelectorArg: "",
Interactive: true,
Lister: func() shared.PRLister {
_, pr1 := _stubPR("OWNER/REPO:master", "OWNER/REPO:feature", 32, "New feature", "OPEN", false)
_, pr2 := _stubPR("OWNER/REPO:master", "OWNER/REPO:bug-fix", 29, "Fixed bad bug", "OPEN", false)
_, pr3 := _stubPR("OWNER/REPO:master", "OWNER/REPO:docs", 28, "Improve documentation", "OPEN", true)
lister := shared.NewMockLister(&api.PullRequestAndTotalCount{
TotalCount: 3,
PullRequests: []api.PullRequest{
*pr1, *pr2, *pr3,
}, SearchCapped: false}, nil)
lister.ExpectFields([]string{"number", "title", "state", "isDraft", "headRefName", "headRepository", "headRepositoryOwner", "isCrossRepository", "maintainerCanModify"})
return lister
}(),
BaseRepo: func() (ghrepo.Interface, error) {
return ghrepo.New("OWNER", "REPO"), nil
},
Config: func() (gh.Config, error) {
return config.NewBlankConfig(), nil
},
},
promptStubs: func(pm *prompter.MockPrompter) {
pm.RegisterSelect("Select a pull request",
[]string{"32\tOPEN New feature [feature]", "29\tOPEN Fixed bad bug [bug-fix]", "28\tDRAFT Improve documentation [docs]"},
func(_, _ string, opts []string) (int, error) {
return prompter.IndexFor(opts, "32\tOPEN New feature [feature]")
})
},
runStubs: func(cs *run.CommandStubber) {
cs.Register(`git show-ref --verify -- refs/heads/feature`, 1, "")
cs.Register(`git fetch origin \+refs/heads/feature:refs/remotes/origin/feature --no-tags`, 0, "")
cs.Register(`git checkout -b feature --track origin/feature`, 0, "")
},
remotes: map[string]string{
"origin": "OWNER/REPO",
},
},
{
name: "with no select PR args and no open PR, return error",
opts: &CheckoutOptions{
SelectorArg: "",
Interactive: true,
BaseRepo: func() (ghrepo.Interface, error) {
return ghrepo.New("OWNER", "REPO"), nil
},
Lister: shared.NewMockLister(&api.PullRequestAndTotalCount{
TotalCount: 0,
PullRequests: []api.PullRequest{},
}, nil),
},
remotes: map[string]string{
"origin": "OWNER/REPO",
},
wantErr: true,
errMsg: "no open pull requests in OWNER/REPO",
errMsg: "expected test error",
},
}
for _, tt := range tests {
@ -309,12 +326,6 @@ func Test_checkoutRun(t *testing.T) {
tt.runStubs(cmdStubs)
}
pm := prompter.NewMockPrompter(t)
tt.opts.Prompter = pm
if tt.promptStubs != nil {
tt.promptStubs(pm)
}
opts.Remotes = func() (context.Remotes, error) {
if len(tt.remotes) == 0 {
return nil, errors.New("no remotes")
@ -351,6 +362,102 @@ func Test_checkoutRun(t *testing.T) {
}
}
func TestSpecificPRResolver(t *testing.T) {
t.Run("when the PR Finder returns results, those are returned", func(t *testing.T) {
t.Parallel()
baseRepo, pr := stubPR("OWNER/REPO:master", "OWNER/REPO:feature")
mockFinder := shared.NewMockFinder("123", pr, baseRepo)
mockFinder.ExpectFields([]string{"number", "headRefName", "headRepository", "headRepositoryOwner", "isCrossRepository", "maintainerCanModify"})
resolver := &specificPRResolver{
prFinder: mockFinder,
selector: "123",
}
resolvedPR, resolvedBaseRepo, err := resolver.Resolve()
require.NoError(t, err)
require.Equal(t, pr, resolvedPR)
require.True(t, ghrepo.IsSame(baseRepo, resolvedBaseRepo), "expected repos to be the same")
})
t.Run("when the PR Finder errors, that error is returned", func(t *testing.T) {
t.Parallel()
mockFinder := shared.NewMockFinder("123", nil, nil)
resolver := &specificPRResolver{
prFinder: mockFinder,
selector: "123",
}
_, _, err := resolver.Resolve()
var notFoundErr *shared.NotFoundError
require.ErrorAs(t, err, &notFoundErr)
})
}
func TestPromptingPRResolver(t *testing.T) {
t.Run("when the PR Lister has results, then we prompt for a choice", func(t *testing.T) {
t.Parallel()
ios, _, _, _ := iostreams.Test()
baseRepo, pr1 := _stubPR("OWNER/REPO:master", "OWNER/REPO:feature", 32, "New feature", "OPEN", false)
_, pr2 := _stubPR("OWNER/REPO:master", "OWNER/REPO:bug-fix", 29, "Fixed bad bug", "OPEN", false)
_, pr3 := _stubPR("OWNER/REPO:master", "OWNER/REPO:docs", 28, "Improve documentation", "OPEN", true)
lister := shared.NewMockLister(&api.PullRequestAndTotalCount{
TotalCount: 3,
PullRequests: []api.PullRequest{
*pr1, *pr2, *pr3,
}, SearchCapped: false}, nil)
lister.ExpectFields([]string{"number", "title", "state", "isDraft", "headRefName", "headRepository", "headRepositoryOwner", "isCrossRepository", "maintainerCanModify"})
pm := prompter.NewMockPrompter(t)
pm.RegisterSelect("Select a pull request",
[]string{"32\tOPEN New feature [feature]", "29\tOPEN Fixed bad bug [bug-fix]", "28\tDRAFT Improve documentation [docs]"},
func(_, _ string, opts []string) (int, error) {
return prompter.IndexFor(opts, "32\tOPEN New feature [feature]")
})
resolver := &promptingPRResolver{
io: ios,
prompter: pm,
prLister: lister,
baseRepo: baseRepo,
}
resolvedPR, resolvedBaseRepo, err := resolver.Resolve()
require.NoError(t, err)
require.Equal(t, pr1, resolvedPR)
require.True(t, ghrepo.IsSame(baseRepo, resolvedBaseRepo), "expected repos to be the same")
})
t.Run("when the PR lister has no results, then we return an error", func(t *testing.T) {
t.Parallel()
ios, _, _, _ := iostreams.Test()
lister := shared.NewMockLister(&api.PullRequestAndTotalCount{
TotalCount: 0,
PullRequests: []api.PullRequest{},
}, nil)
resolver := &promptingPRResolver{
io: ios,
prLister: lister,
baseRepo: ghrepo.New("OWNER", "REPO"),
}
_, _, err := resolver.Resolve()
var noResultsErr cmdutil.NoResultsError
require.ErrorAs(t, err, &noResultsErr)
require.Equal(t, "no open pull requests in OWNER/REPO", noResultsErr.Error())
})
}
/** LEGACY TESTS **/
func runCommand(rt http.RoundTripper, remotes context.Remotes, branch string, cli string, baseRepo ghrepo.Interface) (*test.CmdOut, error) {

View file

@ -7,7 +7,6 @@ import (
"github.com/cli/cli/v2/api"
"github.com/cli/cli/v2/internal/ghrepo"
prShared "github.com/cli/cli/v2/pkg/cmd/pr/shared"
"github.com/cli/cli/v2/pkg/cmdutil"
)
func shouldUseSearch(filters prShared.FilterOptions) bool {
@ -19,10 +18,8 @@ func listPullRequests(httpClient *http.Client, repo ghrepo.Interface, filters pr
return searchPullRequests(httpClient, repo, filters, limit)
}
return prShared.NewLister(&cmdutil.Factory{
HttpClient: func() (*http.Client, error) { return httpClient, nil },
BaseRepo: func() (ghrepo.Interface, error) { return repo, nil },
}).List(prShared.ListOptions{
return prShared.NewLister(httpClient).List(prShared.ListOptions{
BaseRepo: repo,
LimitResults: limit,
State: filters.State,
BaseBranch: filters.BaseBranch,

View file

@ -5,7 +5,6 @@ import (
"net/http"
"github.com/cli/cli/v2/internal/ghrepo"
"github.com/cli/cli/v2/pkg/cmdutil"
api "github.com/cli/cli/v2/api"
)
@ -15,6 +14,8 @@ type PRLister interface {
}
type ListOptions struct {
BaseRepo ghrepo.Interface
LimitResults int
State string
@ -25,32 +26,16 @@ type ListOptions struct {
}
type lister struct {
baseRepoFn func() (ghrepo.Interface, error)
httpClient func() (*http.Client, error)
httpClient *http.Client
}
func NewLister(factory *cmdutil.Factory) PRLister {
func NewLister(httpClient *http.Client) PRLister {
return &lister{
baseRepoFn: factory.BaseRepo,
httpClient: factory.HttpClient,
httpClient: httpClient,
}
}
func (l *lister) List(opts ListOptions) (*api.PullRequestAndTotalCount, error) {
repo, err := l.baseRepoFn()
if err != nil {
return nil, err
}
client, err := l.httpClient()
if err != nil {
return nil, err
}
return listPullRequests(client, repo, opts)
}
func listPullRequests(httpClient *http.Client, repo ghrepo.Interface, filters ListOptions) (*api.PullRequestAndTotalCount, error) {
type response struct {
Repository struct {
PullRequests struct {
@ -63,8 +48,8 @@ func listPullRequests(httpClient *http.Client, repo ghrepo.Interface, filters Li
}
}
}
limit := filters.LimitResults
fragment := fmt.Sprintf("fragment pr on PullRequest{%s}", api.PullRequestGraphQL(filters.Fields))
limit := opts.LimitResults
fragment := fmt.Sprintf("fragment pr on PullRequest{%s}", api.PullRequestGraphQL(opts.Fields))
query := fragment + `
query PullRequestList(
$owner: String!,
@ -98,11 +83,11 @@ func listPullRequests(httpClient *http.Client, repo ghrepo.Interface, filters Li
pageLimit := min(limit, 100)
variables := map[string]interface{}{
"owner": repo.RepoOwner(),
"repo": repo.RepoName(),
"owner": opts.BaseRepo.RepoOwner(),
"repo": opts.BaseRepo.RepoName(),
}
switch filters.State {
switch opts.State {
case "open":
variables["state"] = []string{"OPEN"}
case "closed":
@ -112,25 +97,25 @@ func listPullRequests(httpClient *http.Client, repo ghrepo.Interface, filters Li
case "all":
variables["state"] = []string{"OPEN", "CLOSED", "MERGED"}
default:
return nil, fmt.Errorf("invalid state: %s", filters.State)
return nil, fmt.Errorf("invalid state: %s", opts.State)
}
if filters.BaseBranch != "" {
variables["baseBranch"] = filters.BaseBranch
if opts.BaseBranch != "" {
variables["baseBranch"] = opts.BaseBranch
}
if filters.HeadBranch != "" {
variables["headBranch"] = filters.HeadBranch
if opts.HeadBranch != "" {
variables["headBranch"] = opts.HeadBranch
}
res := api.PullRequestAndTotalCount{}
var check = make(map[int]struct{})
client := api.NewClientFromHTTP(httpClient)
client := api.NewClientFromHTTP(l.httpClient)
loop:
for {
variables["limit"] = pageLimit
var data response
err := client.GraphQL(repo.RepoHost(), query, variables, &data)
err := client.GraphQL(opts.BaseRepo.RepoHost(), query, variables, &data)
if err != nil {
return nil, err
}