From fd7b87f3fa0263cbd70feea93dee75b9fc4595e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Wed, 27 May 2020 21:01:01 +0200 Subject: [PATCH] Allow writing host-specific keys in a blank config --- internal/config/config_file.go | 8 +++++- internal/config/config_file_test.go | 2 +- internal/config/config_setup.go | 42 ++++++---------------------- internal/config/config_type.go | 43 +++++++++++++++++++++++++++-- internal/config/config_type_test.go | 29 +++++++++++++++++++ 5 files changed, 86 insertions(+), 38 deletions(-) create mode 100644 internal/config/config_type_test.go diff --git a/internal/config/config_file.go b/internal/config/config_file.go index 2670aff20..67304eb8f 100644 --- a/internal/config/config_file.go +++ b/internal/config/config_file.go @@ -7,6 +7,7 @@ import ( "io/ioutil" "os" "path" + "path/filepath" "github.com/mitchellh/go-homedir" "gopkg.in/yaml.v3" @@ -49,7 +50,12 @@ var ReadConfigFile = func(fn string) ([]byte, error) { } var WriteConfigFile = func(fn string, data []byte) error { - cfgFile, err := os.OpenFile(ConfigFile(), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) // cargo coded from setup + err := os.MkdirAll(filepath.Dir(fn), 0771) + if err != nil { + return err + } + + cfgFile, err := os.OpenFile(fn, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) // cargo coded from setup if err != nil { return err } diff --git a/internal/config/config_file_test.go b/internal/config/config_file_test.go index 6af10d0e1..28d50a37e 100644 --- a/internal/config/config_file_test.go +++ b/internal/config/config_file_test.go @@ -63,7 +63,7 @@ hosts: config, err := ParseConfig("config.yml") eq(t, err, nil) _, err = config.Get("github.com", "user") - eq(t, err, errors.New(`could not find config entry for "github.com"`)) + eq(t, err, &NotFoundError{errors.New(`could not find config entry for "github.com"`)}) } func Test_migrateConfig(t *testing.T) { diff --git a/internal/config/config_setup.go b/internal/config/config_setup.go index 2fc414a1b..17dbcb190 100644 --- a/internal/config/config_setup.go +++ b/internal/config/config_setup.go @@ -5,12 +5,10 @@ import ( "fmt" "io" "os" - "path/filepath" "strings" "github.com/cli/cli/api" "github.com/cli/cli/auth" - "gopkg.in/yaml.v3" ) const ( @@ -76,44 +74,20 @@ func setupConfigFile(filename string) (Config, error) { return nil, err } - // TODO this sucks. It precludes us laying out a nice config with comments and such. - type yamlConfig struct { - Hosts map[string]map[string]string + cfg := NewBlankConfig() + err = cfg.Set(oauthHost, "user", userLogin) + if err != nil { + return nil, err } - - yamlHosts := map[string]map[string]string{} - yamlHosts[oauthHost] = map[string]string{} - yamlHosts[oauthHost]["user"] = userLogin - yamlHosts[oauthHost]["oauth_token"] = token - - defaultConfig := yamlConfig{ - Hosts: yamlHosts, - } - - err = os.MkdirAll(filepath.Dir(filename), 0771) + err = cfg.Set(oauthHost, "oauth_token", token) if err != nil { return nil, err } - cfgFile, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) - if err != nil { - return nil, err + if err = cfg.Write(); err == nil { + AuthFlowComplete() } - defer cfgFile.Close() - - yamlData, err := yaml.Marshal(defaultConfig) - if err != nil { - return nil, err - } - _, err = cfgFile.Write(yamlData) - if err != nil { - return nil, err - } - - AuthFlowComplete() - - // TODO cleaner error handling? this "should" always work given that we /just/ wrote the file... - return ParseConfig(filename) + return cfg, err } func getViewer(token string) (string, error) { diff --git a/internal/config/config_type.go b/internal/config/config_type.go index 6517e1be3..4c38c421c 100644 --- a/internal/config/config_type.go +++ b/internal/config/config_type.go @@ -88,6 +88,13 @@ func NewConfig(root *yaml.Node) Config { } } +func NewBlankConfig() Config { + return NewConfig(&yaml.Node{ + Kind: yaml.DocumentNode, + Content: []*yaml.Node{{Kind: yaml.MappingNode}}, + }) +} + // This type implements a Config interface and represents a config file on disk. type fileConfig struct { ConfigMap @@ -136,7 +143,10 @@ func (c *fileConfig) Set(hostname, key, value string) error { return c.SetStringValue(key, value) } else { hostCfg, err := c.configForHost(hostname) - if err != nil { + var notFound *NotFoundError + if errors.As(err, ¬Found) { + hostCfg = c.makeConfigForHost(hostname) + } else if err != nil { return err } return hostCfg.SetStringValue(key, value) @@ -154,7 +164,7 @@ func (c *fileConfig) configForHost(hostname string) (*HostConfig, error) { return hc, nil } } - return nil, fmt.Errorf("could not find config entry for %q", hostname) + return nil, &NotFoundError{fmt.Errorf("could not find config entry for %q", hostname)} } func (c *fileConfig) Write() error { @@ -186,6 +196,35 @@ func (c *fileConfig) hostEntries() ([]*HostConfig, error) { return hostConfigs, nil } +func (c *fileConfig) makeConfigForHost(hostname string) *HostConfig { + hostRoot := &yaml.Node{Kind: yaml.MappingNode} + hostCfg := &HostConfig{ + Host: hostname, + ConfigMap: ConfigMap{Root: hostRoot}, + } + + var notFound *NotFoundError + _, hostsEntry, err := c.FindEntry("hosts") + if errors.As(err, ¬Found) { + hostsEntry = &yaml.Node{Kind: yaml.MappingNode} + c.Root.Content = append(c.Root.Content, + &yaml.Node{ + Kind: yaml.ScalarNode, + Value: "hosts", + }, hostsEntry) + } else if err != nil { + panic(err) + } + + hostsEntry.Content = append(hostsEntry.Content, + &yaml.Node{ + Kind: yaml.ScalarNode, + Value: hostname, + }, hostRoot) + + return hostCfg +} + func (c *fileConfig) parseHosts(hostsEntry *yaml.Node) ([]*HostConfig, error) { hostConfigs := []*HostConfig{} diff --git a/internal/config/config_type_test.go b/internal/config/config_type_test.go new file mode 100644 index 000000000..eb0747571 --- /dev/null +++ b/internal/config/config_type_test.go @@ -0,0 +1,29 @@ +package config + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_fileConfig_Set(t *testing.T) { + cb := bytes.Buffer{} + StubWriteConfig(&cb, nil) + + c := NewBlankConfig() + assert.NoError(t, c.Set("", "editor", "nano")) + assert.NoError(t, c.Set("github.com", "git_protocol", "ssh")) + assert.NoError(t, c.Set("example.com", "editor", "vim")) + assert.NoError(t, c.Set("github.com", "user", "hubot")) + assert.NoError(t, c.Write()) + + assert.Equal(t, `editor: nano +hosts: + github.com: + git_protocol: ssh + user: hubot + example.com: + editor: vim +`, cb.String()) +}