diff --git a/command/config_test.go b/command/config_test.go index 61f89ddf8..53654f54f 100644 --- a/command/config_test.go +++ b/command/config_test.go @@ -49,23 +49,31 @@ func TestConfigGet_not_found(t *testing.T) { func TestConfigSet(t *testing.T) { initBlankContext("", "OWNER/REPO", "master") - buf := bytes.NewBufferString("") - defer config.StubWriteConfig(buf, nil)() + mainBuf := bytes.Buffer{} + hostsBuf := bytes.Buffer{} + defer config.StubWriteConfig(&mainBuf, &hostsBuf)() + output, err := RunCommand("config set editor ed") if err != nil { t.Fatalf("error running command `config set editor ed`: %v", err) } - eq(t, output.String(), "") + if len(output.String()) > 0 { + t.Errorf("expected output to be blank: %q", output.String()) + } - expected := `hosts: - github.com: - user: OWNER - oauth_token: 1234567890 -editor: ed + expectedMain := "editor: ed\n" + expectedHosts := `github.com: + user: OWNER + oauth_token: "1234567890" ` - eq(t, buf.String(), expected) + if mainBuf.String() != expectedMain { + t.Errorf("expected config.yml to be %q, got %q", expectedMain, mainBuf.String()) + } + if hostsBuf.String() != expectedHosts { + t.Errorf("expected hosts.yml to be %q, got %q", expectedHosts, hostsBuf.String()) + } } func TestConfigSet_update(t *testing.T) { @@ -79,23 +87,31 @@ editor: ed initBlankContext(cfg, "OWNER/REPO", "master") - buf := bytes.NewBufferString("") - defer config.StubWriteConfig(buf, nil)() + mainBuf := bytes.Buffer{} + hostsBuf := bytes.Buffer{} + defer config.StubWriteConfig(&mainBuf, &hostsBuf)() output, err := RunCommand("config set editor vim") if err != nil { t.Fatalf("error running command `config get editor`: %v", err) } - eq(t, output.String(), "") + if len(output.String()) > 0 { + t.Errorf("expected output to be blank: %q", output.String()) + } - expected := `hosts: - github.com: - user: OWNER - oauth_token: MUSTBEHIGHCUZIMATOKEN -editor: vim + expectedMain := "editor: vim\n" + expectedHosts := `github.com: + user: OWNER + oauth_token: MUSTBEHIGHCUZIMATOKEN ` - eq(t, buf.String(), expected) + + if mainBuf.String() != expectedMain { + t.Errorf("expected config.yml to be %q, got %q", expectedMain, mainBuf.String()) + } + if hostsBuf.String() != expectedHosts { + t.Errorf("expected hosts.yml to be %q, got %q", expectedHosts, hostsBuf.String()) + } } func TestConfigGetHost(t *testing.T) { @@ -141,23 +157,32 @@ git_protocol: ssh func TestConfigSetHost(t *testing.T) { initBlankContext("", "OWNER/REPO", "master") - buf := bytes.NewBufferString("") - defer config.StubWriteConfig(buf, nil)() + mainBuf := bytes.Buffer{} + hostsBuf := bytes.Buffer{} + defer config.StubWriteConfig(&mainBuf, &hostsBuf)() + output, err := RunCommand("config set -hgithub.com git_protocol ssh") if err != nil { t.Fatalf("error running command `config set editor ed`: %v", err) } - eq(t, output.String(), "") + if len(output.String()) > 0 { + t.Errorf("expected output to be blank: %q", output.String()) + } - expected := `hosts: - github.com: - user: OWNER - oauth_token: 1234567890 - git_protocol: ssh + expectedMain := "" + expectedHosts := `github.com: + user: OWNER + oauth_token: "1234567890" + git_protocol: ssh ` - eq(t, buf.String(), expected) + if mainBuf.String() != expectedMain { + t.Errorf("expected config.yml to be %q, got %q", expectedMain, mainBuf.String()) + } + if hostsBuf.String() != expectedHosts { + t.Errorf("expected hosts.yml to be %q, got %q", expectedHosts, hostsBuf.String()) + } } func TestConfigSetHost_update(t *testing.T) { @@ -171,21 +196,30 @@ hosts: initBlankContext(cfg, "OWNER/REPO", "master") - buf := bytes.NewBufferString("") - defer config.StubWriteConfig(buf, nil)() + mainBuf := bytes.Buffer{} + hostsBuf := bytes.Buffer{} + defer config.StubWriteConfig(&mainBuf, &hostsBuf)() output, err := RunCommand("config set -hgithub.com git_protocol https") if err != nil { t.Fatalf("error running command `config get editor`: %v", err) } - eq(t, output.String(), "") + if len(output.String()) > 0 { + t.Errorf("expected output to be blank: %q", output.String()) + } - expected := `hosts: - github.com: - git_protocol: https - user: OWNER - oauth_token: MUSTBEHIGHCUZIMATOKEN + expectedMain := "" + expectedHosts := `github.com: + git_protocol: https + user: OWNER + oauth_token: MUSTBEHIGHCUZIMATOKEN ` - eq(t, buf.String(), expected) + + if mainBuf.String() != expectedMain { + t.Errorf("expected config.yml to be %q, got %q", expectedMain, mainBuf.String()) + } + if hostsBuf.String() != expectedHosts { + t.Errorf("expected hosts.yml to be %q, got %q", expectedHosts, hostsBuf.String()) + } } diff --git a/command/testing.go b/command/testing.go index 2ca2ad19e..af713aa29 100644 --- a/command/testing.go +++ b/command/testing.go @@ -19,7 +19,7 @@ import ( const defaultTestConfig = `hosts: github.com: user: OWNER - oauth_token: 1234567890 + oauth_token: "1234567890" ` type askStubber struct { diff --git a/internal/config/config_file.go b/internal/config/config_file.go index ff10d446f..71fffb0d1 100644 --- a/internal/config/config_file.go +++ b/internal/config/config_file.go @@ -7,7 +7,6 @@ import ( "io/ioutil" "os" "path" - "path/filepath" "github.com/mitchellh/go-homedir" "gopkg.in/yaml.v3" @@ -22,6 +21,10 @@ func ConfigFile() string { return path.Join(ConfigDir(), "config.yml") } +func hostsConfigFile(fn string) string { + return path.Join(path.Dir(fn), "hosts.yml") +} + func ParseDefaultConfig() (Config, error) { return ParseConfig(ConfigFile()) } @@ -42,7 +45,7 @@ var ReadConfigFile = func(fn string) ([]byte, error) { } var WriteConfigFile = func(fn string, data []byte) error { - err := os.MkdirAll(filepath.Dir(fn), 0771) + err := os.MkdirAll(path.Dir(fn), 0771) if err != nil { return err } @@ -91,73 +94,44 @@ func parseConfigFile(fn string) ([]byte, *yaml.Node, error) { func isLegacy(root *yaml.Node) bool { for _, v := range root.Content[0].Content { - if v.Value == "hosts" { - return false + if v.Value == "github.com" { + return true } } - return true + return false } -func migrateConfig(fn string, root *yaml.Node) error { - type ConfigEntry map[string]string - type ConfigHash map[string]ConfigEntry - - newConfigData := map[string]ConfigHash{} - newConfigData["hosts"] = ConfigHash{} - - topLevelKeys := root.Content[0].Content - - for i, x := range topLevelKeys { - if x.Value == "" { - continue - } - if i+1 == len(topLevelKeys) { - break - } - hostname := x.Value - newConfigData["hosts"][hostname] = ConfigEntry{} - - authKeys := topLevelKeys[i+1].Content[0].Content - - for j, y := range authKeys { - if j+1 == len(authKeys) { - break - } - switch y.Value { - case "user": - newConfigData["hosts"][hostname]["user"] = authKeys[j+1].Value - case "oauth_token": - newConfigData["hosts"][hostname]["oauth_token"] = authKeys[j+1].Value - } - } - } - - if _, ok := newConfigData["hosts"][defaultHostname]; !ok { - return errors.New("could not find default host configuration") - } - - defaultHostConfig := newConfigData["hosts"][defaultHostname] - - if _, ok := defaultHostConfig["user"]; !ok { - return errors.New("default host configuration missing user") - } - - if _, ok := defaultHostConfig["oauth_token"]; !ok { - return errors.New("default host configuration missing oauth_token") - } - - newConfig, err := yaml.Marshal(newConfigData) +func migrateConfig(fn string) error { + b, err := ReadConfigFile(fn) if err != nil { return err } + var hosts map[string][]map[string]string + err = yaml.Unmarshal(b, &hosts) + if err != nil { + return fmt.Errorf("error decoding legacy format: %w", err) + } + + cfg := NewBlankConfig() + for hostname, entries := range hosts { + if len(entries) < 1 { + continue + } + for key, value := range entries[0] { + if err := cfg.Set(hostname, key, value); err != nil { + return err + } + } + } + err = BackupConfigFile(fn) if err != nil { return fmt.Errorf("failed to back up existing config: %w", err) } - return WriteConfigFile(fn, newConfig) + return cfg.Write() } func ParseConfig(fn string) (Config, error) { @@ -167,15 +141,28 @@ func ParseConfig(fn string) (Config, error) { } if isLegacy(root) { - err = migrateConfig(fn, root) + err = migrateConfig(fn) if err != nil { - return nil, err + return nil, fmt.Errorf("error migrating legacy config: %w", err) } _, root, err = parseConfigFile(fn) if err != nil { return nil, fmt.Errorf("failed to reparse migrated config: %w", err) } + } else { + if _, hostsRoot, err := parseConfigFile(hostsConfigFile(fn)); err == nil { + if len(hostsRoot.Content[0].Content) > 0 { + newContent := []*yaml.Node{ + {Value: "hosts"}, + hostsRoot.Content[0], + } + restContent := root.Content[0].Content + root.Content[0].Content = append(newContent, restContent...) + } + } else if !errors.Is(err, os.ErrNotExist) { + return nil, err + } } return NewConfig(root), nil diff --git a/internal/config/config_file_test.go b/internal/config/config_file_test.go index 441dc22f2..edc3a6ae9 100644 --- a/internal/config/config_file_test.go +++ b/internal/config/config_file_test.go @@ -54,6 +54,22 @@ hosts: eq(t, token, "OTOKEN") } +func Test_parseConfig_hostsFile(t *testing.T) { + defer StubConfig("", `--- +github.com: + user: monalisa + oauth_token: OTOKEN +`)() + config, err := ParseConfig("config.yml") + eq(t, err, nil) + user, err := config.Get("github.com", "user") + eq(t, err, nil) + eq(t, user, "monalisa") + token, err := config.Get("github.com", "oauth_token") + eq(t, err, nil) + eq(t, token, "OTOKEN") +} + func Test_parseConfig_notFound(t *testing.T) { defer StubConfig(`--- hosts: @@ -67,33 +83,29 @@ hosts: eq(t, err, &NotFoundError{errors.New(`could not find config entry for "github.com"`)}) } -func Test_migrateConfig(t *testing.T) { - oldStyle := `--- +func Test_ParseConfig_migrateConfig(t *testing.T) { + defer StubConfig(`--- github.com: - user: keiyuri - oauth_token: 123456` - - var root yaml.Node - err := yaml.Unmarshal([]byte(oldStyle), &root) - if err != nil { - panic("failed to parse test yaml") - } - - buf := bytes.NewBufferString("") - defer StubWriteConfig(buf, nil)() + oauth_token: 123456 +`, "")() + mainBuf := bytes.Buffer{} + hostsBuf := bytes.Buffer{} + defer StubWriteConfig(&mainBuf, &hostsBuf)() defer StubBackupConfig()() - err = migrateConfig("config.yml", &root) + _, err := ParseConfig("config.yml") eq(t, err, nil) - expected := `hosts: - github.com: - oauth_token: "123456" - user: keiyuri + expectedMain := "" + expectedHosts := `github.com: + user: keiyuri + oauth_token: "123456" ` - eq(t, buf.String(), expected) + eq(t, mainBuf.String(), expectedMain) + eq(t, hostsBuf.String(), expectedHosts) } func Test_parseConfigFile(t *testing.T) { diff --git a/internal/config/config_type.go b/internal/config/config_type.go index c2fc66d9b..f965b2c66 100644 --- a/internal/config/config_type.go +++ b/internal/config/config_type.go @@ -54,6 +54,7 @@ func (cm *ConfigMap) SetStringValue(key, value string) error { } valueNode = &yaml.Node{ Kind: yaml.ScalarNode, + Tag: "!!str", Value: "", } @@ -168,16 +169,42 @@ func (c *fileConfig) configForHost(hostname string) (*HostConfig, error) { } func (c *fileConfig) Write() error { - marshalled, err := yaml.Marshal(c.documentRoot) + mainData := yaml.Node{Kind: yaml.MappingNode} + hostsData := yaml.Node{Kind: yaml.MappingNode} + + nodes := c.documentRoot.Content[0].Content + for i := 0; i < len(nodes)-1; i += 2 { + if nodes[i].Value == "hosts" { + hostsData.Content = append(hostsData.Content, nodes[i+1].Content...) + } else { + mainData.Content = append(mainData.Content, nodes[i], nodes[i+1]) + } + } + + mainBytes, err := yaml.Marshal(&mainData) if err != nil { return err } - if bytes.Equal(marshalled, []byte("{}\n")) { - marshalled = []byte{} + fn := ConfigFile() + err = WriteConfigFile(fn, yamlNormalize(mainBytes)) + if err != nil { + return err } - return WriteConfigFile(ConfigFile(), marshalled) + hostsBytes, err := yaml.Marshal(&hostsData) + if err != nil { + return err + } + + return WriteConfigFile(hostsConfigFile(fn), yamlNormalize(hostsBytes)) +} + +func yamlNormalize(b []byte) []byte { + if bytes.Equal(b, []byte("{}\n")) { + return []byte{} + } + return b } func (c *fileConfig) hostEntries() ([]*HostConfig, error) { diff --git a/internal/config/config_type_test.go b/internal/config/config_type_test.go index 97ffd9aa7..1c8105186 100644 --- a/internal/config/config_type_test.go +++ b/internal/config/config_type_test.go @@ -8,8 +8,9 @@ import ( ) func Test_fileConfig_Set(t *testing.T) { - cb := bytes.Buffer{} - StubWriteConfig(&cb, nil) + mainBuf := bytes.Buffer{} + hostsBuf := bytes.Buffer{} + defer StubWriteConfig(&mainBuf, &hostsBuf)() c := NewBlankConfig() assert.NoError(t, c.Set("", "editor", "nano")) @@ -18,22 +19,23 @@ func Test_fileConfig_Set(t *testing.T) { assert.NoError(t, c.Set("github.com", "user", "hubot")) assert.NoError(t, c.Write()) - assert.Equal(t, `editor: nano -hosts: - github.com: - git_protocol: ssh - user: hubot - example.com: - editor: vim -`, cb.String()) + assert.Equal(t, "editor: nano\n", mainBuf.String()) + assert.Equal(t, `github.com: + git_protocol: ssh + user: hubot +example.com: + editor: vim +`, hostsBuf.String()) } func Test_fileConfig_Write(t *testing.T) { - cb := bytes.Buffer{} - StubWriteConfig(&cb, nil) + mainBuf := bytes.Buffer{} + hostsBuf := bytes.Buffer{} + defer StubWriteConfig(&mainBuf, &hostsBuf)() c := NewBlankConfig() assert.NoError(t, c.Write()) - assert.Equal(t, "", cb.String()) + assert.Equal(t, "", mainBuf.String()) + assert.Equal(t, "", hostsBuf.String()) }