package checkout import ( "context" "fmt" "net/http" "strings" "github.com/MakeNowJust/heredoc" "github.com/cli/cli/v2/api" cliContext "github.com/cli/cli/v2/context" "github.com/cli/cli/v2/git" "github.com/cli/cli/v2/internal/gh" "github.com/cli/cli/v2/internal/ghrepo" "github.com/cli/cli/v2/internal/text" "github.com/cli/cli/v2/pkg/cmd/pr/shared" "github.com/cli/cli/v2/pkg/cmdutil" "github.com/cli/cli/v2/pkg/iostreams" "github.com/spf13/cobra" ) type CheckoutOptions struct { HttpClient func() (*http.Client, error) GitClient *git.Client Config func() (gh.Config, error) IO *iostreams.IOStreams Remotes func() (cliContext.Remotes, error) Branch func() (string, error) Finder shared.PRFinder Prompter shared.Prompt Lister shared.PRLister Interactive bool BaseRepo func() (ghrepo.Interface, error) SelectorArg string RecurseSubmodules bool Force bool Detach bool BranchName string } func NewCmdCheckout(f *cmdutil.Factory, runF func(*CheckoutOptions) error) *cobra.Command { opts := &CheckoutOptions{ IO: f.IOStreams, HttpClient: f.HttpClient, GitClient: f.GitClient, Config: f.Config, Remotes: f.Remotes, Branch: f.Branch, Prompter: f.Prompter, BaseRepo: f.BaseRepo, } cmd := &cobra.Command{ Use: "checkout [ | | ]", Short: "Check out a pull request in git", Example: heredoc.Doc(` # Interactively select a PR from the 10 most recent to check out $ gh pr checkout # Checkout a specific PR $ gh pr checkout 32 $ gh pr checkout https://github.com/OWNER/REPO/pull/32 $ gh pr checkout feature `), Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { opts.Finder = shared.NewFinder(f) opts.Lister = shared.NewLister(f) if len(args) > 0 { opts.SelectorArg = args[0] } else if !opts.IO.CanPrompt() { return cmdutil.FlagErrorf("pull request number, URL, or branch required when not running interactively") } else { opts.Interactive = true } if runF != nil { return runF(opts) } return checkoutRun(opts) }, } cmd.Flags().BoolVarP(&opts.RecurseSubmodules, "recurse-submodules", "", false, "Update all submodules after checkout") cmd.Flags().BoolVarP(&opts.Force, "force", "f", false, "Reset the existing local branch to the latest state of the pull request") cmd.Flags().BoolVarP(&opts.Detach, "detach", "", false, "Checkout PR with a detached HEAD") cmd.Flags().StringVarP(&opts.BranchName, "branch", "b", "", "Local branch name to use (default [the name of the head branch])") return cmd } func checkoutRun(opts *CheckoutOptions) error { baseRepo, err := opts.BaseRepo() if err != nil { return err } pr, err := resolvePR(baseRepo, opts.Prompter, opts.SelectorArg, opts.Interactive, opts.Finder, opts.Lister, opts.IO) if err != nil { return err } cfg, err := opts.Config() if err != nil { return err } protocol := cfg.GitProtocol(baseRepo.RepoHost()).Value remotes, err := opts.Remotes() if err != nil { return err } baseRemote, _ := remotes.FindByRepo(baseRepo.RepoOwner(), baseRepo.RepoName()) baseURLOrName := ghrepo.FormatRemoteURL(baseRepo, protocol) if baseRemote != nil { baseURLOrName = baseRemote.Name } headRemote := baseRemote if pr.HeadRepository == nil { headRemote = nil } else if pr.IsCrossRepository { headRemote, _ = remotes.FindByRepo(pr.HeadRepositoryOwner.Login, pr.HeadRepository.Name) } if strings.HasPrefix(pr.HeadRefName, "-") { return fmt.Errorf("invalid branch name: %q", pr.HeadRefName) } var cmdQueue [][]string if headRemote != nil { cmdQueue = append(cmdQueue, cmdsForExistingRemote(headRemote, pr, opts)...) } else { httpClient, err := opts.HttpClient() if err != nil { return err } apiClient := api.NewClientFromHTTP(httpClient) defaultBranch, err := api.RepoDefaultBranch(apiClient, baseRepo) if err != nil { return err } cmdQueue = append(cmdQueue, cmdsForMissingRemote(pr, baseURLOrName, baseRepo.RepoHost(), defaultBranch, protocol, opts)...) } if opts.RecurseSubmodules { cmdQueue = append(cmdQueue, []string{"submodule", "sync", "--recursive"}) cmdQueue = append(cmdQueue, []string{"submodule", "update", "--init", "--recursive"}) } // Note that although we will probably be fetching from the head, in practice, PR checkout can only // ever point to one host, and we know baseRepo must be populated. err = executeCmds(opts.GitClient, git.CredentialPatternFromHost(baseRepo.RepoHost()), cmdQueue) if err != nil { return err } return nil } func cmdsForExistingRemote(remote *cliContext.Remote, pr *api.PullRequest, opts *CheckoutOptions) [][]string { var cmds [][]string remoteBranch := fmt.Sprintf("%s/%s", remote.Name, pr.HeadRefName) refSpec := fmt.Sprintf("+refs/heads/%s", pr.HeadRefName) if !opts.Detach { refSpec += fmt.Sprintf(":refs/remotes/%s", remoteBranch) } cmds = append(cmds, []string{"fetch", remote.Name, refSpec}) localBranch := pr.HeadRefName if opts.BranchName != "" { localBranch = opts.BranchName } switch { case opts.Detach: cmds = append(cmds, []string{"checkout", "--detach", "FETCH_HEAD"}) case localBranchExists(opts.GitClient, localBranch): cmds = append(cmds, []string{"checkout", localBranch}) if opts.Force { cmds = append(cmds, []string{"reset", "--hard", fmt.Sprintf("refs/remotes/%s", remoteBranch)}) } else { // TODO: check if non-fast-forward and suggest to use `--force` cmds = append(cmds, []string{"merge", "--ff-only", fmt.Sprintf("refs/remotes/%s", remoteBranch)}) } default: cmds = append(cmds, []string{"checkout", "-b", localBranch, "--track", remoteBranch}) } return cmds } func cmdsForMissingRemote(pr *api.PullRequest, baseURLOrName, repoHost, defaultBranch, protocol string, opts *CheckoutOptions) [][]string { var cmds [][]string ref := fmt.Sprintf("refs/pull/%d/head", pr.Number) if opts.Detach { cmds = append(cmds, []string{"fetch", baseURLOrName, ref}) cmds = append(cmds, []string{"checkout", "--detach", "FETCH_HEAD"}) return cmds } localBranch := pr.HeadRefName if opts.BranchName != "" { localBranch = opts.BranchName } else if pr.HeadRefName == defaultBranch { // avoid naming the new branch the same as the default branch localBranch = fmt.Sprintf("%s/%s", pr.HeadRepositoryOwner.Login, localBranch) } currentBranch, _ := opts.Branch() if localBranch == currentBranch { // PR head matches currently checked out branch cmds = append(cmds, []string{"fetch", baseURLOrName, ref}) if opts.Force { cmds = append(cmds, []string{"reset", "--hard", "FETCH_HEAD"}) } else { // TODO: check if non-fast-forward and suggest to use `--force` cmds = append(cmds, []string{"merge", "--ff-only", "FETCH_HEAD"}) } } else { if opts.Force { cmds = append(cmds, []string{"fetch", baseURLOrName, fmt.Sprintf("%s:%s", ref, localBranch), "--force"}) } else { // TODO: check if non-fast-forward and suggest to use `--force` cmds = append(cmds, []string{"fetch", baseURLOrName, fmt.Sprintf("%s:%s", ref, localBranch)}) } cmds = append(cmds, []string{"checkout", localBranch}) } remote := baseURLOrName mergeRef := ref if pr.MaintainerCanModify && pr.HeadRepository != nil { headRepo := ghrepo.NewWithHost(pr.HeadRepositoryOwner.Login, pr.HeadRepository.Name, repoHost) remote = ghrepo.FormatRemoteURL(headRepo, protocol) mergeRef = fmt.Sprintf("refs/heads/%s", pr.HeadRefName) } if missingMergeConfigForBranch(opts.GitClient, localBranch) { // .remote is needed for `git pull` to work // .pushRemote is needed for `git push` to work, if user has set `remote.pushDefault`. // see https://git-scm.com/docs/git-config#Documentation/git-config.txt-branchltnamegtremote cmds = append(cmds, []string{"config", fmt.Sprintf("branch.%s.remote", localBranch), remote}) cmds = append(cmds, []string{"config", fmt.Sprintf("branch.%s.pushRemote", localBranch), remote}) cmds = append(cmds, []string{"config", fmt.Sprintf("branch.%s.merge", localBranch), mergeRef}) } return cmds } func missingMergeConfigForBranch(client *git.Client, b string) bool { mc, err := client.Config(context.Background(), fmt.Sprintf("branch.%s.merge", b)) return err != nil || mc == "" } func localBranchExists(client *git.Client, b string) bool { _, err := client.ShowRefs(context.Background(), []string{"refs/heads/" + b}) return err == nil } func executeCmds(client *git.Client, credentialPattern git.CredentialPattern, cmdQueue [][]string) error { for _, args := range cmdQueue { var err error var cmd *git.Command switch args[0] { case "submodule": cmd, err = client.AuthenticatedCommand(context.Background(), credentialPattern, args...) case "fetch": cmd, err = client.AuthenticatedCommand(context.Background(), git.AllMatchingCredentialsPattern, args...) default: cmd, err = client.Command(context.Background(), args...) } if err != nil { return err } if err := cmd.Run(); err != nil { return err } } return nil } func resolvePR(baseRepo ghrepo.Interface, prompter shared.Prompt, pullRequestSelector string, isInteractive bool, pullRequestFinder shared.PRFinder, prLister shared.PRLister, io *iostreams.IOStreams) (*api.PullRequest, error) { // When non-interactive if pullRequestSelector != "" { pr, _, err := pullRequestFinder.Find(shared.FindOptions{ Selector: pullRequestSelector, Fields: []string{ "number", "headRefName", "headRepository", "headRepositoryOwner", "isCrossRepository", "maintainerCanModify", }, }) if err != nil { return nil, err } return pr, nil } if !isInteractive { return nil, cmdutil.FlagErrorf("pull request number, URL, or branch required when not running interactively") } // When interactive io.StartProgressIndicator() listResult, err := prLister.List(shared.ListOptions{ State: "open", Fields: []string{ "number", "title", "state", "isDraft", "headRefName", "headRepository", "headRepositoryOwner", "isCrossRepository", "maintainerCanModify", }, LimitResults: 10}) io.StopProgressIndicator() if err != nil { return nil, err } if len(listResult.PullRequests) == 0 { return nil, shared.ListNoResults(ghrepo.FullName(baseRepo), "pull request", false) } pr, err := promptForPR(prompter, *listResult) return pr, err } func promptForPR(prompter shared.Prompt, jobs api.PullRequestAndTotalCount) (*api.PullRequest, error) { candidates := []string{} for _, pr := range jobs.PullRequests { candidates = append(candidates, fmt.Sprintf("%d\t%s %s [%s]", pr.Number, shared.PrStateWithDraft(&pr), text.RemoveExcessiveWhitespace(pr.Title), pr.HeadLabel(), )) } selected, err := prompter.Select("Select a pull request", "", candidates) if err != nil { return nil, err } if selected >= 0 { return &jobs.PullRequests[selected], nil } return nil, nil }