diff --git a/api/queries.go b/api/queries.go index 5560aa268..7bd121441 100644 --- a/api/queries.go +++ b/api/queries.go @@ -17,7 +17,20 @@ type PullRequest struct { State string URL string HeadRefName string - Reviews struct { + + HeadRepositoryOwner struct { + Login string + } + HeadRepository struct { + Name string + DefaultBranchRef struct { + Name string + } + } + IsCrossRepository bool + MaintainerCanModify bool + + Reviews struct { Nodes []struct { State string Author struct { @@ -355,6 +368,48 @@ func PullRequests(client *Client, ghRepo Repo, currentBranch, currentUsername st return &payload, nil } +func PullRequestByNumber(client *Client, ghRepo Repo, number int) (*PullRequest, error) { + type response struct { + Repository struct { + PullRequest PullRequest + } + } + + query := ` + query($owner: String!, $repo: String!, $pr_number: Int!) { + repository(owner: $owner, name: $repo) { + pullRequest(number: $pr_number) { + headRefName + headRepositoryOwner { + login + } + headRepository { + name + defaultBranchRef { + name + } + } + isCrossRepository + maintainerCanModify + } + } + }` + + variables := map[string]interface{}{ + "owner": ghRepo.RepoOwner(), + "repo": ghRepo.RepoName(), + "pr_number": number, + } + + var resp response + err := client.GraphQL(query, variables, &resp) + if err != nil { + return nil, err + } + + return &resp.Repository.PullRequest, nil +} + func PullRequestsForBranch(client *Client, ghRepo Repo, branch string) ([]PullRequest, error) { type response struct { Repository struct { diff --git a/command/pr.go b/command/pr.go index a70eaab73..aa774a2d3 100644 --- a/command/pr.go +++ b/command/pr.go @@ -3,9 +3,11 @@ package command import ( "fmt" "os" + "os/exec" "strconv" "github.com/github/gh-cli/api" + "github.com/github/gh-cli/git" "github.com/github/gh-cli/utils" "github.com/spf13/cobra" "golang.org/x/crypto/ssh/terminal" @@ -13,6 +15,7 @@ import ( func init() { RootCmd.AddCommand(prCmd) + prCmd.AddCommand(prCheckoutCmd) prCmd.AddCommand(prCreateCmd) prCmd.AddCommand(prListCmd) prCmd.AddCommand(prStatusCmd) @@ -29,6 +32,12 @@ var prCmd = &cobra.Command{ Short: "Work with pull requests", Long: `Helps you work with pull requests.`, } +var prCheckoutCmd = &cobra.Command{ + Use: "checkout ", + Short: "check out a pull request in git", + Args: cobra.MinimumNArgs(1), + RunE: prCheckout, +} var prListCmd = &cobra.Command{ Use: "list", Short: "List pull requests", @@ -247,6 +256,103 @@ func prView(cmd *cobra.Command, args []string) error { return utils.OpenInBrowser(openURL) } +func prCheckout(cmd *cobra.Command, args []string) error { + prNumber, err := strconv.Atoi(args[0]) + if err != nil { + return err + } + + ctx := contextForCommand(cmd) + currentBranch, err := ctx.Branch() + if err != nil { + return err + } + remotes, err := ctx.Remotes() + if err != nil { + return err + } + // FIXME: duplicates logic from fsContext.BaseRepo + baseRemote, err := remotes.FindByName("upstream", "github", "origin", "*") + if err != nil { + return err + } + apiClient, err := apiClientForContext(ctx) + if err != nil { + return err + } + + pr, err := api.PullRequestByNumber(apiClient, baseRemote, prNumber) + if err != nil { + return err + } + + headRemote := baseRemote + if pr.IsCrossRepository { + headRemote, _ = remotes.FindByRepo(pr.HeadRepositoryOwner.Login, pr.HeadRepository.Name) + } + + cmdQueue := [][]string{} + + newBranchName := pr.HeadRefName + if headRemote != nil { + // there is an existing git remote for PR head + remoteBranch := fmt.Sprintf("%s/%s", headRemote.Name, pr.HeadRefName) + refSpec := fmt.Sprintf("+refs/heads/%s:refs/remotes/%s", pr.HeadRefName, remoteBranch) + + cmdQueue = append(cmdQueue, []string{"git", "fetch", headRemote.Name, refSpec}) + + // local branch already exists + if git.HasFile("refs", "heads", newBranchName) { + cmdQueue = append(cmdQueue, []string{"git", "checkout", newBranchName}) + cmdQueue = append(cmdQueue, []string{"git", "merge", "--ff-only", fmt.Sprintf("refs/remotes/%s", remoteBranch)}) + } else { + cmdQueue = append(cmdQueue, []string{"git", "checkout", "-b", newBranchName, "--no-track", remoteBranch}) + cmdQueue = append(cmdQueue, []string{"git", "config", fmt.Sprintf("branch.%s.remote", newBranchName), headRemote.Name}) + cmdQueue = append(cmdQueue, []string{"git", "config", fmt.Sprintf("branch.%s.merge", newBranchName), "refs/heads/" + pr.HeadRefName}) + } + } else { + // no git remote for PR head + + // avoid naming the new branch the same as the default branch + if newBranchName == pr.HeadRepository.DefaultBranchRef.Name { + newBranchName = fmt.Sprintf("%s/%s", pr.HeadRepositoryOwner.Login, newBranchName) + } + + ref := fmt.Sprintf("refs/pull/%d/head", prNumber) + if newBranchName == currentBranch { + // PR head matches currently checked out branch + cmdQueue = append(cmdQueue, []string{"git", "fetch", baseRemote.Name, ref}) + cmdQueue = append(cmdQueue, []string{"git", "merge", "--ff-only", "FETCH_HEAD"}) + } else { + // create a new branch + cmdQueue = append(cmdQueue, []string{"git", "fetch", baseRemote.Name, fmt.Sprintf("%s:%s", ref, newBranchName)}) + cmdQueue = append(cmdQueue, []string{"git", "checkout", newBranchName}) + } + + remote := baseRemote.Name + mergeRef := ref + if pr.MaintainerCanModify { + remote = fmt.Sprintf("https://github.com/%s/%s.git", pr.HeadRepositoryOwner.Login, pr.HeadRepository.Name) + mergeRef = fmt.Sprintf("refs/heads/%s", pr.HeadRefName) + } + if mc, err := git.Config(fmt.Sprintf("branch.%s.merge", newBranchName)); err != nil || mc == "" { + cmdQueue = append(cmdQueue, []string{"git", "config", fmt.Sprintf("branch.%s.remote", newBranchName), remote}) + cmdQueue = append(cmdQueue, []string{"git", "config", fmt.Sprintf("branch.%s.merge", newBranchName), mergeRef}) + } + } + + for _, args := range cmdQueue { + cmd := exec.Command(args[0], args[1:]...) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := utils.PrepareCmd(cmd).Run(); err != nil { + return err + } + } + + return nil +} + func printPrs(prs ...api.PullRequest) { for _, pr := range prs { prNumber := fmt.Sprintf("#%d", pr.Number) diff --git a/command/pr_checkout_test.go b/command/pr_checkout_test.go new file mode 100644 index 000000000..99dc2281d --- /dev/null +++ b/command/pr_checkout_test.go @@ -0,0 +1,153 @@ +package command + +import ( + "bytes" + "os/exec" + "strings" + "testing" + + "github.com/github/gh-cli/context" + "github.com/github/gh-cli/utils" +) + +func TestPRCheckout_sameRepo(t *testing.T) { + ctx := context.NewBlank() + ctx.SetBranch("master") + ctx.SetRemotes(map[string]string{ + "origin": "OWNER/REPO", + }) + initContext = func() context.Context { + return ctx + } + http := initFakeHTTP() + + http.StubResponse(200, bytes.NewBufferString(` + { "data": { "repository": { "pullRequest": { + "headRefName": "feature", + "headRepositoryOwner": { + "login": "hubot" + }, + "headRepository": { + "name": "REPO", + "defaultBranchRef": { + "name": "master" + } + }, + "isCrossRepository": false, + "maintainerCanModify": false + } } } } + `)) + + ranCommands := [][]string{} + restoreCmd := utils.SetPrepareCmd(func(cmd *exec.Cmd) utils.Runnable { + ranCommands = append(ranCommands, cmd.Args) + return &outputStub{} + }) + defer restoreCmd() + + RootCmd.SetArgs([]string{"pr", "checkout", "123"}) + _, err := prCheckoutCmd.ExecuteC() + eq(t, err, nil) + + eq(t, len(ranCommands), 6) + eq(t, strings.Join(ranCommands[0], " "), "git rev-parse -q --git-path refs/heads/feature") + eq(t, strings.Join(ranCommands[1], " "), "git rev-parse -q --git-dir") + eq(t, strings.Join(ranCommands[2], " "), "git fetch origin +refs/heads/feature:refs/remotes/origin/feature") + eq(t, strings.Join(ranCommands[3], " "), "git checkout -b feature --no-track origin/feature") + eq(t, strings.Join(ranCommands[4], " "), "git config branch.feature.remote origin") + eq(t, strings.Join(ranCommands[5], " "), "git config branch.feature.merge refs/heads/feature") +} + +func TestPRCheckout_differentRepo(t *testing.T) { + ctx := context.NewBlank() + ctx.SetBranch("master") + ctx.SetRemotes(map[string]string{ + "origin": "OWNER/REPO", + }) + initContext = func() context.Context { + return ctx + } + http := initFakeHTTP() + + http.StubResponse(200, bytes.NewBufferString(` + { "data": { "repository": { "pullRequest": { + "headRefName": "feature", + "headRepositoryOwner": { + "login": "hubot" + }, + "headRepository": { + "name": "REPO", + "defaultBranchRef": { + "name": "master" + } + }, + "isCrossRepository": true, + "maintainerCanModify": false + } } } } + `)) + + ranCommands := [][]string{} + restoreCmd := utils.SetPrepareCmd(func(cmd *exec.Cmd) utils.Runnable { + ranCommands = append(ranCommands, cmd.Args) + return &outputStub{} + }) + defer restoreCmd() + + RootCmd.SetArgs([]string{"pr", "checkout", "123"}) + _, err := prCheckoutCmd.ExecuteC() + eq(t, err, nil) + + eq(t, len(ranCommands), 5) + eq(t, strings.Join(ranCommands[0], " "), "git config branch.feature.merge") + eq(t, strings.Join(ranCommands[1], " "), "git fetch origin refs/pull/123/head:feature") + eq(t, strings.Join(ranCommands[2], " "), "git checkout feature") + eq(t, strings.Join(ranCommands[3], " "), "git config branch.feature.remote origin") + eq(t, strings.Join(ranCommands[4], " "), "git config branch.feature.merge refs/pull/123/head") +} + +func TestPRCheckout_maintainerCanModify(t *testing.T) { + ctx := context.NewBlank() + ctx.SetBranch("master") + ctx.SetRemotes(map[string]string{ + "origin": "OWNER/REPO", + }) + initContext = func() context.Context { + return ctx + } + http := initFakeHTTP() + + http.StubResponse(200, bytes.NewBufferString(` + { "data": { "repository": { "pullRequest": { + "headRefName": "feature", + "headRepositoryOwner": { + "login": "hubot" + }, + "headRepository": { + "name": "REPO", + "defaultBranchRef": { + "name": "master" + } + }, + "isCrossRepository": true, + "maintainerCanModify": true + } } } } + `)) + + ranCommands := [][]string{} + restoreCmd := utils.SetPrepareCmd(func(cmd *exec.Cmd) utils.Runnable { + ranCommands = append(ranCommands, cmd.Args) + return &outputStub{} + }) + defer restoreCmd() + + RootCmd.SetArgs([]string{"pr", "checkout", "123"}) + _, err := prCheckoutCmd.ExecuteC() + eq(t, err, nil) + + eq(t, len(ranCommands), 5) + eq(t, strings.Join(ranCommands[0], " "), "git config branch.feature.merge") + eq(t, strings.Join(ranCommands[1], " "), "git fetch origin refs/pull/123/head:feature") + eq(t, strings.Join(ranCommands[2], " "), "git checkout feature") + eq(t, strings.Join(ranCommands[3], " "), "git config branch.feature.remote https://github.com/hubot/REPO.git") + eq(t, strings.Join(ranCommands[4], " "), "git config branch.feature.merge refs/heads/feature") +} diff --git a/context/remote.go b/context/remote.go index 9f3b228dc..a30ac6c45 100644 --- a/context/remote.go +++ b/context/remote.go @@ -25,6 +25,16 @@ func (r Remotes) FindByName(names ...string) (*Remote, error) { return nil, fmt.Errorf("no GitHub remotes found") } +// FindByRepo returns the first Remote that points to a specific GitHub repository +func (r Remotes) FindByRepo(owner, name string) (*Remote, error) { + for _, rem := range r { + if strings.EqualFold(rem.Owner, owner) && strings.EqualFold(rem.Name, name) { + return rem, nil + } + } + return nil, fmt.Errorf("no matching remote found") +} + // Remote represents a git remote mapped to a GitHub repository type Remote struct { *git.Remote