diff --git a/internal/config/config_file_test.go b/internal/config/config_file_test.go index 4c35f24f9..100c065f3 100644 --- a/internal/config/config_file_test.go +++ b/internal/config/config_file_test.go @@ -80,13 +80,13 @@ example.com: `)() config, err := parseConfig("config.yml") assert.NoError(t, err) - val, err := config.Get("example.com", "git_protocol") + val, err := config.GetOrDefault("example.com", "git_protocol") assert.NoError(t, err) assert.Equal(t, "https", val) - val, err = config.Get("github.com", "git_protocol") + val, err = config.GetOrDefault("github.com", "git_protocol") assert.NoError(t, err) assert.Equal(t, "ssh", val) - val, err = config.Get("nonexistent.io", "git_protocol") + val, err = config.GetOrDefault("nonexistent.io", "git_protocol") assert.NoError(t, err) assert.Equal(t, "ssh", val) } diff --git a/internal/config/config_type.go b/internal/config/config_type.go index 92792e93f..7a71e0c96 100644 --- a/internal/config/config_type.go +++ b/internal/config/config_type.go @@ -9,7 +9,10 @@ import ( // This interface describes interacting with some persistent configuration for gh. type Config interface { Get(string, string) (string, error) + GetOrDefault(string, string) (string, error) GetWithSource(string, string) (string, string, error) + GetOrDefaultWithSource(string, string) (string, string, error) + Default(string) string Set(string, string, string) error UnsetHost(string) Hosts() ([]string, error) diff --git a/internal/config/config_type_test.go b/internal/config/config_type_test.go index bf53aabe4..c16455bcc 100644 --- a/internal/config/config_type_test.go +++ b/internal/config/config_type_test.go @@ -58,7 +58,7 @@ func Test_defaultConfig(t *testing.T) { assert.Equal(t, expected, mainBuf.String()) assert.Equal(t, "", hostsBuf.String()) - proto, err := cfg.Get("", "git_protocol") + proto, err := cfg.GetOrDefault("", "git_protocol") assert.NoError(t, err) assert.Equal(t, "https", proto) diff --git a/internal/config/from_env.go b/internal/config/from_env.go index ad31537f4..27cf3c54b 100644 --- a/internal/config/from_env.go +++ b/internal/config/from_env.go @@ -76,6 +76,24 @@ func (c *envConfig) GetWithSource(hostname, key string) (string, string, error) return c.Config.GetWithSource(hostname, key) } +func (c *envConfig) GetOrDefault(hostname, key string) (val string, err error) { + val, _, err = c.GetOrDefaultWithSource(hostname, key) + return +} + +func (c *envConfig) GetOrDefaultWithSource(hostname, key string) (val string, src string, err error) { + val, src, err = c.GetWithSource(hostname, key) + if err == nil && val == "" { + val = c.Default(key) + } + + return +} + +func (c *envConfig) Default(key string) string { + return c.Config.Default(key) +} + func (c *envConfig) CheckWriteable(hostname, key string) error { if hostname != "" && key == "oauth_token" { if token, env := AuthTokenFromEnv(hostname); token != "" { diff --git a/internal/config/from_file.go b/internal/config/from_file.go index 080143df4..3c1cfd65b 100644 --- a/internal/config/from_file.go +++ b/internal/config/from_file.go @@ -65,13 +65,26 @@ func (c *fileConfig) GetWithSource(hostname, key string) (string, string, error) return "", defaultSource, err } - if value == "" { - return defaultFor(key), defaultSource, nil - } - return value, defaultSource, nil } +func (c *fileConfig) GetOrDefault(hostname, key string) (val string, err error) { + val, _, err = c.GetOrDefaultWithSource(hostname, key) + return +} + +func (c *fileConfig) GetOrDefaultWithSource(hostname, key string) (val string, src string, err error) { + val, src, err = c.GetWithSource(hostname, key) + if err != nil && val == "" { + val = c.Default(key) + } + return +} + +func (c *fileConfig) Default(key string) string { + return defaultFor(key) +} + func (c *fileConfig) Set(hostname, key, value string) error { if hostname == "" { return c.SetStringValue(key, value) diff --git a/internal/config/stub.go b/internal/config/stub.go index e68183d32..aeb2e5526 100644 --- a/internal/config/stub.go +++ b/internal/config/stub.go @@ -25,6 +25,23 @@ func (c ConfigStub) GetWithSource(host, key string) (string, string, error) { return "", "", errors.New("not found") } +func (c ConfigStub) GetOrDefault(hostname, key string) (val string, err error) { + val, _, err = c.GetOrDefaultWithSource(hostname, key) + return +} + +func (c ConfigStub) GetOrDefaultWithSource(hostname, key string) (val string, src string, err error) { + val, src, err = c.GetWithSource(hostname, key) + if err == nil && val == "" { + val = c.Default(key) + } + return +} + +func (c ConfigStub) Default(key string) string { + return defaultFor(key) +} + func (c ConfigStub) Set(host, key, value string) error { c[genKey(host, key)] = value return nil diff --git a/pkg/cmd/auth/gitcredential/helper_test.go b/pkg/cmd/auth/gitcredential/helper_test.go index c06d6877f..92cdcb692 100644 --- a/pkg/cmd/auth/gitcredential/helper_test.go +++ b/pkg/cmd/auth/gitcredential/helper_test.go @@ -8,12 +8,18 @@ import ( "github.com/cli/cli/v2/pkg/iostreams" ) +// why not just use the config stub argh type tinyConfig map[string]string func (c tinyConfig) GetWithSource(host, key string) (string, string, error) { return c[fmt.Sprintf("%s:%s", host, key)], c["_source"], nil } +func (c tinyConfig) Get(host, key string) (val string, err error) { + val, _, err = c.GetWithSource(host, key) + return +} + func Test_helperRun(t *testing.T) { tests := []struct { name string diff --git a/pkg/cmd/auth/refresh/refresh.go b/pkg/cmd/auth/refresh/refresh.go index cd470132c..4ef58f73c 100644 --- a/pkg/cmd/auth/refresh/refresh.go +++ b/pkg/cmd/auth/refresh/refresh.go @@ -146,7 +146,7 @@ func refreshRun(opts *RefreshOptions) error { credentialFlow := &shared.GitCredentialFlow{ Executable: opts.MainExecutable, } - gitProtocol, _ := cfg.Get(hostname, "git_protocol") + gitProtocol, _ := cfg.GetOrDefault(hostname, "git_protocol") if opts.Interactive && gitProtocol == "https" { if err := credentialFlow.Prompt(hostname); err != nil { return err diff --git a/pkg/cmd/auth/status/status.go b/pkg/cmd/auth/status/status.go index f84004c87..e09273e99 100644 --- a/pkg/cmd/auth/status/status.go +++ b/pkg/cmd/auth/status/status.go @@ -35,7 +35,7 @@ func NewCmdStatus(f *cmdutil.Factory, runF func(*StatusOptions) error) *cobra.Co Args: cobra.ExactArgs(0), Short: "View authentication status", Long: heredoc.Doc(`Verifies and displays information about your authentication state. - + This command will test your authentication state for each GitHub host that gh knows about and report on any issues. `), @@ -127,7 +127,7 @@ func statusRun(opts *StatusOptions) error { addMsg("%s %s: api call failed: %s", cs.Red("X"), hostname, err) } addMsg("%s Logged in to %s as %s (%s)", cs.SuccessIcon(), hostname, cs.Bold(username), tokenSource) - proto, _ := cfg.Get(hostname, "git_protocol") + proto, _ := cfg.GetOrDefault(hostname, "git_protocol") if proto != "" { addMsg("%s Git operations for %s configured to use %s protocol.", cs.SuccessIcon(), hostname, cs.Bold(proto)) diff --git a/pkg/cmd/config/get/get.go b/pkg/cmd/config/get/get.go index 3a5634458..94694adb2 100644 --- a/pkg/cmd/config/get/get.go +++ b/pkg/cmd/config/get/get.go @@ -53,7 +53,7 @@ func NewCmdConfigGet(f *cmdutil.Factory, runF func(*GetOptions) error) *cobra.Co } func getRun(opts *GetOptions) error { - val, err := opts.Config.Get(opts.Hostname, opts.Key) + val, err := opts.Config.GetOrDefault(opts.Hostname, opts.Key) if err != nil { return err } diff --git a/pkg/cmd/config/get/get_test.go b/pkg/cmd/config/get/get_test.go index 7c5efa9be..46f187394 100644 --- a/pkg/cmd/config/get/get_test.go +++ b/pkg/cmd/config/get/get_test.go @@ -115,6 +115,8 @@ func Test_getRun(t *testing.T) { assert.NoError(t, err) assert.Equal(t, tt.stdout, stdout.String()) assert.Equal(t, tt.stderr, stderr.String()) + _, err = tt.input.Config.GetOrDefault("", "_written") + assert.Error(t, err) _, err = tt.input.Config.Get("", "_written") assert.Error(t, err) }) diff --git a/pkg/cmd/config/list/list.go b/pkg/cmd/config/list/list.go index 8b1157b1f..c2ad0397b 100644 --- a/pkg/cmd/config/list/list.go +++ b/pkg/cmd/config/list/list.go @@ -59,7 +59,7 @@ func listRun(opts *ListOptions) error { configOptions := config.ConfigOptions() for _, key := range configOptions { - val, err := cfg.Get(host, key.Key) + val, err := cfg.GetOrDefault(host, key.Key) if err != nil { return err } diff --git a/pkg/cmd/config/set/set_test.go b/pkg/cmd/config/set/set_test.go index cdd2e7c94..2beb20edc 100644 --- a/pkg/cmd/config/set/set_test.go +++ b/pkg/cmd/config/set/set_test.go @@ -145,11 +145,11 @@ func Test_setRun(t *testing.T) { assert.Equal(t, tt.stdout, stdout.String()) assert.Equal(t, tt.stderr, stderr.String()) - val, err := tt.input.Config.Get(tt.input.Hostname, tt.input.Key) + val, err := tt.input.Config.GetOrDefault(tt.input.Hostname, tt.input.Key) assert.NoError(t, err) assert.Equal(t, tt.expectedValue, val) - val, err = tt.input.Config.Get("", "_written") + val, err = tt.input.Config.GetOrDefault("", "_written") assert.NoError(t, err) assert.Equal(t, "true", val) }) diff --git a/pkg/cmd/extension/manager.go b/pkg/cmd/extension/manager.go index 691110656..5652fb564 100644 --- a/pkg/cmd/extension/manager.go +++ b/pkg/cmd/extension/manager.go @@ -337,7 +337,7 @@ func (m *Manager) Install(repo ghrepo.Interface) error { return errors.New("extension is uninstallable: missing executable") } - protocol, _ := m.config.Get(repo.RepoHost(), "git_protocol") + protocol, _ := m.config.GetOrDefault(repo.RepoHost(), "git_protocol") return m.installGit(ghrepo.FormatRemoteURL(repo, protocol), m.io.Out, m.io.ErrOut) } diff --git a/pkg/cmd/factory/default.go b/pkg/cmd/factory/default.go index 110dfb4e9..b89d28e21 100644 --- a/pkg/cmd/factory/default.go +++ b/pkg/cmd/factory/default.go @@ -166,7 +166,7 @@ func ioStreams(f *cmdutil.Factory) *iostreams.IOStreams { return io } - if prompt, _ := cfg.Get("", "prompt"); prompt == "disabled" { + if prompt, _ := cfg.GetOrDefault("", "prompt"); prompt == "disabled" { io.SetNeverPrompt(true) } diff --git a/pkg/cmd/factory/http.go b/pkg/cmd/factory/http.go index fd61f2dc7..d1b8b54ed 100644 --- a/pkg/cmd/factory/http.go +++ b/pkg/cmd/factory/http.go @@ -54,6 +54,7 @@ var timezoneNames = map[int]string{ } type configGetter interface { + GetOrDefault(string, string) (string, error) Get(string, string) (string, error) } diff --git a/pkg/cmd/factory/http_test.go b/pkg/cmd/factory/http_test.go index 1505d1b65..0039289a7 100644 --- a/pkg/cmd/factory/http_test.go +++ b/pkg/cmd/factory/http_test.go @@ -95,10 +95,10 @@ func TestNewHTTPClient(t *testing.T) { > Accept: application/vnd.github.merge-info-preview+json, application/vnd.github.nebula-preview > Authorization: token ████████████████████ > User-Agent: GitHub CLI v1.2.3 - + < HTTP/1.1 204 No Content < Date: