Improve HTTP caching layer

- make thread-safe
- only cache GET, HEAD, and GraphQL requests
- only cache non-5xx, non-403 responses
- include `Accept` and `Authorization` headers in cache key
This commit is contained in:
Mislav Marohnić 2020-10-02 15:19:40 +02:00
parent f7a82a216b
commit 7663acdc29
2 changed files with 119 additions and 36 deletions

View file

@ -11,6 +11,8 @@ import (
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"time"
)
@ -21,39 +23,79 @@ func makeCachedClient(httpClient *http.Client, cacheTTL time.Duration) *http.Cli
}
}
func isCacheableRequest(req *http.Request) bool {
if strings.EqualFold(req.Method, "GET") || strings.EqualFold(req.Method, "HEAD") {
return true
}
if strings.EqualFold(req.Method, "POST") && (req.URL.Path == "/graphql" || req.URL.Path == "/api/graphql") {
return true
}
return false
}
func isCacheableResponse(res *http.Response) bool {
return res.StatusCode < 500 && res.StatusCode != 403
}
// CacheReponse produces a RoundTripper that caches HTTP responses to disk for a specified amount of time
func CacheReponse(ttl time.Duration, dir string) ClientOption {
fs := fileStorage{
dir: dir,
ttl: ttl,
mu: &sync.RWMutex{},
}
return func(tr http.RoundTripper) http.RoundTripper {
return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) {
if !isCacheableRequest(req) {
return tr.RoundTrip(req)
}
key, keyErr := cacheKey(req)
cacheFile := filepath.Join(dir, key)
if keyErr == nil {
// TODO: make thread-safe
if res, err := readCache(ttl, cacheFile, req); err == nil {
if res, err := fs.read(key); err == nil {
res.Request = req
return res, nil
}
}
res, err := tr.RoundTrip(req)
if err == nil && keyErr == nil {
// TODO: make thread-safe
_ = writeCache(cacheFile, res)
if err == nil && keyErr == nil && isCacheableResponse(res) {
_ = fs.store(key, res)
}
return res, err
}}
}
}
func copyStream(r io.ReadCloser) (io.ReadCloser, io.ReadCloser) {
b := &bytes.Buffer{}
nr := io.TeeReader(r, b)
return ioutil.NopCloser(b), &readCloser{
Reader: nr,
Closer: r,
}
}
type readCloser struct {
io.Reader
io.Closer
}
func cacheKey(req *http.Request) (string, error) {
h := sha256.New()
fmt.Fprintf(h, "%s:", req.Method)
fmt.Fprintf(h, "%s:", req.URL.String())
fmt.Fprintf(h, "%s:", req.Header.Get("Accept"))
fmt.Fprintf(h, "%s:", req.Header.Get("Authorization"))
if req.Body != nil {
bodyCopy := &bytes.Buffer{}
defer req.Body.Close()
_, err := io.Copy(h, io.TeeReader(req.Body, bodyCopy))
req.Body = ioutil.NopCloser(bodyCopy)
if err != nil {
var bodyCopy io.ReadCloser
req.Body, bodyCopy = copyStream(req.Body)
defer bodyCopy.Close()
if _, err := io.Copy(h, bodyCopy); err != nil {
return "", err
}
}
@ -62,20 +104,38 @@ func cacheKey(req *http.Request) (string, error) {
return fmt.Sprintf("%x", digest), nil
}
func readCache(ttl time.Duration, cacheFile string, req *http.Request) (*http.Response, error) {
type fileStorage struct {
dir string
ttl time.Duration
mu *sync.RWMutex
}
func (fs *fileStorage) filePath(key string) string {
if len(key) >= 6 {
return filepath.Join(fs.dir, key[0:2], key[2:4], key[4:])
}
return filepath.Join(fs.dir, key)
}
func (fs *fileStorage) read(key string) (*http.Response, error) {
cacheFile := fs.filePath(key)
fs.mu.RLock()
defer fs.mu.RUnlock()
f, err := os.Open(cacheFile)
if err != nil {
return nil, err
}
defer f.Close()
fs, err := f.Stat()
stat, err := f.Stat()
if err != nil {
return nil, err
}
age := time.Since(fs.ModTime())
if age > ttl {
age := time.Since(stat.ModTime())
if age > fs.ttl {
return nil, errors.New("cache expired")
}
@ -85,11 +145,16 @@ func readCache(ttl time.Duration, cacheFile string, req *http.Request) (*http.Re
return nil, err
}
res, err := http.ReadResponse(bufio.NewReader(body), req)
res, err := http.ReadResponse(bufio.NewReader(body), nil)
return res, err
}
func writeCache(cacheFile string, res *http.Response) error {
func (fs *fileStorage) store(key string, res *http.Response) error {
cacheFile := fs.filePath(key)
fs.mu.Lock()
defer fs.mu.Unlock()
err := os.MkdirAll(filepath.Dir(cacheFile), 0755)
if err != nil {
return err
@ -101,10 +166,10 @@ func writeCache(cacheFile string, res *http.Response) error {
}
defer f.Close()
bodyCopy := &bytes.Buffer{}
var origBody io.ReadCloser
origBody, res.Body = copyStream(res.Body)
defer res.Body.Close()
res.Body = ioutil.NopCloser(io.TeeReader(res.Body, bodyCopy))
err = res.Write(f)
res.Body = ioutil.NopCloser(bodyCopy)
res.Body = origBody
return err
}

View file

@ -20,8 +20,12 @@ func Test_CacheReponse(t *testing.T) {
roundTrip: func(req *http.Request) (*http.Response, error) {
counter += 1
body := fmt.Sprintf("%d: %s %s", counter, req.Method, req.URL.String())
status := 200
if req.URL.Path == "/error" {
status = 500
}
return &http.Response{
StatusCode: 200,
StatusCode: status,
Body: ioutil.NopCloser(bytes.NewBufferString(body)),
}, nil
},
@ -47,25 +51,39 @@ func Test_CacheReponse(t *testing.T) {
return string(resBody), err
}
res1, err := do("GET", "http://example.com/path", nil)
require.NoError(t, err)
assert.Equal(t, "1: GET http://example.com/path", res1)
res2, err := do("GET", "http://example.com/path", nil)
require.NoError(t, err)
assert.Equal(t, "1: GET http://example.com/path", res2)
var res string
var err error
res3, err := do("GET", "http://example.com/path2", nil)
res, err = do("GET", "http://example.com/path", nil)
require.NoError(t, err)
assert.Equal(t, "2: GET http://example.com/path2", res3)
assert.Equal(t, "1: GET http://example.com/path", res)
res, err = do("GET", "http://example.com/path", nil)
require.NoError(t, err)
assert.Equal(t, "1: GET http://example.com/path", res)
res4, err := do("POST", "http://example.com/path", bytes.NewBufferString(`hello`))
res, err = do("GET", "http://example.com/path2", nil)
require.NoError(t, err)
assert.Equal(t, "3: POST http://example.com/path", res4)
res5, err := do("POST", "http://example.com/path", bytes.NewBufferString(`hello`))
require.NoError(t, err)
assert.Equal(t, "3: POST http://example.com/path", res5)
assert.Equal(t, "2: GET http://example.com/path2", res)
res6, err := do("POST", "http://example.com/path", bytes.NewBufferString(`hello2`))
res, err = do("POST", "http://example.com/path2", nil)
require.NoError(t, err)
assert.Equal(t, "4: POST http://example.com/path", res6)
assert.Equal(t, "3: POST http://example.com/path2", res)
res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello`))
require.NoError(t, err)
assert.Equal(t, "4: POST http://example.com/graphql", res)
res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello`))
require.NoError(t, err)
assert.Equal(t, "4: POST http://example.com/graphql", res)
res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello2`))
require.NoError(t, err)
assert.Equal(t, "5: POST http://example.com/graphql", res)
res, err = do("GET", "http://example.com/error", nil)
require.NoError(t, err)
assert.Equal(t, "6: GET http://example.com/error", res)
res, err = do("GET", "http://example.com/error", nil)
require.NoError(t, err)
assert.Equal(t, "7: GET http://example.com/error", res)
}