Extract feature detection package (#5494)
This commit is contained in:
parent
f8b3ff999f
commit
539b150833
9 changed files with 492 additions and 488 deletions
|
|
@ -4,23 +4,12 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cli/cli/v2/internal/ghinstance"
|
||||
"github.com/cli/cli/v2/internal/ghrepo"
|
||||
"github.com/cli/cli/v2/pkg/set"
|
||||
"github.com/shurcooL/githubv4"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
type PullRequestsPayload struct {
|
||||
ViewerCreated PullRequestAndTotalCount
|
||||
ReviewRequested PullRequestAndTotalCount
|
||||
CurrentPR *PullRequest
|
||||
DefaultBranch string
|
||||
}
|
||||
|
||||
type PullRequestAndTotalCount struct {
|
||||
TotalCount int
|
||||
PullRequests []PullRequest
|
||||
|
|
@ -269,275 +258,6 @@ func (pr *PullRequest) DisplayableReviews() PullRequestReviews {
|
|||
return PullRequestReviews{Nodes: published, TotalCount: len(published)}
|
||||
}
|
||||
|
||||
type pullRequestFeature struct {
|
||||
HasReviewDecision bool
|
||||
HasStatusCheckRollup bool
|
||||
HasBranchProtectionRule bool
|
||||
}
|
||||
|
||||
func determinePullRequestFeatures(httpClient *http.Client, hostname string) (prFeatures pullRequestFeature, err error) {
|
||||
if !ghinstance.IsEnterprise(hostname) {
|
||||
prFeatures.HasReviewDecision = true
|
||||
prFeatures.HasStatusCheckRollup = true
|
||||
prFeatures.HasBranchProtectionRule = true
|
||||
return
|
||||
}
|
||||
|
||||
var featureDetection struct {
|
||||
PullRequest struct {
|
||||
Fields []struct {
|
||||
Name string
|
||||
} `graphql:"fields(includeDeprecated: true)"`
|
||||
} `graphql:"PullRequest: __type(name: \"PullRequest\")"`
|
||||
Commit struct {
|
||||
Fields []struct {
|
||||
Name string
|
||||
} `graphql:"fields(includeDeprecated: true)"`
|
||||
} `graphql:"Commit: __type(name: \"Commit\")"`
|
||||
}
|
||||
|
||||
// needs to be a separate query because the backend only supports 2 `__type` expressions in one query
|
||||
var featureDetection2 struct {
|
||||
Ref struct {
|
||||
Fields []struct {
|
||||
Name string
|
||||
} `graphql:"fields(includeDeprecated: true)"`
|
||||
} `graphql:"Ref: __type(name: \"Ref\")"`
|
||||
}
|
||||
|
||||
v4 := graphQLClient(httpClient, hostname)
|
||||
|
||||
g := new(errgroup.Group)
|
||||
g.Go(func() error {
|
||||
return v4.QueryNamed(context.Background(), "PullRequest_fields", &featureDetection, nil)
|
||||
})
|
||||
g.Go(func() error {
|
||||
return v4.QueryNamed(context.Background(), "PullRequest_fields2", &featureDetection2, nil)
|
||||
})
|
||||
|
||||
err = g.Wait()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, field := range featureDetection.PullRequest.Fields {
|
||||
switch field.Name {
|
||||
case "reviewDecision":
|
||||
prFeatures.HasReviewDecision = true
|
||||
}
|
||||
}
|
||||
for _, field := range featureDetection.Commit.Fields {
|
||||
switch field.Name {
|
||||
case "statusCheckRollup":
|
||||
prFeatures.HasStatusCheckRollup = true
|
||||
}
|
||||
}
|
||||
for _, field := range featureDetection2.Ref.Fields {
|
||||
switch field.Name {
|
||||
case "branchProtectionRule":
|
||||
prFeatures.HasBranchProtectionRule = true
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type StatusOptions struct {
|
||||
CurrentPR int
|
||||
HeadRef string
|
||||
Username string
|
||||
Fields []string
|
||||
}
|
||||
|
||||
func PullRequestStatus(client *Client, repo ghrepo.Interface, options StatusOptions) (*PullRequestsPayload, error) {
|
||||
type edges struct {
|
||||
TotalCount int
|
||||
Edges []struct {
|
||||
Node PullRequest
|
||||
}
|
||||
}
|
||||
|
||||
type response struct {
|
||||
Repository struct {
|
||||
DefaultBranchRef struct {
|
||||
Name string
|
||||
}
|
||||
PullRequests edges
|
||||
PullRequest *PullRequest
|
||||
}
|
||||
ViewerCreated edges
|
||||
ReviewRequested edges
|
||||
}
|
||||
|
||||
var fragments string
|
||||
if len(options.Fields) > 0 {
|
||||
fields := set.NewStringSet()
|
||||
fields.AddValues(options.Fields)
|
||||
// these are always necessary to find the PR for the current branch
|
||||
fields.AddValues([]string{"isCrossRepository", "headRepositoryOwner", "headRefName"})
|
||||
gr := PullRequestGraphQL(fields.ToSlice())
|
||||
fragments = fmt.Sprintf("fragment pr on PullRequest{%s}fragment prWithReviews on PullRequest{...pr}", gr)
|
||||
} else {
|
||||
var err error
|
||||
fragments, err = pullRequestFragment(client.http, repo.RepoHost())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
queryPrefix := `
|
||||
query PullRequestStatus($owner: String!, $repo: String!, $headRefName: String!, $viewerQuery: String!, $reviewerQuery: String!, $per_page: Int = 10) {
|
||||
repository(owner: $owner, name: $repo) {
|
||||
defaultBranchRef {
|
||||
name
|
||||
}
|
||||
pullRequests(headRefName: $headRefName, first: $per_page, orderBy: { field: CREATED_AT, direction: DESC }) {
|
||||
totalCount
|
||||
edges {
|
||||
node {
|
||||
...prWithReviews
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
`
|
||||
if options.CurrentPR > 0 {
|
||||
queryPrefix = `
|
||||
query PullRequestStatus($owner: String!, $repo: String!, $number: Int!, $viewerQuery: String!, $reviewerQuery: String!, $per_page: Int = 10) {
|
||||
repository(owner: $owner, name: $repo) {
|
||||
defaultBranchRef {
|
||||
name
|
||||
}
|
||||
pullRequest(number: $number) {
|
||||
...prWithReviews
|
||||
baseRef {
|
||||
branchProtectionRule {
|
||||
requiredApprovingReviewCount
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
`
|
||||
}
|
||||
|
||||
query := fragments + queryPrefix + `
|
||||
viewerCreated: search(query: $viewerQuery, type: ISSUE, first: $per_page) {
|
||||
totalCount: issueCount
|
||||
edges {
|
||||
node {
|
||||
...prWithReviews
|
||||
}
|
||||
}
|
||||
}
|
||||
reviewRequested: search(query: $reviewerQuery, type: ISSUE, first: $per_page) {
|
||||
totalCount: issueCount
|
||||
edges {
|
||||
node {
|
||||
...pr
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
`
|
||||
|
||||
currentUsername := options.Username
|
||||
if currentUsername == "@me" && ghinstance.IsEnterprise(repo.RepoHost()) {
|
||||
var err error
|
||||
currentUsername, err = CurrentLoginName(client, repo.RepoHost())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
viewerQuery := fmt.Sprintf("repo:%s state:open is:pr author:%s", ghrepo.FullName(repo), currentUsername)
|
||||
reviewerQuery := fmt.Sprintf("repo:%s state:open review-requested:%s", ghrepo.FullName(repo), currentUsername)
|
||||
|
||||
currentPRHeadRef := options.HeadRef
|
||||
branchWithoutOwner := currentPRHeadRef
|
||||
if idx := strings.Index(currentPRHeadRef, ":"); idx >= 0 {
|
||||
branchWithoutOwner = currentPRHeadRef[idx+1:]
|
||||
}
|
||||
|
||||
variables := map[string]interface{}{
|
||||
"viewerQuery": viewerQuery,
|
||||
"reviewerQuery": reviewerQuery,
|
||||
"owner": repo.RepoOwner(),
|
||||
"repo": repo.RepoName(),
|
||||
"headRefName": branchWithoutOwner,
|
||||
"number": options.CurrentPR,
|
||||
}
|
||||
|
||||
var resp response
|
||||
err := client.GraphQL(repo.RepoHost(), query, variables, &resp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var viewerCreated []PullRequest
|
||||
for _, edge := range resp.ViewerCreated.Edges {
|
||||
viewerCreated = append(viewerCreated, edge.Node)
|
||||
}
|
||||
|
||||
var reviewRequested []PullRequest
|
||||
for _, edge := range resp.ReviewRequested.Edges {
|
||||
reviewRequested = append(reviewRequested, edge.Node)
|
||||
}
|
||||
|
||||
var currentPR = resp.Repository.PullRequest
|
||||
if currentPR == nil {
|
||||
for _, edge := range resp.Repository.PullRequests.Edges {
|
||||
if edge.Node.HeadLabel() == currentPRHeadRef {
|
||||
currentPR = &edge.Node
|
||||
break // Take the most recent PR for the current branch
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
payload := PullRequestsPayload{
|
||||
ViewerCreated: PullRequestAndTotalCount{
|
||||
PullRequests: viewerCreated,
|
||||
TotalCount: resp.ViewerCreated.TotalCount,
|
||||
},
|
||||
ReviewRequested: PullRequestAndTotalCount{
|
||||
PullRequests: reviewRequested,
|
||||
TotalCount: resp.ReviewRequested.TotalCount,
|
||||
},
|
||||
CurrentPR: currentPR,
|
||||
DefaultBranch: resp.Repository.DefaultBranchRef.Name,
|
||||
}
|
||||
|
||||
return &payload, nil
|
||||
}
|
||||
|
||||
func pullRequestFragment(httpClient *http.Client, hostname string) (string, error) {
|
||||
cachedClient := NewCachedClient(httpClient, time.Hour*24)
|
||||
prFeatures, err := determinePullRequestFeatures(cachedClient, hostname)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
fields := []string{
|
||||
"number", "title", "state", "url", "isDraft", "isCrossRepository",
|
||||
"headRefName", "headRepositoryOwner", "mergeStateStatus",
|
||||
}
|
||||
if prFeatures.HasStatusCheckRollup {
|
||||
fields = append(fields, "statusCheckRollup")
|
||||
}
|
||||
if prFeatures.HasBranchProtectionRule {
|
||||
fields = append(fields, "requiresStrictStatusChecks")
|
||||
}
|
||||
|
||||
var reviewFields []string
|
||||
if prFeatures.HasReviewDecision {
|
||||
reviewFields = append(reviewFields, "reviewDecision", "latestReviews")
|
||||
}
|
||||
|
||||
fragments := fmt.Sprintf(`
|
||||
fragment pr on PullRequest {%s}
|
||||
fragment prWithReviews on PullRequest {...pr,%s}
|
||||
`, PullRequestGraphQL(fields), PullRequestGraphQL(reviewFields))
|
||||
return fragments, nil
|
||||
}
|
||||
|
||||
// CreatePullRequest creates a pull request in a GitHub repository
|
||||
func CreatePullRequest(client *Client, repo *Repository, params map[string]interface{}) (*PullRequest, error) {
|
||||
query := `
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import (
|
|||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/MakeNowJust/heredoc"
|
||||
"github.com/cli/cli/v2/internal/ghrepo"
|
||||
"github.com/cli/cli/v2/pkg/httpmock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
|
@ -49,117 +48,6 @@ func TestBranchDeleteRemote(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_determinePullRequestFeatures(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
hostname string
|
||||
queryResponse map[string]string
|
||||
wantPrFeatures pullRequestFeature
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "github.com",
|
||||
hostname: "github.com",
|
||||
wantPrFeatures: pullRequestFeature{
|
||||
HasReviewDecision: true,
|
||||
HasStatusCheckRollup: true,
|
||||
HasBranchProtectionRule: true,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "GHE empty response",
|
||||
hostname: "git.my.org",
|
||||
queryResponse: map[string]string{
|
||||
`query PullRequest_fields\b`: `{"data": {}}`,
|
||||
`query PullRequest_fields2\b`: `{"data": {}}`,
|
||||
},
|
||||
wantPrFeatures: pullRequestFeature{
|
||||
HasReviewDecision: false,
|
||||
HasStatusCheckRollup: false,
|
||||
HasBranchProtectionRule: false,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "GHE has reviewDecision",
|
||||
hostname: "git.my.org",
|
||||
queryResponse: map[string]string{
|
||||
`query PullRequest_fields\b`: heredoc.Doc(`
|
||||
{ "data": { "PullRequest": { "fields": [
|
||||
{"name": "foo"},
|
||||
{"name": "reviewDecision"}
|
||||
] } } }
|
||||
`),
|
||||
`query PullRequest_fields2\b`: `{"data": {}}`,
|
||||
},
|
||||
wantPrFeatures: pullRequestFeature{
|
||||
HasReviewDecision: true,
|
||||
HasStatusCheckRollup: false,
|
||||
HasBranchProtectionRule: false,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "GHE has statusCheckRollup",
|
||||
hostname: "git.my.org",
|
||||
queryResponse: map[string]string{
|
||||
`query PullRequest_fields\b`: heredoc.Doc(`
|
||||
{ "data": { "Commit": { "fields": [
|
||||
{"name": "foo"},
|
||||
{"name": "statusCheckRollup"}
|
||||
] } } }
|
||||
`),
|
||||
`query PullRequest_fields2\b`: `{"data": {}}`,
|
||||
},
|
||||
wantPrFeatures: pullRequestFeature{
|
||||
HasReviewDecision: false,
|
||||
HasStatusCheckRollup: true,
|
||||
HasBranchProtectionRule: false,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "GHE has branchProtectionRule",
|
||||
hostname: "git.my.org",
|
||||
queryResponse: map[string]string{
|
||||
`query PullRequest_fields\b`: `{"data": {}}`,
|
||||
`query PullRequest_fields2\b`: heredoc.Doc(`
|
||||
{ "data": { "Ref": { "fields": [
|
||||
{"name": "foo"},
|
||||
{"name": "branchProtectionRule"}
|
||||
] } } }
|
||||
`),
|
||||
},
|
||||
wantPrFeatures: pullRequestFeature{
|
||||
HasReviewDecision: false,
|
||||
HasStatusCheckRollup: false,
|
||||
HasBranchProtectionRule: true,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
fakeHTTP := &httpmock.Registry{}
|
||||
httpClient := NewHTTPClient(ReplaceTripper(fakeHTTP))
|
||||
|
||||
for query, resp := range tt.queryResponse {
|
||||
fakeHTTP.Register(httpmock.GraphQL(query), httpmock.StringResponse(resp))
|
||||
}
|
||||
|
||||
gotPrFeatures, err := determinePullRequestFeatures(httpClient, tt.hostname)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tt.wantPrFeatures, gotPrFeatures)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Logins(t *testing.T) {
|
||||
rr := ReviewRequests{}
|
||||
var tests = []struct {
|
||||
|
|
|
|||
29
internal/featuredetection/detector_mock.go
Normal file
29
internal/featuredetection/detector_mock.go
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
package featuredetection
|
||||
|
||||
type DisabledDetectorMock struct{}
|
||||
|
||||
func (md *DisabledDetectorMock) IssueFeatures() (IssueFeatures, error) {
|
||||
return IssueFeatures{}, nil
|
||||
}
|
||||
|
||||
func (md *DisabledDetectorMock) PullRequestFeatures() (PullRequestFeatures, error) {
|
||||
return PullRequestFeatures{}, nil
|
||||
}
|
||||
|
||||
func (md *DisabledDetectorMock) RepositoryFeatures() (RepositoryFeatures, error) {
|
||||
return RepositoryFeatures{}, nil
|
||||
}
|
||||
|
||||
type EnabledDetectorMock struct{}
|
||||
|
||||
func (md *EnabledDetectorMock) IssueFeatures() (IssueFeatures, error) {
|
||||
return allIssueFeatures, nil
|
||||
}
|
||||
|
||||
func (md *EnabledDetectorMock) PullRequestFeatures() (PullRequestFeatures, error) {
|
||||
return allPullRequestFeatures, nil
|
||||
}
|
||||
|
||||
func (md *EnabledDetectorMock) RepositoryFeatures() (RepositoryFeatures, error) {
|
||||
return allRepositoryFeatures, nil
|
||||
}
|
||||
108
internal/featuredetection/feature_detection.go
Normal file
108
internal/featuredetection/feature_detection.go
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
package featuredetection
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/cli/cli/v2/api"
|
||||
"github.com/cli/cli/v2/internal/ghinstance"
|
||||
graphql "github.com/cli/shurcooL-graphql"
|
||||
)
|
||||
|
||||
type Detector interface {
|
||||
IssueFeatures() (IssueFeatures, error)
|
||||
PullRequestFeatures() (PullRequestFeatures, error)
|
||||
RepositoryFeatures() (RepositoryFeatures, error)
|
||||
}
|
||||
|
||||
type IssueFeatures struct{}
|
||||
|
||||
var allIssueFeatures = IssueFeatures{}
|
||||
|
||||
type PullRequestFeatures struct {
|
||||
ReviewDecision bool
|
||||
StatusCheckRollup bool
|
||||
BranchProtectionRule bool
|
||||
}
|
||||
|
||||
var allPullRequestFeatures = PullRequestFeatures{
|
||||
ReviewDecision: true,
|
||||
StatusCheckRollup: true,
|
||||
BranchProtectionRule: true,
|
||||
}
|
||||
|
||||
type RepositoryFeatures struct {
|
||||
IssueTemplateMutation bool
|
||||
IssueTemplateQuery bool
|
||||
PullRequestTemplateQuery bool
|
||||
}
|
||||
|
||||
var allRepositoryFeatures = RepositoryFeatures{
|
||||
IssueTemplateMutation: true,
|
||||
IssueTemplateQuery: true,
|
||||
PullRequestTemplateQuery: true,
|
||||
}
|
||||
|
||||
type detector struct {
|
||||
host string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewDetector(httpClient *http.Client, host string) Detector {
|
||||
cachedClient := api.NewCachedClient(httpClient, time.Hour*48)
|
||||
return &detector{
|
||||
httpClient: cachedClient,
|
||||
host: host,
|
||||
}
|
||||
}
|
||||
|
||||
func (d *detector) IssueFeatures() (IssueFeatures, error) {
|
||||
if !ghinstance.IsEnterprise(d.host) {
|
||||
return allIssueFeatures, nil
|
||||
}
|
||||
|
||||
return allIssueFeatures, nil
|
||||
}
|
||||
|
||||
func (d *detector) PullRequestFeatures() (PullRequestFeatures, error) {
|
||||
if !ghinstance.IsEnterprise(d.host) {
|
||||
return allPullRequestFeatures, nil
|
||||
}
|
||||
|
||||
return allPullRequestFeatures, nil
|
||||
}
|
||||
|
||||
func (d *detector) RepositoryFeatures() (RepositoryFeatures, error) {
|
||||
if !ghinstance.IsEnterprise(d.host) {
|
||||
return allRepositoryFeatures, nil
|
||||
}
|
||||
|
||||
features := RepositoryFeatures{
|
||||
IssueTemplateQuery: true,
|
||||
IssueTemplateMutation: true,
|
||||
}
|
||||
|
||||
var featureDetection struct {
|
||||
Repository struct {
|
||||
Fields []struct {
|
||||
Name string
|
||||
} `graphql:"fields(includeDeprecated: true)"`
|
||||
} `graphql:"Repository: __type(name: \"Repository\")"`
|
||||
}
|
||||
|
||||
gql := graphql.NewClient(ghinstance.GraphQLEndpoint(d.host), d.httpClient)
|
||||
|
||||
err := gql.QueryNamed(context.Background(), "Repository_fields", &featureDetection, nil)
|
||||
if err != nil {
|
||||
return features, err
|
||||
}
|
||||
|
||||
for _, field := range featureDetection.Repository.Fields {
|
||||
if field.Name == "pullRequestTemplates" {
|
||||
features.PullRequestTemplateQuery = true
|
||||
}
|
||||
}
|
||||
|
||||
return features, nil
|
||||
}
|
||||
127
internal/featuredetection/feature_detection_test.go
Normal file
127
internal/featuredetection/feature_detection_test.go
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
package featuredetection
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/MakeNowJust/heredoc"
|
||||
"github.com/cli/cli/v2/api"
|
||||
"github.com/cli/cli/v2/pkg/httpmock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestPullRequestFeatures(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
hostname string
|
||||
queryResponse map[string]string
|
||||
wantFeatures PullRequestFeatures
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "github.com",
|
||||
hostname: "github.com",
|
||||
wantFeatures: PullRequestFeatures{
|
||||
ReviewDecision: true,
|
||||
StatusCheckRollup: true,
|
||||
BranchProtectionRule: true,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "GHE",
|
||||
hostname: "git.my.org",
|
||||
wantFeatures: PullRequestFeatures{
|
||||
ReviewDecision: true,
|
||||
StatusCheckRollup: true,
|
||||
BranchProtectionRule: true,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
fakeHTTP := &httpmock.Registry{}
|
||||
httpClient := api.NewHTTPClient(api.ReplaceTripper(fakeHTTP))
|
||||
for query, resp := range tt.queryResponse {
|
||||
fakeHTTP.Register(httpmock.GraphQL(query), httpmock.StringResponse(resp))
|
||||
}
|
||||
detector := detector{host: tt.hostname, httpClient: httpClient}
|
||||
gotPrFeatures, err := detector.PullRequestFeatures()
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.wantFeatures, gotPrFeatures)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepositoryFeatures(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
hostname string
|
||||
queryResponse map[string]string
|
||||
wantFeatures RepositoryFeatures
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "github.com",
|
||||
hostname: "github.com",
|
||||
wantFeatures: RepositoryFeatures{
|
||||
IssueTemplateMutation: true,
|
||||
IssueTemplateQuery: true,
|
||||
PullRequestTemplateQuery: true,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "GHE empty response",
|
||||
hostname: "git.my.org",
|
||||
queryResponse: map[string]string{
|
||||
`query Repository_fields\b`: `{"data": {}}`,
|
||||
},
|
||||
wantFeatures: RepositoryFeatures{
|
||||
IssueTemplateMutation: true,
|
||||
IssueTemplateQuery: true,
|
||||
PullRequestTemplateQuery: false,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "GHE has pull request template query",
|
||||
hostname: "git.my.org",
|
||||
queryResponse: map[string]string{
|
||||
`query Repository_fields\b`: heredoc.Doc(`
|
||||
{ "data": { "Repository": { "fields": [
|
||||
{"name": "pullRequestTemplates"}
|
||||
] } } }
|
||||
`),
|
||||
},
|
||||
wantFeatures: RepositoryFeatures{
|
||||
IssueTemplateMutation: true,
|
||||
IssueTemplateQuery: true,
|
||||
PullRequestTemplateQuery: true,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
fakeHTTP := &httpmock.Registry{}
|
||||
httpClient := api.NewHTTPClient(api.ReplaceTripper(fakeHTTP))
|
||||
for query, resp := range tt.queryResponse {
|
||||
fakeHTTP.Register(httpmock.GraphQL(query), httpmock.StringResponse(resp))
|
||||
}
|
||||
detector := detector{host: tt.hostname, httpClient: httpClient}
|
||||
gotPrFeatures, err := detector.RepositoryFeatures()
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.wantFeatures, gotPrFeatures)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -4,11 +4,10 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/AlecAivazis/survey/v2"
|
||||
"github.com/cli/cli/v2/api"
|
||||
"github.com/cli/cli/v2/git"
|
||||
fd "github.com/cli/cli/v2/internal/featuredetection"
|
||||
"github.com/cli/cli/v2/internal/ghinstance"
|
||||
"github.com/cli/cli/v2/internal/ghrepo"
|
||||
"github.com/cli/cli/v2/pkg/githubtemplate"
|
||||
|
|
@ -109,55 +108,6 @@ func listPullRequestTemplates(httpClient *http.Client, repo ghrepo.Interface) ([
|
|||
return templates, nil
|
||||
}
|
||||
|
||||
func hasTemplateSupport(httpClient *http.Client, hostname string, isPR bool) (bool, error) {
|
||||
if !ghinstance.IsEnterprise(hostname) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
var featureDetection struct {
|
||||
Repository struct {
|
||||
Fields []struct {
|
||||
Name string
|
||||
} `graphql:"fields(includeDeprecated: true)"`
|
||||
} `graphql:"Repository: __type(name: \"Repository\")"`
|
||||
CreateIssueInput struct {
|
||||
InputFields []struct {
|
||||
Name string
|
||||
}
|
||||
} `graphql:"CreateIssueInput: __type(name: \"CreateIssueInput\")"`
|
||||
}
|
||||
|
||||
gql := graphql.NewClient(ghinstance.GraphQLEndpoint(hostname), httpClient)
|
||||
err := gql.QueryNamed(context.Background(), "IssueTemplates_fields", &featureDetection, nil)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
var hasIssueQuerySupport bool
|
||||
var hasIssueMutationSupport bool
|
||||
var hasPullRequestQuerySupport bool
|
||||
|
||||
for _, field := range featureDetection.Repository.Fields {
|
||||
if field.Name == "issueTemplates" {
|
||||
hasIssueQuerySupport = true
|
||||
}
|
||||
if field.Name == "pullRequestTemplates" {
|
||||
hasPullRequestQuerySupport = true
|
||||
}
|
||||
}
|
||||
for _, field := range featureDetection.CreateIssueInput.InputFields {
|
||||
if field.Name == "issueTemplate" {
|
||||
hasIssueMutationSupport = true
|
||||
}
|
||||
}
|
||||
|
||||
if isPR {
|
||||
return hasPullRequestQuerySupport, nil
|
||||
} else {
|
||||
return hasIssueQuerySupport && hasIssueMutationSupport, nil
|
||||
}
|
||||
}
|
||||
|
||||
type Template interface {
|
||||
Name() string
|
||||
NameForSubmit() string
|
||||
|
|
@ -170,8 +120,8 @@ type templateManager struct {
|
|||
allowFS bool
|
||||
isPR bool
|
||||
httpClient *http.Client
|
||||
detector fd.Detector
|
||||
|
||||
cachedClient *http.Client
|
||||
templates []Template
|
||||
legacyTemplate Template
|
||||
|
||||
|
|
@ -186,14 +136,21 @@ func NewTemplateManager(httpClient *http.Client, repo ghrepo.Interface, dir stri
|
|||
allowFS: allowFS,
|
||||
isPR: isPR,
|
||||
httpClient: httpClient,
|
||||
detector: fd.NewDetector(httpClient, repo.RepoHost()),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *templateManager) hasAPI() (bool, error) {
|
||||
if m.cachedClient == nil {
|
||||
m.cachedClient = api.NewCachedClient(m.httpClient, time.Hour*24)
|
||||
if !m.isPR {
|
||||
return true, nil
|
||||
}
|
||||
return hasTemplateSupport(m.cachedClient, m.repo.RepoHost(), m.isPR)
|
||||
|
||||
features, err := m.detector.RepositoryFeatures()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return features.PullRequestTemplateQuery, nil
|
||||
}
|
||||
|
||||
func (m *templateManager) HasTemplates() (bool, error) {
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import (
|
|||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
fd "github.com/cli/cli/v2/internal/featuredetection"
|
||||
"github.com/cli/cli/v2/internal/ghrepo"
|
||||
"github.com/cli/cli/v2/pkg/httpmock"
|
||||
"github.com/cli/cli/v2/pkg/prompt"
|
||||
|
|
@ -22,22 +23,6 @@ func TestTemplateManager_hasAPI(t *testing.T) {
|
|||
httpClient := &http.Client{Transport: &tr}
|
||||
defer tr.Verify(t)
|
||||
|
||||
tr.Register(
|
||||
httpmock.GraphQL(`query IssueTemplates_fields\b`),
|
||||
httpmock.StringResponse(`{"data":{
|
||||
"Repository": {
|
||||
"fields": [
|
||||
{"name": "foo"},
|
||||
{"name": "issueTemplates"}
|
||||
]
|
||||
},
|
||||
"CreateIssueInput": {
|
||||
"inputFields": [
|
||||
{"name": "foo"},
|
||||
{"name": "issueTemplate"}
|
||||
]
|
||||
}
|
||||
}}`))
|
||||
tr.Register(
|
||||
httpmock.GraphQL(`query IssueTemplates\b`),
|
||||
httpmock.StringResponse(`{"data":{"repository":{
|
||||
|
|
@ -48,12 +33,12 @@ func TestTemplateManager_hasAPI(t *testing.T) {
|
|||
}}}`))
|
||||
|
||||
m := templateManager{
|
||||
repo: ghrepo.NewWithHost("OWNER", "REPO", "example.com"),
|
||||
rootDir: rootDir,
|
||||
allowFS: true,
|
||||
isPR: false,
|
||||
httpClient: httpClient,
|
||||
cachedClient: httpClient,
|
||||
repo: ghrepo.NewWithHost("OWNER", "REPO", "example.com"),
|
||||
rootDir: rootDir,
|
||||
allowFS: true,
|
||||
isPR: false,
|
||||
httpClient: httpClient,
|
||||
detector: &fd.EnabledDetectorMock{},
|
||||
}
|
||||
|
||||
hasTemplates, err := m.HasTemplates()
|
||||
|
|
@ -84,16 +69,6 @@ func TestTemplateManager_hasAPI_PullRequest(t *testing.T) {
|
|||
httpClient := &http.Client{Transport: &tr}
|
||||
defer tr.Verify(t)
|
||||
|
||||
tr.Register(
|
||||
httpmock.GraphQL(`query IssueTemplates_fields\b`),
|
||||
httpmock.StringResponse(`{"data":{
|
||||
"Repository": {
|
||||
"fields": [
|
||||
{"name": "foo"},
|
||||
{"name": "pullRequestTemplates"}
|
||||
]
|
||||
}
|
||||
}}`))
|
||||
tr.Register(
|
||||
httpmock.GraphQL(`query PullRequestTemplates\b`),
|
||||
httpmock.StringResponse(`{"data":{"repository":{
|
||||
|
|
@ -104,12 +79,12 @@ func TestTemplateManager_hasAPI_PullRequest(t *testing.T) {
|
|||
}}}`))
|
||||
|
||||
m := templateManager{
|
||||
repo: ghrepo.NewWithHost("OWNER", "REPO", "example.com"),
|
||||
rootDir: rootDir,
|
||||
allowFS: true,
|
||||
isPR: true,
|
||||
httpClient: httpClient,
|
||||
cachedClient: httpClient,
|
||||
repo: ghrepo.NewWithHost("OWNER", "REPO", "example.com"),
|
||||
rootDir: rootDir,
|
||||
allowFS: true,
|
||||
isPR: true,
|
||||
httpClient: httpClient,
|
||||
detector: &fd.EnabledDetectorMock{},
|
||||
}
|
||||
|
||||
hasTemplates, err := m.HasTemplates()
|
||||
|
|
|
|||
201
pkg/cmd/pr/status/http.go
Normal file
201
pkg/cmd/pr/status/http.go
Normal file
|
|
@ -0,0 +1,201 @@
|
|||
package status
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/cli/cli/v2/api"
|
||||
"github.com/cli/cli/v2/internal/ghinstance"
|
||||
"github.com/cli/cli/v2/internal/ghrepo"
|
||||
"github.com/cli/cli/v2/pkg/set"
|
||||
)
|
||||
|
||||
type requestOptions struct {
|
||||
CurrentPR int
|
||||
HeadRef string
|
||||
Username string
|
||||
Fields []string
|
||||
}
|
||||
|
||||
type pullRequestsPayload struct {
|
||||
ViewerCreated api.PullRequestAndTotalCount
|
||||
ReviewRequested api.PullRequestAndTotalCount
|
||||
CurrentPR *api.PullRequest
|
||||
DefaultBranch string
|
||||
}
|
||||
|
||||
func pullRequestStatus(httpClient *http.Client, repo ghrepo.Interface, options requestOptions) (*pullRequestsPayload, error) {
|
||||
apiClient := api.NewClientFromHTTP(httpClient)
|
||||
type edges struct {
|
||||
TotalCount int
|
||||
Edges []struct {
|
||||
Node api.PullRequest
|
||||
}
|
||||
}
|
||||
|
||||
type response struct {
|
||||
Repository struct {
|
||||
DefaultBranchRef struct {
|
||||
Name string
|
||||
}
|
||||
PullRequests edges
|
||||
PullRequest *api.PullRequest
|
||||
}
|
||||
ViewerCreated edges
|
||||
ReviewRequested edges
|
||||
}
|
||||
|
||||
var fragments string
|
||||
if len(options.Fields) > 0 {
|
||||
fields := set.NewStringSet()
|
||||
fields.AddValues(options.Fields)
|
||||
// these are always necessary to find the PR for the current branch
|
||||
fields.AddValues([]string{"isCrossRepository", "headRepositoryOwner", "headRefName"})
|
||||
gr := api.PullRequestGraphQL(fields.ToSlice())
|
||||
fragments = fmt.Sprintf("fragment pr on PullRequest{%s}fragment prWithReviews on PullRequest{...pr}", gr)
|
||||
} else {
|
||||
var err error
|
||||
fragments, err = pullRequestFragment(httpClient, repo.RepoHost())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
queryPrefix := `
|
||||
query PullRequestStatus($owner: String!, $repo: String!, $headRefName: String!, $viewerQuery: String!, $reviewerQuery: String!, $per_page: Int = 10) {
|
||||
repository(owner: $owner, name: $repo) {
|
||||
defaultBranchRef {
|
||||
name
|
||||
}
|
||||
pullRequests(headRefName: $headRefName, first: $per_page, orderBy: { field: CREATED_AT, direction: DESC }) {
|
||||
totalCount
|
||||
edges {
|
||||
node {
|
||||
...prWithReviews
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
`
|
||||
if options.CurrentPR > 0 {
|
||||
queryPrefix = `
|
||||
query PullRequestStatus($owner: String!, $repo: String!, $number: Int!, $viewerQuery: String!, $reviewerQuery: String!, $per_page: Int = 10) {
|
||||
repository(owner: $owner, name: $repo) {
|
||||
defaultBranchRef {
|
||||
name
|
||||
}
|
||||
pullRequest(number: $number) {
|
||||
...prWithReviews
|
||||
baseRef {
|
||||
branchProtectionRule {
|
||||
requiredApprovingReviewCount
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
`
|
||||
}
|
||||
|
||||
query := fragments + queryPrefix + `
|
||||
viewerCreated: search(query: $viewerQuery, type: ISSUE, first: $per_page) {
|
||||
totalCount: issueCount
|
||||
edges {
|
||||
node {
|
||||
...prWithReviews
|
||||
}
|
||||
}
|
||||
}
|
||||
reviewRequested: search(query: $reviewerQuery, type: ISSUE, first: $per_page) {
|
||||
totalCount: issueCount
|
||||
edges {
|
||||
node {
|
||||
...pr
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
`
|
||||
|
||||
currentUsername := options.Username
|
||||
if currentUsername == "@me" && ghinstance.IsEnterprise(repo.RepoHost()) {
|
||||
var err error
|
||||
currentUsername, err = api.CurrentLoginName(apiClient, repo.RepoHost())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
viewerQuery := fmt.Sprintf("repo:%s state:open is:pr author:%s", ghrepo.FullName(repo), currentUsername)
|
||||
reviewerQuery := fmt.Sprintf("repo:%s state:open review-requested:%s", ghrepo.FullName(repo), currentUsername)
|
||||
|
||||
currentPRHeadRef := options.HeadRef
|
||||
branchWithoutOwner := currentPRHeadRef
|
||||
if idx := strings.Index(currentPRHeadRef, ":"); idx >= 0 {
|
||||
branchWithoutOwner = currentPRHeadRef[idx+1:]
|
||||
}
|
||||
|
||||
variables := map[string]interface{}{
|
||||
"viewerQuery": viewerQuery,
|
||||
"reviewerQuery": reviewerQuery,
|
||||
"owner": repo.RepoOwner(),
|
||||
"repo": repo.RepoName(),
|
||||
"headRefName": branchWithoutOwner,
|
||||
"number": options.CurrentPR,
|
||||
}
|
||||
|
||||
var resp response
|
||||
err := apiClient.GraphQL(repo.RepoHost(), query, variables, &resp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var viewerCreated []api.PullRequest
|
||||
for _, edge := range resp.ViewerCreated.Edges {
|
||||
viewerCreated = append(viewerCreated, edge.Node)
|
||||
}
|
||||
|
||||
var reviewRequested []api.PullRequest
|
||||
for _, edge := range resp.ReviewRequested.Edges {
|
||||
reviewRequested = append(reviewRequested, edge.Node)
|
||||
}
|
||||
|
||||
var currentPR = resp.Repository.PullRequest
|
||||
if currentPR == nil {
|
||||
for _, edge := range resp.Repository.PullRequests.Edges {
|
||||
if edge.Node.HeadLabel() == currentPRHeadRef {
|
||||
currentPR = &edge.Node
|
||||
break // Take the most recent PR for the current branch
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
payload := pullRequestsPayload{
|
||||
ViewerCreated: api.PullRequestAndTotalCount{
|
||||
PullRequests: viewerCreated,
|
||||
TotalCount: resp.ViewerCreated.TotalCount,
|
||||
},
|
||||
ReviewRequested: api.PullRequestAndTotalCount{
|
||||
PullRequests: reviewRequested,
|
||||
TotalCount: resp.ReviewRequested.TotalCount,
|
||||
},
|
||||
CurrentPR: currentPR,
|
||||
DefaultBranch: resp.Repository.DefaultBranchRef.Name,
|
||||
}
|
||||
|
||||
return &payload, nil
|
||||
}
|
||||
|
||||
func pullRequestFragment(httpClient *http.Client, hostname string) (string, error) {
|
||||
fields := []string{
|
||||
"number", "title", "state", "url", "isDraft", "isCrossRepository",
|
||||
"headRefName", "headRepositoryOwner", "mergeStateStatus",
|
||||
"statusCheckRollup", "requiresStrictStatusChecks",
|
||||
}
|
||||
reviewFields := []string{"reviewDecision", "latestReviews"}
|
||||
fragments := fmt.Sprintf(`
|
||||
fragment pr on PullRequest {%s}
|
||||
fragment prWithReviews on PullRequest {...pr,%s}
|
||||
`, api.PullRequestGraphQL(fields), api.PullRequestGraphQL(reviewFields))
|
||||
return fragments, nil
|
||||
}
|
||||
|
|
@ -67,7 +67,6 @@ func statusRun(opts *StatusOptions) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
apiClient := api.NewClientFromHTTP(httpClient)
|
||||
|
||||
baseRepo, err := opts.BaseRepo()
|
||||
if err != nil {
|
||||
|
|
@ -91,7 +90,7 @@ func statusRun(opts *StatusOptions) error {
|
|||
}
|
||||
}
|
||||
|
||||
options := api.StatusOptions{
|
||||
options := requestOptions{
|
||||
Username: "@me",
|
||||
CurrentPR: currentPRNumber,
|
||||
HeadRef: currentPRHeadRef,
|
||||
|
|
@ -99,7 +98,7 @@ func statusRun(opts *StatusOptions) error {
|
|||
if opts.Exporter != nil {
|
||||
options.Fields = opts.Exporter.Fields()
|
||||
}
|
||||
prPayload, err := api.PullRequestStatus(apiClient, baseRepo, options)
|
||||
prPayload, err := pullRequestStatus(httpClient, baseRepo, options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue