diff --git a/api/http_client.go b/api/http_client.go index daeb0f3da..6e92df0e9 100644 --- a/api/http_client.go +++ b/api/http_client.go @@ -14,6 +14,7 @@ import ( type configGetter interface { Get(string, string) (string, error) + AuthToken(string) (string, string) } type HTTPClientOptions struct { @@ -52,7 +53,9 @@ func NewHTTPClient(opts HTTPClientOptions) (*http.Client, error) { return nil, err } - client.Transport = AddAuthTokenHeader(client.Transport, opts.Config) + if opts.Config != nil { + client.Transport = AddAuthTokenHeader(client.Transport, opts.Config) + } return client, nil } @@ -75,7 +78,7 @@ func AddCacheTTLHeader(rt http.RoundTripper, ttl time.Duration) http.RoundTrippe func AddAuthTokenHeader(rt http.RoundTripper, cfg configGetter) http.RoundTripper { return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) { hostname := ghinstance.NormalizeHostname(getHost(req)) - if token, err := cfg.Get(hostname, "oauth_token"); err == nil && token != "" { + if token, _ := cfg.AuthToken(hostname); token != "" { req.Header.Set("Authorization", fmt.Sprintf("token %s", token)) } return rt.RoundTrip(req) diff --git a/api/http_client_test.go b/api/http_client_test.go index 06021e2f1..fa94f84cc 100644 --- a/api/http_client_test.go +++ b/api/http_client_test.go @@ -211,6 +211,10 @@ func (c tinyConfig) Get(host, key string) (string, error) { return c[fmt.Sprintf("%s:%s", host, key)], nil } +func (c tinyConfig) AuthToken(host string) (string, string) { + return c[fmt.Sprintf("%s:%s", host, "oauth_token")], "oauth_token" +} + var requestAtRE = regexp.MustCompile(`(?m)^\* Request at .+`) var dateRE = regexp.MustCompile(`(?m)^< Date: .+`) var hostWithPortRE = regexp.MustCompile(`127\.0\.0\.1:\d+`) diff --git a/cmd/gh/main.go b/cmd/gh/main.go index 433844338..444e1262f 100644 --- a/cmd/gh/main.go +++ b/cmd/gh/main.go @@ -98,10 +98,8 @@ func mainRun() exitCode { return exitError } - // TODO: remove after FromFullName has been revisited - if host, err := cfg.DefaultHost(); err == nil { - ghrepo.SetDefaultHost(host) - } + host, _ := cfg.DefaultHost() + ghrepo.SetDefaultHost(host) expandedArgs := []string{} if len(os.Args) > 0 { @@ -170,18 +168,17 @@ func mainRun() exitCode { // provide completions for aliases and extensions rootCmd.ValidArgsFunction = func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { var results []string - if aliases, err := cfg.Aliases(); err == nil { - for aliasName, aliasValue := range aliases.All() { - if strings.HasPrefix(aliasName, toComplete) { - var s string - if strings.HasPrefix(aliasValue, "!") { - s = fmt.Sprintf("%s\tShell alias", aliasName) - } else { - aliasValue = text.Truncate(80, aliasValue) - s = fmt.Sprintf("%s\tAlias for %s", aliasName, aliasValue) - } - results = append(results, s) + aliases := cfg.Aliases() + for aliasName, aliasValue := range aliases.All() { + if strings.HasPrefix(aliasName, toComplete) { + var s string + if strings.HasPrefix(aliasValue, "!") { + s = fmt.Sprintf("%s\tShell alias", aliasName) + } else { + aliasValue = text.Truncate(80, aliasValue) + s = fmt.Sprintf("%s\tAlias for %s", aliasName, aliasValue) } + results = append(results, s) } } for _, ext := range cmdFactory.ExtensionManager.List() { diff --git a/go.mod b/go.mod index 938f128fb..855d18b16 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/charmbracelet/glamour v0.4.0 github.com/charmbracelet/lipgloss v0.5.0 github.com/cli/browser v1.1.0 - github.com/cli/go-gh v0.0.4-0.20220614183308-ef2bca923638 + github.com/cli/go-gh v0.0.4-0.20220623035622-91ca4ef447d4 github.com/cli/oauth v0.9.0 github.com/cli/safeexec v1.0.0 github.com/cli/shurcooL-graphql v0.0.1 diff --git a/go.sum b/go.sum index d5fa65673..3b8e9d44b 100644 --- a/go.sum +++ b/go.sum @@ -58,8 +58,8 @@ github.com/cli/browser v1.1.0 h1:xOZBfkfY9L9vMBgqb1YwRirGu6QFaQ5dP/vXt5ENSOY= github.com/cli/browser v1.1.0/go.mod h1:HKMQAt9t12kov91Mn7RfZxyJQQgWgyS/3SZswlZ5iTI= github.com/cli/crypto v0.0.0-20210929142629-6be313f59b03 h1:3f4uHLfWx4/WlnMPXGai03eoWAI+oGHJwr+5OXfxCr8= github.com/cli/crypto v0.0.0-20210929142629-6be313f59b03/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -github.com/cli/go-gh v0.0.4-0.20220614183308-ef2bca923638 h1:7MXhocX2RDlWrjKZ1pZsy8eMNGa3xkZzPrGC1IPBfx4= -github.com/cli/go-gh v0.0.4-0.20220614183308-ef2bca923638/go.mod h1:Y/QFb/VxnXQH0W4VlP+507HVxMzQ430x8kdjUuVcono= +github.com/cli/go-gh v0.0.4-0.20220623035622-91ca4ef447d4 h1:6WrekNBE2Y+Xl9OCl7vsg49SSN68hwaVryfEawQevaQ= +github.com/cli/go-gh v0.0.4-0.20220623035622-91ca4ef447d4/go.mod h1:Y/QFb/VxnXQH0W4VlP+507HVxMzQ430x8kdjUuVcono= github.com/cli/oauth v0.9.0 h1:nxBC0Df4tUzMkqffAB+uZvisOwT3/N9FpkfdTDtafxc= github.com/cli/oauth v0.9.0/go.mod h1:qd/FX8ZBD6n1sVNQO3aIdRxeu5LGw9WhKnYhIIoC2A4= github.com/cli/safeexec v1.0.0 h1:0VngyaIyqACHdcMNWfo6+KdUYnqEr2Sg+bSP1pdF+dI= diff --git a/internal/authflow/flow.go b/internal/authflow/flow.go index bbe6f2524..fd7db40dd 100644 --- a/internal/authflow/flow.go +++ b/internal/authflow/flow.go @@ -32,9 +32,8 @@ var ( type iconfig interface { Get(string, string) (string, error) - Set(string, string, string) error + Set(string, string, string) Write() error - WriteHosts() error } func AuthFlowWithConfig(cfg iconfig, IO *iostreams.IOStreams, hostname, notice string, additionalScopes []string, isInteractive bool) (string, error) { @@ -55,16 +54,10 @@ func AuthFlowWithConfig(cfg iconfig, IO *iostreams.IOStreams, hostname, notice s return "", err } - err = cfg.Set(hostname, "user", userLogin) - if err != nil { - return "", err - } - err = cfg.Set(hostname, "oauth_token", token) - if err != nil { - return "", err - } + cfg.Set(hostname, "user", userLogin) + cfg.Set(hostname, "oauth_token", token) - return token, cfg.WriteHosts() + return token, cfg.Write() } func authFlow(oauthHost string, IO *iostreams.IOStreams, notice string, additionalScopes []string, isInteractive bool, browserLauncher string) (string, string, error) { diff --git a/internal/config/alias_config.go b/internal/config/alias_config.go deleted file mode 100644 index 148eb21f9..000000000 --- a/internal/config/alias_config.go +++ /dev/null @@ -1,60 +0,0 @@ -package config - -import ( - "fmt" -) - -type AliasConfig struct { - ConfigMap - Parent Config -} - -func (a *AliasConfig) Get(alias string) (string, bool) { - if a.Empty() { - return "", false - } - value, _ := a.GetStringValue(alias) - - return value, value != "" -} - -func (a *AliasConfig) Add(alias, expansion string) error { - err := a.SetStringValue(alias, expansion) - if err != nil { - return fmt.Errorf("failed to update config: %w", err) - } - - err = a.Parent.Write() - if err != nil { - return fmt.Errorf("failed to write config: %w", err) - } - - return nil -} - -func (a *AliasConfig) Delete(alias string) error { - a.RemoveEntry(alias) - - err := a.Parent.Write() - if err != nil { - return fmt.Errorf("failed to write config: %w", err) - } - - return nil -} - -func (a *AliasConfig) All() map[string]string { - out := map[string]string{} - - if a.Empty() { - return out - } - - for i := 0; i < len(a.Root.Content)-1; i += 2 { - key := a.Root.Content[i].Value - value := a.Root.Content[i+1].Value - out[key] = value - } - - return out -} diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 000000000..cd8cae0dc --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,223 @@ +package config + +import ( + "os" + "path/filepath" + + ghAuth "github.com/cli/go-gh/pkg/auth" + ghConfig "github.com/cli/go-gh/pkg/config" +) + +const ( + hosts = "hosts" + aliases = "aliases" +) + +// This interface describes interacting with some persistent configuration for gh. +//go:generate moq -rm -out config_mock.go . Config +type Config interface { + AuthToken(string) (string, string) + Get(string, string) (string, error) + GetOrDefault(string, string) (string, error) + Set(string, string, string) + UnsetHost(string) + Hosts() []string + DefaultHost() (string, string) + Aliases() *AliasConfig + Write() error +} + +func NewConfig() (Config, error) { + c, err := ghConfig.Read() + if err != nil { + return nil, err + } + return &cfg{c}, nil +} + +// Implements Config interface +type cfg struct { + cfg *ghConfig.Config +} + +func (c *cfg) AuthToken(hostname string) (string, string) { + return ghAuth.TokenForHost(hostname) +} + +func (c *cfg) Get(hostname, key string) (string, error) { + if hostname != "" { + val, err := c.cfg.Get([]string{hosts, hostname, key}) + if err == nil { + return val, err + } + } + + return c.cfg.Get([]string{key}) +} + +func (c *cfg) GetOrDefault(hostname, key string) (string, error) { + var val string + var err error + if hostname != "" { + val, err = c.cfg.Get([]string{hosts, hostname, key}) + if err == nil { + return val, err + } + } + + val, err = c.cfg.Get([]string{key}) + if err == nil { + return val, err + } + + if defaultExists(key) { + return defaultFor(key), nil + } + + return val, err +} + +func (c *cfg) Set(hostname, key, value string) { + if hostname == "" { + c.cfg.Set([]string{key}, value) + } + c.cfg.Set([]string{hosts, hostname, key}, value) +} + +func (c *cfg) UnsetHost(hostname string) { + if hostname == "" { + return + } + _ = c.cfg.Remove([]string{hosts, hostname}) +} + +func (c *cfg) Hosts() []string { + return ghAuth.KnownHosts() +} + +func (c *cfg) DefaultHost() (string, string) { + return ghAuth.DefaultHost() +} + +func (c *cfg) Aliases() *AliasConfig { + return &AliasConfig{cfg: c.cfg} +} + +func (c *cfg) Write() error { + return ghConfig.Write(c.cfg) +} + +func defaultFor(key string) string { + for _, co := range configOptions { + if co.Key == key { + return co.DefaultValue + } + } + return "" +} + +func defaultExists(key string) bool { + for _, co := range configOptions { + if co.Key == key { + return true + } + } + return false +} + +type AliasConfig struct { + cfg *ghConfig.Config +} + +func (a *AliasConfig) Get(alias string) (string, error) { + return a.cfg.Get([]string{aliases, alias}) +} + +func (a *AliasConfig) Add(alias, expansion string) { + a.cfg.Set([]string{aliases, alias}, expansion) +} + +func (a *AliasConfig) Delete(alias string) error { + return a.cfg.Remove([]string{aliases, alias}) +} + +func (a *AliasConfig) All() map[string]string { + out := map[string]string{} + keys, err := a.cfg.Keys([]string{aliases}) + if err != nil { + return out + } + for _, key := range keys { + val, _ := a.cfg.Get([]string{aliases, key}) + out[key] = val + } + return out +} + +type ConfigOption struct { + Key string + Description string + DefaultValue string + AllowedValues []string +} + +var configOptions = []ConfigOption{ + { + Key: "git_protocol", + Description: "the protocol to use for git clone and push operations", + DefaultValue: "https", + AllowedValues: []string{"https", "ssh"}, + }, + { + Key: "editor", + Description: "the text editor program to use for authoring text", + DefaultValue: "", + }, + { + Key: "prompt", + Description: "toggle interactive prompting in the terminal", + DefaultValue: "enabled", + AllowedValues: []string{"enabled", "disabled"}, + }, + { + Key: "pager", + Description: "the terminal pager program to send standard output to", + DefaultValue: "", + }, + { + Key: "http_unix_socket", + Description: "the path to a Unix socket through which to make an HTTP connection", + DefaultValue: "", + }, + { + Key: "browser", + Description: "the web browser to use for opening URLs", + DefaultValue: "", + }, +} + +func ConfigOptions() []ConfigOption { + return configOptions +} + +func HomeDirPath(subdir string) (string, error) { + homeDir, err := os.UserHomeDir() + if err != nil { + return "", err + } + + newPath := filepath.Join(homeDir, subdir) + return newPath, nil +} + +func StateDir() string { + return ghConfig.StateDir() +} + +func DataDir() string { + return ghConfig.DataDir() +} + +func ConfigDir() string { + return ghConfig.ConfigDir() +} diff --git a/internal/config/config_file.go b/internal/config/config_file.go deleted file mode 100644 index f3e2a7d1b..000000000 --- a/internal/config/config_file.go +++ /dev/null @@ -1,349 +0,0 @@ -package config - -import ( - "errors" - "fmt" - "io" - "os" - "path/filepath" - "runtime" - "syscall" - - "gopkg.in/yaml.v3" -) - -const ( - GH_CONFIG_DIR = "GH_CONFIG_DIR" - XDG_CONFIG_HOME = "XDG_CONFIG_HOME" - XDG_STATE_HOME = "XDG_STATE_HOME" - XDG_DATA_HOME = "XDG_DATA_HOME" - APP_DATA = "AppData" - LOCAL_APP_DATA = "LocalAppData" -) - -// Config path precedence -// 1. GH_CONFIG_DIR -// 2. XDG_CONFIG_HOME -// 3. AppData (windows only) -// 4. HOME -func ConfigDir() string { - var path string - if a := os.Getenv(GH_CONFIG_DIR); a != "" { - path = a - } else if b := os.Getenv(XDG_CONFIG_HOME); b != "" { - path = filepath.Join(b, "gh") - } else if c := os.Getenv(APP_DATA); runtime.GOOS == "windows" && c != "" { - path = filepath.Join(c, "GitHub CLI") - } else { - d, _ := os.UserHomeDir() - path = filepath.Join(d, ".config", "gh") - } - - // If the path does not exist and the GH_CONFIG_DIR flag is not set try - // migrating config from default paths. - if !dirExists(path) && os.Getenv(GH_CONFIG_DIR) == "" { - _ = autoMigrateConfigDir(path) - } - - return path -} - -// State path precedence -// 1. XDG_STATE_HOME -// 2. LocalAppData (windows only) -// 3. HOME -func StateDir() string { - var path string - if a := os.Getenv(XDG_STATE_HOME); a != "" { - path = filepath.Join(a, "gh") - } else if b := os.Getenv(LOCAL_APP_DATA); runtime.GOOS == "windows" && b != "" { - path = filepath.Join(b, "GitHub CLI") - } else { - c, _ := os.UserHomeDir() - path = filepath.Join(c, ".local", "state", "gh") - } - - // If the path does not exist try migrating state from default paths - if !dirExists(path) { - _ = autoMigrateStateDir(path) - } - - return path -} - -// Data path precedence -// 1. XDG_DATA_HOME -// 2. LocalAppData (windows only) -// 3. HOME -func DataDir() string { - var path string - if a := os.Getenv(XDG_DATA_HOME); a != "" { - path = filepath.Join(a, "gh") - } else if b := os.Getenv(LOCAL_APP_DATA); runtime.GOOS == "windows" && b != "" { - path = filepath.Join(b, "GitHub CLI") - } else { - c, _ := os.UserHomeDir() - path = filepath.Join(c, ".local", "share", "gh") - } - - return path -} - -var errSamePath = errors.New("same path") -var errNotExist = errors.New("not exist") - -// Check default path, os.UserHomeDir, for existing configs -// If configs exist then move them to newPath -func autoMigrateConfigDir(newPath string) error { - path, err := os.UserHomeDir() - if oldPath := filepath.Join(path, ".config", "gh"); err == nil && dirExists(oldPath) { - return migrateDir(oldPath, newPath) - } - - return errNotExist -} - -// Check default path, os.UserHomeDir, for existing state file (state.yml) -// If state file exist then move it to newPath -func autoMigrateStateDir(newPath string) error { - path, err := os.UserHomeDir() - if oldPath := filepath.Join(path, ".config", "gh"); err == nil && dirExists(oldPath) { - return migrateFile(oldPath, newPath, "state.yml") - } - - return errNotExist -} - -func migrateFile(oldPath, newPath, file string) error { - if oldPath == newPath { - return errSamePath - } - - oldFile := filepath.Join(oldPath, file) - newFile := filepath.Join(newPath, file) - - if !fileExists(oldFile) { - return errNotExist - } - - _ = os.MkdirAll(filepath.Dir(newFile), 0755) - return os.Rename(oldFile, newFile) -} - -func migrateDir(oldPath, newPath string) error { - if oldPath == newPath { - return errSamePath - } - - if !dirExists(oldPath) { - return errNotExist - } - - _ = os.MkdirAll(filepath.Dir(newPath), 0755) - return os.Rename(oldPath, newPath) -} - -func dirExists(path string) bool { - f, err := os.Stat(path) - return err == nil && f.IsDir() -} - -func fileExists(path string) bool { - f, err := os.Stat(path) - return err == nil && !f.IsDir() -} - -func ConfigFile() string { - return filepath.Join(ConfigDir(), "config.yml") -} - -func HostsConfigFile() string { - return filepath.Join(ConfigDir(), "hosts.yml") -} - -func ParseDefaultConfig() (Config, error) { - return parseConfig(ConfigFile()) -} - -func HomeDirPath(subdir string) (string, error) { - homeDir, err := os.UserHomeDir() - if err != nil { - return "", err - } - - newPath := filepath.Join(homeDir, subdir) - return newPath, nil -} - -var ReadConfigFile = func(filename string) ([]byte, error) { - f, err := os.Open(filename) - if err != nil { - return nil, pathError(err) - } - defer f.Close() - - data, err := io.ReadAll(f) - if err != nil { - return nil, err - } - - return data, nil -} - -var WriteConfigFile = func(filename string, data []byte) error { - err := os.MkdirAll(filepath.Dir(filename), 0771) - if err != nil { - return pathError(err) - } - - cfgFile, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) // cargo coded from setup - if err != nil { - return err - } - defer cfgFile.Close() - - _, err = cfgFile.Write(data) - return err -} - -var BackupConfigFile = func(filename string) error { - return os.Rename(filename, filename+".bak") -} - -func parseConfigFile(filename string) ([]byte, *yaml.Node, error) { - data, err := ReadConfigFile(filename) - if err != nil { - return nil, nil, err - } - - root, err := parseConfigData(data) - if err != nil { - return nil, nil, err - } - return data, root, err -} - -func parseConfigData(data []byte) (*yaml.Node, error) { - var root yaml.Node - err := yaml.Unmarshal(data, &root) - if err != nil { - return nil, err - } - - if len(root.Content) == 0 { - return &yaml.Node{ - Kind: yaml.DocumentNode, - Content: []*yaml.Node{{Kind: yaml.MappingNode}}, - }, nil - } - if root.Content[0].Kind != yaml.MappingNode { - return &root, fmt.Errorf("expected a top level map") - } - return &root, nil -} - -func isLegacy(root *yaml.Node) bool { - for _, v := range root.Content[0].Content { - if v.Value == "github.com" { - return true - } - } - - return false -} - -func migrateConfig(filename string) error { - b, err := ReadConfigFile(filename) - if err != nil { - return err - } - - var hosts map[string][]yaml.Node - err = yaml.Unmarshal(b, &hosts) - if err != nil { - return fmt.Errorf("error decoding legacy format: %w", err) - } - - cfg := NewBlankConfig() - for hostname, entries := range hosts { - if len(entries) < 1 { - continue - } - mapContent := entries[0].Content - for i := 0; i < len(mapContent)-1; i += 2 { - if err := cfg.Set(hostname, mapContent[i].Value, mapContent[i+1].Value); err != nil { - return err - } - } - } - - err = BackupConfigFile(filename) - if err != nil { - return fmt.Errorf("failed to back up existing config: %w", err) - } - - return cfg.Write() -} - -func parseConfig(filename string) (Config, error) { - _, root, err := parseConfigFile(filename) - if err != nil { - if os.IsNotExist(err) { - root = NewBlankRoot() - } else { - return nil, err - } - } - - if isLegacy(root) { - err = migrateConfig(filename) - if err != nil { - return nil, fmt.Errorf("error migrating legacy config: %w", err) - } - - _, root, err = parseConfigFile(filename) - if err != nil { - return nil, fmt.Errorf("failed to reparse migrated config: %w", err) - } - } else { - if _, hostsRoot, err := parseConfigFile(HostsConfigFile()); err == nil { - if len(hostsRoot.Content[0].Content) > 0 { - newContent := []*yaml.Node{ - {Value: "hosts"}, - hostsRoot.Content[0], - } - restContent := root.Content[0].Content - root.Content[0].Content = append(newContent, restContent...) - } - } else if !errors.Is(err, os.ErrNotExist) { - return nil, err - } - } - - return NewConfig(root), nil -} - -func pathError(err error) error { - var pathError *os.PathError - if errors.As(err, &pathError) && errors.Is(pathError.Err, syscall.ENOTDIR) { - if p := findRegularFile(pathError.Path); p != "" { - return fmt.Errorf("remove or rename regular file `%s` (must be a directory)", p) - } - - } - return err -} - -func findRegularFile(p string) string { - for { - if s, err := os.Stat(p); err == nil && s.Mode().IsRegular() { - return p - } - newPath := filepath.Dir(p) - if newPath == p || newPath == "/" || newPath == "." { - break - } - p = newPath - } - return "" -} diff --git a/internal/config/config_file_test.go b/internal/config/config_file_test.go deleted file mode 100644 index e7c37caa0..000000000 --- a/internal/config/config_file_test.go +++ /dev/null @@ -1,576 +0,0 @@ -package config - -import ( - "bytes" - "fmt" - "os" - "path/filepath" - "runtime" - "testing" - - "github.com/MakeNowJust/heredoc" - "github.com/stretchr/testify/assert" - "gopkg.in/yaml.v3" -) - -func Test_parseConfig(t *testing.T) { - defer stubConfig(`--- -hosts: - github.com: - user: monalisa - oauth_token: OTOKEN -`, "")() - config, err := parseConfig("config.yml") - assert.NoError(t, err) - user, err := config.Get("github.com", "user") - assert.NoError(t, err) - assert.Equal(t, "monalisa", user) - token, err := config.Get("github.com", "oauth_token") - assert.NoError(t, err) - assert.Equal(t, "OTOKEN", token) -} - -func Test_parseConfig_multipleHosts(t *testing.T) { - defer stubConfig(`--- -hosts: - example.com: - user: wronguser - oauth_token: NOTTHIS - github.com: - user: monalisa - oauth_token: OTOKEN -`, "")() - config, err := parseConfig("config.yml") - assert.NoError(t, err) - user, err := config.Get("github.com", "user") - assert.NoError(t, err) - assert.Equal(t, "monalisa", user) - token, err := config.Get("github.com", "oauth_token") - assert.NoError(t, err) - assert.Equal(t, "OTOKEN", token) -} - -func Test_parseConfig_hostsFile(t *testing.T) { - defer stubConfig("", `--- -github.com: - user: monalisa - oauth_token: OTOKEN -`)() - config, err := parseConfig("config.yml") - assert.NoError(t, err) - user, err := config.Get("github.com", "user") - assert.NoError(t, err) - assert.Equal(t, "monalisa", user) - token, err := config.Get("github.com", "oauth_token") - assert.NoError(t, err) - assert.Equal(t, "OTOKEN", token) -} - -func Test_parseConfig_hostFallback(t *testing.T) { - defer stubConfig(`--- -git_protocol: ssh -`, `--- -github.com: - user: monalisa - oauth_token: OTOKEN -example.com: - user: wronguser - oauth_token: NOTTHIS - git_protocol: https -`)() - config, err := parseConfig("config.yml") - assert.NoError(t, err) - val, err := config.GetOrDefault("example.com", "git_protocol") - assert.NoError(t, err) - assert.Equal(t, "https", val) - val, err = config.GetOrDefault("github.com", "git_protocol") - assert.NoError(t, err) - assert.Equal(t, "ssh", val) - val, err = config.GetOrDefault("nonexistent.io", "git_protocol") - assert.NoError(t, err) - assert.Equal(t, "ssh", val) -} - -func Test_parseConfig_migrateConfig(t *testing.T) { - defer stubConfig(`--- -github.com: - - user: keiyuri - oauth_token: 123456 -`, "")() - - mainBuf := bytes.Buffer{} - hostsBuf := bytes.Buffer{} - defer StubWriteConfig(&mainBuf, &hostsBuf)() - defer StubBackupConfig()() - - _, err := parseConfig("config.yml") - assert.NoError(t, err) - - expectedHosts := `github.com: - user: keiyuri - oauth_token: "123456" -` - - assert.Equal(t, expectedHosts, hostsBuf.String()) - assert.NotContains(t, mainBuf.String(), "github.com") - assert.NotContains(t, mainBuf.String(), "oauth_token") -} - -func Test_parseConfigFile(t *testing.T) { - tests := []struct { - contents string - wantsErr bool - }{ - { - contents: "", - wantsErr: true, - }, - { - contents: " ", - wantsErr: false, - }, - { - contents: "\n", - wantsErr: false, - }, - } - - for _, tt := range tests { - t.Run(fmt.Sprintf("contents: %q", tt.contents), func(t *testing.T) { - defer stubConfig(tt.contents, "")() - _, yamlRoot, err := parseConfigFile("config.yml") - if tt.wantsErr != (err != nil) { - t.Fatalf("got error: %v", err) - } - if tt.wantsErr { - return - } - assert.Equal(t, yaml.MappingNode, yamlRoot.Content[0].Kind) - assert.Equal(t, 0, len(yamlRoot.Content[0].Content)) - }) - } -} - -func Test_ConfigDir(t *testing.T) { - tempDir := t.TempDir() - - tests := []struct { - name string - onlyWindows bool - env map[string]string - output string - }{ - { - name: "HOME/USERPROFILE specified", - env: map[string]string{ - "GH_CONFIG_DIR": "", - "XDG_CONFIG_HOME": "", - "AppData": "", - "USERPROFILE": tempDir, - "HOME": tempDir, - }, - output: filepath.Join(tempDir, ".config", "gh"), - }, - { - name: "GH_CONFIG_DIR specified", - env: map[string]string{ - "GH_CONFIG_DIR": filepath.Join(tempDir, "gh_config_dir"), - }, - output: filepath.Join(tempDir, "gh_config_dir"), - }, - { - name: "XDG_CONFIG_HOME specified", - env: map[string]string{ - "XDG_CONFIG_HOME": tempDir, - }, - output: filepath.Join(tempDir, "gh"), - }, - { - name: "GH_CONFIG_DIR and XDG_CONFIG_HOME specified", - env: map[string]string{ - "GH_CONFIG_DIR": filepath.Join(tempDir, "gh_config_dir"), - "XDG_CONFIG_HOME": tempDir, - }, - output: filepath.Join(tempDir, "gh_config_dir"), - }, - { - name: "AppData specified", - onlyWindows: true, - env: map[string]string{ - "AppData": tempDir, - }, - output: filepath.Join(tempDir, "GitHub CLI"), - }, - { - name: "GH_CONFIG_DIR and AppData specified", - onlyWindows: true, - env: map[string]string{ - "GH_CONFIG_DIR": filepath.Join(tempDir, "gh_config_dir"), - "AppData": tempDir, - }, - output: filepath.Join(tempDir, "gh_config_dir"), - }, - { - name: "XDG_CONFIG_HOME and AppData specified", - onlyWindows: true, - env: map[string]string{ - "XDG_CONFIG_HOME": tempDir, - "AppData": tempDir, - }, - output: filepath.Join(tempDir, "gh"), - }, - } - - for _, tt := range tests { - if tt.onlyWindows && runtime.GOOS != "windows" { - continue - } - t.Run(tt.name, func(t *testing.T) { - if tt.env != nil { - for k, v := range tt.env { - old := os.Getenv(k) - os.Setenv(k, v) - defer os.Setenv(k, old) - } - } - - // Create directory to skip auto migration code - // which gets run when target directory does not exist - _ = os.MkdirAll(tt.output, 0755) - - assert.Equal(t, tt.output, ConfigDir()) - }) - } -} - -func Test_configFile_Write_toDisk(t *testing.T) { - configDir := filepath.Join(t.TempDir(), ".config", "gh") - _ = os.MkdirAll(configDir, 0755) - os.Setenv(GH_CONFIG_DIR, configDir) - defer os.Unsetenv(GH_CONFIG_DIR) - - cfg := NewFromString(`pager: less`) - err := cfg.Write() - if err != nil { - t.Fatal(err) - } - - expectedConfig := "pager: less\n" - if configBytes, err := os.ReadFile(filepath.Join(configDir, "config.yml")); err != nil { - t.Error(err) - } else if string(configBytes) != expectedConfig { - t.Errorf("expected config.yml %q, got %q", expectedConfig, string(configBytes)) - } - - if configBytes, err := os.ReadFile(filepath.Join(configDir, "hosts.yml")); err != nil { - t.Error(err) - } else if string(configBytes) != "" { - t.Errorf("unexpected hosts.yml: %q", string(configBytes)) - } -} - -func Test_configFile_WriteHosts_toDisk(t *testing.T) { - configDir := filepath.Join(t.TempDir(), ".config", "gh") - _ = os.MkdirAll(configDir, 0755) - os.Setenv(GH_CONFIG_DIR, configDir) - defer os.Unsetenv(GH_CONFIG_DIR) - - cfg := NewFromString(heredoc.Doc(` - hosts: - github.com: - user: monalisa - oauth_token: TOKEN - `)) - err := cfg.WriteHosts() - if err != nil { - t.Fatal(err) - } - - expectedConfig := "github.com:\n user: monalisa\n oauth_token: TOKEN\n" - actualConfig, err := os.ReadFile(filepath.Join(configDir, "hosts.yml")) - assert.NoError(t, err) - assert.Equal(t, expectedConfig, string(actualConfig)) - _, nonExistErr := os.Stat(filepath.Join(configDir, "config.yml")) - assert.Error(t, nonExistErr) -} - -func Test_autoMigrateConfigDir_noMigration_notExist(t *testing.T) { - homeDir := t.TempDir() - migrateDir := t.TempDir() - - homeEnvVar := "HOME" - if runtime.GOOS == "windows" { - homeEnvVar = "USERPROFILE" - } - old := os.Getenv(homeEnvVar) - os.Setenv(homeEnvVar, homeDir) - defer os.Setenv(homeEnvVar, old) - - err := autoMigrateConfigDir(migrateDir) - assert.Equal(t, errNotExist, err) - - files, err := os.ReadDir(migrateDir) - assert.NoError(t, err) - assert.Equal(t, 0, len(files)) -} - -func Test_autoMigrateConfigDir_noMigration_samePath(t *testing.T) { - homeDir := t.TempDir() - migrateDir := filepath.Join(homeDir, ".config", "gh") - err := os.MkdirAll(migrateDir, 0755) - assert.NoError(t, err) - - homeEnvVar := "HOME" - if runtime.GOOS == "windows" { - homeEnvVar = "USERPROFILE" - } - old := os.Getenv(homeEnvVar) - os.Setenv(homeEnvVar, homeDir) - defer os.Setenv(homeEnvVar, old) - - err = autoMigrateConfigDir(migrateDir) - assert.Equal(t, errSamePath, err) - - files, err := os.ReadDir(migrateDir) - assert.NoError(t, err) - assert.Equal(t, 0, len(files)) -} - -func Test_autoMigrateConfigDir_migration(t *testing.T) { - homeDir := t.TempDir() - migrateDir := t.TempDir() - homeConfigDir := filepath.Join(homeDir, ".config", "gh") - migrateConfigDir := filepath.Join(migrateDir, ".config", "gh") - - homeEnvVar := "HOME" - if runtime.GOOS == "windows" { - homeEnvVar = "USERPROFILE" - } - old := os.Getenv(homeEnvVar) - os.Setenv(homeEnvVar, homeDir) - defer os.Setenv(homeEnvVar, old) - - err := os.MkdirAll(homeConfigDir, 0755) - assert.NoError(t, err) - f, err := os.CreateTemp(homeConfigDir, "") - assert.NoError(t, err) - f.Close() - - err = autoMigrateConfigDir(migrateConfigDir) - assert.NoError(t, err) - - _, err = os.ReadDir(homeConfigDir) - assert.True(t, os.IsNotExist(err)) - - files, err := os.ReadDir(migrateConfigDir) - assert.NoError(t, err) - assert.Equal(t, 1, len(files)) -} - -func Test_StateDir(t *testing.T) { - tempDir := t.TempDir() - - tests := []struct { - name string - onlyWindows bool - env map[string]string - output string - }{ - { - name: "HOME/USERPROFILE specified", - env: map[string]string{ - "XDG_STATE_HOME": "", - "GH_CONFIG_DIR": "", - "XDG_CONFIG_HOME": "", - "LocalAppData": "", - "USERPROFILE": tempDir, - "HOME": tempDir, - }, - output: filepath.Join(tempDir, ".local", "state", "gh"), - }, - { - name: "XDG_STATE_HOME specified", - env: map[string]string{ - "XDG_STATE_HOME": tempDir, - }, - output: filepath.Join(tempDir, "gh"), - }, - { - name: "LocalAppData specified", - onlyWindows: true, - env: map[string]string{ - "LocalAppData": tempDir, - }, - output: filepath.Join(tempDir, "GitHub CLI"), - }, - { - name: "XDG_STATE_HOME and LocalAppData specified", - onlyWindows: true, - env: map[string]string{ - "XDG_STATE_HOME": tempDir, - "LocalAppData": tempDir, - }, - output: filepath.Join(tempDir, "gh"), - }, - } - - for _, tt := range tests { - if tt.onlyWindows && runtime.GOOS != "windows" { - continue - } - t.Run(tt.name, func(t *testing.T) { - if tt.env != nil { - for k, v := range tt.env { - old := os.Getenv(k) - os.Setenv(k, v) - defer os.Setenv(k, old) - } - } - - // Create directory to skip auto migration code - // which gets run when target directory does not exist - _ = os.MkdirAll(tt.output, 0755) - - assert.Equal(t, tt.output, StateDir()) - }) - } -} - -func Test_autoMigrateStateDir_noMigration_notExist(t *testing.T) { - homeDir := t.TempDir() - migrateDir := t.TempDir() - - homeEnvVar := "HOME" - if runtime.GOOS == "windows" { - homeEnvVar = "USERPROFILE" - } - old := os.Getenv(homeEnvVar) - os.Setenv(homeEnvVar, homeDir) - defer os.Setenv(homeEnvVar, old) - - err := autoMigrateStateDir(migrateDir) - assert.Equal(t, errNotExist, err) - - files, err := os.ReadDir(migrateDir) - assert.NoError(t, err) - assert.Equal(t, 0, len(files)) -} - -func Test_autoMigrateStateDir_noMigration_samePath(t *testing.T) { - homeDir := t.TempDir() - migrateDir := filepath.Join(homeDir, ".config", "gh") - err := os.MkdirAll(migrateDir, 0755) - assert.NoError(t, err) - - homeEnvVar := "HOME" - if runtime.GOOS == "windows" { - homeEnvVar = "USERPROFILE" - } - old := os.Getenv(homeEnvVar) - os.Setenv(homeEnvVar, homeDir) - defer os.Setenv(homeEnvVar, old) - - err = autoMigrateStateDir(migrateDir) - assert.Equal(t, errSamePath, err) - - files, err := os.ReadDir(migrateDir) - assert.NoError(t, err) - assert.Equal(t, 0, len(files)) -} - -func Test_autoMigrateStateDir_migration(t *testing.T) { - homeDir := t.TempDir() - migrateDir := t.TempDir() - homeConfigDir := filepath.Join(homeDir, ".config", "gh") - migrateStateDir := filepath.Join(migrateDir, ".local", "state", "gh") - - homeEnvVar := "HOME" - if runtime.GOOS == "windows" { - homeEnvVar = "USERPROFILE" - } - old := os.Getenv(homeEnvVar) - os.Setenv(homeEnvVar, homeDir) - defer os.Setenv(homeEnvVar, old) - - err := os.MkdirAll(homeConfigDir, 0755) - assert.NoError(t, err) - err = os.WriteFile(filepath.Join(homeConfigDir, "state.yml"), nil, 0755) - assert.NoError(t, err) - - err = autoMigrateStateDir(migrateStateDir) - assert.NoError(t, err) - - files, err := os.ReadDir(homeConfigDir) - assert.NoError(t, err) - assert.Equal(t, 0, len(files)) - - files, err = os.ReadDir(migrateStateDir) - assert.NoError(t, err) - assert.Equal(t, 1, len(files)) - assert.Equal(t, "state.yml", files[0].Name()) -} - -func Test_DataDir(t *testing.T) { - tempDir := t.TempDir() - - tests := []struct { - name string - onlyWindows bool - env map[string]string - output string - }{ - { - name: "HOME/USERPROFILE specified", - env: map[string]string{ - "XDG_DATA_HOME": "", - "GH_CONFIG_DIR": "", - "XDG_CONFIG_HOME": "", - "LocalAppData": "", - "USERPROFILE": tempDir, - "HOME": tempDir, - }, - output: filepath.Join(tempDir, ".local", "share", "gh"), - }, - { - name: "XDG_DATA_HOME specified", - env: map[string]string{ - "XDG_DATA_HOME": tempDir, - }, - output: filepath.Join(tempDir, "gh"), - }, - { - name: "LocalAppData specified", - onlyWindows: true, - env: map[string]string{ - "LocalAppData": tempDir, - }, - output: filepath.Join(tempDir, "GitHub CLI"), - }, - { - name: "XDG_DATA_HOME and LocalAppData specified", - onlyWindows: true, - env: map[string]string{ - "XDG_DATA_HOME": tempDir, - "LocalAppData": tempDir, - }, - output: filepath.Join(tempDir, "gh"), - }, - } - - for _, tt := range tests { - if tt.onlyWindows && runtime.GOOS != "windows" { - continue - } - t.Run(tt.name, func(t *testing.T) { - if tt.env != nil { - for k, v := range tt.env { - old := os.Getenv(k) - os.Setenv(k, v) - defer os.Setenv(k, old) - } - } - - assert.Equal(t, tt.output, DataDir()) - }) - } -} diff --git a/internal/config/config_map.go b/internal/config/config_map.go deleted file mode 100644 index c391bc486..000000000 --- a/internal/config/config_map.go +++ /dev/null @@ -1,113 +0,0 @@ -package config - -import ( - "errors" - - "gopkg.in/yaml.v3" -) - -// This type implements a low-level get/set config that is backed by an in-memory tree of yaml -// nodes. It allows us to interact with a yaml-based config programmatically, preserving any -// comments that were present when the yaml was parsed. -type ConfigMap struct { - Root *yaml.Node -} - -type ConfigEntry struct { - KeyNode *yaml.Node - ValueNode *yaml.Node - Index int -} - -type NotFoundError struct { - error -} - -func (cm *ConfigMap) Empty() bool { - return cm.Root == nil || len(cm.Root.Content) == 0 -} - -func (cm *ConfigMap) GetStringValue(key string) (string, error) { - entry, err := cm.FindEntry(key) - if err != nil { - return "", err - } - return entry.ValueNode.Value, nil -} - -func (cm *ConfigMap) SetStringValue(key, value string) error { - entry, err := cm.FindEntry(key) - if err == nil { - entry.ValueNode.Value = value - return nil - } - - var notFound *NotFoundError - if err != nil && !errors.As(err, ¬Found) { - return err - } - - keyNode := &yaml.Node{ - Kind: yaml.ScalarNode, - Value: key, - } - valueNode := &yaml.Node{ - Kind: yaml.ScalarNode, - Tag: "!!str", - Value: value, - } - - cm.Root.Content = append(cm.Root.Content, keyNode, valueNode) - return nil -} - -func (cm *ConfigMap) FindEntry(key string) (*ConfigEntry, error) { - ce := &ConfigEntry{} - - if cm.Empty() { - return ce, &NotFoundError{errors.New("not found")} - } - - // Content slice goes [key1, value1, key2, value2, ...]. - topLevelPairs := cm.Root.Content - for i, v := range topLevelPairs { - // Skip every other slice item since we only want to check against keys. - if i%2 != 0 { - continue - } - if v.Value == key { - ce.KeyNode = v - ce.Index = i - if i+1 < len(topLevelPairs) { - ce.ValueNode = topLevelPairs[i+1] - } - return ce, nil - } - } - - return ce, &NotFoundError{errors.New("not found")} -} - -func (cm *ConfigMap) RemoveEntry(key string) { - if cm.Empty() { - return - } - - newContent := []*yaml.Node{} - - var skipNext bool - for i, v := range cm.Root.Content { - if skipNext { - skipNext = false - continue - } - if i%2 != 0 || v.Value != key { - newContent = append(newContent, v) - } else { - // Don't append current node and skip the next which is this key's value. - skipNext = true - } - } - - cm.Root.Content = newContent -} diff --git a/internal/config/config_map_test.go b/internal/config/config_map_test.go deleted file mode 100644 index 4dc49d01b..000000000 --- a/internal/config/config_map_test.go +++ /dev/null @@ -1,187 +0,0 @@ -package config - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "gopkg.in/yaml.v3" -) - -func TestFindEntry(t *testing.T) { - tests := []struct { - name string - key string - output string - wantErr bool - }{ - { - name: "find key", - key: "valid", - output: "present", - }, - { - name: "find key that is not present", - key: "invalid", - wantErr: true, - }, - { - name: "find key with blank value", - key: "blank", - output: "", - }, - { - name: "find key that has same content as a value", - key: "same", - output: "logical", - }, - } - - for _, tt := range tests { - cm := ConfigMap{Root: testYaml()} - t.Run(tt.name, func(t *testing.T) { - out, err := cm.FindEntry(tt.key) - if tt.wantErr { - assert.EqualError(t, err, "not found") - return - } - assert.NoError(t, err) - assert.Equal(t, tt.output, out.ValueNode.Value) - }) - } -} - -func TestEmpty(t *testing.T) { - cm := ConfigMap{} - assert.Equal(t, true, cm.Empty()) - cm.Root = &yaml.Node{ - Content: []*yaml.Node{ - { - Value: "test", - }, - }, - } - assert.Equal(t, false, cm.Empty()) -} - -func TestGetStringValue(t *testing.T) { - tests := []struct { - name string - key string - wantValue string - wantErr bool - }{ - { - name: "get key", - key: "valid", - wantValue: "present", - }, - { - name: "get key that is not present", - key: "invalid", - wantErr: true, - }, - { - name: "get key that has same content as a value", - key: "same", - wantValue: "logical", - }, - } - - for _, tt := range tests { - cm := ConfigMap{Root: testYaml()} - t.Run(tt.name, func(t *testing.T) { - val, err := cm.GetStringValue(tt.key) - if tt.wantErr { - assert.EqualError(t, err, "not found") - return - } - assert.Equal(t, tt.wantValue, val) - }) - } -} - -func TestSetStringValue(t *testing.T) { - tests := []struct { - name string - key string - value string - }{ - { - name: "set key that is not present", - key: "notPresent", - value: "test1", - }, - { - name: "set key that is present", - key: "erroneous", - value: "test2", - }, - { - name: "set key that is blank", - key: "blank", - value: "test3", - }, - { - name: "set key that has same content as a value", - key: "present", - value: "test4", - }, - } - - for _, tt := range tests { - cm := ConfigMap{Root: testYaml()} - t.Run(tt.name, func(t *testing.T) { - err := cm.SetStringValue(tt.key, tt.value) - assert.NoError(t, err) - val, err := cm.GetStringValue(tt.key) - assert.NoError(t, err) - assert.Equal(t, tt.value, val) - }) - } -} - -func TestRemoveEntry(t *testing.T) { - tests := []struct { - name string - key string - wantLength int - }{ - { - name: "remove key", - key: "erroneous", - wantLength: 6, - }, - { - name: "remove key that is not present", - key: "invalid", - wantLength: 8, - }, - { - name: "remove key that has same content as a value", - key: "same", - wantLength: 6, - }, - } - - for _, tt := range tests { - cm := ConfigMap{Root: testYaml()} - t.Run(tt.name, func(t *testing.T) { - cm.RemoveEntry(tt.key) - assert.Equal(t, tt.wantLength, len(cm.Root.Content)) - _, err := cm.FindEntry(tt.key) - assert.EqualError(t, err, "not found") - }) - } -} - -func testYaml() *yaml.Node { - var root yaml.Node - var data = ` -valid: present -erroneous: same -blank: -same: logical -` - _ = yaml.Unmarshal([]byte(data), &root) - return root.Content[0] -} diff --git a/internal/config/config_mock.go b/internal/config/config_mock.go new file mode 100644 index 000000000..13fd1fae8 --- /dev/null +++ b/internal/config/config_mock.go @@ -0,0 +1,413 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package config + +import ( + "sync" +) + +// Ensure, that ConfigMock does implement Config. +// If this is not the case, regenerate this file with moq. +var _ Config = &ConfigMock{} + +// ConfigMock is a mock implementation of Config. +// +// func TestSomethingThatUsesConfig(t *testing.T) { +// +// // make and configure a mocked Config +// mockedConfig := &ConfigMock{ +// AliasesFunc: func() *AliasConfig { +// panic("mock out the Aliases method") +// }, +// AuthTokenFunc: func(s string) (string, string) { +// panic("mock out the AuthToken method") +// }, +// DefaultHostFunc: func() (string, string) { +// panic("mock out the DefaultHost method") +// }, +// GetFunc: func(s1 string, s2 string) (string, error) { +// panic("mock out the Get method") +// }, +// GetOrDefaultFunc: func(s1 string, s2 string) (string, error) { +// panic("mock out the GetOrDefault method") +// }, +// HostsFunc: func() []string { +// panic("mock out the Hosts method") +// }, +// SetFunc: func(s1 string, s2 string, s3 string) { +// panic("mock out the Set method") +// }, +// UnsetHostFunc: func(s string) { +// panic("mock out the UnsetHost method") +// }, +// WriteFunc: func() error { +// panic("mock out the Write method") +// }, +// } +// +// // use mockedConfig in code that requires Config +// // and then make assertions. +// +// } +type ConfigMock struct { + // AliasesFunc mocks the Aliases method. + AliasesFunc func() *AliasConfig + + // AuthTokenFunc mocks the AuthToken method. + AuthTokenFunc func(s string) (string, string) + + // DefaultHostFunc mocks the DefaultHost method. + DefaultHostFunc func() (string, string) + + // GetFunc mocks the Get method. + GetFunc func(s1 string, s2 string) (string, error) + + // GetOrDefaultFunc mocks the GetOrDefault method. + GetOrDefaultFunc func(s1 string, s2 string) (string, error) + + // HostsFunc mocks the Hosts method. + HostsFunc func() []string + + // SetFunc mocks the Set method. + SetFunc func(s1 string, s2 string, s3 string) + + // UnsetHostFunc mocks the UnsetHost method. + UnsetHostFunc func(s string) + + // WriteFunc mocks the Write method. + WriteFunc func() error + + // calls tracks calls to the methods. + calls struct { + // Aliases holds details about calls to the Aliases method. + Aliases []struct { + } + // AuthToken holds details about calls to the AuthToken method. + AuthToken []struct { + // S is the s argument value. + S string + } + // DefaultHost holds details about calls to the DefaultHost method. + DefaultHost []struct { + } + // Get holds details about calls to the Get method. + Get []struct { + // S1 is the s1 argument value. + S1 string + // S2 is the s2 argument value. + S2 string + } + // GetOrDefault holds details about calls to the GetOrDefault method. + GetOrDefault []struct { + // S1 is the s1 argument value. + S1 string + // S2 is the s2 argument value. + S2 string + } + // Hosts holds details about calls to the Hosts method. + Hosts []struct { + } + // Set holds details about calls to the Set method. + Set []struct { + // S1 is the s1 argument value. + S1 string + // S2 is the s2 argument value. + S2 string + // S3 is the s3 argument value. + S3 string + } + // UnsetHost holds details about calls to the UnsetHost method. + UnsetHost []struct { + // S is the s argument value. + S string + } + // Write holds details about calls to the Write method. + Write []struct { + } + } + lockAliases sync.RWMutex + lockAuthToken sync.RWMutex + lockDefaultHost sync.RWMutex + lockGet sync.RWMutex + lockGetOrDefault sync.RWMutex + lockHosts sync.RWMutex + lockSet sync.RWMutex + lockUnsetHost sync.RWMutex + lockWrite sync.RWMutex +} + +// Aliases calls AliasesFunc. +func (mock *ConfigMock) Aliases() *AliasConfig { + if mock.AliasesFunc == nil { + panic("ConfigMock.AliasesFunc: method is nil but Config.Aliases was just called") + } + callInfo := struct { + }{} + mock.lockAliases.Lock() + mock.calls.Aliases = append(mock.calls.Aliases, callInfo) + mock.lockAliases.Unlock() + return mock.AliasesFunc() +} + +// AliasesCalls gets all the calls that were made to Aliases. +// Check the length with: +// len(mockedConfig.AliasesCalls()) +func (mock *ConfigMock) AliasesCalls() []struct { +} { + var calls []struct { + } + mock.lockAliases.RLock() + calls = mock.calls.Aliases + mock.lockAliases.RUnlock() + return calls +} + +// AuthToken calls AuthTokenFunc. +func (mock *ConfigMock) AuthToken(s string) (string, string) { + if mock.AuthTokenFunc == nil { + panic("ConfigMock.AuthTokenFunc: method is nil but Config.AuthToken was just called") + } + callInfo := struct { + S string + }{ + S: s, + } + mock.lockAuthToken.Lock() + mock.calls.AuthToken = append(mock.calls.AuthToken, callInfo) + mock.lockAuthToken.Unlock() + return mock.AuthTokenFunc(s) +} + +// AuthTokenCalls gets all the calls that were made to AuthToken. +// Check the length with: +// len(mockedConfig.AuthTokenCalls()) +func (mock *ConfigMock) AuthTokenCalls() []struct { + S string +} { + var calls []struct { + S string + } + mock.lockAuthToken.RLock() + calls = mock.calls.AuthToken + mock.lockAuthToken.RUnlock() + return calls +} + +// DefaultHost calls DefaultHostFunc. +func (mock *ConfigMock) DefaultHost() (string, string) { + if mock.DefaultHostFunc == nil { + panic("ConfigMock.DefaultHostFunc: method is nil but Config.DefaultHost was just called") + } + callInfo := struct { + }{} + mock.lockDefaultHost.Lock() + mock.calls.DefaultHost = append(mock.calls.DefaultHost, callInfo) + mock.lockDefaultHost.Unlock() + return mock.DefaultHostFunc() +} + +// DefaultHostCalls gets all the calls that were made to DefaultHost. +// Check the length with: +// len(mockedConfig.DefaultHostCalls()) +func (mock *ConfigMock) DefaultHostCalls() []struct { +} { + var calls []struct { + } + mock.lockDefaultHost.RLock() + calls = mock.calls.DefaultHost + mock.lockDefaultHost.RUnlock() + return calls +} + +// Get calls GetFunc. +func (mock *ConfigMock) Get(s1 string, s2 string) (string, error) { + if mock.GetFunc == nil { + panic("ConfigMock.GetFunc: method is nil but Config.Get was just called") + } + callInfo := struct { + S1 string + S2 string + }{ + S1: s1, + S2: s2, + } + mock.lockGet.Lock() + mock.calls.Get = append(mock.calls.Get, callInfo) + mock.lockGet.Unlock() + return mock.GetFunc(s1, s2) +} + +// GetCalls gets all the calls that were made to Get. +// Check the length with: +// len(mockedConfig.GetCalls()) +func (mock *ConfigMock) GetCalls() []struct { + S1 string + S2 string +} { + var calls []struct { + S1 string + S2 string + } + mock.lockGet.RLock() + calls = mock.calls.Get + mock.lockGet.RUnlock() + return calls +} + +// GetOrDefault calls GetOrDefaultFunc. +func (mock *ConfigMock) GetOrDefault(s1 string, s2 string) (string, error) { + if mock.GetOrDefaultFunc == nil { + panic("ConfigMock.GetOrDefaultFunc: method is nil but Config.GetOrDefault was just called") + } + callInfo := struct { + S1 string + S2 string + }{ + S1: s1, + S2: s2, + } + mock.lockGetOrDefault.Lock() + mock.calls.GetOrDefault = append(mock.calls.GetOrDefault, callInfo) + mock.lockGetOrDefault.Unlock() + return mock.GetOrDefaultFunc(s1, s2) +} + +// GetOrDefaultCalls gets all the calls that were made to GetOrDefault. +// Check the length with: +// len(mockedConfig.GetOrDefaultCalls()) +func (mock *ConfigMock) GetOrDefaultCalls() []struct { + S1 string + S2 string +} { + var calls []struct { + S1 string + S2 string + } + mock.lockGetOrDefault.RLock() + calls = mock.calls.GetOrDefault + mock.lockGetOrDefault.RUnlock() + return calls +} + +// Hosts calls HostsFunc. +func (mock *ConfigMock) Hosts() []string { + if mock.HostsFunc == nil { + panic("ConfigMock.HostsFunc: method is nil but Config.Hosts was just called") + } + callInfo := struct { + }{} + mock.lockHosts.Lock() + mock.calls.Hosts = append(mock.calls.Hosts, callInfo) + mock.lockHosts.Unlock() + return mock.HostsFunc() +} + +// HostsCalls gets all the calls that were made to Hosts. +// Check the length with: +// len(mockedConfig.HostsCalls()) +func (mock *ConfigMock) HostsCalls() []struct { +} { + var calls []struct { + } + mock.lockHosts.RLock() + calls = mock.calls.Hosts + mock.lockHosts.RUnlock() + return calls +} + +// Set calls SetFunc. +func (mock *ConfigMock) Set(s1 string, s2 string, s3 string) { + if mock.SetFunc == nil { + panic("ConfigMock.SetFunc: method is nil but Config.Set was just called") + } + callInfo := struct { + S1 string + S2 string + S3 string + }{ + S1: s1, + S2: s2, + S3: s3, + } + mock.lockSet.Lock() + mock.calls.Set = append(mock.calls.Set, callInfo) + mock.lockSet.Unlock() + mock.SetFunc(s1, s2, s3) +} + +// SetCalls gets all the calls that were made to Set. +// Check the length with: +// len(mockedConfig.SetCalls()) +func (mock *ConfigMock) SetCalls() []struct { + S1 string + S2 string + S3 string +} { + var calls []struct { + S1 string + S2 string + S3 string + } + mock.lockSet.RLock() + calls = mock.calls.Set + mock.lockSet.RUnlock() + return calls +} + +// UnsetHost calls UnsetHostFunc. +func (mock *ConfigMock) UnsetHost(s string) { + if mock.UnsetHostFunc == nil { + panic("ConfigMock.UnsetHostFunc: method is nil but Config.UnsetHost was just called") + } + callInfo := struct { + S string + }{ + S: s, + } + mock.lockUnsetHost.Lock() + mock.calls.UnsetHost = append(mock.calls.UnsetHost, callInfo) + mock.lockUnsetHost.Unlock() + mock.UnsetHostFunc(s) +} + +// UnsetHostCalls gets all the calls that were made to UnsetHost. +// Check the length with: +// len(mockedConfig.UnsetHostCalls()) +func (mock *ConfigMock) UnsetHostCalls() []struct { + S string +} { + var calls []struct { + S string + } + mock.lockUnsetHost.RLock() + calls = mock.calls.UnsetHost + mock.lockUnsetHost.RUnlock() + return calls +} + +// Write calls WriteFunc. +func (mock *ConfigMock) Write() error { + if mock.WriteFunc == nil { + panic("ConfigMock.WriteFunc: method is nil but Config.Write was just called") + } + callInfo := struct { + }{} + mock.lockWrite.Lock() + mock.calls.Write = append(mock.calls.Write, callInfo) + mock.lockWrite.Unlock() + return mock.WriteFunc() +} + +// WriteCalls gets all the calls that were made to Write. +// Check the length with: +// len(mockedConfig.WriteCalls()) +func (mock *ConfigMock) WriteCalls() []struct { +} { + var calls []struct { + } + mock.lockWrite.RLock() + calls = mock.calls.Write + mock.lockWrite.RUnlock() + return calls +} diff --git a/internal/config/config_type.go b/internal/config/config_type.go deleted file mode 100644 index a4f63e069..000000000 --- a/internal/config/config_type.go +++ /dev/null @@ -1,218 +0,0 @@ -package config - -import ( - "fmt" - - "gopkg.in/yaml.v3" -) - -// This interface describes interacting with some persistent configuration for gh. -type Config interface { - Get(string, string) (string, error) - GetOrDefault(string, string) (string, error) - GetWithSource(string, string) (string, string, error) - GetOrDefaultWithSource(string, string) (string, string, error) - Default(string) string - Set(string, string, string) error - UnsetHost(string) - Hosts() ([]string, error) - DefaultHost() (string, error) - DefaultHostWithSource() (string, string, error) - Aliases() (*AliasConfig, error) - CheckWriteable(string, string) error - Write() error - WriteHosts() error -} - -type ConfigOption struct { - Key string - Description string - DefaultValue string - AllowedValues []string -} - -var configOptions = []ConfigOption{ - { - Key: "git_protocol", - Description: "the protocol to use for git clone and push operations", - DefaultValue: "https", - AllowedValues: []string{"https", "ssh"}, - }, - { - Key: "editor", - Description: "the text editor program to use for authoring text", - DefaultValue: "", - }, - { - Key: "prompt", - Description: "toggle interactive prompting in the terminal", - DefaultValue: "enabled", - AllowedValues: []string{"enabled", "disabled"}, - }, - { - Key: "pager", - Description: "the terminal pager program to send standard output to", - DefaultValue: "", - }, - { - Key: "http_unix_socket", - Description: "the path to a Unix socket through which to make an HTTP connection", - DefaultValue: "", - }, - { - Key: "browser", - Description: "the web browser to use for opening URLs", - DefaultValue: "", - }, -} - -func ConfigOptions() []ConfigOption { - return configOptions -} - -func ValidateKey(key string) error { - for _, configKey := range configOptions { - if key == configKey.Key { - return nil - } - } - - return fmt.Errorf("invalid key") -} - -type InvalidValueError struct { - ValidValues []string -} - -func (e InvalidValueError) Error() string { - return "invalid value" -} - -func ValidateValue(key, value string) error { - var validValues []string - - for _, v := range configOptions { - if v.Key == key { - validValues = v.AllowedValues - break - } - } - - if validValues == nil { - return nil - } - - for _, v := range validValues { - if v == value { - return nil - } - } - - return &InvalidValueError{ValidValues: validValues} -} - -func NewConfig(root *yaml.Node) Config { - return &fileConfig{ - ConfigMap: ConfigMap{Root: root.Content[0]}, - documentRoot: root, - } -} - -// NewFromString initializes a Config from a yaml string -func NewFromString(str string) Config { - root, err := parseConfigData([]byte(str)) - if err != nil { - panic(err) - } - return NewConfig(root) -} - -// NewBlankConfig initializes a config file pre-populated with comments and default values -func NewBlankConfig() Config { - return NewConfig(NewBlankRoot()) -} - -func NewBlankRoot() *yaml.Node { - return &yaml.Node{ - Kind: yaml.DocumentNode, - Content: []*yaml.Node{ - { - Kind: yaml.MappingNode, - Content: []*yaml.Node{ - { - HeadComment: "What protocol to use when performing git operations. Supported values: ssh, https", - Kind: yaml.ScalarNode, - Value: "git_protocol", - }, - { - Kind: yaml.ScalarNode, - Value: "https", - }, - { - HeadComment: "What editor gh should run when creating issues, pull requests, etc. If blank, will refer to environment.", - Kind: yaml.ScalarNode, - Value: "editor", - }, - { - Kind: yaml.ScalarNode, - Value: "", - }, - { - HeadComment: "When to interactively prompt. This is a global config that cannot be overridden by hostname. Supported values: enabled, disabled", - Kind: yaml.ScalarNode, - Value: "prompt", - }, - { - Kind: yaml.ScalarNode, - Value: "enabled", - }, - { - HeadComment: "A pager program to send command output to, e.g. \"less\". Set the value to \"cat\" to disable the pager.", - Kind: yaml.ScalarNode, - Value: "pager", - }, - { - Kind: yaml.ScalarNode, - Value: "", - }, - { - HeadComment: "Aliases allow you to create nicknames for gh commands", - Kind: yaml.ScalarNode, - Value: "aliases", - }, - { - Kind: yaml.MappingNode, - Content: []*yaml.Node{ - { - Kind: yaml.ScalarNode, - Value: "co", - }, - { - Kind: yaml.ScalarNode, - Value: "pr checkout", - }, - }, - }, - { - HeadComment: "The path to a unix socket through which send HTTP connections. If blank, HTTP traffic will be handled by net/http.DefaultTransport.", - Kind: yaml.ScalarNode, - Value: "http_unix_socket", - }, - { - Kind: yaml.ScalarNode, - Value: "", - }, - { - HeadComment: "What web browser gh should use when opening URLs. If blank, will refer to environment.", - Kind: yaml.ScalarNode, - Value: "browser", - }, - { - Kind: yaml.ScalarNode, - Value: "", - }, - }, - }, - }, - } -} diff --git a/internal/config/config_type_test.go b/internal/config/config_type_test.go deleted file mode 100644 index c16455bcc..000000000 --- a/internal/config/config_type_test.go +++ /dev/null @@ -1,118 +0,0 @@ -package config - -import ( - "bytes" - "testing" - - "github.com/MakeNowJust/heredoc" - "github.com/stretchr/testify/assert" -) - -func Test_fileConfig_Set(t *testing.T) { - mainBuf := bytes.Buffer{} - hostsBuf := bytes.Buffer{} - defer StubWriteConfig(&mainBuf, &hostsBuf)() - - c := NewBlankConfig() - assert.NoError(t, c.Set("", "editor", "nano")) - assert.NoError(t, c.Set("github.com", "git_protocol", "ssh")) - assert.NoError(t, c.Set("example.com", "editor", "vim")) - assert.NoError(t, c.Set("github.com", "user", "hubot")) - assert.NoError(t, c.Write()) - - assert.Contains(t, mainBuf.String(), "editor: nano") - assert.Contains(t, mainBuf.String(), "git_protocol: https") - assert.Equal(t, `github.com: - git_protocol: ssh - user: hubot -example.com: - editor: vim -`, hostsBuf.String()) -} - -func Test_defaultConfig(t *testing.T) { - mainBuf := bytes.Buffer{} - hostsBuf := bytes.Buffer{} - defer StubWriteConfig(&mainBuf, &hostsBuf)() - - cfg := NewBlankConfig() - assert.NoError(t, cfg.Write()) - - expected := heredoc.Doc(` - # What protocol to use when performing git operations. Supported values: ssh, https - git_protocol: https - # What editor gh should run when creating issues, pull requests, etc. If blank, will refer to environment. - editor: - # When to interactively prompt. This is a global config that cannot be overridden by hostname. Supported values: enabled, disabled - prompt: enabled - # A pager program to send command output to, e.g. "less". Set the value to "cat" to disable the pager. - pager: - # Aliases allow you to create nicknames for gh commands - aliases: - co: pr checkout - # The path to a unix socket through which send HTTP connections. If blank, HTTP traffic will be handled by net/http.DefaultTransport. - http_unix_socket: - # What web browser gh should use when opening URLs. If blank, will refer to environment. - browser: - `) - assert.Equal(t, expected, mainBuf.String()) - assert.Equal(t, "", hostsBuf.String()) - - proto, err := cfg.GetOrDefault("", "git_protocol") - assert.NoError(t, err) - assert.Equal(t, "https", proto) - - editor, err := cfg.Get("", "editor") - assert.NoError(t, err) - assert.Equal(t, "", editor) - - aliases, err := cfg.Aliases() - assert.NoError(t, err) - assert.Equal(t, len(aliases.All()), 1) - expansion, _ := aliases.Get("co") - assert.Equal(t, expansion, "pr checkout") - - browser, err := cfg.Get("", "browser") - assert.NoError(t, err) - assert.Equal(t, "", browser) -} - -func Test_ValidateValue(t *testing.T) { - err := ValidateValue("git_protocol", "sshpps") - assert.EqualError(t, err, "invalid value") - - err = ValidateValue("git_protocol", "ssh") - assert.NoError(t, err) - - err = ValidateValue("editor", "vim") - assert.NoError(t, err) - - err = ValidateValue("got", "123") - assert.NoError(t, err) - - err = ValidateValue("http_unix_socket", "really_anything/is/allowed/and/net.Dial\\(...\\)/will/ultimately/validate") - assert.NoError(t, err) -} - -func Test_ValidateKey(t *testing.T) { - err := ValidateKey("invalid") - assert.EqualError(t, err, "invalid key") - - err = ValidateKey("git_protocol") - assert.NoError(t, err) - - err = ValidateKey("editor") - assert.NoError(t, err) - - err = ValidateKey("prompt") - assert.NoError(t, err) - - err = ValidateKey("pager") - assert.NoError(t, err) - - err = ValidateKey("http_unix_socket") - assert.NoError(t, err) - - err = ValidateKey("browser") - assert.NoError(t, err) -} diff --git a/internal/config/from_env.go b/internal/config/from_env.go deleted file mode 100644 index 3cc19879d..000000000 --- a/internal/config/from_env.go +++ /dev/null @@ -1,156 +0,0 @@ -package config - -import ( - "fmt" - "os" - "sort" - "strconv" - - "github.com/cli/cli/v2/internal/ghinstance" - "github.com/cli/cli/v2/pkg/set" -) - -const ( - GH_HOST = "GH_HOST" - GH_TOKEN = "GH_TOKEN" - GITHUB_TOKEN = "GITHUB_TOKEN" - GH_ENTERPRISE_TOKEN = "GH_ENTERPRISE_TOKEN" - GITHUB_ENTERPRISE_TOKEN = "GITHUB_ENTERPRISE_TOKEN" - CODESPACES = "CODESPACES" -) - -type ReadOnlyEnvError struct { - Variable string -} - -func (e *ReadOnlyEnvError) Error() string { - return fmt.Sprintf("read-only value in %s", e.Variable) -} - -func InheritEnv(c Config) Config { - return &envConfig{Config: c} -} - -type envConfig struct { - Config -} - -func (c *envConfig) Hosts() ([]string, error) { - hosts, err := c.Config.Hosts() - if err != nil { - return nil, err - } - - hostSet := set.NewStringSet() - hostSet.AddValues(hosts) - - // If GH_HOST is set then add it to list. - if host := os.Getenv(GH_HOST); host != "" { - hostSet.Add(host) - } - - // If there is a valid environment variable token for the - // default host then add default host to list. - if token, _ := AuthTokenFromEnv(ghinstance.Default()); token != "" { - hostSet.Add(ghinstance.Default()) - } - - s := hostSet.ToSlice() - // If default host is in list then move it to the front. - sort.SliceStable(s, func(i, j int) bool { return s[i] == ghinstance.Default() }) - return s, nil -} - -func (c *envConfig) DefaultHost() (string, error) { - val, _, err := c.DefaultHostWithSource() - return val, err -} - -func (c *envConfig) DefaultHostWithSource() (string, string, error) { - if host := os.Getenv(GH_HOST); host != "" { - return host, GH_HOST, nil - } - return c.Config.DefaultHostWithSource() -} - -func (c *envConfig) Get(hostname, key string) (string, error) { - val, _, err := c.GetWithSource(hostname, key) - return val, err -} - -func (c *envConfig) GetWithSource(hostname, key string) (string, string, error) { - if hostname != "" && key == "oauth_token" { - if token, env := AuthTokenFromEnv(hostname); token != "" { - return token, env, nil - } - } - - return c.Config.GetWithSource(hostname, key) -} - -func (c *envConfig) GetOrDefault(hostname, key string) (val string, err error) { - val, _, err = c.GetOrDefaultWithSource(hostname, key) - return -} - -func (c *envConfig) GetOrDefaultWithSource(hostname, key string) (val string, src string, err error) { - val, src, err = c.GetWithSource(hostname, key) - if err == nil && val == "" { - val = c.Default(key) - } - - return -} - -func (c *envConfig) Default(key string) string { - return c.Config.Default(key) -} - -func (c *envConfig) CheckWriteable(hostname, key string) error { - if hostname != "" && key == "oauth_token" { - if token, env := AuthTokenFromEnv(hostname); token != "" { - return &ReadOnlyEnvError{Variable: env} - } - } - - return c.Config.CheckWriteable(hostname, key) -} - -func AuthTokenFromEnv(hostname string) (string, string) { - if ghinstance.IsEnterprise(hostname) { - if token := os.Getenv(GH_ENTERPRISE_TOKEN); token != "" { - return token, GH_ENTERPRISE_TOKEN - } - - if token := os.Getenv(GITHUB_ENTERPRISE_TOKEN); token != "" { - return token, GITHUB_ENTERPRISE_TOKEN - } - - if isCodespaces, _ := strconv.ParseBool(os.Getenv(CODESPACES)); isCodespaces { - return os.Getenv(GITHUB_TOKEN), GITHUB_TOKEN - } - - return "", "" - } - - if token := os.Getenv(GH_TOKEN); token != "" { - return token, GH_TOKEN - } - - return os.Getenv(GITHUB_TOKEN), GITHUB_TOKEN -} - -func AuthTokenProvidedFromEnv() bool { - return os.Getenv(GH_ENTERPRISE_TOKEN) != "" || - os.Getenv(GITHUB_ENTERPRISE_TOKEN) != "" || - os.Getenv(GH_TOKEN) != "" || - os.Getenv(GITHUB_TOKEN) != "" -} - -func IsHostEnv(src string) bool { - return src == GH_HOST -} - -func IsEnterpriseEnv(src string) bool { - return src == GH_ENTERPRISE_TOKEN || src == GITHUB_ENTERPRISE_TOKEN -} diff --git a/internal/config/from_env_test.go b/internal/config/from_env_test.go deleted file mode 100644 index 4bce09e85..000000000 --- a/internal/config/from_env_test.go +++ /dev/null @@ -1,389 +0,0 @@ -package config - -import ( - "os" - "testing" - - "github.com/MakeNowJust/heredoc" - "github.com/stretchr/testify/assert" -) - -func setenv(t *testing.T, key, newValue string) { - oldValue, hasValue := os.LookupEnv(key) - os.Setenv(key, newValue) - t.Cleanup(func() { - if hasValue { - os.Setenv(key, oldValue) - } else { - os.Unsetenv(key) - } - }) -} - -func TestInheritEnv(t *testing.T) { - orig_GITHUB_TOKEN := os.Getenv("GITHUB_TOKEN") - orig_GITHUB_ENTERPRISE_TOKEN := os.Getenv("GITHUB_ENTERPRISE_TOKEN") - orig_GH_TOKEN := os.Getenv("GH_TOKEN") - orig_GH_ENTERPRISE_TOKEN := os.Getenv("GH_ENTERPRISE_TOKEN") - orig_AppData := os.Getenv("AppData") - t.Cleanup(func() { - os.Setenv("GITHUB_TOKEN", orig_GITHUB_TOKEN) - os.Setenv("GITHUB_ENTERPRISE_TOKEN", orig_GITHUB_ENTERPRISE_TOKEN) - os.Setenv("GH_TOKEN", orig_GH_TOKEN) - os.Setenv("GH_ENTERPRISE_TOKEN", orig_GH_ENTERPRISE_TOKEN) - os.Setenv("AppData", orig_AppData) - }) - - type wants struct { - hosts []string - token string - source string - writeable bool - } - - tests := []struct { - name string - baseConfig string - GH_HOST string - GITHUB_TOKEN string - GITHUB_ENTERPRISE_TOKEN string - GH_TOKEN string - GH_ENTERPRISE_TOKEN string - CODESPACES string - hostname string - wants wants - }{ - { - name: "blank", - baseConfig: ``, - hostname: "github.com", - wants: wants{ - hosts: []string{}, - token: "", - source: ".config.gh.config.yml", - writeable: true, - }, - }, - { - name: "GITHUB_TOKEN over blank config", - baseConfig: ``, - GITHUB_TOKEN: "OTOKEN", - hostname: "github.com", - wants: wants{ - hosts: []string{"github.com"}, - token: "OTOKEN", - source: "GITHUB_TOKEN", - writeable: false, - }, - }, - { - name: "GH_TOKEN over blank config", - baseConfig: ``, - GH_TOKEN: "OTOKEN", - hostname: "github.com", - wants: wants{ - hosts: []string{"github.com"}, - token: "OTOKEN", - source: "GH_TOKEN", - writeable: false, - }, - }, - { - name: "GITHUB_TOKEN not applicable to GHE", - baseConfig: ``, - GITHUB_TOKEN: "OTOKEN", - hostname: "example.org", - wants: wants{ - hosts: []string{"github.com"}, - token: "", - source: ".config.gh.config.yml", - writeable: true, - }, - }, - { - name: "GH_TOKEN not applicable to GHE", - baseConfig: ``, - GH_TOKEN: "OTOKEN", - hostname: "example.org", - wants: wants{ - hosts: []string{"github.com"}, - token: "", - source: ".config.gh.config.yml", - writeable: true, - }, - }, - { - name: "GITHUB_TOKEN allowed in Codespaces", - baseConfig: ``, - GITHUB_TOKEN: "OTOKEN", - hostname: "example.org", - CODESPACES: "true", - wants: wants{ - hosts: []string{"github.com"}, - token: "OTOKEN", - source: "GITHUB_TOKEN", - writeable: false, - }, - }, - { - name: "GITHUB_ENTERPRISE_TOKEN over blank config", - baseConfig: ``, - GITHUB_ENTERPRISE_TOKEN: "ENTOKEN", - hostname: "example.org", - wants: wants{ - hosts: []string{}, - token: "ENTOKEN", - source: "GITHUB_ENTERPRISE_TOKEN", - writeable: false, - }, - }, - { - name: "GH_ENTERPRISE_TOKEN over blank config", - baseConfig: ``, - GH_ENTERPRISE_TOKEN: "ENTOKEN", - hostname: "example.org", - wants: wants{ - hosts: []string{}, - token: "ENTOKEN", - source: "GH_ENTERPRISE_TOKEN", - writeable: false, - }, - }, - { - name: "token from file", - baseConfig: heredoc.Doc(` - hosts: - github.com: - oauth_token: OTOKEN - `), - hostname: "github.com", - wants: wants{ - hosts: []string{"github.com"}, - token: "OTOKEN", - source: ".config.gh.hosts.yml", - writeable: true, - }, - }, - { - name: "GITHUB_TOKEN shadows token from file", - baseConfig: heredoc.Doc(` - hosts: - github.com: - oauth_token: OTOKEN - `), - GITHUB_TOKEN: "ENVTOKEN", - hostname: "github.com", - wants: wants{ - hosts: []string{"github.com"}, - token: "ENVTOKEN", - source: "GITHUB_TOKEN", - writeable: false, - }, - }, - { - name: "GH_TOKEN shadows token from file", - baseConfig: heredoc.Doc(` - hosts: - github.com: - oauth_token: OTOKEN - `), - GH_TOKEN: "ENVTOKEN", - hostname: "github.com", - wants: wants{ - hosts: []string{"github.com"}, - token: "ENVTOKEN", - source: "GH_TOKEN", - writeable: false, - }, - }, - { - name: "GITHUB_ENTERPRISE_TOKEN shadows token from file", - baseConfig: heredoc.Doc(` - hosts: - example.org: - oauth_token: OTOKEN - `), - GITHUB_ENTERPRISE_TOKEN: "ENVTOKEN", - hostname: "example.org", - wants: wants{ - hosts: []string{"example.org"}, - token: "ENVTOKEN", - source: "GITHUB_ENTERPRISE_TOKEN", - writeable: false, - }, - }, - { - name: "GH_ENTERPRISE_TOKEN shadows token from file", - baseConfig: heredoc.Doc(` - hosts: - example.org: - oauth_token: OTOKEN - `), - GH_ENTERPRISE_TOKEN: "ENVTOKEN", - hostname: "example.org", - wants: wants{ - hosts: []string{"example.org"}, - token: "ENVTOKEN", - source: "GH_ENTERPRISE_TOKEN", - writeable: false, - }, - }, - { - name: "GH_TOKEN shadows token from GITHUB_TOKEN", - baseConfig: ``, - GH_TOKEN: "GHTOKEN", - GITHUB_TOKEN: "GITHUBTOKEN", - hostname: "github.com", - wants: wants{ - hosts: []string{"github.com"}, - token: "GHTOKEN", - source: "GH_TOKEN", - writeable: false, - }, - }, - { - name: "GH_ENTERPRISE_TOKEN shadows token from GITHUB_ENTERPRISE_TOKEN", - baseConfig: ``, - GH_ENTERPRISE_TOKEN: "GHTOKEN", - GITHUB_ENTERPRISE_TOKEN: "GITHUBTOKEN", - hostname: "example.org", - wants: wants{ - hosts: []string{}, - token: "GHTOKEN", - source: "GH_ENTERPRISE_TOKEN", - writeable: false, - }, - }, - { - name: "GITHUB_TOKEN adds host entry", - baseConfig: heredoc.Doc(` - hosts: - example.org: - oauth_token: OTOKEN - `), - GITHUB_TOKEN: "ENVTOKEN", - hostname: "github.com", - wants: wants{ - hosts: []string{"github.com", "example.org"}, - token: "ENVTOKEN", - source: "GITHUB_TOKEN", - writeable: false, - }, - }, - { - name: "GH_TOKEN adds host entry", - baseConfig: heredoc.Doc(` - hosts: - example.org: - oauth_token: OTOKEN - `), - GH_TOKEN: "ENVTOKEN", - hostname: "github.com", - wants: wants{ - hosts: []string{"github.com", "example.org"}, - token: "ENVTOKEN", - source: "GH_TOKEN", - writeable: false, - }, - }, - { - name: "GH_HOST adds host entry when paired with environment token", - baseConfig: ``, - GH_HOST: "example.org", - GH_ENTERPRISE_TOKEN: "GH_ENTERPRISE_TOKEN", - hostname: "example.org", - wants: wants{ - hosts: []string{"example.org"}, - token: "GH_ENTERPRISE_TOKEN", - source: "GH_ENTERPRISE_TOKEN", - writeable: false, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - setenv(t, "GH_HOST", tt.GH_HOST) - setenv(t, "GITHUB_TOKEN", tt.GITHUB_TOKEN) - setenv(t, "GITHUB_ENTERPRISE_TOKEN", tt.GITHUB_ENTERPRISE_TOKEN) - setenv(t, "GH_TOKEN", tt.GH_TOKEN) - setenv(t, "GH_ENTERPRISE_TOKEN", tt.GH_ENTERPRISE_TOKEN) - setenv(t, "AppData", "") - setenv(t, "CODESPACES", tt.CODESPACES) - - baseCfg := NewFromString(tt.baseConfig) - cfg := InheritEnv(baseCfg) - - hosts, _ := cfg.Hosts() - assert.Equal(t, tt.wants.hosts, hosts) - - val, source, _ := cfg.GetWithSource(tt.hostname, "oauth_token") - assert.Equal(t, tt.wants.token, val) - assert.Regexp(t, tt.wants.source, source) - - val, _ = cfg.Get(tt.hostname, "oauth_token") - assert.Equal(t, tt.wants.token, val) - - err := cfg.CheckWriteable(tt.hostname, "oauth_token") - if tt.wants.writeable != (err == nil) { - t.Errorf("CheckWriteable() = %v, wants %v", err, tt.wants.writeable) - } - }) - } -} - -func TestAuthTokenProvidedFromEnv(t *testing.T) { - orig_GITHUB_TOKEN := os.Getenv("GITHUB_TOKEN") - orig_GITHUB_ENTERPRISE_TOKEN := os.Getenv("GITHUB_ENTERPRISE_TOKEN") - orig_GH_TOKEN := os.Getenv("GH_TOKEN") - orig_GH_ENTERPRISE_TOKEN := os.Getenv("GH_ENTERPRISE_TOKEN") - t.Cleanup(func() { - os.Setenv("GITHUB_TOKEN", orig_GITHUB_TOKEN) - os.Setenv("GITHUB_ENTERPRISE_TOKEN", orig_GITHUB_ENTERPRISE_TOKEN) - os.Setenv("GH_TOKEN", orig_GH_TOKEN) - os.Setenv("GH_ENTERPRISE_TOKEN", orig_GH_ENTERPRISE_TOKEN) - }) - - tests := []struct { - name string - GITHUB_TOKEN string - GITHUB_ENTERPRISE_TOKEN string - GH_TOKEN string - GH_ENTERPRISE_TOKEN string - provided bool - }{ - { - name: "no env tokens", - provided: false, - }, - { - name: "GH_TOKEN", - GH_TOKEN: "TOKEN", - provided: true, - }, - { - name: "GITHUB_TOKEN", - GITHUB_TOKEN: "TOKEN", - provided: true, - }, - { - name: "GH_ENTERPRISE_TOKEN", - GH_ENTERPRISE_TOKEN: "TOKEN", - provided: true, - }, - { - name: "GITHUB_ENTERPRISE_TOKEN", - GITHUB_ENTERPRISE_TOKEN: "TOKEN", - provided: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - os.Setenv("GITHUB_TOKEN", tt.GITHUB_TOKEN) - os.Setenv("GITHUB_ENTERPRISE_TOKEN", tt.GITHUB_ENTERPRISE_TOKEN) - os.Setenv("GH_TOKEN", tt.GH_TOKEN) - os.Setenv("GH_ENTERPRISE_TOKEN", tt.GH_ENTERPRISE_TOKEN) - assert.Equal(t, tt.provided, AuthTokenProvidedFromEnv()) - }) - } -} diff --git a/internal/config/from_file.go b/internal/config/from_file.go deleted file mode 100644 index a4adec2f4..000000000 --- a/internal/config/from_file.go +++ /dev/null @@ -1,342 +0,0 @@ -package config - -import ( - "bytes" - "errors" - "fmt" - "sort" - "strings" - - "github.com/cli/cli/v2/internal/ghinstance" - "gopkg.in/yaml.v3" -) - -// This type implements a Config interface and represents a config file on disk. -type fileConfig struct { - ConfigMap - documentRoot *yaml.Node -} - -type HostConfig struct { - ConfigMap - Host string -} - -func (c *fileConfig) Root() *yaml.Node { - return c.ConfigMap.Root -} - -func (c *fileConfig) Get(hostname, key string) (string, error) { - val, _, err := c.GetWithSource(hostname, key) - return val, err -} - -func (c *fileConfig) GetWithSource(hostname, key string) (string, string, error) { - if hostname != "" { - var notFound *NotFoundError - - hostCfg, err := c.configForHost(hostname) - if err != nil && !errors.As(err, ¬Found) { - return "", "", err - } - - var hostValue string - if hostCfg != nil { - hostValue, err = hostCfg.GetStringValue(key) - if err != nil && !errors.As(err, ¬Found) { - return "", "", err - } - } - - if hostValue != "" { - return hostValue, HostsConfigFile(), nil - } - } - - defaultSource := ConfigFile() - - value, err := c.GetStringValue(key) - - var notFound *NotFoundError - - if err != nil && errors.As(err, ¬Found) { - return defaultFor(key), defaultSource, nil - } else if err != nil { - return "", defaultSource, err - } - - return value, defaultSource, nil -} - -func (c *fileConfig) GetOrDefault(hostname, key string) (val string, err error) { - val, _, err = c.GetOrDefaultWithSource(hostname, key) - return -} - -func (c *fileConfig) GetOrDefaultWithSource(hostname, key string) (val string, src string, err error) { - val, src, err = c.GetWithSource(hostname, key) - if err != nil && val == "" { - val = c.Default(key) - } - return -} - -func (c *fileConfig) Default(key string) string { - return defaultFor(key) -} - -func (c *fileConfig) Set(hostname, key, value string) error { - if hostname == "" { - return c.SetStringValue(key, value) - } else { - hostCfg, err := c.configForHost(hostname) - var notFound *NotFoundError - if errors.As(err, ¬Found) { - hostCfg = c.makeConfigForHost(hostname) - } else if err != nil { - return err - } - return hostCfg.SetStringValue(key, value) - } -} - -func (c *fileConfig) UnsetHost(hostname string) { - if hostname == "" { - return - } - - hostsEntry, err := c.FindEntry("hosts") - if err != nil { - return - } - - cm := ConfigMap{hostsEntry.ValueNode} - cm.RemoveEntry(hostname) -} - -func (c *fileConfig) configForHost(hostname string) (*HostConfig, error) { - hosts, err := c.hostEntries() - if err != nil { - return nil, err - } - - for _, hc := range hosts { - if strings.EqualFold(hc.Host, hostname) { - return hc, nil - } - } - return nil, &NotFoundError{fmt.Errorf("could not find config entry for %q", hostname)} -} - -func (c *fileConfig) CheckWriteable(hostname, key string) error { - // TODO: check filesystem permissions - return nil -} - -func (c *fileConfig) Write() error { - mainData := yaml.Node{Kind: yaml.MappingNode} - - nodes := c.documentRoot.Content[0].Content - for i := 0; i < len(nodes)-1; i += 2 { - if nodes[i].Value != "hosts" { - mainData.Content = append(mainData.Content, nodes[i], nodes[i+1]) - } - } - - mainBytes, err := yaml.Marshal(&mainData) - if err != nil { - return err - } - - err = WriteConfigFile(ConfigFile(), yamlNormalize(mainBytes)) - if err != nil { - return err - } - - return c.WriteHosts() -} - -// Write the hosts config file only, so as to allow logging in when the main -// config file is not writable. -func (c *fileConfig) WriteHosts() error { - hostsData := yaml.Node{Kind: yaml.MappingNode} - - nodes := c.documentRoot.Content[0].Content - for i := 0; i < len(nodes)-1; i += 2 { - if nodes[i].Value == "hosts" { - hostsData.Content = append(hostsData.Content, nodes[i+1].Content...) - } - } - - hostsBytes, err := yaml.Marshal(&hostsData) - if err != nil { - return err - } - - return WriteConfigFile(HostsConfigFile(), yamlNormalize(hostsBytes)) -} - -func (c *fileConfig) Aliases() (*AliasConfig, error) { - // The complexity here is for dealing with either a missing or empty aliases key. It's something - // we'll likely want for other config sections at some point. - entry, err := c.FindEntry("aliases") - var nfe *NotFoundError - notFound := errors.As(err, &nfe) - if err != nil && !notFound { - return nil, err - } - - toInsert := []*yaml.Node{} - - keyNode := entry.KeyNode - valueNode := entry.ValueNode - - if keyNode == nil { - keyNode = &yaml.Node{ - Kind: yaml.ScalarNode, - Value: "aliases", - } - toInsert = append(toInsert, keyNode) - } - - if valueNode == nil || valueNode.Kind != yaml.MappingNode { - valueNode = &yaml.Node{ - Kind: yaml.MappingNode, - Value: "", - } - toInsert = append(toInsert, valueNode) - } - - if len(toInsert) > 0 { - newContent := []*yaml.Node{} - if notFound { - newContent = append(c.Root().Content, keyNode, valueNode) - } else { - for i := 0; i < len(c.Root().Content); i++ { - if i == entry.Index { - newContent = append(newContent, keyNode, valueNode) - i++ - } else { - newContent = append(newContent, c.Root().Content[i]) - } - } - } - c.Root().Content = newContent - } - - return &AliasConfig{ - Parent: c, - ConfigMap: ConfigMap{Root: valueNode}, - }, nil -} - -func (c *fileConfig) hostEntries() ([]*HostConfig, error) { - entry, err := c.FindEntry("hosts") - if err != nil { - return []*HostConfig{}, nil - } - - hostConfigs, err := c.parseHosts(entry.ValueNode) - if err != nil { - return nil, fmt.Errorf("could not parse hosts config: %w", err) - } - - return hostConfigs, nil -} - -// Hosts returns a list of all known hostnames configured in hosts.yml -func (c *fileConfig) Hosts() ([]string, error) { - entries, err := c.hostEntries() - if err != nil { - return nil, err - } - - hostnames := []string{} - for _, entry := range entries { - hostnames = append(hostnames, entry.Host) - } - - sort.SliceStable(hostnames, func(i, j int) bool { return hostnames[i] == ghinstance.Default() }) - - return hostnames, nil -} - -func (c *fileConfig) DefaultHost() (string, error) { - val, _, err := c.DefaultHostWithSource() - return val, err -} - -func (c *fileConfig) DefaultHostWithSource() (string, string, error) { - hosts, err := c.Hosts() - if err == nil && len(hosts) == 1 { - return hosts[0], HostsConfigFile(), nil - } - - return ghinstance.Default(), "", nil -} - -func (c *fileConfig) makeConfigForHost(hostname string) *HostConfig { - hostRoot := &yaml.Node{Kind: yaml.MappingNode} - hostCfg := &HostConfig{ - Host: hostname, - ConfigMap: ConfigMap{Root: hostRoot}, - } - - var notFound *NotFoundError - hostsEntry, err := c.FindEntry("hosts") - if errors.As(err, ¬Found) { - hostsEntry.KeyNode = &yaml.Node{ - Kind: yaml.ScalarNode, - Value: "hosts", - } - hostsEntry.ValueNode = &yaml.Node{Kind: yaml.MappingNode} - root := c.Root() - root.Content = append(root.Content, hostsEntry.KeyNode, hostsEntry.ValueNode) - } else if err != nil { - panic(err) - } - - hostsEntry.ValueNode.Content = append(hostsEntry.ValueNode.Content, - &yaml.Node{ - Kind: yaml.ScalarNode, - Value: hostname, - }, hostRoot) - - return hostCfg -} - -func (c *fileConfig) parseHosts(hostsEntry *yaml.Node) ([]*HostConfig, error) { - hostConfigs := []*HostConfig{} - - for i := 0; i < len(hostsEntry.Content)-1; i = i + 2 { - hostname := hostsEntry.Content[i].Value - hostRoot := hostsEntry.Content[i+1] - hostConfig := HostConfig{ - ConfigMap: ConfigMap{Root: hostRoot}, - Host: hostname, - } - hostConfigs = append(hostConfigs, &hostConfig) - } - - if len(hostConfigs) == 0 { - return nil, errors.New("could not find any host configurations") - } - - return hostConfigs, nil -} - -func yamlNormalize(b []byte) []byte { - if bytes.Equal(b, []byte("{}\n")) { - return []byte{} - } - return b -} - -func defaultFor(key string) string { - for _, co := range configOptions { - if co.Key == key { - return co.DefaultValue - } - } - return "" -} diff --git a/internal/config/from_file_test.go b/internal/config/from_file_test.go deleted file mode 100644 index 0c43c43a7..000000000 --- a/internal/config/from_file_test.go +++ /dev/null @@ -1,15 +0,0 @@ -package config - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func Test_fileConfig_Hosts(t *testing.T) { - c := NewBlankConfig() - hosts, err := c.Hosts() - require.NoError(t, err) - assert.Equal(t, []string{}, hosts) -} diff --git a/internal/config/stub.go b/internal/config/stub.go index 4b1704849..bbed1d763 100644 --- a/internal/config/stub.go +++ b/internal/config/stub.go @@ -1,80 +1,107 @@ package config import ( - "errors" + "io" + "os" + "path/filepath" + "testing" + + ghConfig "github.com/cli/go-gh/pkg/config" ) -type ConfigStub map[string]string +func NewBlankConfig() *ConfigMock { + defaultStr := ` +# What protocol to use when performing git operations. Supported values: ssh, https +git_protocol: https +# What editor gh should run when creating issues, pull requests, etc. If blank, will refer to environment. +editor: +# When to interactively prompt. This is a global config that cannot be overridden by hostname. Supported values: enabled, disabled +prompt: enabled +# A pager program to send command output to, e.g. "less". Set the value to "cat" to disable the pager. +pager: +# Aliases allow you to create nicknames for gh commands +aliases: + co: pr checkout +# The path to a unix socket through which send HTTP connections. If blank, HTTP traffic will be handled by net/http.DefaultTransport. +http_unix_socket: +# What web browser gh should use when opening URLs. If blank, will refer to environment. +browser: +` + return NewFromString(defaultStr) +} -func genKey(host, key string) string { - if host != "" { - return host + ":" + key +func NewFromString(cfgStr string) *ConfigMock { + c := ghConfig.ReadFromString(cfgStr) + cfg := cfg{c} + mock := &ConfigMock{} + mock.AuthTokenFunc = func(host string) (string, string) { + token, _ := c.Get([]string{"hosts", host, "oauth_token"}) + return token, "oauth_token" } - return key -} - -func (c ConfigStub) Get(host, key string) (string, error) { - val, _, err := c.GetWithSource(host, key) - return val, err -} - -func (c ConfigStub) GetWithSource(host, key string) (string, string, error) { - if v, found := c[genKey(host, key)]; found { - return v, "(memory)", nil + mock.GetFunc = func(host, key string) (string, error) { + return cfg.Get(host, key) } - return "", "", errors.New("not found") -} - -func (c ConfigStub) GetOrDefault(hostname, key string) (val string, err error) { - val, _, err = c.GetOrDefaultWithSource(hostname, key) - return -} - -func (c ConfigStub) GetOrDefaultWithSource(hostname, key string) (val string, src string, err error) { - val, src, err = c.GetWithSource(hostname, key) - if err == nil && val == "" { - val = c.Default(key) + mock.GetOrDefaultFunc = func(host, key string) (string, error) { + return cfg.GetOrDefault(host, key) } - return + mock.SetFunc = func(host, key, value string) { + cfg.Set(host, key, value) + } + mock.UnsetHostFunc = func(host string) { + cfg.UnsetHost(host) + } + mock.HostsFunc = func() []string { + keys, _ := c.Keys([]string{"hosts"}) + return keys + } + mock.DefaultHostFunc = func() (string, string) { + return "github.com", "default" + } + mock.AliasesFunc = func() *AliasConfig { + return &AliasConfig{cfg: c} + } + mock.WriteFunc = func() error { + return cfg.Write() + } + return mock } -func (c ConfigStub) Default(key string) string { - return defaultFor(key) -} +// StubWriteConfig stubs out the filesystem where config file are written. +// It then returns a function that will read in the config files into io.Writers. +// It automatically cleans up environment variables and written files. +func StubWriteConfig(t *testing.T) func(io.Writer, io.Writer) { + t.Helper() + tempDir := t.TempDir() + old := os.Getenv("GH_CONFIG_DIR") + os.Setenv("GH_CONFIG_DIR", tempDir) + t.Cleanup(func() { os.Setenv("GH_CONFIG_DIR", old) }) + return func(wc io.Writer, wh io.Writer) { + config, err := os.Open(filepath.Join(tempDir, "config.yml")) + if err != nil { + return + } + defer config.Close() + configData, err := io.ReadAll(config) + if err != nil { + return + } + _, err = wc.Write(configData) + if err != nil { + return + } -func (c ConfigStub) Set(host, key, value string) error { - c[genKey(host, key)] = value - return nil -} - -func (c ConfigStub) Aliases() (*AliasConfig, error) { - return nil, nil -} - -func (c ConfigStub) Hosts() ([]string, error) { - return nil, nil -} - -func (c ConfigStub) UnsetHost(hostname string) { -} - -func (c ConfigStub) CheckWriteable(host, key string) error { - return nil -} - -func (c ConfigStub) Write() error { - c["_written"] = "true" - return nil -} - -func (c ConfigStub) WriteHosts() error { - return nil -} - -func (c ConfigStub) DefaultHost() (string, error) { - return "", nil -} - -func (c ConfigStub) DefaultHostWithSource() (string, string, error) { - return "", "", nil + hosts, err := os.Open(filepath.Join(tempDir, "hosts.yml")) + if err != nil { + return + } + defer hosts.Close() + hostsData, err := io.ReadAll(hosts) + if err != nil { + return + } + _, err = wh.Write(hostsData) + if err != nil { + return + } + } } diff --git a/internal/config/testing.go b/internal/config/testing.go deleted file mode 100644 index 31a5fb2a8..000000000 --- a/internal/config/testing.go +++ /dev/null @@ -1,64 +0,0 @@ -package config - -import ( - "fmt" - "io" - "os" - "path/filepath" -) - -func StubBackupConfig() func() { - orig := BackupConfigFile - BackupConfigFile = func(_ string) error { - return nil - } - - return func() { - BackupConfigFile = orig - } -} - -func StubWriteConfig(wc io.Writer, wh io.Writer) func() { - orig := WriteConfigFile - WriteConfigFile = func(fn string, data []byte) error { - switch filepath.Base(fn) { - case "config.yml": - _, err := wc.Write(data) - return err - case "hosts.yml": - _, err := wh.Write(data) - return err - default: - return fmt.Errorf("write to unstubbed file: %q", fn) - } - } - return func() { - WriteConfigFile = orig - } -} - -func stubConfig(main, hosts string) func() { - orig := ReadConfigFile - ReadConfigFile = func(fn string) ([]byte, error) { - switch filepath.Base(fn) { - case "config.yml": - if main == "" { - return []byte(nil), os.ErrNotExist - } else { - return []byte(main), nil - } - case "hosts.yml": - if hosts == "" { - return []byte(nil), os.ErrNotExist - } else { - return []byte(hosts), nil - } - default: - return []byte(nil), fmt.Errorf("read from unstubbed file: %q", fn) - } - - } - return func() { - ReadConfigFile = orig - } -} diff --git a/pkg/cmd/alias/delete/delete.go b/pkg/cmd/alias/delete/delete.go index 85372d181..daaaa2435 100644 --- a/pkg/cmd/alias/delete/delete.go +++ b/pkg/cmd/alias/delete/delete.go @@ -45,13 +45,10 @@ func deleteRun(opts *DeleteOptions) error { return err } - aliasCfg, err := cfg.Aliases() - if err != nil { - return fmt.Errorf("couldn't read aliases config: %w", err) - } + aliasCfg := cfg.Aliases() - expansion, ok := aliasCfg.Get(opts.Name) - if !ok { + expansion, err := aliasCfg.Get(opts.Name) + if err != nil { return fmt.Errorf("no such alias %s", opts.Name) } @@ -61,6 +58,11 @@ func deleteRun(opts *DeleteOptions) error { return fmt.Errorf("failed to delete alias %s: %w", opts.Name, err) } + err = cfg.Write() + if err != nil { + return err + } + if opts.IO.IsStdoutTTY() { cs := opts.IO.ColorScheme() fmt.Fprintf(opts.IO.ErrOut, "%s Deleted alias %s; was %s\n", cs.SuccessIconWithColor(cs.Red), opts.Name, expansion) diff --git a/pkg/cmd/alias/delete/delete_test.go b/pkg/cmd/alias/delete/delete_test.go index bb955d055..d6c7098b1 100644 --- a/pkg/cmd/alias/delete/delete_test.go +++ b/pkg/cmd/alias/delete/delete_test.go @@ -15,6 +15,8 @@ import ( ) func TestAliasDelete(t *testing.T) { + _ = config.StubWriteConfig(t) + tests := []struct { name string config string @@ -48,8 +50,6 @@ func TestAliasDelete(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - defer config.StubWriteConfig(io.Discard, io.Discard)() - cfg := config.NewFromString(tt.config) ios, _, stdout, stderr := iostreams.Test() diff --git a/pkg/cmd/alias/expand/expand.go b/pkg/cmd/alias/expand/expand.go index f67a93942..d0e2627fb 100644 --- a/pkg/cmd/alias/expand/expand.go +++ b/pkg/cmd/alias/expand/expand.go @@ -23,13 +23,10 @@ func ExpandAlias(cfg config.Config, args []string, findShFunc func() (string, er } expanded = args[1:] - aliases, err := cfg.Aliases() - if err != nil { - return - } + aliases := cfg.Aliases() - expansion, ok := aliases.Get(args[1]) - if !ok { + expansion, getErr := aliases.Get(args[1]) + if getErr != nil { return } diff --git a/pkg/cmd/alias/list/list.go b/pkg/cmd/alias/list/list.go index 785a8fcf5..a7ea40ac4 100644 --- a/pkg/cmd/alias/list/list.go +++ b/pkg/cmd/alias/list/list.go @@ -1,7 +1,6 @@ package list import ( - "fmt" "sort" "github.com/MakeNowJust/heredoc" @@ -48,18 +47,14 @@ func listRun(opts *ListOptions) error { return err } - aliasCfg, err := cfg.Aliases() - if err != nil { - return fmt.Errorf("couldn't read aliases config: %w", err) - } + aliasCfg := cfg.Aliases() - if aliasCfg.Empty() { + aliasMap := aliasCfg.All() + if len(aliasMap) == 0 { return cmdutil.NewNoResultsError("no aliases configured") } tp := utils.NewTablePrinter(opts.IO) - - aliasMap := aliasCfg.All() keys := []string{} for alias := range aliasMap { keys = append(keys, alias) diff --git a/pkg/cmd/alias/list/list_test.go b/pkg/cmd/alias/list/list_test.go index c93486b3f..59b7ed0ef 100644 --- a/pkg/cmd/alias/list/list_test.go +++ b/pkg/cmd/alias/list/list_test.go @@ -44,10 +44,6 @@ func TestAliasList(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // TODO: change underlying config implementation so Write is not - // automatically called when editing aliases in-memory - defer config.StubWriteConfig(io.Discard, io.Discard)() - cfg := config.NewFromString(tt.config) ios, _, stdout, stderr := iostreams.Test() diff --git a/pkg/cmd/alias/set/set.go b/pkg/cmd/alias/set/set.go index f85b9af99..6d09787a4 100644 --- a/pkg/cmd/alias/set/set.go +++ b/pkg/cmd/alias/set/set.go @@ -109,10 +109,7 @@ func setRun(opts *SetOptions) error { return err } - aliasCfg, err := cfg.Aliases() - if err != nil { - return err - } + aliasCfg := cfg.Aliases() expansion, err := getExpansion(opts) if err != nil { @@ -139,7 +136,7 @@ func setRun(opts *SetOptions) error { } successMsg := fmt.Sprintf("%s Added alias.", cs.SuccessIcon()) - if oldExpansion, ok := aliasCfg.Get(opts.Name); ok { + if oldExpansion, err := aliasCfg.Get(opts.Name); err == nil { successMsg = fmt.Sprintf("%s Changed alias %s from %s to %s", cs.SuccessIcon(), cs.Bold(opts.Name), @@ -148,9 +145,11 @@ func setRun(opts *SetOptions) error { ) } - err = aliasCfg.Add(opts.Name, expansion) + aliasCfg.Add(opts.Name, expansion) + + err = cfg.Write() if err != nil { - return fmt.Errorf("could not create alias: %s", err) + return err } if isTerminal { diff --git a/pkg/cmd/alias/set/set_test.go b/pkg/cmd/alias/set/set_test.go index a43866109..72d6fb413 100644 --- a/pkg/cmd/alias/set/set_test.go +++ b/pkg/cmd/alias/set/set_test.go @@ -70,8 +70,6 @@ func runCommand(cfg config.Config, isTTY bool, cli string, in string) (*test.Cmd } func TestAliasSet_gh_command(t *testing.T) { - defer config.StubWriteConfig(io.Discard, io.Discard)() - cfg := config.NewFromString(``) _, err := runCommand(cfg, true, "pr 'pr status'", "") @@ -79,8 +77,7 @@ func TestAliasSet_gh_command(t *testing.T) { } func TestAliasSet_empty_aliases(t *testing.T) { - mainBuf := bytes.Buffer{} - defer config.StubWriteConfig(&mainBuf, io.Discard)() + readConfigs := config.StubWriteConfig(t) cfg := config.NewFromString(heredoc.Doc(` aliases: @@ -93,6 +90,9 @@ func TestAliasSet_empty_aliases(t *testing.T) { t.Fatalf("unexpected error: %s", err) } + mainBuf := bytes.Buffer{} + readConfigs(&mainBuf, io.Discard) + //nolint:staticcheck // prefer exact matchers over ExpectLines test.ExpectLines(t, output.Stderr(), "Added alias") //nolint:staticcheck // prefer exact matchers over ExpectLines @@ -106,8 +106,7 @@ editor: vim } func TestAliasSet_existing_alias(t *testing.T) { - mainBuf := bytes.Buffer{} - defer config.StubWriteConfig(&mainBuf, io.Discard)() + _ = config.StubWriteConfig(t) cfg := config.NewFromString(heredoc.Doc(` aliases: @@ -122,14 +121,16 @@ func TestAliasSet_existing_alias(t *testing.T) { } func TestAliasSet_space_args(t *testing.T) { - mainBuf := bytes.Buffer{} - defer config.StubWriteConfig(&mainBuf, io.Discard)() + readConfigs := config.StubWriteConfig(t) cfg := config.NewFromString(``) output, err := runCommand(cfg, true, `il 'issue list -l "cool story"'`, "") require.NoError(t, err) + mainBuf := bytes.Buffer{} + readConfigs(&mainBuf, io.Discard) + //nolint:staticcheck // prefer exact matchers over ExpectLines test.ExpectLines(t, output.Stderr(), `Adding alias for.*il.*issue list -l "cool story"`) @@ -138,6 +139,8 @@ func TestAliasSet_space_args(t *testing.T) { } func TestAliasSet_arg_processing(t *testing.T) { + readConfigs := config.StubWriteConfig(t) + cases := []struct { Cmd string ExpectedOutputLine string @@ -158,9 +161,6 @@ func TestAliasSet_arg_processing(t *testing.T) { for _, c := range cases { t.Run(c.Cmd, func(t *testing.T) { - mainBuf := bytes.Buffer{} - defer config.StubWriteConfig(&mainBuf, io.Discard)() - cfg := config.NewFromString(``) output, err := runCommand(cfg, true, c.Cmd, "") @@ -168,6 +168,9 @@ func TestAliasSet_arg_processing(t *testing.T) { t.Fatalf("got unexpected error running %s: %s", c.Cmd, err) } + mainBuf := bytes.Buffer{} + readConfigs(&mainBuf, io.Discard) + //nolint:staticcheck // prefer exact matchers over ExpectLines test.ExpectLines(t, output.Stderr(), c.ExpectedOutputLine) //nolint:staticcheck // prefer exact matchers over ExpectLines @@ -177,8 +180,7 @@ func TestAliasSet_arg_processing(t *testing.T) { } func TestAliasSet_init_alias_cfg(t *testing.T) { - mainBuf := bytes.Buffer{} - defer config.StubWriteConfig(&mainBuf, io.Discard)() + readConfigs := config.StubWriteConfig(t) cfg := config.NewFromString(heredoc.Doc(` editor: vim @@ -187,6 +189,9 @@ func TestAliasSet_init_alias_cfg(t *testing.T) { output, err := runCommand(cfg, true, "diff 'pr diff'", "") require.NoError(t, err) + mainBuf := bytes.Buffer{} + readConfigs(&mainBuf, io.Discard) + expected := `editor: vim aliases: diff: pr diff @@ -198,8 +203,7 @@ aliases: } func TestAliasSet_existing_aliases(t *testing.T) { - mainBuf := bytes.Buffer{} - defer config.StubWriteConfig(&mainBuf, io.Discard)() + readConfigs := config.StubWriteConfig(t) cfg := config.NewFromString(heredoc.Doc(` aliases: @@ -209,6 +213,9 @@ func TestAliasSet_existing_aliases(t *testing.T) { output, err := runCommand(cfg, true, "view 'pr view'", "") require.NoError(t, err) + mainBuf := bytes.Buffer{} + readConfigs(&mainBuf, io.Discard) + expected := `aliases: foo: bar view: pr view @@ -221,8 +228,6 @@ func TestAliasSet_existing_aliases(t *testing.T) { } func TestAliasSet_invalid_command(t *testing.T) { - defer config.StubWriteConfig(io.Discard, io.Discard)() - cfg := config.NewFromString(``) _, err := runCommand(cfg, true, "co 'pe checkout'", "") @@ -230,8 +235,7 @@ func TestAliasSet_invalid_command(t *testing.T) { } func TestShellAlias_flag(t *testing.T) { - mainBuf := bytes.Buffer{} - defer config.StubWriteConfig(&mainBuf, io.Discard)() + readConfigs := config.StubWriteConfig(t) cfg := config.NewFromString(``) @@ -240,6 +244,9 @@ func TestShellAlias_flag(t *testing.T) { t.Fatalf("unexpected error: %s", err) } + mainBuf := bytes.Buffer{} + readConfigs(&mainBuf, io.Discard) + //nolint:staticcheck // prefer exact matchers over ExpectLines test.ExpectLines(t, output.Stderr(), "Adding alias for.*igrep") @@ -250,14 +257,16 @@ func TestShellAlias_flag(t *testing.T) { } func TestShellAlias_bang(t *testing.T) { - mainBuf := bytes.Buffer{} - defer config.StubWriteConfig(&mainBuf, io.Discard)() + readConfigs := config.StubWriteConfig(t) cfg := config.NewFromString(``) output, err := runCommand(cfg, true, "igrep '!gh issue list | grep'", "") require.NoError(t, err) + mainBuf := bytes.Buffer{} + readConfigs(&mainBuf, io.Discard) + //nolint:staticcheck // prefer exact matchers over ExpectLines test.ExpectLines(t, output.Stderr(), "Adding alias for.*igrep") @@ -268,8 +277,7 @@ func TestShellAlias_bang(t *testing.T) { } func TestShellAlias_from_stdin(t *testing.T) { - mainBuf := bytes.Buffer{} - defer config.StubWriteConfig(&mainBuf, io.Discard)() + readConfigs := config.StubWriteConfig(t) cfg := config.NewFromString(``) @@ -282,6 +290,9 @@ func TestShellAlias_from_stdin(t *testing.T) { require.NoError(t, err) + mainBuf := bytes.Buffer{} + readConfigs(&mainBuf, io.Discard) + //nolint:staticcheck // prefer exact matchers over ExpectLines test.ExpectLines(t, output.Stderr(), "Adding alias for.*users") diff --git a/pkg/cmd/api/api.go b/pkg/cmd/api/api.go index 12bf3a56c..02b288fa6 100644 --- a/pkg/cmd/api/api.go +++ b/pkg/cmd/api/api.go @@ -288,10 +288,7 @@ func apiRun(opts *ApiOptions) error { return err } - host, err := cfg.DefaultHost() - if err != nil { - return err - } + host, _ := cfg.DefaultHost() if opts.Hostname != "" { host = opts.Hostname diff --git a/pkg/cmd/auth/gitcredential/helper.go b/pkg/cmd/auth/gitcredential/helper.go index bda962e50..46d05d92c 100644 --- a/pkg/cmd/auth/gitcredential/helper.go +++ b/pkg/cmd/auth/gitcredential/helper.go @@ -14,7 +14,8 @@ import ( const tokenUser = "x-access-token" type config interface { - GetWithSource(string, string) (string, string, error) + AuthToken(string) (string, string) + Get(string, string) (string, error) } type CredentialOptions struct { @@ -102,16 +103,16 @@ func helperRun(opts *CredentialOptions) error { lookupHost := wants["host"] var gotUser string - gotToken, source, _ := cfg.GetWithSource(lookupHost, "oauth_token") + gotToken, source := cfg.AuthToken(lookupHost) if gotToken == "" && strings.HasPrefix(lookupHost, "gist.") { lookupHost = strings.TrimPrefix(lookupHost, "gist.") - gotToken, source, _ = cfg.GetWithSource(lookupHost, "oauth_token") + gotToken, source = cfg.AuthToken(lookupHost) } if strings.HasSuffix(source, "_TOKEN") { gotUser = tokenUser } else { - gotUser, _, _ = cfg.GetWithSource(lookupHost, "user") + gotUser, _ = cfg.Get(lookupHost, "user") if gotUser == "" { gotUser = tokenUser } diff --git a/pkg/cmd/auth/gitcredential/helper_test.go b/pkg/cmd/auth/gitcredential/helper_test.go index 5501be4ed..c1f76e6fa 100644 --- a/pkg/cmd/auth/gitcredential/helper_test.go +++ b/pkg/cmd/auth/gitcredential/helper_test.go @@ -11,8 +11,12 @@ import ( // why not just use the config stub argh type tinyConfig map[string]string -func (c tinyConfig) GetWithSource(host, key string) (string, string, error) { - return c[fmt.Sprintf("%s:%s", host, key)], c["_source"], nil +func (c tinyConfig) AuthToken(host string) (string, string) { + return c[fmt.Sprintf("%s:%s", host, "oauth_token")], c["_source"] +} + +func (c tinyConfig) Get(host, key string) (string, error) { + return c[fmt.Sprintf("%s:%s", host, key)], nil } func Test_helperRun(t *testing.T) { diff --git a/pkg/cmd/auth/login/login.go b/pkg/cmd/auth/login/login.go index dccec0c82..474c21588 100644 --- a/pkg/cmd/auth/login/login.go +++ b/pkg/cmd/auth/login/login.go @@ -1,7 +1,6 @@ package login import ( - "errors" "fmt" "io" "net/http" @@ -136,14 +135,10 @@ func loginRun(opts *LoginOptions) error { } } - if err := cfg.CheckWriteable(hostname, "oauth_token"); err != nil { - var roErr *config.ReadOnlyEnvError - if errors.As(err, &roErr) { - fmt.Fprintf(opts.IO.ErrOut, "The value of the %s environment variable is being used for authentication.\n", roErr.Variable) - fmt.Fprint(opts.IO.ErrOut, "To have GitHub CLI store credentials instead, first clear the value from the environment.\n") - return cmdutil.SilentError - } - return err + if src, writeable := shared.AuthTokenWriteable(cfg, hostname); !writeable { + fmt.Fprintf(opts.IO.ErrOut, "The value of the %s environment variable is being used for authentication.\n", src) + fmt.Fprint(opts.IO.ErrOut, "To have GitHub CLI store credentials instead, first clear the value from the environment.\n") + return cmdutil.SilentError } httpClient, err := opts.HttpClient() @@ -152,19 +147,16 @@ func loginRun(opts *LoginOptions) error { } if opts.Token != "" { - err := cfg.Set(hostname, "oauth_token", opts.Token) - if err != nil { - return err - } + cfg.Set(hostname, "oauth_token", opts.Token) if err := shared.HasMinimumScopes(httpClient, hostname, opts.Token); err != nil { return fmt.Errorf("error validating token: %w", err) } - return cfg.WriteHosts() + return cfg.Write() } - existingToken, _ := cfg.Get(hostname, "oauth_token") + existingToken, _ := cfg.AuthToken(hostname) if existingToken != "" && opts.Interactive { if err := shared.HasMinimumScopes(httpClient, hostname, existingToken); err == nil { var keepGoing bool diff --git a/pkg/cmd/auth/login/login_test.go b/pkg/cmd/auth/login/login_test.go index d87602b06..4a48aece7 100644 --- a/pkg/cmd/auth/login/login_test.go +++ b/pkg/cmd/auth/login/login_test.go @@ -216,7 +216,7 @@ func Test_loginRun_nontty(t *testing.T) { name string opts *LoginOptions httpStubs func(*httpmock.Registry) - env map[string]string + cfgStubs func(*config.ConfigMock) wantHosts string wantErr string wantStderr string @@ -282,8 +282,10 @@ func Test_loginRun_nontty(t *testing.T) { Hostname: "github.com", Token: "abc456", }, - env: map[string]string{ - "GH_TOKEN": "value_from_env", + cfgStubs: func(c *config.ConfigMock) { + c.AuthTokenFunc = func(string) (string, string) { + return "value_from_env", "GH_TOKEN" + } }, wantErr: "SilentError", wantStderr: heredoc.Doc(` @@ -297,8 +299,10 @@ func Test_loginRun_nontty(t *testing.T) { Hostname: "ghe.io", Token: "abc456", }, - env: map[string]string{ - "GH_ENTERPRISE_TOKEN": "value_from_env", + cfgStubs: func(c *config.ConfigMock) { + c.AuthTokenFunc = func(string) (string, string) { + return "value_from_env", "GH_ENTERPRISE_TOKEN" + } }, wantErr: "SilentError", wantStderr: heredoc.Doc(` @@ -310,37 +314,24 @@ func Test_loginRun_nontty(t *testing.T) { for _, tt := range tests { ios, _, stdout, stderr := iostreams.Test() - ios.SetStdinTTY(false) ios.SetStdoutTTY(false) - - tt.opts.Config = func() (config.Config, error) { - cfg := config.NewBlankConfig() - return config.InheritEnv(cfg), nil - } - tt.opts.IO = ios + t.Run(tt.name, func(t *testing.T) { + readConfigs := config.StubWriteConfig(t) + cfg := config.NewBlankConfig() + if tt.cfgStubs != nil { + tt.cfgStubs(cfg) + } + tt.opts.Config = func() (config.Config, error) { + return cfg, nil + } + reg := &httpmock.Registry{} tt.opts.HttpClient = func() (*http.Client, error) { return &http.Client{Transport: reg}, nil } - - old_GH_TOKEN := os.Getenv("GH_TOKEN") - os.Setenv("GH_TOKEN", tt.env["GH_TOKEN"]) - old_GITHUB_TOKEN := os.Getenv("GITHUB_TOKEN") - os.Setenv("GITHUB_TOKEN", tt.env["GITHUB_TOKEN"]) - old_GH_ENTERPRISE_TOKEN := os.Getenv("GH_ENTERPRISE_TOKEN") - os.Setenv("GH_ENTERPRISE_TOKEN", tt.env["GH_ENTERPRISE_TOKEN"]) - old_GITHUB_ENTERPRISE_TOKEN := os.Getenv("GITHUB_ENTERPRISE_TOKEN") - os.Setenv("GITHUB_ENTERPRISE_TOKEN", tt.env["GITHUB_ENTERPRISE_TOKEN"]) - defer func() { - os.Setenv("GH_TOKEN", old_GH_TOKEN) - os.Setenv("GITHUB_TOKEN", old_GITHUB_TOKEN) - os.Setenv("GH_ENTERPRISE_TOKEN", old_GH_ENTERPRISE_TOKEN) - os.Setenv("GITHUB_ENTERPRISE_TOKEN", old_GITHUB_ENTERPRISE_TOKEN) - }() - if tt.httpStubs != nil { tt.httpStubs(reg) } @@ -348,10 +339,6 @@ func Test_loginRun_nontty(t *testing.T) { _, restoreRun := run.Stub() defer restoreRun(t) - mainBuf := bytes.Buffer{} - hostsBuf := bytes.Buffer{} - defer config.StubWriteConfig(&mainBuf, &hostsBuf)() - err := loginRun(tt.opts) if tt.wantErr != "" { assert.EqualError(t, err, tt.wantErr) @@ -359,6 +346,10 @@ func Test_loginRun_nontty(t *testing.T) { assert.NoError(t, err) } + mainBuf := bytes.Buffer{} + hostsBuf := bytes.Buffer{} + readConfigs(&mainBuf, &hostsBuf) + assert.Equal(t, "", stdout.String()) assert.Equal(t, tt.wantStderr, stderr.String()) assert.Equal(t, tt.wantHosts, hostsBuf.String()) @@ -378,27 +369,26 @@ func Test_loginRun_Survey(t *testing.T) { runStubs func(*run.CommandStubber) wantHosts string wantErrOut *regexp.Regexp - cfg func(config.Config) + cfgStubs func(*config.ConfigMock) }{ { name: "already authenticated", opts: &LoginOptions{ Interactive: true, }, - cfg: func(cfg config.Config) { - _ = cfg.Set("github.com", "oauth_token", "ghi789") + cfgStubs: func(c *config.ConfigMock) { + c.AuthTokenFunc = func(h string) (string, string) { + return "ghi789", "oauth_token" + } }, httpStubs: func(reg *httpmock.Registry) { reg.Register(httpmock.REST("GET", ""), httpmock.ScopesResponder("repo,read:org")) - // reg.Register( - // httpmock.GraphQL(`query UserCurrent\b`), - // httpmock.StringResponse(`{"data":{"viewer":{"login":"jillv"}}}`)) }, askStubs: func(as *prompt.AskStubber) { as.StubPrompt("What account do you want to log into?").AnswerWith("GitHub.com") as.StubPrompt("You're already logged into github.com. Do you want to re-authenticate?").AnswerWith(false) }, - wantHosts: "", // nothing should have been written to hosts + wantHosts: "", wantErrOut: nil, }, { @@ -521,10 +511,11 @@ func Test_loginRun_Survey(t *testing.T) { tt.opts.IO = ios - cfg := config.NewBlankConfig() + readConfigs := config.StubWriteConfig(t) - if tt.cfg != nil { - tt.cfg(cfg) + cfg := config.NewBlankConfig() + if tt.cfgStubs != nil { + tt.cfgStubs(cfg) } tt.opts.Config = func() (config.Config, error) { return cfg, nil @@ -544,10 +535,6 @@ func Test_loginRun_Survey(t *testing.T) { httpmock.StringResponse(`{"data":{"viewer":{"login":"jillv"}}}`)) } - mainBuf := bytes.Buffer{} - hostsBuf := bytes.Buffer{} - defer config.StubWriteConfig(&mainBuf, &hostsBuf)() - as := prompt.NewAskStubber(t) if tt.askStubs != nil { tt.askStubs(as) @@ -564,6 +551,10 @@ func Test_loginRun_Survey(t *testing.T) { t.Fatalf("unexpected error: %s", err) } + mainBuf := bytes.Buffer{} + hostsBuf := bytes.Buffer{} + readConfigs(&mainBuf, &hostsBuf) + assert.Equal(t, tt.wantHosts, hostsBuf.String()) if tt.wantErrOut == nil { assert.Equal(t, "", stderr.String()) diff --git a/pkg/cmd/auth/logout/logout.go b/pkg/cmd/auth/logout/logout.go index 04cb2a087..4027fd0f8 100644 --- a/pkg/cmd/auth/logout/logout.go +++ b/pkg/cmd/auth/logout/logout.go @@ -1,7 +1,6 @@ package logout import ( - "errors" "fmt" "net/http" @@ -9,6 +8,7 @@ import ( "github.com/MakeNowJust/heredoc" "github.com/cli/cli/v2/api" "github.com/cli/cli/v2/internal/config" + "github.com/cli/cli/v2/pkg/cmd/auth/shared" "github.com/cli/cli/v2/pkg/cmdutil" "github.com/cli/cli/v2/pkg/iostreams" "github.com/cli/cli/v2/pkg/prompt" @@ -70,10 +70,7 @@ func logoutRun(opts *LogoutOptions) error { return err } - candidates, err := cfg.Hosts() - if err != nil { - return err - } + candidates := cfg.Hosts() if len(candidates) == 0 { return fmt.Errorf("not logged in to any hosts") } @@ -105,14 +102,10 @@ func logoutRun(opts *LogoutOptions) error { } } - if err := cfg.CheckWriteable(hostname, "oauth_token"); err != nil { - var roErr *config.ReadOnlyEnvError - if errors.As(err, &roErr) { - fmt.Fprintf(opts.IO.ErrOut, "The value of the %s environment variable is being used for authentication.\n", roErr.Variable) - fmt.Fprint(opts.IO.ErrOut, "To erase credentials stored in GitHub CLI, first clear the value from the environment.\n") - return cmdutil.SilentError - } - return err + if src, writeable := shared.AuthTokenWriteable(cfg, hostname); !writeable { + fmt.Fprintf(opts.IO.ErrOut, "The value of the %s environment variable is being used for authentication.\n", src) + fmt.Fprint(opts.IO.ErrOut, "To erase credentials stored in GitHub CLI, first clear the value from the environment.\n") + return cmdutil.SilentError } httpClient, err := opts.HttpClient() @@ -134,7 +127,7 @@ func logoutRun(opts *LogoutOptions) error { } cfg.UnsetHost(hostname) - err = cfg.WriteHosts() + err = cfg.Write() if err != nil { return fmt.Errorf("failed to write config, authentication configuration not updated: %w", err) } diff --git a/pkg/cmd/auth/logout/logout_test.go b/pkg/cmd/auth/logout/logout_test.go index 73fd50d40..aa176f9f0 100644 --- a/pkg/cmd/auth/logout/logout_test.go +++ b/pkg/cmd/auth/logout/logout_test.go @@ -114,6 +114,7 @@ func Test_logoutRun_tty(t *testing.T) { name: "no arguments, one host", opts: &LogoutOptions{}, cfgHosts: []string{"github.com"}, + wantHosts: "{}\n", wantErrOut: regexp.MustCompile(`Logged out of github.com account 'cybilb'`), }, { @@ -134,34 +135,29 @@ func Test_logoutRun_tty(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ios, _, _, stderr := iostreams.Test() - - ios.SetStdinTTY(true) - ios.SetStdoutTTY(true) - - tt.opts.IO = ios - cfg := config.NewBlankConfig() + readConfigs := config.StubWriteConfig(t) + cfg := config.NewFromString("") + for _, hostname := range tt.cfgHosts { + cfg.Set(hostname, "oauth_token", "abc123") + } tt.opts.Config = func() (config.Config, error) { return cfg, nil } - for _, hostname := range tt.cfgHosts { - _ = cfg.Set(hostname, "oauth_token", "abc123") - } + ios, _, _, stderr := iostreams.Test() + ios.SetStdinTTY(true) + ios.SetStdoutTTY(true) + tt.opts.IO = ios reg := &httpmock.Registry{} reg.Register( httpmock.GraphQL(`query UserCurrent\b`), - httpmock.StringResponse(`{"data":{"viewer":{"login":"cybilb"}}}`)) - + httpmock.StringResponse(`{"data":{"viewer":{"login":"cybilb"}}}`), + ) tt.opts.HttpClient = func() (*http.Client, error) { return &http.Client{Transport: reg}, nil } - mainBuf := bytes.Buffer{} - hostsBuf := bytes.Buffer{} - defer config.StubWriteConfig(&mainBuf, &hostsBuf)() - as := prompt.NewAskStubber(t) if tt.askStubs != nil { tt.askStubs(as) @@ -181,6 +177,10 @@ func Test_logoutRun_tty(t *testing.T) { assert.True(t, tt.wantErrOut.MatchString(stderr.String())) } + mainBuf := bytes.Buffer{} + hostsBuf := bytes.Buffer{} + readConfigs(&mainBuf, &hostsBuf) + assert.Equal(t, tt.wantHosts, hostsBuf.String()) reg.Verify(t) }) @@ -201,7 +201,8 @@ func Test_logoutRun_nontty(t *testing.T) { opts: &LogoutOptions{ Hostname: "harry.mason", }, - cfgHosts: []string{"harry.mason"}, + cfgHosts: []string{"harry.mason"}, + wantHosts: "{}\n", }, { name: "hostname, multiple hosts", @@ -222,30 +223,25 @@ func Test_logoutRun_nontty(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ios, _, _, stderr := iostreams.Test() - - ios.SetStdinTTY(false) - ios.SetStdoutTTY(false) - - tt.opts.IO = ios - cfg := config.NewBlankConfig() + readConfigs := config.StubWriteConfig(t) + cfg := config.NewFromString("") + for _, hostname := range tt.cfgHosts { + cfg.Set(hostname, "oauth_token", "abc123") + } tt.opts.Config = func() (config.Config, error) { return cfg, nil } - for _, hostname := range tt.cfgHosts { - _ = cfg.Set(hostname, "oauth_token", "abc123") - } + ios, _, _, stderr := iostreams.Test() + ios.SetStdinTTY(false) + ios.SetStdoutTTY(false) + tt.opts.IO = ios reg := &httpmock.Registry{} tt.opts.HttpClient = func() (*http.Client, error) { return &http.Client{Transport: reg}, nil } - mainBuf := bytes.Buffer{} - hostsBuf := bytes.Buffer{} - defer config.StubWriteConfig(&mainBuf, &hostsBuf)() - err := logoutRun(tt.opts) if tt.wantErr != "" { assert.EqualError(t, err, tt.wantErr) @@ -255,6 +251,10 @@ func Test_logoutRun_nontty(t *testing.T) { assert.Equal(t, "", stderr.String()) + mainBuf := bytes.Buffer{} + hostsBuf := bytes.Buffer{} + readConfigs(&mainBuf, &hostsBuf) + assert.Equal(t, tt.wantHosts, hostsBuf.String()) reg.Verify(t) }) diff --git a/pkg/cmd/auth/refresh/refresh.go b/pkg/cmd/auth/refresh/refresh.go index 4a7095eef..32d1dade4 100644 --- a/pkg/cmd/auth/refresh/refresh.go +++ b/pkg/cmd/auth/refresh/refresh.go @@ -1,7 +1,6 @@ package refresh import ( - "errors" "fmt" "net/http" "strings" @@ -85,10 +84,7 @@ func refreshRun(opts *RefreshOptions) error { return err } - candidates, err := cfg.Hosts() - if err != nil { - return err - } + candidates := cfg.Hosts() if len(candidates) == 0 { return fmt.Errorf("not logged in to any hosts. Use 'gh auth login' to authenticate with a host") } @@ -121,18 +117,14 @@ func refreshRun(opts *RefreshOptions) error { } } - if err := cfg.CheckWriteable(hostname, "oauth_token"); err != nil { - var roErr *config.ReadOnlyEnvError - if errors.As(err, &roErr) { - fmt.Fprintf(opts.IO.ErrOut, "The value of the %s environment variable is being used for authentication.\n", roErr.Variable) - fmt.Fprint(opts.IO.ErrOut, "To refresh credentials stored in GitHub CLI, first clear the value from the environment.\n") - return cmdutil.SilentError - } - return err + if src, writeable := shared.AuthTokenWriteable(cfg, hostname); !writeable { + fmt.Fprintf(opts.IO.ErrOut, "The value of the %s environment variable is being used for authentication.\n", src) + fmt.Fprint(opts.IO.ErrOut, "To refresh credentials stored in GitHub CLI, first clear the value from the environment.\n") + return cmdutil.SilentError } var additionalScopes []string - if oldToken, _ := cfg.Get(hostname, "oauth_token"); oldToken != "" { + if oldToken, _ := cfg.AuthToken(hostname); oldToken != "" { if oldScopes, err := shared.GetScopes(opts.httpClient, hostname, oldToken); err == nil { for _, s := range strings.Split(oldScopes, ",") { s = strings.TrimSpace(s) @@ -163,7 +155,7 @@ func refreshRun(opts *RefreshOptions) error { if credentialFlow.ShouldSetup() { username, _ := cfg.Get(hostname, "user") - password, _ := cfg.Get(hostname, "oauth_token") + password, _ := cfg.AuthToken(hostname) if err := credentialFlow.Setup(hostname, username, password); err != nil { return err } diff --git a/pkg/cmd/auth/refresh/refresh_test.go b/pkg/cmd/auth/refresh/refresh_test.go index d8a88d70f..3de01897f 100644 --- a/pkg/cmd/auth/refresh/refresh_test.go +++ b/pkg/cmd/auth/refresh/refresh_test.go @@ -238,19 +238,19 @@ func Test_refreshRun(t *testing.T) { return nil } - ios, _, _, _ := iostreams.Test() - - ios.SetStdinTTY(!tt.nontty) - ios.SetStdoutTTY(!tt.nontty) - - tt.opts.IO = ios - cfg := config.NewBlankConfig() + _ = config.StubWriteConfig(t) + cfg := config.NewFromString("") + for _, hostname := range tt.cfgHosts { + cfg.Set(hostname, "oauth_token", "abc123") + } tt.opts.Config = func() (config.Config, error) { return cfg, nil } - for _, hostname := range tt.cfgHosts { - _ = cfg.Set(hostname, "oauth_token", "abc123") - } + + ios, _, _, _ := iostreams.Test() + ios.SetStdinTTY(!tt.nontty) + ios.SetStdoutTTY(!tt.nontty) + tt.opts.IO = ios httpReg := &httpmock.Registry{} httpReg.Register( @@ -272,10 +272,6 @@ func Test_refreshRun(t *testing.T) { ) tt.opts.httpClient = &http.Client{Transport: httpReg} - mainBuf := bytes.Buffer{} - hostsBuf := bytes.Buffer{} - defer config.StubWriteConfig(&mainBuf, &hostsBuf)() - as := prompt.NewAskStubber(t) if tt.askStubs != nil { tt.askStubs(as) diff --git a/pkg/cmd/auth/setupgit/setupgit.go b/pkg/cmd/auth/setupgit/setupgit.go index 5295ff424..b5caefe5b 100644 --- a/pkg/cmd/auth/setupgit/setupgit.go +++ b/pkg/cmd/auth/setupgit/setupgit.go @@ -54,10 +54,7 @@ func setupGitRun(opts *SetupGitOptions) error { return err } - hostnames, err := cfg.Hosts() - if err != nil { - return err - } + hostnames := cfg.Hosts() stderr := opts.IO.ErrOut cs := opts.IO.ColorScheme() diff --git a/pkg/cmd/auth/setupgit/setupgit_test.go b/pkg/cmd/auth/setupgit/setupgit_test.go index e5c9cda0d..1f3f0be92 100644 --- a/pkg/cmd/auth/setupgit/setupgit_test.go +++ b/pkg/cmd/auth/setupgit/setupgit_test.go @@ -7,7 +7,6 @@ import ( "github.com/cli/cli/v2/internal/config" "github.com/cli/cli/v2/pkg/iostreams" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) type mockGitConfigurer struct { @@ -35,8 +34,16 @@ func Test_setupGitRun(t *testing.T) { expectedErr: "oops", }, { - name: "no authenticated hostnames", - opts: &SetupGitOptions{}, + name: "no authenticated hostnames", + opts: &SetupGitOptions{ + Config: func() (config.Config, error) { + cfg := &config.ConfigMock{} + cfg.HostsFunc = func() []string { + return []string{} + } + return cfg, nil + }, + }, expectedErr: "SilentError", expectedErrOut: "You are not logged into any GitHub hosts. Run gh auth login to authenticate.\n", }, @@ -45,8 +52,10 @@ func Test_setupGitRun(t *testing.T) { opts: &SetupGitOptions{ Hostname: "foo", Config: func() (config.Config, error) { - cfg := config.NewBlankConfig() - require.NoError(t, cfg.Set("bar", "", "")) + cfg := &config.ConfigMock{} + cfg.HostsFunc = func() []string { + return []string{"bar"} + } return cfg, nil }, }, @@ -60,8 +69,10 @@ func Test_setupGitRun(t *testing.T) { setupErr: fmt.Errorf("broken"), }, Config: func() (config.Config, error) { - cfg := config.NewBlankConfig() - require.NoError(t, cfg.Set("bar", "", "")) + cfg := &config.ConfigMock{} + cfg.HostsFunc = func() []string { + return []string{"bar"} + } return cfg, nil }, }, @@ -73,8 +84,10 @@ func Test_setupGitRun(t *testing.T) { opts: &SetupGitOptions{ gitConfigure: &mockGitConfigurer{}, Config: func() (config.Config, error) { - cfg := config.NewBlankConfig() - require.NoError(t, cfg.Set("bar", "", "")) + cfg := &config.ConfigMock{} + cfg.HostsFunc = func() []string { + return []string{"bar"} + } return cfg, nil }, }, @@ -85,9 +98,10 @@ func Test_setupGitRun(t *testing.T) { Hostname: "yes", gitConfigure: &mockGitConfigurer{}, Config: func() (config.Config, error) { - cfg := config.NewBlankConfig() - require.NoError(t, cfg.Set("bar", "", "")) - require.NoError(t, cfg.Set("yes", "", "")) + cfg := &config.ConfigMock{} + cfg.HostsFunc = func() []string { + return []string{"bar", "yes"} + } return cfg, nil }, }, @@ -98,7 +112,7 @@ func Test_setupGitRun(t *testing.T) { t.Run(tt.name, func(t *testing.T) { if tt.opts.Config == nil { tt.opts.Config = func() (config.Config, error) { - return config.NewBlankConfig(), nil + return &config.ConfigMock{}, nil } } diff --git a/pkg/cmd/auth/shared/login_flow.go b/pkg/cmd/auth/shared/login_flow.go index 48b0c773c..0afb84981 100644 --- a/pkg/cmd/auth/shared/login_flow.go +++ b/pkg/cmd/auth/shared/login_flow.go @@ -21,9 +21,8 @@ const defaultSSHKeyTitle = "GitHub CLI" type iconfig interface { Get(string, string) (string, error) - Set(string, string, string) error + Set(string, string, string) Write() error - WriteHosts() error } type LoginOptions struct { @@ -175,9 +174,7 @@ func Login(opts *LoginOptions) error { return fmt.Errorf("error validating token: %w", err) } - if err := cfg.Set(hostname, "oauth_token", authToken); err != nil { - return err - } + cfg.Set(hostname, "oauth_token", authToken) } var username string @@ -191,22 +188,16 @@ func Login(opts *LoginOptions) error { return fmt.Errorf("error using api: %w", err) } - err = cfg.Set(hostname, "user", username) - if err != nil { - return err - } + cfg.Set(hostname, "user", username) } if gitProtocol != "" { fmt.Fprintf(opts.IO.ErrOut, "- gh config set -h %s git_protocol %s\n", hostname, gitProtocol) - err := cfg.Set(hostname, "git_protocol", gitProtocol) - if err != nil { - return err - } + cfg.Set(hostname, "git_protocol", gitProtocol) fmt.Fprintf(opts.IO.ErrOut, "%s Configured git protocol\n", cs.SuccessIcon()) } - err := cfg.WriteHosts() + err := cfg.Write() if err != nil { return err } diff --git a/pkg/cmd/auth/shared/login_flow_test.go b/pkg/cmd/auth/shared/login_flow_test.go index d32e2cc87..467388c6a 100644 --- a/pkg/cmd/auth/shared/login_flow_test.go +++ b/pkg/cmd/auth/shared/login_flow_test.go @@ -22,19 +22,14 @@ func (c tinyConfig) Get(host, key string) (string, error) { return c[fmt.Sprintf("%s:%s", host, key)], nil } -func (c tinyConfig) Set(host string, key string, value string) error { +func (c tinyConfig) Set(host string, key string, value string) { c[fmt.Sprintf("%s:%s", host, key)] = value - return nil } func (c tinyConfig) Write() error { return nil } -func (c tinyConfig) WriteHosts() error { - return nil -} - func TestLogin_ssh(t *testing.T) { dir := t.TempDir() ios, _, stdout, stderr := iostreams.Test() diff --git a/pkg/cmd/auth/shared/writeable.go b/pkg/cmd/auth/shared/writeable.go new file mode 100644 index 000000000..ef117f32d --- /dev/null +++ b/pkg/cmd/auth/shared/writeable.go @@ -0,0 +1,14 @@ +package shared + +import ( + "github.com/cli/cli/v2/internal/config" +) + +const ( + oauthToken = "oauth_token" +) + +func AuthTokenWriteable(cfg config.Config, hostname string) (string, bool) { + token, src := cfg.AuthToken(hostname) + return src, (token == "" || src == oauthToken) +} diff --git a/pkg/cmd/auth/status/status.go b/pkg/cmd/auth/status/status.go index e09273e99..750c74e29 100644 --- a/pkg/cmd/auth/status/status.go +++ b/pkg/cmd/auth/status/status.go @@ -68,10 +68,7 @@ func statusRun(opts *StatusOptions) error { statusInfo := map[string][]string{} - hostnames, err := cfg.Hosts() - if err != nil { - return err - } + hostnames := cfg.Hosts() if len(hostnames) == 0 { fmt.Fprintf(stderr, "You are not logged into any GitHub hosts. Run %s to authenticate.\n", cs.Bold("gh auth login")) @@ -92,8 +89,8 @@ func statusRun(opts *StatusOptions) error { } isHostnameFound = true - token, tokenSource, _ := cfg.GetWithSource(hostname, "oauth_token") - tokenIsWriteable := cfg.CheckWriteable(hostname, "oauth_token") == nil + token, tokenSource := cfg.AuthToken(hostname) + _, tokenIsWriteable := shared.AuthTokenWriteable(cfg, hostname) statusInfo[hostname] = []string{} addMsg := func(x string, ys ...interface{}) { diff --git a/pkg/cmd/auth/status/status_test.go b/pkg/cmd/auth/status/status_test.go index 4bf548643..33aded088 100644 --- a/pkg/cmd/auth/status/status_test.go +++ b/pkg/cmd/auth/status/status_test.go @@ -71,11 +71,13 @@ func Test_NewCmdStatus(t *testing.T) { } func Test_statusRun(t *testing.T) { + readConfigs := config.StubWriteConfig(t) + tests := []struct { name string opts *StatusOptions httpStubs func(*httpmock.Registry) - cfg func(config.Config) + cfgStubs func(*config.ConfigMock) wantErr string wantErrOut *regexp.Regexp }{ @@ -84,9 +86,9 @@ func Test_statusRun(t *testing.T) { opts: &StatusOptions{ Hostname: "joel.miller", }, - cfg: func(c config.Config) { - _ = c.Set("joel.miller", "oauth_token", "abc123") - _ = c.Set("github.com", "oauth_token", "abc123") + cfgStubs: func(c *config.ConfigMock) { + c.Set("joel.miller", "oauth_token", "abc123") + c.Set("github.com", "oauth_token", "abc123") }, httpStubs: func(reg *httpmock.Registry) { reg.Register(httpmock.REST("GET", "api/v3/"), httpmock.ScopesResponder("repo,read:org")) @@ -99,9 +101,9 @@ func Test_statusRun(t *testing.T) { { name: "missing scope", opts: &StatusOptions{}, - cfg: func(c config.Config) { - _ = c.Set("joel.miller", "oauth_token", "abc123") - _ = c.Set("github.com", "oauth_token", "abc123") + cfgStubs: func(c *config.ConfigMock) { + c.Set("joel.miller", "oauth_token", "abc123") + c.Set("github.com", "oauth_token", "abc123") }, httpStubs: func(reg *httpmock.Registry) { reg.Register(httpmock.REST("GET", "api/v3/"), httpmock.ScopesResponder("repo")) @@ -116,9 +118,9 @@ func Test_statusRun(t *testing.T) { { name: "bad token", opts: &StatusOptions{}, - cfg: func(c config.Config) { - _ = c.Set("joel.miller", "oauth_token", "abc123") - _ = c.Set("github.com", "oauth_token", "abc123") + cfgStubs: func(c *config.ConfigMock) { + c.Set("joel.miller", "oauth_token", "abc123") + c.Set("github.com", "oauth_token", "abc123") }, httpStubs: func(reg *httpmock.Registry) { reg.Register(httpmock.REST("GET", "api/v3/"), httpmock.StatusStringResponse(400, "no bueno")) @@ -133,9 +135,9 @@ func Test_statusRun(t *testing.T) { { name: "all good", opts: &StatusOptions{}, - cfg: func(c config.Config) { - _ = c.Set("joel.miller", "oauth_token", "abc123") - _ = c.Set("github.com", "oauth_token", "abc123") + cfgStubs: func(c *config.ConfigMock) { + c.Set("github.com", "oauth_token", "abc123") + c.Set("joel.miller", "oauth_token", "abc123") }, httpStubs: func(reg *httpmock.Registry) { reg.Register(httpmock.REST("GET", "api/v3/"), httpmock.ScopesResponder("repo,read:org")) @@ -152,9 +154,9 @@ func Test_statusRun(t *testing.T) { { name: "hide token", opts: &StatusOptions{}, - cfg: func(c config.Config) { - _ = c.Set("joel.miller", "oauth_token", "abc123") - _ = c.Set("github.com", "oauth_token", "xyz456") + cfgStubs: func(c *config.ConfigMock) { + c.Set("joel.miller", "oauth_token", "abc123") + c.Set("github.com", "oauth_token", "xyz456") }, httpStubs: func(reg *httpmock.Registry) { reg.Register(httpmock.REST("GET", "api/v3/"), httpmock.ScopesResponder("repo,read:org")) @@ -173,9 +175,9 @@ func Test_statusRun(t *testing.T) { opts: &StatusOptions{ ShowToken: true, }, - cfg: func(c config.Config) { - _ = c.Set("joel.miller", "oauth_token", "abc123") - _ = c.Set("github.com", "oauth_token", "xyz456") + cfgStubs: func(c *config.ConfigMock) { + c.Set("github.com", "oauth_token", "xyz456") + c.Set("joel.miller", "oauth_token", "abc123") }, httpStubs: func(reg *httpmock.Registry) { reg.Register(httpmock.REST("GET", "api/v3/"), httpmock.ScopesResponder("repo,read:org")) @@ -188,13 +190,14 @@ func Test_statusRun(t *testing.T) { httpmock.StringResponse(`{"data":{"viewer":{"login":"tess"}}}`)) }, wantErrOut: regexp.MustCompile(`(?s)Token: xyz456.*Token: abc123`), - }, { + }, + { name: "missing hostname", opts: &StatusOptions{ Hostname: "github.example.com", }, - cfg: func(c config.Config) { - _ = c.Set("github.com", "oauth_token", "abc123") + cfgStubs: func(c *config.ConfigMock) { + c.Set("github.com", "oauth_token", "abc123") }, httpStubs: func(reg *httpmock.Registry) {}, wantErrOut: regexp.MustCompile(`(?s)Hostname "github.example.com" not found among authenticated GitHub hosts`), @@ -213,13 +216,11 @@ func Test_statusRun(t *testing.T) { ios.SetStdinTTY(true) ios.SetStderrTTY(true) ios.SetStdoutTTY(true) - tt.opts.IO = ios - cfg := config.NewBlankConfig() - - if tt.cfg != nil { - tt.cfg(cfg) + cfg := config.NewFromString("") + if tt.cfgStubs != nil { + tt.cfgStubs(cfg) } tt.opts.Config = func() (config.Config, error) { return cfg, nil @@ -232,9 +233,6 @@ func Test_statusRun(t *testing.T) { if tt.httpStubs != nil { tt.httpStubs(reg) } - mainBuf := bytes.Buffer{} - hostsBuf := bytes.Buffer{} - defer config.StubWriteConfig(&mainBuf, &hostsBuf)() err := statusRun(tt.opts) if tt.wantErr != "" { @@ -250,6 +248,10 @@ func Test_statusRun(t *testing.T) { assert.True(t, tt.wantErrOut.MatchString(stderr.String())) } + mainBuf := bytes.Buffer{} + hostsBuf := bytes.Buffer{} + readConfigs(&mainBuf, &hostsBuf) + assert.Equal(t, "", mainBuf.String()) assert.Equal(t, "", hostsBuf.String()) diff --git a/pkg/cmd/config/get/get_test.go b/pkg/cmd/config/get/get_test.go index 0a530012d..baebb584a 100644 --- a/pkg/cmd/config/get/get_test.go +++ b/pkg/cmd/config/get/get_test.go @@ -42,7 +42,7 @@ func TestNewCmdConfigGet(t *testing.T) { t.Run(tt.name, func(t *testing.T) { f := &cmdutil.Factory{ Config: func() (config.Config, error) { - return config.ConfigStub{}, nil + return config.NewBlankConfig(), nil }, } @@ -86,9 +86,11 @@ func Test_getRun(t *testing.T) { name: "get key", input: &GetOptions{ Key: "editor", - Config: config.ConfigStub{ - "editor": "ed", - }, + Config: func() config.Config { + cfg := config.NewBlankConfig() + cfg.Set("", "editor", "ed") + return cfg + }(), }, stdout: "ed\n", }, @@ -97,10 +99,12 @@ func Test_getRun(t *testing.T) { input: &GetOptions{ Hostname: "github.com", Key: "editor", - Config: config.ConfigStub{ - "editor": "ed", - "github.com:editor": "vim", - }, + Config: func() config.Config { + cfg := config.NewBlankConfig() + cfg.Set("", "editor", "ed") + cfg.Set("github.com", "editor", "vim") + return cfg + }(), }, stdout: "vim\n", }, @@ -115,10 +119,6 @@ func Test_getRun(t *testing.T) { assert.NoError(t, err) assert.Equal(t, tt.stdout, stdout.String()) assert.Equal(t, tt.stderr, stderr.String()) - _, err = tt.input.Config.GetOrDefault("", "_written") - assert.Error(t, err) - _, err = tt.input.Config.Get("", "_written") - assert.Error(t, err) }) } } diff --git a/pkg/cmd/config/list/list.go b/pkg/cmd/config/list/list.go index cf1422f0f..f9da66d3f 100644 --- a/pkg/cmd/config/list/list.go +++ b/pkg/cmd/config/list/list.go @@ -51,10 +51,7 @@ func listRun(opts *ListOptions) error { if opts.Hostname != "" { host = opts.Hostname } else { - host, err = cfg.DefaultHost() - if err != nil { - return err - } + host, _ = cfg.DefaultHost() } configOptions := config.ConfigOptions() diff --git a/pkg/cmd/config/list/list_test.go b/pkg/cmd/config/list/list_test.go index 7297d4816..945979414 100644 --- a/pkg/cmd/config/list/list_test.go +++ b/pkg/cmd/config/list/list_test.go @@ -36,7 +36,7 @@ func TestNewCmdConfigList(t *testing.T) { t.Run(tt.name, func(t *testing.T) { f := &cmdutil.Factory{ Config: func() (config.Config, error) { - return config.ConfigStub{}, nil + return config.NewBlankConfig(), nil }, } @@ -71,21 +71,23 @@ func Test_listRun(t *testing.T) { tests := []struct { name string input *ListOptions - config config.ConfigStub + config config.Config stdout string wantErr bool }{ { name: "list", - config: config.ConfigStub{ - "HOST:git_protocol": "ssh", - "HOST:editor": "/usr/bin/vim", - "HOST:prompt": "disabled", - "HOST:pager": "less", - "HOST:http_unix_socket": "", - "HOST:browser": "brave", - }, - input: &ListOptions{Hostname: "HOST"}, // ConfigStub gives empty DefaultHost + config: func() config.Config { + cfg := config.NewBlankConfig() + cfg.Set("HOST", "git_protocol", "ssh") + cfg.Set("HOST", "editor", "/usr/bin/vim") + cfg.Set("HOST", "prompt", "disabled") + cfg.Set("HOST", "pager", "less") + cfg.Set("HOST", "http_unix_socket", "") + cfg.Set("HOST", "browser", "brave") + return cfg + }(), + input: &ListOptions{Hostname: "HOST"}, stdout: `git_protocol=ssh editor=/usr/bin/vim prompt=disabled diff --git a/pkg/cmd/config/set/set.go b/pkg/cmd/config/set/set.go index 38a23c899..e99bfb17c 100644 --- a/pkg/cmd/config/set/set.go +++ b/pkg/cmd/config/set/set.go @@ -59,15 +59,15 @@ func NewCmdConfigSet(f *cmdutil.Factory, runF func(*SetOptions) error) *cobra.Co } func setRun(opts *SetOptions) error { - err := config.ValidateKey(opts.Key) + err := ValidateKey(opts.Key) if err != nil { warningIcon := opts.IO.ColorScheme().WarningIcon() fmt.Fprintf(opts.IO.ErrOut, "%s warning: '%s' is not a known configuration key\n", warningIcon, opts.Key) } - err = config.ValidateValue(opts.Key, opts.Value) + err = ValidateValue(opts.Key, opts.Value) if err != nil { - var invalidValue *config.InvalidValueError + var invalidValue InvalidValueError if errors.As(err, &invalidValue) { var values []string for _, v := range invalidValue.ValidValues { @@ -77,10 +77,7 @@ func setRun(opts *SetOptions) error { } } - err = opts.Config.Set(opts.Hostname, opts.Key, opts.Value) - if err != nil { - return fmt.Errorf("failed to set %q to %q: %w", opts.Key, opts.Value, err) - } + opts.Config.Set(opts.Hostname, opts.Key, opts.Value) err = opts.Config.Write() if err != nil { @@ -88,3 +85,44 @@ func setRun(opts *SetOptions) error { } return nil } + +func ValidateKey(key string) error { + for _, configKey := range config.ConfigOptions() { + if key == configKey.Key { + return nil + } + } + + return fmt.Errorf("invalid key") +} + +type InvalidValueError struct { + ValidValues []string +} + +func (e InvalidValueError) Error() string { + return "invalid value" +} + +func ValidateValue(key, value string) error { + var validValues []string + + for _, v := range config.ConfigOptions() { + if v.Key == key { + validValues = v.AllowedValues + break + } + } + + if validValues == nil { + return nil + } + + for _, v := range validValues { + if v == value { + return nil + } + } + + return InvalidValueError{ValidValues: validValues} +} diff --git a/pkg/cmd/config/set/set_test.go b/pkg/cmd/config/set/set_test.go index 015e13def..2e26fd044 100644 --- a/pkg/cmd/config/set/set_test.go +++ b/pkg/cmd/config/set/set_test.go @@ -46,9 +46,11 @@ func TestNewCmdConfigSet(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + _ = config.StubWriteConfig(t) + f := &cmdutil.Factory{ Config: func() (config.Config, error) { - return config.ConfigStub{}, nil + return config.NewBlankConfig(), nil }, } @@ -94,7 +96,7 @@ func Test_setRun(t *testing.T) { { name: "set key value", input: &SetOptions{ - Config: config.ConfigStub{}, + Config: config.NewBlankConfig(), Key: "editor", Value: "vim", }, @@ -103,7 +105,7 @@ func Test_setRun(t *testing.T) { { name: "set key value scoped by host", input: &SetOptions{ - Config: config.ConfigStub{}, + Config: config.NewBlankConfig(), Hostname: "github.com", Key: "editor", Value: "vim", @@ -113,7 +115,7 @@ func Test_setRun(t *testing.T) { { name: "set unknown key", input: &SetOptions{ - Config: config.ConfigStub{}, + Config: config.NewBlankConfig(), Key: "unknownKey", Value: "someValue", }, @@ -123,7 +125,7 @@ func Test_setRun(t *testing.T) { { name: "set invalid value", input: &SetOptions{ - Config: config.ConfigStub{}, + Config: config.NewBlankConfig(), Key: "git_protocol", Value: "invalid", }, @@ -132,10 +134,12 @@ func Test_setRun(t *testing.T) { }, } for _, tt := range tests { - ios, _, stdout, stderr := iostreams.Test() - tt.input.IO = ios - t.Run(tt.name, func(t *testing.T) { + _ = config.StubWriteConfig(t) + + ios, _, stdout, stderr := iostreams.Test() + tt.input.IO = ios + err := setRun(tt.input) if tt.wantsErr { assert.EqualError(t, err, tt.errMsg) @@ -148,10 +152,46 @@ func Test_setRun(t *testing.T) { val, err := tt.input.Config.GetOrDefault(tt.input.Hostname, tt.input.Key) assert.NoError(t, err) assert.Equal(t, tt.expectedValue, val) - - val, err = tt.input.Config.GetOrDefault("", "_written") - assert.NoError(t, err) - assert.Equal(t, "true", val) }) } } + +func Test_ValidateValue(t *testing.T) { + err := ValidateValue("git_protocol", "sshpps") + assert.EqualError(t, err, "invalid value") + + err = ValidateValue("git_protocol", "ssh") + assert.NoError(t, err) + + err = ValidateValue("editor", "vim") + assert.NoError(t, err) + + err = ValidateValue("got", "123") + assert.NoError(t, err) + + err = ValidateValue("http_unix_socket", "really_anything/is/allowed/and/net.Dial\\(...\\)/will/ultimately/validate") + assert.NoError(t, err) +} + +func Test_ValidateKey(t *testing.T) { + err := ValidateKey("invalid") + assert.EqualError(t, err, "invalid key") + + err = ValidateKey("git_protocol") + assert.NoError(t, err) + + err = ValidateKey("editor") + assert.NoError(t, err) + + err = ValidateKey("prompt") + assert.NoError(t, err) + + err = ValidateKey("pager") + assert.NoError(t, err) + + err = ValidateKey("http_unix_socket") + assert.NoError(t, err) + + err = ValidateKey("browser") + assert.NoError(t, err) +} diff --git a/pkg/cmd/extension/manager.go b/pkg/cmd/extension/manager.go index 9b766c1c2..20a60087d 100644 --- a/pkg/cmd/extension/manager.go +++ b/pkg/cmd/extension/manager.go @@ -702,10 +702,7 @@ func (m *Manager) goBinScaffolding(gitExe, name string) error { return err } - host, err := m.config.DefaultHost() - if err != nil { - return err - } + host, _ := m.config.DefaultHost() currentUser, err := api.CurrentLoginName(api.NewClientFromHTTP(m.client), host) if err != nil { diff --git a/pkg/cmd/factory/default.go b/pkg/cmd/factory/default.go index 76b18b248..fdc56bb3e 100644 --- a/pkg/cmd/factory/default.go +++ b/pkg/cmd/factory/default.go @@ -1,7 +1,6 @@ package factory import ( - "errors" "fmt" "net/http" "os" @@ -134,12 +133,7 @@ func configFunc() func() (config.Config, error) { if cachedConfig != nil || configError != nil { return cachedConfig, configError } - cachedConfig, configError = config.ParseDefaultConfig() - if errors.Is(configError, os.ErrNotExist) { - cachedConfig = config.NewBlankConfig() - configError = nil - } - cachedConfig = config.InheritEnv(cachedConfig) + cachedConfig, configError = config.NewConfig() return cachedConfig, configError } } diff --git a/pkg/cmd/factory/default_test.go b/pkg/cmd/factory/default_test.go index e6399ae7c..7863ce933 100644 --- a/pkg/cmd/factory/default_test.go +++ b/pkg/cmd/factory/default_test.go @@ -7,7 +7,6 @@ import ( "os" "testing" - "github.com/MakeNowJust/heredoc" "github.com/cli/cli/v2/git" "github.com/cli/cli/v2/internal/config" "github.com/cli/cli/v2/pkg/cmdutil" @@ -17,15 +16,9 @@ import ( ) func Test_BaseRepo(t *testing.T) { - orig_GH_HOST := os.Getenv("GH_HOST") - t.Cleanup(func() { - os.Setenv("GH_HOST", orig_GH_HOST) - }) - tests := []struct { name string remotes git.RemoteSet - config config.Config override string wantsErr bool wantsName string @@ -37,7 +30,6 @@ func Test_BaseRepo(t *testing.T) { remotes: git.RemoteSet{ git.NewRemote("origin", "https://nonsense.com/owner/repo.git"), }, - config: defaultConfig(), wantsName: "repo", wantsOwner: "owner", wantsHost: "nonsense.com", @@ -47,7 +39,6 @@ func Test_BaseRepo(t *testing.T) { remotes: git.RemoteSet{ git.NewRemote("origin", "https://test.com/owner/repo.git"), }, - config: defaultConfig(), wantsErr: true, }, { @@ -55,7 +46,6 @@ func Test_BaseRepo(t *testing.T) { remotes: git.RemoteSet{ git.NewRemote("origin", "https://test.com/owner/repo.git"), }, - config: defaultConfig(), override: "test.com", wantsName: "repo", wantsOwner: "owner", @@ -66,7 +56,6 @@ func Test_BaseRepo(t *testing.T) { remotes: git.RemoteSet{ git.NewRemote("origin", "https://nonsense.com/owner/repo.git"), }, - config: defaultConfig(), override: "test.com", wantsErr: true, }, @@ -74,18 +63,30 @@ func Test_BaseRepo(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if tt.override != "" { - os.Setenv("GH_HOST", tt.override) - } else { - os.Unsetenv("GH_HOST") - } f := New("1") rr := &remoteResolver{ readRemotes: func() (git.RemoteSet, error) { return tt.remotes, nil }, getConfig: func() (config.Config, error) { - return tt.config, nil + cfg := &config.ConfigMock{} + cfg.HostsFunc = func() []string { + hosts := []string{"nonsense.com"} + if tt.override != "" { + hosts = append([]string{tt.override}, hosts...) + } + return hosts + } + cfg.DefaultHostFunc = func() (string, string) { + if tt.override != "" { + return tt.override, "GH_HOST" + } + return "nonsense.com", "hosts" + } + cfg.AuthTokenFunc = func(string) (string, string) { + return "", "" + } + return cfg, nil }, } f.Remotes = rr.Resolver() @@ -105,15 +106,10 @@ func Test_BaseRepo(t *testing.T) { func Test_SmartBaseRepo(t *testing.T) { pu, _ := url.Parse("https://test.com/newowner/newrepo.git") - orig_GH_HOST := os.Getenv("GH_HOST") - t.Cleanup(func() { - os.Setenv("GH_HOST", orig_GH_HOST) - }) tests := []struct { name string remotes git.RemoteSet - config config.Config override string wantsErr bool wantsName string @@ -125,7 +121,6 @@ func Test_SmartBaseRepo(t *testing.T) { remotes: git.RemoteSet{ git.NewRemote("origin", "https://test.com/owner/repo.git"), }, - config: defaultConfig(), override: "test.com", wantsName: "repo", wantsOwner: "owner", @@ -139,7 +134,6 @@ func Test_SmartBaseRepo(t *testing.T) { FetchURL: pu, PushURL: pu}, }, - config: defaultConfig(), override: "test.com", wantsName: "newrepo", wantsOwner: "newowner", @@ -153,7 +147,6 @@ func Test_SmartBaseRepo(t *testing.T) { FetchURL: pu, PushURL: pu}, }, - config: defaultConfig(), override: "test.com", wantsName: "test", wantsOwner: "johnny", @@ -164,7 +157,6 @@ func Test_SmartBaseRepo(t *testing.T) { remotes: git.RemoteSet{ git.NewRemote("origin", "https://example.com/owner/repo.git"), }, - config: defaultConfig(), override: "test.com", wantsErr: true, }, @@ -172,18 +164,27 @@ func Test_SmartBaseRepo(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if tt.override != "" { - os.Setenv("GH_HOST", tt.override) - } else { - os.Unsetenv("GH_HOST") - } f := New("1") rr := &remoteResolver{ readRemotes: func() (git.RemoteSet, error) { return tt.remotes, nil }, getConfig: func() (config.Config, error) { - return tt.config, nil + cfg := &config.ConfigMock{} + cfg.HostsFunc = func() []string { + hosts := []string{"nonsense.com"} + if tt.override != "" { + hosts = append([]string{tt.override}, hosts...) + } + return hosts + } + cfg.DefaultHostFunc = func() (string, string) { + if tt.override != "" { + return tt.override, "GH_HOST" + } + return "nonsense.com", "hosts" + } + return cfg, nil }, } f.HttpClient = func() (*http.Client, error) { return nil, nil } @@ -204,11 +205,6 @@ func Test_SmartBaseRepo(t *testing.T) { // Defined in pkg/cmdutil/repo_override.go but test it along with other BaseRepo functions func Test_OverrideBaseRepo(t *testing.T) { - orig_GH_HOST := os.Getenv("GH_REPO") - t.Cleanup(func() { - os.Setenv("GH_REPO", orig_GH_HOST) - }) - tests := []struct { name string remotes git.RemoteSet @@ -249,9 +245,9 @@ func Test_OverrideBaseRepo(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if tt.envOverride != "" { + old := os.Getenv("GH_REPO") os.Setenv("GH_REPO", tt.envOverride) - } else { - os.Unsetenv("GH_REPO") + defer os.Setenv("GH_REPO", old) } f := New("1") rr := &remoteResolver{ @@ -511,12 +507,10 @@ func TestSSOURL(t *testing.T) { } } -func defaultConfig() config.Config { - return config.InheritEnv(config.NewFromString(heredoc.Doc(` - hosts: - nonsense.com: - oauth_token: BLAH - `))) +func defaultConfig() *config.ConfigMock { + cfg := config.NewFromString("") + cfg.Set("nonsense.com", "oauth_token", "BLAH") + return cfg } func pagerConfig() config.Config { diff --git a/pkg/cmd/factory/remote_resolver.go b/pkg/cmd/factory/remote_resolver.go index 73872c776..5a8893a3d 100644 --- a/pkg/cmd/factory/remote_resolver.go +++ b/pkg/cmd/factory/remote_resolver.go @@ -13,6 +13,10 @@ import ( "github.com/cli/go-gh/pkg/ssh" ) +const ( + GH_HOST = "GH_HOST" +) + type remoteResolver struct { readRemotes func() (git.RemoteSet, error) getConfig func() (config.Config, error) @@ -49,14 +53,12 @@ func (rr *remoteResolver) Resolver() func() (context.Remotes, error) { return nil, err } - authedHosts, err := cfg.Hosts() - if err != nil { - return nil, err - } - defaultHost, src, err := cfg.DefaultHostWithSource() - if err != nil { - return nil, err + authedHosts := cfg.Hosts() + if len(authedHosts) == 0 { + return nil, errors.New("could not find any host configurations") } + defaultHost, src := cfg.DefaultHost() + // Use set to dedupe list of hosts hostsSet := set.NewStringSet() hostsSet.AddValues(authedHosts) @@ -72,18 +74,19 @@ func (rr *remoteResolver) Resolver() func() (context.Remotes, error) { // Filter again by default host if one is set // For config file default host fallback to cachedRemotes if none match // For enviornment default host (GH_HOST) do not fallback to cachedRemotes if none match - if src != "" { + if src != "default" { filteredRemotes := cachedRemotes.FilterByHosts([]string{defaultHost}) - if config.IsHostEnv(src) || len(filteredRemotes) > 0 { + if isHostEnv(src) || len(filteredRemotes) > 0 { cachedRemotes = filteredRemotes } } if len(cachedRemotes) == 0 { - dummyHostname := "example.com" // any non-github.com hostname is fine here - if config.IsHostEnv(src) { + // Any non-github.com hostname is fine here + dummyHostname := "example.com" + if isHostEnv(src) { return nil, fmt.Errorf("none of the git remotes configured for this repository correspond to the %s environment variable. Try adding a matching remote or unsetting the variable.", src) - } else if v, src, _ := cfg.GetWithSource(dummyHostname, "oauth_token"); v != "" && config.IsEnterpriseEnv(src) { + } else if v, _ := cfg.AuthToken(dummyHostname); v != "" { return nil, errors.New("set the GH_HOST environment variable to specify which GitHub host to use") } return nil, errors.New("none of the git remotes configured for this repository point to a known GitHub host. To tell gh about a new GitHub host, please use `gh auth login`") @@ -92,3 +95,7 @@ func (rr *remoteResolver) Resolver() func() (context.Remotes, error) { return cachedRemotes, nil } } + +func isHostEnv(src string) bool { + return src == GH_HOST +} diff --git a/pkg/cmd/factory/remote_resolver_test.go b/pkg/cmd/factory/remote_resolver_test.go index 69b8c1bba..07ae5863d 100644 --- a/pkg/cmd/factory/remote_resolver_test.go +++ b/pkg/cmd/factory/remote_resolver_test.go @@ -5,7 +5,6 @@ import ( "os" "testing" - "github.com/MakeNowJust/heredoc" "github.com/cli/cli/v2/git" "github.com/cli/cli/v2/internal/config" "github.com/stretchr/testify/assert" @@ -26,8 +25,7 @@ func Test_remoteResolver(t *testing.T) { tests := []struct { name string remotes func() (git.RemoteSet, error) - config func() (config.Config, error) - override string + config config.Config output []string wantsErr bool }{ @@ -38,9 +36,16 @@ func Test_remoteResolver(t *testing.T) { git.NewRemote("origin", "https://github.com/owner/repo.git"), }, nil }, - config: func() (config.Config, error) { - return config.NewFromString(heredoc.Doc(`hosts:`)), nil - }, + config: func() config.Config { + cfg := &config.ConfigMock{} + cfg.HostsFunc = func() []string { + return []string{} + } + cfg.DefaultHostFunc = func() (string, string) { + return "github.com", "default" + } + return cfg + }(), wantsErr: true, }, { @@ -48,13 +53,16 @@ func Test_remoteResolver(t *testing.T) { remotes: func() (git.RemoteSet, error) { return git.RemoteSet{}, nil }, - config: func() (config.Config, error) { - return config.NewFromString(heredoc.Doc(` - hosts: - example.com: - oauth_token: GHETOKEN - `)), nil - }, + config: func() config.Config { + cfg := &config.ConfigMock{} + cfg.HostsFunc = func() []string { + return []string{"example.com"} + } + cfg.DefaultHostFunc = func() (string, string) { + return "example.com", "hosts" + } + return cfg + }(), wantsErr: true, }, { @@ -64,13 +72,19 @@ func Test_remoteResolver(t *testing.T) { git.NewRemote("origin", "https://test.com/owner/repo.git"), }, nil }, - config: func() (config.Config, error) { - return config.NewFromString(heredoc.Doc(` - hosts: - example.com: - oauth_token: GHETOKEN - `)), nil - }, + config: func() config.Config { + cfg := &config.ConfigMock{} + cfg.HostsFunc = func() []string { + return []string{"example.com"} + } + cfg.DefaultHostFunc = func() (string, string) { + return "example.com", "hosts" + } + cfg.AuthTokenFunc = func(string) (string, string) { + return "", "" + } + return cfg + }(), wantsErr: true, }, { @@ -80,30 +94,35 @@ func Test_remoteResolver(t *testing.T) { git.NewRemote("origin", "https://github.com/owner/repo.git"), }, nil }, - config: func() (config.Config, error) { - return config.NewFromString(heredoc.Doc(` - hosts: - example.com: - oauth_token: GHETOKEN - `)), nil - }, + config: func() config.Config { + cfg := &config.ConfigMock{} + cfg.HostsFunc = func() []string { + return []string{"example.com"} + } + cfg.DefaultHostFunc = func() (string, string) { + return "example.com", "hosts" + } + return cfg + }(), output: []string{"origin"}, }, { name: "one authenticated host with matching git remote", remotes: func() (git.RemoteSet, error) { return git.RemoteSet{ - git.NewRemote("upstream", "https://github.com/owner/repo.git"), git.NewRemote("origin", "https://example.com/owner/repo.git"), }, nil }, - config: func() (config.Config, error) { - return config.NewFromString(heredoc.Doc(` - hosts: - example.com: - oauth_token: GHETOKEN - `)), nil - }, + config: func() config.Config { + cfg := &config.ConfigMock{} + cfg.HostsFunc = func() []string { + return []string{"example.com"} + } + cfg.DefaultHostFunc = func() (string, string) { + return "example.com", "default" + } + return cfg + }(), output: []string{"origin"}, }, { @@ -116,13 +135,16 @@ func Test_remoteResolver(t *testing.T) { git.NewRemote("fork", "https://example.com/owner/repo.git"), }, nil }, - config: func() (config.Config, error) { - return config.NewFromString(heredoc.Doc(` - hosts: - example.com: - oauth_token: GHETOKEN - `)), nil - }, + config: func() config.Config { + cfg := &config.ConfigMock{} + cfg.HostsFunc = func() []string { + return []string{"example.com"} + } + cfg.DefaultHostFunc = func() (string, string) { + return "example.com", "default" + } + return cfg + }(), output: []string{"upstream", "github", "origin", "fork"}, }, { @@ -132,15 +154,19 @@ func Test_remoteResolver(t *testing.T) { git.NewRemote("origin", "https://test.com/owner/repo.git"), }, nil }, - config: func() (config.Config, error) { - return config.NewFromString(heredoc.Doc(` - hosts: - example.com: - oauth_token: GHETOKEN - github.com: - oauth_token: GHTOKEN - `)), nil - }, + config: func() config.Config { + cfg := &config.ConfigMock{} + cfg.HostsFunc = func() []string { + return []string{"example.com", "github.com"} + } + cfg.DefaultHostFunc = func() (string, string) { + return "github.com", "default" + } + cfg.AuthTokenFunc = func(string) (string, string) { + return "", "" + } + return cfg + }(), wantsErr: true, }, { @@ -151,15 +177,16 @@ func Test_remoteResolver(t *testing.T) { git.NewRemote("origin", "https://example.com/owner/repo.git"), }, nil }, - config: func() (config.Config, error) { - return config.NewFromString(heredoc.Doc(` - hosts: - example.com: - oauth_token: GHETOKEN - github.com: - oauth_token: GHTOKEN - `)), nil - }, + config: func() config.Config { + cfg := &config.ConfigMock{} + cfg.HostsFunc = func() []string { + return []string{"example.com", "github.com"} + } + cfg.DefaultHostFunc = func() (string, string) { + return "github.com", "default" + } + return cfg + }(), output: []string{"origin"}, }, { @@ -173,15 +200,16 @@ func Test_remoteResolver(t *testing.T) { git.NewRemote("test", "https://test.com/owner/repo.git"), }, nil }, - config: func() (config.Config, error) { - return config.NewFromString(heredoc.Doc(` - hosts: - example.com: - oauth_token: GHETOKEN - github.com: - oauth_token: GHTOKEN - `)), nil - }, + config: func() config.Config { + cfg := &config.ConfigMock{} + cfg.HostsFunc = func() []string { + return []string{"example.com", "github.com"} + } + cfg.DefaultHostFunc = func() (string, string) { + return "github.com", "default" + } + return cfg + }(), output: []string{"upstream", "github", "origin", "fork"}, }, { @@ -191,14 +219,16 @@ func Test_remoteResolver(t *testing.T) { git.NewRemote("origin", "https://example.com/owner/repo.git"), }, nil }, - config: func() (config.Config, error) { - return config.InheritEnv(config.NewFromString(heredoc.Doc(` - hosts: - example.com: - oauth_token: GHETOKEN - `))), nil - }, - override: "test.com", + config: func() config.Config { + cfg := &config.ConfigMock{} + cfg.HostsFunc = func() []string { + return []string{"example.com"} + } + cfg.DefaultHostFunc = func() (string, string) { + return "test.com", "GH_HOST" + } + return cfg + }(), wantsErr: true, }, { @@ -209,15 +239,17 @@ func Test_remoteResolver(t *testing.T) { git.NewRemote("origin", "https://test.com/owner/repo.git"), }, nil }, - config: func() (config.Config, error) { - return config.InheritEnv(config.NewFromString(heredoc.Doc(` - hosts: - example.com: - oauth_token: GHETOKEN - `))), nil - }, - override: "test.com", - output: []string{"origin"}, + config: func() config.Config { + cfg := &config.ConfigMock{} + cfg.HostsFunc = func() []string { + return []string{"example.com"} + } + cfg.DefaultHostFunc = func() (string, string) { + return "test.com", "GH_HOST" + } + return cfg + }(), + output: []string{"origin"}, }, { name: "override host with multiple matching git remotes", @@ -228,26 +260,25 @@ func Test_remoteResolver(t *testing.T) { git.NewRemote("origin", "https://test.com/owner/repo.git"), }, nil }, - config: func() (config.Config, error) { - return config.InheritEnv(config.NewFromString(heredoc.Doc(` - hosts: - example.com: - oauth_token: GHETOKEN - `))), nil - }, - override: "test.com", - output: []string{"upstream", "origin"}, + config: func() config.Config { + cfg := &config.ConfigMock{} + cfg.HostsFunc = func() []string { + return []string{"example.com", "test.com"} + } + cfg.DefaultHostFunc = func() (string, string) { + return "test.com", "GH_HOST" + } + return cfg + }(), + output: []string{"upstream", "origin"}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if tt.override != "" { - os.Setenv("GH_HOST", tt.override) - } rr := &remoteResolver{ readRemotes: tt.remotes, - getConfig: tt.config, + getConfig: func() (config.Config, error) { return tt.config, nil }, urlTranslator: identityTranslator{}, } resolver := rr.Resolver() diff --git a/pkg/cmd/gist/clone/clone.go b/pkg/cmd/gist/clone/clone.go index 5127dba64..c1dfee31b 100644 --- a/pkg/cmd/gist/clone/clone.go +++ b/pkg/cmd/gist/clone/clone.go @@ -75,10 +75,7 @@ func cloneRun(opts *CloneOptions) error { if err != nil { return err } - hostname, err := cfg.DefaultHost() - if err != nil { - return err - } + hostname, _ := cfg.DefaultHost() protocol, err := cfg.GetOrDefault(hostname, "git_protocol") if err != nil { return err diff --git a/pkg/cmd/gist/create/create.go b/pkg/cmd/gist/create/create.go index 254dbc0a7..1122bdb7d 100644 --- a/pkg/cmd/gist/create/create.go +++ b/pkg/cmd/gist/create/create.go @@ -143,10 +143,7 @@ func createRun(opts *CreateOptions) error { return err } - host, err := cfg.DefaultHost() - if err != nil { - return err - } + host, _ := cfg.DefaultHost() opts.IO.StartProgressIndicator() gist, err := createGist(httpClient, host, opts.Description, opts.Public, files) diff --git a/pkg/cmd/gist/delete/delete.go b/pkg/cmd/gist/delete/delete.go index 70dbb8d2a..b4e76a69a 100644 --- a/pkg/cmd/gist/delete/delete.go +++ b/pkg/cmd/gist/delete/delete.go @@ -64,10 +64,7 @@ func deleteRun(opts *DeleteOptions) error { return err } - host, err := cfg.DefaultHost() - if err != nil { - return err - } + host, _ := cfg.DefaultHost() apiClient := api.NewClientFromHTTP(client) if err := deleteGist(apiClient, host, gistID); err != nil { diff --git a/pkg/cmd/gist/edit/edit.go b/pkg/cmd/gist/edit/edit.go index ff4e59b98..47b122c27 100644 --- a/pkg/cmd/gist/edit/edit.go +++ b/pkg/cmd/gist/edit/edit.go @@ -107,10 +107,7 @@ func editRun(opts *EditOptions) error { return err } - host, err := cfg.DefaultHost() - if err != nil { - return err - } + host, _ := cfg.DefaultHost() gist, err := shared.GetGist(client, host, gistID) if err != nil { diff --git a/pkg/cmd/gist/list/list.go b/pkg/cmd/gist/list/list.go index 46d373d0e..e1f453781 100644 --- a/pkg/cmd/gist/list/list.go +++ b/pkg/cmd/gist/list/list.go @@ -76,10 +76,7 @@ func listRun(opts *ListOptions) error { return err } - host, err := cfg.DefaultHost() - if err != nil { - return err - } + host, _ := cfg.DefaultHost() gists, err := shared.ListGists(client, host, opts.Limit, opts.Visibility) if err != nil { diff --git a/pkg/cmd/gist/view/view.go b/pkg/cmd/gist/view/view.go index 121ca2fac..1c4e618df 100644 --- a/pkg/cmd/gist/view/view.go +++ b/pkg/cmd/gist/view/view.go @@ -86,10 +86,7 @@ func viewRun(opts *ViewOptions) error { return err } - hostname, err := cfg.DefaultHost() - if err != nil { - return err - } + hostname, _ := cfg.DefaultHost() cs := opts.IO.ColorScheme() if gistID == "" { diff --git a/pkg/cmd/gpg-key/add/add.go b/pkg/cmd/gpg-key/add/add.go index 3e8da7ad7..c5b8c2abe 100644 --- a/pkg/cmd/gpg-key/add/add.go +++ b/pkg/cmd/gpg-key/add/add.go @@ -76,10 +76,7 @@ func runAdd(opts *AddOptions) error { return err } - hostname, err := cfg.DefaultHost() - if err != nil { - return err - } + hostname, _ := cfg.DefaultHost() err = gpgKeyUpload(httpClient, hostname, keyReader) if err != nil { diff --git a/pkg/cmd/gpg-key/list/list.go b/pkg/cmd/gpg-key/list/list.go index 30d28518d..137f61f0b 100644 --- a/pkg/cmd/gpg-key/list/list.go +++ b/pkg/cmd/gpg-key/list/list.go @@ -53,10 +53,7 @@ func listRun(opts *ListOptions) error { return err } - host, err := cfg.DefaultHost() - if err != nil { - return err - } + host, _ := cfg.DefaultHost() gpgKeys, err := userKeys(apiClient, host, "") if err != nil { diff --git a/pkg/cmd/repo/archive/archive.go b/pkg/cmd/repo/archive/archive.go index 5dc8d0fb6..2c4487af6 100644 --- a/pkg/cmd/repo/archive/archive.go +++ b/pkg/cmd/repo/archive/archive.go @@ -83,10 +83,7 @@ func archiveRun(opts *ArchiveOptions) error { return err } - hostname, err := cfg.DefaultHost() - if err != nil { - return err - } + hostname, _ := cfg.DefaultHost() currentUser, err := api.CurrentLoginName(apiClient, hostname) if err != nil { diff --git a/pkg/cmd/repo/clone/clone.go b/pkg/cmd/repo/clone/clone.go index 57eb4b022..ba289aa47 100644 --- a/pkg/cmd/repo/clone/clone.go +++ b/pkg/cmd/repo/clone/clone.go @@ -111,10 +111,7 @@ func cloneRun(opts *CloneOptions) error { if repositoryIsFullName { fullName = opts.Repository } else { - host, err := cfg.DefaultHost() - if err != nil { - return err - } + host, _ := cfg.DefaultHost() currentUser, err := api.CurrentLoginName(apiClient, host) if err != nil { return err diff --git a/pkg/cmd/repo/create/create.go b/pkg/cmd/repo/create/create.go index 2dcf70524..0b237c963 100644 --- a/pkg/cmd/repo/create/create.go +++ b/pkg/cmd/repo/create/create.go @@ -191,10 +191,7 @@ func NewCmdCreate(f *cmdutil.Factory, runF func(*CreateOptions) error) *cobra.Co if err != nil { return nil, cobra.ShellCompDirectiveError } - hostname, err := cfg.DefaultHost() - if err != nil { - return nil, cobra.ShellCompDirectiveError - } + hostname, _ := cfg.DefaultHost() results, err := listGitIgnoreTemplates(httpClient, hostname) if err != nil { return nil, cobra.ShellCompDirectiveError @@ -211,10 +208,7 @@ func NewCmdCreate(f *cmdutil.Factory, runF func(*CreateOptions) error) *cobra.Co if err != nil { return nil, cobra.ShellCompDirectiveError } - hostname, err := cfg.DefaultHost() - if err != nil { - return nil, cobra.ShellCompDirectiveError - } + hostname, _ := cfg.DefaultHost() licenses, err := listLicenseTemplates(httpClient, hostname) if err != nil { return nil, cobra.ShellCompDirectiveError @@ -266,10 +260,7 @@ func createFromScratch(opts *CreateOptions) error { return err } - host, err := cfg.DefaultHost() - if err != nil { - return err - } + host, _ := cfg.DefaultHost() if opts.Interactive { opts.Name, opts.Description, opts.Visibility, err = interactiveRepoInfo("") @@ -409,10 +400,7 @@ func createFromLocal(opts *CreateOptions) error { if err != nil { return err } - host, err := cfg.DefaultHost() - if err != nil { - return err - } + host, _ := cfg.DefaultHost() if opts.Interactive { opts.Source, err = interactiveSource() diff --git a/pkg/cmd/repo/fork/fork.go b/pkg/cmd/repo/fork/fork.go index f7bf0140a..1752e6272 100644 --- a/pkg/cmd/repo/fork/fork.go +++ b/pkg/cmd/repo/fork/fork.go @@ -225,10 +225,7 @@ func forkRun(opts *ForkOptions) error { if err != nil { return err } - protocol, err := cfg.Get(repoToFork.RepoHost(), "git_protocol") - if err != nil { - return err - } + protocol, _ := cfg.Get(repoToFork.RepoHost(), "git_protocol") if inParent { remotes, err := opts.Remotes() @@ -248,7 +245,7 @@ func forkRun(opts *ForkOptions) error { if scheme != "" { protocol = scheme } else { - protocol = cfg.Default("git_protocol") + protocol = "https" } } } diff --git a/pkg/cmd/repo/fork/fork_test.go b/pkg/cmd/repo/fork/fork_test.go index ef47b82ba..111eb98f6 100644 --- a/pkg/cmd/repo/fork/fork_test.go +++ b/pkg/cmd/repo/fork/fork_test.go @@ -210,7 +210,7 @@ func TestRepoFork(t *testing.T) { httpStubs func(*httpmock.Registry) execStubs func(*run.CommandStubber) askStubs func(*prompt.AskStubber) - cfg func(config.Config) config.Config + cfgStubs func(*config.ConfigMock) remotes []*context.Remote wantOut string wantErrOut string @@ -253,9 +253,8 @@ func TestRepoFork(t *testing.T) { Repo: ghrepo.New("OWNER", "REPO"), }, }, - cfg: func(c config.Config) config.Config { - _ = c.Set("", "git_protocol", "") - return c + cfgStubs: func(c *config.ConfigMock) { + c.Set("", "git_protocol", "") }, httpStubs: forkPost, execStubs: func(cs *run.CommandStubber) { @@ -679,8 +678,8 @@ func TestRepoFork(t *testing.T) { } cfg := config.NewBlankConfig() - if tt.cfg != nil { - cfg = tt.cfg(cfg) + if tt.cfgStubs != nil { + tt.cfgStubs(cfg) } tt.opts.Config = func() (config.Config, error) { return cfg, nil diff --git a/pkg/cmd/repo/garden/garden.go b/pkg/cmd/repo/garden/garden.go index 95591bd32..f8e0722f3 100644 --- a/pkg/cmd/repo/garden/garden.go +++ b/pkg/cmd/repo/garden/garden.go @@ -155,11 +155,7 @@ func gardenRun(opts *GardenOptions) error { if err != nil { return err } - hostname, err := cfg.DefaultHost() - if err != nil { - return err - } - + hostname, _ := cfg.DefaultHost() currentUser, err := api.CurrentLoginName(apiClient, hostname) if err != nil { return err diff --git a/pkg/cmd/repo/list/list.go b/pkg/cmd/repo/list/list.go index 22d710dde..223c16fb6 100644 --- a/pkg/cmd/repo/list/list.go +++ b/pkg/cmd/repo/list/list.go @@ -119,10 +119,7 @@ func listRun(opts *ListOptions) error { return err } - host, err := cfg.DefaultHost() - if err != nil { - return err - } + host, _ := cfg.DefaultHost() if opts.Detector == nil { cachedClient := api.NewCachedHTTPClient(httpClient, time.Hour*24) diff --git a/pkg/cmd/repo/list/list_test.go b/pkg/cmd/repo/list/list_test.go index 8d0bf72b2..50d0e380d 100644 --- a/pkg/cmd/repo/list/list_test.go +++ b/pkg/cmd/repo/list/list_test.go @@ -421,7 +421,7 @@ func TestRepoList_noVisibilityField(t *testing.T) { return &http.Client{Transport: reg}, nil }, Config: func() (config.Config, error) { - return config.InheritEnv(config.NewBlankConfig()), nil + return config.NewBlankConfig(), nil }, Now: func() time.Time { t, _ := time.Parse(time.RFC822, "19 Feb 21 15:00 UTC") diff --git a/pkg/cmd/repo/view/view.go b/pkg/cmd/repo/view/view.go index e799306f0..25eebb0bb 100644 --- a/pkg/cmd/repo/view/view.go +++ b/pkg/cmd/repo/view/view.go @@ -97,11 +97,7 @@ func viewRun(opts *ViewOptions) error { if err != nil { return err } - hostname, err := cfg.DefaultHost() - if err != nil { - return err - } - + hostname, _ := cfg.DefaultHost() currentUser, err := api.CurrentLoginName(apiClient, hostname) if err != nil { return err diff --git a/pkg/cmd/search/shared/shared.go b/pkg/cmd/search/shared/shared.go index 1b506ea0f..918d95882 100644 --- a/pkg/cmd/search/shared/shared.go +++ b/pkg/cmd/search/shared/shared.go @@ -40,10 +40,7 @@ func Searcher(f *cmdutil.Factory) (search.Searcher, error) { if err != nil { return nil, err } - host, err := cfg.DefaultHost() - if err != nil { - return nil, err - } + host, _ := cfg.DefaultHost() client, err := f.HttpClient() if err != nil { return nil, err diff --git a/pkg/cmd/secret/delete/delete.go b/pkg/cmd/secret/delete/delete.go index ee3a3f836..bf1eb5c77 100644 --- a/pkg/cmd/secret/delete/delete.go +++ b/pkg/cmd/secret/delete/delete.go @@ -122,10 +122,7 @@ func removeRun(opts *DeleteOptions) error { return err } - host, err := cfg.DefaultHost() - if err != nil { - return err - } + host, _ := cfg.DefaultHost() err = client.REST(host, "DELETE", path, nil, nil) if err != nil { diff --git a/pkg/cmd/secret/list/list.go b/pkg/cmd/secret/list/list.go index 7fab513e2..2a3d5a100 100644 --- a/pkg/cmd/secret/list/list.go +++ b/pkg/cmd/secret/list/list.go @@ -123,10 +123,7 @@ func listRun(opts *ListOptions) error { return err } - host, err = cfg.DefaultHost() - if err != nil { - return err - } + host, _ = cfg.DefaultHost() if secretEntity == shared.User { secrets, err = getUserSecrets(client, host, showSelectedRepoInfo) diff --git a/pkg/cmd/secret/set/set.go b/pkg/cmd/secret/set/set.go index 3c08b6e09..9be244772 100644 --- a/pkg/cmd/secret/set/set.go +++ b/pkg/cmd/secret/set/set.go @@ -186,11 +186,7 @@ func setRun(opts *SetOptions) error { if err != nil { return err } - - host, err = cfg.DefaultHost() - if err != nil { - return err - } + host, _ = cfg.DefaultHost() } secretEntity, err := shared.GetSecretEntity(orgName, envName, opts.UserSecrets) diff --git a/pkg/cmd/ssh-key/add/add.go b/pkg/cmd/ssh-key/add/add.go index 8892527ed..5b0cb7aa5 100644 --- a/pkg/cmd/ssh-key/add/add.go +++ b/pkg/cmd/ssh-key/add/add.go @@ -77,10 +77,7 @@ func runAdd(opts *AddOptions) error { return err } - hostname, err := cfg.DefaultHost() - if err != nil { - return err - } + hostname, _ := cfg.DefaultHost() err = SSHKeyUpload(httpClient, hostname, keyReader, opts.Title) if err != nil { diff --git a/pkg/cmd/ssh-key/list/list.go b/pkg/cmd/ssh-key/list/list.go index dd1c4d92f..943a3ddb0 100644 --- a/pkg/cmd/ssh-key/list/list.go +++ b/pkg/cmd/ssh-key/list/list.go @@ -51,10 +51,7 @@ func listRun(opts *ListOptions) error { return err } - host, err := cfg.DefaultHost() - if err != nil { - return err - } + host, _ := cfg.DefaultHost() sshKeys, err := userKeys(apiClient, host, "") if err != nil { diff --git a/pkg/cmd/status/status.go b/pkg/cmd/status/status.go index 0f551da9c..7a19d07dd 100644 --- a/pkg/cmd/status/status.go +++ b/pkg/cmd/status/status.go @@ -26,7 +26,7 @@ import ( ) type hostConfig interface { - DefaultHost() (string, error) + DefaultHost() (string, string) } type StatusOptions struct { @@ -619,10 +619,7 @@ func statusRun(opts *StatusOptions) error { return fmt.Errorf("could not create client: %w", err) } - hostname, err := opts.HostConfig.DefaultHost() - if err != nil { - return err - } + hostname, _ := opts.HostConfig.DefaultHost() sg := NewStatusGetter(client, hostname, opts) diff --git a/pkg/cmd/status/status_test.go b/pkg/cmd/status/status_test.go index 816034808..73ca3d5ff 100644 --- a/pkg/cmd/status/status_test.go +++ b/pkg/cmd/status/status_test.go @@ -19,8 +19,8 @@ import ( type testHostConfig string -func (c testHostConfig) DefaultHost() (string, error) { - return string(c), nil +func (c testHostConfig) DefaultHost() (string, string) { + return string(c), "" } func TestNewCmdStatus(t *testing.T) { diff --git a/pkg/cmdutil/auth_check.go b/pkg/cmdutil/auth_check.go index 3da9cfcfd..03f76de6f 100644 --- a/pkg/cmdutil/auth_check.go +++ b/pkg/cmdutil/auth_check.go @@ -14,20 +14,17 @@ func DisableAuthCheck(cmd *cobra.Command) { } func CheckAuth(cfg config.Config) bool { - if config.AuthTokenProvidedFromEnv() { + // This will check if there are any environment variable + // authentication tokens set for enterprise hosts. + // Any non-github.com hostname is fine here + dummyHostname := "example.com" + token, _ := cfg.AuthToken(dummyHostname) + if token != "" { return true } - hosts, err := cfg.Hosts() - if err != nil { - return false - } - - for _, hostname := range hosts { - token, _ := cfg.Get(hostname, "oauth_token") - if token != "" { - return true - } + if len(cfg.Hosts()) > 0 { + return true } return false diff --git a/pkg/cmdutil/auth_check_test.go b/pkg/cmdutil/auth_check_test.go index adbfac992..797824f1c 100644 --- a/pkg/cmdutil/auth_check_test.go +++ b/pkg/cmdutil/auth_check_test.go @@ -1,7 +1,6 @@ package cmdutil import ( - "os" "testing" "github.com/cli/cli/v2/internal/config" @@ -9,56 +8,38 @@ import ( ) func Test_CheckAuth(t *testing.T) { - orig_GITHUB_TOKEN := os.Getenv("GITHUB_TOKEN") - t.Cleanup(func() { - os.Setenv("GITHUB_TOKEN", orig_GITHUB_TOKEN) - }) - tests := []struct { name string - cfg func(config.Config) - envToken bool + cfgStubs func(*config.ConfigMock) expected bool }{ { - name: "no hosts", - cfg: func(c config.Config) {}, - envToken: false, + name: "no known hosts, no env auth token", + cfgStubs: func(c *config.ConfigMock) {}, expected: false, }, - {name: "no hosts, env auth token", - cfg: func(c config.Config) {}, - envToken: true, + { + name: "no known hosts, env auth token", + cfgStubs: func(c *config.ConfigMock) { + c.AuthTokenFunc = func(string) (string, string) { + return "token", "GITHUB_TOKEN" + } + }, expected: true, }, { - name: "host, no token", - cfg: func(c config.Config) { - _ = c.Set("github.com", "oauth_token", "") + name: "known host", + cfgStubs: func(c *config.ConfigMock) { + c.Set("github.com", "oauth_token", "token") }, - envToken: false, - expected: false, - }, - { - name: "host, token", - cfg: func(c config.Config) { - _ = c.Set("github.com", "oauth_token", "a token") - }, - envToken: false, expected: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if tt.envToken { - os.Setenv("GITHUB_TOKEN", "TOKEN") - } else { - os.Setenv("GITHUB_TOKEN", "") - } - cfg := config.NewBlankConfig() - tt.cfg(cfg) + tt.cfgStubs(cfg) result := CheckAuth(cfg) assert.Equal(t, tt.expected, result) }) diff --git a/pkg/cmdutil/repo_override.go b/pkg/cmdutil/repo_override.go index f0550e2d7..29043d54d 100644 --- a/pkg/cmdutil/repo_override.go +++ b/pkg/cmdutil/repo_override.go @@ -31,10 +31,7 @@ func EnableRepoOverride(cmd *cobra.Command, f *Factory) { if err != nil { return nil, cobra.ShellCompDirectiveError } - defaultHost, err := config.DefaultHost() - if err != nil { - return nil, cobra.ShellCompDirectiveError - } + defaultHost, _ := config.DefaultHost() var results []string for _, remote := range remotes {