diff --git a/api/queries_pr.go b/api/queries_pr.go index b89e20609..6cda72919 100644 --- a/api/queries_pr.go +++ b/api/queries_pr.go @@ -4,23 +4,12 @@ import ( "context" "fmt" "net/http" - "strings" "time" - "github.com/cli/cli/v2/internal/ghinstance" "github.com/cli/cli/v2/internal/ghrepo" - "github.com/cli/cli/v2/pkg/set" "github.com/shurcooL/githubv4" - "golang.org/x/sync/errgroup" ) -type PullRequestsPayload struct { - ViewerCreated PullRequestAndTotalCount - ReviewRequested PullRequestAndTotalCount - CurrentPR *PullRequest - DefaultBranch string -} - type PullRequestAndTotalCount struct { TotalCount int PullRequests []PullRequest @@ -269,275 +258,6 @@ func (pr *PullRequest) DisplayableReviews() PullRequestReviews { return PullRequestReviews{Nodes: published, TotalCount: len(published)} } -type pullRequestFeature struct { - HasReviewDecision bool - HasStatusCheckRollup bool - HasBranchProtectionRule bool -} - -func determinePullRequestFeatures(httpClient *http.Client, hostname string) (prFeatures pullRequestFeature, err error) { - if !ghinstance.IsEnterprise(hostname) { - prFeatures.HasReviewDecision = true - prFeatures.HasStatusCheckRollup = true - prFeatures.HasBranchProtectionRule = true - return - } - - var featureDetection struct { - PullRequest struct { - Fields []struct { - Name string - } `graphql:"fields(includeDeprecated: true)"` - } `graphql:"PullRequest: __type(name: \"PullRequest\")"` - Commit struct { - Fields []struct { - Name string - } `graphql:"fields(includeDeprecated: true)"` - } `graphql:"Commit: __type(name: \"Commit\")"` - } - - // needs to be a separate query because the backend only supports 2 `__type` expressions in one query - var featureDetection2 struct { - Ref struct { - Fields []struct { - Name string - } `graphql:"fields(includeDeprecated: true)"` - } `graphql:"Ref: __type(name: \"Ref\")"` - } - - v4 := graphQLClient(httpClient, hostname) - - g := new(errgroup.Group) - g.Go(func() error { - return v4.QueryNamed(context.Background(), "PullRequest_fields", &featureDetection, nil) - }) - g.Go(func() error { - return v4.QueryNamed(context.Background(), "PullRequest_fields2", &featureDetection2, nil) - }) - - err = g.Wait() - if err != nil { - return - } - - for _, field := range featureDetection.PullRequest.Fields { - switch field.Name { - case "reviewDecision": - prFeatures.HasReviewDecision = true - } - } - for _, field := range featureDetection.Commit.Fields { - switch field.Name { - case "statusCheckRollup": - prFeatures.HasStatusCheckRollup = true - } - } - for _, field := range featureDetection2.Ref.Fields { - switch field.Name { - case "branchProtectionRule": - prFeatures.HasBranchProtectionRule = true - } - } - return -} - -type StatusOptions struct { - CurrentPR int - HeadRef string - Username string - Fields []string -} - -func PullRequestStatus(client *Client, repo ghrepo.Interface, options StatusOptions) (*PullRequestsPayload, error) { - type edges struct { - TotalCount int - Edges []struct { - Node PullRequest - } - } - - type response struct { - Repository struct { - DefaultBranchRef struct { - Name string - } - PullRequests edges - PullRequest *PullRequest - } - ViewerCreated edges - ReviewRequested edges - } - - var fragments string - if len(options.Fields) > 0 { - fields := set.NewStringSet() - fields.AddValues(options.Fields) - // these are always necessary to find the PR for the current branch - fields.AddValues([]string{"isCrossRepository", "headRepositoryOwner", "headRefName"}) - gr := PullRequestGraphQL(fields.ToSlice()) - fragments = fmt.Sprintf("fragment pr on PullRequest{%s}fragment prWithReviews on PullRequest{...pr}", gr) - } else { - var err error - fragments, err = pullRequestFragment(client.http, repo.RepoHost()) - if err != nil { - return nil, err - } - } - - queryPrefix := ` - query PullRequestStatus($owner: String!, $repo: String!, $headRefName: String!, $viewerQuery: String!, $reviewerQuery: String!, $per_page: Int = 10) { - repository(owner: $owner, name: $repo) { - defaultBranchRef { - name - } - pullRequests(headRefName: $headRefName, first: $per_page, orderBy: { field: CREATED_AT, direction: DESC }) { - totalCount - edges { - node { - ...prWithReviews - } - } - } - } - ` - if options.CurrentPR > 0 { - queryPrefix = ` - query PullRequestStatus($owner: String!, $repo: String!, $number: Int!, $viewerQuery: String!, $reviewerQuery: String!, $per_page: Int = 10) { - repository(owner: $owner, name: $repo) { - defaultBranchRef { - name - } - pullRequest(number: $number) { - ...prWithReviews - baseRef { - branchProtectionRule { - requiredApprovingReviewCount - } - } - } - } - ` - } - - query := fragments + queryPrefix + ` - viewerCreated: search(query: $viewerQuery, type: ISSUE, first: $per_page) { - totalCount: issueCount - edges { - node { - ...prWithReviews - } - } - } - reviewRequested: search(query: $reviewerQuery, type: ISSUE, first: $per_page) { - totalCount: issueCount - edges { - node { - ...pr - } - } - } - } - ` - - currentUsername := options.Username - if currentUsername == "@me" && ghinstance.IsEnterprise(repo.RepoHost()) { - var err error - currentUsername, err = CurrentLoginName(client, repo.RepoHost()) - if err != nil { - return nil, err - } - } - - viewerQuery := fmt.Sprintf("repo:%s state:open is:pr author:%s", ghrepo.FullName(repo), currentUsername) - reviewerQuery := fmt.Sprintf("repo:%s state:open review-requested:%s", ghrepo.FullName(repo), currentUsername) - - currentPRHeadRef := options.HeadRef - branchWithoutOwner := currentPRHeadRef - if idx := strings.Index(currentPRHeadRef, ":"); idx >= 0 { - branchWithoutOwner = currentPRHeadRef[idx+1:] - } - - variables := map[string]interface{}{ - "viewerQuery": viewerQuery, - "reviewerQuery": reviewerQuery, - "owner": repo.RepoOwner(), - "repo": repo.RepoName(), - "headRefName": branchWithoutOwner, - "number": options.CurrentPR, - } - - var resp response - err := client.GraphQL(repo.RepoHost(), query, variables, &resp) - if err != nil { - return nil, err - } - - var viewerCreated []PullRequest - for _, edge := range resp.ViewerCreated.Edges { - viewerCreated = append(viewerCreated, edge.Node) - } - - var reviewRequested []PullRequest - for _, edge := range resp.ReviewRequested.Edges { - reviewRequested = append(reviewRequested, edge.Node) - } - - var currentPR = resp.Repository.PullRequest - if currentPR == nil { - for _, edge := range resp.Repository.PullRequests.Edges { - if edge.Node.HeadLabel() == currentPRHeadRef { - currentPR = &edge.Node - break // Take the most recent PR for the current branch - } - } - } - - payload := PullRequestsPayload{ - ViewerCreated: PullRequestAndTotalCount{ - PullRequests: viewerCreated, - TotalCount: resp.ViewerCreated.TotalCount, - }, - ReviewRequested: PullRequestAndTotalCount{ - PullRequests: reviewRequested, - TotalCount: resp.ReviewRequested.TotalCount, - }, - CurrentPR: currentPR, - DefaultBranch: resp.Repository.DefaultBranchRef.Name, - } - - return &payload, nil -} - -func pullRequestFragment(httpClient *http.Client, hostname string) (string, error) { - cachedClient := NewCachedClient(httpClient, time.Hour*24) - prFeatures, err := determinePullRequestFeatures(cachedClient, hostname) - if err != nil { - return "", err - } - - fields := []string{ - "number", "title", "state", "url", "isDraft", "isCrossRepository", - "headRefName", "headRepositoryOwner", "mergeStateStatus", - } - if prFeatures.HasStatusCheckRollup { - fields = append(fields, "statusCheckRollup") - } - if prFeatures.HasBranchProtectionRule { - fields = append(fields, "requiresStrictStatusChecks") - } - - var reviewFields []string - if prFeatures.HasReviewDecision { - reviewFields = append(reviewFields, "reviewDecision", "latestReviews") - } - - fragments := fmt.Sprintf(` - fragment pr on PullRequest {%s} - fragment prWithReviews on PullRequest {...pr,%s} - `, PullRequestGraphQL(fields), PullRequestGraphQL(reviewFields)) - return fragments, nil -} - // CreatePullRequest creates a pull request in a GitHub repository func CreatePullRequest(client *Client, repo *Repository, params map[string]interface{}) (*PullRequest, error) { query := ` diff --git a/api/queries_pr_test.go b/api/queries_pr_test.go index 0d692b451..1aefdf019 100644 --- a/api/queries_pr_test.go +++ b/api/queries_pr_test.go @@ -4,7 +4,6 @@ import ( "encoding/json" "testing" - "github.com/MakeNowJust/heredoc" "github.com/cli/cli/v2/internal/ghrepo" "github.com/cli/cli/v2/pkg/httpmock" "github.com/stretchr/testify/assert" @@ -49,117 +48,6 @@ func TestBranchDeleteRemote(t *testing.T) { } } -func Test_determinePullRequestFeatures(t *testing.T) { - tests := []struct { - name string - hostname string - queryResponse map[string]string - wantPrFeatures pullRequestFeature - wantErr bool - }{ - { - name: "github.com", - hostname: "github.com", - wantPrFeatures: pullRequestFeature{ - HasReviewDecision: true, - HasStatusCheckRollup: true, - HasBranchProtectionRule: true, - }, - wantErr: false, - }, - { - name: "GHE empty response", - hostname: "git.my.org", - queryResponse: map[string]string{ - `query PullRequest_fields\b`: `{"data": {}}`, - `query PullRequest_fields2\b`: `{"data": {}}`, - }, - wantPrFeatures: pullRequestFeature{ - HasReviewDecision: false, - HasStatusCheckRollup: false, - HasBranchProtectionRule: false, - }, - wantErr: false, - }, - { - name: "GHE has reviewDecision", - hostname: "git.my.org", - queryResponse: map[string]string{ - `query PullRequest_fields\b`: heredoc.Doc(` - { "data": { "PullRequest": { "fields": [ - {"name": "foo"}, - {"name": "reviewDecision"} - ] } } } - `), - `query PullRequest_fields2\b`: `{"data": {}}`, - }, - wantPrFeatures: pullRequestFeature{ - HasReviewDecision: true, - HasStatusCheckRollup: false, - HasBranchProtectionRule: false, - }, - wantErr: false, - }, - { - name: "GHE has statusCheckRollup", - hostname: "git.my.org", - queryResponse: map[string]string{ - `query PullRequest_fields\b`: heredoc.Doc(` - { "data": { "Commit": { "fields": [ - {"name": "foo"}, - {"name": "statusCheckRollup"} - ] } } } - `), - `query PullRequest_fields2\b`: `{"data": {}}`, - }, - wantPrFeatures: pullRequestFeature{ - HasReviewDecision: false, - HasStatusCheckRollup: true, - HasBranchProtectionRule: false, - }, - wantErr: false, - }, - { - name: "GHE has branchProtectionRule", - hostname: "git.my.org", - queryResponse: map[string]string{ - `query PullRequest_fields\b`: `{"data": {}}`, - `query PullRequest_fields2\b`: heredoc.Doc(` - { "data": { "Ref": { "fields": [ - {"name": "foo"}, - {"name": "branchProtectionRule"} - ] } } } - `), - }, - wantPrFeatures: pullRequestFeature{ - HasReviewDecision: false, - HasStatusCheckRollup: false, - HasBranchProtectionRule: true, - }, - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - fakeHTTP := &httpmock.Registry{} - httpClient := NewHTTPClient(ReplaceTripper(fakeHTTP)) - - for query, resp := range tt.queryResponse { - fakeHTTP.Register(httpmock.GraphQL(query), httpmock.StringResponse(resp)) - } - - gotPrFeatures, err := determinePullRequestFeatures(httpClient, tt.hostname) - if tt.wantErr { - assert.Error(t, err) - return - } else { - assert.NoError(t, err) - } - assert.Equal(t, tt.wantPrFeatures, gotPrFeatures) - }) - } -} - func Test_Logins(t *testing.T) { rr := ReviewRequests{} var tests = []struct { diff --git a/internal/featuredetection/detector_mock.go b/internal/featuredetection/detector_mock.go new file mode 100644 index 000000000..6f36dd3fc --- /dev/null +++ b/internal/featuredetection/detector_mock.go @@ -0,0 +1,29 @@ +package featuredetection + +type DisabledDetectorMock struct{} + +func (md *DisabledDetectorMock) IssueFeatures() (IssueFeatures, error) { + return IssueFeatures{}, nil +} + +func (md *DisabledDetectorMock) PullRequestFeatures() (PullRequestFeatures, error) { + return PullRequestFeatures{}, nil +} + +func (md *DisabledDetectorMock) RepositoryFeatures() (RepositoryFeatures, error) { + return RepositoryFeatures{}, nil +} + +type EnabledDetectorMock struct{} + +func (md *EnabledDetectorMock) IssueFeatures() (IssueFeatures, error) { + return allIssueFeatures, nil +} + +func (md *EnabledDetectorMock) PullRequestFeatures() (PullRequestFeatures, error) { + return allPullRequestFeatures, nil +} + +func (md *EnabledDetectorMock) RepositoryFeatures() (RepositoryFeatures, error) { + return allRepositoryFeatures, nil +} diff --git a/internal/featuredetection/feature_detection.go b/internal/featuredetection/feature_detection.go new file mode 100644 index 000000000..ee96656f1 --- /dev/null +++ b/internal/featuredetection/feature_detection.go @@ -0,0 +1,108 @@ +package featuredetection + +import ( + "context" + "net/http" + "time" + + "github.com/cli/cli/v2/api" + "github.com/cli/cli/v2/internal/ghinstance" + graphql "github.com/cli/shurcooL-graphql" +) + +type Detector interface { + IssueFeatures() (IssueFeatures, error) + PullRequestFeatures() (PullRequestFeatures, error) + RepositoryFeatures() (RepositoryFeatures, error) +} + +type IssueFeatures struct{} + +var allIssueFeatures = IssueFeatures{} + +type PullRequestFeatures struct { + ReviewDecision bool + StatusCheckRollup bool + BranchProtectionRule bool +} + +var allPullRequestFeatures = PullRequestFeatures{ + ReviewDecision: true, + StatusCheckRollup: true, + BranchProtectionRule: true, +} + +type RepositoryFeatures struct { + IssueTemplateMutation bool + IssueTemplateQuery bool + PullRequestTemplateQuery bool +} + +var allRepositoryFeatures = RepositoryFeatures{ + IssueTemplateMutation: true, + IssueTemplateQuery: true, + PullRequestTemplateQuery: true, +} + +type detector struct { + host string + httpClient *http.Client +} + +func NewDetector(httpClient *http.Client, host string) Detector { + cachedClient := api.NewCachedClient(httpClient, time.Hour*48) + return &detector{ + httpClient: cachedClient, + host: host, + } +} + +func (d *detector) IssueFeatures() (IssueFeatures, error) { + if !ghinstance.IsEnterprise(d.host) { + return allIssueFeatures, nil + } + + return allIssueFeatures, nil +} + +func (d *detector) PullRequestFeatures() (PullRequestFeatures, error) { + if !ghinstance.IsEnterprise(d.host) { + return allPullRequestFeatures, nil + } + + return allPullRequestFeatures, nil +} + +func (d *detector) RepositoryFeatures() (RepositoryFeatures, error) { + if !ghinstance.IsEnterprise(d.host) { + return allRepositoryFeatures, nil + } + + features := RepositoryFeatures{ + IssueTemplateQuery: true, + IssueTemplateMutation: true, + } + + var featureDetection struct { + Repository struct { + Fields []struct { + Name string + } `graphql:"fields(includeDeprecated: true)"` + } `graphql:"Repository: __type(name: \"Repository\")"` + } + + gql := graphql.NewClient(ghinstance.GraphQLEndpoint(d.host), d.httpClient) + + err := gql.QueryNamed(context.Background(), "Repository_fields", &featureDetection, nil) + if err != nil { + return features, err + } + + for _, field := range featureDetection.Repository.Fields { + if field.Name == "pullRequestTemplates" { + features.PullRequestTemplateQuery = true + } + } + + return features, nil +} diff --git a/internal/featuredetection/feature_detection_test.go b/internal/featuredetection/feature_detection_test.go new file mode 100644 index 000000000..1b6c26273 --- /dev/null +++ b/internal/featuredetection/feature_detection_test.go @@ -0,0 +1,127 @@ +package featuredetection + +import ( + "testing" + + "github.com/MakeNowJust/heredoc" + "github.com/cli/cli/v2/api" + "github.com/cli/cli/v2/pkg/httpmock" + "github.com/stretchr/testify/assert" +) + +func TestPullRequestFeatures(t *testing.T) { + tests := []struct { + name string + hostname string + queryResponse map[string]string + wantFeatures PullRequestFeatures + wantErr bool + }{ + { + name: "github.com", + hostname: "github.com", + wantFeatures: PullRequestFeatures{ + ReviewDecision: true, + StatusCheckRollup: true, + BranchProtectionRule: true, + }, + wantErr: false, + }, + { + name: "GHE", + hostname: "git.my.org", + wantFeatures: PullRequestFeatures{ + ReviewDecision: true, + StatusCheckRollup: true, + BranchProtectionRule: true, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fakeHTTP := &httpmock.Registry{} + httpClient := api.NewHTTPClient(api.ReplaceTripper(fakeHTTP)) + for query, resp := range tt.queryResponse { + fakeHTTP.Register(httpmock.GraphQL(query), httpmock.StringResponse(resp)) + } + detector := detector{host: tt.hostname, httpClient: httpClient} + gotPrFeatures, err := detector.PullRequestFeatures() + if tt.wantErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + assert.Equal(t, tt.wantFeatures, gotPrFeatures) + }) + } +} + +func TestRepositoryFeatures(t *testing.T) { + tests := []struct { + name string + hostname string + queryResponse map[string]string + wantFeatures RepositoryFeatures + wantErr bool + }{ + { + name: "github.com", + hostname: "github.com", + wantFeatures: RepositoryFeatures{ + IssueTemplateMutation: true, + IssueTemplateQuery: true, + PullRequestTemplateQuery: true, + }, + wantErr: false, + }, + { + name: "GHE empty response", + hostname: "git.my.org", + queryResponse: map[string]string{ + `query Repository_fields\b`: `{"data": {}}`, + }, + wantFeatures: RepositoryFeatures{ + IssueTemplateMutation: true, + IssueTemplateQuery: true, + PullRequestTemplateQuery: false, + }, + wantErr: false, + }, + { + name: "GHE has pull request template query", + hostname: "git.my.org", + queryResponse: map[string]string{ + `query Repository_fields\b`: heredoc.Doc(` + { "data": { "Repository": { "fields": [ + {"name": "pullRequestTemplates"} + ] } } } + `), + }, + wantFeatures: RepositoryFeatures{ + IssueTemplateMutation: true, + IssueTemplateQuery: true, + PullRequestTemplateQuery: true, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fakeHTTP := &httpmock.Registry{} + httpClient := api.NewHTTPClient(api.ReplaceTripper(fakeHTTP)) + for query, resp := range tt.queryResponse { + fakeHTTP.Register(httpmock.GraphQL(query), httpmock.StringResponse(resp)) + } + detector := detector{host: tt.hostname, httpClient: httpClient} + gotPrFeatures, err := detector.RepositoryFeatures() + if tt.wantErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + assert.Equal(t, tt.wantFeatures, gotPrFeatures) + }) + } +} diff --git a/pkg/cmd/pr/shared/templates.go b/pkg/cmd/pr/shared/templates.go index 6750bd6d6..741fe160e 100644 --- a/pkg/cmd/pr/shared/templates.go +++ b/pkg/cmd/pr/shared/templates.go @@ -4,11 +4,10 @@ import ( "context" "fmt" "net/http" - "time" "github.com/AlecAivazis/survey/v2" - "github.com/cli/cli/v2/api" "github.com/cli/cli/v2/git" + fd "github.com/cli/cli/v2/internal/featuredetection" "github.com/cli/cli/v2/internal/ghinstance" "github.com/cli/cli/v2/internal/ghrepo" "github.com/cli/cli/v2/pkg/githubtemplate" @@ -109,55 +108,6 @@ func listPullRequestTemplates(httpClient *http.Client, repo ghrepo.Interface) ([ return templates, nil } -func hasTemplateSupport(httpClient *http.Client, hostname string, isPR bool) (bool, error) { - if !ghinstance.IsEnterprise(hostname) { - return true, nil - } - - var featureDetection struct { - Repository struct { - Fields []struct { - Name string - } `graphql:"fields(includeDeprecated: true)"` - } `graphql:"Repository: __type(name: \"Repository\")"` - CreateIssueInput struct { - InputFields []struct { - Name string - } - } `graphql:"CreateIssueInput: __type(name: \"CreateIssueInput\")"` - } - - gql := graphql.NewClient(ghinstance.GraphQLEndpoint(hostname), httpClient) - err := gql.QueryNamed(context.Background(), "IssueTemplates_fields", &featureDetection, nil) - if err != nil { - return false, err - } - - var hasIssueQuerySupport bool - var hasIssueMutationSupport bool - var hasPullRequestQuerySupport bool - - for _, field := range featureDetection.Repository.Fields { - if field.Name == "issueTemplates" { - hasIssueQuerySupport = true - } - if field.Name == "pullRequestTemplates" { - hasPullRequestQuerySupport = true - } - } - for _, field := range featureDetection.CreateIssueInput.InputFields { - if field.Name == "issueTemplate" { - hasIssueMutationSupport = true - } - } - - if isPR { - return hasPullRequestQuerySupport, nil - } else { - return hasIssueQuerySupport && hasIssueMutationSupport, nil - } -} - type Template interface { Name() string NameForSubmit() string @@ -170,8 +120,8 @@ type templateManager struct { allowFS bool isPR bool httpClient *http.Client + detector fd.Detector - cachedClient *http.Client templates []Template legacyTemplate Template @@ -186,14 +136,21 @@ func NewTemplateManager(httpClient *http.Client, repo ghrepo.Interface, dir stri allowFS: allowFS, isPR: isPR, httpClient: httpClient, + detector: fd.NewDetector(httpClient, repo.RepoHost()), } } func (m *templateManager) hasAPI() (bool, error) { - if m.cachedClient == nil { - m.cachedClient = api.NewCachedClient(m.httpClient, time.Hour*24) + if !m.isPR { + return true, nil } - return hasTemplateSupport(m.cachedClient, m.repo.RepoHost(), m.isPR) + + features, err := m.detector.RepositoryFeatures() + if err != nil { + return false, err + } + + return features.PullRequestTemplateQuery, nil } func (m *templateManager) HasTemplates() (bool, error) { diff --git a/pkg/cmd/pr/shared/templates_test.go b/pkg/cmd/pr/shared/templates_test.go index 829b63b7a..752de47c1 100644 --- a/pkg/cmd/pr/shared/templates_test.go +++ b/pkg/cmd/pr/shared/templates_test.go @@ -6,6 +6,7 @@ import ( "path/filepath" "testing" + fd "github.com/cli/cli/v2/internal/featuredetection" "github.com/cli/cli/v2/internal/ghrepo" "github.com/cli/cli/v2/pkg/httpmock" "github.com/cli/cli/v2/pkg/prompt" @@ -22,22 +23,6 @@ func TestTemplateManager_hasAPI(t *testing.T) { httpClient := &http.Client{Transport: &tr} defer tr.Verify(t) - tr.Register( - httpmock.GraphQL(`query IssueTemplates_fields\b`), - httpmock.StringResponse(`{"data":{ - "Repository": { - "fields": [ - {"name": "foo"}, - {"name": "issueTemplates"} - ] - }, - "CreateIssueInput": { - "inputFields": [ - {"name": "foo"}, - {"name": "issueTemplate"} - ] - } - }}`)) tr.Register( httpmock.GraphQL(`query IssueTemplates\b`), httpmock.StringResponse(`{"data":{"repository":{ @@ -48,12 +33,12 @@ func TestTemplateManager_hasAPI(t *testing.T) { }}}`)) m := templateManager{ - repo: ghrepo.NewWithHost("OWNER", "REPO", "example.com"), - rootDir: rootDir, - allowFS: true, - isPR: false, - httpClient: httpClient, - cachedClient: httpClient, + repo: ghrepo.NewWithHost("OWNER", "REPO", "example.com"), + rootDir: rootDir, + allowFS: true, + isPR: false, + httpClient: httpClient, + detector: &fd.EnabledDetectorMock{}, } hasTemplates, err := m.HasTemplates() @@ -84,16 +69,6 @@ func TestTemplateManager_hasAPI_PullRequest(t *testing.T) { httpClient := &http.Client{Transport: &tr} defer tr.Verify(t) - tr.Register( - httpmock.GraphQL(`query IssueTemplates_fields\b`), - httpmock.StringResponse(`{"data":{ - "Repository": { - "fields": [ - {"name": "foo"}, - {"name": "pullRequestTemplates"} - ] - } - }}`)) tr.Register( httpmock.GraphQL(`query PullRequestTemplates\b`), httpmock.StringResponse(`{"data":{"repository":{ @@ -104,12 +79,12 @@ func TestTemplateManager_hasAPI_PullRequest(t *testing.T) { }}}`)) m := templateManager{ - repo: ghrepo.NewWithHost("OWNER", "REPO", "example.com"), - rootDir: rootDir, - allowFS: true, - isPR: true, - httpClient: httpClient, - cachedClient: httpClient, + repo: ghrepo.NewWithHost("OWNER", "REPO", "example.com"), + rootDir: rootDir, + allowFS: true, + isPR: true, + httpClient: httpClient, + detector: &fd.EnabledDetectorMock{}, } hasTemplates, err := m.HasTemplates() diff --git a/pkg/cmd/pr/status/http.go b/pkg/cmd/pr/status/http.go new file mode 100644 index 000000000..04dfebdcf --- /dev/null +++ b/pkg/cmd/pr/status/http.go @@ -0,0 +1,201 @@ +package status + +import ( + "fmt" + "net/http" + "strings" + + "github.com/cli/cli/v2/api" + "github.com/cli/cli/v2/internal/ghinstance" + "github.com/cli/cli/v2/internal/ghrepo" + "github.com/cli/cli/v2/pkg/set" +) + +type requestOptions struct { + CurrentPR int + HeadRef string + Username string + Fields []string +} + +type pullRequestsPayload struct { + ViewerCreated api.PullRequestAndTotalCount + ReviewRequested api.PullRequestAndTotalCount + CurrentPR *api.PullRequest + DefaultBranch string +} + +func pullRequestStatus(httpClient *http.Client, repo ghrepo.Interface, options requestOptions) (*pullRequestsPayload, error) { + apiClient := api.NewClientFromHTTP(httpClient) + type edges struct { + TotalCount int + Edges []struct { + Node api.PullRequest + } + } + + type response struct { + Repository struct { + DefaultBranchRef struct { + Name string + } + PullRequests edges + PullRequest *api.PullRequest + } + ViewerCreated edges + ReviewRequested edges + } + + var fragments string + if len(options.Fields) > 0 { + fields := set.NewStringSet() + fields.AddValues(options.Fields) + // these are always necessary to find the PR for the current branch + fields.AddValues([]string{"isCrossRepository", "headRepositoryOwner", "headRefName"}) + gr := api.PullRequestGraphQL(fields.ToSlice()) + fragments = fmt.Sprintf("fragment pr on PullRequest{%s}fragment prWithReviews on PullRequest{...pr}", gr) + } else { + var err error + fragments, err = pullRequestFragment(httpClient, repo.RepoHost()) + if err != nil { + return nil, err + } + } + + queryPrefix := ` + query PullRequestStatus($owner: String!, $repo: String!, $headRefName: String!, $viewerQuery: String!, $reviewerQuery: String!, $per_page: Int = 10) { + repository(owner: $owner, name: $repo) { + defaultBranchRef { + name + } + pullRequests(headRefName: $headRefName, first: $per_page, orderBy: { field: CREATED_AT, direction: DESC }) { + totalCount + edges { + node { + ...prWithReviews + } + } + } + } + ` + if options.CurrentPR > 0 { + queryPrefix = ` + query PullRequestStatus($owner: String!, $repo: String!, $number: Int!, $viewerQuery: String!, $reviewerQuery: String!, $per_page: Int = 10) { + repository(owner: $owner, name: $repo) { + defaultBranchRef { + name + } + pullRequest(number: $number) { + ...prWithReviews + baseRef { + branchProtectionRule { + requiredApprovingReviewCount + } + } + } + } + ` + } + + query := fragments + queryPrefix + ` + viewerCreated: search(query: $viewerQuery, type: ISSUE, first: $per_page) { + totalCount: issueCount + edges { + node { + ...prWithReviews + } + } + } + reviewRequested: search(query: $reviewerQuery, type: ISSUE, first: $per_page) { + totalCount: issueCount + edges { + node { + ...pr + } + } + } + } + ` + + currentUsername := options.Username + if currentUsername == "@me" && ghinstance.IsEnterprise(repo.RepoHost()) { + var err error + currentUsername, err = api.CurrentLoginName(apiClient, repo.RepoHost()) + if err != nil { + return nil, err + } + } + + viewerQuery := fmt.Sprintf("repo:%s state:open is:pr author:%s", ghrepo.FullName(repo), currentUsername) + reviewerQuery := fmt.Sprintf("repo:%s state:open review-requested:%s", ghrepo.FullName(repo), currentUsername) + + currentPRHeadRef := options.HeadRef + branchWithoutOwner := currentPRHeadRef + if idx := strings.Index(currentPRHeadRef, ":"); idx >= 0 { + branchWithoutOwner = currentPRHeadRef[idx+1:] + } + + variables := map[string]interface{}{ + "viewerQuery": viewerQuery, + "reviewerQuery": reviewerQuery, + "owner": repo.RepoOwner(), + "repo": repo.RepoName(), + "headRefName": branchWithoutOwner, + "number": options.CurrentPR, + } + + var resp response + err := apiClient.GraphQL(repo.RepoHost(), query, variables, &resp) + if err != nil { + return nil, err + } + + var viewerCreated []api.PullRequest + for _, edge := range resp.ViewerCreated.Edges { + viewerCreated = append(viewerCreated, edge.Node) + } + + var reviewRequested []api.PullRequest + for _, edge := range resp.ReviewRequested.Edges { + reviewRequested = append(reviewRequested, edge.Node) + } + + var currentPR = resp.Repository.PullRequest + if currentPR == nil { + for _, edge := range resp.Repository.PullRequests.Edges { + if edge.Node.HeadLabel() == currentPRHeadRef { + currentPR = &edge.Node + break // Take the most recent PR for the current branch + } + } + } + + payload := pullRequestsPayload{ + ViewerCreated: api.PullRequestAndTotalCount{ + PullRequests: viewerCreated, + TotalCount: resp.ViewerCreated.TotalCount, + }, + ReviewRequested: api.PullRequestAndTotalCount{ + PullRequests: reviewRequested, + TotalCount: resp.ReviewRequested.TotalCount, + }, + CurrentPR: currentPR, + DefaultBranch: resp.Repository.DefaultBranchRef.Name, + } + + return &payload, nil +} + +func pullRequestFragment(httpClient *http.Client, hostname string) (string, error) { + fields := []string{ + "number", "title", "state", "url", "isDraft", "isCrossRepository", + "headRefName", "headRepositoryOwner", "mergeStateStatus", + "statusCheckRollup", "requiresStrictStatusChecks", + } + reviewFields := []string{"reviewDecision", "latestReviews"} + fragments := fmt.Sprintf(` + fragment pr on PullRequest {%s} + fragment prWithReviews on PullRequest {...pr,%s} + `, api.PullRequestGraphQL(fields), api.PullRequestGraphQL(reviewFields)) + return fragments, nil +} diff --git a/pkg/cmd/pr/status/status.go b/pkg/cmd/pr/status/status.go index a0f521724..37cc388ef 100644 --- a/pkg/cmd/pr/status/status.go +++ b/pkg/cmd/pr/status/status.go @@ -67,7 +67,6 @@ func statusRun(opts *StatusOptions) error { if err != nil { return err } - apiClient := api.NewClientFromHTTP(httpClient) baseRepo, err := opts.BaseRepo() if err != nil { @@ -91,7 +90,7 @@ func statusRun(opts *StatusOptions) error { } } - options := api.StatusOptions{ + options := requestOptions{ Username: "@me", CurrentPR: currentPRNumber, HeadRef: currentPRHeadRef, @@ -99,7 +98,7 @@ func statusRun(opts *StatusOptions) error { if opts.Exporter != nil { options.Fields = opts.Exporter.Fields() } - prPayload, err := api.PullRequestStatus(apiClient, baseRepo, options) + prPayload, err := pullRequestStatus(httpClient, baseRepo, options) if err != nil { return err }