diff --git a/command/pr.go b/command/pr.go index aa774a2d3..2f59934a7 100644 --- a/command/pr.go +++ b/command/pr.go @@ -302,7 +302,7 @@ func prCheckout(cmd *cobra.Command, args []string) error { cmdQueue = append(cmdQueue, []string{"git", "fetch", headRemote.Name, refSpec}) // local branch already exists - if git.HasFile("refs", "heads", newBranchName) { + if git.VerifyRef("refs/heads/" + newBranchName) { cmdQueue = append(cmdQueue, []string{"git", "checkout", newBranchName}) cmdQueue = append(cmdQueue, []string{"git", "merge", "--ff-only", fmt.Sprintf("refs/remotes/%s", remoteBranch)}) } else { diff --git a/command/pr_checkout_test.go b/command/pr_checkout_test.go index 99dc2281d..9f85f4f95 100644 --- a/command/pr_checkout_test.go +++ b/command/pr_checkout_test.go @@ -40,8 +40,13 @@ func TestPRCheckout_sameRepo(t *testing.T) { ranCommands := [][]string{} restoreCmd := utils.SetPrepareCmd(func(cmd *exec.Cmd) utils.Runnable { - ranCommands = append(ranCommands, cmd.Args) - return &outputStub{} + switch strings.Join(cmd.Args, " ") { + case "git show-ref --verify --quiet refs/heads/feature": + return &errorStub{"exit status: 1"} + default: + ranCommands = append(ranCommands, cmd.Args) + return &outputStub{} + } }) defer restoreCmd() @@ -49,13 +54,61 @@ func TestPRCheckout_sameRepo(t *testing.T) { _, err := prCheckoutCmd.ExecuteC() eq(t, err, nil) - eq(t, len(ranCommands), 6) - eq(t, strings.Join(ranCommands[0], " "), "git rev-parse -q --git-path refs/heads/feature") - eq(t, strings.Join(ranCommands[1], " "), "git rev-parse -q --git-dir") - eq(t, strings.Join(ranCommands[2], " "), "git fetch origin +refs/heads/feature:refs/remotes/origin/feature") - eq(t, strings.Join(ranCommands[3], " "), "git checkout -b feature --no-track origin/feature") - eq(t, strings.Join(ranCommands[4], " "), "git config branch.feature.remote origin") - eq(t, strings.Join(ranCommands[5], " "), "git config branch.feature.merge refs/heads/feature") + eq(t, len(ranCommands), 4) + eq(t, strings.Join(ranCommands[0], " "), "git fetch origin +refs/heads/feature:refs/remotes/origin/feature") + eq(t, strings.Join(ranCommands[1], " "), "git checkout -b feature --no-track origin/feature") + eq(t, strings.Join(ranCommands[2], " "), "git config branch.feature.remote origin") + eq(t, strings.Join(ranCommands[3], " "), "git config branch.feature.merge refs/heads/feature") +} + +func TestPRCheckout_existingBranch(t *testing.T) { + ctx := context.NewBlank() + ctx.SetBranch("master") + ctx.SetRemotes(map[string]string{ + "origin": "OWNER/REPO", + }) + initContext = func() context.Context { + return ctx + } + http := initFakeHTTP() + + http.StubResponse(200, bytes.NewBufferString(` + { "data": { "repository": { "pullRequest": { + "headRefName": "feature", + "headRepositoryOwner": { + "login": "hubot" + }, + "headRepository": { + "name": "REPO", + "defaultBranchRef": { + "name": "master" + } + }, + "isCrossRepository": false, + "maintainerCanModify": false + } } } } + `)) + + ranCommands := [][]string{} + restoreCmd := utils.SetPrepareCmd(func(cmd *exec.Cmd) utils.Runnable { + switch strings.Join(cmd.Args, " ") { + case "git show-ref --verify --quiet refs/heads/feature": + return &outputStub{} + default: + ranCommands = append(ranCommands, cmd.Args) + return &outputStub{} + } + }) + defer restoreCmd() + + RootCmd.SetArgs([]string{"pr", "checkout", "123"}) + _, err := prCheckoutCmd.ExecuteC() + eq(t, err, nil) + + eq(t, len(ranCommands), 3) + eq(t, strings.Join(ranCommands[0], " "), "git fetch origin +refs/heads/feature:refs/remotes/origin/feature") + eq(t, strings.Join(ranCommands[1], " "), "git checkout feature") + eq(t, strings.Join(ranCommands[2], " "), "git merge --ff-only refs/remotes/origin/feature") } func TestPRCheckout_differentRepo(t *testing.T) { diff --git a/command/testing.go b/command/testing.go index 2e1cf9505..7758d4a8b 100644 --- a/command/testing.go +++ b/command/testing.go @@ -1,6 +1,8 @@ package command import ( + "errors" + "github.com/github/gh-cli/api" "github.com/github/gh-cli/context" ) @@ -34,3 +36,15 @@ func (s outputStub) Output() ([]byte, error) { func (s outputStub) Run() error { return nil } + +type errorStub struct { + message string +} + +func (s errorStub) Output() ([]byte, error) { + return nil, errors.New(s.message) +} + +func (s errorStub) Run() error { + return errors.New(s.message) +} diff --git a/git/git.go b/git/git.go index 318a9ddcc..8824872a4 100644 --- a/git/git.go +++ b/git/git.go @@ -41,31 +41,10 @@ func WorkdirName() (string, error) { return dir, err } -func HasFile(segments ...string) bool { - // The blessed way to resolve paths within git dir since Git 2.5.0 - pathCmd := exec.Command("git", "rev-parse", "-q", "--git-path", filepath.Join(segments...)) - if output, err := utils.PrepareCmd(pathCmd).Output(); err == nil { - if lines := outputLines(output); len(lines) == 1 { - if _, err := os.Stat(lines[0]); err == nil { - return true - } - } - } - - // Fallback for older git versions - dir, err := Dir() - if err != nil { - return false - } - - s := []string{dir} - s = append(s, segments...) - path := filepath.Join(s...) - if _, err := os.Stat(path); err == nil { - return true - } - - return false +func VerifyRef(ref string) bool { + showRef := exec.Command("git", "show-ref", "--verify", "--quiet", ref) + err := utils.PrepareCmd(showRef).Run() + return err == nil } func BranchAtRef(paths ...string) (name string, err error) {