cli/pkg/cmd/pr/shared/finder.go
Mislav Marohnić 9bdc63c4ca Eliminate API overfetching in pr commands
This completely rewrites the PR lookup mechanism so that the caller
must specify the GraphQL fields to query for each PR. Additionally, this
fixes some export problems with `pr view --json`.

Features:

- Each pr command now gets assigned a concept of a Finder. This makes it
  easier to stub the PR in tests without having to stub the underlying
  HTTP calls or git invocations.

- `pr view --web` is much faster since it only fetches the "url" field.

- `pr diff 123` now skips a whole API call where a whole PR was
  unnecessarily preloaded just to access its diff in a subsequent call.

- PullRequestGraphQL query builder is now used to construct queries.

- A bunch of individual commands are now freed of having to know about
  concepts such as BaseRepo, Branch, Config, or Remotes.
2021-04-30 20:34:36 +02:00

328 lines
8.5 KiB
Go

package shared
import (
"errors"
"fmt"
"net/http"
"net/url"
"regexp"
"sort"
"strconv"
"strings"
"github.com/cli/cli/api"
"github.com/cli/cli/context"
"github.com/cli/cli/git"
"github.com/cli/cli/internal/ghrepo"
"github.com/cli/cli/pkg/cmdutil"
"github.com/cli/cli/pkg/set"
)
type PRFinder interface {
Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, error)
}
type progressIndicator interface {
StartProgressIndicator()
StopProgressIndicator()
}
type finder struct {
baseRepoFn func() (ghrepo.Interface, error)
branchFn func() (string, error)
remotesFn func() (context.Remotes, error)
httpClient func() (*http.Client, error)
progress progressIndicator
repo ghrepo.Interface
prNumber int
branchName string
}
func NewFinder(factory *cmdutil.Factory) PRFinder {
if runCommandFinder != nil {
f := runCommandFinder
runCommandFinder = &mockFinder{err: errors.New("you must use a RunCommandFinder to stub PR lookups")}
return f
}
return &finder{
baseRepoFn: factory.BaseRepo,
branchFn: factory.Branch,
remotesFn: factory.Remotes,
httpClient: factory.HttpClient,
progress: factory.IOStreams,
}
}
var runCommandFinder PRFinder
// RunCommandFinder is the NewMockFinder substitute to be used ONLY in runCommand-style tests.
func RunCommandFinder(selector string, pr *api.PullRequest, repo ghrepo.Interface) {
runCommandFinder = NewMockFinder(selector, pr, repo)
}
type FindOptions struct {
// Selector can be a number with optional `#` prefix, a branch name with optional `<owner>:` prefix, or
// a PR URL.
Selector string
// Fields lists the GraphQL fields to fetch for the PullRequest.
Fields []string
// BaseBranch is the name of the base branch to scope the PR-for-branch lookup to.
BaseBranch string
// States lists the possible PR states to scope the PR-for-branch lookup to.
States []string
}
func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, error) {
if len(opts.Fields) == 0 {
return nil, nil, errors.New("Find error: no fields specified")
}
_ = f.parseURL(opts.Selector)
if f.repo == nil {
repo, err := f.baseRepoFn()
if err != nil {
return nil, nil, fmt.Errorf("could not determine base repo: %w", err)
}
f.repo = repo
}
if opts.Selector == "" {
if err := f.parseCurrentBranch(); err != nil {
return nil, nil, err
}
} else if f.prNumber == 0 {
if prNumber, err := strconv.Atoi(strings.TrimPrefix(opts.Selector, "#")); err == nil {
f.prNumber = prNumber
} else {
f.branchName = opts.Selector
}
}
httpClient, err := f.httpClient()
if err != nil {
return nil, nil, err
}
if f.progress != nil {
f.progress.StartProgressIndicator()
defer f.progress.StopProgressIndicator()
}
if f.prNumber > 0 {
if len(opts.Fields) == 1 && opts.Fields[0] == "number" {
// avoid hitting the API if we already have all the information
return &api.PullRequest{Number: f.prNumber}, f.repo, nil
}
pr, err := findByNumber(httpClient, f.repo, f.prNumber, opts.Fields)
return pr, f.repo, err
}
pr, err := findForBranch(httpClient, f.repo, opts.BaseBranch, f.branchName, opts.States, opts.Fields)
// TODO: preload view: api.ReviewsForPullRequest, api.CommentsForPullRequest
// TODO: preload checks: get all checks
return pr, f.repo, err
}
var pullURLRE = regexp.MustCompile(`^/([^/]+)/([^/]+)/pull/(\d+)`)
func (f *finder) parseURL(prURL string) error {
if prURL == "" {
return fmt.Errorf("invalid URL: %q", prURL)
}
u, err := url.Parse(prURL)
if err != nil {
return err
}
if u.Scheme != "https" && u.Scheme != "http" {
return fmt.Errorf("invalid scheme: %s", u.Scheme)
}
m := pullURLRE.FindStringSubmatch(u.Path)
if m == nil {
return fmt.Errorf("not a pull request URL: %s", prURL)
}
f.repo = ghrepo.NewWithHost(m[1], m[2], u.Hostname())
f.prNumber, _ = strconv.Atoi(m[3])
return nil
}
var prHeadRE = regexp.MustCompile(`^refs/pull/(\d+)/head$`)
func (f *finder) parseCurrentBranch() error {
prHeadRef, err := f.branchFn()
if err != nil {
return err
}
branchConfig := git.ReadBranchConfig(prHeadRef)
// the branch is configured to merge a special PR head ref
if m := prHeadRE.FindStringSubmatch(branchConfig.MergeRef); m != nil {
f.prNumber, _ = strconv.Atoi(m[1])
return nil
}
var branchOwner string
if branchConfig.RemoteURL != nil {
// the branch merges from a remote specified by URL
if r, err := ghrepo.FromURL(branchConfig.RemoteURL); err == nil {
branchOwner = r.RepoOwner()
}
} else if branchConfig.RemoteName != "" {
// the branch merges from a remote specified by name
rem, _ := f.remotesFn()
if r, err := rem.FindByName(branchConfig.RemoteName); err == nil {
branchOwner = r.RepoOwner()
}
}
if branchOwner != "" {
if strings.HasPrefix(branchConfig.MergeRef, "refs/heads/") {
prHeadRef = strings.TrimPrefix(branchConfig.MergeRef, "refs/heads/")
}
// prepend `OWNER:` if this branch is pushed to a fork
if !strings.EqualFold(branchOwner, f.repo.RepoOwner()) {
prHeadRef = fmt.Sprintf("%s:%s", branchOwner, prHeadRef)
}
}
f.branchName = prHeadRef
return nil
}
func findByNumber(httpClient *http.Client, repo ghrepo.Interface, number int, fields []string) (*api.PullRequest, error) {
type response struct {
Repository struct {
PullRequest api.PullRequest
}
}
query := fmt.Sprintf(`
query PullRequestByNumber($owner: String!, $repo: String!, $pr_number: Int!) {
repository(owner: $owner, name: $repo) {
pullRequest(number: $pr_number) {%s}
}
}`, api.PullRequestGraphQL(fields))
variables := map[string]interface{}{
"owner": repo.RepoOwner(),
"repo": repo.RepoName(),
"pr_number": number,
}
var resp response
client := api.NewClientFromHTTP(httpClient)
err := client.GraphQL(repo.RepoHost(), query, variables, &resp)
if err != nil {
return nil, err
}
return &resp.Repository.PullRequest, nil
}
func findForBranch(httpClient *http.Client, repo ghrepo.Interface, baseBranch, headBranch string, stateFilters, fields []string) (*api.PullRequest, error) {
type response struct {
Repository struct {
PullRequests struct {
Nodes []api.PullRequest
}
}
}
fieldSet := set.NewStringSet()
fieldSet.AddValues(fields)
// these fields are required for filtering below
fieldSet.AddValues([]string{"state", "baseRefName", "headRefName", "isCrossRepository", "headRepositoryOwner"})
query := fmt.Sprintf(`
query PullRequestForBranch($owner: String!, $repo: String!, $headRefName: String!, $states: [PullRequestState!]) {
repository(owner: $owner, name: $repo) {
pullRequests(headRefName: $headRefName, states: $states, first: 30, orderBy: { field: CREATED_AT, direction: DESC }) {
nodes {%s}
}
}
}`, api.PullRequestGraphQL(fieldSet.ToSlice()))
branchWithoutOwner := headBranch
if idx := strings.Index(headBranch, ":"); idx >= 0 {
branchWithoutOwner = headBranch[idx+1:]
}
variables := map[string]interface{}{
"owner": repo.RepoOwner(),
"repo": repo.RepoName(),
"headRefName": branchWithoutOwner,
"states": stateFilters,
}
var resp response
client := api.NewClientFromHTTP(httpClient)
err := client.GraphQL(repo.RepoHost(), query, variables, &resp)
if err != nil {
return nil, err
}
prs := resp.Repository.PullRequests.Nodes
sort.SliceStable(prs, func(a, b int) bool {
return prs[a].State == "OPEN" && prs[b].State != "OPEN"
})
for _, pr := range prs {
if pr.HeadLabel() == headBranch && (baseBranch == "" || pr.BaseRefName == baseBranch) {
return &pr, nil
}
}
return nil, &NotFoundError{fmt.Errorf("no pull requests found for branch %q", headBranch)}
}
type NotFoundError struct {
error
}
func (err *NotFoundError) Unwrap() error {
return err.error
}
func NewMockFinder(selector string, pr *api.PullRequest, repo ghrepo.Interface) PRFinder {
return &mockFinder{
expectSelector: selector,
pr: pr,
repo: repo,
}
}
type mockFinder struct {
called bool
expectSelector string
pr *api.PullRequest
repo ghrepo.Interface
err error
}
func (m *mockFinder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, error) {
if m.err != nil {
return nil, nil, m.err
}
if m.expectSelector != opts.Selector {
return nil, nil, fmt.Errorf("mockFinder: expected selector %q, got %q", m.expectSelector, opts.Selector)
}
if m.called {
return nil, nil, errors.New("mockFinder used more than once")
}
m.called = true
if m.pr.HeadRepositoryOwner.Login == "" {
// pose as same-repo PR by default
m.pr.HeadRepositoryOwner.Login = m.repo.RepoOwner()
}
return m.pr, m.repo, nil
}