Extract feature detection package (#5494)

This commit is contained in:
Sam Coe 2022-05-17 21:07:44 +02:00 committed by GitHub
parent f8b3ff999f
commit 539b150833
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 492 additions and 488 deletions

View file

@ -4,23 +4,12 @@ import (
"context"
"fmt"
"net/http"
"strings"
"time"
"github.com/cli/cli/v2/internal/ghinstance"
"github.com/cli/cli/v2/internal/ghrepo"
"github.com/cli/cli/v2/pkg/set"
"github.com/shurcooL/githubv4"
"golang.org/x/sync/errgroup"
)
type PullRequestsPayload struct {
ViewerCreated PullRequestAndTotalCount
ReviewRequested PullRequestAndTotalCount
CurrentPR *PullRequest
DefaultBranch string
}
type PullRequestAndTotalCount struct {
TotalCount int
PullRequests []PullRequest
@ -269,275 +258,6 @@ func (pr *PullRequest) DisplayableReviews() PullRequestReviews {
return PullRequestReviews{Nodes: published, TotalCount: len(published)}
}
type pullRequestFeature struct {
HasReviewDecision bool
HasStatusCheckRollup bool
HasBranchProtectionRule bool
}
func determinePullRequestFeatures(httpClient *http.Client, hostname string) (prFeatures pullRequestFeature, err error) {
if !ghinstance.IsEnterprise(hostname) {
prFeatures.HasReviewDecision = true
prFeatures.HasStatusCheckRollup = true
prFeatures.HasBranchProtectionRule = true
return
}
var featureDetection struct {
PullRequest struct {
Fields []struct {
Name string
} `graphql:"fields(includeDeprecated: true)"`
} `graphql:"PullRequest: __type(name: \"PullRequest\")"`
Commit struct {
Fields []struct {
Name string
} `graphql:"fields(includeDeprecated: true)"`
} `graphql:"Commit: __type(name: \"Commit\")"`
}
// needs to be a separate query because the backend only supports 2 `__type` expressions in one query
var featureDetection2 struct {
Ref struct {
Fields []struct {
Name string
} `graphql:"fields(includeDeprecated: true)"`
} `graphql:"Ref: __type(name: \"Ref\")"`
}
v4 := graphQLClient(httpClient, hostname)
g := new(errgroup.Group)
g.Go(func() error {
return v4.QueryNamed(context.Background(), "PullRequest_fields", &featureDetection, nil)
})
g.Go(func() error {
return v4.QueryNamed(context.Background(), "PullRequest_fields2", &featureDetection2, nil)
})
err = g.Wait()
if err != nil {
return
}
for _, field := range featureDetection.PullRequest.Fields {
switch field.Name {
case "reviewDecision":
prFeatures.HasReviewDecision = true
}
}
for _, field := range featureDetection.Commit.Fields {
switch field.Name {
case "statusCheckRollup":
prFeatures.HasStatusCheckRollup = true
}
}
for _, field := range featureDetection2.Ref.Fields {
switch field.Name {
case "branchProtectionRule":
prFeatures.HasBranchProtectionRule = true
}
}
return
}
type StatusOptions struct {
CurrentPR int
HeadRef string
Username string
Fields []string
}
func PullRequestStatus(client *Client, repo ghrepo.Interface, options StatusOptions) (*PullRequestsPayload, error) {
type edges struct {
TotalCount int
Edges []struct {
Node PullRequest
}
}
type response struct {
Repository struct {
DefaultBranchRef struct {
Name string
}
PullRequests edges
PullRequest *PullRequest
}
ViewerCreated edges
ReviewRequested edges
}
var fragments string
if len(options.Fields) > 0 {
fields := set.NewStringSet()
fields.AddValues(options.Fields)
// these are always necessary to find the PR for the current branch
fields.AddValues([]string{"isCrossRepository", "headRepositoryOwner", "headRefName"})
gr := PullRequestGraphQL(fields.ToSlice())
fragments = fmt.Sprintf("fragment pr on PullRequest{%s}fragment prWithReviews on PullRequest{...pr}", gr)
} else {
var err error
fragments, err = pullRequestFragment(client.http, repo.RepoHost())
if err != nil {
return nil, err
}
}
queryPrefix := `
query PullRequestStatus($owner: String!, $repo: String!, $headRefName: String!, $viewerQuery: String!, $reviewerQuery: String!, $per_page: Int = 10) {
repository(owner: $owner, name: $repo) {
defaultBranchRef {
name
}
pullRequests(headRefName: $headRefName, first: $per_page, orderBy: { field: CREATED_AT, direction: DESC }) {
totalCount
edges {
node {
...prWithReviews
}
}
}
}
`
if options.CurrentPR > 0 {
queryPrefix = `
query PullRequestStatus($owner: String!, $repo: String!, $number: Int!, $viewerQuery: String!, $reviewerQuery: String!, $per_page: Int = 10) {
repository(owner: $owner, name: $repo) {
defaultBranchRef {
name
}
pullRequest(number: $number) {
...prWithReviews
baseRef {
branchProtectionRule {
requiredApprovingReviewCount
}
}
}
}
`
}
query := fragments + queryPrefix + `
viewerCreated: search(query: $viewerQuery, type: ISSUE, first: $per_page) {
totalCount: issueCount
edges {
node {
...prWithReviews
}
}
}
reviewRequested: search(query: $reviewerQuery, type: ISSUE, first: $per_page) {
totalCount: issueCount
edges {
node {
...pr
}
}
}
}
`
currentUsername := options.Username
if currentUsername == "@me" && ghinstance.IsEnterprise(repo.RepoHost()) {
var err error
currentUsername, err = CurrentLoginName(client, repo.RepoHost())
if err != nil {
return nil, err
}
}
viewerQuery := fmt.Sprintf("repo:%s state:open is:pr author:%s", ghrepo.FullName(repo), currentUsername)
reviewerQuery := fmt.Sprintf("repo:%s state:open review-requested:%s", ghrepo.FullName(repo), currentUsername)
currentPRHeadRef := options.HeadRef
branchWithoutOwner := currentPRHeadRef
if idx := strings.Index(currentPRHeadRef, ":"); idx >= 0 {
branchWithoutOwner = currentPRHeadRef[idx+1:]
}
variables := map[string]interface{}{
"viewerQuery": viewerQuery,
"reviewerQuery": reviewerQuery,
"owner": repo.RepoOwner(),
"repo": repo.RepoName(),
"headRefName": branchWithoutOwner,
"number": options.CurrentPR,
}
var resp response
err := client.GraphQL(repo.RepoHost(), query, variables, &resp)
if err != nil {
return nil, err
}
var viewerCreated []PullRequest
for _, edge := range resp.ViewerCreated.Edges {
viewerCreated = append(viewerCreated, edge.Node)
}
var reviewRequested []PullRequest
for _, edge := range resp.ReviewRequested.Edges {
reviewRequested = append(reviewRequested, edge.Node)
}
var currentPR = resp.Repository.PullRequest
if currentPR == nil {
for _, edge := range resp.Repository.PullRequests.Edges {
if edge.Node.HeadLabel() == currentPRHeadRef {
currentPR = &edge.Node
break // Take the most recent PR for the current branch
}
}
}
payload := PullRequestsPayload{
ViewerCreated: PullRequestAndTotalCount{
PullRequests: viewerCreated,
TotalCount: resp.ViewerCreated.TotalCount,
},
ReviewRequested: PullRequestAndTotalCount{
PullRequests: reviewRequested,
TotalCount: resp.ReviewRequested.TotalCount,
},
CurrentPR: currentPR,
DefaultBranch: resp.Repository.DefaultBranchRef.Name,
}
return &payload, nil
}
func pullRequestFragment(httpClient *http.Client, hostname string) (string, error) {
cachedClient := NewCachedClient(httpClient, time.Hour*24)
prFeatures, err := determinePullRequestFeatures(cachedClient, hostname)
if err != nil {
return "", err
}
fields := []string{
"number", "title", "state", "url", "isDraft", "isCrossRepository",
"headRefName", "headRepositoryOwner", "mergeStateStatus",
}
if prFeatures.HasStatusCheckRollup {
fields = append(fields, "statusCheckRollup")
}
if prFeatures.HasBranchProtectionRule {
fields = append(fields, "requiresStrictStatusChecks")
}
var reviewFields []string
if prFeatures.HasReviewDecision {
reviewFields = append(reviewFields, "reviewDecision", "latestReviews")
}
fragments := fmt.Sprintf(`
fragment pr on PullRequest {%s}
fragment prWithReviews on PullRequest {...pr,%s}
`, PullRequestGraphQL(fields), PullRequestGraphQL(reviewFields))
return fragments, nil
}
// CreatePullRequest creates a pull request in a GitHub repository
func CreatePullRequest(client *Client, repo *Repository, params map[string]interface{}) (*PullRequest, error) {
query := `

View file

@ -4,7 +4,6 @@ import (
"encoding/json"
"testing"
"github.com/MakeNowJust/heredoc"
"github.com/cli/cli/v2/internal/ghrepo"
"github.com/cli/cli/v2/pkg/httpmock"
"github.com/stretchr/testify/assert"
@ -49,117 +48,6 @@ func TestBranchDeleteRemote(t *testing.T) {
}
}
func Test_determinePullRequestFeatures(t *testing.T) {
tests := []struct {
name string
hostname string
queryResponse map[string]string
wantPrFeatures pullRequestFeature
wantErr bool
}{
{
name: "github.com",
hostname: "github.com",
wantPrFeatures: pullRequestFeature{
HasReviewDecision: true,
HasStatusCheckRollup: true,
HasBranchProtectionRule: true,
},
wantErr: false,
},
{
name: "GHE empty response",
hostname: "git.my.org",
queryResponse: map[string]string{
`query PullRequest_fields\b`: `{"data": {}}`,
`query PullRequest_fields2\b`: `{"data": {}}`,
},
wantPrFeatures: pullRequestFeature{
HasReviewDecision: false,
HasStatusCheckRollup: false,
HasBranchProtectionRule: false,
},
wantErr: false,
},
{
name: "GHE has reviewDecision",
hostname: "git.my.org",
queryResponse: map[string]string{
`query PullRequest_fields\b`: heredoc.Doc(`
{ "data": { "PullRequest": { "fields": [
{"name": "foo"},
{"name": "reviewDecision"}
] } } }
`),
`query PullRequest_fields2\b`: `{"data": {}}`,
},
wantPrFeatures: pullRequestFeature{
HasReviewDecision: true,
HasStatusCheckRollup: false,
HasBranchProtectionRule: false,
},
wantErr: false,
},
{
name: "GHE has statusCheckRollup",
hostname: "git.my.org",
queryResponse: map[string]string{
`query PullRequest_fields\b`: heredoc.Doc(`
{ "data": { "Commit": { "fields": [
{"name": "foo"},
{"name": "statusCheckRollup"}
] } } }
`),
`query PullRequest_fields2\b`: `{"data": {}}`,
},
wantPrFeatures: pullRequestFeature{
HasReviewDecision: false,
HasStatusCheckRollup: true,
HasBranchProtectionRule: false,
},
wantErr: false,
},
{
name: "GHE has branchProtectionRule",
hostname: "git.my.org",
queryResponse: map[string]string{
`query PullRequest_fields\b`: `{"data": {}}`,
`query PullRequest_fields2\b`: heredoc.Doc(`
{ "data": { "Ref": { "fields": [
{"name": "foo"},
{"name": "branchProtectionRule"}
] } } }
`),
},
wantPrFeatures: pullRequestFeature{
HasReviewDecision: false,
HasStatusCheckRollup: false,
HasBranchProtectionRule: true,
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
fakeHTTP := &httpmock.Registry{}
httpClient := NewHTTPClient(ReplaceTripper(fakeHTTP))
for query, resp := range tt.queryResponse {
fakeHTTP.Register(httpmock.GraphQL(query), httpmock.StringResponse(resp))
}
gotPrFeatures, err := determinePullRequestFeatures(httpClient, tt.hostname)
if tt.wantErr {
assert.Error(t, err)
return
} else {
assert.NoError(t, err)
}
assert.Equal(t, tt.wantPrFeatures, gotPrFeatures)
})
}
}
func Test_Logins(t *testing.T) {
rr := ReviewRequests{}
var tests = []struct {

View file

@ -0,0 +1,29 @@
package featuredetection
type DisabledDetectorMock struct{}
func (md *DisabledDetectorMock) IssueFeatures() (IssueFeatures, error) {
return IssueFeatures{}, nil
}
func (md *DisabledDetectorMock) PullRequestFeatures() (PullRequestFeatures, error) {
return PullRequestFeatures{}, nil
}
func (md *DisabledDetectorMock) RepositoryFeatures() (RepositoryFeatures, error) {
return RepositoryFeatures{}, nil
}
type EnabledDetectorMock struct{}
func (md *EnabledDetectorMock) IssueFeatures() (IssueFeatures, error) {
return allIssueFeatures, nil
}
func (md *EnabledDetectorMock) PullRequestFeatures() (PullRequestFeatures, error) {
return allPullRequestFeatures, nil
}
func (md *EnabledDetectorMock) RepositoryFeatures() (RepositoryFeatures, error) {
return allRepositoryFeatures, nil
}

View file

@ -0,0 +1,108 @@
package featuredetection
import (
"context"
"net/http"
"time"
"github.com/cli/cli/v2/api"
"github.com/cli/cli/v2/internal/ghinstance"
graphql "github.com/cli/shurcooL-graphql"
)
type Detector interface {
IssueFeatures() (IssueFeatures, error)
PullRequestFeatures() (PullRequestFeatures, error)
RepositoryFeatures() (RepositoryFeatures, error)
}
type IssueFeatures struct{}
var allIssueFeatures = IssueFeatures{}
type PullRequestFeatures struct {
ReviewDecision bool
StatusCheckRollup bool
BranchProtectionRule bool
}
var allPullRequestFeatures = PullRequestFeatures{
ReviewDecision: true,
StatusCheckRollup: true,
BranchProtectionRule: true,
}
type RepositoryFeatures struct {
IssueTemplateMutation bool
IssueTemplateQuery bool
PullRequestTemplateQuery bool
}
var allRepositoryFeatures = RepositoryFeatures{
IssueTemplateMutation: true,
IssueTemplateQuery: true,
PullRequestTemplateQuery: true,
}
type detector struct {
host string
httpClient *http.Client
}
func NewDetector(httpClient *http.Client, host string) Detector {
cachedClient := api.NewCachedClient(httpClient, time.Hour*48)
return &detector{
httpClient: cachedClient,
host: host,
}
}
func (d *detector) IssueFeatures() (IssueFeatures, error) {
if !ghinstance.IsEnterprise(d.host) {
return allIssueFeatures, nil
}
return allIssueFeatures, nil
}
func (d *detector) PullRequestFeatures() (PullRequestFeatures, error) {
if !ghinstance.IsEnterprise(d.host) {
return allPullRequestFeatures, nil
}
return allPullRequestFeatures, nil
}
func (d *detector) RepositoryFeatures() (RepositoryFeatures, error) {
if !ghinstance.IsEnterprise(d.host) {
return allRepositoryFeatures, nil
}
features := RepositoryFeatures{
IssueTemplateQuery: true,
IssueTemplateMutation: true,
}
var featureDetection struct {
Repository struct {
Fields []struct {
Name string
} `graphql:"fields(includeDeprecated: true)"`
} `graphql:"Repository: __type(name: \"Repository\")"`
}
gql := graphql.NewClient(ghinstance.GraphQLEndpoint(d.host), d.httpClient)
err := gql.QueryNamed(context.Background(), "Repository_fields", &featureDetection, nil)
if err != nil {
return features, err
}
for _, field := range featureDetection.Repository.Fields {
if field.Name == "pullRequestTemplates" {
features.PullRequestTemplateQuery = true
}
}
return features, nil
}

View file

@ -0,0 +1,127 @@
package featuredetection
import (
"testing"
"github.com/MakeNowJust/heredoc"
"github.com/cli/cli/v2/api"
"github.com/cli/cli/v2/pkg/httpmock"
"github.com/stretchr/testify/assert"
)
func TestPullRequestFeatures(t *testing.T) {
tests := []struct {
name string
hostname string
queryResponse map[string]string
wantFeatures PullRequestFeatures
wantErr bool
}{
{
name: "github.com",
hostname: "github.com",
wantFeatures: PullRequestFeatures{
ReviewDecision: true,
StatusCheckRollup: true,
BranchProtectionRule: true,
},
wantErr: false,
},
{
name: "GHE",
hostname: "git.my.org",
wantFeatures: PullRequestFeatures{
ReviewDecision: true,
StatusCheckRollup: true,
BranchProtectionRule: true,
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
fakeHTTP := &httpmock.Registry{}
httpClient := api.NewHTTPClient(api.ReplaceTripper(fakeHTTP))
for query, resp := range tt.queryResponse {
fakeHTTP.Register(httpmock.GraphQL(query), httpmock.StringResponse(resp))
}
detector := detector{host: tt.hostname, httpClient: httpClient}
gotPrFeatures, err := detector.PullRequestFeatures()
if tt.wantErr {
assert.Error(t, err)
return
}
assert.NoError(t, err)
assert.Equal(t, tt.wantFeatures, gotPrFeatures)
})
}
}
func TestRepositoryFeatures(t *testing.T) {
tests := []struct {
name string
hostname string
queryResponse map[string]string
wantFeatures RepositoryFeatures
wantErr bool
}{
{
name: "github.com",
hostname: "github.com",
wantFeatures: RepositoryFeatures{
IssueTemplateMutation: true,
IssueTemplateQuery: true,
PullRequestTemplateQuery: true,
},
wantErr: false,
},
{
name: "GHE empty response",
hostname: "git.my.org",
queryResponse: map[string]string{
`query Repository_fields\b`: `{"data": {}}`,
},
wantFeatures: RepositoryFeatures{
IssueTemplateMutation: true,
IssueTemplateQuery: true,
PullRequestTemplateQuery: false,
},
wantErr: false,
},
{
name: "GHE has pull request template query",
hostname: "git.my.org",
queryResponse: map[string]string{
`query Repository_fields\b`: heredoc.Doc(`
{ "data": { "Repository": { "fields": [
{"name": "pullRequestTemplates"}
] } } }
`),
},
wantFeatures: RepositoryFeatures{
IssueTemplateMutation: true,
IssueTemplateQuery: true,
PullRequestTemplateQuery: true,
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
fakeHTTP := &httpmock.Registry{}
httpClient := api.NewHTTPClient(api.ReplaceTripper(fakeHTTP))
for query, resp := range tt.queryResponse {
fakeHTTP.Register(httpmock.GraphQL(query), httpmock.StringResponse(resp))
}
detector := detector{host: tt.hostname, httpClient: httpClient}
gotPrFeatures, err := detector.RepositoryFeatures()
if tt.wantErr {
assert.Error(t, err)
return
}
assert.NoError(t, err)
assert.Equal(t, tt.wantFeatures, gotPrFeatures)
})
}
}

View file

@ -4,11 +4,10 @@ import (
"context"
"fmt"
"net/http"
"time"
"github.com/AlecAivazis/survey/v2"
"github.com/cli/cli/v2/api"
"github.com/cli/cli/v2/git"
fd "github.com/cli/cli/v2/internal/featuredetection"
"github.com/cli/cli/v2/internal/ghinstance"
"github.com/cli/cli/v2/internal/ghrepo"
"github.com/cli/cli/v2/pkg/githubtemplate"
@ -109,55 +108,6 @@ func listPullRequestTemplates(httpClient *http.Client, repo ghrepo.Interface) ([
return templates, nil
}
func hasTemplateSupport(httpClient *http.Client, hostname string, isPR bool) (bool, error) {
if !ghinstance.IsEnterprise(hostname) {
return true, nil
}
var featureDetection struct {
Repository struct {
Fields []struct {
Name string
} `graphql:"fields(includeDeprecated: true)"`
} `graphql:"Repository: __type(name: \"Repository\")"`
CreateIssueInput struct {
InputFields []struct {
Name string
}
} `graphql:"CreateIssueInput: __type(name: \"CreateIssueInput\")"`
}
gql := graphql.NewClient(ghinstance.GraphQLEndpoint(hostname), httpClient)
err := gql.QueryNamed(context.Background(), "IssueTemplates_fields", &featureDetection, nil)
if err != nil {
return false, err
}
var hasIssueQuerySupport bool
var hasIssueMutationSupport bool
var hasPullRequestQuerySupport bool
for _, field := range featureDetection.Repository.Fields {
if field.Name == "issueTemplates" {
hasIssueQuerySupport = true
}
if field.Name == "pullRequestTemplates" {
hasPullRequestQuerySupport = true
}
}
for _, field := range featureDetection.CreateIssueInput.InputFields {
if field.Name == "issueTemplate" {
hasIssueMutationSupport = true
}
}
if isPR {
return hasPullRequestQuerySupport, nil
} else {
return hasIssueQuerySupport && hasIssueMutationSupport, nil
}
}
type Template interface {
Name() string
NameForSubmit() string
@ -170,8 +120,8 @@ type templateManager struct {
allowFS bool
isPR bool
httpClient *http.Client
detector fd.Detector
cachedClient *http.Client
templates []Template
legacyTemplate Template
@ -186,14 +136,21 @@ func NewTemplateManager(httpClient *http.Client, repo ghrepo.Interface, dir stri
allowFS: allowFS,
isPR: isPR,
httpClient: httpClient,
detector: fd.NewDetector(httpClient, repo.RepoHost()),
}
}
func (m *templateManager) hasAPI() (bool, error) {
if m.cachedClient == nil {
m.cachedClient = api.NewCachedClient(m.httpClient, time.Hour*24)
if !m.isPR {
return true, nil
}
return hasTemplateSupport(m.cachedClient, m.repo.RepoHost(), m.isPR)
features, err := m.detector.RepositoryFeatures()
if err != nil {
return false, err
}
return features.PullRequestTemplateQuery, nil
}
func (m *templateManager) HasTemplates() (bool, error) {

View file

@ -6,6 +6,7 @@ import (
"path/filepath"
"testing"
fd "github.com/cli/cli/v2/internal/featuredetection"
"github.com/cli/cli/v2/internal/ghrepo"
"github.com/cli/cli/v2/pkg/httpmock"
"github.com/cli/cli/v2/pkg/prompt"
@ -22,22 +23,6 @@ func TestTemplateManager_hasAPI(t *testing.T) {
httpClient := &http.Client{Transport: &tr}
defer tr.Verify(t)
tr.Register(
httpmock.GraphQL(`query IssueTemplates_fields\b`),
httpmock.StringResponse(`{"data":{
"Repository": {
"fields": [
{"name": "foo"},
{"name": "issueTemplates"}
]
},
"CreateIssueInput": {
"inputFields": [
{"name": "foo"},
{"name": "issueTemplate"}
]
}
}}`))
tr.Register(
httpmock.GraphQL(`query IssueTemplates\b`),
httpmock.StringResponse(`{"data":{"repository":{
@ -48,12 +33,12 @@ func TestTemplateManager_hasAPI(t *testing.T) {
}}}`))
m := templateManager{
repo: ghrepo.NewWithHost("OWNER", "REPO", "example.com"),
rootDir: rootDir,
allowFS: true,
isPR: false,
httpClient: httpClient,
cachedClient: httpClient,
repo: ghrepo.NewWithHost("OWNER", "REPO", "example.com"),
rootDir: rootDir,
allowFS: true,
isPR: false,
httpClient: httpClient,
detector: &fd.EnabledDetectorMock{},
}
hasTemplates, err := m.HasTemplates()
@ -84,16 +69,6 @@ func TestTemplateManager_hasAPI_PullRequest(t *testing.T) {
httpClient := &http.Client{Transport: &tr}
defer tr.Verify(t)
tr.Register(
httpmock.GraphQL(`query IssueTemplates_fields\b`),
httpmock.StringResponse(`{"data":{
"Repository": {
"fields": [
{"name": "foo"},
{"name": "pullRequestTemplates"}
]
}
}}`))
tr.Register(
httpmock.GraphQL(`query PullRequestTemplates\b`),
httpmock.StringResponse(`{"data":{"repository":{
@ -104,12 +79,12 @@ func TestTemplateManager_hasAPI_PullRequest(t *testing.T) {
}}}`))
m := templateManager{
repo: ghrepo.NewWithHost("OWNER", "REPO", "example.com"),
rootDir: rootDir,
allowFS: true,
isPR: true,
httpClient: httpClient,
cachedClient: httpClient,
repo: ghrepo.NewWithHost("OWNER", "REPO", "example.com"),
rootDir: rootDir,
allowFS: true,
isPR: true,
httpClient: httpClient,
detector: &fd.EnabledDetectorMock{},
}
hasTemplates, err := m.HasTemplates()

201
pkg/cmd/pr/status/http.go Normal file
View file

@ -0,0 +1,201 @@
package status
import (
"fmt"
"net/http"
"strings"
"github.com/cli/cli/v2/api"
"github.com/cli/cli/v2/internal/ghinstance"
"github.com/cli/cli/v2/internal/ghrepo"
"github.com/cli/cli/v2/pkg/set"
)
type requestOptions struct {
CurrentPR int
HeadRef string
Username string
Fields []string
}
type pullRequestsPayload struct {
ViewerCreated api.PullRequestAndTotalCount
ReviewRequested api.PullRequestAndTotalCount
CurrentPR *api.PullRequest
DefaultBranch string
}
func pullRequestStatus(httpClient *http.Client, repo ghrepo.Interface, options requestOptions) (*pullRequestsPayload, error) {
apiClient := api.NewClientFromHTTP(httpClient)
type edges struct {
TotalCount int
Edges []struct {
Node api.PullRequest
}
}
type response struct {
Repository struct {
DefaultBranchRef struct {
Name string
}
PullRequests edges
PullRequest *api.PullRequest
}
ViewerCreated edges
ReviewRequested edges
}
var fragments string
if len(options.Fields) > 0 {
fields := set.NewStringSet()
fields.AddValues(options.Fields)
// these are always necessary to find the PR for the current branch
fields.AddValues([]string{"isCrossRepository", "headRepositoryOwner", "headRefName"})
gr := api.PullRequestGraphQL(fields.ToSlice())
fragments = fmt.Sprintf("fragment pr on PullRequest{%s}fragment prWithReviews on PullRequest{...pr}", gr)
} else {
var err error
fragments, err = pullRequestFragment(httpClient, repo.RepoHost())
if err != nil {
return nil, err
}
}
queryPrefix := `
query PullRequestStatus($owner: String!, $repo: String!, $headRefName: String!, $viewerQuery: String!, $reviewerQuery: String!, $per_page: Int = 10) {
repository(owner: $owner, name: $repo) {
defaultBranchRef {
name
}
pullRequests(headRefName: $headRefName, first: $per_page, orderBy: { field: CREATED_AT, direction: DESC }) {
totalCount
edges {
node {
...prWithReviews
}
}
}
}
`
if options.CurrentPR > 0 {
queryPrefix = `
query PullRequestStatus($owner: String!, $repo: String!, $number: Int!, $viewerQuery: String!, $reviewerQuery: String!, $per_page: Int = 10) {
repository(owner: $owner, name: $repo) {
defaultBranchRef {
name
}
pullRequest(number: $number) {
...prWithReviews
baseRef {
branchProtectionRule {
requiredApprovingReviewCount
}
}
}
}
`
}
query := fragments + queryPrefix + `
viewerCreated: search(query: $viewerQuery, type: ISSUE, first: $per_page) {
totalCount: issueCount
edges {
node {
...prWithReviews
}
}
}
reviewRequested: search(query: $reviewerQuery, type: ISSUE, first: $per_page) {
totalCount: issueCount
edges {
node {
...pr
}
}
}
}
`
currentUsername := options.Username
if currentUsername == "@me" && ghinstance.IsEnterprise(repo.RepoHost()) {
var err error
currentUsername, err = api.CurrentLoginName(apiClient, repo.RepoHost())
if err != nil {
return nil, err
}
}
viewerQuery := fmt.Sprintf("repo:%s state:open is:pr author:%s", ghrepo.FullName(repo), currentUsername)
reviewerQuery := fmt.Sprintf("repo:%s state:open review-requested:%s", ghrepo.FullName(repo), currentUsername)
currentPRHeadRef := options.HeadRef
branchWithoutOwner := currentPRHeadRef
if idx := strings.Index(currentPRHeadRef, ":"); idx >= 0 {
branchWithoutOwner = currentPRHeadRef[idx+1:]
}
variables := map[string]interface{}{
"viewerQuery": viewerQuery,
"reviewerQuery": reviewerQuery,
"owner": repo.RepoOwner(),
"repo": repo.RepoName(),
"headRefName": branchWithoutOwner,
"number": options.CurrentPR,
}
var resp response
err := apiClient.GraphQL(repo.RepoHost(), query, variables, &resp)
if err != nil {
return nil, err
}
var viewerCreated []api.PullRequest
for _, edge := range resp.ViewerCreated.Edges {
viewerCreated = append(viewerCreated, edge.Node)
}
var reviewRequested []api.PullRequest
for _, edge := range resp.ReviewRequested.Edges {
reviewRequested = append(reviewRequested, edge.Node)
}
var currentPR = resp.Repository.PullRequest
if currentPR == nil {
for _, edge := range resp.Repository.PullRequests.Edges {
if edge.Node.HeadLabel() == currentPRHeadRef {
currentPR = &edge.Node
break // Take the most recent PR for the current branch
}
}
}
payload := pullRequestsPayload{
ViewerCreated: api.PullRequestAndTotalCount{
PullRequests: viewerCreated,
TotalCount: resp.ViewerCreated.TotalCount,
},
ReviewRequested: api.PullRequestAndTotalCount{
PullRequests: reviewRequested,
TotalCount: resp.ReviewRequested.TotalCount,
},
CurrentPR: currentPR,
DefaultBranch: resp.Repository.DefaultBranchRef.Name,
}
return &payload, nil
}
func pullRequestFragment(httpClient *http.Client, hostname string) (string, error) {
fields := []string{
"number", "title", "state", "url", "isDraft", "isCrossRepository",
"headRefName", "headRepositoryOwner", "mergeStateStatus",
"statusCheckRollup", "requiresStrictStatusChecks",
}
reviewFields := []string{"reviewDecision", "latestReviews"}
fragments := fmt.Sprintf(`
fragment pr on PullRequest {%s}
fragment prWithReviews on PullRequest {...pr,%s}
`, api.PullRequestGraphQL(fields), api.PullRequestGraphQL(reviewFields))
return fragments, nil
}

View file

@ -67,7 +67,6 @@ func statusRun(opts *StatusOptions) error {
if err != nil {
return err
}
apiClient := api.NewClientFromHTTP(httpClient)
baseRepo, err := opts.BaseRepo()
if err != nil {
@ -91,7 +90,7 @@ func statusRun(opts *StatusOptions) error {
}
}
options := api.StatusOptions{
options := requestOptions{
Username: "@me",
CurrentPR: currentPRNumber,
HeadRef: currentPRHeadRef,
@ -99,7 +98,7 @@ func statusRun(opts *StatusOptions) error {
if opts.Exporter != nil {
options.Fields = opts.Exporter.Fields()
}
prPayload, err := api.PullRequestStatus(apiClient, baseRepo, options)
prPayload, err := pullRequestStatus(httpClient, baseRepo, options)
if err != nil {
return err
}