From 39805fa9b1a5c04eb3c88a1c3bf89f28657d1d33 Mon Sep 17 00:00:00 2001 From: Sam Coe Date: Thu, 23 Mar 2023 12:17:47 +1100 Subject: [PATCH] Properly handle closing files that have been writen to (#7199) --- internal/config/config.go | 3 ++- pkg/cmd/extension/http.go | 33 ++++++++++++++++------------ pkg/cmd/extension/manager.go | 31 ++++++++++++++++---------- pkg/cmd/release/download/download.go | 33 ++++++++++++++++------------ pkg/cmd/run/download/zip.go | 33 +++++++++++++++++----------- pkg/cmd/run/download/zip_test.go | 13 +++-------- 6 files changed, 83 insertions(+), 63 deletions(-) diff --git a/internal/config/config.go b/internal/config/config.go index 96d6b5ed0..890dcec9f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -140,7 +140,8 @@ func (c *AuthConfig) Token(hostname string) (string, string) { return token, source } -// HasEnvToken checks whether the current env or config contains a token +// HasEnvToken returns true when a token has been specified in an +// environment variable, else returns false. func (c *AuthConfig) HasEnvToken() bool { // This will check if there are any environment variable // authentication tokens set for enterprise hosts. diff --git a/pkg/cmd/extension/http.go b/pkg/cmd/extension/http.go index 2fae2f023..90ccd64ce 100644 --- a/pkg/cmd/extension/http.go +++ b/pkg/cmd/extension/http.go @@ -74,32 +74,37 @@ type release struct { } // downloadAsset downloads a single asset to the given file path. -func downloadAsset(httpClient *http.Client, asset releaseAsset, destPath string) error { - req, err := http.NewRequest("GET", asset.APIURL, nil) - if err != nil { - return err +func downloadAsset(httpClient *http.Client, asset releaseAsset, destPath string) (downloadErr error) { + var req *http.Request + if req, downloadErr = http.NewRequest("GET", asset.APIURL, nil); downloadErr != nil { + return } req.Header.Set("Accept", "application/octet-stream") - resp, err := httpClient.Do(req) - if err != nil { - return err + var resp *http.Response + if resp, downloadErr = httpClient.Do(req); downloadErr != nil { + return } defer resp.Body.Close() if resp.StatusCode > 299 { - return api.HandleHTTPError(resp) + downloadErr = api.HandleHTTPError(resp) + return } - f, err := os.OpenFile(destPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0755) - if err != nil { - return err + var f *os.File + if f, downloadErr = os.OpenFile(destPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0755); downloadErr != nil { + return } - defer f.Close() + defer func() { + if err := f.Close(); downloadErr == nil && err != nil { + downloadErr = err + } + }() - _, err = io.Copy(f, resp.Body) - return err + _, downloadErr = io.Copy(f, resp.Body) + return } var commitNotFoundErr = errors.New("commit not found") diff --git a/pkg/cmd/extension/manager.go b/pkg/cmd/extension/manager.go index af4dbe9a4..7a76e6a47 100644 --- a/pkg/cmd/extension/manager.go +++ b/pkg/cmd/extension/manager.go @@ -436,23 +436,32 @@ func (m *Manager) installBin(repo ghrepo.Interface, target string) error { } if !m.dryRunMode { - manifestPath := filepath.Join(targetDir, manifestName) - - f, err := os.OpenFile(manifestPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) - if err != nil { - return fmt.Errorf("failed to open manifest for writing: %w", err) - } - defer f.Close() - - _, err = f.Write(bs) - if err != nil { - return fmt.Errorf("failed write manifest file: %w", err) + if err := writeManifest(targetDir, manifestName, bs); err != nil { + return err } } return nil } +func writeManifest(dir, name string, data []byte) (writeErr error) { + path := filepath.Join(dir, name) + var f *os.File + if f, writeErr = os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600); writeErr != nil { + writeErr = fmt.Errorf("failed to open manifest for writing: %w", writeErr) + return + } + defer func() { + if err := f.Close(); writeErr == nil && err != nil { + writeErr = err + } + }() + if _, writeErr = f.Write(data); writeErr != nil { + writeErr = fmt.Errorf("failed write manifest file: %w", writeErr) + } + return +} + func (m *Manager) installGit(repo ghrepo.Interface, target string) error { protocol, _ := m.config.GetOrDefault(repo.RepoHost(), "git_protocol") cloneURL := ghrepo.FormatRemoteURL(repo, protocol) diff --git a/pkg/cmd/release/download/download.go b/pkg/cmd/release/download/download.go index d595dca4a..b957ba509 100644 --- a/pkg/cmd/release/download/download.go +++ b/pkg/cmd/release/download/download.go @@ -356,29 +356,34 @@ func (w destinationWriter) check(fp string) error { return nil } -// Copy writes the data from r into a file specified by name -func (w destinationWriter) Copy(name string, r io.Reader) error { +// Copy writes the data from r into a file specified by name. +func (w destinationWriter) Copy(name string, r io.Reader) (copyErr error) { fp := w.makePath(name) if fp == "-" { - _, err := io.Copy(w.stdout, r) - return err + _, copyErr = io.Copy(w.stdout, r) + return } - if err := w.check(fp); err != nil { - return err + if copyErr = w.check(fp); copyErr != nil { + return } if dir := filepath.Dir(fp); dir != "." { - if err := os.MkdirAll(dir, 0755); err != nil { - return err + if copyErr = os.MkdirAll(dir, 0755); copyErr != nil { + return } } - f, err := os.OpenFile(fp, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) - if err != nil { - return err + var f *os.File + if f, copyErr = os.OpenFile(fp, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644); copyErr != nil { + return } - defer f.Close() - _, err = io.Copy(f, r) - return err + defer func() { + if err := f.Close(); copyErr == nil && err != nil { + copyErr = err + } + }() + + _, copyErr = io.Copy(f, r) + return } diff --git a/pkg/cmd/run/download/zip.go b/pkg/cmd/run/download/zip.go index bf56ea081..ab5723e94 100644 --- a/pkg/cmd/run/download/zip.go +++ b/pkg/cmd/run/download/zip.go @@ -28,32 +28,39 @@ func extractZip(zr *zip.Reader, destDir string) error { return nil } -func extractZipFile(zf *zip.File, dest string) error { +func extractZipFile(zf *zip.File, dest string) (extractErr error) { zm := zf.Mode() if zm.IsDir() { - return os.MkdirAll(dest, dirMode) + extractErr = os.MkdirAll(dest, dirMode) + return } - f, err := zf.Open() - if err != nil { - return err + var f io.ReadCloser + f, extractErr = zf.Open() + if extractErr != nil { + return } defer f.Close() if dir := filepath.Dir(dest); dir != "." { - if err := os.MkdirAll(dir, dirMode); err != nil { - return err + if extractErr = os.MkdirAll(dir, dirMode); extractErr != nil { + return } } - df, err := os.OpenFile(dest, os.O_WRONLY|os.O_CREATE|os.O_EXCL, getPerm(zm)) - if err != nil { - return err + var df *os.File + if df, extractErr = os.OpenFile(dest, os.O_WRONLY|os.O_CREATE|os.O_EXCL, getPerm(zm)); extractErr != nil { + return } - defer df.Close() - _, err = io.Copy(df, f) - return err + defer func() { + if err := df.Close(); extractErr == nil && err != nil { + extractErr = err + } + }() + + _, extractErr = io.Copy(df, f) + return } func getPerm(m os.FileMode) os.FileMode { diff --git a/pkg/cmd/run/download/zip_test.go b/pkg/cmd/run/download/zip_test.go index 97861b183..ca401cdb9 100644 --- a/pkg/cmd/run/download/zip_test.go +++ b/pkg/cmd/run/download/zip_test.go @@ -11,23 +11,16 @@ import ( func Test_extractZip(t *testing.T) { tmpDir := t.TempDir() - wd, err := os.Getwd() - require.NoError(t, err) - t.Cleanup(func() { _ = os.Chdir(wd) }) + extractPath := filepath.Join(tmpDir, "artifact") zipFile, err := zip.OpenReader("./fixtures/myproject.zip") require.NoError(t, err) defer zipFile.Close() - extractPath := filepath.Join(tmpDir, "artifact") - err = os.MkdirAll(extractPath, 0700) - require.NoError(t, err) - require.NoError(t, os.Chdir(extractPath)) - - err = extractZip(&zipFile.Reader, ".") + err = extractZip(&zipFile.Reader, extractPath) require.NoError(t, err) - _, err = os.Stat(filepath.Join("src", "main.go")) + _, err = os.Stat(filepath.Join(extractPath, "src", "main.go")) require.NoError(t, err) }