Reset base branch when URL is used

This commit is contained in:
Corey Johnson 2020-06-03 14:34:13 -07:00
parent de6b1e0786
commit 4c75c8bccc
7 changed files with 36 additions and 75 deletions

View file

@ -205,9 +205,9 @@ func (pr *PullRequest) ChecksStatus() (summary PullRequestChecksStatus) {
return
}
func (c Client) PullRequestDiff(baseRepo ghrepo.Interface, pr *PullRequest) (string, error) {
func (c Client) PullRequestDiff(baseRepo ghrepo.Interface, prNumber int) (string, error) {
url := fmt.Sprintf("https://api.github.com/repos/%s/pulls/%d",
ghrepo.FullName(baseRepo), pr.Number)
ghrepo.FullName(baseRepo), prNumber)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return "", err

View file

@ -305,17 +305,12 @@ func prView(cmd *cobra.Command, args []string) error {
return err
}
baseRepo, err := determineBaseRepo(apiClient, cmd, ctx)
if err != nil {
return err
}
web, err := cmd.Flags().GetBool("web")
if err != nil {
return err
}
pr, err := prFromArgs(ctx, apiClient, baseRepo, args)
pr, _, err := prFromArgs(ctx, apiClient, cmd, args)
if err != nil {
return err
}
@ -337,12 +332,7 @@ func prClose(cmd *cobra.Command, args []string) error {
return err
}
baseRepo, err := determineBaseRepo(apiClient, cmd, ctx)
if err != nil {
return err
}
pr, err := prFromArgs(ctx, apiClient, baseRepo, args)
pr, baseRepo, err := prFromArgs(ctx, apiClient, cmd, args)
if err != nil {
return err
}
@ -372,12 +362,7 @@ func prReopen(cmd *cobra.Command, args []string) error {
return err
}
baseRepo, err := determineBaseRepo(apiClient, cmd, ctx)
if err != nil {
return err
}
pr, err := prFromArgs(ctx, apiClient, baseRepo, args)
pr, baseRepo, err := prFromArgs(ctx, apiClient, cmd, args)
if err != nil {
return err
}
@ -409,12 +394,7 @@ func prMerge(cmd *cobra.Command, args []string) error {
return err
}
baseRepo, err := determineBaseRepo(apiClient, cmd, ctx)
if err != nil {
return err
}
pr, err := prFromArgs(ctx, apiClient, baseRepo, args)
pr, baseRepo, err := prFromArgs(ctx, apiClient, cmd, args)
if err != nil {
return err
}
@ -652,12 +632,7 @@ func prReady(cmd *cobra.Command, args []string) error {
return err
}
baseRepo, err := determineBaseRepo(apiClient, cmd, ctx)
if err != nil {
return err
}
pr, err := prFromArgs(ctx, apiClient, baseRepo, args)
pr, baseRepo, err := prFromArgs(ctx, apiClient, cmd, args)
if err != nil {
return err
}

View file

@ -5,7 +5,6 @@ import (
"fmt"
"os"
"os/exec"
"regexp"
"github.com/spf13/cobra"
@ -27,22 +26,7 @@ func prCheckout(cmd *cobra.Command, args []string) error {
return err
}
var baseRepo ghrepo.Interface
prString := args[0]
r := regexp.MustCompile(`^https://github\.com/([^/]+)/([^/]+)/pull/(\d+)`)
if m := r.FindStringSubmatch(prString); m != nil {
prString = m[3]
baseRepo = ghrepo.New(m[1], m[2])
}
if baseRepo == nil {
baseRepo, err = determineBaseRepo(apiClient, cmd, ctx)
if err != nil {
return err
}
}
pr, err := prFromArgs(ctx, apiClient, baseRepo, []string{prString})
pr, baseRepo, err := prFromArgs(ctx, apiClient, cmd, args)
if err != nil {
return err
}
@ -61,6 +45,7 @@ func prCheckout(cmd *cobra.Command, args []string) error {
var cmdQueue [][]string
newBranchName := pr.HeadRefName
if headRemote != nil {
// there is an existing git remote for PR head
remoteBranch := fmt.Sprintf("%s/%s", headRemote.Name, pr.HeadRefName)

View file

@ -76,6 +76,7 @@ func TestPRCheckout_urlArg(t *testing.T) {
return ctx
}
http := initFakeHTTP()
http.StubRepoResponse("hubot", "REPO")
http.StubResponse(200, bytes.NewBufferString(`
{ "data": { "repository": { "pullRequest": {
@ -125,7 +126,7 @@ func TestPRCheckout_urlArg_differentBase(t *testing.T) {
return ctx
}
http := initFakeHTTP()
http.StubRepoResponse("OWNER", "REPO")
http.StubResponse(200, bytes.NewBufferString(`
{ "data": { "repository": { "pullRequest": {
"number": 123,
@ -160,7 +161,7 @@ func TestPRCheckout_urlArg_differentBase(t *testing.T) {
eq(t, err, nil)
eq(t, output.String(), "")
bodyBytes, _ := ioutil.ReadAll(http.Requests[0].Body)
bodyBytes, _ := ioutil.ReadAll(http.Requests[1].Body)
reqBody := struct {
Variables struct {
Owner string

View file

@ -36,17 +36,12 @@ func prDiff(cmd *cobra.Command, args []string) error {
return err
}
baseRepo, err := determineBaseRepo(apiClient, cmd, ctx)
if err != nil {
return fmt.Errorf("could not determine base repo: %w", err)
}
pr, err := prFromArgs(ctx, apiClient, baseRepo, args)
pr, baseRepo, err := prFromArgs(ctx, apiClient, cmd, args)
if err != nil {
return fmt.Errorf("could not find pull request: %w", err)
}
diff, err := apiClient.PullRequestDiff(baseRepo, pr)
diff, err := apiClient.PullRequestDiff(baseRepo, pr.Number)
if err != nil {
return fmt.Errorf("could not find pull request diff: %w", err)
}

View file

@ -10,28 +10,36 @@ import (
"github.com/cli/cli/context"
"github.com/cli/cli/git"
"github.com/cli/cli/internal/ghrepo"
"github.com/spf13/cobra"
)
func prFromArgs(ctx context.Context, apiClient *api.Client, repo ghrepo.Interface, args []string) (*api.PullRequest, error) {
if len(args) == 0 {
return prForCurrentBranch(ctx, apiClient, repo)
func prFromArgs(ctx context.Context, apiClient *api.Client, cmd *cobra.Command, args []string) (*api.PullRequest, ghrepo.Interface, error) {
repo, err := determineBaseRepo(apiClient, cmd, ctx)
if err != nil {
return nil, nil, fmt.Errorf("could not determine base repo: %w", err)
}
// First check to see if the prString is a url
if len(args) == 0 {
pr, err := prForCurrentBranch(ctx, apiClient, repo)
return pr, repo, err
}
// First check to see if the prString is a url, return repo from url if found
prString := args[0]
pr, err := prFromURL(ctx, apiClient, repo, prString)
pr, r, err := prFromURL(ctx, apiClient, prString)
if pr != nil || err != nil {
return pr, err
return pr, r, err
}
// Next see if the prString is a number and use that to look up the url
pr, err = prFromNumberString(ctx, apiClient, repo, prString)
if pr != nil || err != nil {
return pr, err
return pr, repo, err
}
// Last see if it is a branch name
return api.PullRequestForBranch(apiClient, repo, "", prString)
pr, err = api.PullRequestForBranch(apiClient, repo, "", prString)
return pr, repo, err
}
func prFromNumberString(ctx context.Context, apiClient *api.Client, repo ghrepo.Interface, s string) (*api.PullRequest, error) {
@ -42,14 +50,16 @@ func prFromNumberString(ctx context.Context, apiClient *api.Client, repo ghrepo.
return nil, nil
}
func prFromURL(ctx context.Context, apiClient *api.Client, repo ghrepo.Interface, s string) (*api.PullRequest, error) {
func prFromURL(ctx context.Context, apiClient *api.Client, s string) (*api.PullRequest, ghrepo.Interface, error) {
r := regexp.MustCompile(`^https://github\.com/([^/]+)/([^/]+)/pull/(\d+)`)
if m := r.FindStringSubmatch(s); m != nil {
repo := ghrepo.New(m[1], m[2])
prNumberString := m[3]
return prFromNumberString(ctx, apiClient, repo, prNumberString)
pr, err := prFromNumberString(ctx, apiClient, repo, prNumberString)
return pr, repo, err
}
return nil, nil
return nil, nil, nil
}
func prForCurrentBranch(ctx context.Context, apiClient *api.Client, repo ghrepo.Interface) (*api.PullRequest, error) {

View file

@ -89,12 +89,7 @@ func prReview(cmd *cobra.Command, args []string) error {
return err
}
baseRepo, err := determineBaseRepo(apiClient, cmd, ctx)
if err != nil {
return fmt.Errorf("could not determine base repo: %w", err)
}
pr, err := prFromArgs(ctx, apiClient, baseRepo, args)
pr, _, err := prFromArgs(ctx, apiClient, cmd, args)
if err != nil {
return err
}