Support detach head for pr checkout

This commit is contained in:
Sam Coe 2021-01-25 11:15:30 -08:00
parent d0a46399b7
commit cec3aa294e
No known key found for this signature in database
GPG key ID: 8E322C20F811D086
2 changed files with 150 additions and 72 deletions

View file

@ -31,6 +31,7 @@ type CheckoutOptions struct {
SelectorArg string
RecurseSubmodules bool
Force bool
Detach bool
}
func NewCmdCheckout(f *cmdutil.Factory, runF func(*CheckoutOptions) error) *cobra.Command {
@ -63,6 +64,7 @@ func NewCmdCheckout(f *cmdutil.Factory, runF func(*CheckoutOptions) error) *cobr
cmd.Flags().BoolVarP(&opts.RecurseSubmodules, "recurse-submodules", "", false, "Update all submodules after checkout")
cmd.Flags().BoolVarP(&opts.Force, "force", "f", false, "Reset the existing local branch to the latest state of the pull request")
cmd.Flags().BoolVarP(&opts.Detach, "detach", "", false, "Checkout PR with a detached HEAD")
return cmd
}
@ -88,10 +90,9 @@ func checkoutRun(opts *CheckoutOptions) error {
if err != nil {
return err
}
protocol, _ := cfg.Get(baseRepo.RepoHost(), "git_protocol")
protocol, _ := cfg.Get(baseRepo.RepoHost(), "git_protocol")
baseRemote, _ := remotes.FindByRepo(baseRepo.RepoOwner(), baseRepo.RepoName())
// baseRemoteSpec is a repository URL or a remote name to be used in git fetch
baseURLOrName := ghrepo.FormatRemoteURL(baseRepo, protocol)
if baseRemote != nil {
baseURLOrName = baseRemote.Name
@ -102,82 +103,20 @@ func checkoutRun(opts *CheckoutOptions) error {
headRemote, _ = remotes.FindByRepo(pr.HeadRepositoryOwner.Login, pr.HeadRepository.Name)
}
var cmdQueue [][]string
newBranchName := pr.HeadRefName
if strings.HasPrefix(newBranchName, "-") {
return fmt.Errorf("invalid branch name: %q", newBranchName)
if strings.HasPrefix(pr.HeadRefName, "-") {
return fmt.Errorf("invalid branch name: %q", pr.HeadRefName)
}
var cmdQueue [][]string
if headRemote != nil {
// there is an existing git remote for PR head
remoteBranch := fmt.Sprintf("%s/%s", headRemote.Name, pr.HeadRefName)
refSpec := fmt.Sprintf("+refs/heads/%s:refs/remotes/%s", pr.HeadRefName, remoteBranch)
cmdQueue = append(cmdQueue, []string{"git", "fetch", headRemote.Name, refSpec})
// local branch already exists
if _, err := git.ShowRefs("refs/heads/" + newBranchName); err == nil {
cmdQueue = append(cmdQueue, []string{"git", "checkout", newBranchName})
if opts.Force {
cmdQueue = append(cmdQueue, []string{"git", "reset", "--hard", fmt.Sprintf("refs/remotes/%s", remoteBranch)})
} else {
// TODO: check if non-fast-forward and suggest to use `--force`
cmdQueue = append(cmdQueue, []string{"git", "merge", "--ff-only", fmt.Sprintf("refs/remotes/%s", remoteBranch)})
}
} else {
cmdQueue = append(cmdQueue, []string{"git", "checkout", "-b", newBranchName, "--no-track", remoteBranch})
cmdQueue = append(cmdQueue, []string{"git", "config", fmt.Sprintf("branch.%s.remote", newBranchName), headRemote.Name})
cmdQueue = append(cmdQueue, []string{"git", "config", fmt.Sprintf("branch.%s.merge", newBranchName), "refs/heads/" + pr.HeadRefName})
}
cmdQueue = append(cmdQueue, cmdsForExistingRemote(headRemote, pr, opts)...)
} else {
// no git remote for PR head
currentBranch, _ := opts.Branch()
defaultBranchName, err := api.RepoDefaultBranch(apiClient, baseRepo)
defaultBranch, err := api.RepoDefaultBranch(apiClient, baseRepo)
if err != nil {
return err
}
// avoid naming the new branch the same as the default branch
if newBranchName == defaultBranchName {
newBranchName = fmt.Sprintf("%s/%s", pr.HeadRepositoryOwner.Login, newBranchName)
}
ref := fmt.Sprintf("refs/pull/%d/head", pr.Number)
if newBranchName == currentBranch {
// PR head matches currently checked out branch
cmdQueue = append(cmdQueue, []string{"git", "fetch", baseURLOrName, ref})
if opts.Force {
cmdQueue = append(cmdQueue, []string{"git", "reset", "--hard", "FETCH_HEAD"})
} else {
// TODO: check if non-fast-forward and suggest to use `--force`
cmdQueue = append(cmdQueue, []string{"git", "merge", "--ff-only", "FETCH_HEAD"})
}
} else {
// create a new branch
if opts.Force {
cmdQueue = append(cmdQueue, []string{"git", "fetch", baseURLOrName, fmt.Sprintf("%s:%s", ref, newBranchName), "--force"})
} else {
// TODO: check if non-fast-forward and suggest to use `--force`
cmdQueue = append(cmdQueue, []string{"git", "fetch", baseURLOrName, fmt.Sprintf("%s:%s", ref, newBranchName)})
}
cmdQueue = append(cmdQueue, []string{"git", "checkout", newBranchName})
}
remote := baseURLOrName
mergeRef := ref
if pr.MaintainerCanModify {
headRepo := ghrepo.NewWithHost(pr.HeadRepositoryOwner.Login, pr.HeadRepository.Name, baseRepo.RepoHost())
remote = ghrepo.FormatRemoteURL(headRepo, protocol)
mergeRef = fmt.Sprintf("refs/heads/%s", pr.HeadRefName)
}
if mc, err := git.Config(fmt.Sprintf("branch.%s.merge", newBranchName)); err != nil || mc == "" {
cmdQueue = append(cmdQueue, []string{"git", "config", fmt.Sprintf("branch.%s.remote", newBranchName), remote})
cmdQueue = append(cmdQueue, []string{"git", "config", fmt.Sprintf("branch.%s.merge", newBranchName), mergeRef})
}
cmdQueue = append(cmdQueue, cmdsForMissingRemote(pr, baseURLOrName, baseRepo.RepoHost(), defaultBranch, protocol, opts)...)
}
if opts.RecurseSubmodules {
@ -185,6 +124,110 @@ func checkoutRun(opts *CheckoutOptions) error {
cmdQueue = append(cmdQueue, []string{"git", "submodule", "update", "--init", "--recursive"})
}
err = executeCmds(cmdQueue)
if err != nil {
return err
}
return nil
}
func cmdsForExistingRemote(remote *context.Remote, pr *api.PullRequest, opts *CheckoutOptions) [][]string {
var cmds [][]string
remoteBranch := fmt.Sprintf("%s/%s", remote.Name, pr.HeadRefName)
refSpec := fmt.Sprintf("+refs/heads/%s", pr.HeadRefName)
if !opts.Detach {
refSpec += fmt.Sprintf(":refs/remotes/%s", remoteBranch)
}
cmds = append(cmds, []string{"git", "fetch", remote.Name, refSpec})
switch {
case opts.Detach:
cmds = append(cmds, []string{"git", "checkout", "--detach", "FETCH_HEAD"})
case localBranchExists(pr.HeadRefName):
cmds = append(cmds, []string{"git", "checkout", pr.HeadRefName})
if opts.Force {
cmds = append(cmds, []string{"git", "reset", "--hard", fmt.Sprintf("refs/remotes/%s", remoteBranch)})
} else {
// TODO: check if non-fast-forward and suggest to use `--force`
cmds = append(cmds, []string{"git", "merge", "--ff-only", fmt.Sprintf("refs/remotes/%s", remoteBranch)})
}
default:
cmds = append(cmds, []string{"git", "checkout", "-b", pr.HeadRefName, "--no-track", remoteBranch})
cmds = append(cmds, []string{"git", "config", fmt.Sprintf("branch.%s.remote", pr.HeadRefName), remote.Name})
cmds = append(cmds, []string{"git", "config", fmt.Sprintf("branch.%s.merge", pr.HeadRefName), "refs/heads/" + pr.HeadRefName})
}
return cmds
}
func cmdsForMissingRemote(pr *api.PullRequest, baseURLOrName, repoHost, defaultBranch, protocol string, opts *CheckoutOptions) [][]string {
var cmds [][]string
newBranchName := pr.HeadRefName
// avoid naming the new branch the same as the default branch
if newBranchName == defaultBranch {
newBranchName = fmt.Sprintf("%s/%s", pr.HeadRepositoryOwner.Login, newBranchName)
}
ref := fmt.Sprintf("refs/pull/%d/head", pr.Number)
if opts.Detach {
cmds = append(cmds, []string{"git", "fetch", baseURLOrName, ref})
cmds = append(cmds, []string{"git", "checkout", "--detach", "FETCH_HEAD"})
return cmds
}
currentBranch, _ := opts.Branch()
if newBranchName == currentBranch {
// PR head matches currently checked out branch
cmds = append(cmds, []string{"git", "fetch", baseURLOrName, ref})
if opts.Force {
cmds = append(cmds, []string{"git", "reset", "--hard", "FETCH_HEAD"})
} else {
// TODO: check if non-fast-forward and suggest to use `--force`
cmds = append(cmds, []string{"git", "merge", "--ff-only", "FETCH_HEAD"})
}
} else {
// create a new branch
if opts.Force {
cmds = append(cmds, []string{"git", "fetch", baseURLOrName, fmt.Sprintf("%s:%s", ref, newBranchName), "--force"})
} else {
// TODO: check if non-fast-forward and suggest to use `--force`
cmds = append(cmds, []string{"git", "fetch", baseURLOrName, fmt.Sprintf("%s:%s", ref, newBranchName)})
}
cmds = append(cmds, []string{"git", "checkout", newBranchName})
}
remote := baseURLOrName
mergeRef := ref
if pr.MaintainerCanModify {
headRepo := ghrepo.NewWithHost(pr.HeadRepositoryOwner.Login, pr.HeadRepository.Name, repoHost)
remote = ghrepo.FormatRemoteURL(headRepo, protocol)
mergeRef = fmt.Sprintf("refs/heads/%s", pr.HeadRefName)
}
if missingMergeConfigForBranch(newBranchName) {
cmds = append(cmds, []string{"git", "config", fmt.Sprintf("branch.%s.remote", newBranchName), remote})
cmds = append(cmds, []string{"git", "config", fmt.Sprintf("branch.%s.merge", newBranchName), mergeRef})
}
return cmds
}
func missingMergeConfigForBranch(b string) bool {
mc, err := git.Config(fmt.Sprintf("branch.%s.merge", b))
return err != nil || mc == ""
}
func localBranchExists(b string) bool {
_, err := git.ShowRefs("refs/heads/" + b)
return err == nil
}
func executeCmds(cmdQueue [][]string) error {
for _, args := range cmdQueue {
// TODO: reuse the result of this lookup across loop iteration
exe, err := safeexec.LookPath(args[0])
@ -198,6 +241,5 @@ func checkoutRun(opts *CheckoutOptions) error {
return err
}
}
return nil
}

View file

@ -688,3 +688,39 @@ func TestPRCheckout_force(t *testing.T) {
assert.Equal(t, "git checkout feature", strings.Join(ranCommands[1], " "))
assert.Equal(t, "git reset --hard refs/remotes/origin/feature", strings.Join(ranCommands[2], " "))
}
func TestPRCheckout_detach(t *testing.T) {
http := &httpmock.Registry{}
defer http.Verify(t)
http.Register(httpmock.GraphQL(`query PullRequestByNumber\b`), httpmock.StringResponse(`
{ "data": { "repository": { "pullRequest": {
"number": 123,
"headRef": "f8f8f8",
"headRepositoryOwner": {
"login": "hubot"
},
"headRepository": {
"name": "REPO"
},
"isCrossRepository": true,
"maintainerCanModify": true
} } } }
`))
ranCommands := [][]string{}
//nolint:staticcheck // SA1019 TODO: rewrite to use run.Stub
restoreCmd := run.SetPrepareCmd(func(cmd *exec.Cmd) run.Runnable {
ranCommands = append(ranCommands, cmd.Args)
return &test.OutputStub{}
})
defer restoreCmd()
output, err := runCommand(http, nil, "", `123 --detach`)
assert.Nil(t, err)
assert.Equal(t, "", output.String())
assert.Equal(t, 2, len(ranCommands))
assert.Equal(t, "git fetch origin refs/pull/123/head", strings.Join(ranCommands[0], " "))
assert.Equal(t, "git checkout --detach FETCH_HEAD", strings.Join(ranCommands[1], " "))
}