diff --git a/api/queries_pr.go b/api/queries_pr.go index cca77be99..188e708a8 100644 --- a/api/queries_pr.go +++ b/api/queries_pr.go @@ -205,9 +205,9 @@ func (pr *PullRequest) ChecksStatus() (summary PullRequestChecksStatus) { return } -func (c Client) PullRequestDiff(baseRepo ghrepo.Interface, prNum int) (string, error) { +func (c Client) PullRequestDiff(baseRepo ghrepo.Interface, prNumber int) (string, error) { url := fmt.Sprintf("https://api.github.com/repos/%s/pulls/%d", - ghrepo.FullName(baseRepo), prNum) + ghrepo.FullName(baseRepo), prNumber) req, err := http.NewRequest("GET", url, nil) if err != nil { return "", err @@ -235,7 +235,6 @@ func (c Client) PullRequestDiff(baseRepo ghrepo.Interface, prNum int) (string, e } return "", errors.New("pull request diff lookup failed") - } func PullRequests(client *Client, repo ghrepo.Interface, currentPRNumber int, currentPRHeadRef, currentUsername string) (*PullRequestsPayload, error) { diff --git a/command/pr.go b/command/pr.go index 386a5e8ce..e92fa7cf0 100644 --- a/command/pr.go +++ b/command/pr.go @@ -302,59 +302,16 @@ func prView(cmd *cobra.Command, args []string) error { return err } - var baseRepo ghrepo.Interface - var prArg string - if len(args) > 0 { - prArg = args[0] - if prNum, repo := prFromURL(prArg); repo != nil { - prArg = prNum - baseRepo = repo - } - } - - if baseRepo == nil { - baseRepo, err = determineBaseRepo(apiClient, cmd, ctx) - if err != nil { - return err - } - } - web, err := cmd.Flags().GetBool("web") if err != nil { return err } - var openURL string - var pr *api.PullRequest - if len(args) > 0 { - pr, err = prFromArg(apiClient, baseRepo, prArg) - if err != nil { - return err - } - openURL = pr.URL - } else { - prNumber, branchWithOwner, err := prSelectorForCurrentBranch(ctx, baseRepo) - if err != nil { - return err - } - - if prNumber > 0 { - openURL = fmt.Sprintf("https://github.com/%s/pull/%d", ghrepo.FullName(baseRepo), prNumber) - if !web { - pr, err = api.PullRequestByNumber(apiClient, baseRepo, prNumber) - if err != nil { - return err - } - } - } else { - pr, err = api.PullRequestForBranch(apiClient, baseRepo, "", branchWithOwner) - if err != nil { - return err - } - - openURL = pr.URL - } + pr, _, err := prFromArgs(ctx, apiClient, cmd, args) + if err != nil { + return err } + openURL := pr.URL if web { fmt.Fprintf(cmd.ErrOrStderr(), "Opening %s in your browser.\n", openURL) @@ -372,12 +329,7 @@ func prClose(cmd *cobra.Command, args []string) error { return err } - baseRepo, err := determineBaseRepo(apiClient, cmd, ctx) - if err != nil { - return err - } - - pr, err := prFromArg(apiClient, baseRepo, args[0]) + pr, baseRepo, err := prFromArgs(ctx, apiClient, cmd, args) if err != nil { return err } @@ -407,12 +359,7 @@ func prReopen(cmd *cobra.Command, args []string) error { return err } - baseRepo, err := determineBaseRepo(apiClient, cmd, ctx) - if err != nil { - return err - } - - pr, err := prFromArg(apiClient, baseRepo, args[0]) + pr, baseRepo, err := prFromArgs(ctx, apiClient, cmd, args) if err != nil { return err } @@ -444,41 +391,11 @@ func prMerge(cmd *cobra.Command, args []string) error { return err } - baseRepo, err := determineBaseRepo(apiClient, cmd, ctx) + pr, baseRepo, err := prFromArgs(ctx, apiClient, cmd, args) if err != nil { return err } - var pr *api.PullRequest - if len(args) > 0 { - var prNumber string - n, _ := prFromURL(args[0]) - if n != "" { - prNumber = n - } else { - prNumber = args[0] - } - - pr, err = prFromArg(apiClient, baseRepo, prNumber) - if err != nil { - return err - } - } else { - prNumber, branchWithOwner, err := prSelectorForCurrentBranch(ctx, baseRepo) - if err != nil { - return err - } - - if prNumber != 0 { - pr, err = api.PullRequestByNumber(apiClient, baseRepo, prNumber) - } else { - pr, err = api.PullRequestForBranch(apiClient, baseRepo, "", branchWithOwner) - } - if err != nil { - return err - } - } - if pr.Mergeable == "CONFLICTING" { err := fmt.Errorf("%s Pull request #%d has conflicts and isn't mergeable ", utils.Red("!"), pr.Number) return err @@ -712,41 +629,11 @@ func prReady(cmd *cobra.Command, args []string) error { return err } - baseRepo, err := determineBaseRepo(apiClient, cmd, ctx) + pr, baseRepo, err := prFromArgs(ctx, apiClient, cmd, args) if err != nil { return err } - var pr *api.PullRequest - if len(args) > 0 { - var prNumber string - n, _ := prFromURL(args[0]) - if n != "" { - prNumber = n - } else { - prNumber = args[0] - } - - pr, err = prFromArg(apiClient, baseRepo, prNumber) - if err != nil { - return err - } - } else { - prNumber, branchWithOwner, err := prSelectorForCurrentBranch(ctx, baseRepo) - if err != nil { - return err - } - - if prNumber != 0 { - pr, err = api.PullRequestByNumber(apiClient, baseRepo, prNumber) - } else { - pr, err = api.PullRequestForBranch(apiClient, baseRepo, "", branchWithOwner) - } - if err != nil { - return err - } - } - if pr.Closed { err := fmt.Errorf("%s Pull request #%d is closed. Only draft pull requests can be marked as \"ready for review\"", utils.Red("!"), pr.Number) return err @@ -941,23 +828,6 @@ func prProjectList(pr api.PullRequest) string { return list } -var prURLRE = regexp.MustCompile(`^https://github\.com/([^/]+)/([^/]+)/pull/(\d+)`) - -func prFromURL(arg string) (string, ghrepo.Interface) { - if m := prURLRE.FindStringSubmatch(arg); m != nil { - return m[3], ghrepo.New(m[1], m[2]) - } - return "", nil -} - -func prFromArg(apiClient *api.Client, baseRepo ghrepo.Interface, arg string) (*api.PullRequest, error) { - if prNumber, err := strconv.Atoi(strings.TrimPrefix(arg, "#")); err == nil { - return api.PullRequestByNumber(apiClient, baseRepo, prNumber) - } - - return api.PullRequestForBranch(apiClient, baseRepo, "", arg) -} - func prSelectorForCurrentBranch(ctx context.Context, baseRepo ghrepo.Interface) (prNumber int, prHeadRef string, err error) { prHeadRef, err = ctx.Branch() if err != nil { diff --git a/command/pr_checkout.go b/command/pr_checkout.go index aa9e3afcc..8ffa39210 100644 --- a/command/pr_checkout.go +++ b/command/pr_checkout.go @@ -26,21 +26,7 @@ func prCheckout(cmd *cobra.Command, args []string) error { return err } - var baseRepo ghrepo.Interface - prArg := args[0] - if prNum, repo := prFromURL(prArg); repo != nil { - prArg = prNum - baseRepo = repo - } - - if baseRepo == nil { - baseRepo, err = determineBaseRepo(apiClient, cmd, ctx) - if err != nil { - return err - } - } - - pr, err := prFromArg(apiClient, baseRepo, prArg) + pr, baseRepo, err := prFromArgs(ctx, apiClient, cmd, args) if err != nil { return err } @@ -59,6 +45,7 @@ func prCheckout(cmd *cobra.Command, args []string) error { var cmdQueue [][]string newBranchName := pr.HeadRefName + if headRemote != nil { // there is an existing git remote for PR head remoteBranch := fmt.Sprintf("%s/%s", headRemote.Name, pr.HeadRefName) diff --git a/command/pr_checkout_test.go b/command/pr_checkout_test.go index 851d6ed22..16b7429be 100644 --- a/command/pr_checkout_test.go +++ b/command/pr_checkout_test.go @@ -76,7 +76,6 @@ func TestPRCheckout_urlArg(t *testing.T) { return ctx } http := initFakeHTTP() - http.StubResponse(200, bytes.NewBufferString(` { "data": { "repository": { "pullRequest": { "number": 123, @@ -125,7 +124,6 @@ func TestPRCheckout_urlArg_differentBase(t *testing.T) { return ctx } http := initFakeHTTP() - http.StubResponse(200, bytes.NewBufferString(` { "data": { "repository": { "pullRequest": { "number": 123, diff --git a/command/pr_diff.go b/command/pr_diff.go index 6e4bb4bdb..ae50a2618 100644 --- a/command/pr_diff.go +++ b/command/pr_diff.go @@ -1,13 +1,10 @@ package command import ( - "errors" "fmt" "os" - "strconv" "strings" - "github.com/cli/cli/api" "github.com/cli/cli/utils" "github.com/spf13/cobra" ) @@ -39,43 +36,12 @@ func prDiff(cmd *cobra.Command, args []string) error { return err } - baseRepo, err := determineBaseRepo(apiClient, cmd, ctx) + pr, baseRepo, err := prFromArgs(ctx, apiClient, cmd, args) if err != nil { - return fmt.Errorf("could not determine base repo: %w", err) + return fmt.Errorf("could not find pull request: %w", err) } - // begin pr resolution boilerplate - var prNum int - branchWithOwner := "" - - if len(args) == 0 { - prNum, branchWithOwner, err = prSelectorForCurrentBranch(ctx, baseRepo) - if err != nil { - return fmt.Errorf("could not query for pull request for current branch: %w", err) - } - } else { - prArg, repo := prFromURL(args[0]) - if repo != nil { - baseRepo = repo - } else { - prArg = strings.TrimPrefix(args[0], "#") - } - prNum, err = strconv.Atoi(prArg) - if err != nil { - return errors.New("could not parse pull request argument") - } - } - - if prNum < 1 { - pr, err := api.PullRequestForBranch(apiClient, baseRepo, "", branchWithOwner) - if err != nil { - return fmt.Errorf("could not find pull request: %w", err) - } - prNum = pr.Number - } - // end pr resolution boilerplate - - diff, err := apiClient.PullRequestDiff(baseRepo, prNum) + diff, err := apiClient.PullRequestDiff(baseRepo, pr.Number) if err != nil { return fmt.Errorf("could not find pull request diff: %w", err) } diff --git a/command/pr_diff_test.go b/command/pr_diff_test.go index ab883dcab..725b21f03 100644 --- a/command/pr_diff_test.go +++ b/command/pr_diff_test.go @@ -37,6 +37,11 @@ func TestPRDiff_argument_not_found(t *testing.T) { initBlankContext("", "OWNER/REPO", "master") http := initFakeHTTP() http.StubRepoResponse("OWNER", "REPO") + http.StubResponse(200, bytes.NewBufferString(` + { "data": { "repository": { + "pullRequest": { "number": 123 } + } } } +`)) http.StubResponse(404, bytes.NewBufferString("")) _, err := RunCommand("pr diff 123") if err == nil { diff --git a/command/pr_lookup.go b/command/pr_lookup.go new file mode 100644 index 000000000..a1b0c9638 --- /dev/null +++ b/command/pr_lookup.go @@ -0,0 +1,109 @@ +package command + +import ( + "fmt" + "regexp" + "strconv" + "strings" + + "github.com/cli/cli/api" + "github.com/cli/cli/context" + "github.com/cli/cli/git" + "github.com/cli/cli/internal/ghrepo" + "github.com/spf13/cobra" +) + +func prFromArgs(ctx context.Context, apiClient *api.Client, cmd *cobra.Command, args []string) (*api.PullRequest, ghrepo.Interface, error) { + if len(args) == 1 { + // First check to see if the prString is a url, return repo from url if found. This + // is run first because we don't need to run determineBaseRepo for this path + prString := args[0] + pr, r, err := prFromURL(ctx, apiClient, prString) + if pr != nil || err != nil { + return pr, r, err + } + } + + repo, err := determineBaseRepo(apiClient, cmd, ctx) + if err != nil { + return nil, nil, fmt.Errorf("could not determine base repo: %w", err) + } + + // If there are no args see if we can guess the PR from the current branch + if len(args) == 0 { + pr, err := prForCurrentBranch(ctx, apiClient, repo) + return pr, repo, err + } else { + prString := args[0] + // Next see if the prString is a number and use that to look up the url + pr, err := prFromNumberString(ctx, apiClient, repo, prString) + if pr != nil || err != nil { + return pr, repo, err + } + + // Last see if it is a branch name + pr, err = api.PullRequestForBranch(apiClient, repo, "", prString) + return pr, repo, err + } +} + +func prFromNumberString(ctx context.Context, apiClient *api.Client, repo ghrepo.Interface, s string) (*api.PullRequest, error) { + if prNumber, err := strconv.Atoi(strings.TrimPrefix(s, "#")); err == nil { + return api.PullRequestByNumber(apiClient, repo, prNumber) + } + + return nil, nil +} + +func prFromURL(ctx context.Context, apiClient *api.Client, s string) (*api.PullRequest, ghrepo.Interface, error) { + r := regexp.MustCompile(`^https://github\.com/([^/]+)/([^/]+)/pull/(\d+)`) + if m := r.FindStringSubmatch(s); m != nil { + repo := ghrepo.New(m[1], m[2]) + prNumberString := m[3] + pr, err := prFromNumberString(ctx, apiClient, repo, prNumberString) + return pr, repo, err + } + + return nil, nil, nil +} + +func prForCurrentBranch(ctx context.Context, apiClient *api.Client, repo ghrepo.Interface) (*api.PullRequest, error) { + prHeadRef, err := ctx.Branch() + if err != nil { + return nil, err + } + + branchConfig := git.ReadBranchConfig(prHeadRef) + + // the branch is configured to merge a special PR head ref + prHeadRE := regexp.MustCompile(`^refs/pull/(\d+)/head$`) + if m := prHeadRE.FindStringSubmatch(branchConfig.MergeRef); m != nil { + return prFromNumberString(ctx, apiClient, repo, m[1]) + } + + var branchOwner string + if branchConfig.RemoteURL != nil { + // the branch merges from a remote specified by URL + if r, err := ghrepo.FromURL(branchConfig.RemoteURL); err == nil { + branchOwner = r.RepoOwner() + } + } else if branchConfig.RemoteName != "" { + // the branch merges from a remote specified by name + rem, _ := ctx.Remotes() + if r, err := rem.FindByName(branchConfig.RemoteName); err == nil { + branchOwner = r.RepoOwner() + } + } + + if branchOwner != "" { + if strings.HasPrefix(branchConfig.MergeRef, "refs/heads/") { + prHeadRef = strings.TrimPrefix(branchConfig.MergeRef, "refs/heads/") + } + // prepend `OWNER:` if this branch is pushed to a fork + if !strings.EqualFold(branchOwner, repo.RepoOwner()) { + prHeadRef = fmt.Sprintf("%s:%s", branchOwner, prHeadRef) + } + } + + return api.PullRequestForBranch(apiClient, repo, "", prHeadRef) +} diff --git a/command/pr_review.go b/command/pr_review.go index 37a8e002a..642c9ac24 100644 --- a/command/pr_review.go +++ b/command/pr_review.go @@ -3,8 +3,6 @@ package command import ( "errors" "fmt" - "strconv" - "strings" "github.com/AlecAivazis/survey/v2" "github.com/spf13/cobra" @@ -91,30 +89,9 @@ func prReview(cmd *cobra.Command, args []string) error { return err } - baseRepo, err := determineBaseRepo(apiClient, cmd, ctx) + pr, _, err := prFromArgs(ctx, apiClient, cmd, args) if err != nil { - return fmt.Errorf("could not determine base repo: %w", err) - } - - var prNum int - branchWithOwner := "" - - if len(args) == 0 { - prNum, branchWithOwner, err = prSelectorForCurrentBranch(ctx, baseRepo) - if err != nil { - return fmt.Errorf("could not query for pull request for current branch: %w", err) - } - } else { - prArg, repo := prFromURL(args[0]) - if repo != nil { - baseRepo = repo - } else { - prArg = strings.TrimPrefix(args[0], "#") - } - prNum, err = strconv.Atoi(prArg) - if err != nil { - return errors.New("could not parse pull request argument") - } + return err } reviewData, err := processReviewOpt(cmd) @@ -122,20 +99,6 @@ func prReview(cmd *cobra.Command, args []string) error { return fmt.Errorf("did not understand desired review action: %w", err) } - var pr *api.PullRequest - if prNum > 0 { - pr, err = api.PullRequestByNumber(apiClient, baseRepo, prNum) - if err != nil { - return fmt.Errorf("could not find pull request: %w", err) - } - } else { - pr, err = api.PullRequestForBranch(apiClient, baseRepo, "", branchWithOwner) - if err != nil { - return fmt.Errorf("could not find pull request: %w", err) - } - prNum = pr.Number - } - out := colorableOut(cmd) if reviewData == nil { @@ -156,11 +119,11 @@ func prReview(cmd *cobra.Command, args []string) error { switch reviewData.State { case api.ReviewComment: - fmt.Fprintf(out, "%s Reviewed pull request #%d\n", utils.Gray("-"), prNum) + fmt.Fprintf(out, "%s Reviewed pull request #%d\n", utils.Gray("-"), pr.Number) case api.ReviewApprove: - fmt.Fprintf(out, "%s Approved pull request #%d\n", utils.Green("✓"), prNum) + fmt.Fprintf(out, "%s Approved pull request #%d\n", utils.Green("✓"), pr.Number) case api.ReviewRequestChanges: - fmt.Fprintf(out, "%s Requested changes to pull request #%d\n", utils.Red("+"), prNum) + fmt.Fprintf(out, "%s Requested changes to pull request #%d\n", utils.Red("+"), pr.Number) } return nil diff --git a/command/pr_review_test.go b/command/pr_review_test.go index c6c56c00d..612f6c574 100644 --- a/command/pr_review_test.go +++ b/command/pr_review_test.go @@ -18,6 +18,11 @@ func TestPRReview_validation(t *testing.T) { `pr review --approve --comment -b"hey" 123`, } { http.StubRepoResponse("OWNER", "REPO") + http.StubResponse(200, bytes.NewBufferString(` + { "data": { "repository": { + "pullRequest": { "number": 123 } + } } } + `)) _, err := RunCommand(cmd) if err == nil { t.Fatal("expected error") @@ -30,7 +35,12 @@ func TestPRReview_bad_body(t *testing.T) { initBlankContext("", "OWNER/REPO", "master") http := initFakeHTTP() http.StubRepoResponse("OWNER", "REPO") - _, err := RunCommand(`pr review -b "radical"`) + http.StubResponse(200, bytes.NewBufferString(` + { "data": { "repository": { + "pullRequest": { "number": 123 } + } } } + `)) + _, err := RunCommand(`pr review 123 -b "radical"`) if err == nil { t.Fatal("expected error") } @@ -40,7 +50,6 @@ func TestPRReview_bad_body(t *testing.T) { func TestPRReview_url_arg(t *testing.T) { initBlankContext("", "OWNER/REPO", "master") http := initFakeHTTP() - http.StubRepoResponse("OWNER", "REPO") http.StubResponse(200, bytes.NewBufferString(` { "data": { "repository": { "pullRequest": { "id": "foobar123", @@ -67,7 +76,7 @@ func TestPRReview_url_arg(t *testing.T) { test.ExpectLines(t, output.String(), "Approved pull request #123") - bodyBytes, _ := ioutil.ReadAll(http.Requests[2].Body) + bodyBytes, _ := ioutil.ReadAll(http.Requests[1].Body) reqBody := struct { Variables struct { Input struct { @@ -173,6 +182,11 @@ func TestPRReview_blank_comment(t *testing.T) { initBlankContext("", "OWNER/REPO", "master") http := initFakeHTTP() http.StubRepoResponse("OWNER", "REPO") + http.StubResponse(200, bytes.NewBufferString(` + { "data": { "repository": { + "pullRequest": { "number": 123 } + } } } + `)) _, err := RunCommand(`pr review --comment 123`) eq(t, err.Error(), "did not understand desired review action: body cannot be blank for comment review") @@ -182,6 +196,11 @@ func TestPRReview_blank_request_changes(t *testing.T) { initBlankContext("", "OWNER/REPO", "master") http := initFakeHTTP() http.StubRepoResponse("OWNER", "REPO") + http.StubResponse(200, bytes.NewBufferString(` + { "data": { "repository": { + "pullRequest": { "number": 123 } + } } } + `)) _, err := RunCommand(`pr review -r 123`) eq(t, err.Error(), "did not understand desired review action: body cannot be blank for request-changes review") diff --git a/command/pr_test.go b/command/pr_test.go index 86e2d4ecd..a11c60820 100644 --- a/command/pr_test.go +++ b/command/pr_test.go @@ -722,11 +722,10 @@ func TestPRView_web_numberArgWithHash(t *testing.T) { func TestPRView_web_urlArg(t *testing.T) { initBlankContext("", "OWNER/REPO", "master") http := initFakeHTTP() - http.StubResponse(200, bytes.NewBufferString(` - { "data": { "repository": { "pullRequest": { - "url": "https://github.com/OWNER/REPO/pull/23" - } } } } + { "data": { "repository": { "pullRequest": { + "url": "https://github.com/OWNER/REPO/pull/23" + } } } } `)) var seenCmd *exec.Cmd