diff --git a/Makefile b/Makefile index dd863cb0d..b925c11fd 100644 --- a/Makefile +++ b/Makefile @@ -11,8 +11,8 @@ endif LDFLAGS := -X github.com/cli/cli/command.Version=$(GH_VERSION) $(LDFLAGS) LDFLAGS := -X github.com/cli/cli/command.BuildDate=$(BUILD_DATE) $(LDFLAGS) ifdef GH_OAUTH_CLIENT_SECRET - LDFLAGS := -X github.com/cli/cli/context.oauthClientID=$(GH_OAUTH_CLIENT_ID) $(LDFLAGS) - LDFLAGS := -X github.com/cli/cli/context.oauthClientSecret=$(GH_OAUTH_CLIENT_SECRET) $(LDFLAGS) + LDFLAGS := -X github.com/cli/cli/internal/config.oauthClientID=$(GH_OAUTH_CLIENT_ID) $(LDFLAGS) + LDFLAGS := -X github.com/cli/cli/internal/config.oauthClientSecret=$(GH_OAUTH_CLIENT_SECRET) $(LDFLAGS) endif bin/gh: $(BUILD_FILES) diff --git a/api/client.go b/api/client.go index fe1fcf842..b8c6857fb 100644 --- a/api/client.go +++ b/api/client.go @@ -165,6 +165,10 @@ func (c Client) HasScopes(wantedScopes ...string) (bool, string, error) { } defer res.Body.Close() + if res.StatusCode != 200 { + return false, "", handleHTTPError(res) + } + appID := res.Header.Get("X-Oauth-Client-Id") hasScopes := strings.Split(res.Header.Get("X-Oauth-Scopes"), ",") diff --git a/auth/oauth.go b/auth/oauth.go index 5e286443e..3568956c5 100644 --- a/auth/oauth.go +++ b/auth/oauth.go @@ -66,9 +66,9 @@ func (oa *OAuthFlow) ObtainAccessToken() (accessToken string, err error) { fmt.Fprintf(os.Stderr, "Please open the following URL manually:\n%s\n", startURL) fmt.Fprintf(os.Stderr, "") // TODO: Temporary workaround for https://github.com/cli/cli/issues/297 - fmt.Fprintf(os.Stderr, "If you are on a server or other headless system, use this workaround instead:") - fmt.Fprintf(os.Stderr, " 1. Complete authentication on a GUI system") - fmt.Fprintf(os.Stderr, " 2. Copy the contents of ~/.config/gh/config.yml to this system") + fmt.Fprintf(os.Stderr, "If you are on a server or other headless system, use this workaround instead:\n") + fmt.Fprintf(os.Stderr, " 1. Complete authentication on a GUI system;\n") + fmt.Fprintf(os.Stderr, " 2. Copy the contents of `~/.config/gh/hosts.yml` to this system.\n") } _ = http.Serve(listener, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/command/alias_test.go b/command/alias_test.go index 09e0c7bc5..666d74589 100644 --- a/command/alias_test.go +++ b/command/alias_test.go @@ -12,8 +12,9 @@ import ( func TestAliasSet_gh_command(t *testing.T) { initBlankContext("", "OWNER/REPO", "trunk") - buf := bytes.NewBufferString("") - defer config.StubWriteConfig(buf)() + mainBuf := bytes.Buffer{} + hostsBuf := bytes.Buffer{} + defer config.StubWriteConfig(&mainBuf, &hostsBuf)() _, err := RunCommand("alias set pr pr status") if err == nil { @@ -26,15 +27,14 @@ func TestAliasSet_gh_command(t *testing.T) { func TestAliasSet_empty_aliases(t *testing.T) { cfg := `--- aliases: -hosts: - github.com: - user: OWNER - oauth_token: token123 +editor: vim ` initBlankContext(cfg, "OWNER/REPO", "trunk") - buf := bytes.NewBufferString("") - defer config.StubWriteConfig(buf)() + mainBuf := bytes.Buffer{} + hostsBuf := bytes.Buffer{} + defer config.StubWriteConfig(&mainBuf, &hostsBuf)() + output, err := RunCommand("alias set co pr checkout") if err != nil { @@ -45,12 +45,9 @@ hosts: expected := `aliases: co: pr checkout -hosts: - github.com: - user: OWNER - oauth_token: token123 +editor: vim ` - eq(t, buf.String(), expected) + eq(t, mainBuf.String(), expected) } func TestAliasSet_existing_alias(t *testing.T) { @@ -64,8 +61,10 @@ aliases: ` initBlankContext(cfg, "OWNER/REPO", "trunk") - buf := bytes.NewBufferString("") - defer config.StubWriteConfig(buf)() + mainBuf := bytes.Buffer{} + hostsBuf := bytes.Buffer{} + defer config.StubWriteConfig(&mainBuf, &hostsBuf)() + output, err := RunCommand("alias set co pr checkout -Rcool/repo") if err != nil { @@ -78,8 +77,9 @@ aliases: func TestAliasSet_space_args(t *testing.T) { initBlankContext("", "OWNER/REPO", "trunk") - buf := bytes.NewBufferString("") - defer config.StubWriteConfig(buf)() + mainBuf := bytes.Buffer{} + hostsBuf := bytes.Buffer{} + defer config.StubWriteConfig(&mainBuf, &hostsBuf)() output, err := RunCommand(`alias set il issue list -l 'cool story'`) @@ -89,7 +89,7 @@ func TestAliasSet_space_args(t *testing.T) { test.ExpectLines(t, output.String(), `Adding alias for il: issue list -l "cool story"`) - test.ExpectLines(t, buf.String(), `il: issue list -l "cool story"`) + test.ExpectLines(t, mainBuf.String(), `il: issue list -l "cool story"`) } func TestAliasSet_arg_processing(t *testing.T) { @@ -118,77 +118,66 @@ func TestAliasSet_arg_processing(t *testing.T) { `ix: issue list --author=\$1 --label=\$2`}, } - var buf *bytes.Buffer for _, c := range cases { - buf = bytes.NewBufferString("") - defer config.StubWriteConfig(buf)() + mainBuf := bytes.Buffer{} + hostsBuf := bytes.Buffer{} + defer config.StubWriteConfig(&mainBuf, &hostsBuf)() + output, err := RunCommand(c.Cmd) if err != nil { t.Fatalf("got unexpected error running %s: %s", c.Cmd, err) } test.ExpectLines(t, output.String(), c.ExpectedOutputLine) - test.ExpectLines(t, buf.String(), c.ExpectedConfigLine) + test.ExpectLines(t, mainBuf.String(), c.ExpectedConfigLine) } } func TestAliasSet_init_alias_cfg(t *testing.T) { cfg := `--- -hosts: - github.com: - user: OWNER - oauth_token: token123 +editor: vim ` initBlankContext(cfg, "OWNER/REPO", "trunk") - buf := bytes.NewBufferString("") - defer config.StubWriteConfig(buf)() + mainBuf := bytes.Buffer{} + hostsBuf := bytes.Buffer{} + defer config.StubWriteConfig(&mainBuf, &hostsBuf)() output, err := RunCommand("alias set diff pr diff") if err != nil { t.Fatalf("unexpected error: %s", err) } - expected := `hosts: - github.com: - user: OWNER - oauth_token: token123 + expected := `editor: vim aliases: diff: pr diff ` test.ExpectLines(t, output.String(), "Adding alias for diff: pr diff", "Added alias.") - eq(t, buf.String(), expected) + eq(t, mainBuf.String(), expected) } func TestAliasSet_existing_aliases(t *testing.T) { cfg := `--- -hosts: - github.com: - user: OWNER - oauth_token: token123 aliases: foo: bar ` initBlankContext(cfg, "OWNER/REPO", "trunk") - buf := bytes.NewBufferString("") - defer config.StubWriteConfig(buf)() + mainBuf := bytes.Buffer{} + hostsBuf := bytes.Buffer{} + defer config.StubWriteConfig(&mainBuf, &hostsBuf)() output, err := RunCommand("alias set view pr view") if err != nil { t.Fatalf("unexpected error: %s", err) } - expected := `hosts: - github.com: - user: OWNER - oauth_token: token123 -aliases: + expected := `aliases: foo: bar view: pr view ` test.ExpectLines(t, output.String(), "Adding alias for view: pr view", "Added alias.") - eq(t, buf.String(), expected) + eq(t, mainBuf.String(), expected) } diff --git a/command/config_test.go b/command/config_test.go index 00148f3b4..53654f54f 100644 --- a/command/config_test.go +++ b/command/config_test.go @@ -49,23 +49,31 @@ func TestConfigGet_not_found(t *testing.T) { func TestConfigSet(t *testing.T) { initBlankContext("", "OWNER/REPO", "master") - buf := bytes.NewBufferString("") - defer config.StubWriteConfig(buf)() + mainBuf := bytes.Buffer{} + hostsBuf := bytes.Buffer{} + defer config.StubWriteConfig(&mainBuf, &hostsBuf)() + output, err := RunCommand("config set editor ed") if err != nil { t.Fatalf("error running command `config set editor ed`: %v", err) } - eq(t, output.String(), "") + if len(output.String()) > 0 { + t.Errorf("expected output to be blank: %q", output.String()) + } - expected := `hosts: - github.com: - user: OWNER - oauth_token: 1234567890 -editor: ed + expectedMain := "editor: ed\n" + expectedHosts := `github.com: + user: OWNER + oauth_token: "1234567890" ` - eq(t, buf.String(), expected) + if mainBuf.String() != expectedMain { + t.Errorf("expected config.yml to be %q, got %q", expectedMain, mainBuf.String()) + } + if hostsBuf.String() != expectedHosts { + t.Errorf("expected hosts.yml to be %q, got %q", expectedHosts, hostsBuf.String()) + } } func TestConfigSet_update(t *testing.T) { @@ -79,23 +87,31 @@ editor: ed initBlankContext(cfg, "OWNER/REPO", "master") - buf := bytes.NewBufferString("") - defer config.StubWriteConfig(buf)() + mainBuf := bytes.Buffer{} + hostsBuf := bytes.Buffer{} + defer config.StubWriteConfig(&mainBuf, &hostsBuf)() output, err := RunCommand("config set editor vim") if err != nil { t.Fatalf("error running command `config get editor`: %v", err) } - eq(t, output.String(), "") + if len(output.String()) > 0 { + t.Errorf("expected output to be blank: %q", output.String()) + } - expected := `hosts: - github.com: - user: OWNER - oauth_token: MUSTBEHIGHCUZIMATOKEN -editor: vim + expectedMain := "editor: vim\n" + expectedHosts := `github.com: + user: OWNER + oauth_token: MUSTBEHIGHCUZIMATOKEN ` - eq(t, buf.String(), expected) + + if mainBuf.String() != expectedMain { + t.Errorf("expected config.yml to be %q, got %q", expectedMain, mainBuf.String()) + } + if hostsBuf.String() != expectedHosts { + t.Errorf("expected hosts.yml to be %q, got %q", expectedHosts, hostsBuf.String()) + } } func TestConfigGetHost(t *testing.T) { @@ -141,23 +157,32 @@ git_protocol: ssh func TestConfigSetHost(t *testing.T) { initBlankContext("", "OWNER/REPO", "master") - buf := bytes.NewBufferString("") - defer config.StubWriteConfig(buf)() + mainBuf := bytes.Buffer{} + hostsBuf := bytes.Buffer{} + defer config.StubWriteConfig(&mainBuf, &hostsBuf)() + output, err := RunCommand("config set -hgithub.com git_protocol ssh") if err != nil { t.Fatalf("error running command `config set editor ed`: %v", err) } - eq(t, output.String(), "") + if len(output.String()) > 0 { + t.Errorf("expected output to be blank: %q", output.String()) + } - expected := `hosts: - github.com: - user: OWNER - oauth_token: 1234567890 - git_protocol: ssh + expectedMain := "" + expectedHosts := `github.com: + user: OWNER + oauth_token: "1234567890" + git_protocol: ssh ` - eq(t, buf.String(), expected) + if mainBuf.String() != expectedMain { + t.Errorf("expected config.yml to be %q, got %q", expectedMain, mainBuf.String()) + } + if hostsBuf.String() != expectedHosts { + t.Errorf("expected hosts.yml to be %q, got %q", expectedHosts, hostsBuf.String()) + } } func TestConfigSetHost_update(t *testing.T) { @@ -171,21 +196,30 @@ hosts: initBlankContext(cfg, "OWNER/REPO", "master") - buf := bytes.NewBufferString("") - defer config.StubWriteConfig(buf)() + mainBuf := bytes.Buffer{} + hostsBuf := bytes.Buffer{} + defer config.StubWriteConfig(&mainBuf, &hostsBuf)() output, err := RunCommand("config set -hgithub.com git_protocol https") if err != nil { t.Fatalf("error running command `config get editor`: %v", err) } - eq(t, output.String(), "") + if len(output.String()) > 0 { + t.Errorf("expected output to be blank: %q", output.String()) + } - expected := `hosts: - github.com: - git_protocol: https - user: OWNER - oauth_token: MUSTBEHIGHCUZIMATOKEN + expectedMain := "" + expectedHosts := `github.com: + git_protocol: https + user: OWNER + oauth_token: MUSTBEHIGHCUZIMATOKEN ` - eq(t, buf.String(), expected) + + if mainBuf.String() != expectedMain { + t.Errorf("expected config.yml to be %q, got %q", expectedMain, mainBuf.String()) + } + if hostsBuf.String() != expectedHosts { + t.Errorf("expected hosts.yml to be %q, got %q", expectedHosts, hostsBuf.String()) + } } diff --git a/command/root.go b/command/root.go index 0055a2716..db76509ed 100644 --- a/command/root.go +++ b/command/root.go @@ -181,24 +181,16 @@ var apiClientForContext = func(ctx context.Context) (*api.Client, error) { checkScopesFunc := func(appID string) error { if config.IsGitHubApp(appID) && !tokenFromEnv() && utils.IsTerminal(os.Stdin) && utils.IsTerminal(os.Stderr) { - newToken, loginHandle, err := config.AuthFlow("Notice: additional authorization required") - if err != nil { - return err - } cfg, err := ctx.Config() if err != nil { return err } - _ = cfg.Set(defaultHostname, "oauth_token", newToken) - _ = cfg.Set(defaultHostname, "user", loginHandle) - // update config file on disk - err = cfg.Write() + newToken, err := config.AuthFlowWithConfig(cfg, defaultHostname, "Notice: additional authorization required") if err != nil { return err } // update configuration in memory token = newToken - config.AuthFlowComplete() } else { fmt.Fprintln(os.Stderr, "Warning: gh now requires the `read:org` OAuth scope.") fmt.Fprintln(os.Stderr, "Visit https://github.com/settings/tokens and edit your token to enable `read:org`") @@ -235,28 +227,19 @@ var ensureScopes = func(ctx context.Context, client *api.Client, wantedScopes .. tokenFromEnv := len(os.Getenv("GITHUB_TOKEN")) > 0 if config.IsGitHubApp(appID) && !tokenFromEnv && utils.IsTerminal(os.Stdin) && utils.IsTerminal(os.Stderr) { - newToken, loginHandle, err := config.AuthFlow("Notice: additional authorization required") - if err != nil { - return client, err - } cfg, err := ctx.Config() if err != nil { - return client, err + return nil, err } - _ = cfg.Set(defaultHostname, "oauth_token", newToken) - _ = cfg.Set(defaultHostname, "user", loginHandle) - // update config file on disk - err = cfg.Write() + _, err = config.AuthFlowWithConfig(cfg, defaultHostname, "Notice: additional authorization required") if err != nil { - return client, err + return nil, err } - // update configuration in memory - config.AuthFlowComplete() + reloadedClient, err := apiClientForContext(ctx) if err != nil { return client, err } - return reloadedClient, nil } else { fmt.Fprintln(os.Stderr, fmt.Sprintf("Warning: gh now requires %s OAuth scopes.", wantedScopes)) diff --git a/command/testing.go b/command/testing.go index 5612f58ea..af713aa29 100644 --- a/command/testing.go +++ b/command/testing.go @@ -19,7 +19,7 @@ import ( const defaultTestConfig = `hosts: github.com: user: OWNER - oauth_token: 1234567890 + oauth_token: "1234567890" ` type askStubber struct { @@ -88,7 +88,7 @@ func initBlankContext(cfg, repo, branch string) { // NOTE we are not restoring the original readConfig; we never want to touch the config file on // disk during tests. - config.StubConfig(cfg) + config.StubConfig(cfg, "") return ctx } diff --git a/context/blank_context.go b/context/blank_context.go index ed8784cfc..3ea657abe 100644 --- a/context/blank_context.go +++ b/context/blank_context.go @@ -23,7 +23,7 @@ type blankContext struct { } func (c *blankContext) Config() (config.Config, error) { - cfg, err := config.ParseConfig("boom.txt") + cfg, err := config.ParseConfig("config.yml") if err != nil { panic(fmt.Sprintf("failed to parse config during tests. did you remember to stub? error: %s", err)) } diff --git a/context/context.go b/context/context.go index ddeb82c4d..236f9e722 100644 --- a/context/context.go +++ b/context/context.go @@ -3,6 +3,7 @@ package context import ( "errors" "fmt" + "os" "sort" "github.com/cli/cli/api" @@ -162,11 +163,13 @@ type fsContext struct { func (c *fsContext) Config() (config.Config, error) { if c.config == nil { - config, err := config.ParseOrSetupConfigFile(config.ConfigFile()) - if err != nil { + cfg, err := config.ParseDefaultConfig() + if errors.Is(err, os.ErrNotExist) { + cfg = config.NewBlankConfig() + } else if err != nil { return nil, err } - c.config = config + c.config = cfg c.authToken = "" } return c.config, nil @@ -182,8 +185,12 @@ func (c *fsContext) AuthToken() (string, error) { return "", err } + var notFound *config.NotFoundError token, err := cfg.Get(defaultHostname, "oauth_token") - if token == "" || err != nil { + if token == "" || errors.As(err, ¬Found) { + // interactive OAuth flow + return config.AuthFlowWithConfig(cfg, defaultHostname, "Notice: authentication required") + } else if err != nil { return "", err } diff --git a/internal/config/config_file.go b/internal/config/config_file.go index 2670aff20..49a2770d3 100644 --- a/internal/config/config_file.go +++ b/internal/config/config_file.go @@ -21,20 +21,16 @@ func ConfigFile() string { return path.Join(ConfigDir(), "config.yml") } -func ParseOrSetupConfigFile(fn string) (Config, error) { - config, err := ParseConfig(fn) - if err != nil && errors.Is(err, os.ErrNotExist) { - return setupConfigFile(fn) - } - return config, err +func hostsConfigFile(filename string) string { + return path.Join(path.Dir(filename), "hosts.yml") } func ParseDefaultConfig() (Config, error) { return ParseConfig(ConfigFile()) } -var ReadConfigFile = func(fn string) ([]byte, error) { - f, err := os.Open(fn) +var ReadConfigFile = func(filename string) ([]byte, error) { + f, err := os.Open(filename) if err != nil { return nil, err } @@ -48,8 +44,13 @@ var ReadConfigFile = func(fn string) ([]byte, error) { return data, nil } -var WriteConfigFile = func(fn string, data []byte) error { - cfgFile, err := os.OpenFile(ConfigFile(), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) // cargo coded from setup +var WriteConfigFile = func(filename string, data []byte) error { + err := os.MkdirAll(path.Dir(filename), 0771) + if err != nil { + return err + } + + cfgFile, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) // cargo coded from setup if err != nil { return err } @@ -63,12 +64,12 @@ var WriteConfigFile = func(fn string, data []byte) error { return err } -var BackupConfigFile = func(fn string) error { - return os.Rename(fn, fn+".bak") +var BackupConfigFile = func(filename string) error { + return os.Rename(filename, filename+".bak") } -func parseConfigFile(fn string) ([]byte, *yaml.Node, error) { - data, err := ReadConfigFile(fn) +func parseConfigFile(filename string) ([]byte, *yaml.Node, error) { + data, err := ReadConfigFile(filename) if err != nil { return nil, nil, err } @@ -78,8 +79,11 @@ func parseConfigFile(fn string) ([]byte, *yaml.Node, error) { if err != nil { return data, nil, err } - if len(root.Content) < 1 { - return data, &root, fmt.Errorf("malformed config") + if len(root.Content) == 0 { + return data, &yaml.Node{ + Kind: yaml.DocumentNode, + Content: []*yaml.Node{{Kind: yaml.MappingNode}}, + }, nil } if root.Content[0].Kind != yaml.MappingNode { return data, &root, fmt.Errorf("expected a top level map") @@ -90,91 +94,76 @@ func parseConfigFile(fn string) ([]byte, *yaml.Node, error) { func isLegacy(root *yaml.Node) bool { for _, v := range root.Content[0].Content { - if v.Value == "hosts" { - return false + if v.Value == "github.com" { + return true } } - return true + return false } -func migrateConfig(fn string, root *yaml.Node) error { - type ConfigEntry map[string]string - type ConfigHash map[string]ConfigEntry - - newConfigData := map[string]ConfigHash{} - newConfigData["hosts"] = ConfigHash{} - - topLevelKeys := root.Content[0].Content - - for i, x := range topLevelKeys { - if x.Value == "" { - continue - } - if i+1 == len(topLevelKeys) { - break - } - hostname := x.Value - newConfigData["hosts"][hostname] = ConfigEntry{} - - authKeys := topLevelKeys[i+1].Content[0].Content - - for j, y := range authKeys { - if j+1 == len(authKeys) { - break - } - switch y.Value { - case "user": - newConfigData["hosts"][hostname]["user"] = authKeys[j+1].Value - case "oauth_token": - newConfigData["hosts"][hostname]["oauth_token"] = authKeys[j+1].Value - } - } - } - - if _, ok := newConfigData["hosts"][defaultHostname]; !ok { - return errors.New("could not find default host configuration") - } - - defaultHostConfig := newConfigData["hosts"][defaultHostname] - - if _, ok := defaultHostConfig["user"]; !ok { - return errors.New("default host configuration missing user") - } - - if _, ok := defaultHostConfig["oauth_token"]; !ok { - return errors.New("default host configuration missing oauth_token") - } - - newConfig, err := yaml.Marshal(newConfigData) +func migrateConfig(filename string) error { + b, err := ReadConfigFile(filename) if err != nil { return err } - err = BackupConfigFile(fn) + var hosts map[string][]yaml.Node + err = yaml.Unmarshal(b, &hosts) + if err != nil { + return fmt.Errorf("error decoding legacy format: %w", err) + } + + cfg := NewBlankConfig() + for hostname, entries := range hosts { + if len(entries) < 1 { + continue + } + mapContent := entries[0].Content + for i := 0; i < len(mapContent)-1; i += 2 { + if err := cfg.Set(hostname, mapContent[i].Value, mapContent[i+1].Value); err != nil { + return err + } + } + } + + err = BackupConfigFile(filename) if err != nil { return fmt.Errorf("failed to back up existing config: %w", err) } - return WriteConfigFile(fn, newConfig) + return cfg.Write() } -func ParseConfig(fn string) (Config, error) { - _, root, err := parseConfigFile(fn) +func ParseConfig(filename string) (Config, error) { + _, root, err := parseConfigFile(filename) if err != nil { return nil, err } if isLegacy(root) { - err = migrateConfig(fn, root) + err = migrateConfig(filename) if err != nil { - return nil, err + return nil, fmt.Errorf("error migrating legacy config: %w", err) } - _, root, err = parseConfigFile(fn) + _, root, err = parseConfigFile(filename) if err != nil { return nil, fmt.Errorf("failed to reparse migrated config: %w", err) } + } else { + if _, hostsRoot, err := parseConfigFile(hostsConfigFile(filename)); err == nil { + if len(hostsRoot.Content[0].Content) > 0 { + newContent := []*yaml.Node{ + {Value: "hosts"}, + hostsRoot.Content[0], + } + restContent := root.Content[0].Content + root.Content[0].Content = append(newContent, restContent...) + } + } else if !errors.Is(err, os.ErrNotExist) { + return nil, err + } } return NewConfig(root), nil diff --git a/internal/config/config_file_test.go b/internal/config/config_file_test.go index 4ddd63435..edc3a6ae9 100644 --- a/internal/config/config_file_test.go +++ b/internal/config/config_file_test.go @@ -3,6 +3,7 @@ package config import ( "bytes" "errors" + "fmt" "reflect" "testing" @@ -22,8 +23,8 @@ hosts: github.com: user: monalisa oauth_token: OTOKEN -`)() - config, err := ParseConfig("filename") +`, "")() + config, err := ParseConfig("config.yml") eq(t, err, nil) user, err := config.Get("github.com", "user") eq(t, err, nil) @@ -42,8 +43,24 @@ hosts: github.com: user: monalisa oauth_token: OTOKEN +`, "")() + config, err := ParseConfig("config.yml") + eq(t, err, nil) + user, err := config.Get("github.com", "user") + eq(t, err, nil) + eq(t, user, "monalisa") + token, err := config.Get("github.com", "oauth_token") + eq(t, err, nil) + eq(t, token, "OTOKEN") +} + +func Test_parseConfig_hostsFile(t *testing.T) { + defer StubConfig("", `--- +github.com: + user: monalisa + oauth_token: OTOKEN `)() - config, err := ParseConfig("filename") + config, err := ParseConfig("config.yml") eq(t, err, nil) user, err := config.Get("github.com", "user") eq(t, err, nil) @@ -59,38 +76,47 @@ hosts: example.com: user: wronguser oauth_token: NOTTHIS -`)() - config, err := ParseConfig("filename") +`, "")() + config, err := ParseConfig("config.yml") eq(t, err, nil) _, err = config.Get("github.com", "user") - eq(t, err, errors.New(`could not find config entry for "github.com"`)) + eq(t, err, &NotFoundError{errors.New(`could not find config entry for "github.com"`)}) } -func Test_migrateConfig(t *testing.T) { - oldStyle := `--- +func Test_ParseConfig_migrateConfig(t *testing.T) { + defer StubConfig(`--- github.com: - user: keiyuri - oauth_token: 123456` - - var root yaml.Node - err := yaml.Unmarshal([]byte(oldStyle), &root) - if err != nil { - panic("failed to parse test yaml") - } - - buf := bytes.NewBufferString("") - defer StubWriteConfig(buf)() + oauth_token: 123456 +`, "")() + mainBuf := bytes.Buffer{} + hostsBuf := bytes.Buffer{} + defer StubWriteConfig(&mainBuf, &hostsBuf)() defer StubBackupConfig()() - err = migrateConfig("boom.txt", &root) + _, err := ParseConfig("config.yml") eq(t, err, nil) - expected := `hosts: - github.com: - oauth_token: "123456" - user: keiyuri + expectedMain := "" + expectedHosts := `github.com: + user: keiyuri + oauth_token: "123456" ` - eq(t, buf.String(), expected) + eq(t, mainBuf.String(), expectedMain) + eq(t, hostsBuf.String(), expectedHosts) +} + +func Test_parseConfigFile(t *testing.T) { + fileContents := []string{"", " ", "\n"} + for _, contents := range fileContents { + t.Run(fmt.Sprintf("contents: %q", contents), func(t *testing.T) { + defer StubConfig(contents, "")() + _, yamlRoot, err := parseConfigFile("config.yml") + eq(t, err, nil) + eq(t, yamlRoot.Content[0].Kind, yaml.MappingNode) + eq(t, len(yamlRoot.Content[0].Content), 0) + }) + } } diff --git a/internal/config/config_setup.go b/internal/config/config_setup.go index 2fc414a1b..7217eb3d9 100644 --- a/internal/config/config_setup.go +++ b/internal/config/config_setup.go @@ -5,16 +5,10 @@ import ( "fmt" "io" "os" - "path/filepath" "strings" "github.com/cli/cli/api" "github.com/cli/cli/auth" - "gopkg.in/yaml.v3" -) - -const ( - oauthHost = "github.com" ) var ( @@ -31,7 +25,31 @@ func IsGitHubApp(id string) bool { return id == "178c6fc778ccc68e1d6a" || id == "4d747ba5675d5d66553f" } -func AuthFlow(notice string) (string, string, error) { +func AuthFlowWithConfig(cfg Config, hostname, notice string) (string, error) { + token, userLogin, err := authFlow(hostname, notice) + if err != nil { + return "", err + } + + err = cfg.Set(hostname, "user", userLogin) + if err != nil { + return "", err + } + err = cfg.Set(hostname, "oauth_token", token) + if err != nil { + return "", err + } + + err = cfg.Write() + if err != nil { + return "", err + } + + AuthFlowComplete() + return token, nil +} + +func authFlow(oauthHost, notice string) (string, string, error) { var verboseStream io.Writer if strings.Contains(os.Getenv("DEBUG"), "oauth") { verboseStream = os.Stderr @@ -69,53 +87,6 @@ func AuthFlowComplete() { _ = waitForEnter(os.Stdin) } -// FIXME: make testable -func setupConfigFile(filename string) (Config, error) { - token, userLogin, err := AuthFlow("Notice: authentication required") - if err != nil { - return nil, err - } - - // TODO this sucks. It precludes us laying out a nice config with comments and such. - type yamlConfig struct { - Hosts map[string]map[string]string - } - - yamlHosts := map[string]map[string]string{} - yamlHosts[oauthHost] = map[string]string{} - yamlHosts[oauthHost]["user"] = userLogin - yamlHosts[oauthHost]["oauth_token"] = token - - defaultConfig := yamlConfig{ - Hosts: yamlHosts, - } - - err = os.MkdirAll(filepath.Dir(filename), 0771) - if err != nil { - return nil, err - } - - cfgFile, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) - if err != nil { - return nil, err - } - defer cfgFile.Close() - - yamlData, err := yaml.Marshal(defaultConfig) - if err != nil { - return nil, err - } - _, err = cfgFile.Write(yamlData) - if err != nil { - return nil, err - } - - AuthFlowComplete() - - // TODO cleaner error handling? this "should" always work given that we /just/ wrote the file... - return ParseConfig(filename) -} - func getViewer(token string) (string, error) { http := api.NewClient(api.AddHeader("Authorization", fmt.Sprintf("token %s", token))) return api.CurrentLoginName(http) diff --git a/internal/config/config_type.go b/internal/config/config_type.go index 8d2c77b9b..f13666364 100644 --- a/internal/config/config_type.go +++ b/internal/config/config_type.go @@ -1,18 +1,17 @@ package config import ( + "bytes" "errors" "fmt" "gopkg.in/yaml.v3" ) -const defaultHostname = "github.com" const defaultGitProtocol = "https" // This interface describes interacting with some persistent configuration for gh. type Config interface { - Hosts() ([]*HostConfig, error) Get(string, string) (string, error) Set(string, string, string) error Aliases() (*AliasConfig, error) @@ -61,6 +60,7 @@ func (cm *ConfigMap) SetStringValue(key, value string) error { } valueNode = &yaml.Node{ Kind: yaml.ScalarNode, + Tag: "!!str", Value: "", } @@ -107,11 +107,17 @@ func NewConfig(root *yaml.Node) Config { } } +func NewBlankConfig() Config { + return NewConfig(&yaml.Node{ + Kind: yaml.DocumentNode, + Content: []*yaml.Node{{Kind: yaml.MappingNode}}, + }) +} + // This type implements a Config interface and represents a config file on disk. type fileConfig struct { ConfigMap documentRoot *yaml.Node - hosts []*HostConfig } func (c *fileConfig) Root() *yaml.Node { @@ -159,7 +165,10 @@ func (c *fileConfig) Set(hostname, key, value string) error { return c.SetStringValue(key, value) } else { hostCfg, err := c.configForHost(hostname) - if err != nil { + var notFound *NotFoundError + if errors.As(err, ¬Found) { + hostCfg = c.makeConfigForHost(hostname) + } else if err != nil { return err } return hostCfg.SetStringValue(key, value) @@ -167,7 +176,7 @@ func (c *fileConfig) Set(hostname, key, value string) error { } func (c *fileConfig) configForHost(hostname string) (*HostConfig, error) { - hosts, err := c.Hosts() + hosts, err := c.hostEntries() if err != nil { return nil, fmt.Errorf("failed to parse hosts config: %w", err) } @@ -177,16 +186,46 @@ func (c *fileConfig) configForHost(hostname string) (*HostConfig, error) { return hc, nil } } - return nil, fmt.Errorf("could not find config entry for %q", hostname) + return nil, &NotFoundError{fmt.Errorf("could not find config entry for %q", hostname)} } func (c *fileConfig) Write() error { - marshalled, err := yaml.Marshal(c.documentRoot) + mainData := yaml.Node{Kind: yaml.MappingNode} + hostsData := yaml.Node{Kind: yaml.MappingNode} + + nodes := c.documentRoot.Content[0].Content + for i := 0; i < len(nodes)-1; i += 2 { + if nodes[i].Value == "hosts" { + hostsData.Content = append(hostsData.Content, nodes[i+1].Content...) + } else { + mainData.Content = append(mainData.Content, nodes[i], nodes[i+1]) + } + } + + mainBytes, err := yaml.Marshal(&mainData) if err != nil { return err } - return WriteConfigFile(ConfigFile(), marshalled) + filename := ConfigFile() + err = WriteConfigFile(filename, yamlNormalize(mainBytes)) + if err != nil { + return err + } + + hostsBytes, err := yaml.Marshal(&hostsData) + if err != nil { + return err + } + + return WriteConfigFile(hostsConfigFile(filename), yamlNormalize(hostsBytes)) +} + +func yamlNormalize(b []byte) []byte { + if bytes.Equal(b, []byte("{}\n")) { + return []byte{} + } + return b } func (c *fileConfig) Aliases() (*AliasConfig, error) { @@ -243,11 +282,7 @@ func (c *fileConfig) Aliases() (*AliasConfig, error) { }, nil } -func (c *fileConfig) Hosts() ([]*HostConfig, error) { - if len(c.hosts) > 0 { - return c.hosts, nil - } - +func (c *fileConfig) hostEntries() ([]*HostConfig, error) { entry, err := c.FindEntry("hosts") if err != nil { return nil, fmt.Errorf("could not find hosts config: %w", err) @@ -258,11 +293,39 @@ func (c *fileConfig) Hosts() ([]*HostConfig, error) { return nil, fmt.Errorf("could not parse hosts config: %w", err) } - c.hosts = hostConfigs - return hostConfigs, nil } +func (c *fileConfig) makeConfigForHost(hostname string) *HostConfig { + hostRoot := &yaml.Node{Kind: yaml.MappingNode} + hostCfg := &HostConfig{ + Host: hostname, + ConfigMap: ConfigMap{Root: hostRoot}, + } + + var notFound *NotFoundError + hostsEntry, err := c.FindEntry("hosts") + if errors.As(err, ¬Found) { + hostsEntry.KeyNode = &yaml.Node{ + Kind: yaml.ScalarNode, + Value: "hosts", + } + hostsEntry.ValueNode = &yaml.Node{Kind: yaml.MappingNode} + root := c.Root() + root.Content = append(root.Content, hostsEntry.KeyNode, hostsEntry.ValueNode) + } else if err != nil { + panic(err) + } + + hostsEntry.ValueNode.Content = append(hostsEntry.ValueNode.Content, + &yaml.Node{ + Kind: yaml.ScalarNode, + Value: hostname, + }, hostRoot) + + return hostCfg +} + func (c *fileConfig) parseHosts(hostsEntry *yaml.Node) ([]*HostConfig, error) { hostConfigs := []*HostConfig{} diff --git a/internal/config/config_type_test.go b/internal/config/config_type_test.go new file mode 100644 index 000000000..1c8105186 --- /dev/null +++ b/internal/config/config_type_test.go @@ -0,0 +1,41 @@ +package config + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_fileConfig_Set(t *testing.T) { + mainBuf := bytes.Buffer{} + hostsBuf := bytes.Buffer{} + defer StubWriteConfig(&mainBuf, &hostsBuf)() + + c := NewBlankConfig() + assert.NoError(t, c.Set("", "editor", "nano")) + assert.NoError(t, c.Set("github.com", "git_protocol", "ssh")) + assert.NoError(t, c.Set("example.com", "editor", "vim")) + assert.NoError(t, c.Set("github.com", "user", "hubot")) + assert.NoError(t, c.Write()) + + assert.Equal(t, "editor: nano\n", mainBuf.String()) + assert.Equal(t, `github.com: + git_protocol: ssh + user: hubot +example.com: + editor: vim +`, hostsBuf.String()) +} + +func Test_fileConfig_Write(t *testing.T) { + mainBuf := bytes.Buffer{} + hostsBuf := bytes.Buffer{} + defer StubWriteConfig(&mainBuf, &hostsBuf)() + + c := NewBlankConfig() + assert.NoError(t, c.Write()) + + assert.Equal(t, "", mainBuf.String()) + assert.Equal(t, "", hostsBuf.String()) +} diff --git a/internal/config/testing.go b/internal/config/testing.go index a91cedfae..59c2ff212 100644 --- a/internal/config/testing.go +++ b/internal/config/testing.go @@ -1,7 +1,10 @@ package config import ( + "fmt" "io" + "os" + "path" ) func StubBackupConfig() func() { @@ -15,21 +18,41 @@ func StubBackupConfig() func() { } } -func StubWriteConfig(w io.Writer) func() { +func StubWriteConfig(wc io.Writer, wh io.Writer) func() { orig := WriteConfigFile WriteConfigFile = func(fn string, data []byte) error { - _, err := w.Write(data) - return err + switch path.Base(fn) { + case "config.yml": + _, err := wc.Write(data) + return err + case "hosts.yml": + _, err := wh.Write(data) + return err + default: + return fmt.Errorf("write to unstubbed file: %q", fn) + } } return func() { WriteConfigFile = orig } } -func StubConfig(content string) func() { +func StubConfig(main, hosts string) func() { orig := ReadConfigFile ReadConfigFile = func(fn string) ([]byte, error) { - return []byte(content), nil + switch path.Base(fn) { + case "config.yml": + return []byte(main), nil + case "hosts.yml": + if hosts == "" { + return []byte(nil), os.ErrNotExist + } else { + return []byte(hosts), nil + } + default: + return []byte(nil), fmt.Errorf("read from unstubbed file: %q", fn) + } + } return func() { ReadConfigFile = orig