From 4c75c8bccc429cabe66bfd004f959e3db45dfd09 Mon Sep 17 00:00:00 2001 From: Corey Johnson Date: Wed, 3 Jun 2020 14:34:13 -0700 Subject: [PATCH] Reset base branch when URL is used --- api/queries_pr.go | 4 ++-- command/pr.go | 35 +++++------------------------------ command/pr_checkout.go | 19 ++----------------- command/pr_checkout_test.go | 5 +++-- command/pr_diff.go | 9 ++------- command/pr_lookup.go | 32 +++++++++++++++++++++----------- command/pr_review.go | 7 +------ 7 files changed, 36 insertions(+), 75 deletions(-) diff --git a/api/queries_pr.go b/api/queries_pr.go index d1dedb82d..ad42deb0f 100644 --- a/api/queries_pr.go +++ b/api/queries_pr.go @@ -205,9 +205,9 @@ func (pr *PullRequest) ChecksStatus() (summary PullRequestChecksStatus) { return } -func (c Client) PullRequestDiff(baseRepo ghrepo.Interface, pr *PullRequest) (string, error) { +func (c Client) PullRequestDiff(baseRepo ghrepo.Interface, prNumber int) (string, error) { url := fmt.Sprintf("https://api.github.com/repos/%s/pulls/%d", - ghrepo.FullName(baseRepo), pr.Number) + ghrepo.FullName(baseRepo), prNumber) req, err := http.NewRequest("GET", url, nil) if err != nil { return "", err diff --git a/command/pr.go b/command/pr.go index 4daa2db0a..5bdf51e3f 100644 --- a/command/pr.go +++ b/command/pr.go @@ -305,17 +305,12 @@ func prView(cmd *cobra.Command, args []string) error { return err } - baseRepo, err := determineBaseRepo(apiClient, cmd, ctx) - if err != nil { - return err - } - web, err := cmd.Flags().GetBool("web") if err != nil { return err } - pr, err := prFromArgs(ctx, apiClient, baseRepo, args) + pr, _, err := prFromArgs(ctx, apiClient, cmd, args) if err != nil { return err } @@ -337,12 +332,7 @@ func prClose(cmd *cobra.Command, args []string) error { return err } - baseRepo, err := determineBaseRepo(apiClient, cmd, ctx) - if err != nil { - return err - } - - pr, err := prFromArgs(ctx, apiClient, baseRepo, args) + pr, baseRepo, err := prFromArgs(ctx, apiClient, cmd, args) if err != nil { return err } @@ -372,12 +362,7 @@ func prReopen(cmd *cobra.Command, args []string) error { return err } - baseRepo, err := determineBaseRepo(apiClient, cmd, ctx) - if err != nil { - return err - } - - pr, err := prFromArgs(ctx, apiClient, baseRepo, args) + pr, baseRepo, err := prFromArgs(ctx, apiClient, cmd, args) if err != nil { return err } @@ -409,12 +394,7 @@ func prMerge(cmd *cobra.Command, args []string) error { return err } - baseRepo, err := determineBaseRepo(apiClient, cmd, ctx) - if err != nil { - return err - } - - pr, err := prFromArgs(ctx, apiClient, baseRepo, args) + pr, baseRepo, err := prFromArgs(ctx, apiClient, cmd, args) if err != nil { return err } @@ -652,12 +632,7 @@ func prReady(cmd *cobra.Command, args []string) error { return err } - baseRepo, err := determineBaseRepo(apiClient, cmd, ctx) - if err != nil { - return err - } - - pr, err := prFromArgs(ctx, apiClient, baseRepo, args) + pr, baseRepo, err := prFromArgs(ctx, apiClient, cmd, args) if err != nil { return err } diff --git a/command/pr_checkout.go b/command/pr_checkout.go index 757df0fd5..8ffa39210 100644 --- a/command/pr_checkout.go +++ b/command/pr_checkout.go @@ -5,7 +5,6 @@ import ( "fmt" "os" "os/exec" - "regexp" "github.com/spf13/cobra" @@ -27,22 +26,7 @@ func prCheckout(cmd *cobra.Command, args []string) error { return err } - var baseRepo ghrepo.Interface - prString := args[0] - r := regexp.MustCompile(`^https://github\.com/([^/]+)/([^/]+)/pull/(\d+)`) - if m := r.FindStringSubmatch(prString); m != nil { - prString = m[3] - baseRepo = ghrepo.New(m[1], m[2]) - } - - if baseRepo == nil { - baseRepo, err = determineBaseRepo(apiClient, cmd, ctx) - if err != nil { - return err - } - } - - pr, err := prFromArgs(ctx, apiClient, baseRepo, []string{prString}) + pr, baseRepo, err := prFromArgs(ctx, apiClient, cmd, args) if err != nil { return err } @@ -61,6 +45,7 @@ func prCheckout(cmd *cobra.Command, args []string) error { var cmdQueue [][]string newBranchName := pr.HeadRefName + if headRemote != nil { // there is an existing git remote for PR head remoteBranch := fmt.Sprintf("%s/%s", headRemote.Name, pr.HeadRefName) diff --git a/command/pr_checkout_test.go b/command/pr_checkout_test.go index 851d6ed22..501af2447 100644 --- a/command/pr_checkout_test.go +++ b/command/pr_checkout_test.go @@ -76,6 +76,7 @@ func TestPRCheckout_urlArg(t *testing.T) { return ctx } http := initFakeHTTP() + http.StubRepoResponse("hubot", "REPO") http.StubResponse(200, bytes.NewBufferString(` { "data": { "repository": { "pullRequest": { @@ -125,7 +126,7 @@ func TestPRCheckout_urlArg_differentBase(t *testing.T) { return ctx } http := initFakeHTTP() - + http.StubRepoResponse("OWNER", "REPO") http.StubResponse(200, bytes.NewBufferString(` { "data": { "repository": { "pullRequest": { "number": 123, @@ -160,7 +161,7 @@ func TestPRCheckout_urlArg_differentBase(t *testing.T) { eq(t, err, nil) eq(t, output.String(), "") - bodyBytes, _ := ioutil.ReadAll(http.Requests[0].Body) + bodyBytes, _ := ioutil.ReadAll(http.Requests[1].Body) reqBody := struct { Variables struct { Owner string diff --git a/command/pr_diff.go b/command/pr_diff.go index 1f89d59cb..ae50a2618 100644 --- a/command/pr_diff.go +++ b/command/pr_diff.go @@ -36,17 +36,12 @@ func prDiff(cmd *cobra.Command, args []string) error { return err } - baseRepo, err := determineBaseRepo(apiClient, cmd, ctx) - if err != nil { - return fmt.Errorf("could not determine base repo: %w", err) - } - - pr, err := prFromArgs(ctx, apiClient, baseRepo, args) + pr, baseRepo, err := prFromArgs(ctx, apiClient, cmd, args) if err != nil { return fmt.Errorf("could not find pull request: %w", err) } - diff, err := apiClient.PullRequestDiff(baseRepo, pr) + diff, err := apiClient.PullRequestDiff(baseRepo, pr.Number) if err != nil { return fmt.Errorf("could not find pull request diff: %w", err) } diff --git a/command/pr_lookup.go b/command/pr_lookup.go index a5e380b47..7f9a86955 100644 --- a/command/pr_lookup.go +++ b/command/pr_lookup.go @@ -10,28 +10,36 @@ import ( "github.com/cli/cli/context" "github.com/cli/cli/git" "github.com/cli/cli/internal/ghrepo" + "github.com/spf13/cobra" ) -func prFromArgs(ctx context.Context, apiClient *api.Client, repo ghrepo.Interface, args []string) (*api.PullRequest, error) { - if len(args) == 0 { - return prForCurrentBranch(ctx, apiClient, repo) +func prFromArgs(ctx context.Context, apiClient *api.Client, cmd *cobra.Command, args []string) (*api.PullRequest, ghrepo.Interface, error) { + repo, err := determineBaseRepo(apiClient, cmd, ctx) + if err != nil { + return nil, nil, fmt.Errorf("could not determine base repo: %w", err) } - // First check to see if the prString is a url + if len(args) == 0 { + pr, err := prForCurrentBranch(ctx, apiClient, repo) + return pr, repo, err + } + + // First check to see if the prString is a url, return repo from url if found prString := args[0] - pr, err := prFromURL(ctx, apiClient, repo, prString) + pr, r, err := prFromURL(ctx, apiClient, prString) if pr != nil || err != nil { - return pr, err + return pr, r, err } // Next see if the prString is a number and use that to look up the url pr, err = prFromNumberString(ctx, apiClient, repo, prString) if pr != nil || err != nil { - return pr, err + return pr, repo, err } // Last see if it is a branch name - return api.PullRequestForBranch(apiClient, repo, "", prString) + pr, err = api.PullRequestForBranch(apiClient, repo, "", prString) + return pr, repo, err } func prFromNumberString(ctx context.Context, apiClient *api.Client, repo ghrepo.Interface, s string) (*api.PullRequest, error) { @@ -42,14 +50,16 @@ func prFromNumberString(ctx context.Context, apiClient *api.Client, repo ghrepo. return nil, nil } -func prFromURL(ctx context.Context, apiClient *api.Client, repo ghrepo.Interface, s string) (*api.PullRequest, error) { +func prFromURL(ctx context.Context, apiClient *api.Client, s string) (*api.PullRequest, ghrepo.Interface, error) { r := regexp.MustCompile(`^https://github\.com/([^/]+)/([^/]+)/pull/(\d+)`) if m := r.FindStringSubmatch(s); m != nil { + repo := ghrepo.New(m[1], m[2]) prNumberString := m[3] - return prFromNumberString(ctx, apiClient, repo, prNumberString) + pr, err := prFromNumberString(ctx, apiClient, repo, prNumberString) + return pr, repo, err } - return nil, nil + return nil, nil, nil } func prForCurrentBranch(ctx context.Context, apiClient *api.Client, repo ghrepo.Interface) (*api.PullRequest, error) { diff --git a/command/pr_review.go b/command/pr_review.go index b0a1f2062..c48790b6a 100644 --- a/command/pr_review.go +++ b/command/pr_review.go @@ -89,12 +89,7 @@ func prReview(cmd *cobra.Command, args []string) error { return err } - baseRepo, err := determineBaseRepo(apiClient, cmd, ctx) - if err != nil { - return fmt.Errorf("could not determine base repo: %w", err) - } - - pr, err := prFromArgs(ctx, apiClient, baseRepo, args) + pr, _, err := prFromArgs(ctx, apiClient, cmd, args) if err != nil { return err }