diff --git a/api/cache.go b/api/cache.go index 1c9936d37..fc9d1c5ba 100644 --- a/api/cache.go +++ b/api/cache.go @@ -167,9 +167,13 @@ func (fs *fileStorage) store(key string, res *http.Response) error { defer f.Close() var origBody io.ReadCloser - origBody, res.Body = copyStream(res.Body) - defer res.Body.Close() + if res.Body != nil { + origBody, res.Body = copyStream(res.Body) + defer res.Body.Close() + } err = res.Write(f) - res.Body = origBody + if origBody != nil { + res.Body = origBody + } return err } diff --git a/pkg/cmd/api/api.go b/pkg/cmd/api/api.go index 8ebdd3bc7..8273b3635 100644 --- a/pkg/cmd/api/api.go +++ b/pkg/cmd/api/api.go @@ -14,8 +14,10 @@ import ( "strconv" "strings" "syscall" + "time" "github.com/MakeNowJust/heredoc" + "github.com/cli/cli/api" "github.com/cli/cli/internal/ghinstance" "github.com/cli/cli/internal/ghrepo" "github.com/cli/cli/pkg/cmdutil" @@ -38,6 +40,7 @@ type ApiOptions struct { ShowResponseHeaders bool Paginate bool Silent bool + CacheTTL time.Duration HttpClient func() (*http.Client, error) BaseRepo func() (ghrepo.Interface, error) @@ -176,6 +179,7 @@ func NewCmdApi(f *cmdutil.Factory, runF func(*ApiOptions) error) *cobra.Command cmd.Flags().BoolVar(&opts.Paginate, "paginate", false, "Make additional HTTP requests to fetch all pages of results") cmd.Flags().StringVar(&opts.RequestInputFile, "input", "", "The `file` to use as body for the HTTP request") cmd.Flags().BoolVar(&opts.Silent, "silent", false, "Do not print the response body") + cmd.Flags().DurationVar(&opts.CacheTTL, "cache", 0, "Cache the response, e.g. \"3600s\", \"60m\", \"1h\"") return cmd } @@ -219,6 +223,9 @@ func apiRun(opts *ApiOptions) error { if err != nil { return err } + if opts.CacheTTL > 0 { + httpClient = api.NewCachedClient(httpClient, opts.CacheTTL) + } headersOutputStream := opts.IO.Out if opts.Silent { diff --git a/pkg/cmd/api/api_test.go b/pkg/cmd/api/api_test.go index 6cd693ce1..69ffa100d 100644 --- a/pkg/cmd/api/api_test.go +++ b/pkg/cmd/api/api_test.go @@ -7,7 +7,9 @@ import ( "io/ioutil" "net/http" "os" + "path/filepath" "testing" + "time" "github.com/cli/cli/git" "github.com/cli/cli/internal/ghrepo" @@ -42,6 +44,7 @@ func Test_NewCmdApi(t *testing.T) { ShowResponseHeaders: false, Paginate: false, Silent: false, + CacheTTL: 0, }, wantsErr: false, }, @@ -60,6 +63,7 @@ func Test_NewCmdApi(t *testing.T) { ShowResponseHeaders: false, Paginate: false, Silent: false, + CacheTTL: 0, }, wantsErr: false, }, @@ -78,6 +82,7 @@ func Test_NewCmdApi(t *testing.T) { ShowResponseHeaders: false, Paginate: false, Silent: false, + CacheTTL: 0, }, wantsErr: false, }, @@ -96,6 +101,7 @@ func Test_NewCmdApi(t *testing.T) { ShowResponseHeaders: true, Paginate: false, Silent: false, + CacheTTL: 0, }, wantsErr: false, }, @@ -114,6 +120,7 @@ func Test_NewCmdApi(t *testing.T) { ShowResponseHeaders: false, Paginate: true, Silent: false, + CacheTTL: 0, }, wantsErr: false, }, @@ -132,6 +139,7 @@ func Test_NewCmdApi(t *testing.T) { ShowResponseHeaders: false, Paginate: false, Silent: true, + CacheTTL: 0, }, wantsErr: false, }, @@ -155,6 +163,7 @@ func Test_NewCmdApi(t *testing.T) { ShowResponseHeaders: false, Paginate: true, Silent: false, + CacheTTL: 0, }, wantsErr: false, }, @@ -178,6 +187,7 @@ func Test_NewCmdApi(t *testing.T) { ShowResponseHeaders: false, Paginate: false, Silent: false, + CacheTTL: 0, }, wantsErr: false, }, @@ -201,22 +211,35 @@ func Test_NewCmdApi(t *testing.T) { ShowResponseHeaders: false, Paginate: false, Silent: false, + CacheTTL: 0, + }, + wantsErr: false, + }, + { + name: "with cache", + cli: "user --cache 5m", + wants: ApiOptions{ + Hostname: "", + RequestMethod: "GET", + RequestMethodPassed: false, + RequestPath: "user", + RequestInputFile: "", + RawFields: []string(nil), + MagicFields: []string(nil), + RequestHeaders: []string(nil), + ShowResponseHeaders: false, + Paginate: false, + Silent: false, + CacheTTL: time.Minute * 5, }, wantsErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + var opts *ApiOptions cmd := NewCmdApi(f, func(o *ApiOptions) error { - assert.Equal(t, tt.wants.Hostname, o.Hostname) - assert.Equal(t, tt.wants.RequestMethod, o.RequestMethod) - assert.Equal(t, tt.wants.RequestMethodPassed, o.RequestMethodPassed) - assert.Equal(t, tt.wants.RequestPath, o.RequestPath) - assert.Equal(t, tt.wants.RequestInputFile, o.RequestInputFile) - assert.Equal(t, tt.wants.RawFields, o.RawFields) - assert.Equal(t, tt.wants.MagicFields, o.MagicFields) - assert.Equal(t, tt.wants.RequestHeaders, o.RequestHeaders) - assert.Equal(t, tt.wants.ShowResponseHeaders, o.ShowResponseHeaders) + opts = o return nil }) @@ -232,6 +255,19 @@ func Test_NewCmdApi(t *testing.T) { return } assert.NoError(t, err) + + assert.Equal(t, tt.wants.Hostname, opts.Hostname) + assert.Equal(t, tt.wants.RequestMethod, opts.RequestMethod) + assert.Equal(t, tt.wants.RequestMethodPassed, opts.RequestMethodPassed) + assert.Equal(t, tt.wants.RequestPath, opts.RequestPath) + assert.Equal(t, tt.wants.RequestInputFile, opts.RequestInputFile) + assert.Equal(t, tt.wants.RawFields, opts.RawFields) + assert.Equal(t, tt.wants.MagicFields, opts.MagicFields) + assert.Equal(t, tt.wants.RequestHeaders, opts.RequestHeaders) + assert.Equal(t, tt.wants.ShowResponseHeaders, opts.ShowResponseHeaders) + assert.Equal(t, tt.wants.Paginate, opts.Paginate) + assert.Equal(t, tt.wants.Silent, opts.Silent) + assert.Equal(t, tt.wants.CacheTTL, opts.CacheTTL) }) } } @@ -593,6 +629,42 @@ func Test_apiRun_inputFile(t *testing.T) { } } +func Test_apiRun_cache(t *testing.T) { + io, _, stdout, stderr := iostreams.Test() + + requestCount := 0 + options := ApiOptions{ + IO: io, + HttpClient: func() (*http.Client, error) { + var tr roundTripper = func(req *http.Request) (*http.Response, error) { + requestCount++ + return &http.Response{ + Request: req, + StatusCode: 204, + }, nil + } + return &http.Client{Transport: tr}, nil + }, + + RequestPath: "issues", + CacheTTL: time.Minute, + } + + t.Cleanup(func() { + cacheDir := filepath.Join(os.TempDir(), "gh-cli-cache") + os.RemoveAll(cacheDir) + }) + + err := apiRun(&options) + assert.NoError(t, err) + err = apiRun(&options) + assert.NoError(t, err) + + assert.Equal(t, 1, requestCount) + assert.Equal(t, "", stdout.String(), "stdout") + assert.Equal(t, "", stderr.String(), "stderr") +} + func Test_parseFields(t *testing.T) { io, stdin, _, _ := iostreams.Test() fmt.Fprint(stdin, "pasted contents")