From 93c8fc1e98105ff8b94b4e562f4d9346fa54ba3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Thu, 1 Oct 2020 16:33:56 +0200 Subject: [PATCH] Add tests for GraphQL introspection --- api/cache.go | 24 ++++++++++- api/cache_test.go | 70 +++++++++++++++++++++++++++++++ api/queries_pr.go | 18 +++----- api/queries_pr_test.go | 93 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 191 insertions(+), 14 deletions(-) create mode 100644 api/cache_test.go diff --git a/api/cache.go b/api/cache.go index f1dd4f7c2..79d446ae2 100644 --- a/api/cache.go +++ b/api/cache.go @@ -14,6 +14,13 @@ import ( "time" ) +func makeCachedClient(httpClient *http.Client, cacheTTL time.Duration) *http.Client { + cacheDir := filepath.Join(os.TempDir(), "gh-cli-cache") + return &http.Client{ + Transport: CacheReponse(cacheTTL, cacheDir)(httpClient.Transport), + } +} + // CacheReponse produces a RoundTripper that caches HTTP responses to disk for a specified amount of time func CacheReponse(ttl time.Duration, dir string) ClientOption { return func(tr http.RoundTripper) http.RoundTripper { @@ -21,12 +28,14 @@ func CacheReponse(ttl time.Duration, dir string) ClientOption { key, keyErr := cacheKey(req) cacheFile := filepath.Join(dir, key) if keyErr == nil { + // TODO: make thread-safe if res, err := readCache(ttl, cacheFile, req); err == nil { return res, nil } } res, err := tr.RoundTrip(req) if err == nil && keyErr == nil { + // TODO: make thread-safe _ = writeCache(cacheFile, res) } return res, err @@ -53,12 +62,16 @@ func cacheKey(req *http.Request) (string, error) { return fmt.Sprintf("%x", digest), nil } +type readCloser struct { + io.Reader + io.Closer +} + func readCache(ttl time.Duration, cacheFile string, req *http.Request) (*http.Response, error) { f, err := os.Open(cacheFile) if err != nil { return nil, err } - defer f.Close() fs, err := f.Stat() if err != nil { @@ -70,7 +83,14 @@ func readCache(ttl time.Duration, cacheFile string, req *http.Request) (*http.Re return nil, errors.New("cache expired") } - return http.ReadResponse(bufio.NewReader(f), req) + res, err := http.ReadResponse(bufio.NewReader(f), req) + if res != nil { + res.Body = &readCloser{ + Reader: res.Body, + Closer: f, + } + } + return res, err } func writeCache(cacheFile string, res *http.Response) error { diff --git a/api/cache_test.go b/api/cache_test.go new file mode 100644 index 000000000..48255186f --- /dev/null +++ b/api/cache_test.go @@ -0,0 +1,70 @@ +package api + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "net/http" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_CacheReponse(t *testing.T) { + counter := 0 + fakeHTTP := funcTripper{ + roundTrip: func(req *http.Request) (*http.Response, error) { + counter += 1 + body := fmt.Sprintf("%d: %s %s", counter, req.Method, req.URL.String()) + return &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(bytes.NewBufferString(body)), + }, nil + }, + } + + cacheDir := filepath.Join(t.TempDir(), "gh-cli-cache") + httpClient := NewHTTPClient(ReplaceTripper(fakeHTTP), CacheReponse(time.Minute, cacheDir)) + + do := func(method, url string, body io.Reader) (string, error) { + req, err := http.NewRequest(method, url, body) + if err != nil { + return "", err + } + res, err := httpClient.Do(req) + if err != nil { + return "", err + } + resBody, err := ioutil.ReadAll(res.Body) + if err != nil { + err = fmt.Errorf("ReadAll: %w", err) + } + return string(resBody), err + } + + res1, err := do("GET", "http://example.com/path", nil) + require.NoError(t, err) + assert.Equal(t, "1: GET http://example.com/path", res1) + res2, err := do("GET", "http://example.com/path", nil) + require.NoError(t, err) + assert.Equal(t, "1: GET http://example.com/path", res2) + + res3, err := do("GET", "http://example.com/path2", nil) + require.NoError(t, err) + assert.Equal(t, "2: GET http://example.com/path2", res3) + + res4, err := do("POST", "http://example.com/path", bytes.NewBufferString(`hello`)) + require.NoError(t, err) + assert.Equal(t, "3: POST http://example.com/path", res4) + res5, err := do("POST", "http://example.com/path", bytes.NewBufferString(`hello`)) + require.NoError(t, err) + assert.Equal(t, "3: POST http://example.com/path", res5) + + res6, err := do("POST", "http://example.com/path", bytes.NewBufferString(`hello2`)) + require.NoError(t, err) + assert.Equal(t, "4: POST http://example.com/path", res6) +} diff --git a/api/queries_pr.go b/api/queries_pr.go index 60e4a18b3..34011d5be 100644 --- a/api/queries_pr.go +++ b/api/queries_pr.go @@ -6,8 +6,6 @@ import ( "fmt" "io" "net/http" - "os" - "path/filepath" "strings" "time" @@ -268,14 +266,8 @@ func determinePullRequestFeatures(httpClient *http.Client, hostname string) (prF } `graphql:"Commit: __type(name: \"Commit\")"` } - cacheDir := filepath.Join(os.TempDir(), "gh-cli-cache") - cacheTTL := time.Duration(24 * time.Hour) - cachedClient := &http.Client{ - Transport: CacheReponse(cacheTTL, cacheDir)(httpClient.Transport), - } - - v4 := graphQLClient(cachedClient, hostname) - err = v4.Query(context.Background(), &featureDetection, nil) + v4 := graphQLClient(httpClient, hostname) + err = v4.QueryNamed(context.Background(), "PullRequest_fields", &featureDetection, nil) if err != nil { return } @@ -315,7 +307,8 @@ func PullRequests(client *Client, repo ghrepo.Interface, currentPRNumber int, cu ReviewRequested edges } - prFeatures, err := determinePullRequestFeatures(client.http, repo.RepoHost()) + cachedClient := makeCachedClient(client.http, time.Hour*24) + prFeatures, err := determinePullRequestFeatures(cachedClient, repo.RepoHost()) if err != nil { return nil, err } @@ -483,7 +476,8 @@ func PullRequests(client *Client, repo ghrepo.Interface, currentPRNumber int, cu } func prCommitsFragment(httpClient *http.Client, hostname string) (string, error) { - if prFeatures, err := determinePullRequestFeatures(httpClient, hostname); err != nil { + cachedClient := makeCachedClient(httpClient, time.Hour*24) + if prFeatures, err := determinePullRequestFeatures(cachedClient, hostname); err != nil { return "", err } else if !prFeatures.HasStatusCheckRollup { return "", nil diff --git a/api/queries_pr_test.go b/api/queries_pr_test.go index 3441534bf..e624b61b3 100644 --- a/api/queries_pr_test.go +++ b/api/queries_pr_test.go @@ -1,8 +1,10 @@ package api import ( + "reflect" "testing" + "github.com/MakeNowJust/heredoc" "github.com/cli/cli/internal/ghrepo" "github.com/cli/cli/pkg/httpmock" ) @@ -45,3 +47,94 @@ func TestBranchDeleteRemote(t *testing.T) { }) } } + +func Test_determinePullRequestFeatures(t *testing.T) { + tests := []struct { + name string + hostname string + queryResponse string + wantPrFeatures pullRequestFeature + wantErr bool + }{ + { + name: "github.com", + hostname: "github.com", + wantPrFeatures: pullRequestFeature{ + HasReviewDecision: true, + HasStatusCheckRollup: true, + }, + wantErr: false, + }, + { + name: "GHE empty response", + hostname: "git.my.org", + queryResponse: heredoc.Doc(` + {"data": {}} + `), + wantPrFeatures: pullRequestFeature{ + HasReviewDecision: false, + HasStatusCheckRollup: false, + }, + wantErr: false, + }, + { + name: "GHE has reviewDecision", + hostname: "git.my.org", + queryResponse: heredoc.Doc(` + {"data": { + "PullRequest": { + "fields": [ + {"name": "foo"}, + {"name": "reviewDecision"} + ] + } + } } + `), + wantPrFeatures: pullRequestFeature{ + HasReviewDecision: true, + HasStatusCheckRollup: false, + }, + wantErr: false, + }, + { + name: "GHE has statusCheckRollup", + hostname: "git.my.org", + queryResponse: heredoc.Doc(` + {"data": { + "Commit": { + "fields": [ + {"name": "foo"}, + {"name": "statusCheckRollup"} + ] + } + } } + `), + wantPrFeatures: pullRequestFeature{ + HasReviewDecision: false, + HasStatusCheckRollup: true, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fakeHTTP := &httpmock.Registry{} + httpClient := NewHTTPClient(ReplaceTripper(fakeHTTP)) + + if tt.queryResponse != "" { + fakeHTTP.Register( + httpmock.GraphQL(`query PullRequest_fields\b`), + httpmock.StringResponse(tt.queryResponse)) + } + + gotPrFeatures, err := determinePullRequestFeatures(httpClient, tt.hostname) + if (err != nil) != tt.wantErr { + t.Errorf("determinePullRequestFeatures() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(gotPrFeatures, tt.wantPrFeatures) { + t.Errorf("determinePullRequestFeatures() = %v, want %v", gotPrFeatures, tt.wantPrFeatures) + } + }) + } +}