diff --git a/internal/config/config_file.go b/internal/config/config_file.go index 540bafa46..e3d7d2d52 100644 --- a/internal/config/config_file.go +++ b/internal/config/config_file.go @@ -3,10 +3,8 @@ package config import ( "errors" "fmt" - "io" "io/ioutil" "os" - "path" "path/filepath" "syscall" @@ -111,7 +109,7 @@ var ReadConfigFile = func(filename string) ([]byte, error) { } var WriteConfigFile = func(filename string, data []byte) error { - err := os.MkdirAll(path.Dir(filename), 0771) + err := os.MkdirAll(filepath.Dir(filename), 0771) if err != nil { return pathError(err) } @@ -122,11 +120,7 @@ var WriteConfigFile = func(filename string, data []byte) error { } defer cfgFile.Close() - n, err := cfgFile.Write(data) - if err == nil && n < len(data) { - err = io.ErrShortWrite - } - + _, err = cfgFile.Write(data) return err } @@ -263,7 +257,7 @@ func findRegularFile(p string) string { if s, err := os.Stat(p); err == nil && s.Mode().IsRegular() { return p } - newPath := path.Dir(p) + newPath := filepath.Dir(p) if newPath == p || newPath == "/" || newPath == "." { break } diff --git a/internal/config/config_file_test.go b/internal/config/config_file_test.go index f12629d7a..c7343b7ea 100644 --- a/internal/config/config_file_test.go +++ b/internal/config/config_file_test.go @@ -3,7 +3,9 @@ package config import ( "bytes" "fmt" + "io/ioutil" "os" + "path/filepath" "testing" "github.com/stretchr/testify/assert" @@ -167,3 +169,28 @@ func Test_ConfigDir(t *testing.T) { }) } } + +func Test_configFile_Write_toDisk(t *testing.T) { + configDir := filepath.Join(t.TempDir(), ".config", "gh") + os.Setenv(GH_CONFIG_DIR, configDir) + defer os.Unsetenv(GH_CONFIG_DIR) + + cfg := NewFromString(`pager: less`) + err := cfg.Write() + if err != nil { + t.Fatal(err) + } + + expectedConfig := "pager: less\n" + if configBytes, err := ioutil.ReadFile(filepath.Join(configDir, "config.yml")); err != nil { + t.Error(err) + } else if string(configBytes) != expectedConfig { + t.Errorf("expected config.yml %q, got %q", expectedConfig, string(configBytes)) + } + + if configBytes, err := ioutil.ReadFile(filepath.Join(configDir, "hosts.yml")); err != nil { + t.Error(err) + } else if string(configBytes) != "" { + t.Errorf("unexpected hosts.yml: %q", string(configBytes)) + } +}