From 8058c4ea34bca49212caf01a7635037653278939 Mon Sep 17 00:00:00 2001 From: lpessoa Date: Mon, 11 Oct 2021 18:55:39 +0000 Subject: [PATCH] Adding gh release download for .zip and .tar.gz MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Mislav Marohnić --- pkg/cmd/release/download/download.go | 90 +++++++++++++++++++---- pkg/cmd/release/download/download_test.go | 85 ++++++++++++++++++++- pkg/httpmock/stub.go | 11 +++ 3 files changed, 170 insertions(+), 16 deletions(-) diff --git a/pkg/cmd/release/download/download.go b/pkg/cmd/release/download/download.go index 5a2b48706..f410bae7b 100644 --- a/pkg/cmd/release/download/download.go +++ b/pkg/cmd/release/download/download.go @@ -2,7 +2,9 @@ package download import ( "errors" + "fmt" "io" + "mime" "net/http" "os" "path/filepath" @@ -27,6 +29,8 @@ type DownloadOptions struct { // maximum number of simultaneous downloads Concurrency int + + ArchiveType string } func NewCmdDownload(f *cmdutil.Factory, runF func(*DownloadOptions) error) *cobra.Command { @@ -47,12 +51,15 @@ func NewCmdDownload(f *cmdutil.Factory, runF func(*DownloadOptions) error) *cobr Example: heredoc.Doc(` # download all assets from a specific release $ gh release download v1.2.3 - + # download only Debian packages for the latest release $ gh release download --pattern '*.deb' - + # specify multiple file patterns $ gh release download -p '*.deb' -p '*.rpm' + + # download the archive of the source code for a release + $ gh release download v1.2.3 --archive=zip `), Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { @@ -67,6 +74,11 @@ func NewCmdDownload(f *cmdutil.Factory, runF func(*DownloadOptions) error) *cobr opts.TagName = args[0] } + // check archive type option validity + if err := checkArchiveTypeOption(opts); err != nil { + return err + } + opts.Concurrency = 5 if runF != nil { @@ -78,10 +90,30 @@ func NewCmdDownload(f *cmdutil.Factory, runF func(*DownloadOptions) error) *cobr cmd.Flags().StringVarP(&opts.Destination, "dir", "D", ".", "The directory to download files into") cmd.Flags().StringArrayVarP(&opts.FilePatterns, "pattern", "p", nil, "Download only assets that match a glob pattern") + cmd.Flags().StringVarP(&opts.ArchiveType, "archive", "A", "", "Download the source code archive in the specified `format` (zip or tar.gz)") return cmd } +func checkArchiveTypeOption(opts *DownloadOptions) error { + if len(opts.ArchiveType) == 0 { + return nil + } + + if err := cmdutil.MutuallyExclusive( + "specify only one of '--pattern' or '--archive'", + true, // ArchiveType len > 0 + len(opts.FilePatterns) > 0, + ); err != nil { + return err + } + + if opts.ArchiveType != "zip" && opts.ArchiveType != "tar.gz" { + return cmdutil.FlagErrorf("the value for `--archive` must be one of \"zip\" or \"tar.gz\"") + } + return nil +} + func downloadRun(opts *DownloadOptions) error { httpClient, err := opts.HttpClient() if err != nil { @@ -93,8 +125,10 @@ func downloadRun(opts *DownloadOptions) error { return err } - var release *shared.Release + opts.IO.StartProgressIndicator() + defer opts.IO.StopProgressIndicator() + var release *shared.Release if opts.TagName == "" { release, err = shared.FetchLatestRelease(httpClient, baseRepo) if err != nil { @@ -108,11 +142,22 @@ func downloadRun(opts *DownloadOptions) error { } var toDownload []shared.ReleaseAsset - for _, a := range release.Assets { - if len(opts.FilePatterns) > 0 && !matchAny(opts.FilePatterns, a.Name) { - continue + isArchive := false + if opts.ArchiveType != "" { + var archiveURL = release.ZipballURL + if opts.ArchiveType == "tar.gz" { + archiveURL = release.TarballURL + } + // create pseudo-Asset with no name and pointing to ZipBallURL or TarBallURL + toDownload = append(toDownload, shared.ReleaseAsset{APIURL: archiveURL}) + isArchive = true + } else { + for _, a := range release.Assets { + if len(opts.FilePatterns) > 0 && !matchAny(opts.FilePatterns, a.Name) { + continue + } + toDownload = append(toDownload, a) } - toDownload = append(toDownload, a) } if len(toDownload) == 0 { @@ -129,10 +174,7 @@ func downloadRun(opts *DownloadOptions) error { } } - opts.IO.StartProgressIndicator() - err = downloadAssets(httpClient, toDownload, opts.Destination, opts.Concurrency) - opts.IO.StopProgressIndicator() - return err + return downloadAssets(httpClient, toDownload, opts.Destination, opts.Concurrency, isArchive) } func matchAny(patterns []string, name string) bool { @@ -144,7 +186,7 @@ func matchAny(patterns []string, name string) bool { return false } -func downloadAssets(httpClient *http.Client, toDownload []shared.ReleaseAsset, destDir string, numWorkers int) error { +func downloadAssets(httpClient *http.Client, toDownload []shared.ReleaseAsset, destDir string, numWorkers int, isArchive bool) error { if numWorkers == 0 { return errors.New("the number of concurrent workers needs to be greater than 0") } @@ -159,7 +201,7 @@ func downloadAssets(httpClient *http.Client, toDownload []shared.ReleaseAsset, d for w := 1; w <= numWorkers; w++ { go func() { for a := range jobs { - results <- downloadAsset(httpClient, a.APIURL, filepath.Join(destDir, a.Name)) + results <- downloadAsset(httpClient, a.APIURL, destDir, a.Name, isArchive) } }() } @@ -179,13 +221,17 @@ func downloadAssets(httpClient *http.Client, toDownload []shared.ReleaseAsset, d return downloadError } -func downloadAsset(httpClient *http.Client, assetURL, destinationPath string) error { +func downloadAsset(httpClient *http.Client, assetURL, destinationDir string, fileName string, isArchive bool) error { req, err := http.NewRequest("GET", assetURL, nil) if err != nil { return err } req.Header.Set("Accept", "application/octet-stream") + // adding application/json to Accept header due to a bug in the zipball/tarball API endpoint that makes it mandatory + if isArchive { + req.Header.Set("Accept", "application/octet-stream, application/json") + } resp, err := httpClient.Do(req) if err != nil { @@ -197,6 +243,22 @@ func downloadAsset(httpClient *http.Client, assetURL, destinationPath string) er return api.HandleHTTPError(resp) } + var destinationPath = filepath.Join(destinationDir, fileName) + + if len(fileName) == 0 { + contentDisposition := resp.Header.Get("Content-Disposition") + + _, params, err := mime.ParseMediaType(contentDisposition) + if err != nil { + return fmt.Errorf("unable to parse file name of archive: %w", err) + } + if serverFileName, ok := params["filename"]; ok { + destinationPath = filepath.Join(destinationDir, serverFileName) + } else { + return errors.New("unable to determine file name of archive") + } + } + f, err := os.OpenFile(destinationPath, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0644) if err != nil { return err diff --git a/pkg/cmd/release/download/download_test.go b/pkg/cmd/release/download/download_test.go index 2c2aa5891..95c2f8c77 100644 --- a/pkg/cmd/release/download/download_test.go +++ b/pkg/cmd/release/download/download_test.go @@ -80,12 +80,36 @@ func Test_NewCmdDownload(t *testing.T) { Concurrency: 5, }, }, + { + name: "download archive with valid option", + args: "v1.2.3 -A zip", + isTTY: true, + want: DownloadOptions{ + TagName: "v1.2.3", + FilePatterns: []string(nil), + Destination: ".", + ArchiveType: "zip", + Concurrency: 5, + }, + }, { name: "no arguments", args: "", isTTY: true, wantErr: "the '--pattern' flag is required when downloading the latest release", }, + { + name: "simultaneous pattern and archive arguments", + args: "-p * -A zip", + isTTY: true, + wantErr: "specify only one of '--pattern' or '--archive'", + }, + { + name: "invalid archive argument", + args: "v1.2.3 -A abc", + isTTY: true, + wantErr: "the value for `--archive` must be one of \"zip\" or \"tar.gz\"", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -184,6 +208,36 @@ func Test_downloadRun(t *testing.T) { wantStderr: ``, wantErr: "no assets match the file pattern", }, + { + name: "download archive in zip format into destination directory", + isTTY: true, + opts: DownloadOptions{ + TagName: "v1.2.3", + ArchiveType: "zip", + Destination: "tmp/packages", + Concurrency: 2, + }, + wantStdout: ``, + wantStderr: ``, + wantFiles: []string{ + "tmp/packages/zipball.zip", + }, + }, + { + name: "download archive in `tar.gz` format into destination directory", + isTTY: true, + opts: DownloadOptions{ + TagName: "v1.2.3", + ArchiveType: "tar.gz", + Destination: "tmp/packages", + Concurrency: 2, + }, + wantStdout: ``, + wantStderr: ``, + wantFiles: []string{ + "tmp/packages/tarball.tgz", + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -204,12 +258,34 @@ func Test_downloadRun(t *testing.T) { "url": "https://api.github.com/assets/3456" }, { "name": "linux.tgz", "size": 56, "url": "https://api.github.com/assets/5678" } - ] + ], + "tarball_url": "https://api.github.com/repos/OWNER/REPO/tarball/v1.2.3", + "zipball_url": "https://api.github.com/repos/OWNER/REPO/zipball/v1.2.3" }`)) fakeHTTP.Register(httpmock.REST("GET", "assets/1234"), httpmock.StringResponse(`1234`)) fakeHTTP.Register(httpmock.REST("GET", "assets/3456"), httpmock.StringResponse(`3456`)) fakeHTTP.Register(httpmock.REST("GET", "assets/5678"), httpmock.StringResponse(`5678`)) + fakeHTTP.Register( + httpmock.REST( + "GET", + "repos/OWNER/REPO/tarball/v1.2.3", + ), + httpmock.WithHeader( + httpmock.StringResponse("somedata"), "content-disposition", "attachment; filename=tarball.tgz", + ), + ) + + fakeHTTP.Register( + httpmock.REST( + "GET", + "repos/OWNER/REPO/zipball/v1.2.3", + ), + httpmock.WithHeader( + httpmock.StringResponse("somedata"), "content-disposition", "attachment; filename=zipball.zip", + ), + ) + tt.opts.IO = io tt.opts.HttpClient = func() (*http.Client, error) { return &http.Client{Transport: fakeHTTP}, nil @@ -226,7 +302,12 @@ func Test_downloadRun(t *testing.T) { require.NoError(t, err) } - assert.Equal(t, "application/octet-stream", fakeHTTP.Requests[1].Header.Get("Accept")) + var expectedAcceptHeader = "application/octet-stream" + if len(tt.opts.ArchiveType) > 0 { + expectedAcceptHeader = "application/octet-stream, application/json" + } + + assert.Equal(t, expectedAcceptHeader, fakeHTTP.Requests[1].Header.Get("Accept")) assert.Equal(t, tt.wantStdout, stdout.String()) assert.Equal(t, tt.wantStderr, stderr.String()) diff --git a/pkg/httpmock/stub.go b/pkg/httpmock/stub.go index 98edbcb58..5633c2caf 100644 --- a/pkg/httpmock/stub.go +++ b/pkg/httpmock/stub.go @@ -77,6 +77,17 @@ func StringResponse(body string) Responder { } } +func WithHeader(responder Responder, header string, value string) Responder { + return func(req *http.Request) (*http.Response, error) { + resp, _ := responder(req) + if resp.Header == nil { + resp.Header = make(http.Header) + } + resp.Header.Set(header, value) + return resp, nil + } +} + func StatusStringResponse(status int, body string) Responder { return func(req *http.Request) (*http.Response, error) { return httpResponse(status, req, bytes.NewBufferString(body)), nil