diff --git a/api/queries_pr.go b/api/queries_pr.go index 5dc6c7396..c1afc6bc9 100644 --- a/api/queries_pr.go +++ b/api/queries_pr.go @@ -143,6 +143,14 @@ type PullRequestReviewStatus struct { ReviewRequired bool } +type PullRequestMergeMethod int + +const ( + PullRequestMergeMethodMerge PullRequestMergeMethod = iota + PullRequestMergeMethodRebase + PullRequestMergeMethodSquash +) + func (pr *PullRequest) ReviewStatus() PullRequestReviewStatus { var status PullRequestReviewStatus switch pr.ReviewDecision { @@ -466,6 +474,7 @@ func PullRequestForBranch(client *Client, repo ghrepo.Interface, baseBranch, hea type response struct { Repository struct { PullRequests struct { + ID githubv4.ID Nodes []PullRequest } } @@ -915,6 +924,34 @@ func PullRequestReopen(client *Client, repo ghrepo.Interface, pr *PullRequest) e return err } +func PullRequestMerge(client *Client, repo ghrepo.Interface, pr *PullRequest, m PullRequestMergeMethod) error { + mergeMethod := githubv4.PullRequestMergeMethodMerge + switch m { + case PullRequestMergeMethodRebase: + mergeMethod = githubv4.PullRequestMergeMethodRebase + case PullRequestMergeMethodSquash: + mergeMethod = githubv4.PullRequestMergeMethodSquash + } + + var mutation struct { + MergePullRequest struct { + PullRequest struct { + ID githubv4.ID + } + } `graphql:"mergePullRequest(input: $input)"` + } + + input := githubv4.MergePullRequestInput{ + PullRequestID: pr.ID, + MergeMethod: &mergeMethod, + } + + v4 := githubv4.NewClient(client.http) + err := v4.Mutate(context.Background(), &mutation, input, nil) + + return err +} + func min(a, b int) int { if a < b { return a diff --git a/command/pr.go b/command/pr.go index 13c42883e..2407e0867 100644 --- a/command/pr.go +++ b/command/pr.go @@ -25,6 +25,10 @@ func init() { prCmd.AddCommand(prStatusCmd) prCmd.AddCommand(prCloseCmd) prCmd.AddCommand(prReopenCmd) + prCmd.AddCommand(prMergeCmd) + prMergeCmd.Flags().BoolP("merge", "m", true, "Merge the commits with the base branch") + prMergeCmd.Flags().BoolP("rebase", "r", false, "Rebase the commits onto the base branch") + prMergeCmd.Flags().BoolP("squash", "s", false, "Squash the commits into one commit and merge it into the base branch") prCmd.AddCommand(prListCmd) prListCmd.Flags().IntP("limit", "L", 30, "Maximum number of items to fetch") @@ -58,7 +62,7 @@ var prStatusCmd = &cobra.Command{ RunE: prStatus, } var prViewCmd = &cobra.Command{ - Use: "view [{ | | }]", + Use: "view [ | | ]", Short: "View a pull request", Long: `Display the title, body, and other information about a pull request. @@ -81,6 +85,13 @@ var prReopenCmd = &cobra.Command{ RunE: prReopen, } +var prMergeCmd = &cobra.Command{ + Use: "merge [ | | ]", + Short: "Merge a pull request", + Args: cobra.MaximumNArgs(1), + RunE: prMerge, +} + func prStatus(cmd *cobra.Command, args []string) error { ctx := contextForCommand(cmd) apiClient, err := apiClientForContext(ctx) @@ -100,6 +111,7 @@ func prStatus(cmd *cobra.Command, args []string) error { repoOverride, _ := cmd.Flags().GetString("repo") currentPRNumber, currentPRHeadRef, err := prSelectorForCurrentBranch(ctx, baseRepo) + if err != nil && repoOverride == "" && err.Error() != "git: not on any branch" { return fmt.Errorf("could not query for pull request for current branch: %w", err) } @@ -419,6 +431,75 @@ func prReopen(cmd *cobra.Command, args []string) error { return nil } +func prMerge(cmd *cobra.Command, args []string) error { + ctx := contextForCommand(cmd) + apiClient, err := apiClientForContext(ctx) + if err != nil { + return err + } + + baseRepo, err := determineBaseRepo(cmd, ctx) + if err != nil { + return err + } + + var pr *api.PullRequest + if len(args) > 0 { + pr, err = prFromArg(apiClient, baseRepo, args[0]) + if err != nil { + return err + } + } else { + prNumber, branchWithOwner, err := prSelectorForCurrentBranch(ctx, baseRepo) + if err != nil { + return err + } + + if prNumber != 0 { + pr, err = api.PullRequestByNumber(apiClient, baseRepo, prNumber) + } else { + pr, err = api.PullRequestForBranch(apiClient, baseRepo, "", branchWithOwner) + } + if err != nil { + return err + } + } + + if pr.State == "MERGED" { + err := fmt.Errorf("%s Pull request #%d was already merged", utils.Red("!"), pr.Number) + return err + } + + rebase, err := cmd.Flags().GetBool("rebase") + if err != nil { + return err + } + squash, err := cmd.Flags().GetBool("squash") + if err != nil { + return err + } + + var output string + if rebase { + output = fmt.Sprintf("%s Rebased and merged pull request #%d\n", utils.Green("✔"), pr.Number) + err = api.PullRequestMerge(apiClient, baseRepo, pr, api.PullRequestMergeMethodRebase) + } else if squash { + output = fmt.Sprintf("%s Squashed and merged pull request #%d\n", utils.Green("✔"), pr.Number) + err = api.PullRequestMerge(apiClient, baseRepo, pr, api.PullRequestMergeMethodSquash) + } else { + output = fmt.Sprintf("%s Merged pull request #%d\n", utils.Green("✔"), pr.Number) + err = api.PullRequestMerge(apiClient, baseRepo, pr, api.PullRequestMergeMethodMerge) + } + + if err != nil { + return fmt.Errorf("API call failed: %w", err) + } + + fmt.Fprint(colorableOut(cmd), output) + + return nil +} + func printPrPreview(out io.Writer, pr *api.PullRequest) error { // Header (Title and State) fmt.Fprintln(out, utils.Bold(pr.Title)) diff --git a/command/pr_test.go b/command/pr_test.go index 98a09b725..528e2ee77 100644 --- a/command/pr_test.go +++ b/command/pr_test.go @@ -3,6 +3,7 @@ package command import ( "bytes" "encoding/json" + "io" "io/ioutil" "os" "os/exec" @@ -970,5 +971,125 @@ func TestPRReopen_alreadyMerged(t *testing.T) { if !r.MatchString(err.Error()) { t.Fatalf("output did not match regexp /%s/\n> output\n%q\n", r, output.Stderr()) } - +} + +type stubResponse struct { + ResponseCode int + ResponseBody io.Reader +} + +func initWithStubs(branch string, stubs ...stubResponse) { + initBlankContext("", "OWNER/REPO", branch) + http := initFakeHTTP() + http.StubRepoResponse("OWNER", "REPO") + + for _, s := range stubs { + http.StubResponse(s.ResponseCode, s.ResponseBody) + } +} + +func TestPrMerge(t *testing.T) { + initWithStubs("master", + stubResponse{200, bytes.NewBufferString(`{ "data": { "repository": { + "pullRequest": { "number": 1, "closed": false, "state": "OPEN"} + } } }`)}, + stubResponse{200, bytes.NewBufferString(`{"id": "THE-ID"}`)}, + ) + + output, err := RunCommand("pr merge 1") + if err != nil { + t.Fatalf("error running command `pr merge`: %v", err) + } + + r := regexp.MustCompile(`Merged pull request #1`) + + if !r.MatchString(output.String()) { + t.Fatalf("output did not match regexp /%s/\n> output\n%q\n", r, output.Stderr()) + } +} + +func TestPrMerge_noPrNumberGiven(t *testing.T) { + cs, cmdTeardown := test.InitCmdStubber() + defer cmdTeardown() + + cs.Stub("branch.blueberries.remote origin\nbranch.blueberries.merge refs/heads/blueberries") // git config --get-regexp ^branch\.master\.(remote|merge) + + jsonFile, _ := os.Open("../test/fixtures/prViewPreviewWithMetadataByBranch.json") + defer jsonFile.Close() + + initWithStubs("blueberries", + stubResponse{200, jsonFile}, + stubResponse{200, bytes.NewBufferString(`{"id": "THE-ID"}`)}, + ) + + output, err := RunCommand("pr merge") + if err != nil { + t.Fatalf("error running command `pr merge`: %v", err) + } + + r := regexp.MustCompile(`Merged pull request #10`) + + if !r.MatchString(output.String()) { + t.Fatalf("output did not match regexp /%s/\n> output\n%q\n", r, output.Stderr()) + } +} + +func TestPrMerge_rebase(t *testing.T) { + initWithStubs("master", + stubResponse{200, bytes.NewBufferString(`{ "data": { "repository": { + "pullRequest": { "number": 2, "closed": false, "state": "OPEN"} + } } }`)}, + stubResponse{200, bytes.NewBufferString(`{"id": "THE-ID"}`)}, + ) + + output, err := RunCommand("pr merge 2 --rebase") + if err != nil { + t.Fatalf("error running command `pr merge`: %v", err) + } + + r := regexp.MustCompile(`Rebased and merged pull request #2`) + + if !r.MatchString(output.String()) { + t.Fatalf("output did not match regexp /%s/\n> output\n%q\n", r, output.Stderr()) + } +} + +func TestPrMerge_squash(t *testing.T) { + initWithStubs("master", + stubResponse{200, bytes.NewBufferString(`{ "data": { "repository": { + "pullRequest": { "number": 3, "closed": false, "state": "OPEN"} + } } }`)}, + stubResponse{200, bytes.NewBufferString(`{"id": "THE-ID"}`)}, + ) + + output, err := RunCommand("pr merge 3 --squash") + if err != nil { + t.Fatalf("error running command `pr merge`: %v", err) + } + + r := regexp.MustCompile(`Squashed and merged pull request #3`) + + if !r.MatchString(output.String()) { + t.Fatalf("output did not match regexp /%s/\n> output\n%q\n", r, output.Stderr()) + } +} + +func TestPrMerge_alreadyMerged(t *testing.T) { + initWithStubs("master", + stubResponse{200, bytes.NewBufferString(`{ "data": { "repository": { + "pullRequest": { "number": 4, "closed": true, "state": "MERGED"} + } } }`)}, + stubResponse{200, bytes.NewBufferString(`{"id": "THE-ID"}`)}, + ) + + output, err := RunCommand("pr merge 4") + if err == nil { + t.Fatalf("expected an error running command `pr merge`: %v", err) + } + + r := regexp.MustCompile(`Pull request #4 was already merged`) + + if !r.MatchString(err.Error()) { + t.Fatalf("output did not match regexp /%s/\n> output\n%q\n", r, output.Stderr()) + } }