cli/pkg/cmd/pr/shared/finder.go
2025-01-24 10:20:04 -08:00

648 lines
18 KiB
Go

package shared
import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
"regexp"
"sort"
"strconv"
"strings"
"time"
"github.com/cli/cli/v2/api"
remotes "github.com/cli/cli/v2/context"
"github.com/cli/cli/v2/git"
fd "github.com/cli/cli/v2/internal/featuredetection"
"github.com/cli/cli/v2/internal/ghrepo"
"github.com/cli/cli/v2/pkg/cmdutil"
"github.com/cli/cli/v2/pkg/set"
"github.com/shurcooL/githubv4"
"golang.org/x/sync/errgroup"
)
type PRFinder interface {
Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, error)
}
type progressIndicator interface {
StartProgressIndicator()
StopProgressIndicator()
}
type finder struct {
baseRepoFn func() (ghrepo.Interface, error)
branchFn func() (string, error)
remotesFn func() (remotes.Remotes, error)
httpClient func() (*http.Client, error)
pushDefault func() (string, error)
parsePushRevision func(string) (string, error)
branchConfig func(string) (git.BranchConfig, error)
progress progressIndicator
baseRefRepo ghrepo.Interface
prNumber int
branchName string
}
func NewFinder(factory *cmdutil.Factory) PRFinder {
if runCommandFinder != nil {
f := runCommandFinder
runCommandFinder = &mockFinder{err: errors.New("you must use a RunCommandFinder to stub PR lookups")}
return f
}
return &finder{
baseRepoFn: factory.BaseRepo,
branchFn: factory.Branch,
remotesFn: factory.Remotes,
httpClient: factory.HttpClient,
pushDefault: func() (string, error) {
return factory.GitClient.PushDefault(context.Background())
},
parsePushRevision: func(branch string) (string, error) {
return factory.GitClient.ParsePushRevision(context.Background(), branch)
},
progress: factory.IOStreams,
branchConfig: func(s string) (git.BranchConfig, error) {
return factory.GitClient.ReadBranchConfig(context.Background(), s)
},
}
}
var runCommandFinder PRFinder
// RunCommandFinder is the NewMockFinder substitute to be used ONLY in runCommand-style tests.
func RunCommandFinder(selector string, pr *api.PullRequest, repo ghrepo.Interface) *mockFinder {
finder := NewMockFinder(selector, pr, repo)
runCommandFinder = finder
return finder
}
type FindOptions struct {
// Selector can be a number with optional `#` prefix, a branch name with optional `<owner>:` prefix, or
// a PR URL.
Selector string
// Fields lists the GraphQL fields to fetch for the PullRequest.
Fields []string
// BaseBranch is the name of the base branch to scope the PR-for-branch lookup to.
BaseBranch string
// States lists the possible PR states to scope the PR-for-branch lookup to.
States []string
}
type PRRefs struct {
BranchName string
HeadRepo ghrepo.Interface
BaseRepo ghrepo.Interface
}
// GetPRLabel returns the string that the GitHub API uses to identify the PR. This is
// either just the branch name or, if the PR is originating from a fork, the fork owner
// and the branch name, like <owner>:<branch>.
func (s *PRRefs) GetPRLabel() string {
if ghrepo.IsSame(s.HeadRepo, s.BaseRepo) {
return s.BranchName
}
return fmt.Sprintf("%s:%s", s.HeadRepo.RepoOwner(), s.BranchName)
}
func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, error) {
if len(opts.Fields) == 0 {
return nil, nil, errors.New("Find error: no fields specified")
}
if repo, prNumber, err := f.parseURL(opts.Selector); err == nil {
f.prNumber = prNumber
f.baseRefRepo = repo
}
if f.baseRefRepo == nil {
repo, err := f.baseRepoFn()
if err != nil {
return nil, nil, err
}
f.baseRefRepo = repo
}
if f.prNumber == 0 && opts.Selector != "" {
// If opts.Selector is a valid number then assume it is the
// PR number unless opts.BaseBranch is specified. This is a
// special case for PR create command which will always want
// to assume that a numerical selector is a branch name rather
// than PR number.
prNumber, err := strconv.Atoi(strings.TrimPrefix(opts.Selector, "#"))
if opts.BaseBranch == "" && err == nil {
f.prNumber = prNumber
} else {
f.branchName = opts.Selector
}
} else {
currentBranchName, err := f.branchFn()
if err != nil {
return nil, nil, err
}
f.branchName = currentBranchName
}
// Get the branch config for the current branchName
branchConfig, err := f.branchConfig(f.branchName)
if err != nil {
return nil, nil, err
}
// Determine if the branch is configured to merge to a special PR ref
prHeadRE := regexp.MustCompile(`^refs/pull/(\d+)/head$`)
if m := prHeadRE.FindStringSubmatch(branchConfig.MergeRef); m != nil {
prNumber, _ := strconv.Atoi(m[1])
f.prNumber = prNumber
}
// Set up HTTP client
httpClient, err := f.httpClient()
if err != nil {
return nil, nil, err
}
// TODO(josebalius): Should we be guarding here?
if f.progress != nil {
f.progress.StartProgressIndicator()
defer f.progress.StopProgressIndicator()
}
fields := set.NewStringSet()
fields.AddValues(opts.Fields)
numberFieldOnly := fields.Len() == 1 && fields.Contains("number")
fields.AddValues([]string{"id", "number"}) // for additional preload queries below
if fields.Contains("isInMergeQueue") || fields.Contains("isMergeQueueEnabled") {
cachedClient := api.NewCachedHTTPClient(httpClient, time.Hour*24)
detector := fd.NewDetector(cachedClient, f.baseRefRepo.RepoHost())
prFeatures, err := detector.PullRequestFeatures()
if err != nil {
return nil, nil, err
}
if !prFeatures.MergeQueue {
fields.Remove("isInMergeQueue")
fields.Remove("isMergeQueueEnabled")
}
}
var getProjectItems bool
if fields.Contains("projectItems") {
getProjectItems = true
fields.Remove("projectItems")
}
var pr *api.PullRequest
if f.prNumber > 0 {
if numberFieldOnly {
// avoid hitting the API if we already have all the information
return &api.PullRequest{Number: f.prNumber}, f.baseRefRepo, nil
}
pr, err = findByNumber(httpClient, f.baseRefRepo, f.prNumber, fields.ToSlice())
if err != nil {
return pr, f.baseRefRepo, err
}
} else {
rems, err := f.remotesFn()
if err != nil {
return nil, nil, err
}
pushDefault, err := f.pushDefault()
if err != nil {
return nil, nil, err
}
// Suppressing the error as we have other means of computing the PRRefs if this fails.
parsedPushRevision, _ := f.parsePushRevision(f.branchName)
prRefs, err := parsePRRefs(f.branchName, branchConfig, parsedPushRevision, pushDefault, f.baseRefRepo, rems)
if err != nil {
return nil, nil, err
}
pr, err = findForBranch(httpClient, f.baseRefRepo, opts.BaseBranch, prRefs.GetPRLabel(), opts.States, fields.ToSlice())
if err != nil {
return pr, f.baseRefRepo, err
}
}
g, _ := errgroup.WithContext(context.Background())
if fields.Contains("reviews") {
g.Go(func() error {
return preloadPrReviews(httpClient, f.baseRefRepo, pr)
})
}
if fields.Contains("comments") {
g.Go(func() error {
return preloadPrComments(httpClient, f.baseRefRepo, pr)
})
}
if fields.Contains("statusCheckRollup") {
g.Go(func() error {
return preloadPrChecks(httpClient, f.baseRefRepo, pr)
})
}
if getProjectItems {
g.Go(func() error {
apiClient := api.NewClientFromHTTP(httpClient)
err := api.ProjectsV2ItemsForPullRequest(apiClient, f.baseRefRepo, pr)
if err != nil && !api.ProjectsV2IgnorableError(err) {
return err
}
return nil
})
}
return pr, f.baseRefRepo, g.Wait()
}
var pullURLRE = regexp.MustCompile(`^/([^/]+)/([^/]+)/pull/(\d+)`)
func (f *finder) parseURL(prURL string) (ghrepo.Interface, int, error) {
if prURL == "" {
return nil, 0, fmt.Errorf("invalid URL: %q", prURL)
}
u, err := url.Parse(prURL)
if err != nil {
return nil, 0, err
}
if u.Scheme != "https" && u.Scheme != "http" {
return nil, 0, fmt.Errorf("invalid scheme: %s", u.Scheme)
}
m := pullURLRE.FindStringSubmatch(u.Path)
if m == nil {
return nil, 0, fmt.Errorf("not a pull request URL: %s", prURL)
}
repo := ghrepo.NewWithHost(m[1], m[2], u.Hostname())
prNumber, _ := strconv.Atoi(m[3])
return repo, prNumber, nil
}
func parsePRRefs(currentBranchName string, branchConfig git.BranchConfig, parsedPushRevision string, pushDefault string, baseRefRepo ghrepo.Interface, rems remotes.Remotes) (PRRefs, error) {
prRefs := PRRefs{
BaseRepo: baseRefRepo,
}
// If @{push} resolves, then we have all the information we need to determine the head repo
// and branch name. It is of the form <remote>/<branch>.
if parsedPushRevision != "" {
for _, r := range rems {
// Find the remote who's name matches the push <remote> prefix
if strings.HasPrefix(parsedPushRevision, r.Name+"/") {
prRefs.BranchName = strings.TrimPrefix(parsedPushRevision, r.Name+"/")
prRefs.HeadRepo = r.Repo
return prRefs, nil
}
}
remoteNames := make([]string, len(rems))
for i, r := range rems {
remoteNames[i] = r.Name
}
return PRRefs{}, fmt.Errorf("no remote for %q found in %q", parsedPushRevision, strings.Join(remoteNames, ", "))
}
// To get the HeadRepo, we look to the git config. The PushRemote{Name | URL} comes from
// one of the following, in order of precedence:
// 1. branch.<name>.pushRemote
// 2. remote.pushDefault
// 3. branch.<name>.remote
if branchConfig.PushRemoteName != "" {
if r, err := rems.FindByName(branchConfig.PushRemoteName); err == nil {
prRefs.HeadRepo = r.Repo
}
} else if branchConfig.PushRemoteURL != nil {
if r, err := ghrepo.FromURL(branchConfig.PushRemoteURL); err == nil {
prRefs.HeadRepo = r
}
}
// We assume the PR's branch name is the same as whatever f.BranchFn() returned earlier.
// unless the user has specified push.default = upstream or tracking, then we use the
// branch name from the merge ref.
prRefs.BranchName = currentBranchName
if pushDefault == "upstream" || pushDefault == "tracking" {
prRefs.BranchName = strings.TrimPrefix(branchConfig.MergeRef, "refs/heads/")
}
// The PR merges from a branch in the same repo as the base branch (usually the default branch)
if prRefs.HeadRepo == nil {
prRefs.HeadRepo = baseRefRepo
}
return prRefs, nil
}
func findByNumber(httpClient *http.Client, repo ghrepo.Interface, number int, fields []string) (*api.PullRequest, error) {
type response struct {
Repository struct {
PullRequest api.PullRequest
}
}
query := fmt.Sprintf(`
query PullRequestByNumber($owner: String!, $repo: String!, $pr_number: Int!) {
repository(owner: $owner, name: $repo) {
pullRequest(number: $pr_number) {%s}
}
}`, api.PullRequestGraphQL(fields))
variables := map[string]interface{}{
"owner": repo.RepoOwner(),
"repo": repo.RepoName(),
"pr_number": number,
}
var resp response
client := api.NewClientFromHTTP(httpClient)
err := client.GraphQL(repo.RepoHost(), query, variables, &resp)
if err != nil {
return nil, err
}
return &resp.Repository.PullRequest, nil
}
func findForBranch(httpClient *http.Client, repo ghrepo.Interface, baseBranch, headBranchWithOwnerIfFork string, stateFilters, fields []string) (*api.PullRequest, error) {
type response struct {
Repository struct {
PullRequests struct {
Nodes []api.PullRequest
}
DefaultBranchRef struct {
Name string
}
}
}
fieldSet := set.NewStringSet()
fieldSet.AddValues(fields)
// these fields are required for filtering below
fieldSet.AddValues([]string{"state", "baseRefName", "headRefName", "isCrossRepository", "headRepositoryOwner"})
query := fmt.Sprintf(`
query PullRequestForBranch($owner: String!, $repo: String!, $headRefName: String!, $states: [PullRequestState!]) {
repository(owner: $owner, name: $repo) {
pullRequests(headRefName: $headRefName, states: $states, first: 30, orderBy: { field: CREATED_AT, direction: DESC }) {
nodes {%s}
}
defaultBranchRef { name }
}
}`, api.PullRequestGraphQL(fieldSet.ToSlice()))
branchWithoutOwner := headBranchWithOwnerIfFork
if idx := strings.Index(headBranchWithOwnerIfFork, ":"); idx >= 0 {
branchWithoutOwner = headBranchWithOwnerIfFork[idx+1:]
}
variables := map[string]interface{}{
"owner": repo.RepoOwner(),
"repo": repo.RepoName(),
"headRefName": branchWithoutOwner,
"states": stateFilters,
}
var resp response
client := api.NewClientFromHTTP(httpClient)
err := client.GraphQL(repo.RepoHost(), query, variables, &resp)
if err != nil {
return nil, err
}
prs := resp.Repository.PullRequests.Nodes
sort.SliceStable(prs, func(a, b int) bool {
return prs[a].State == "OPEN" && prs[b].State != "OPEN"
})
for _, pr := range prs {
headBranchMatches := pr.HeadLabel() == headBranchWithOwnerIfFork
baseBranchEmptyOrMatches := baseBranch == "" || pr.BaseRefName == baseBranch
// When the head is the default branch, it doesn't really make sense to show merged or closed PRs.
// https://github.com/cli/cli/issues/4263
isNotClosedOrMergedWhenHeadIsDefault := pr.State == "OPEN" || resp.Repository.DefaultBranchRef.Name != headBranchWithOwnerIfFork
if headBranchMatches && baseBranchEmptyOrMatches && isNotClosedOrMergedWhenHeadIsDefault {
return &pr, nil
}
}
return nil, &NotFoundError{fmt.Errorf("no pull requests found for branch %q", headBranchWithOwnerIfFork)}
}
func preloadPrReviews(httpClient *http.Client, repo ghrepo.Interface, pr *api.PullRequest) error {
if !pr.Reviews.PageInfo.HasNextPage {
return nil
}
type response struct {
Node struct {
PullRequest struct {
Reviews api.PullRequestReviews `graphql:"reviews(first: 100, after: $endCursor)"`
} `graphql:"...on PullRequest"`
} `graphql:"node(id: $id)"`
}
variables := map[string]interface{}{
"id": githubv4.ID(pr.ID),
"endCursor": githubv4.String(pr.Reviews.PageInfo.EndCursor),
}
gql := api.NewClientFromHTTP(httpClient)
for {
var query response
err := gql.Query(repo.RepoHost(), "ReviewsForPullRequest", &query, variables)
if err != nil {
return err
}
pr.Reviews.Nodes = append(pr.Reviews.Nodes, query.Node.PullRequest.Reviews.Nodes...)
pr.Reviews.TotalCount = len(pr.Reviews.Nodes)
if !query.Node.PullRequest.Reviews.PageInfo.HasNextPage {
break
}
variables["endCursor"] = githubv4.String(query.Node.PullRequest.Reviews.PageInfo.EndCursor)
}
pr.Reviews.PageInfo.HasNextPage = false
return nil
}
func preloadPrComments(client *http.Client, repo ghrepo.Interface, pr *api.PullRequest) error {
if !pr.Comments.PageInfo.HasNextPage {
return nil
}
type response struct {
Node struct {
PullRequest struct {
Comments api.Comments `graphql:"comments(first: 100, after: $endCursor)"`
} `graphql:"...on PullRequest"`
} `graphql:"node(id: $id)"`
}
variables := map[string]interface{}{
"id": githubv4.ID(pr.ID),
"endCursor": githubv4.String(pr.Comments.PageInfo.EndCursor),
}
gql := api.NewClientFromHTTP(client)
for {
var query response
err := gql.Query(repo.RepoHost(), "CommentsForPullRequest", &query, variables)
if err != nil {
return err
}
pr.Comments.Nodes = append(pr.Comments.Nodes, query.Node.PullRequest.Comments.Nodes...)
pr.Comments.TotalCount = len(pr.Comments.Nodes)
if !query.Node.PullRequest.Comments.PageInfo.HasNextPage {
break
}
variables["endCursor"] = githubv4.String(query.Node.PullRequest.Comments.PageInfo.EndCursor)
}
pr.Comments.PageInfo.HasNextPage = false
return nil
}
func preloadPrChecks(client *http.Client, repo ghrepo.Interface, pr *api.PullRequest) error {
if len(pr.StatusCheckRollup.Nodes) == 0 {
return nil
}
statusCheckRollup := &pr.StatusCheckRollup.Nodes[0].Commit.StatusCheckRollup.Contexts
if !statusCheckRollup.PageInfo.HasNextPage {
return nil
}
endCursor := statusCheckRollup.PageInfo.EndCursor
type response struct {
Node *api.PullRequest
}
query := fmt.Sprintf(`
query PullRequestStatusChecks($id: ID!, $endCursor: String!) {
node(id: $id) {
...on PullRequest {
%s
}
}
}`, api.StatusCheckRollupGraphQLWithoutCountByState("$endCursor"))
variables := map[string]interface{}{
"id": pr.ID,
}
apiClient := api.NewClientFromHTTP(client)
for {
variables["endCursor"] = endCursor
var resp response
err := apiClient.GraphQL(repo.RepoHost(), query, variables, &resp)
if err != nil {
return err
}
result := resp.Node.StatusCheckRollup.Nodes[0].Commit.StatusCheckRollup.Contexts
statusCheckRollup.Nodes = append(
statusCheckRollup.Nodes,
result.Nodes...,
)
if !result.PageInfo.HasNextPage {
break
}
endCursor = result.PageInfo.EndCursor
}
statusCheckRollup.PageInfo.HasNextPage = false
return nil
}
type NotFoundError struct {
error
}
func (err *NotFoundError) Unwrap() error {
return err.error
}
func NewMockFinder(selector string, pr *api.PullRequest, repo ghrepo.Interface) *mockFinder {
var err error
if pr == nil {
err = &NotFoundError{errors.New("no pull requests found")}
}
return &mockFinder{
expectSelector: selector,
pr: pr,
repo: repo,
err: err,
}
}
type mockFinder struct {
called bool
expectSelector string
expectFields []string
pr *api.PullRequest
repo ghrepo.Interface
err error
}
func (m *mockFinder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, error) {
if m.err != nil {
return nil, nil, m.err
}
if m.expectSelector != opts.Selector {
return nil, nil, fmt.Errorf("mockFinder: expected selector %q, got %q", m.expectSelector, opts.Selector)
}
if len(m.expectFields) > 0 && !isEqualSet(m.expectFields, opts.Fields) {
return nil, nil, fmt.Errorf("mockFinder: expected fields %v, got %v", m.expectFields, opts.Fields)
}
if m.called {
return nil, nil, errors.New("mockFinder used more than once")
}
m.called = true
if m.pr.HeadRepositoryOwner.Login == "" {
// pose as same-repo PR by default
m.pr.HeadRepositoryOwner.Login = m.repo.RepoOwner()
}
return m.pr, m.repo, nil
}
func (m *mockFinder) ExpectFields(fields []string) {
m.expectFields = fields
}
func isEqualSet(a, b []string) bool {
if len(a) != len(b) {
return false
}
aCopy := make([]string, len(a))
copy(aCopy, a)
bCopy := make([]string, len(b))
copy(bCopy, b)
sort.Strings(aCopy)
sort.Strings(bCopy)
for i := range aCopy {
if aCopy[i] != bCopy[i] {
return false
}
}
return true
}