Properly handle closing files that have been writen to (#7199)

This commit is contained in:
Sam Coe 2023-03-23 12:17:47 +11:00 committed by GitHub
parent 3534cf7527
commit 39805fa9b1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 83 additions and 63 deletions

View file

@ -140,7 +140,8 @@ func (c *AuthConfig) Token(hostname string) (string, string) {
return token, source
}
// HasEnvToken checks whether the current env or config contains a token
// HasEnvToken returns true when a token has been specified in an
// environment variable, else returns false.
func (c *AuthConfig) HasEnvToken() bool {
// This will check if there are any environment variable
// authentication tokens set for enterprise hosts.

View file

@ -74,32 +74,37 @@ type release struct {
}
// downloadAsset downloads a single asset to the given file path.
func downloadAsset(httpClient *http.Client, asset releaseAsset, destPath string) error {
req, err := http.NewRequest("GET", asset.APIURL, nil)
if err != nil {
return err
func downloadAsset(httpClient *http.Client, asset releaseAsset, destPath string) (downloadErr error) {
var req *http.Request
if req, downloadErr = http.NewRequest("GET", asset.APIURL, nil); downloadErr != nil {
return
}
req.Header.Set("Accept", "application/octet-stream")
resp, err := httpClient.Do(req)
if err != nil {
return err
var resp *http.Response
if resp, downloadErr = httpClient.Do(req); downloadErr != nil {
return
}
defer resp.Body.Close()
if resp.StatusCode > 299 {
return api.HandleHTTPError(resp)
downloadErr = api.HandleHTTPError(resp)
return
}
f, err := os.OpenFile(destPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0755)
if err != nil {
return err
var f *os.File
if f, downloadErr = os.OpenFile(destPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0755); downloadErr != nil {
return
}
defer f.Close()
defer func() {
if err := f.Close(); downloadErr == nil && err != nil {
downloadErr = err
}
}()
_, err = io.Copy(f, resp.Body)
return err
_, downloadErr = io.Copy(f, resp.Body)
return
}
var commitNotFoundErr = errors.New("commit not found")

View file

@ -436,23 +436,32 @@ func (m *Manager) installBin(repo ghrepo.Interface, target string) error {
}
if !m.dryRunMode {
manifestPath := filepath.Join(targetDir, manifestName)
f, err := os.OpenFile(manifestPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return fmt.Errorf("failed to open manifest for writing: %w", err)
}
defer f.Close()
_, err = f.Write(bs)
if err != nil {
return fmt.Errorf("failed write manifest file: %w", err)
if err := writeManifest(targetDir, manifestName, bs); err != nil {
return err
}
}
return nil
}
func writeManifest(dir, name string, data []byte) (writeErr error) {
path := filepath.Join(dir, name)
var f *os.File
if f, writeErr = os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600); writeErr != nil {
writeErr = fmt.Errorf("failed to open manifest for writing: %w", writeErr)
return
}
defer func() {
if err := f.Close(); writeErr == nil && err != nil {
writeErr = err
}
}()
if _, writeErr = f.Write(data); writeErr != nil {
writeErr = fmt.Errorf("failed write manifest file: %w", writeErr)
}
return
}
func (m *Manager) installGit(repo ghrepo.Interface, target string) error {
protocol, _ := m.config.GetOrDefault(repo.RepoHost(), "git_protocol")
cloneURL := ghrepo.FormatRemoteURL(repo, protocol)

View file

@ -356,29 +356,34 @@ func (w destinationWriter) check(fp string) error {
return nil
}
// Copy writes the data from r into a file specified by name
func (w destinationWriter) Copy(name string, r io.Reader) error {
// Copy writes the data from r into a file specified by name.
func (w destinationWriter) Copy(name string, r io.Reader) (copyErr error) {
fp := w.makePath(name)
if fp == "-" {
_, err := io.Copy(w.stdout, r)
return err
_, copyErr = io.Copy(w.stdout, r)
return
}
if err := w.check(fp); err != nil {
return err
if copyErr = w.check(fp); copyErr != nil {
return
}
if dir := filepath.Dir(fp); dir != "." {
if err := os.MkdirAll(dir, 0755); err != nil {
return err
if copyErr = os.MkdirAll(dir, 0755); copyErr != nil {
return
}
}
f, err := os.OpenFile(fp, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
if err != nil {
return err
var f *os.File
if f, copyErr = os.OpenFile(fp, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644); copyErr != nil {
return
}
defer f.Close()
_, err = io.Copy(f, r)
return err
defer func() {
if err := f.Close(); copyErr == nil && err != nil {
copyErr = err
}
}()
_, copyErr = io.Copy(f, r)
return
}

View file

@ -28,32 +28,39 @@ func extractZip(zr *zip.Reader, destDir string) error {
return nil
}
func extractZipFile(zf *zip.File, dest string) error {
func extractZipFile(zf *zip.File, dest string) (extractErr error) {
zm := zf.Mode()
if zm.IsDir() {
return os.MkdirAll(dest, dirMode)
extractErr = os.MkdirAll(dest, dirMode)
return
}
f, err := zf.Open()
if err != nil {
return err
var f io.ReadCloser
f, extractErr = zf.Open()
if extractErr != nil {
return
}
defer f.Close()
if dir := filepath.Dir(dest); dir != "." {
if err := os.MkdirAll(dir, dirMode); err != nil {
return err
if extractErr = os.MkdirAll(dir, dirMode); extractErr != nil {
return
}
}
df, err := os.OpenFile(dest, os.O_WRONLY|os.O_CREATE|os.O_EXCL, getPerm(zm))
if err != nil {
return err
var df *os.File
if df, extractErr = os.OpenFile(dest, os.O_WRONLY|os.O_CREATE|os.O_EXCL, getPerm(zm)); extractErr != nil {
return
}
defer df.Close()
_, err = io.Copy(df, f)
return err
defer func() {
if err := df.Close(); extractErr == nil && err != nil {
extractErr = err
}
}()
_, extractErr = io.Copy(df, f)
return
}
func getPerm(m os.FileMode) os.FileMode {

View file

@ -11,23 +11,16 @@ import (
func Test_extractZip(t *testing.T) {
tmpDir := t.TempDir()
wd, err := os.Getwd()
require.NoError(t, err)
t.Cleanup(func() { _ = os.Chdir(wd) })
extractPath := filepath.Join(tmpDir, "artifact")
zipFile, err := zip.OpenReader("./fixtures/myproject.zip")
require.NoError(t, err)
defer zipFile.Close()
extractPath := filepath.Join(tmpDir, "artifact")
err = os.MkdirAll(extractPath, 0700)
require.NoError(t, err)
require.NoError(t, os.Chdir(extractPath))
err = extractZip(&zipFile.Reader, ".")
err = extractZip(&zipFile.Reader, extractPath)
require.NoError(t, err)
_, err = os.Stat(filepath.Join("src", "main.go"))
_, err = os.Stat(filepath.Join(extractPath, "src", "main.go"))
require.NoError(t, err)
}