diff --git a/pkg/cmd/api/api_test.go b/pkg/cmd/api/api_test.go index 3ffc6a0b3..02b08f842 100644 --- a/pkg/cmd/api/api_test.go +++ b/pkg/cmd/api/api_test.go @@ -6,8 +6,8 @@ import ( "fmt" "io" "net/http" + "net/http/httptest" "os" - "path/filepath" "runtime" "strings" "testing" @@ -1063,40 +1063,41 @@ func Test_apiRun_inputFile(t *testing.T) { } func Test_apiRun_cache(t *testing.T) { - ios, _, stdout, stderr := iostreams.Test() - + // Given we have a test server that spies on the number of requests it receives requestCount := 0 + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestCount++ + w.WriteHeader(http.StatusNoContent) + })) + t.Cleanup(s.Close) + + ios, _, stdout, stderr := iostreams.Test() options := ApiOptions{ IO: ios, - HttpClient: func() (*http.Client, error) { - var tr roundTripper = func(req *http.Request) (*http.Response, error) { - requestCount++ - return &http.Response{ - Request: req, - StatusCode: 204, - }, nil - } - return &http.Client{Transport: tr}, nil - }, Config: func() (config.Config, error) { - return config.NewBlankConfig(), nil + return &config.ConfigMock{ + AuthenticationFunc: func() *config.AuthConfig { + return &config.AuthConfig{} + }, + // Cached responses are stored in a tempdir that gets automatically cleaned up + CacheDirFunc: func() string { + return t.TempDir() + }, + }, nil }, - - RequestPath: "issues", + // You might think that we want to set Host: s.URL here, but you'd be wrong. + // The host field is later used to evaluate an API URL e.g. https://api.host.com/graphql + // The RequestPath field is used exactly as is, for the request if it includes a host. + RequestPath: s.URL, CacheTTL: time.Minute, } - t.Cleanup(func() { - cacheDir := filepath.Join(os.TempDir(), "gh-cli-cache") - os.RemoveAll(cacheDir) - }) + // When we run the API behaviour twice + require.NoError(t, apiRun(&options)) + require.NoError(t, apiRun(&options)) - err := apiRun(&options) - assert.NoError(t, err) - err = apiRun(&options) - assert.NoError(t, err) - - assert.Equal(t, 2, requestCount) + // We only get one request to the http server because it uses the cached response + assert.Equal(t, 1, requestCount) assert.Equal(t, "", stdout.String(), "stdout") assert.Equal(t, "", stderr.String(), "stderr") }