diff --git a/pkg/cmd/pr/checkout/checkout.go b/pkg/cmd/pr/checkout/checkout.go index bc697ae70..5d45856cd 100644 --- a/pkg/cmd/pr/checkout/checkout.go +++ b/pkg/cmd/pr/checkout/checkout.go @@ -33,6 +33,7 @@ type CheckoutOptions struct { RecurseSubmodules bool Force bool Detach bool + BranchName string } func NewCmdCheckout(f *cmdutil.Factory, runF func(*CheckoutOptions) error) *cobra.Command { @@ -65,6 +66,7 @@ func NewCmdCheckout(f *cmdutil.Factory, runF func(*CheckoutOptions) error) *cobr cmd.Flags().BoolVarP(&opts.RecurseSubmodules, "recurse-submodules", "", false, "Update all submodules after checkout") cmd.Flags().BoolVarP(&opts.Force, "force", "f", false, "Reset the existing local branch to the latest state of the pull request") cmd.Flags().BoolVarP(&opts.Detach, "detach", "", false, "Checkout PR with a detached HEAD") + cmd.Flags().StringVarP(&opts.BranchName, "branch", "b", "", "Local branch name to use (default: the name of the head branch)") return cmd } @@ -139,7 +141,6 @@ func checkoutRun(opts *CheckoutOptions) error { func cmdsForExistingRemote(remote *context.Remote, pr *api.PullRequest, opts *CheckoutOptions) [][]string { var cmds [][]string - remoteBranch := fmt.Sprintf("%s/%s", remote.Name, pr.HeadRefName) refSpec := fmt.Sprintf("+refs/heads/%s", pr.HeadRefName) @@ -149,11 +150,16 @@ func cmdsForExistingRemote(remote *context.Remote, pr *api.PullRequest, opts *Ch cmds = append(cmds, []string{"git", "fetch", remote.Name, refSpec}) + localBranch := pr.HeadRefName + if opts.BranchName != "" { + localBranch = opts.BranchName + } + switch { case opts.Detach: cmds = append(cmds, []string{"git", "checkout", "--detach", "FETCH_HEAD"}) - case localBranchExists(pr.HeadRefName): - cmds = append(cmds, []string{"git", "checkout", pr.HeadRefName}) + case localBranchExists(localBranch): + cmds = append(cmds, []string{"git", "checkout", localBranch}) if opts.Force { cmds = append(cmds, []string{"git", "reset", "--hard", fmt.Sprintf("refs/remotes/%s", remoteBranch)}) } else { @@ -161,9 +167,12 @@ func cmdsForExistingRemote(remote *context.Remote, pr *api.PullRequest, opts *Ch cmds = append(cmds, []string{"git", "merge", "--ff-only", fmt.Sprintf("refs/remotes/%s", remoteBranch)}) } default: - cmds = append(cmds, []string{"git", "checkout", "-b", pr.HeadRefName, "--no-track", remoteBranch}) - cmds = append(cmds, []string{"git", "config", fmt.Sprintf("branch.%s.remote", pr.HeadRefName), remote.Name}) - cmds = append(cmds, []string{"git", "config", fmt.Sprintf("branch.%s.merge", pr.HeadRefName), "refs/heads/" + pr.HeadRefName}) + cmds = append( + cmds, + []string{"git", "checkout", "-b", localBranch, "--no-track", remoteBranch}, + []string{"git", "config", fmt.Sprintf("branch.%s.remote", localBranch), remote.Name}, + []string{"git", "config", fmt.Sprintf("branch.%s.merge", pr.HeadRefName), "refs/heads/" + pr.HeadRefName}, + ) } return cmds @@ -204,7 +213,12 @@ func cmdsForMissingRemote(pr *api.PullRequest, baseURLOrName, repoHost, defaultB // TODO: check if non-fast-forward and suggest to use `--force` cmds = append(cmds, []string{"git", "fetch", baseURLOrName, fmt.Sprintf("%s:%s", ref, newBranchName)}) } - cmds = append(cmds, []string{"git", "checkout", newBranchName}) + + if opts.BranchName != "" { + cmds = append(cmds, []string{"git", "checkout", "-b", opts.BranchName, "--track", newBranchName}) + } else { + cmds = append(cmds, []string{"git", "checkout", newBranchName}) + } } remote := baseURLOrName diff --git a/pkg/cmd/pr/checkout/checkout_test.go b/pkg/cmd/pr/checkout/checkout_test.go index 9f4acea25..73e7141b6 100644 --- a/pkg/cmd/pr/checkout/checkout_test.go +++ b/pkg/cmd/pr/checkout/checkout_test.go @@ -100,6 +100,66 @@ func Test_checkoutRun(t *testing.T) { cs.Register(`git config branch\.feature\.merge refs/pull/123/head`, 0, "") }, }, + { + name: "with local branch rename", + opts: &CheckoutOptions{ + SelectorArg: "123", + BranchName: "foobar", + Finder: func() shared.PRFinder { + baseRepo, pr := stubPR("OWNER/REPO:master", "hubot/REPO:feature") + pr.MaintainerCanModify = true + pr.HeadRepository = nil + finder := shared.NewMockFinder("123", pr, baseRepo) + return finder + }(), + Config: func() (config.Config, error) { + return config.NewBlankConfig(), nil + }, + Branch: func() (string, error) { + return "main", nil + }, + }, + remotes: map[string]string{ + "origin": "OWNER/REPO", + }, + runStubs: func(cs *run.CommandStubber) { + cs.Register(`git fetch origin refs/pull/123/head:feature`, 0, "") + cs.Register(`git config branch\.feature\.merge`, 1, "") + cs.Register(`git checkout -b foobar --track feature`, 0, "") + cs.Register(`git config branch\.feature\.remote origin`, 0, "") + cs.Register(`git config branch\.feature\.merge refs/pull/123/head`, 0, "") + }, + }, + { + name: "with local branch name, no existing git remote", + opts: &CheckoutOptions{ + SelectorArg: "123", + BranchName: "foobar", + Finder: func() shared.PRFinder { + baseRepo, pr := stubPR("OWNER/REPO:master", "hubot/REPO:feature") + pr.MaintainerCanModify = true + pr.HeadRepository = nil + finder := shared.NewMockFinder("123", pr, baseRepo) + return finder + }(), + Config: func() (config.Config, error) { + return config.NewBlankConfig(), nil + }, + Branch: func() (string, error) { + return "main", nil + }, + }, + remotes: map[string]string{ + "origin": "OWNER/REPO", + }, + runStubs: func(cs *run.CommandStubber) { + cs.Register(`git fetch origin refs/pull/123/head:feature`, 0, "") + cs.Register(`git config branch\.feature\.merge`, 1, "") + cs.Register(`git checkout -b foobar --track feature`, 0, "") + cs.Register(`git config branch\.feature\.remote origin`, 0, "") + cs.Register(`git config branch\.feature\.merge refs/pull/123/head`, 0, "") + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {