Add tests for GraphQL introspection

This commit is contained in:
Mislav Marohnić 2020-10-01 16:33:56 +02:00
parent 0ef2863ede
commit 93c8fc1e98
4 changed files with 191 additions and 14 deletions

View file

@ -14,6 +14,13 @@ import (
"time"
)
func makeCachedClient(httpClient *http.Client, cacheTTL time.Duration) *http.Client {
cacheDir := filepath.Join(os.TempDir(), "gh-cli-cache")
return &http.Client{
Transport: CacheReponse(cacheTTL, cacheDir)(httpClient.Transport),
}
}
// CacheReponse produces a RoundTripper that caches HTTP responses to disk for a specified amount of time
func CacheReponse(ttl time.Duration, dir string) ClientOption {
return func(tr http.RoundTripper) http.RoundTripper {
@ -21,12 +28,14 @@ func CacheReponse(ttl time.Duration, dir string) ClientOption {
key, keyErr := cacheKey(req)
cacheFile := filepath.Join(dir, key)
if keyErr == nil {
// TODO: make thread-safe
if res, err := readCache(ttl, cacheFile, req); err == nil {
return res, nil
}
}
res, err := tr.RoundTrip(req)
if err == nil && keyErr == nil {
// TODO: make thread-safe
_ = writeCache(cacheFile, res)
}
return res, err
@ -53,12 +62,16 @@ func cacheKey(req *http.Request) (string, error) {
return fmt.Sprintf("%x", digest), nil
}
type readCloser struct {
io.Reader
io.Closer
}
func readCache(ttl time.Duration, cacheFile string, req *http.Request) (*http.Response, error) {
f, err := os.Open(cacheFile)
if err != nil {
return nil, err
}
defer f.Close()
fs, err := f.Stat()
if err != nil {
@ -70,7 +83,14 @@ func readCache(ttl time.Duration, cacheFile string, req *http.Request) (*http.Re
return nil, errors.New("cache expired")
}
return http.ReadResponse(bufio.NewReader(f), req)
res, err := http.ReadResponse(bufio.NewReader(f), req)
if res != nil {
res.Body = &readCloser{
Reader: res.Body,
Closer: f,
}
}
return res, err
}
func writeCache(cacheFile string, res *http.Response) error {

70
api/cache_test.go Normal file
View file

@ -0,0 +1,70 @@
package api
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_CacheReponse(t *testing.T) {
counter := 0
fakeHTTP := funcTripper{
roundTrip: func(req *http.Request) (*http.Response, error) {
counter += 1
body := fmt.Sprintf("%d: %s %s", counter, req.Method, req.URL.String())
return &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewBufferString(body)),
}, nil
},
}
cacheDir := filepath.Join(t.TempDir(), "gh-cli-cache")
httpClient := NewHTTPClient(ReplaceTripper(fakeHTTP), CacheReponse(time.Minute, cacheDir))
do := func(method, url string, body io.Reader) (string, error) {
req, err := http.NewRequest(method, url, body)
if err != nil {
return "", err
}
res, err := httpClient.Do(req)
if err != nil {
return "", err
}
resBody, err := ioutil.ReadAll(res.Body)
if err != nil {
err = fmt.Errorf("ReadAll: %w", err)
}
return string(resBody), err
}
res1, err := do("GET", "http://example.com/path", nil)
require.NoError(t, err)
assert.Equal(t, "1: GET http://example.com/path", res1)
res2, err := do("GET", "http://example.com/path", nil)
require.NoError(t, err)
assert.Equal(t, "1: GET http://example.com/path", res2)
res3, err := do("GET", "http://example.com/path2", nil)
require.NoError(t, err)
assert.Equal(t, "2: GET http://example.com/path2", res3)
res4, err := do("POST", "http://example.com/path", bytes.NewBufferString(`hello`))
require.NoError(t, err)
assert.Equal(t, "3: POST http://example.com/path", res4)
res5, err := do("POST", "http://example.com/path", bytes.NewBufferString(`hello`))
require.NoError(t, err)
assert.Equal(t, "3: POST http://example.com/path", res5)
res6, err := do("POST", "http://example.com/path", bytes.NewBufferString(`hello2`))
require.NoError(t, err)
assert.Equal(t, "4: POST http://example.com/path", res6)
}

View file

@ -6,8 +6,6 @@ import (
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"time"
@ -268,14 +266,8 @@ func determinePullRequestFeatures(httpClient *http.Client, hostname string) (prF
} `graphql:"Commit: __type(name: \"Commit\")"`
}
cacheDir := filepath.Join(os.TempDir(), "gh-cli-cache")
cacheTTL := time.Duration(24 * time.Hour)
cachedClient := &http.Client{
Transport: CacheReponse(cacheTTL, cacheDir)(httpClient.Transport),
}
v4 := graphQLClient(cachedClient, hostname)
err = v4.Query(context.Background(), &featureDetection, nil)
v4 := graphQLClient(httpClient, hostname)
err = v4.QueryNamed(context.Background(), "PullRequest_fields", &featureDetection, nil)
if err != nil {
return
}
@ -315,7 +307,8 @@ func PullRequests(client *Client, repo ghrepo.Interface, currentPRNumber int, cu
ReviewRequested edges
}
prFeatures, err := determinePullRequestFeatures(client.http, repo.RepoHost())
cachedClient := makeCachedClient(client.http, time.Hour*24)
prFeatures, err := determinePullRequestFeatures(cachedClient, repo.RepoHost())
if err != nil {
return nil, err
}
@ -483,7 +476,8 @@ func PullRequests(client *Client, repo ghrepo.Interface, currentPRNumber int, cu
}
func prCommitsFragment(httpClient *http.Client, hostname string) (string, error) {
if prFeatures, err := determinePullRequestFeatures(httpClient, hostname); err != nil {
cachedClient := makeCachedClient(httpClient, time.Hour*24)
if prFeatures, err := determinePullRequestFeatures(cachedClient, hostname); err != nil {
return "", err
} else if !prFeatures.HasStatusCheckRollup {
return "", nil

View file

@ -1,8 +1,10 @@
package api
import (
"reflect"
"testing"
"github.com/MakeNowJust/heredoc"
"github.com/cli/cli/internal/ghrepo"
"github.com/cli/cli/pkg/httpmock"
)
@ -45,3 +47,94 @@ func TestBranchDeleteRemote(t *testing.T) {
})
}
}
func Test_determinePullRequestFeatures(t *testing.T) {
tests := []struct {
name string
hostname string
queryResponse string
wantPrFeatures pullRequestFeature
wantErr bool
}{
{
name: "github.com",
hostname: "github.com",
wantPrFeatures: pullRequestFeature{
HasReviewDecision: true,
HasStatusCheckRollup: true,
},
wantErr: false,
},
{
name: "GHE empty response",
hostname: "git.my.org",
queryResponse: heredoc.Doc(`
{"data": {}}
`),
wantPrFeatures: pullRequestFeature{
HasReviewDecision: false,
HasStatusCheckRollup: false,
},
wantErr: false,
},
{
name: "GHE has reviewDecision",
hostname: "git.my.org",
queryResponse: heredoc.Doc(`
{"data": {
"PullRequest": {
"fields": [
{"name": "foo"},
{"name": "reviewDecision"}
]
}
} }
`),
wantPrFeatures: pullRequestFeature{
HasReviewDecision: true,
HasStatusCheckRollup: false,
},
wantErr: false,
},
{
name: "GHE has statusCheckRollup",
hostname: "git.my.org",
queryResponse: heredoc.Doc(`
{"data": {
"Commit": {
"fields": [
{"name": "foo"},
{"name": "statusCheckRollup"}
]
}
} }
`),
wantPrFeatures: pullRequestFeature{
HasReviewDecision: false,
HasStatusCheckRollup: true,
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
fakeHTTP := &httpmock.Registry{}
httpClient := NewHTTPClient(ReplaceTripper(fakeHTTP))
if tt.queryResponse != "" {
fakeHTTP.Register(
httpmock.GraphQL(`query PullRequest_fields\b`),
httpmock.StringResponse(tt.queryResponse))
}
gotPrFeatures, err := determinePullRequestFeatures(httpClient, tt.hostname)
if (err != nil) != tt.wantErr {
t.Errorf("determinePullRequestFeatures() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(gotPrFeatures, tt.wantPrFeatures) {
t.Errorf("determinePullRequestFeatures() = %v, want %v", gotPrFeatures, tt.wantPrFeatures)
}
})
}
}