Add tests for GraphQL introspection
This commit is contained in:
parent
0ef2863ede
commit
93c8fc1e98
4 changed files with 191 additions and 14 deletions
24
api/cache.go
24
api/cache.go
|
|
@ -14,6 +14,13 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
func makeCachedClient(httpClient *http.Client, cacheTTL time.Duration) *http.Client {
|
||||
cacheDir := filepath.Join(os.TempDir(), "gh-cli-cache")
|
||||
return &http.Client{
|
||||
Transport: CacheReponse(cacheTTL, cacheDir)(httpClient.Transport),
|
||||
}
|
||||
}
|
||||
|
||||
// CacheReponse produces a RoundTripper that caches HTTP responses to disk for a specified amount of time
|
||||
func CacheReponse(ttl time.Duration, dir string) ClientOption {
|
||||
return func(tr http.RoundTripper) http.RoundTripper {
|
||||
|
|
@ -21,12 +28,14 @@ func CacheReponse(ttl time.Duration, dir string) ClientOption {
|
|||
key, keyErr := cacheKey(req)
|
||||
cacheFile := filepath.Join(dir, key)
|
||||
if keyErr == nil {
|
||||
// TODO: make thread-safe
|
||||
if res, err := readCache(ttl, cacheFile, req); err == nil {
|
||||
return res, nil
|
||||
}
|
||||
}
|
||||
res, err := tr.RoundTrip(req)
|
||||
if err == nil && keyErr == nil {
|
||||
// TODO: make thread-safe
|
||||
_ = writeCache(cacheFile, res)
|
||||
}
|
||||
return res, err
|
||||
|
|
@ -53,12 +62,16 @@ func cacheKey(req *http.Request) (string, error) {
|
|||
return fmt.Sprintf("%x", digest), nil
|
||||
}
|
||||
|
||||
type readCloser struct {
|
||||
io.Reader
|
||||
io.Closer
|
||||
}
|
||||
|
||||
func readCache(ttl time.Duration, cacheFile string, req *http.Request) (*http.Response, error) {
|
||||
f, err := os.Open(cacheFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
fs, err := f.Stat()
|
||||
if err != nil {
|
||||
|
|
@ -70,7 +83,14 @@ func readCache(ttl time.Duration, cacheFile string, req *http.Request) (*http.Re
|
|||
return nil, errors.New("cache expired")
|
||||
}
|
||||
|
||||
return http.ReadResponse(bufio.NewReader(f), req)
|
||||
res, err := http.ReadResponse(bufio.NewReader(f), req)
|
||||
if res != nil {
|
||||
res.Body = &readCloser{
|
||||
Reader: res.Body,
|
||||
Closer: f,
|
||||
}
|
||||
}
|
||||
return res, err
|
||||
}
|
||||
|
||||
func writeCache(cacheFile string, res *http.Response) error {
|
||||
|
|
|
|||
70
api/cache_test.go
Normal file
70
api/cache_test.go
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_CacheReponse(t *testing.T) {
|
||||
counter := 0
|
||||
fakeHTTP := funcTripper{
|
||||
roundTrip: func(req *http.Request) (*http.Response, error) {
|
||||
counter += 1
|
||||
body := fmt.Sprintf("%d: %s %s", counter, req.Method, req.URL.String())
|
||||
return &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(bytes.NewBufferString(body)),
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
cacheDir := filepath.Join(t.TempDir(), "gh-cli-cache")
|
||||
httpClient := NewHTTPClient(ReplaceTripper(fakeHTTP), CacheReponse(time.Minute, cacheDir))
|
||||
|
||||
do := func(method, url string, body io.Reader) (string, error) {
|
||||
req, err := http.NewRequest(method, url, body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
res, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
resBody, err := ioutil.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("ReadAll: %w", err)
|
||||
}
|
||||
return string(resBody), err
|
||||
}
|
||||
|
||||
res1, err := do("GET", "http://example.com/path", nil)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "1: GET http://example.com/path", res1)
|
||||
res2, err := do("GET", "http://example.com/path", nil)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "1: GET http://example.com/path", res2)
|
||||
|
||||
res3, err := do("GET", "http://example.com/path2", nil)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "2: GET http://example.com/path2", res3)
|
||||
|
||||
res4, err := do("POST", "http://example.com/path", bytes.NewBufferString(`hello`))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "3: POST http://example.com/path", res4)
|
||||
res5, err := do("POST", "http://example.com/path", bytes.NewBufferString(`hello`))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "3: POST http://example.com/path", res5)
|
||||
|
||||
res6, err := do("POST", "http://example.com/path", bytes.NewBufferString(`hello2`))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "4: POST http://example.com/path", res6)
|
||||
}
|
||||
|
|
@ -6,8 +6,6 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
|
@ -268,14 +266,8 @@ func determinePullRequestFeatures(httpClient *http.Client, hostname string) (prF
|
|||
} `graphql:"Commit: __type(name: \"Commit\")"`
|
||||
}
|
||||
|
||||
cacheDir := filepath.Join(os.TempDir(), "gh-cli-cache")
|
||||
cacheTTL := time.Duration(24 * time.Hour)
|
||||
cachedClient := &http.Client{
|
||||
Transport: CacheReponse(cacheTTL, cacheDir)(httpClient.Transport),
|
||||
}
|
||||
|
||||
v4 := graphQLClient(cachedClient, hostname)
|
||||
err = v4.Query(context.Background(), &featureDetection, nil)
|
||||
v4 := graphQLClient(httpClient, hostname)
|
||||
err = v4.QueryNamed(context.Background(), "PullRequest_fields", &featureDetection, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
@ -315,7 +307,8 @@ func PullRequests(client *Client, repo ghrepo.Interface, currentPRNumber int, cu
|
|||
ReviewRequested edges
|
||||
}
|
||||
|
||||
prFeatures, err := determinePullRequestFeatures(client.http, repo.RepoHost())
|
||||
cachedClient := makeCachedClient(client.http, time.Hour*24)
|
||||
prFeatures, err := determinePullRequestFeatures(cachedClient, repo.RepoHost())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -483,7 +476,8 @@ func PullRequests(client *Client, repo ghrepo.Interface, currentPRNumber int, cu
|
|||
}
|
||||
|
||||
func prCommitsFragment(httpClient *http.Client, hostname string) (string, error) {
|
||||
if prFeatures, err := determinePullRequestFeatures(httpClient, hostname); err != nil {
|
||||
cachedClient := makeCachedClient(httpClient, time.Hour*24)
|
||||
if prFeatures, err := determinePullRequestFeatures(cachedClient, hostname); err != nil {
|
||||
return "", err
|
||||
} else if !prFeatures.HasStatusCheckRollup {
|
||||
return "", nil
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/MakeNowJust/heredoc"
|
||||
"github.com/cli/cli/internal/ghrepo"
|
||||
"github.com/cli/cli/pkg/httpmock"
|
||||
)
|
||||
|
|
@ -45,3 +47,94 @@ func TestBranchDeleteRemote(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_determinePullRequestFeatures(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
hostname string
|
||||
queryResponse string
|
||||
wantPrFeatures pullRequestFeature
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "github.com",
|
||||
hostname: "github.com",
|
||||
wantPrFeatures: pullRequestFeature{
|
||||
HasReviewDecision: true,
|
||||
HasStatusCheckRollup: true,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "GHE empty response",
|
||||
hostname: "git.my.org",
|
||||
queryResponse: heredoc.Doc(`
|
||||
{"data": {}}
|
||||
`),
|
||||
wantPrFeatures: pullRequestFeature{
|
||||
HasReviewDecision: false,
|
||||
HasStatusCheckRollup: false,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "GHE has reviewDecision",
|
||||
hostname: "git.my.org",
|
||||
queryResponse: heredoc.Doc(`
|
||||
{"data": {
|
||||
"PullRequest": {
|
||||
"fields": [
|
||||
{"name": "foo"},
|
||||
{"name": "reviewDecision"}
|
||||
]
|
||||
}
|
||||
} }
|
||||
`),
|
||||
wantPrFeatures: pullRequestFeature{
|
||||
HasReviewDecision: true,
|
||||
HasStatusCheckRollup: false,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "GHE has statusCheckRollup",
|
||||
hostname: "git.my.org",
|
||||
queryResponse: heredoc.Doc(`
|
||||
{"data": {
|
||||
"Commit": {
|
||||
"fields": [
|
||||
{"name": "foo"},
|
||||
{"name": "statusCheckRollup"}
|
||||
]
|
||||
}
|
||||
} }
|
||||
`),
|
||||
wantPrFeatures: pullRequestFeature{
|
||||
HasReviewDecision: false,
|
||||
HasStatusCheckRollup: true,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
fakeHTTP := &httpmock.Registry{}
|
||||
httpClient := NewHTTPClient(ReplaceTripper(fakeHTTP))
|
||||
|
||||
if tt.queryResponse != "" {
|
||||
fakeHTTP.Register(
|
||||
httpmock.GraphQL(`query PullRequest_fields\b`),
|
||||
httpmock.StringResponse(tt.queryResponse))
|
||||
}
|
||||
|
||||
gotPrFeatures, err := determinePullRequestFeatures(httpClient, tt.hostname)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("determinePullRequestFeatures() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(gotPrFeatures, tt.wantPrFeatures) {
|
||||
t.Errorf("determinePullRequestFeatures() = %v, want %v", gotPrFeatures, tt.wantPrFeatures)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue