From 7663acdc295b3a7d0c76ebade4ab0389644ac6b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Fri, 2 Oct 2020 15:19:40 +0200 Subject: [PATCH] Improve HTTP caching layer - make thread-safe - only cache GET, HEAD, and GraphQL requests - only cache non-5xx, non-403 responses - include `Accept` and `Authorization` headers in cache key --- api/cache.go | 105 +++++++++++++++++++++++++++++++++++++--------- api/cache_test.go | 50 +++++++++++++++------- 2 files changed, 119 insertions(+), 36 deletions(-) diff --git a/api/cache.go b/api/cache.go index 9d6ee7ea0..620660c15 100644 --- a/api/cache.go +++ b/api/cache.go @@ -11,6 +11,8 @@ import ( "net/http" "os" "path/filepath" + "strings" + "sync" "time" ) @@ -21,39 +23,79 @@ func makeCachedClient(httpClient *http.Client, cacheTTL time.Duration) *http.Cli } } +func isCacheableRequest(req *http.Request) bool { + if strings.EqualFold(req.Method, "GET") || strings.EqualFold(req.Method, "HEAD") { + return true + } + + if strings.EqualFold(req.Method, "POST") && (req.URL.Path == "/graphql" || req.URL.Path == "/api/graphql") { + return true + } + + return false +} + +func isCacheableResponse(res *http.Response) bool { + return res.StatusCode < 500 && res.StatusCode != 403 +} + // CacheReponse produces a RoundTripper that caches HTTP responses to disk for a specified amount of time func CacheReponse(ttl time.Duration, dir string) ClientOption { + fs := fileStorage{ + dir: dir, + ttl: ttl, + mu: &sync.RWMutex{}, + } + return func(tr http.RoundTripper) http.RoundTripper { return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) { + if !isCacheableRequest(req) { + return tr.RoundTrip(req) + } + key, keyErr := cacheKey(req) - cacheFile := filepath.Join(dir, key) if keyErr == nil { - // TODO: make thread-safe - if res, err := readCache(ttl, cacheFile, req); err == nil { + if res, err := fs.read(key); err == nil { + res.Request = req return res, nil } } + res, err := tr.RoundTrip(req) - if err == nil && keyErr == nil { - // TODO: make thread-safe - _ = writeCache(cacheFile, res) + if err == nil && keyErr == nil && isCacheableResponse(res) { + _ = fs.store(key, res) } return res, err }} } } +func copyStream(r io.ReadCloser) (io.ReadCloser, io.ReadCloser) { + b := &bytes.Buffer{} + nr := io.TeeReader(r, b) + return ioutil.NopCloser(b), &readCloser{ + Reader: nr, + Closer: r, + } +} + +type readCloser struct { + io.Reader + io.Closer +} + func cacheKey(req *http.Request) (string, error) { h := sha256.New() fmt.Fprintf(h, "%s:", req.Method) fmt.Fprintf(h, "%s:", req.URL.String()) + fmt.Fprintf(h, "%s:", req.Header.Get("Accept")) + fmt.Fprintf(h, "%s:", req.Header.Get("Authorization")) if req.Body != nil { - bodyCopy := &bytes.Buffer{} - defer req.Body.Close() - _, err := io.Copy(h, io.TeeReader(req.Body, bodyCopy)) - req.Body = ioutil.NopCloser(bodyCopy) - if err != nil { + var bodyCopy io.ReadCloser + req.Body, bodyCopy = copyStream(req.Body) + defer bodyCopy.Close() + if _, err := io.Copy(h, bodyCopy); err != nil { return "", err } } @@ -62,20 +104,38 @@ func cacheKey(req *http.Request) (string, error) { return fmt.Sprintf("%x", digest), nil } -func readCache(ttl time.Duration, cacheFile string, req *http.Request) (*http.Response, error) { +type fileStorage struct { + dir string + ttl time.Duration + mu *sync.RWMutex +} + +func (fs *fileStorage) filePath(key string) string { + if len(key) >= 6 { + return filepath.Join(fs.dir, key[0:2], key[2:4], key[4:]) + } + return filepath.Join(fs.dir, key) +} + +func (fs *fileStorage) read(key string) (*http.Response, error) { + cacheFile := fs.filePath(key) + + fs.mu.RLock() + defer fs.mu.RUnlock() + f, err := os.Open(cacheFile) if err != nil { return nil, err } defer f.Close() - fs, err := f.Stat() + stat, err := f.Stat() if err != nil { return nil, err } - age := time.Since(fs.ModTime()) - if age > ttl { + age := time.Since(stat.ModTime()) + if age > fs.ttl { return nil, errors.New("cache expired") } @@ -85,11 +145,16 @@ func readCache(ttl time.Duration, cacheFile string, req *http.Request) (*http.Re return nil, err } - res, err := http.ReadResponse(bufio.NewReader(body), req) + res, err := http.ReadResponse(bufio.NewReader(body), nil) return res, err } -func writeCache(cacheFile string, res *http.Response) error { +func (fs *fileStorage) store(key string, res *http.Response) error { + cacheFile := fs.filePath(key) + + fs.mu.Lock() + defer fs.mu.Unlock() + err := os.MkdirAll(filepath.Dir(cacheFile), 0755) if err != nil { return err @@ -101,10 +166,10 @@ func writeCache(cacheFile string, res *http.Response) error { } defer f.Close() - bodyCopy := &bytes.Buffer{} + var origBody io.ReadCloser + origBody, res.Body = copyStream(res.Body) defer res.Body.Close() - res.Body = ioutil.NopCloser(io.TeeReader(res.Body, bodyCopy)) err = res.Write(f) - res.Body = ioutil.NopCloser(bodyCopy) + res.Body = origBody return err } diff --git a/api/cache_test.go b/api/cache_test.go index 8540e7d44..d1039d71b 100644 --- a/api/cache_test.go +++ b/api/cache_test.go @@ -20,8 +20,12 @@ func Test_CacheReponse(t *testing.T) { roundTrip: func(req *http.Request) (*http.Response, error) { counter += 1 body := fmt.Sprintf("%d: %s %s", counter, req.Method, req.URL.String()) + status := 200 + if req.URL.Path == "/error" { + status = 500 + } return &http.Response{ - StatusCode: 200, + StatusCode: status, Body: ioutil.NopCloser(bytes.NewBufferString(body)), }, nil }, @@ -47,25 +51,39 @@ func Test_CacheReponse(t *testing.T) { return string(resBody), err } - res1, err := do("GET", "http://example.com/path", nil) - require.NoError(t, err) - assert.Equal(t, "1: GET http://example.com/path", res1) - res2, err := do("GET", "http://example.com/path", nil) - require.NoError(t, err) - assert.Equal(t, "1: GET http://example.com/path", res2) + var res string + var err error - res3, err := do("GET", "http://example.com/path2", nil) + res, err = do("GET", "http://example.com/path", nil) require.NoError(t, err) - assert.Equal(t, "2: GET http://example.com/path2", res3) + assert.Equal(t, "1: GET http://example.com/path", res) + res, err = do("GET", "http://example.com/path", nil) + require.NoError(t, err) + assert.Equal(t, "1: GET http://example.com/path", res) - res4, err := do("POST", "http://example.com/path", bytes.NewBufferString(`hello`)) + res, err = do("GET", "http://example.com/path2", nil) require.NoError(t, err) - assert.Equal(t, "3: POST http://example.com/path", res4) - res5, err := do("POST", "http://example.com/path", bytes.NewBufferString(`hello`)) - require.NoError(t, err) - assert.Equal(t, "3: POST http://example.com/path", res5) + assert.Equal(t, "2: GET http://example.com/path2", res) - res6, err := do("POST", "http://example.com/path", bytes.NewBufferString(`hello2`)) + res, err = do("POST", "http://example.com/path2", nil) require.NoError(t, err) - assert.Equal(t, "4: POST http://example.com/path", res6) + assert.Equal(t, "3: POST http://example.com/path2", res) + + res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello`)) + require.NoError(t, err) + assert.Equal(t, "4: POST http://example.com/graphql", res) + res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello`)) + require.NoError(t, err) + assert.Equal(t, "4: POST http://example.com/graphql", res) + + res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello2`)) + require.NoError(t, err) + assert.Equal(t, "5: POST http://example.com/graphql", res) + + res, err = do("GET", "http://example.com/error", nil) + require.NoError(t, err) + assert.Equal(t, "6: GET http://example.com/error", res) + res, err = do("GET", "http://example.com/error", nil) + require.NoError(t, err) + assert.Equal(t, "7: GET http://example.com/error", res) }