Merge pull request #2035 from cli/ghe-2.20-compat

GHE 2.20 compatibility for `pr` commands
This commit is contained in:
Mislav Marohnić 2020-10-06 12:50:08 +02:00 committed by GitHub
commit 115357c6af
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 496 additions and 71 deletions

175
api/cache.go Normal file
View file

@ -0,0 +1,175 @@
package api
import (
"bufio"
"bytes"
"crypto/sha256"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"time"
)
func makeCachedClient(httpClient *http.Client, cacheTTL time.Duration) *http.Client {
cacheDir := filepath.Join(os.TempDir(), "gh-cli-cache")
return &http.Client{
Transport: CacheReponse(cacheTTL, cacheDir)(httpClient.Transport),
}
}
func isCacheableRequest(req *http.Request) bool {
if strings.EqualFold(req.Method, "GET") || strings.EqualFold(req.Method, "HEAD") {
return true
}
if strings.EqualFold(req.Method, "POST") && (req.URL.Path == "/graphql" || req.URL.Path == "/api/graphql") {
return true
}
return false
}
func isCacheableResponse(res *http.Response) bool {
return res.StatusCode < 500 && res.StatusCode != 403
}
// CacheReponse produces a RoundTripper that caches HTTP responses to disk for a specified amount of time
func CacheReponse(ttl time.Duration, dir string) ClientOption {
fs := fileStorage{
dir: dir,
ttl: ttl,
mu: &sync.RWMutex{},
}
return func(tr http.RoundTripper) http.RoundTripper {
return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) {
if !isCacheableRequest(req) {
return tr.RoundTrip(req)
}
key, keyErr := cacheKey(req)
if keyErr == nil {
if res, err := fs.read(key); err == nil {
res.Request = req
return res, nil
}
}
res, err := tr.RoundTrip(req)
if err == nil && keyErr == nil && isCacheableResponse(res) {
_ = fs.store(key, res)
}
return res, err
}}
}
}
func copyStream(r io.ReadCloser) (io.ReadCloser, io.ReadCloser) {
b := &bytes.Buffer{}
nr := io.TeeReader(r, b)
return ioutil.NopCloser(b), &readCloser{
Reader: nr,
Closer: r,
}
}
type readCloser struct {
io.Reader
io.Closer
}
func cacheKey(req *http.Request) (string, error) {
h := sha256.New()
fmt.Fprintf(h, "%s:", req.Method)
fmt.Fprintf(h, "%s:", req.URL.String())
fmt.Fprintf(h, "%s:", req.Header.Get("Accept"))
fmt.Fprintf(h, "%s:", req.Header.Get("Authorization"))
if req.Body != nil {
var bodyCopy io.ReadCloser
req.Body, bodyCopy = copyStream(req.Body)
defer bodyCopy.Close()
if _, err := io.Copy(h, bodyCopy); err != nil {
return "", err
}
}
digest := h.Sum(nil)
return fmt.Sprintf("%x", digest), nil
}
type fileStorage struct {
dir string
ttl time.Duration
mu *sync.RWMutex
}
func (fs *fileStorage) filePath(key string) string {
if len(key) >= 6 {
return filepath.Join(fs.dir, key[0:2], key[2:4], key[4:])
}
return filepath.Join(fs.dir, key)
}
func (fs *fileStorage) read(key string) (*http.Response, error) {
cacheFile := fs.filePath(key)
fs.mu.RLock()
defer fs.mu.RUnlock()
f, err := os.Open(cacheFile)
if err != nil {
return nil, err
}
defer f.Close()
stat, err := f.Stat()
if err != nil {
return nil, err
}
age := time.Since(stat.ModTime())
if age > fs.ttl {
return nil, errors.New("cache expired")
}
body := &bytes.Buffer{}
_, err = io.Copy(body, f)
if err != nil {
return nil, err
}
res, err := http.ReadResponse(bufio.NewReader(body), nil)
return res, err
}
func (fs *fileStorage) store(key string, res *http.Response) error {
cacheFile := fs.filePath(key)
fs.mu.Lock()
defer fs.mu.Unlock()
err := os.MkdirAll(filepath.Dir(cacheFile), 0755)
if err != nil {
return err
}
f, err := os.OpenFile(cacheFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return err
}
defer f.Close()
var origBody io.ReadCloser
origBody, res.Body = copyStream(res.Body)
defer res.Body.Close()
err = res.Write(f)
res.Body = origBody
return err
}

89
api/cache_test.go Normal file
View file

@ -0,0 +1,89 @@
package api
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_CacheReponse(t *testing.T) {
counter := 0
fakeHTTP := funcTripper{
roundTrip: func(req *http.Request) (*http.Response, error) {
counter += 1
body := fmt.Sprintf("%d: %s %s", counter, req.Method, req.URL.String())
status := 200
if req.URL.Path == "/error" {
status = 500
}
return &http.Response{
StatusCode: status,
Body: ioutil.NopCloser(bytes.NewBufferString(body)),
}, nil
},
}
cacheDir := filepath.Join(t.TempDir(), "gh-cli-cache")
httpClient := NewHTTPClient(ReplaceTripper(fakeHTTP), CacheReponse(time.Minute, cacheDir))
do := func(method, url string, body io.Reader) (string, error) {
req, err := http.NewRequest(method, url, body)
if err != nil {
return "", err
}
res, err := httpClient.Do(req)
if err != nil {
return "", err
}
defer res.Body.Close()
resBody, err := ioutil.ReadAll(res.Body)
if err != nil {
err = fmt.Errorf("ReadAll: %w", err)
}
return string(resBody), err
}
var res string
var err error
res, err = do("GET", "http://example.com/path", nil)
require.NoError(t, err)
assert.Equal(t, "1: GET http://example.com/path", res)
res, err = do("GET", "http://example.com/path", nil)
require.NoError(t, err)
assert.Equal(t, "1: GET http://example.com/path", res)
res, err = do("GET", "http://example.com/path2", nil)
require.NoError(t, err)
assert.Equal(t, "2: GET http://example.com/path2", res)
res, err = do("POST", "http://example.com/path2", nil)
require.NoError(t, err)
assert.Equal(t, "3: POST http://example.com/path2", res)
res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello`))
require.NoError(t, err)
assert.Equal(t, "4: POST http://example.com/graphql", res)
res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello`))
require.NoError(t, err)
assert.Equal(t, "4: POST http://example.com/graphql", res)
res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello2`))
require.NoError(t, err)
assert.Equal(t, "5: POST http://example.com/graphql", res)
res, err = do("GET", "http://example.com/error", nil)
require.NoError(t, err)
assert.Equal(t, "6: GET http://example.com/error", res)
res, err = do("GET", "http://example.com/error", nil)
require.NoError(t, err)
assert.Equal(t, "7: GET http://example.com/error", res)
}

View file

@ -241,6 +241,52 @@ func (c Client) PullRequestDiff(baseRepo ghrepo.Interface, prNumber int) (io.Rea
return resp.Body, nil
}
type pullRequestFeature struct {
HasReviewDecision bool
HasStatusCheckRollup bool
}
func determinePullRequestFeatures(httpClient *http.Client, hostname string) (prFeatures pullRequestFeature, err error) {
if !ghinstance.IsEnterprise(hostname) {
prFeatures.HasReviewDecision = true
prFeatures.HasStatusCheckRollup = true
return
}
var featureDetection struct {
PullRequest struct {
Fields []struct {
Name string
} `graphql:"fields(includeDeprecated: true)"`
} `graphql:"PullRequest: __type(name: \"PullRequest\")"`
Commit struct {
Fields []struct {
Name string
} `graphql:"fields(includeDeprecated: true)"`
} `graphql:"Commit: __type(name: \"Commit\")"`
}
v4 := graphQLClient(httpClient, hostname)
err = v4.QueryNamed(context.Background(), "PullRequest_fields", &featureDetection, nil)
if err != nil {
return
}
for _, field := range featureDetection.PullRequest.Fields {
switch field.Name {
case "reviewDecision":
prFeatures.HasReviewDecision = true
}
}
for _, field := range featureDetection.Commit.Fields {
switch field.Name {
case "statusCheckRollup":
prFeatures.HasStatusCheckRollup = true
}
}
return
}
func PullRequests(client *Client, repo ghrepo.Interface, currentPRNumber int, currentPRHeadRef, currentUsername string) (*PullRequestsPayload, error) {
type edges struct {
TotalCount int
@ -261,18 +307,20 @@ func PullRequests(client *Client, repo ghrepo.Interface, currentPRNumber int, cu
ReviewRequested edges
}
fragments := `
fragment pr on PullRequest {
number
title
state
url
headRefName
headRepositoryOwner {
login
}
isCrossRepository
isDraft
cachedClient := makeCachedClient(client.http, time.Hour*24)
prFeatures, err := determinePullRequestFeatures(cachedClient, repo.RepoHost())
if err != nil {
return nil, err
}
var reviewsFragment string
if prFeatures.HasReviewDecision {
reviewsFragment = "reviewDecision"
}
var statusesFragment string
if prFeatures.HasStatusCheckRollup {
statusesFragment = `
commits(last: 1) {
nodes {
commit {
@ -292,12 +340,28 @@ func PullRequests(client *Client, repo ghrepo.Interface, currentPRNumber int, cu
}
}
}
`
}
fragments := fmt.Sprintf(`
fragment pr on PullRequest {
number
title
state
url
headRefName
headRepositoryOwner {
login
}
isCrossRepository
isDraft
%s
}
fragment prWithReviews on PullRequest {
...pr
reviewDecision
%s
}
`
`, statusesFragment, reviewsFragment)
queryPrefix := `
query PullRequestStatus($owner: String!, $repo: String!, $headRefName: String!, $viewerQuery: String!, $reviewerQuery: String!, $per_page: Int = 10) {
@ -345,6 +409,13 @@ func PullRequests(client *Client, repo ghrepo.Interface, currentPRNumber int, cu
}
`
if currentUsername == "@me" && ghinstance.IsEnterprise(repo.RepoHost()) {
currentUsername, err = CurrentLoginName(client, repo.RepoHost())
if err != nil {
return nil, err
}
}
viewerQuery := fmt.Sprintf("repo:%s state:open is:pr author:%s", ghrepo.FullName(repo), currentUsername)
reviewerQuery := fmt.Sprintf("repo:%s state:open review-requested:%s", ghrepo.FullName(repo), currentUsername)
@ -363,7 +434,7 @@ func PullRequests(client *Client, repo ghrepo.Interface, currentPRNumber int, cu
}
var resp response
err := client.GraphQL(repo.RepoHost(), query, variables, &resp)
err = client.GraphQL(repo.RepoHost(), query, variables, &resp)
if err != nil {
return nil, err
}
@ -404,6 +475,45 @@ func PullRequests(client *Client, repo ghrepo.Interface, currentPRNumber int, cu
return &payload, nil
}
func prCommitsFragment(httpClient *http.Client, hostname string) (string, error) {
cachedClient := makeCachedClient(httpClient, time.Hour*24)
if prFeatures, err := determinePullRequestFeatures(cachedClient, hostname); err != nil {
return "", err
} else if !prFeatures.HasStatusCheckRollup {
return "", nil
}
return `
commits(last: 1) {
totalCount
nodes {
commit {
oid
statusCheckRollup {
contexts(last: 100) {
nodes {
...on StatusContext {
context
state
targetUrl
}
...on CheckRun {
name
status
conclusion
startedAt
completedAt
detailsUrl
}
}
}
}
}
}
}
`, nil
}
func PullRequestByNumber(client *Client, repo ghrepo.Interface, number int) (*PullRequest, error) {
type response struct {
Repository struct {
@ -411,6 +521,11 @@ func PullRequestByNumber(client *Client, repo ghrepo.Interface, number int) (*Pu
}
}
statusesFragment, err := prCommitsFragment(client.http, repo.RepoHost())
if err != nil {
return nil, err
}
query := `
query PullRequestByNumber($owner: String!, $repo: String!, $pr_number: Int!) {
repository(owner: $owner, name: $repo) {
@ -426,33 +541,7 @@ func PullRequestByNumber(client *Client, repo ghrepo.Interface, number int) (*Pu
author {
login
}
commits(last: 1) {
totalCount
nodes {
commit {
oid
statusCheckRollup {
contexts(last: 100) {
nodes {
...on StatusContext {
context
state
targetUrl
}
...on CheckRun {
name
status
conclusion
startedAt
completedAt
detailsUrl
}
}
}
}
}
}
}
` + statusesFragment + `
baseRefName
headRefName
headRepositoryOwner {
@ -524,7 +613,7 @@ func PullRequestByNumber(client *Client, repo ghrepo.Interface, number int) (*Pu
}
var resp response
err := client.GraphQL(repo.RepoHost(), query, variables, &resp)
err = client.GraphQL(repo.RepoHost(), query, variables, &resp)
if err != nil {
return nil, err
}
@ -542,6 +631,11 @@ func PullRequestForBranch(client *Client, repo ghrepo.Interface, baseBranch, hea
}
}
statusesFragment, err := prCommitsFragment(client.http, repo.RepoHost())
if err != nil {
return nil, err
}
query := `
query PullRequestForBranch($owner: String!, $repo: String!, $headRefName: String!) {
repository(owner: $owner, name: $repo) {
@ -556,33 +650,7 @@ func PullRequestForBranch(client *Client, repo ghrepo.Interface, baseBranch, hea
author {
login
}
commits(last: 1) {
totalCount
nodes {
commit {
oid
statusCheckRollup {
contexts(last: 100) {
nodes {
...on StatusContext {
context
state
targetUrl
}
...on CheckRun {
name
status
conclusion
startedAt
completedAt
detailsUrl
}
}
}
}
}
}
}
` + statusesFragment + `
url
baseRefName
headRefName
@ -661,7 +729,7 @@ func PullRequestForBranch(client *Client, repo ghrepo.Interface, baseBranch, hea
}
var resp response
err := client.GraphQL(repo.RepoHost(), query, variables, &resp)
err = client.GraphQL(repo.RepoHost(), query, variables, &resp)
if err != nil {
return nil, err
}

View file

@ -1,8 +1,10 @@
package api
import (
"reflect"
"testing"
"github.com/MakeNowJust/heredoc"
"github.com/cli/cli/internal/ghrepo"
"github.com/cli/cli/pkg/httpmock"
)
@ -45,3 +47,94 @@ func TestBranchDeleteRemote(t *testing.T) {
})
}
}
func Test_determinePullRequestFeatures(t *testing.T) {
tests := []struct {
name string
hostname string
queryResponse string
wantPrFeatures pullRequestFeature
wantErr bool
}{
{
name: "github.com",
hostname: "github.com",
wantPrFeatures: pullRequestFeature{
HasReviewDecision: true,
HasStatusCheckRollup: true,
},
wantErr: false,
},
{
name: "GHE empty response",
hostname: "git.my.org",
queryResponse: heredoc.Doc(`
{"data": {}}
`),
wantPrFeatures: pullRequestFeature{
HasReviewDecision: false,
HasStatusCheckRollup: false,
},
wantErr: false,
},
{
name: "GHE has reviewDecision",
hostname: "git.my.org",
queryResponse: heredoc.Doc(`
{"data": {
"PullRequest": {
"fields": [
{"name": "foo"},
{"name": "reviewDecision"}
]
}
} }
`),
wantPrFeatures: pullRequestFeature{
HasReviewDecision: true,
HasStatusCheckRollup: false,
},
wantErr: false,
},
{
name: "GHE has statusCheckRollup",
hostname: "git.my.org",
queryResponse: heredoc.Doc(`
{"data": {
"Commit": {
"fields": [
{"name": "foo"},
{"name": "statusCheckRollup"}
]
}
} }
`),
wantPrFeatures: pullRequestFeature{
HasReviewDecision: false,
HasStatusCheckRollup: true,
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
fakeHTTP := &httpmock.Registry{}
httpClient := NewHTTPClient(ReplaceTripper(fakeHTTP))
if tt.queryResponse != "" {
fakeHTTP.Register(
httpmock.GraphQL(`query PullRequest_fields\b`),
httpmock.StringResponse(tt.queryResponse))
}
gotPrFeatures, err := determinePullRequestFeatures(httpClient, tt.hostname)
if (err != nil) != tt.wantErr {
t.Errorf("determinePullRequestFeatures() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(gotPrFeatures, tt.wantPrFeatures) {
t.Errorf("determinePullRequestFeatures() = %v, want %v", gotPrFeatures, tt.wantPrFeatures)
}
})
}
}