Improve HTTP caching layer
- make thread-safe - only cache GET, HEAD, and GraphQL requests - only cache non-5xx, non-403 responses - include `Accept` and `Authorization` headers in cache key
This commit is contained in:
parent
f7a82a216b
commit
7663acdc29
2 changed files with 119 additions and 36 deletions
105
api/cache.go
105
api/cache.go
|
|
@ -11,6 +11,8 @@ import (
|
|||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
|
@ -21,39 +23,79 @@ func makeCachedClient(httpClient *http.Client, cacheTTL time.Duration) *http.Cli
|
|||
}
|
||||
}
|
||||
|
||||
func isCacheableRequest(req *http.Request) bool {
|
||||
if strings.EqualFold(req.Method, "GET") || strings.EqualFold(req.Method, "HEAD") {
|
||||
return true
|
||||
}
|
||||
|
||||
if strings.EqualFold(req.Method, "POST") && (req.URL.Path == "/graphql" || req.URL.Path == "/api/graphql") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func isCacheableResponse(res *http.Response) bool {
|
||||
return res.StatusCode < 500 && res.StatusCode != 403
|
||||
}
|
||||
|
||||
// CacheReponse produces a RoundTripper that caches HTTP responses to disk for a specified amount of time
|
||||
func CacheReponse(ttl time.Duration, dir string) ClientOption {
|
||||
fs := fileStorage{
|
||||
dir: dir,
|
||||
ttl: ttl,
|
||||
mu: &sync.RWMutex{},
|
||||
}
|
||||
|
||||
return func(tr http.RoundTripper) http.RoundTripper {
|
||||
return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) {
|
||||
if !isCacheableRequest(req) {
|
||||
return tr.RoundTrip(req)
|
||||
}
|
||||
|
||||
key, keyErr := cacheKey(req)
|
||||
cacheFile := filepath.Join(dir, key)
|
||||
if keyErr == nil {
|
||||
// TODO: make thread-safe
|
||||
if res, err := readCache(ttl, cacheFile, req); err == nil {
|
||||
if res, err := fs.read(key); err == nil {
|
||||
res.Request = req
|
||||
return res, nil
|
||||
}
|
||||
}
|
||||
|
||||
res, err := tr.RoundTrip(req)
|
||||
if err == nil && keyErr == nil {
|
||||
// TODO: make thread-safe
|
||||
_ = writeCache(cacheFile, res)
|
||||
if err == nil && keyErr == nil && isCacheableResponse(res) {
|
||||
_ = fs.store(key, res)
|
||||
}
|
||||
return res, err
|
||||
}}
|
||||
}
|
||||
}
|
||||
|
||||
func copyStream(r io.ReadCloser) (io.ReadCloser, io.ReadCloser) {
|
||||
b := &bytes.Buffer{}
|
||||
nr := io.TeeReader(r, b)
|
||||
return ioutil.NopCloser(b), &readCloser{
|
||||
Reader: nr,
|
||||
Closer: r,
|
||||
}
|
||||
}
|
||||
|
||||
type readCloser struct {
|
||||
io.Reader
|
||||
io.Closer
|
||||
}
|
||||
|
||||
func cacheKey(req *http.Request) (string, error) {
|
||||
h := sha256.New()
|
||||
fmt.Fprintf(h, "%s:", req.Method)
|
||||
fmt.Fprintf(h, "%s:", req.URL.String())
|
||||
fmt.Fprintf(h, "%s:", req.Header.Get("Accept"))
|
||||
fmt.Fprintf(h, "%s:", req.Header.Get("Authorization"))
|
||||
|
||||
if req.Body != nil {
|
||||
bodyCopy := &bytes.Buffer{}
|
||||
defer req.Body.Close()
|
||||
_, err := io.Copy(h, io.TeeReader(req.Body, bodyCopy))
|
||||
req.Body = ioutil.NopCloser(bodyCopy)
|
||||
if err != nil {
|
||||
var bodyCopy io.ReadCloser
|
||||
req.Body, bodyCopy = copyStream(req.Body)
|
||||
defer bodyCopy.Close()
|
||||
if _, err := io.Copy(h, bodyCopy); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
|
@ -62,20 +104,38 @@ func cacheKey(req *http.Request) (string, error) {
|
|||
return fmt.Sprintf("%x", digest), nil
|
||||
}
|
||||
|
||||
func readCache(ttl time.Duration, cacheFile string, req *http.Request) (*http.Response, error) {
|
||||
type fileStorage struct {
|
||||
dir string
|
||||
ttl time.Duration
|
||||
mu *sync.RWMutex
|
||||
}
|
||||
|
||||
func (fs *fileStorage) filePath(key string) string {
|
||||
if len(key) >= 6 {
|
||||
return filepath.Join(fs.dir, key[0:2], key[2:4], key[4:])
|
||||
}
|
||||
return filepath.Join(fs.dir, key)
|
||||
}
|
||||
|
||||
func (fs *fileStorage) read(key string) (*http.Response, error) {
|
||||
cacheFile := fs.filePath(key)
|
||||
|
||||
fs.mu.RLock()
|
||||
defer fs.mu.RUnlock()
|
||||
|
||||
f, err := os.Open(cacheFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
fs, err := f.Stat()
|
||||
stat, err := f.Stat()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
age := time.Since(fs.ModTime())
|
||||
if age > ttl {
|
||||
age := time.Since(stat.ModTime())
|
||||
if age > fs.ttl {
|
||||
return nil, errors.New("cache expired")
|
||||
}
|
||||
|
||||
|
|
@ -85,11 +145,16 @@ func readCache(ttl time.Duration, cacheFile string, req *http.Request) (*http.Re
|
|||
return nil, err
|
||||
}
|
||||
|
||||
res, err := http.ReadResponse(bufio.NewReader(body), req)
|
||||
res, err := http.ReadResponse(bufio.NewReader(body), nil)
|
||||
return res, err
|
||||
}
|
||||
|
||||
func writeCache(cacheFile string, res *http.Response) error {
|
||||
func (fs *fileStorage) store(key string, res *http.Response) error {
|
||||
cacheFile := fs.filePath(key)
|
||||
|
||||
fs.mu.Lock()
|
||||
defer fs.mu.Unlock()
|
||||
|
||||
err := os.MkdirAll(filepath.Dir(cacheFile), 0755)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -101,10 +166,10 @@ func writeCache(cacheFile string, res *http.Response) error {
|
|||
}
|
||||
defer f.Close()
|
||||
|
||||
bodyCopy := &bytes.Buffer{}
|
||||
var origBody io.ReadCloser
|
||||
origBody, res.Body = copyStream(res.Body)
|
||||
defer res.Body.Close()
|
||||
res.Body = ioutil.NopCloser(io.TeeReader(res.Body, bodyCopy))
|
||||
err = res.Write(f)
|
||||
res.Body = ioutil.NopCloser(bodyCopy)
|
||||
res.Body = origBody
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,8 +20,12 @@ func Test_CacheReponse(t *testing.T) {
|
|||
roundTrip: func(req *http.Request) (*http.Response, error) {
|
||||
counter += 1
|
||||
body := fmt.Sprintf("%d: %s %s", counter, req.Method, req.URL.String())
|
||||
status := 200
|
||||
if req.URL.Path == "/error" {
|
||||
status = 500
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: 200,
|
||||
StatusCode: status,
|
||||
Body: ioutil.NopCloser(bytes.NewBufferString(body)),
|
||||
}, nil
|
||||
},
|
||||
|
|
@ -47,25 +51,39 @@ func Test_CacheReponse(t *testing.T) {
|
|||
return string(resBody), err
|
||||
}
|
||||
|
||||
res1, err := do("GET", "http://example.com/path", nil)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "1: GET http://example.com/path", res1)
|
||||
res2, err := do("GET", "http://example.com/path", nil)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "1: GET http://example.com/path", res2)
|
||||
var res string
|
||||
var err error
|
||||
|
||||
res3, err := do("GET", "http://example.com/path2", nil)
|
||||
res, err = do("GET", "http://example.com/path", nil)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "2: GET http://example.com/path2", res3)
|
||||
assert.Equal(t, "1: GET http://example.com/path", res)
|
||||
res, err = do("GET", "http://example.com/path", nil)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "1: GET http://example.com/path", res)
|
||||
|
||||
res4, err := do("POST", "http://example.com/path", bytes.NewBufferString(`hello`))
|
||||
res, err = do("GET", "http://example.com/path2", nil)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "3: POST http://example.com/path", res4)
|
||||
res5, err := do("POST", "http://example.com/path", bytes.NewBufferString(`hello`))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "3: POST http://example.com/path", res5)
|
||||
assert.Equal(t, "2: GET http://example.com/path2", res)
|
||||
|
||||
res6, err := do("POST", "http://example.com/path", bytes.NewBufferString(`hello2`))
|
||||
res, err = do("POST", "http://example.com/path2", nil)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "4: POST http://example.com/path", res6)
|
||||
assert.Equal(t, "3: POST http://example.com/path2", res)
|
||||
|
||||
res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello`))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "4: POST http://example.com/graphql", res)
|
||||
res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello`))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "4: POST http://example.com/graphql", res)
|
||||
|
||||
res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello2`))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "5: POST http://example.com/graphql", res)
|
||||
|
||||
res, err = do("GET", "http://example.com/error", nil)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "6: GET http://example.com/error", res)
|
||||
res, err = do("GET", "http://example.com/error", nil)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "7: GET http://example.com/error", res)
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue