Fix pr checkout <owner>:<branch> when it matches the default branch

First, consolidate the functionality between `pr merge` and `pr
checkout` that resolves the default branch name of the base repo. With
an added bonus, the new approach avoids an API request when one isn't
necessary.

Then, ensure that checking out 3rd-party PRs will result in local branch
name such as `<owner>/<branch>` when the head branch of the repository
matches the default branch of the base repository. We already have had
code in place to take care of this, but it only took effect in the `pr
checkout <number>`-style invocation.
This commit is contained in:
Mislav Marohnić 2020-07-15 15:35:42 +02:00
parent 6825944cad
commit 305cd290ee
7 changed files with 81 additions and 145 deletions

View file

@ -432,9 +432,6 @@ func PullRequestByNumber(client *Client, repo ghrepo.Interface, number int) (*Pu
}
headRepository {
name
defaultBranchRef {
name
}
}
isCrossRepository
isDraft

View file

@ -114,6 +114,18 @@ func GitHubRepo(client *Client, repo ghrepo.Interface) (*Repository, error) {
return initRepoHostname(&result.Repository, repo.RepoHost()), nil
}
func RepoDefaultBranch(client *Client, repo ghrepo.Interface) (string, error) {
if r, ok := repo.(*Repository); ok && r.DefaultBranchRef.Name != "" {
return r.DefaultBranchRef.Name, nil
}
r, err := GitHubRepo(client, repo)
if err != nil {
return "", err
}
return r.DefaultBranchRef.Name, nil
}
// RepoParent finds out the parent repository of a fork
func RepoParent(client *Client, repo ghrepo.Interface) (ghrepo.Interface, error) {
var query struct {

View file

@ -493,23 +493,21 @@ func prMerge(cmd *cobra.Command, args []string) error {
fmt.Fprintf(colorableOut(cmd), "%s %s pull request #%d (%s)\n", utils.Magenta("✔"), action, pr.Number, pr.Title)
if deleteBranch {
repo, err := api.GitHubRepo(apiClient, baseRepo)
if err != nil {
return err
}
currentBranch, err := ctx.Branch()
if err != nil {
return err
}
branchSwitchString := ""
if deleteLocalBranch && !crossRepoPR {
currentBranch, err := ctx.Branch()
if err != nil {
return err
}
var branchToSwitchTo string
if currentBranch == pr.HeadRefName {
branchToSwitchTo = repo.DefaultBranchRef.Name
err = git.CheckoutBranch(repo.DefaultBranchRef.Name)
branchToSwitchTo, err = api.RepoDefaultBranch(apiClient, baseRepo)
if err != nil {
return err
}
err = git.CheckoutBranch(branchToSwitchTo)
if err != nil {
return err
}

View file

@ -8,6 +8,7 @@ import (
"github.com/spf13/cobra"
"github.com/cli/cli/api"
"github.com/cli/cli/git"
"github.com/cli/cli/internal/ghrepo"
"github.com/cli/cli/internal/run"
@ -65,8 +66,13 @@ func prCheckout(cmd *cobra.Command, args []string) error {
} else {
// no git remote for PR head
defaultBranchName, err := api.RepoDefaultBranch(apiClient, baseRepo)
if err != nil {
return err
}
// avoid naming the new branch the same as the default branch
if newBranchName == pr.HeadRepository.DefaultBranchRef.Name {
if newBranchName == defaultBranchName {
newBranchName = fmt.Sprintf("%s/%s", pr.HeadRepositoryOwner.Login, newBranchName)
}

View file

@ -33,10 +33,7 @@ func TestPRCheckout_sameRepo(t *testing.T) {
"login": "hubot"
},
"headRepository": {
"name": "REPO",
"defaultBranchRef": {
"name": "master"
}
"name": "REPO"
},
"isCrossRepository": false,
"maintainerCanModify": false
@ -84,10 +81,7 @@ func TestPRCheckout_urlArg(t *testing.T) {
"login": "hubot"
},
"headRepository": {
"name": "REPO",
"defaultBranchRef": {
"name": "master"
}
"name": "REPO"
},
"isCrossRepository": false,
"maintainerCanModify": false
@ -132,15 +126,17 @@ func TestPRCheckout_urlArg_differentBase(t *testing.T) {
"login": "hubot"
},
"headRepository": {
"name": "POE",
"defaultBranchRef": {
"name": "master"
}
"name": "POE"
},
"isCrossRepository": false,
"maintainerCanModify": false
} } } }
`))
http.Register(httpmock.GraphQL(`query RepositoryInfo\b`), httpmock.StringResponse(`
{ "data": { "repository": {
"defaultBranchRef": {"name": "master"}
} } }
`))
ranCommands := [][]string{}
restoreCmd := run.SetPrepareCmd(func(cmd *exec.Cmd) run.Runnable {
@ -195,10 +191,7 @@ func TestPRCheckout_branchArg(t *testing.T) {
"login": "hubot"
},
"headRepository": {
"name": "REPO",
"defaultBranchRef": {
"name": "master"
}
"name": "REPO"
},
"isCrossRepository": true,
"maintainerCanModify": false }
@ -245,10 +238,7 @@ func TestPRCheckout_existingBranch(t *testing.T) {
"login": "hubot"
},
"headRepository": {
"name": "REPO",
"defaultBranchRef": {
"name": "master"
}
"name": "REPO"
},
"isCrossRepository": false,
"maintainerCanModify": false
@ -298,10 +288,7 @@ func TestPRCheckout_differentRepo_remoteExists(t *testing.T) {
"login": "hubot"
},
"headRepository": {
"name": "REPO",
"defaultBranchRef": {
"name": "master"
}
"name": "REPO"
},
"isCrossRepository": true,
"maintainerCanModify": false
@ -351,10 +338,7 @@ func TestPRCheckout_differentRepo(t *testing.T) {
"login": "hubot"
},
"headRepository": {
"name": "REPO",
"defaultBranchRef": {
"name": "master"
}
"name": "REPO"
},
"isCrossRepository": true,
"maintainerCanModify": false
@ -404,10 +388,7 @@ func TestPRCheckout_differentRepo_existingBranch(t *testing.T) {
"login": "hubot"
},
"headRepository": {
"name": "REPO",
"defaultBranchRef": {
"name": "master"
}
"name": "REPO"
},
"isCrossRepository": true,
"maintainerCanModify": false
@ -455,10 +436,7 @@ func TestPRCheckout_differentRepo_currentBranch(t *testing.T) {
"login": "hubot"
},
"headRepository": {
"name": "REPO",
"defaultBranchRef": {
"name": "master"
}
"name": "REPO"
},
"isCrossRepository": true,
"maintainerCanModify": false
@ -506,10 +484,7 @@ func TestPRCheckout_maintainerCanModify(t *testing.T) {
"login": "hubot"
},
"headRepository": {
"name": "REPO",
"defaultBranchRef": {
"name": "master"
}
"name": "REPO"
},
"isCrossRepository": true,
"maintainerCanModify": true

View file

@ -937,28 +937,25 @@ func TestPRReopen_alreadyMerged(t *testing.T) {
func TestPrMerge(t *testing.T) {
initBlankContext("", "OWNER/REPO", "master")
http := initFakeHTTP()
defer http.Verify(t)
http.StubRepoResponse("OWNER", "REPO")
http.Register(
httpmock.GraphQL(`query PullRequestByNumber\b`),
httpmock.StringResponse(`
{ "data": { "repository": {
"pullRequest": { "number": 1, "title": "The title of the PR", "state": "OPEN", "id": "THE-ID"}
} } }`))
{ "data": { "repository": { "pullRequest": {
"id": "THE-ID",
"number": 1,
"title": "The title of the PR",
"state": "OPEN",
"headRefName": "blueberries",
"headRepositoryOwner": {"login": "OWNER"}
} } } }`))
http.Register(
httpmock.GraphQL(`mutation PullRequestMerge\b`),
httpmock.GraphQLMutation(`{}`, func(input map[string]interface{}) {
assert.Equal(t, "THE-ID", input["pullRequestId"].(string))
assert.Equal(t, "MERGE", input["mergeMethod"].(string))
}))
http.Register(
httpmock.GraphQL(`query RepositoryInfo\b`),
httpmock.StringResponse(`{
"data": {
"repository": {
"defaultBranchRef": {"name": "master"}
}
}
}`))
http.Register(
httpmock.REST("DELETE", "repos/OWNER/REPO/git/refs/heads/blueberries"),
httpmock.StringResponse(`{}`))
@ -987,6 +984,7 @@ func TestPrMerge(t *testing.T) {
func TestPrMerge_withRepoFlag(t *testing.T) {
initBlankContext("", "OWNER/REPO", "master")
http := initFakeHTTP()
defer http.Verify(t)
http.Register(
httpmock.GraphQL(`query PullRequestByNumber\b`),
httpmock.GraphQLQuery(`
@ -1002,18 +1000,6 @@ func TestPrMerge_withRepoFlag(t *testing.T) {
assert.Equal(t, "THE-ID", input["pullRequestId"].(string))
assert.Equal(t, "MERGE", input["mergeMethod"].(string))
}))
http.Register(
httpmock.GraphQL(`query RepositoryInfo\b`),
httpmock.StringResponse(`{
"data": {
"repository": {
"defaultBranchRef": {"name": "master"}
}
}
}`))
http.Register(
httpmock.REST("DELETE", "repos/OWNER/REPO/git/refs/heads/blueberries"),
httpmock.StringResponse(`{}`))
cs, cmdTeardown := test.InitCmdStubber()
defer cmdTeardown()
@ -1035,6 +1021,7 @@ func TestPrMerge_withRepoFlag(t *testing.T) {
func TestPrMerge_deleteBranch(t *testing.T) {
initBlankContext("", "OWNER/REPO", "blueberries")
http := initFakeHTTP()
defer http.Verify(t)
http.StubRepoResponse("OWNER", "REPO")
http.Register(
httpmock.GraphQL(`query PullRequestForBranch\b`),
@ -1045,15 +1032,6 @@ func TestPrMerge_deleteBranch(t *testing.T) {
assert.Equal(t, "PR_10", input["pullRequestId"].(string))
assert.Equal(t, "MERGE", input["mergeMethod"].(string))
}))
http.Register(
httpmock.GraphQL(`query RepositoryInfo\b`),
httpmock.StringResponse(`{
"data": {
"repository": {
"defaultBranchRef": {"name": "master"}
}
}
}`))
http.Register(
httpmock.REST("DELETE", "repos/OWNER/REPO/git/refs/heads/blueberries"),
httpmock.StringResponse(`{}`))
@ -1078,6 +1056,7 @@ func TestPrMerge_deleteBranch(t *testing.T) {
func TestPrMerge_deleteNonCurrentBranch(t *testing.T) {
initBlankContext("", "OWNER/REPO", "another-branch")
http := initFakeHTTP()
defer http.Verify(t)
http.StubRepoResponse("OWNER", "REPO")
http.Register(
httpmock.GraphQL(`query PullRequestForBranch\b`),
@ -1088,15 +1067,6 @@ func TestPrMerge_deleteNonCurrentBranch(t *testing.T) {
assert.Equal(t, "PR_10", input["pullRequestId"].(string))
assert.Equal(t, "MERGE", input["mergeMethod"].(string))
}))
http.Register(
httpmock.GraphQL(`query RepositoryInfo\b`),
httpmock.StringResponse(`{
"data": {
"repository": {
"defaultBranchRef": {"name": "master"}
}
}
}`))
http.Register(
httpmock.REST("DELETE", "repos/OWNER/REPO/git/refs/heads/blueberries"),
httpmock.StringResponse(`{}`))
@ -1119,6 +1089,7 @@ func TestPrMerge_deleteNonCurrentBranch(t *testing.T) {
func TestPrMerge_noPrNumberGiven(t *testing.T) {
initBlankContext("", "OWNER/REPO", "blueberries")
http := initFakeHTTP()
defer http.Verify(t)
http.StubRepoResponse("OWNER", "REPO")
http.Register(
httpmock.GraphQL(`query PullRequestForBranch\b`),
@ -1129,15 +1100,6 @@ func TestPrMerge_noPrNumberGiven(t *testing.T) {
assert.Equal(t, "PR_10", input["pullRequestId"].(string))
assert.Equal(t, "MERGE", input["mergeMethod"].(string))
}))
http.Register(
httpmock.GraphQL(`query RepositoryInfo\b`),
httpmock.StringResponse(`{
"data": {
"repository": {
"defaultBranchRef": {"name": "master"}
}
}
}`))
http.Register(
httpmock.REST("DELETE", "repos/OWNER/REPO/git/refs/heads/blueberries"),
httpmock.StringResponse(`{}`))
@ -1166,28 +1128,25 @@ func TestPrMerge_noPrNumberGiven(t *testing.T) {
func TestPrMerge_rebase(t *testing.T) {
initBlankContext("", "OWNER/REPO", "master")
http := initFakeHTTP()
defer http.Verify(t)
http.StubRepoResponse("OWNER", "REPO")
http.Register(
httpmock.GraphQL(`query PullRequestByNumber\b`),
httpmock.StringResponse(`
{ "data": { "repository": {
"pullRequest": { "number": 2, "title": "The title of the PR", "state": "OPEN", "id": "THE-ID"}
} } }`))
{ "data": { "repository": { "pullRequest": {
"id": "THE-ID",
"number": 2,
"title": "The title of the PR",
"state": "OPEN",
"headRefName": "blueberries",
"headRepositoryOwner": {"login": "OWNER"}
} } } }`))
http.Register(
httpmock.GraphQL(`mutation PullRequestMerge\b`),
httpmock.GraphQLMutation(`{}`, func(input map[string]interface{}) {
assert.Equal(t, "THE-ID", input["pullRequestId"].(string))
assert.Equal(t, "REBASE", input["mergeMethod"].(string))
}))
http.Register(
httpmock.GraphQL(`query RepositoryInfo\b`),
httpmock.StringResponse(`{
"data": {
"repository": {
"defaultBranchRef": {"name": "master"}
}
}
}`))
http.Register(
httpmock.REST("DELETE", "repos/OWNER/REPO/git/refs/heads/blueberries"),
httpmock.StringResponse(`{}`))
@ -1215,28 +1174,25 @@ func TestPrMerge_rebase(t *testing.T) {
func TestPrMerge_squash(t *testing.T) {
initBlankContext("", "OWNER/REPO", "master")
http := initFakeHTTP()
defer http.Verify(t)
http.StubRepoResponse("OWNER", "REPO")
http.Register(
httpmock.GraphQL(`query PullRequestByNumber\b`),
httpmock.StringResponse(`
{ "data": { "repository": {
"pullRequest": { "number": 3, "title": "The title of the PR", "state": "OPEN", "id": "THE-ID"}
} } }`))
{ "data": { "repository": { "pullRequest": {
"id": "THE-ID",
"number": 3,
"title": "The title of the PR",
"state": "OPEN",
"headRefName": "blueberries",
"headRepositoryOwner": {"login": "OWNER"}
} } } }`))
http.Register(
httpmock.GraphQL(`mutation PullRequestMerge\b`),
httpmock.GraphQLMutation(`{}`, func(input map[string]interface{}) {
assert.Equal(t, "THE-ID", input["pullRequestId"].(string))
assert.Equal(t, "SQUASH", input["mergeMethod"].(string))
}))
http.Register(
httpmock.GraphQL(`query RepositoryInfo\b`),
httpmock.StringResponse(`{
"data": {
"repository": {
"defaultBranchRef": {"name": "master"}
}
}
}`))
http.Register(
httpmock.REST("DELETE", "repos/OWNER/REPO/git/refs/heads/blueberries"),
httpmock.StringResponse(`{}`))
@ -1254,16 +1210,14 @@ func TestPrMerge_squash(t *testing.T) {
t.Fatalf("error running command `pr merge`: %v", err)
}
r := regexp.MustCompile(`Squashed and merged pull request #3`)
if !r.MatchString(output.String()) {
t.Fatalf("output did not match regexp /%s/\n> output\n%q\n", r, output.String())
}
expected := "✔ Squashed and merged pull request #3 (The title of the PR)\n✔ Deleted branch blueberries\n"
assert.Equal(t, expected, output.String())
}
func TestPrMerge_alreadyMerged(t *testing.T) {
initBlankContext("", "OWNER/REPO", "master")
http := initFakeHTTP()
defer http.Verify(t)
http.StubRepoResponse("OWNER", "REPO")
http.Register(
httpmock.GraphQL(`query PullRequestByNumber\b`),
@ -1295,6 +1249,7 @@ func TestPrMerge_alreadyMerged(t *testing.T) {
func TestPRMerge_interactive(t *testing.T) {
initBlankContext("", "OWNER/REPO", "blueberries")
http := initFakeHTTP()
defer http.Verify(t)
http.StubRepoResponse("OWNER", "REPO")
http.Register(
httpmock.GraphQL(`query PullRequestForBranch\b`),
@ -1311,15 +1266,6 @@ func TestPRMerge_interactive(t *testing.T) {
assert.Equal(t, "THE-ID", input["pullRequestId"].(string))
assert.Equal(t, "MERGE", input["mergeMethod"].(string))
}))
http.Register(
httpmock.GraphQL(`query RepositoryInfo\b`),
httpmock.StringResponse(`{
"data": {
"repository": {
"defaultBranchRef": {"name": "master"}
}
}
}`))
http.Register(
httpmock.REST("DELETE", "repos/OWNER/REPO/git/refs/heads/blueberries"),
httpmock.StringResponse(`{}`))

View file

@ -21,6 +21,7 @@ func (r *Registry) Register(m Matcher, resp Responder) {
type Testing interface {
Errorf(string, ...interface{})
Helper()
}
func (r *Registry) Verify(t Testing) {
@ -31,6 +32,7 @@ func (r *Registry) Verify(t Testing) {
}
}
if n > 0 {
t.Helper()
// NOTE: stubs offer no useful reflection, so we can't print details
// about dead stubs and what they were trying to match
t.Errorf("%d unmatched HTTP stubs", n)