diff --git a/pkg/cmd/pr/checkout/checkout.go b/pkg/cmd/pr/checkout/checkout.go index bc697ae70..651a7b3f4 100644 --- a/pkg/cmd/pr/checkout/checkout.go +++ b/pkg/cmd/pr/checkout/checkout.go @@ -33,6 +33,7 @@ type CheckoutOptions struct { RecurseSubmodules bool Force bool Detach bool + BranchName string } func NewCmdCheckout(f *cmdutil.Factory, runF func(*CheckoutOptions) error) *cobra.Command { @@ -65,6 +66,7 @@ func NewCmdCheckout(f *cmdutil.Factory, runF func(*CheckoutOptions) error) *cobr cmd.Flags().BoolVarP(&opts.RecurseSubmodules, "recurse-submodules", "", false, "Update all submodules after checkout") cmd.Flags().BoolVarP(&opts.Force, "force", "f", false, "Reset the existing local branch to the latest state of the pull request") cmd.Flags().BoolVarP(&opts.Detach, "detach", "", false, "Checkout PR with a detached HEAD") + cmd.Flags().StringVarP(&opts.BranchName, "branch", "b", "", "Local branch name to use (default: the name of the head branch)") return cmd } @@ -139,7 +141,6 @@ func checkoutRun(opts *CheckoutOptions) error { func cmdsForExistingRemote(remote *context.Remote, pr *api.PullRequest, opts *CheckoutOptions) [][]string { var cmds [][]string - remoteBranch := fmt.Sprintf("%s/%s", remote.Name, pr.HeadRefName) refSpec := fmt.Sprintf("+refs/heads/%s", pr.HeadRefName) @@ -149,11 +150,16 @@ func cmdsForExistingRemote(remote *context.Remote, pr *api.PullRequest, opts *Ch cmds = append(cmds, []string{"git", "fetch", remote.Name, refSpec}) + localBranch := pr.HeadRefName + if opts.BranchName != "" { + localBranch = opts.BranchName + } + switch { case opts.Detach: cmds = append(cmds, []string{"git", "checkout", "--detach", "FETCH_HEAD"}) - case localBranchExists(pr.HeadRefName): - cmds = append(cmds, []string{"git", "checkout", pr.HeadRefName}) + case localBranchExists(localBranch): + cmds = append(cmds, []string{"git", "checkout", localBranch}) if opts.Force { cmds = append(cmds, []string{"git", "reset", "--hard", fmt.Sprintf("refs/remotes/%s", remoteBranch)}) } else { @@ -161,9 +167,7 @@ func cmdsForExistingRemote(remote *context.Remote, pr *api.PullRequest, opts *Ch cmds = append(cmds, []string{"git", "merge", "--ff-only", fmt.Sprintf("refs/remotes/%s", remoteBranch)}) } default: - cmds = append(cmds, []string{"git", "checkout", "-b", pr.HeadRefName, "--no-track", remoteBranch}) - cmds = append(cmds, []string{"git", "config", fmt.Sprintf("branch.%s.remote", pr.HeadRefName), remote.Name}) - cmds = append(cmds, []string{"git", "config", fmt.Sprintf("branch.%s.merge", pr.HeadRefName), "refs/heads/" + pr.HeadRefName}) + cmds = append(cmds, []string{"git", "checkout", "-b", localBranch, "--track", remoteBranch}) } return cmds @@ -171,13 +175,6 @@ func cmdsForExistingRemote(remote *context.Remote, pr *api.PullRequest, opts *Ch func cmdsForMissingRemote(pr *api.PullRequest, baseURLOrName, repoHost, defaultBranch, protocol string, opts *CheckoutOptions) [][]string { var cmds [][]string - - newBranchName := pr.HeadRefName - // avoid naming the new branch the same as the default branch - if newBranchName == defaultBranch { - newBranchName = fmt.Sprintf("%s/%s", pr.HeadRepositoryOwner.Login, newBranchName) - } - ref := fmt.Sprintf("refs/pull/%d/head", pr.Number) if opts.Detach { @@ -186,8 +183,16 @@ func cmdsForMissingRemote(pr *api.PullRequest, baseURLOrName, repoHost, defaultB return cmds } + localBranch := pr.HeadRefName + if opts.BranchName != "" { + localBranch = opts.BranchName + } else if pr.HeadRefName == defaultBranch { + // avoid naming the new branch the same as the default branch + localBranch = fmt.Sprintf("%s/%s", pr.HeadRepositoryOwner.Login, localBranch) + } + currentBranch, _ := opts.Branch() - if newBranchName == currentBranch { + if localBranch == currentBranch { // PR head matches currently checked out branch cmds = append(cmds, []string{"git", "fetch", baseURLOrName, ref}) if opts.Force { @@ -197,14 +202,14 @@ func cmdsForMissingRemote(pr *api.PullRequest, baseURLOrName, repoHost, defaultB cmds = append(cmds, []string{"git", "merge", "--ff-only", "FETCH_HEAD"}) } } else { - // create a new branch if opts.Force { - cmds = append(cmds, []string{"git", "fetch", baseURLOrName, fmt.Sprintf("%s:%s", ref, newBranchName), "--force"}) + cmds = append(cmds, []string{"git", "fetch", baseURLOrName, fmt.Sprintf("%s:%s", ref, localBranch), "--force"}) } else { // TODO: check if non-fast-forward and suggest to use `--force` - cmds = append(cmds, []string{"git", "fetch", baseURLOrName, fmt.Sprintf("%s:%s", ref, newBranchName)}) + cmds = append(cmds, []string{"git", "fetch", baseURLOrName, fmt.Sprintf("%s:%s", ref, localBranch)}) } - cmds = append(cmds, []string{"git", "checkout", newBranchName}) + + cmds = append(cmds, []string{"git", "checkout", localBranch}) } remote := baseURLOrName @@ -214,9 +219,9 @@ func cmdsForMissingRemote(pr *api.PullRequest, baseURLOrName, repoHost, defaultB remote = ghrepo.FormatRemoteURL(headRepo, protocol) mergeRef = fmt.Sprintf("refs/heads/%s", pr.HeadRefName) } - if missingMergeConfigForBranch(newBranchName) { - cmds = append(cmds, []string{"git", "config", fmt.Sprintf("branch.%s.remote", newBranchName), remote}) - cmds = append(cmds, []string{"git", "config", fmt.Sprintf("branch.%s.merge", newBranchName), mergeRef}) + if missingMergeConfigForBranch(localBranch) { + cmds = append(cmds, []string{"git", "config", fmt.Sprintf("branch.%s.remote", localBranch), remote}) + cmds = append(cmds, []string{"git", "config", fmt.Sprintf("branch.%s.merge", localBranch), mergeRef}) } return cmds diff --git a/pkg/cmd/pr/checkout/checkout_test.go b/pkg/cmd/pr/checkout/checkout_test.go index 9f4acea25..7386cf2fc 100644 --- a/pkg/cmd/pr/checkout/checkout_test.go +++ b/pkg/cmd/pr/checkout/checkout_test.go @@ -100,6 +100,61 @@ func Test_checkoutRun(t *testing.T) { cs.Register(`git config branch\.feature\.merge refs/pull/123/head`, 0, "") }, }, + { + name: "with local branch rename and existing git remote", + opts: &CheckoutOptions{ + SelectorArg: "123", + BranchName: "foobar", + Finder: func() shared.PRFinder { + baseRepo, pr := stubPR("OWNER/REPO:master", "OWNER/REPO:feature") + finder := shared.NewMockFinder("123", pr, baseRepo) + return finder + }(), + Config: func() (config.Config, error) { + return config.NewBlankConfig(), nil + }, + Branch: func() (string, error) { + return "main", nil + }, + }, + remotes: map[string]string{ + "origin": "OWNER/REPO", + }, + runStubs: func(cs *run.CommandStubber) { + cs.Register(`git show-ref --verify -- refs/heads/foobar`, 1, "") + cs.Register(`git fetch origin \+refs/heads/feature:refs/remotes/origin/feature`, 0, "") + cs.Register(`git checkout -b foobar --track origin/feature`, 0, "") + }, + }, + { + name: "with local branch name, no existing git remote", + opts: &CheckoutOptions{ + SelectorArg: "123", + BranchName: "foobar", + Finder: func() shared.PRFinder { + baseRepo, pr := stubPR("OWNER/REPO:master", "hubot/REPO:feature") + pr.MaintainerCanModify = true + finder := shared.NewMockFinder("123", pr, baseRepo) + return finder + }(), + Config: func() (config.Config, error) { + return config.NewBlankConfig(), nil + }, + Branch: func() (string, error) { + return "main", nil + }, + }, + remotes: map[string]string{ + "origin": "OWNER/REPO", + }, + runStubs: func(cs *run.CommandStubber) { + cs.Register(`git config branch\.foobar\.merge`, 1, "") + cs.Register(`git fetch origin refs/pull/123/head:foobar`, 0, "") + cs.Register(`git checkout foobar`, 0, "") + cs.Register(`git config branch\.foobar\.remote https://github.com/hubot/REPO.git`, 0, "") + cs.Register(`git config branch\.foobar\.merge refs/heads/feature`, 0, "") + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -212,9 +267,7 @@ func TestPRCheckout_sameRepo(t *testing.T) { cs.Register(`git fetch origin \+refs/heads/feature:refs/remotes/origin/feature`, 0, "") cs.Register(`git show-ref --verify -- refs/heads/feature`, 1, "") - cs.Register(`git checkout -b feature --no-track origin/feature`, 0, "") - cs.Register(`git config branch\.feature\.remote origin`, 0, "") - cs.Register(`git config branch\.feature\.merge refs/heads/feature`, 0, "") + cs.Register(`git checkout -b feature --track origin/feature`, 0, "") output, err := runCommand(http, nil, "master", `123`) assert.NoError(t, err) @@ -267,9 +320,7 @@ func TestPRCheckout_differentRepo_remoteExists(t *testing.T) { cs.Register(`git fetch robot-fork \+refs/heads/feature:refs/remotes/robot-fork/feature`, 0, "") cs.Register(`git show-ref --verify -- refs/heads/feature`, 1, "") - cs.Register(`git checkout -b feature --no-track robot-fork/feature`, 0, "") - cs.Register(`git config branch\.feature\.remote robot-fork`, 0, "") - cs.Register(`git config branch\.feature\.merge refs/heads/feature`, 0, "") + cs.Register(`git checkout -b feature --track robot-fork/feature`, 0, "") output, err := runCommand(http, remotes, "master", `123`) assert.NoError(t, err)