From bed9d11f7a37c456919ae01e5c2194f368f88e65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Wed, 30 Sep 2020 19:04:09 +0200 Subject: [PATCH 1/7] Avoid querying `statusCheckRollup` or `reviewDecision` on unsupported GHE We first ask the GHE server for whether it supports these fields. --- api/queries_pr.go | 201 ++++++++++++++++++++++++++++++---------------- 1 file changed, 130 insertions(+), 71 deletions(-) diff --git a/api/queries_pr.go b/api/queries_pr.go index 4e93cb293..1f00beb13 100644 --- a/api/queries_pr.go +++ b/api/queries_pr.go @@ -241,6 +241,52 @@ func (c Client) PullRequestDiff(baseRepo ghrepo.Interface, prNumber int) (io.Rea return resp.Body, nil } +type pullRequestFeature struct { + HasReviewDecision bool + HasStatusCheckRollup bool +} + +func determinePullRequestFeatures(httpClient *http.Client, hostname string) (prFeatures pullRequestFeature, err error) { + if !ghinstance.IsEnterprise(hostname) { + prFeatures.HasReviewDecision = true + prFeatures.HasStatusCheckRollup = true + return + } + + var featureDetection struct { + PullRequest struct { + Fields []struct { + Name string + } `graphql:"fields(includeDeprecated: true)"` + } `graphql:"PullRequest: __type(name: \"PullRequest\")"` + Commit struct { + Fields []struct { + Name string + } `graphql:"fields(includeDeprecated: true)"` + } `graphql:"Commit: __type(name: \"Commit\")"` + } + + v4 := graphQLClient(httpClient, hostname) + err = v4.Query(context.Background(), &featureDetection, nil) + if err != nil { + return + } + + for _, field := range featureDetection.PullRequest.Fields { + switch field.Name { + case "reviewDecision": + prFeatures.HasReviewDecision = true + } + } + for _, field := range featureDetection.Commit.Fields { + switch field.Name { + case "statusCheckRollup": + prFeatures.HasStatusCheckRollup = true + } + } + return +} + func PullRequests(client *Client, repo ghrepo.Interface, currentPRNumber int, currentPRHeadRef, currentUsername string) (*PullRequestsPayload, error) { type edges struct { TotalCount int @@ -261,18 +307,19 @@ func PullRequests(client *Client, repo ghrepo.Interface, currentPRNumber int, cu ReviewRequested edges } - fragments := ` - fragment pr on PullRequest { - number - title - state - url - headRefName - headRepositoryOwner { - login - } - isCrossRepository - isDraft + prFeatures, err := determinePullRequestFeatures(client.http, repo.RepoHost()) + if err != nil { + return nil, err + } + + var reviewsFragment string + if prFeatures.HasReviewDecision { + reviewsFragment = "reviewDecision" + } + + var statusesFragment string + if prFeatures.HasStatusCheckRollup { + statusesFragment = ` commits(last: 1) { nodes { commit { @@ -292,12 +339,28 @@ func PullRequests(client *Client, repo ghrepo.Interface, currentPRNumber int, cu } } } + ` + } + + fragments := fmt.Sprintf(` + fragment pr on PullRequest { + number + title + state + url + headRefName + headRepositoryOwner { + login + } + isCrossRepository + isDraft + %s } fragment prWithReviews on PullRequest { ...pr - reviewDecision + %s } - ` + `, statusesFragment, reviewsFragment) queryPrefix := ` query PullRequestStatus($owner: String!, $repo: String!, $headRefName: String!, $viewerQuery: String!, $reviewerQuery: String!, $per_page: Int = 10) { @@ -363,7 +426,7 @@ func PullRequests(client *Client, repo ghrepo.Interface, currentPRNumber int, cu } var resp response - err := client.GraphQL(repo.RepoHost(), query, variables, &resp) + err = client.GraphQL(repo.RepoHost(), query, variables, &resp) if err != nil { return nil, err } @@ -404,6 +467,44 @@ func PullRequests(client *Client, repo ghrepo.Interface, currentPRNumber int, cu return &payload, nil } +func prCommitsFragment(httpClient *http.Client, hostname string) (string, error) { + if prFeatures, err := determinePullRequestFeatures(httpClient, hostname); err != nil { + return "", err + } else if !prFeatures.HasStatusCheckRollup { + return "", nil + } + + return ` + commits(last: 1) { + totalCount + nodes { + commit { + oid + statusCheckRollup { + contexts(last: 100) { + nodes { + ...on StatusContext { + context + state + targetUrl + } + ...on CheckRun { + name + status + conclusion + startedAt + completedAt + detailsUrl + } + } + } + } + } + } + } + `, nil +} + func PullRequestByNumber(client *Client, repo ghrepo.Interface, number int) (*PullRequest, error) { type response struct { Repository struct { @@ -411,6 +512,11 @@ func PullRequestByNumber(client *Client, repo ghrepo.Interface, number int) (*Pu } } + statusesFragment, err := prCommitsFragment(client.http, repo.RepoHost()) + if err != nil { + return nil, err + } + query := ` query PullRequestByNumber($owner: String!, $repo: String!, $pr_number: Int!) { repository(owner: $owner, name: $repo) { @@ -426,33 +532,7 @@ func PullRequestByNumber(client *Client, repo ghrepo.Interface, number int) (*Pu author { login } - commits(last: 1) { - totalCount - nodes { - commit { - oid - statusCheckRollup { - contexts(last: 100) { - nodes { - ...on StatusContext { - context - state - targetUrl - } - ...on CheckRun { - name - status - conclusion - startedAt - completedAt - detailsUrl - } - } - } - } - } - } - } + ` + statusesFragment + ` baseRefName headRefName headRepositoryOwner { @@ -524,7 +604,7 @@ func PullRequestByNumber(client *Client, repo ghrepo.Interface, number int) (*Pu } var resp response - err := client.GraphQL(repo.RepoHost(), query, variables, &resp) + err = client.GraphQL(repo.RepoHost(), query, variables, &resp) if err != nil { return nil, err } @@ -542,6 +622,11 @@ func PullRequestForBranch(client *Client, repo ghrepo.Interface, baseBranch, hea } } + statusesFragment, err := prCommitsFragment(client.http, repo.RepoHost()) + if err != nil { + return nil, err + } + query := ` query PullRequestForBranch($owner: String!, $repo: String!, $headRefName: String!) { repository(owner: $owner, name: $repo) { @@ -556,33 +641,7 @@ func PullRequestForBranch(client *Client, repo ghrepo.Interface, baseBranch, hea author { login } - commits(last: 1) { - totalCount - nodes { - commit { - oid - statusCheckRollup { - contexts(last: 100) { - nodes { - ...on StatusContext { - context - state - targetUrl - } - ...on CheckRun { - name - status - conclusion - startedAt - completedAt - detailsUrl - } - } - } - } - } - } - } + ` + statusesFragment + ` url baseRefName headRefName @@ -661,7 +720,7 @@ func PullRequestForBranch(client *Client, repo ghrepo.Interface, baseBranch, hea } var resp response - err := client.GraphQL(repo.RepoHost(), query, variables, &resp) + err = client.GraphQL(repo.RepoHost(), query, variables, &resp) if err != nil { return nil, err } From ff925fb480647741eec19ea6d42ca33781fb03e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Wed, 30 Sep 2020 19:05:06 +0200 Subject: [PATCH 2/7] Resolve `@me` to current username on GHE It looks like GHE v2.20 does not support `@me` in search yet. --- api/queries_pr.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/api/queries_pr.go b/api/queries_pr.go index 1f00beb13..7b63aa642 100644 --- a/api/queries_pr.go +++ b/api/queries_pr.go @@ -408,6 +408,13 @@ func PullRequests(client *Client, repo ghrepo.Interface, currentPRNumber int, cu } ` + if currentUsername == "@me" && ghinstance.IsEnterprise(repo.RepoHost()) { + currentUsername, err = CurrentLoginName(client, repo.RepoHost()) + if err != nil { + return nil, err + } + } + viewerQuery := fmt.Sprintf("repo:%s state:open is:pr author:%s", ghrepo.FullName(repo), currentUsername) reviewerQuery := fmt.Sprintf("repo:%s state:open review-requested:%s", ghrepo.FullName(repo), currentUsername) From 0ef2863ede1ac5ef407e022d0a943285d6c5811b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Wed, 30 Sep 2020 19:08:35 +0200 Subject: [PATCH 3/7] Cache GHE responses for schema queries This speeds up `pr`-related commands for GHE by caching schema introspection queries for 24h. --- api/cache.go | 94 +++++++++++++++++++++++++++++++++++++++++++++++ api/queries_pr.go | 10 ++++- 2 files changed, 103 insertions(+), 1 deletion(-) create mode 100644 api/cache.go diff --git a/api/cache.go b/api/cache.go new file mode 100644 index 000000000..f1dd4f7c2 --- /dev/null +++ b/api/cache.go @@ -0,0 +1,94 @@ +package api + +import ( + "bufio" + "bytes" + "crypto/sha256" + "errors" + "fmt" + "io" + "io/ioutil" + "net/http" + "os" + "path/filepath" + "time" +) + +// CacheReponse produces a RoundTripper that caches HTTP responses to disk for a specified amount of time +func CacheReponse(ttl time.Duration, dir string) ClientOption { + return func(tr http.RoundTripper) http.RoundTripper { + return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) { + key, keyErr := cacheKey(req) + cacheFile := filepath.Join(dir, key) + if keyErr == nil { + if res, err := readCache(ttl, cacheFile, req); err == nil { + return res, nil + } + } + res, err := tr.RoundTrip(req) + if err == nil && keyErr == nil { + _ = writeCache(cacheFile, res) + } + return res, err + }} + } +} + +func cacheKey(req *http.Request) (string, error) { + h := sha256.New() + fmt.Fprintf(h, "%s:", req.Method) + fmt.Fprintf(h, "%s:", req.URL.String()) + + if req.Body != nil { + bodyCopy := &bytes.Buffer{} + defer req.Body.Close() + _, err := io.Copy(h, io.TeeReader(req.Body, bodyCopy)) + req.Body = ioutil.NopCloser(bodyCopy) + if err != nil { + return "", err + } + } + + digest := h.Sum(nil) + return fmt.Sprintf("%x", digest), nil +} + +func readCache(ttl time.Duration, cacheFile string, req *http.Request) (*http.Response, error) { + f, err := os.Open(cacheFile) + if err != nil { + return nil, err + } + defer f.Close() + + fs, err := f.Stat() + if err != nil { + return nil, err + } + + age := time.Since(fs.ModTime()) + if age > ttl { + return nil, errors.New("cache expired") + } + + return http.ReadResponse(bufio.NewReader(f), req) +} + +func writeCache(cacheFile string, res *http.Response) error { + err := os.MkdirAll(filepath.Dir(cacheFile), 0755) + if err != nil { + return err + } + + f, err := os.OpenFile(cacheFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + return err + } + defer f.Close() + + bodyCopy := &bytes.Buffer{} + defer res.Body.Close() + res.Body = ioutil.NopCloser(io.TeeReader(res.Body, bodyCopy)) + err = res.Write(f) + res.Body = ioutil.NopCloser(bodyCopy) + return err +} diff --git a/api/queries_pr.go b/api/queries_pr.go index 7b63aa642..60e4a18b3 100644 --- a/api/queries_pr.go +++ b/api/queries_pr.go @@ -6,6 +6,8 @@ import ( "fmt" "io" "net/http" + "os" + "path/filepath" "strings" "time" @@ -266,7 +268,13 @@ func determinePullRequestFeatures(httpClient *http.Client, hostname string) (prF } `graphql:"Commit: __type(name: \"Commit\")"` } - v4 := graphQLClient(httpClient, hostname) + cacheDir := filepath.Join(os.TempDir(), "gh-cli-cache") + cacheTTL := time.Duration(24 * time.Hour) + cachedClient := &http.Client{ + Transport: CacheReponse(cacheTTL, cacheDir)(httpClient.Transport), + } + + v4 := graphQLClient(cachedClient, hostname) err = v4.Query(context.Background(), &featureDetection, nil) if err != nil { return From 93c8fc1e98105ff8b94b4e562f4d9346fa54ba3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Thu, 1 Oct 2020 16:33:56 +0200 Subject: [PATCH 4/7] Add tests for GraphQL introspection --- api/cache.go | 24 ++++++++++- api/cache_test.go | 70 +++++++++++++++++++++++++++++++ api/queries_pr.go | 18 +++----- api/queries_pr_test.go | 93 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 191 insertions(+), 14 deletions(-) create mode 100644 api/cache_test.go diff --git a/api/cache.go b/api/cache.go index f1dd4f7c2..79d446ae2 100644 --- a/api/cache.go +++ b/api/cache.go @@ -14,6 +14,13 @@ import ( "time" ) +func makeCachedClient(httpClient *http.Client, cacheTTL time.Duration) *http.Client { + cacheDir := filepath.Join(os.TempDir(), "gh-cli-cache") + return &http.Client{ + Transport: CacheReponse(cacheTTL, cacheDir)(httpClient.Transport), + } +} + // CacheReponse produces a RoundTripper that caches HTTP responses to disk for a specified amount of time func CacheReponse(ttl time.Duration, dir string) ClientOption { return func(tr http.RoundTripper) http.RoundTripper { @@ -21,12 +28,14 @@ func CacheReponse(ttl time.Duration, dir string) ClientOption { key, keyErr := cacheKey(req) cacheFile := filepath.Join(dir, key) if keyErr == nil { + // TODO: make thread-safe if res, err := readCache(ttl, cacheFile, req); err == nil { return res, nil } } res, err := tr.RoundTrip(req) if err == nil && keyErr == nil { + // TODO: make thread-safe _ = writeCache(cacheFile, res) } return res, err @@ -53,12 +62,16 @@ func cacheKey(req *http.Request) (string, error) { return fmt.Sprintf("%x", digest), nil } +type readCloser struct { + io.Reader + io.Closer +} + func readCache(ttl time.Duration, cacheFile string, req *http.Request) (*http.Response, error) { f, err := os.Open(cacheFile) if err != nil { return nil, err } - defer f.Close() fs, err := f.Stat() if err != nil { @@ -70,7 +83,14 @@ func readCache(ttl time.Duration, cacheFile string, req *http.Request) (*http.Re return nil, errors.New("cache expired") } - return http.ReadResponse(bufio.NewReader(f), req) + res, err := http.ReadResponse(bufio.NewReader(f), req) + if res != nil { + res.Body = &readCloser{ + Reader: res.Body, + Closer: f, + } + } + return res, err } func writeCache(cacheFile string, res *http.Response) error { diff --git a/api/cache_test.go b/api/cache_test.go new file mode 100644 index 000000000..48255186f --- /dev/null +++ b/api/cache_test.go @@ -0,0 +1,70 @@ +package api + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "net/http" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_CacheReponse(t *testing.T) { + counter := 0 + fakeHTTP := funcTripper{ + roundTrip: func(req *http.Request) (*http.Response, error) { + counter += 1 + body := fmt.Sprintf("%d: %s %s", counter, req.Method, req.URL.String()) + return &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(bytes.NewBufferString(body)), + }, nil + }, + } + + cacheDir := filepath.Join(t.TempDir(), "gh-cli-cache") + httpClient := NewHTTPClient(ReplaceTripper(fakeHTTP), CacheReponse(time.Minute, cacheDir)) + + do := func(method, url string, body io.Reader) (string, error) { + req, err := http.NewRequest(method, url, body) + if err != nil { + return "", err + } + res, err := httpClient.Do(req) + if err != nil { + return "", err + } + resBody, err := ioutil.ReadAll(res.Body) + if err != nil { + err = fmt.Errorf("ReadAll: %w", err) + } + return string(resBody), err + } + + res1, err := do("GET", "http://example.com/path", nil) + require.NoError(t, err) + assert.Equal(t, "1: GET http://example.com/path", res1) + res2, err := do("GET", "http://example.com/path", nil) + require.NoError(t, err) + assert.Equal(t, "1: GET http://example.com/path", res2) + + res3, err := do("GET", "http://example.com/path2", nil) + require.NoError(t, err) + assert.Equal(t, "2: GET http://example.com/path2", res3) + + res4, err := do("POST", "http://example.com/path", bytes.NewBufferString(`hello`)) + require.NoError(t, err) + assert.Equal(t, "3: POST http://example.com/path", res4) + res5, err := do("POST", "http://example.com/path", bytes.NewBufferString(`hello`)) + require.NoError(t, err) + assert.Equal(t, "3: POST http://example.com/path", res5) + + res6, err := do("POST", "http://example.com/path", bytes.NewBufferString(`hello2`)) + require.NoError(t, err) + assert.Equal(t, "4: POST http://example.com/path", res6) +} diff --git a/api/queries_pr.go b/api/queries_pr.go index 60e4a18b3..34011d5be 100644 --- a/api/queries_pr.go +++ b/api/queries_pr.go @@ -6,8 +6,6 @@ import ( "fmt" "io" "net/http" - "os" - "path/filepath" "strings" "time" @@ -268,14 +266,8 @@ func determinePullRequestFeatures(httpClient *http.Client, hostname string) (prF } `graphql:"Commit: __type(name: \"Commit\")"` } - cacheDir := filepath.Join(os.TempDir(), "gh-cli-cache") - cacheTTL := time.Duration(24 * time.Hour) - cachedClient := &http.Client{ - Transport: CacheReponse(cacheTTL, cacheDir)(httpClient.Transport), - } - - v4 := graphQLClient(cachedClient, hostname) - err = v4.Query(context.Background(), &featureDetection, nil) + v4 := graphQLClient(httpClient, hostname) + err = v4.QueryNamed(context.Background(), "PullRequest_fields", &featureDetection, nil) if err != nil { return } @@ -315,7 +307,8 @@ func PullRequests(client *Client, repo ghrepo.Interface, currentPRNumber int, cu ReviewRequested edges } - prFeatures, err := determinePullRequestFeatures(client.http, repo.RepoHost()) + cachedClient := makeCachedClient(client.http, time.Hour*24) + prFeatures, err := determinePullRequestFeatures(cachedClient, repo.RepoHost()) if err != nil { return nil, err } @@ -483,7 +476,8 @@ func PullRequests(client *Client, repo ghrepo.Interface, currentPRNumber int, cu } func prCommitsFragment(httpClient *http.Client, hostname string) (string, error) { - if prFeatures, err := determinePullRequestFeatures(httpClient, hostname); err != nil { + cachedClient := makeCachedClient(httpClient, time.Hour*24) + if prFeatures, err := determinePullRequestFeatures(cachedClient, hostname); err != nil { return "", err } else if !prFeatures.HasStatusCheckRollup { return "", nil diff --git a/api/queries_pr_test.go b/api/queries_pr_test.go index 3441534bf..e624b61b3 100644 --- a/api/queries_pr_test.go +++ b/api/queries_pr_test.go @@ -1,8 +1,10 @@ package api import ( + "reflect" "testing" + "github.com/MakeNowJust/heredoc" "github.com/cli/cli/internal/ghrepo" "github.com/cli/cli/pkg/httpmock" ) @@ -45,3 +47,94 @@ func TestBranchDeleteRemote(t *testing.T) { }) } } + +func Test_determinePullRequestFeatures(t *testing.T) { + tests := []struct { + name string + hostname string + queryResponse string + wantPrFeatures pullRequestFeature + wantErr bool + }{ + { + name: "github.com", + hostname: "github.com", + wantPrFeatures: pullRequestFeature{ + HasReviewDecision: true, + HasStatusCheckRollup: true, + }, + wantErr: false, + }, + { + name: "GHE empty response", + hostname: "git.my.org", + queryResponse: heredoc.Doc(` + {"data": {}} + `), + wantPrFeatures: pullRequestFeature{ + HasReviewDecision: false, + HasStatusCheckRollup: false, + }, + wantErr: false, + }, + { + name: "GHE has reviewDecision", + hostname: "git.my.org", + queryResponse: heredoc.Doc(` + {"data": { + "PullRequest": { + "fields": [ + {"name": "foo"}, + {"name": "reviewDecision"} + ] + } + } } + `), + wantPrFeatures: pullRequestFeature{ + HasReviewDecision: true, + HasStatusCheckRollup: false, + }, + wantErr: false, + }, + { + name: "GHE has statusCheckRollup", + hostname: "git.my.org", + queryResponse: heredoc.Doc(` + {"data": { + "Commit": { + "fields": [ + {"name": "foo"}, + {"name": "statusCheckRollup"} + ] + } + } } + `), + wantPrFeatures: pullRequestFeature{ + HasReviewDecision: false, + HasStatusCheckRollup: true, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fakeHTTP := &httpmock.Registry{} + httpClient := NewHTTPClient(ReplaceTripper(fakeHTTP)) + + if tt.queryResponse != "" { + fakeHTTP.Register( + httpmock.GraphQL(`query PullRequest_fields\b`), + httpmock.StringResponse(tt.queryResponse)) + } + + gotPrFeatures, err := determinePullRequestFeatures(httpClient, tt.hostname) + if (err != nil) != tt.wantErr { + t.Errorf("determinePullRequestFeatures() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(gotPrFeatures, tt.wantPrFeatures) { + t.Errorf("determinePullRequestFeatures() = %v, want %v", gotPrFeatures, tt.wantPrFeatures) + } + }) + } +} From 1d435a3e2e0542f9be0ca753ddafb3d70a48a6a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Thu, 1 Oct 2020 19:18:11 +0200 Subject: [PATCH 5/7] Ensure that cache file is closed after reading --- api/cache.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/api/cache.go b/api/cache.go index 79d446ae2..006362504 100644 --- a/api/cache.go +++ b/api/cache.go @@ -84,11 +84,13 @@ func readCache(ttl time.Duration, cacheFile string, req *http.Request) (*http.Re } res, err := http.ReadResponse(bufio.NewReader(f), req) - if res != nil { + if res == nil { res.Body = &readCloser{ Reader: res.Body, Closer: f, } + } else { + f.Close() } return res, err } From f7a82a216b2b05a46a23ab2fec9ef8a368743eb2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Thu, 1 Oct 2020 22:06:15 +0200 Subject: [PATCH 6/7] Avoid error from cache file being prematurely closed I have no idea what's going on there, so I'll just give up the streaming approach and read the entire contents of the cache file to memory. https://github.com/cli/cli/pull/2035/checks?check_run_id=1194798056 --- api/cache.go | 20 +++++++------------- api/cache_test.go | 1 + 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/api/cache.go b/api/cache.go index 006362504..9d6ee7ea0 100644 --- a/api/cache.go +++ b/api/cache.go @@ -62,16 +62,12 @@ func cacheKey(req *http.Request) (string, error) { return fmt.Sprintf("%x", digest), nil } -type readCloser struct { - io.Reader - io.Closer -} - func readCache(ttl time.Duration, cacheFile string, req *http.Request) (*http.Response, error) { f, err := os.Open(cacheFile) if err != nil { return nil, err } + defer f.Close() fs, err := f.Stat() if err != nil { @@ -83,15 +79,13 @@ func readCache(ttl time.Duration, cacheFile string, req *http.Request) (*http.Re return nil, errors.New("cache expired") } - res, err := http.ReadResponse(bufio.NewReader(f), req) - if res == nil { - res.Body = &readCloser{ - Reader: res.Body, - Closer: f, - } - } else { - f.Close() + body := &bytes.Buffer{} + _, err = io.Copy(body, f) + if err != nil { + return nil, err } + + res, err := http.ReadResponse(bufio.NewReader(body), req) return res, err } diff --git a/api/cache_test.go b/api/cache_test.go index 48255186f..8540e7d44 100644 --- a/api/cache_test.go +++ b/api/cache_test.go @@ -39,6 +39,7 @@ func Test_CacheReponse(t *testing.T) { if err != nil { return "", err } + defer res.Body.Close() resBody, err := ioutil.ReadAll(res.Body) if err != nil { err = fmt.Errorf("ReadAll: %w", err) From 7663acdc295b3a7d0c76ebade4ab0389644ac6b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Fri, 2 Oct 2020 15:19:40 +0200 Subject: [PATCH 7/7] Improve HTTP caching layer - make thread-safe - only cache GET, HEAD, and GraphQL requests - only cache non-5xx, non-403 responses - include `Accept` and `Authorization` headers in cache key --- api/cache.go | 105 +++++++++++++++++++++++++++++++++++++--------- api/cache_test.go | 50 +++++++++++++++------- 2 files changed, 119 insertions(+), 36 deletions(-) diff --git a/api/cache.go b/api/cache.go index 9d6ee7ea0..620660c15 100644 --- a/api/cache.go +++ b/api/cache.go @@ -11,6 +11,8 @@ import ( "net/http" "os" "path/filepath" + "strings" + "sync" "time" ) @@ -21,39 +23,79 @@ func makeCachedClient(httpClient *http.Client, cacheTTL time.Duration) *http.Cli } } +func isCacheableRequest(req *http.Request) bool { + if strings.EqualFold(req.Method, "GET") || strings.EqualFold(req.Method, "HEAD") { + return true + } + + if strings.EqualFold(req.Method, "POST") && (req.URL.Path == "/graphql" || req.URL.Path == "/api/graphql") { + return true + } + + return false +} + +func isCacheableResponse(res *http.Response) bool { + return res.StatusCode < 500 && res.StatusCode != 403 +} + // CacheReponse produces a RoundTripper that caches HTTP responses to disk for a specified amount of time func CacheReponse(ttl time.Duration, dir string) ClientOption { + fs := fileStorage{ + dir: dir, + ttl: ttl, + mu: &sync.RWMutex{}, + } + return func(tr http.RoundTripper) http.RoundTripper { return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) { + if !isCacheableRequest(req) { + return tr.RoundTrip(req) + } + key, keyErr := cacheKey(req) - cacheFile := filepath.Join(dir, key) if keyErr == nil { - // TODO: make thread-safe - if res, err := readCache(ttl, cacheFile, req); err == nil { + if res, err := fs.read(key); err == nil { + res.Request = req return res, nil } } + res, err := tr.RoundTrip(req) - if err == nil && keyErr == nil { - // TODO: make thread-safe - _ = writeCache(cacheFile, res) + if err == nil && keyErr == nil && isCacheableResponse(res) { + _ = fs.store(key, res) } return res, err }} } } +func copyStream(r io.ReadCloser) (io.ReadCloser, io.ReadCloser) { + b := &bytes.Buffer{} + nr := io.TeeReader(r, b) + return ioutil.NopCloser(b), &readCloser{ + Reader: nr, + Closer: r, + } +} + +type readCloser struct { + io.Reader + io.Closer +} + func cacheKey(req *http.Request) (string, error) { h := sha256.New() fmt.Fprintf(h, "%s:", req.Method) fmt.Fprintf(h, "%s:", req.URL.String()) + fmt.Fprintf(h, "%s:", req.Header.Get("Accept")) + fmt.Fprintf(h, "%s:", req.Header.Get("Authorization")) if req.Body != nil { - bodyCopy := &bytes.Buffer{} - defer req.Body.Close() - _, err := io.Copy(h, io.TeeReader(req.Body, bodyCopy)) - req.Body = ioutil.NopCloser(bodyCopy) - if err != nil { + var bodyCopy io.ReadCloser + req.Body, bodyCopy = copyStream(req.Body) + defer bodyCopy.Close() + if _, err := io.Copy(h, bodyCopy); err != nil { return "", err } } @@ -62,20 +104,38 @@ func cacheKey(req *http.Request) (string, error) { return fmt.Sprintf("%x", digest), nil } -func readCache(ttl time.Duration, cacheFile string, req *http.Request) (*http.Response, error) { +type fileStorage struct { + dir string + ttl time.Duration + mu *sync.RWMutex +} + +func (fs *fileStorage) filePath(key string) string { + if len(key) >= 6 { + return filepath.Join(fs.dir, key[0:2], key[2:4], key[4:]) + } + return filepath.Join(fs.dir, key) +} + +func (fs *fileStorage) read(key string) (*http.Response, error) { + cacheFile := fs.filePath(key) + + fs.mu.RLock() + defer fs.mu.RUnlock() + f, err := os.Open(cacheFile) if err != nil { return nil, err } defer f.Close() - fs, err := f.Stat() + stat, err := f.Stat() if err != nil { return nil, err } - age := time.Since(fs.ModTime()) - if age > ttl { + age := time.Since(stat.ModTime()) + if age > fs.ttl { return nil, errors.New("cache expired") } @@ -85,11 +145,16 @@ func readCache(ttl time.Duration, cacheFile string, req *http.Request) (*http.Re return nil, err } - res, err := http.ReadResponse(bufio.NewReader(body), req) + res, err := http.ReadResponse(bufio.NewReader(body), nil) return res, err } -func writeCache(cacheFile string, res *http.Response) error { +func (fs *fileStorage) store(key string, res *http.Response) error { + cacheFile := fs.filePath(key) + + fs.mu.Lock() + defer fs.mu.Unlock() + err := os.MkdirAll(filepath.Dir(cacheFile), 0755) if err != nil { return err @@ -101,10 +166,10 @@ func writeCache(cacheFile string, res *http.Response) error { } defer f.Close() - bodyCopy := &bytes.Buffer{} + var origBody io.ReadCloser + origBody, res.Body = copyStream(res.Body) defer res.Body.Close() - res.Body = ioutil.NopCloser(io.TeeReader(res.Body, bodyCopy)) err = res.Write(f) - res.Body = ioutil.NopCloser(bodyCopy) + res.Body = origBody return err } diff --git a/api/cache_test.go b/api/cache_test.go index 8540e7d44..d1039d71b 100644 --- a/api/cache_test.go +++ b/api/cache_test.go @@ -20,8 +20,12 @@ func Test_CacheReponse(t *testing.T) { roundTrip: func(req *http.Request) (*http.Response, error) { counter += 1 body := fmt.Sprintf("%d: %s %s", counter, req.Method, req.URL.String()) + status := 200 + if req.URL.Path == "/error" { + status = 500 + } return &http.Response{ - StatusCode: 200, + StatusCode: status, Body: ioutil.NopCloser(bytes.NewBufferString(body)), }, nil }, @@ -47,25 +51,39 @@ func Test_CacheReponse(t *testing.T) { return string(resBody), err } - res1, err := do("GET", "http://example.com/path", nil) - require.NoError(t, err) - assert.Equal(t, "1: GET http://example.com/path", res1) - res2, err := do("GET", "http://example.com/path", nil) - require.NoError(t, err) - assert.Equal(t, "1: GET http://example.com/path", res2) + var res string + var err error - res3, err := do("GET", "http://example.com/path2", nil) + res, err = do("GET", "http://example.com/path", nil) require.NoError(t, err) - assert.Equal(t, "2: GET http://example.com/path2", res3) + assert.Equal(t, "1: GET http://example.com/path", res) + res, err = do("GET", "http://example.com/path", nil) + require.NoError(t, err) + assert.Equal(t, "1: GET http://example.com/path", res) - res4, err := do("POST", "http://example.com/path", bytes.NewBufferString(`hello`)) + res, err = do("GET", "http://example.com/path2", nil) require.NoError(t, err) - assert.Equal(t, "3: POST http://example.com/path", res4) - res5, err := do("POST", "http://example.com/path", bytes.NewBufferString(`hello`)) - require.NoError(t, err) - assert.Equal(t, "3: POST http://example.com/path", res5) + assert.Equal(t, "2: GET http://example.com/path2", res) - res6, err := do("POST", "http://example.com/path", bytes.NewBufferString(`hello2`)) + res, err = do("POST", "http://example.com/path2", nil) require.NoError(t, err) - assert.Equal(t, "4: POST http://example.com/path", res6) + assert.Equal(t, "3: POST http://example.com/path2", res) + + res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello`)) + require.NoError(t, err) + assert.Equal(t, "4: POST http://example.com/graphql", res) + res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello`)) + require.NoError(t, err) + assert.Equal(t, "4: POST http://example.com/graphql", res) + + res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello2`)) + require.NoError(t, err) + assert.Equal(t, "5: POST http://example.com/graphql", res) + + res, err = do("GET", "http://example.com/error", nil) + require.NoError(t, err) + assert.Equal(t, "6: GET http://example.com/error", res) + res, err = do("GET", "http://example.com/error", nil) + require.NoError(t, err) + assert.Equal(t, "7: GET http://example.com/error", res) }