From 75e35e2ddee2afe846ae29a5b95588752167ce57 Mon Sep 17 00:00:00 2001 From: William Martin Date: Mon, 22 May 2023 17:28:15 +0200 Subject: [PATCH] Use pseudo field for statusCheckRollupWithCountByState --- api/query_builder.go | 34 ++++++---------------------------- pkg/cmd/pr/status/http.go | 29 +++++++++++++++++++---------- 2 files changed, 25 insertions(+), 38 deletions(-) diff --git a/api/query_builder.go b/api/query_builder.go index 4274e45e8..76eb5294a 100644 --- a/api/query_builder.go +++ b/api/query_builder.go @@ -292,28 +292,8 @@ var PullRequestFields = append(IssueFields, "statusCheckRollup", ) -type issueGraphQLOpts struct { - useCheckRunAndStatusContextCounts bool -} - -type IssueGraphQLOptFn func(*issueGraphQLOpts) - -func WithUseCheckRunAndStatusContextCounts() IssueGraphQLOptFn { - return func(opts *issueGraphQLOpts) { - opts.useCheckRunAndStatusContextCounts = true - } -} - // IssueGraphQL constructs a GraphQL query fragment for a set of issue fields. -func IssueGraphQL(fields []string, opts ...IssueGraphQLOptFn) string { - issueGraphQLOpts := issueGraphQLOpts{ - useCheckRunAndStatusContextCounts: false, - } - - for _, opt := range opts { - opt(&issueGraphQLOpts) - } - +func IssueGraphQL(fields []string) string { var q []string for _, field := range fields { switch field { @@ -364,11 +344,9 @@ func IssueGraphQL(fields []string, opts ...IssueGraphQLOptFn) string { case "requiresStrictStatusChecks": // pseudo-field q = append(q, `baseRef{branchProtectionRule{requiresStrictStatusChecks}}`) case "statusCheckRollup": - if issueGraphQLOpts.useCheckRunAndStatusContextCounts { - q = append(q, StatusCheckRollupGraphQLWithCountByState()) - } else { - q = append(q, StatusCheckRollupGraphQLWithoutCountByState("")) - } + q = append(q, StatusCheckRollupGraphQLWithoutCountByState("")) + case "statusCheckRollupWithCountByState": // pseudo-field + q = append(q, StatusCheckRollupGraphQLWithCountByState()) default: q = append(q, field) } @@ -378,12 +356,12 @@ func IssueGraphQL(fields []string, opts ...IssueGraphQLOptFn) string { // PullRequestGraphQL constructs a GraphQL query fragment for a set of pull request fields. // It will try to sanitize the fields to just those available on pull request. -func PullRequestGraphQL(fields []string, opts ...IssueGraphQLOptFn) string { +func PullRequestGraphQL(fields []string) string { invalidFields := []string{"isPinned", "stateReason"} s := set.NewStringSet() s.AddValues(fields) s.RemoveValues(invalidFields) - return IssueGraphQL(s.ToSlice(), opts...) + return IssueGraphQL(s.ToSlice()) } var RepositoryFields = []string{ diff --git a/pkg/cmd/pr/status/http.go b/pkg/cmd/pr/status/http.go index 1c38e2221..b240ee8ca 100644 --- a/pkg/cmd/pr/status/http.go +++ b/pkg/cmd/pr/status/http.go @@ -57,22 +57,24 @@ func pullRequestStatus(httpClient *http.Client, repo ghrepo.Interface, options r return nil, err } - var prGraphQLOpts []api.IssueGraphQLOptFn - if prFeatures.CheckRunAndStatusContextCounts { - prGraphQLOpts = append(prGraphQLOpts, api.WithUseCheckRunAndStatusContextCounts()) - } - var fragments string if len(options.Fields) > 0 { fields := set.NewStringSet() fields.AddValues(options.Fields) // these are always necessary to find the PR for the current branch fields.AddValues([]string{"isCrossRepository", "headRepositoryOwner", "headRefName"}) - gr := api.PullRequestGraphQL(fields.ToSlice(), prGraphQLOpts...) + + if prFeatures.CheckRunAndStatusContextCounts { + fields.Add("statusCheckRollupWithCountByState") + } else { + fields.Add("statusCheckRollup") + } + + gr := api.PullRequestGraphQL(fields.ToSlice()) fragments = fmt.Sprintf("fragment pr on PullRequest{%s}fragment prWithReviews on PullRequest{...pr}", gr) } else { var err error - fragments, err = pullRequestFragment(repo.RepoHost(), options.ConflictStatus, prGraphQLOpts...) + fragments, err = pullRequestFragment(repo.RepoHost(), options.ConflictStatus, prFeatures.CheckRunAndStatusContextCounts) if err != nil { return nil, err } @@ -201,20 +203,27 @@ func pullRequestStatus(httpClient *http.Client, repo ghrepo.Interface, options r return &payload, nil } -func pullRequestFragment(hostname string, conflictStatus bool, opts ...api.IssueGraphQLOptFn) (string, error) { +func pullRequestFragment(hostname string, conflictStatus bool, statusCheckRollupWithCountByState bool) (string, error) { fields := []string{ "number", "title", "state", "url", "isDraft", "isCrossRepository", "headRefName", "headRepositoryOwner", "mergeStateStatus", - "statusCheckRollup", "requiresStrictStatusChecks", "autoMergeRequest", + "requiresStrictStatusChecks", "autoMergeRequest", } if conflictStatus { fields = append(fields, "mergeable") } + + if statusCheckRollupWithCountByState { + fields = append(fields, "statusCheckRollupWithCountByState") + } else { + fields = append(fields, "statusCheckRollup") + } + reviewFields := []string{"reviewDecision", "latestReviews"} fragments := fmt.Sprintf(` fragment pr on PullRequest {%s} fragment prWithReviews on PullRequest {...pr,%s} - `, api.PullRequestGraphQL(fields, opts...), api.PullRequestGraphQL(reviewFields, opts...)) + `, api.PullRequestGraphQL(fields), api.PullRequestGraphQL(reviewFields)) return fragments, nil }