This commit is contained in:
Corey Johnson 2020-05-27 08:38:30 -07:00
parent 30eb2d3ad6
commit e341d7b49f
6 changed files with 22 additions and 32 deletions

View file

@ -315,7 +315,7 @@ func prView(cmd *cobra.Command, args []string) error {
return err
}
pr, err := prFromArgs(ctx, baseRepo, args...)
pr, err := prFromArgs(ctx, apiClient, baseRepo, args...)
if err != nil {
return err
}
@ -342,7 +342,7 @@ func prClose(cmd *cobra.Command, args []string) error {
return err
}
pr, err := prFromArgs(ctx, baseRepo, args...)
pr, err := prFromArgs(ctx, apiClient, baseRepo, args...)
if err != nil {
return err
}
@ -377,7 +377,7 @@ func prReopen(cmd *cobra.Command, args []string) error {
return err
}
pr, err := prFromArgs(ctx, baseRepo, args...)
pr, err := prFromArgs(ctx, apiClient, baseRepo, args...)
if err != nil {
return err
}
@ -414,7 +414,7 @@ func prMerge(cmd *cobra.Command, args []string) error {
return err
}
pr, err := prFromArgs(ctx, baseRepo, args...)
pr, err := prFromArgs(ctx, apiClient, baseRepo, args...)
if err != nil {
return err
}
@ -657,7 +657,7 @@ func prReady(cmd *cobra.Command, args []string) error {
return err
}
pr, err := prFromArgs(ctx, baseRepo, args...)
pr, err := prFromArgs(ctx, apiClient, baseRepo, args...)
if err != nil {
return err
}

View file

@ -42,7 +42,7 @@ func prCheckout(cmd *cobra.Command, args []string) error {
}
}
pr, err := prFromArgs(ctx, baseRepo, prString)
pr, err := prFromArgs(ctx, apiClient, baseRepo, prString)
if err != nil {
return err
}

View file

@ -41,7 +41,7 @@ func prDiff(cmd *cobra.Command, args []string) error {
return fmt.Errorf("could not determine base repo: %w", err)
}
pr, err := prFromArgs(ctx, baseRepo, args...)
pr, err := prFromArgs(ctx, apiClient, baseRepo, args...)
if err != nil {
return fmt.Errorf("could not find pull request: %w", err)
}

View file

@ -37,6 +37,11 @@ func TestPRDiff_argument_not_found(t *testing.T) {
initBlankContext("", "OWNER/REPO", "master")
http := initFakeHTTP()
http.StubRepoResponse("OWNER", "REPO")
http.StubResponse(200, bytes.NewBufferString(`
{ "data": { "repository": {
"pullRequest": { "number": 123 }
} } }
`))
http.StubResponse(404, bytes.NewBufferString(""))
_, err := RunCommand("pr diff 123")
if err == nil {

View file

@ -12,25 +12,20 @@ import (
"github.com/cli/cli/internal/ghrepo"
)
func prFromArgs(ctx context.Context, repo ghrepo.Interface, args ...string) (*api.PullRequest, error) {
apiClient, err := apiClientForContext(ctx)
if err != nil {
return nil, err
}
func prFromArgs(ctx context.Context, apiClient *api.Client, repo ghrepo.Interface, args ...string) (*api.PullRequest, error) {
if len(args) == 0 {
return prForCurrentBranch(ctx, repo)
return prForCurrentBranch(ctx, apiClient, repo)
}
// First check to see if the prString is a url
prString := args[0]
pr, err := prFromURL(ctx, repo, prString)
pr, err := prFromURL(ctx, apiClient, repo, prString)
if pr != nil || err != nil {
return pr, err
}
// Next see if the prString is a number and use that to look up the url
pr, err = prFromNumberString(ctx, repo, prString)
pr, err = prFromNumberString(ctx, apiClient, repo, prString)
if pr != nil || err != nil {
return pr, err
}
@ -39,12 +34,7 @@ func prFromArgs(ctx context.Context, repo ghrepo.Interface, args ...string) (*ap
return api.PullRequestForBranch(apiClient, repo, "", prString)
}
func prFromNumberString(ctx context.Context, repo ghrepo.Interface, s string) (*api.PullRequest, error) {
apiClient, err := apiClientForContext(ctx)
if err != nil {
return nil, err
}
func prFromNumberString(ctx context.Context, apiClient *api.Client, repo ghrepo.Interface, s string) (*api.PullRequest, error) {
if prNumber, err := strconv.Atoi(strings.TrimPrefix(s, "#")); err == nil {
return api.PullRequestByNumber(apiClient, repo, prNumber)
}
@ -52,22 +42,17 @@ func prFromNumberString(ctx context.Context, repo ghrepo.Interface, s string) (*
return nil, nil
}
func prFromURL(ctx context.Context, repo ghrepo.Interface, s string) (*api.PullRequest, error) {
func prFromURL(ctx context.Context, apiClient *api.Client, repo ghrepo.Interface, s string) (*api.PullRequest, error) {
r := regexp.MustCompile(`^https://github\.com/([^/]+)/([^/]+)/pull/(\d+)`)
if m := r.FindStringSubmatch(s); m != nil {
prNumberString := m[3]
return prFromNumberString(ctx, repo, prNumberString)
return prFromNumberString(ctx, apiClient, repo, prNumberString)
}
return nil, nil
}
func prForCurrentBranch(ctx context.Context, repo ghrepo.Interface) (*api.PullRequest, error) {
apiClient, err := apiClientForContext(ctx)
if err != nil {
return nil, err
}
func prForCurrentBranch(ctx context.Context, apiClient *api.Client, repo ghrepo.Interface) (*api.PullRequest, error) {
prHeadRef, err := ctx.Branch()
if err != nil {
return nil, err
@ -78,7 +63,7 @@ func prForCurrentBranch(ctx context.Context, repo ghrepo.Interface) (*api.PullRe
// the branch is configured to merge a special PR head ref
prHeadRE := regexp.MustCompile(`^refs/pull/(\d+)/head$`)
if m := prHeadRE.FindStringSubmatch(branchConfig.MergeRef); m != nil {
return prFromNumberString(ctx, repo, m[1])
return prFromNumberString(ctx, apiClient, repo, m[1])
}
var branchOwner string

View file

@ -94,7 +94,7 @@ func prReview(cmd *cobra.Command, args []string) error {
return fmt.Errorf("could not determine base repo: %w", err)
}
pr, err := prFromArgs(ctx, baseRepo, args...)
pr, err := prFromArgs(ctx, apiClient, baseRepo, args...)
if err != nil {
return err
}