cli/pkg/cmd/pr/shared/finder.go

590 lines
16 KiB
Go

package shared
import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
"regexp"
"sort"
"strconv"
"strings"
"time"
"github.com/cli/cli/v2/api"
remotes "github.com/cli/cli/v2/context"
"github.com/cli/cli/v2/git"
fd "github.com/cli/cli/v2/internal/featuredetection"
"github.com/cli/cli/v2/internal/ghrepo"
"github.com/cli/cli/v2/pkg/cmdutil"
"github.com/cli/cli/v2/pkg/set"
"github.com/shurcooL/githubv4"
"golang.org/x/sync/errgroup"
)
type PRFinder interface {
Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, error)
}
type progressIndicator interface {
StartProgressIndicator()
StopProgressIndicator()
}
type finder struct {
baseRepoFn func() (ghrepo.Interface, error)
branchFn func() (string, error)
remotesFn func() (remotes.Remotes, error)
httpClient func() (*http.Client, error)
branchConfig func(string) git.BranchConfig
progress progressIndicator
repo ghrepo.Interface
prNumber int
branchName string
}
func NewFinder(factory *cmdutil.Factory) PRFinder {
if runCommandFinder != nil {
f := runCommandFinder
runCommandFinder = &mockFinder{err: errors.New("you must use a RunCommandFinder to stub PR lookups")}
return f
}
return &finder{
baseRepoFn: factory.BaseRepo,
branchFn: factory.Branch,
remotesFn: factory.Remotes,
httpClient: factory.HttpClient,
progress: factory.IOStreams,
branchConfig: func(s string) git.BranchConfig {
branchConfig, _ := factory.GitClient.ReadBranchConfig(context.Background(), s)
return branchConfig
},
}
}
var runCommandFinder PRFinder
// RunCommandFinder is the NewMockFinder substitute to be used ONLY in runCommand-style tests.
func RunCommandFinder(selector string, pr *api.PullRequest, repo ghrepo.Interface) *mockFinder {
finder := NewMockFinder(selector, pr, repo)
runCommandFinder = finder
return finder
}
type FindOptions struct {
// Selector can be a number with optional `#` prefix, a branch name with optional `<owner>:` prefix, or
// a PR URL.
Selector string
// Fields lists the GraphQL fields to fetch for the PullRequest.
Fields []string
// BaseBranch is the name of the base branch to scope the PR-for-branch lookup to.
BaseBranch string
// States lists the possible PR states to scope the PR-for-branch lookup to.
States []string
}
func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, error) {
if len(opts.Fields) == 0 {
return nil, nil, errors.New("Find error: no fields specified")
}
if repo, prNumber, err := f.parseURL(opts.Selector); err == nil {
f.prNumber = prNumber
f.repo = repo
}
if f.repo == nil {
repo, err := f.baseRepoFn()
if err != nil {
return nil, nil, err
}
f.repo = repo
}
if opts.Selector == "" {
if branch, prNumber, err := f.parseCurrentBranch(); err != nil {
return nil, nil, err
} else if prNumber > 0 {
f.prNumber = prNumber
} else {
f.branchName = branch
}
} else if f.prNumber == 0 {
// If opts.Selector is a valid number then assume it is the
// PR number unless opts.BaseBranch is specified. This is a
// special case for PR create command which will always want
// to assume that a numerical selector is a branch name rather
// than PR number.
prNumber, err := strconv.Atoi(strings.TrimPrefix(opts.Selector, "#"))
if opts.BaseBranch == "" && err == nil {
f.prNumber = prNumber
} else {
f.branchName = opts.Selector
}
}
httpClient, err := f.httpClient()
if err != nil {
return nil, nil, err
}
// TODO(josebalius): Should we be guarding here?
if f.progress != nil {
f.progress.StartProgressIndicator()
defer f.progress.StopProgressIndicator()
}
fields := set.NewStringSet()
fields.AddValues(opts.Fields)
numberFieldOnly := fields.Len() == 1 && fields.Contains("number")
fields.AddValues([]string{"id", "number"}) // for additional preload queries below
if fields.Contains("isInMergeQueue") || fields.Contains("isMergeQueueEnabled") {
cachedClient := api.NewCachedHTTPClient(httpClient, time.Hour*24)
detector := fd.NewDetector(cachedClient, f.repo.RepoHost())
prFeatures, err := detector.PullRequestFeatures()
if err != nil {
return nil, nil, err
}
if !prFeatures.MergeQueue {
fields.Remove("isInMergeQueue")
fields.Remove("isMergeQueueEnabled")
}
}
var getProjectItems bool
if fields.Contains("projectItems") {
getProjectItems = true
fields.Remove("projectItems")
}
var pr *api.PullRequest
if f.prNumber > 0 {
if numberFieldOnly {
// avoid hitting the API if we already have all the information
return &api.PullRequest{Number: f.prNumber}, f.repo, nil
}
pr, err = findByNumber(httpClient, f.repo, f.prNumber, fields.ToSlice())
} else {
pr, err = findForBranch(httpClient, f.repo, opts.BaseBranch, f.branchName, opts.States, fields.ToSlice())
}
if err != nil {
return pr, f.repo, err
}
g, _ := errgroup.WithContext(context.Background())
if fields.Contains("reviews") {
g.Go(func() error {
return preloadPrReviews(httpClient, f.repo, pr)
})
}
if fields.Contains("comments") {
g.Go(func() error {
return preloadPrComments(httpClient, f.repo, pr)
})
}
if fields.Contains("statusCheckRollup") {
g.Go(func() error {
return preloadPrChecks(httpClient, f.repo, pr)
})
}
if getProjectItems {
g.Go(func() error {
apiClient := api.NewClientFromHTTP(httpClient)
err := api.ProjectsV2ItemsForPullRequest(apiClient, f.repo, pr)
if err != nil && !api.ProjectsV2IgnorableError(err) {
return err
}
return nil
})
}
return pr, f.repo, g.Wait()
}
var pullURLRE = regexp.MustCompile(`^/([^/]+)/([^/]+)/pull/(\d+)`)
func (f *finder) parseURL(prURL string) (ghrepo.Interface, int, error) {
if prURL == "" {
return nil, 0, fmt.Errorf("invalid URL: %q", prURL)
}
u, err := url.Parse(prURL)
if err != nil {
return nil, 0, err
}
if u.Scheme != "https" && u.Scheme != "http" {
return nil, 0, fmt.Errorf("invalid scheme: %s", u.Scheme)
}
m := pullURLRE.FindStringSubmatch(u.Path)
if m == nil {
return nil, 0, fmt.Errorf("not a pull request URL: %s", prURL)
}
repo := ghrepo.NewWithHost(m[1], m[2], u.Hostname())
prNumber, _ := strconv.Atoi(m[3])
return repo, prNumber, nil
}
var prHeadRE = regexp.MustCompile(`^refs/pull/(\d+)/head$`)
func (f *finder) parseCurrentBranch() (string, int, error) {
prHeadRef, err := f.branchFn()
if err != nil {
return "", 0, err
}
branchConfig := f.branchConfig(prHeadRef)
// the branch is configured to merge a special PR head ref
if m := prHeadRE.FindStringSubmatch(branchConfig.MergeRef); m != nil {
prNumber, _ := strconv.Atoi(m[1])
return "", prNumber, nil
}
var gitRemoteRepo ghrepo.Interface
if branchConfig.RemoteURL != nil {
// the branch merges from a remote specified by URL
if r, err := ghrepo.FromURL(branchConfig.RemoteURL); err == nil {
gitRemoteRepo = r
}
} else if branchConfig.RemoteName != "" {
// the branch merges from a remote specified by name
rem, _ := f.remotesFn()
if r, err := rem.FindByName(branchConfig.RemoteName); err == nil {
gitRemoteRepo = r
}
}
if gitRemoteRepo != nil {
if strings.HasPrefix(branchConfig.MergeRef, "refs/heads/") {
prHeadRef = strings.TrimPrefix(branchConfig.MergeRef, "refs/heads/")
}
// prepend `OWNER:` if this branch is pushed to a fork
// This is determined by:
// - The repo having a different owner
// - The repo having the same owner but a different name (private org fork)
// I suspect that the implementation of the second case may be broken in the face
// of a repo rename, where the remote hasn't been updated locally. This is a
// frequent issue in commands that use SmartBaseRepoFunc. It's not any worse than not
// supporting this case at all though.
sameOwner := strings.EqualFold(gitRemoteRepo.RepoOwner(), f.repo.RepoOwner())
sameOwnerDifferentRepoName := sameOwner && !strings.EqualFold(gitRemoteRepo.RepoName(), f.repo.RepoName())
if !sameOwner || sameOwnerDifferentRepoName {
prHeadRef = fmt.Sprintf("%s:%s", gitRemoteRepo.RepoOwner(), prHeadRef)
}
}
return prHeadRef, 0, nil
}
func findByNumber(httpClient *http.Client, repo ghrepo.Interface, number int, fields []string) (*api.PullRequest, error) {
type response struct {
Repository struct {
PullRequest api.PullRequest
}
}
query := fmt.Sprintf(`
query PullRequestByNumber($owner: String!, $repo: String!, $pr_number: Int!) {
repository(owner: $owner, name: $repo) {
pullRequest(number: $pr_number) {%s}
}
}`, api.PullRequestGraphQL(fields))
variables := map[string]interface{}{
"owner": repo.RepoOwner(),
"repo": repo.RepoName(),
"pr_number": number,
}
var resp response
client := api.NewClientFromHTTP(httpClient)
err := client.GraphQL(repo.RepoHost(), query, variables, &resp)
if err != nil {
return nil, err
}
return &resp.Repository.PullRequest, nil
}
func findForBranch(httpClient *http.Client, repo ghrepo.Interface, baseBranch, headBranch string, stateFilters, fields []string) (*api.PullRequest, error) {
type response struct {
Repository struct {
PullRequests struct {
Nodes []api.PullRequest
}
DefaultBranchRef struct {
Name string
}
}
}
fieldSet := set.NewStringSet()
fieldSet.AddValues(fields)
// these fields are required for filtering below
fieldSet.AddValues([]string{"state", "baseRefName", "headRefName", "isCrossRepository", "headRepositoryOwner"})
query := fmt.Sprintf(`
query PullRequestForBranch($owner: String!, $repo: String!, $headRefName: String!, $states: [PullRequestState!]) {
repository(owner: $owner, name: $repo) {
pullRequests(headRefName: $headRefName, states: $states, first: 30, orderBy: { field: CREATED_AT, direction: DESC }) {
nodes {%s}
}
defaultBranchRef { name }
}
}`, api.PullRequestGraphQL(fieldSet.ToSlice()))
branchWithoutOwner := headBranch
if idx := strings.Index(headBranch, ":"); idx >= 0 {
branchWithoutOwner = headBranch[idx+1:]
}
variables := map[string]interface{}{
"owner": repo.RepoOwner(),
"repo": repo.RepoName(),
"headRefName": branchWithoutOwner,
"states": stateFilters,
}
var resp response
client := api.NewClientFromHTTP(httpClient)
err := client.GraphQL(repo.RepoHost(), query, variables, &resp)
if err != nil {
return nil, err
}
prs := resp.Repository.PullRequests.Nodes
sort.SliceStable(prs, func(a, b int) bool {
return prs[a].State == "OPEN" && prs[b].State != "OPEN"
})
for _, pr := range prs {
headBranchMatches := pr.HeadLabel() == headBranch
baseBranchEmptyOrMatches := baseBranch == "" || pr.BaseRefName == baseBranch
// When the head is the default branch, it doesn't really make sense to show merged or closed PRs.
// https://github.com/cli/cli/issues/4263
isNotClosedOrMergedWhenHeadIsDefault := pr.State == "OPEN" || resp.Repository.DefaultBranchRef.Name != headBranch
if headBranchMatches && baseBranchEmptyOrMatches && isNotClosedOrMergedWhenHeadIsDefault {
return &pr, nil
}
}
return nil, &NotFoundError{fmt.Errorf("no pull requests found for branch %q", headBranch)}
}
func preloadPrReviews(httpClient *http.Client, repo ghrepo.Interface, pr *api.PullRequest) error {
if !pr.Reviews.PageInfo.HasNextPage {
return nil
}
type response struct {
Node struct {
PullRequest struct {
Reviews api.PullRequestReviews `graphql:"reviews(first: 100, after: $endCursor)"`
} `graphql:"...on PullRequest"`
} `graphql:"node(id: $id)"`
}
variables := map[string]interface{}{
"id": githubv4.ID(pr.ID),
"endCursor": githubv4.String(pr.Reviews.PageInfo.EndCursor),
}
gql := api.NewClientFromHTTP(httpClient)
for {
var query response
err := gql.Query(repo.RepoHost(), "ReviewsForPullRequest", &query, variables)
if err != nil {
return err
}
pr.Reviews.Nodes = append(pr.Reviews.Nodes, query.Node.PullRequest.Reviews.Nodes...)
pr.Reviews.TotalCount = len(pr.Reviews.Nodes)
if !query.Node.PullRequest.Reviews.PageInfo.HasNextPage {
break
}
variables["endCursor"] = githubv4.String(query.Node.PullRequest.Reviews.PageInfo.EndCursor)
}
pr.Reviews.PageInfo.HasNextPage = false
return nil
}
func preloadPrComments(client *http.Client, repo ghrepo.Interface, pr *api.PullRequest) error {
if !pr.Comments.PageInfo.HasNextPage {
return nil
}
type response struct {
Node struct {
PullRequest struct {
Comments api.Comments `graphql:"comments(first: 100, after: $endCursor)"`
} `graphql:"...on PullRequest"`
} `graphql:"node(id: $id)"`
}
variables := map[string]interface{}{
"id": githubv4.ID(pr.ID),
"endCursor": githubv4.String(pr.Comments.PageInfo.EndCursor),
}
gql := api.NewClientFromHTTP(client)
for {
var query response
err := gql.Query(repo.RepoHost(), "CommentsForPullRequest", &query, variables)
if err != nil {
return err
}
pr.Comments.Nodes = append(pr.Comments.Nodes, query.Node.PullRequest.Comments.Nodes...)
pr.Comments.TotalCount = len(pr.Comments.Nodes)
if !query.Node.PullRequest.Comments.PageInfo.HasNextPage {
break
}
variables["endCursor"] = githubv4.String(query.Node.PullRequest.Comments.PageInfo.EndCursor)
}
pr.Comments.PageInfo.HasNextPage = false
return nil
}
func preloadPrChecks(client *http.Client, repo ghrepo.Interface, pr *api.PullRequest) error {
if len(pr.StatusCheckRollup.Nodes) == 0 {
return nil
}
statusCheckRollup := &pr.StatusCheckRollup.Nodes[0].Commit.StatusCheckRollup.Contexts
if !statusCheckRollup.PageInfo.HasNextPage {
return nil
}
endCursor := statusCheckRollup.PageInfo.EndCursor
type response struct {
Node *api.PullRequest
}
query := fmt.Sprintf(`
query PullRequestStatusChecks($id: ID!, $endCursor: String!) {
node(id: $id) {
...on PullRequest {
%s
}
}
}`, api.StatusCheckRollupGraphQLWithoutCountByState("$endCursor"))
variables := map[string]interface{}{
"id": pr.ID,
}
apiClient := api.NewClientFromHTTP(client)
for {
variables["endCursor"] = endCursor
var resp response
err := apiClient.GraphQL(repo.RepoHost(), query, variables, &resp)
if err != nil {
return err
}
result := resp.Node.StatusCheckRollup.Nodes[0].Commit.StatusCheckRollup.Contexts
statusCheckRollup.Nodes = append(
statusCheckRollup.Nodes,
result.Nodes...,
)
if !result.PageInfo.HasNextPage {
break
}
endCursor = result.PageInfo.EndCursor
}
statusCheckRollup.PageInfo.HasNextPage = false
return nil
}
type NotFoundError struct {
error
}
func (err *NotFoundError) Unwrap() error {
return err.error
}
func NewMockFinder(selector string, pr *api.PullRequest, repo ghrepo.Interface) *mockFinder {
var err error
if pr == nil {
err = &NotFoundError{errors.New("no pull requests found")}
}
return &mockFinder{
expectSelector: selector,
pr: pr,
repo: repo,
err: err,
}
}
type mockFinder struct {
called bool
expectSelector string
expectFields []string
pr *api.PullRequest
repo ghrepo.Interface
err error
}
func (m *mockFinder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, error) {
if m.err != nil {
return nil, nil, m.err
}
if m.expectSelector != opts.Selector {
return nil, nil, fmt.Errorf("mockFinder: expected selector %q, got %q", m.expectSelector, opts.Selector)
}
if len(m.expectFields) > 0 && !isEqualSet(m.expectFields, opts.Fields) {
return nil, nil, fmt.Errorf("mockFinder: expected fields %v, got %v", m.expectFields, opts.Fields)
}
if m.called {
return nil, nil, errors.New("mockFinder used more than once")
}
m.called = true
if m.pr.HeadRepositoryOwner.Login == "" {
// pose as same-repo PR by default
m.pr.HeadRepositoryOwner.Login = m.repo.RepoOwner()
}
return m.pr, m.repo, nil
}
func (m *mockFinder) ExpectFields(fields []string) {
m.expectFields = fields
}
func isEqualSet(a, b []string) bool {
if len(a) != len(b) {
return false
}
aCopy := make([]string, len(a))
copy(aCopy, a)
bCopy := make([]string, len(b))
copy(bCopy, b)
sort.Strings(aCopy)
sort.Strings(bCopy)
for i := range aCopy {
if aCopy[i] != bCopy[i] {
return false
}
}
return true
}