From 8a82e3a85631809ce17d4a45fb5cb845680b4518 Mon Sep 17 00:00:00 2001 From: William Martin Date: Wed, 8 May 2024 13:23:08 +0200 Subject: [PATCH] Provide more type safety around config values --- internal/config/config.go | 68 ++++++----- internal/config/config_test.go | 62 +++++----- internal/config/stub.go | 26 ++--- internal/gh/gh.go | 5 +- internal/gh/mock/config.go | 203 +++++++++++++++++---------------- pkg/cmd/config/get/get.go | 15 ++- pkg/cmd/config/get/get_test.go | 25 ++-- pkg/cmd/config/list/list.go | 9 +- pkg/cmd/config/set/set_test.go | 6 +- pkg/option/option.go | 115 +++++++++++++++++++ 10 files changed, 331 insertions(+), 203 deletions(-) create mode 100644 pkg/option/option.go diff --git a/internal/config/config.go b/internal/config/config.go index ad451ee97..cf4d761c1 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -9,6 +9,7 @@ import ( "github.com/cli/cli/v2/internal/gh" "github.com/cli/cli/v2/internal/keyring" + o "github.com/cli/cli/v2/pkg/option" ghAuth "github.com/cli/go-gh/v2/pkg/auth" ghConfig "github.com/cli/go-gh/v2/pkg/config" ) @@ -41,28 +42,32 @@ type cfg struct { cfg *ghConfig.Config } -func (c *cfg) Get(hostname, key string) (string, error) { +func (c *cfg) Get(hostname, key string) o.Option[string] { if hostname != "" { val, err := c.cfg.Get([]string{hostsKey, hostname, key}) if err == nil { - return val, err + return o.Some(val) } } - return c.cfg.Get([]string{key}) + val, err := c.cfg.Get([]string{key}) + if err == nil { + return o.Some(val) + } + + return o.None[string]() } -func (c *cfg) GetOrDefault(hostname, key string) (string, error) { - val, err := c.Get(hostname, key) - if err == nil { - return val, err +func (c *cfg) GetOrDefault(hostname, key string) o.Option[string] { + if val := c.Get(hostname, key); val.IsSome() { + return val } - if val, ok := defaultFor(key); ok { - return val, nil + if defaultVal := defaultFor(key); defaultVal.IsSome() { + return defaultVal } - return val, err + return o.None[string]() } func (c *cfg) Set(hostname, key, value string) { @@ -91,42 +96,43 @@ func (c *cfg) Authentication() gh.AuthConfig { } func (c *cfg) Browser(hostname string) string { - val, _ := c.GetOrDefault(hostname, browserKey) - return val + // Intentionally panic as this is a programmer error + return c.GetOrDefault(hostname, browserKey).Unwrap() } func (c *cfg) Editor(hostname string) string { - val, _ := c.GetOrDefault(hostname, editorKey) - return val + // Intentionally panic as this is a programmer error + return c.GetOrDefault(hostname, editorKey).Unwrap() } func (c *cfg) GitProtocol(hostname string) string { - val, _ := c.GetOrDefault(hostname, gitProtocolKey) - return val + // Intentionally panic as this is a programmer error + return c.GetOrDefault(hostname, gitProtocolKey).Unwrap() } func (c *cfg) HTTPUnixSocket(hostname string) string { - val, _ := c.GetOrDefault(hostname, httpUnixSocketKey) - return val + // Intentionally panic as this is a programmer error + return c.GetOrDefault(hostname, httpUnixSocketKey).Unwrap() } func (c *cfg) Pager(hostname string) string { - val, _ := c.GetOrDefault(hostname, pagerKey) - return val + // Intentionally panic as this is a programmer error + return c.GetOrDefault(hostname, pagerKey).Unwrap() } func (c *cfg) Prompt(hostname string) string { - val, _ := c.GetOrDefault(hostname, promptKey) - return val + // Intentionally panic as this is a programmer error + return c.GetOrDefault(hostname, promptKey).Unwrap() } -func (c *cfg) Version() string { - val, _ := c.GetOrDefault("", versionKey) - return val +func (c *cfg) Version() o.Option[string] { + return c.Get("", versionKey) } func (c *cfg) Migrate(m gh.Migration) error { - version := c.Version() + // If there is no version entry we must never have applied a migration, and the following conditional logic + // handles the version as an empty string correctly. + version := c.Version().UnwrapOrZero() // If migration has already occurred then do not attempt to migrate again. if m.PostVersion() == version { @@ -156,16 +162,16 @@ func (c *cfg) CacheDir() string { return ghConfig.CacheDir() } -func defaultFor(key string) (string, bool) { +func defaultFor(key string) o.Option[string] { for _, co := range ConfigOptions() { if co.Key == key { - return co.DefaultValue, true + return o.Some(co.DefaultValue) } } - return "", false + return o.None[string]() } -// AuthConfig is used for interacting with some persistent configuration for gh, +// AuthConfig is used for interacting with o.Some persistent configuration for gh, // with knowledge on how to access encrypted storage when neccesarry. // Behavior is scoped to authentication specific tasks. type AuthConfig struct { @@ -332,7 +338,7 @@ func (c *AuthConfig) SwitchUser(hostname, user string) error { if err != nil { // Given that activateUser can only fail before the config is written, or when writing the config // we know for sure that the config has not been written. However, we still should restore it back - // to its previous clean state just in case something else tries to make use of the config, or tries + // to its previous clean state just in case o.Something else tries to make use of the config, or tries // to write it again. if previousSource == "keyring" { if setErr := keyring.Set(keyringServiceName(hostname), "", previouslyActiveToken); setErr != nil { diff --git a/internal/config/config_test.go b/internal/config/config_test.go index f6832f4a4..3d13fb7d7 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -37,12 +37,10 @@ func TestGetNonExistentKey(t *testing.T) { cfg := newTestConfig() // When we get a key that has no value - val, err := cfg.Get("", "non-existent-key") + optionalVal := cfg.Get("", "non-existent-key") - // Then it returns an error and the value is empty - var keyNotFoundError *ghConfig.KeyNotFoundError - require.ErrorAs(t, err, &keyNotFoundError) - require.Empty(t, val) + // Then it returns a None variant + require.True(t, optionalVal.IsNone(), "expected there to be no value") } func TestGetNonExistentHostSpecificKey(t *testing.T) { @@ -50,12 +48,10 @@ func TestGetNonExistentHostSpecificKey(t *testing.T) { cfg := newTestConfig() // When we get a key for a host that has no value - val, err := cfg.Get("non-existent-host", "non-existent-key") + optionalVal := cfg.Get("non-existent-host", "non-existent-key") - // Then it returns an error and the value is empty - var keyNotFoundError *ghConfig.KeyNotFoundError - require.ErrorAs(t, err, &keyNotFoundError) - require.Empty(t, val) + // Then it returns a None variant + require.True(t, optionalVal.IsNone(), "expected there to be no value") } func TestGetExistingTopLevelKey(t *testing.T) { @@ -64,11 +60,11 @@ func TestGetExistingTopLevelKey(t *testing.T) { cfg.Set("", "top-level-key", "top-level-value") // When we get that key - val, err := cfg.Get("non-existent-host", "top-level-key") + optionalVal := cfg.Get("non-existent-host", "top-level-key") - // Then it returns successfully with the correct value - require.NoError(t, err) - require.Equal(t, "top-level-value", val) + // Then it returns a Some variant containing the correct value + require.True(t, optionalVal.IsSome(), "expected there to be a value") + require.Equal(t, "top-level-value", optionalVal.Unwrap()) } func TestGetExistingHostSpecificKey(t *testing.T) { @@ -77,11 +73,11 @@ func TestGetExistingHostSpecificKey(t *testing.T) { cfg.Set("github.com", "host-specific-key", "host-specific-value") // When we get that key - val, err := cfg.Get("github.com", "host-specific-key") + optionalVal := cfg.Get("github.com", "host-specific-key") - // Then it returns successfully with the correct value - require.NoError(t, err) - require.Equal(t, "host-specific-value", val) + // Then it returns a Some variant containing the correct value + require.True(t, optionalVal.IsSome(), "expected there to be a value") + require.Equal(t, "host-specific-value", optionalVal.Unwrap()) } func TestGetHostnameSpecificKeyFallsBackToTopLevel(t *testing.T) { @@ -90,11 +86,11 @@ func TestGetHostnameSpecificKeyFallsBackToTopLevel(t *testing.T) { cfg.Set("", "key", "value") // When we get that key on a specific host - val, err := cfg.Get("github.com", "key") + optionalVal := cfg.Get("github.com", "key") - // Then it returns successfully, falling back to the top level config - require.NoError(t, err) - require.Equal(t, "value", val) + // Then it returns a Some variant containing the correct value by falling back to the top level config + require.True(t, optionalVal.IsSome(), "expected there to be a value") + require.Equal(t, "value", optionalVal.Unwrap()) } func TestGetOrDefaultApplicationDefaults(t *testing.T) { @@ -116,11 +112,11 @@ func TestGetOrDefaultApplicationDefaults(t *testing.T) { cfg := newTestConfig() // When we get a key that has no value, but has a default - val, err := cfg.GetOrDefault("", tt.key) + optionalVal := cfg.GetOrDefault("", tt.key) - // Then it returns the default value - require.NoError(t, err) - require.Equal(t, tt.expectedDefault, val) + // Then there is an entry with the default value, and source set as default + require.True(t, optionalVal.IsSome(), "expected there to be a value") + require.Equal(t, tt.expectedDefault, optionalVal.Unwrap()) }) } } @@ -131,12 +127,12 @@ func TestGetOrDefaultExistingKey(t *testing.T) { cfg.Set("", gitProtocolKey, "ssh") // When we get that key - val, err := cfg.GetOrDefault("", gitProtocolKey) + optionalVal := cfg.GetOrDefault("", gitProtocolKey) // Then it returns successfully with the correct value, and doesn't fall back // to the default - require.NoError(t, err) - require.Equal(t, "ssh", val) + require.True(t, optionalVal.IsSome(), "expected there to be a value") + require.Equal(t, "ssh", optionalVal.Unwrap()) } func TestGetOrDefaultNotFoundAndNoDefault(t *testing.T) { @@ -144,12 +140,10 @@ func TestGetOrDefaultNotFoundAndNoDefault(t *testing.T) { cfg := newTestConfig() // When we get a non-existent-key that has no default - val, err := cfg.GetOrDefault("", "non-existent-key") + optionalEntry := cfg.GetOrDefault("", "non-existent-key") - // Then it returns an error and the value is empty - var keyNotFoundError *ghConfig.KeyNotFoundError - require.ErrorAs(t, err, &keyNotFoundError) - require.Empty(t, val) + // Then it returns with no entry + require.False(t, optionalEntry.IsSome(), "expected the config to not contain a value") } func TestFallbackConfig(t *testing.T) { diff --git a/internal/config/stub.go b/internal/config/stub.go index d0500cfe4..ec088ed07 100644 --- a/internal/config/stub.go +++ b/internal/config/stub.go @@ -9,6 +9,7 @@ import ( "github.com/cli/cli/v2/internal/gh" ghmock "github.com/cli/cli/v2/internal/gh/mock" "github.com/cli/cli/v2/internal/keyring" + o "github.com/cli/cli/v2/pkg/option" ghConfig "github.com/cli/go-gh/v2/pkg/config" ) @@ -20,7 +21,7 @@ func NewFromString(cfgStr string) *ghmock.ConfigMock { c := ghConfig.ReadFromString(cfgStr) cfg := cfg{c} mock := &ghmock.ConfigMock{} - mock.GetOrDefaultFunc = func(host, key string) (string, error) { + mock.GetOrDefaultFunc = func(host, key string) o.Option[string] { return cfg.GetOrDefault(host, key) } mock.SetFunc = func(host, key, value string) { @@ -52,32 +53,25 @@ func NewFromString(cfgStr string) *ghmock.ConfigMock { } } mock.BrowserFunc = func(hostname string) string { - val, _ := cfg.GetOrDefault(hostname, browserKey) - return val + return cfg.Browser(hostname) } mock.EditorFunc = func(hostname string) string { - val, _ := cfg.GetOrDefault(hostname, editorKey) - return val + return cfg.Editor(hostname) } mock.GitProtocolFunc = func(hostname string) string { - val, _ := cfg.GetOrDefault(hostname, gitProtocolKey) - return val + return cfg.GitProtocol(hostname) } mock.HTTPUnixSocketFunc = func(hostname string) string { - val, _ := cfg.GetOrDefault(hostname, httpUnixSocketKey) - return val + return cfg.HTTPUnixSocket(hostname) } mock.PagerFunc = func(hostname string) string { - val, _ := cfg.GetOrDefault(hostname, pagerKey) - return val + return cfg.Pager(hostname) } mock.PromptFunc = func(hostname string) string { - val, _ := cfg.GetOrDefault(hostname, promptKey) - return val + return cfg.Prompt(hostname) } - mock.VersionFunc = func() string { - val, _ := cfg.GetOrDefault("", versionKey) - return val + mock.VersionFunc = func() o.Option[string] { + return cfg.Version() } mock.CacheDirFunc = func() string { return cfg.CacheDir() diff --git a/internal/gh/gh.go b/internal/gh/gh.go index 6e3094ea1..5abde3af7 100644 --- a/internal/gh/gh.go +++ b/internal/gh/gh.go @@ -10,6 +10,7 @@ package gh import ( + o "github.com/cli/cli/v2/pkg/option" ghConfig "github.com/cli/go-gh/v2/pkg/config" ) @@ -18,7 +19,7 @@ import ( //go:generate moq -rm -pkg ghmock -out mock/config.go . Config type Config interface { // GetOrDefault provides primitive access for fetching configuration values, optionally scoped by host. - GetOrDefault(hostname string, key string) (string, error) + GetOrDefault(hostname string, key string) o.Option[string] // Set provides primitive access for setting configuration values, optionally scoped by host. Set(hostname string, key string, value string) @@ -48,7 +49,7 @@ type Config interface { Migrate(Migration) error // Version returns the current schema version of the configuration. - Version() string + Version() o.Option[string] // Write persists modifications to the configuration. Write() error diff --git a/internal/gh/mock/config.go b/internal/gh/mock/config.go index 736c8c262..18d9e0cbc 100644 --- a/internal/gh/mock/config.go +++ b/internal/gh/mock/config.go @@ -5,6 +5,7 @@ package ghmock import ( "github.com/cli/cli/v2/internal/gh" + o "github.com/cli/cli/v2/pkg/option" "sync" ) @@ -24,37 +25,37 @@ var _ gh.Config = &ConfigMock{} // AuthenticationFunc: func() gh.AuthConfig { // panic("mock out the Authentication method") // }, -// BrowserFunc: func(s string) string { +// BrowserFunc: func(hostname string) string { // panic("mock out the Browser method") // }, // CacheDirFunc: func() string { // panic("mock out the CacheDir method") // }, -// EditorFunc: func(s string) string { +// EditorFunc: func(hostname string) string { // panic("mock out the Editor method") // }, -// GetOrDefaultFunc: func(s1 string, s2 string) (string, error) { +// GetOrDefaultFunc: func(hostname string, key string) o.Option[string] { // panic("mock out the GetOrDefault method") // }, -// GitProtocolFunc: func(s string) string { +// GitProtocolFunc: func(hostname string) string { // panic("mock out the GitProtocol method") // }, -// HTTPUnixSocketFunc: func(s string) string { +// HTTPUnixSocketFunc: func(hostname string) string { // panic("mock out the HTTPUnixSocket method") // }, // MigrateFunc: func(migration gh.Migration) error { // panic("mock out the Migrate method") // }, -// PagerFunc: func(s string) string { +// PagerFunc: func(hostname string) string { // panic("mock out the Pager method") // }, -// PromptFunc: func(s string) string { +// PromptFunc: func(hostname string) string { // panic("mock out the Prompt method") // }, -// SetFunc: func(s1 string, s2 string, s3 string) { +// SetFunc: func(hostname string, key string, value string) { // panic("mock out the Set method") // }, -// VersionFunc: func() string { +// VersionFunc: func() o.Option[string] { // panic("mock out the Version method") // }, // WriteFunc: func() error { @@ -74,37 +75,37 @@ type ConfigMock struct { AuthenticationFunc func() gh.AuthConfig // BrowserFunc mocks the Browser method. - BrowserFunc func(s string) string + BrowserFunc func(hostname string) string // CacheDirFunc mocks the CacheDir method. CacheDirFunc func() string // EditorFunc mocks the Editor method. - EditorFunc func(s string) string + EditorFunc func(hostname string) string // GetOrDefaultFunc mocks the GetOrDefault method. - GetOrDefaultFunc func(s1 string, s2 string) (string, error) + GetOrDefaultFunc func(hostname string, key string) o.Option[string] // GitProtocolFunc mocks the GitProtocol method. - GitProtocolFunc func(s string) string + GitProtocolFunc func(hostname string) string // HTTPUnixSocketFunc mocks the HTTPUnixSocket method. - HTTPUnixSocketFunc func(s string) string + HTTPUnixSocketFunc func(hostname string) string // MigrateFunc mocks the Migrate method. MigrateFunc func(migration gh.Migration) error // PagerFunc mocks the Pager method. - PagerFunc func(s string) string + PagerFunc func(hostname string) string // PromptFunc mocks the Prompt method. - PromptFunc func(s string) string + PromptFunc func(hostname string) string // SetFunc mocks the Set method. - SetFunc func(s1 string, s2 string, s3 string) + SetFunc func(hostname string, key string, value string) // VersionFunc mocks the Version method. - VersionFunc func() string + VersionFunc func() o.Option[string] // WriteFunc mocks the Write method. WriteFunc func() error @@ -119,33 +120,33 @@ type ConfigMock struct { } // Browser holds details about calls to the Browser method. Browser []struct { - // S is the s argument value. - S string + // Hostname is the hostname argument value. + Hostname string } // CacheDir holds details about calls to the CacheDir method. CacheDir []struct { } // Editor holds details about calls to the Editor method. Editor []struct { - // S is the s argument value. - S string + // Hostname is the hostname argument value. + Hostname string } // GetOrDefault holds details about calls to the GetOrDefault method. GetOrDefault []struct { - // S1 is the s1 argument value. - S1 string - // S2 is the s2 argument value. - S2 string + // Hostname is the hostname argument value. + Hostname string + // Key is the key argument value. + Key string } // GitProtocol holds details about calls to the GitProtocol method. GitProtocol []struct { - // S is the s argument value. - S string + // Hostname is the hostname argument value. + Hostname string } // HTTPUnixSocket holds details about calls to the HTTPUnixSocket method. HTTPUnixSocket []struct { - // S is the s argument value. - S string + // Hostname is the hostname argument value. + Hostname string } // Migrate holds details about calls to the Migrate method. Migrate []struct { @@ -154,22 +155,22 @@ type ConfigMock struct { } // Pager holds details about calls to the Pager method. Pager []struct { - // S is the s argument value. - S string + // Hostname is the hostname argument value. + Hostname string } // Prompt holds details about calls to the Prompt method. Prompt []struct { - // S is the s argument value. - S string + // Hostname is the hostname argument value. + Hostname string } // Set holds details about calls to the Set method. Set []struct { - // S1 is the s1 argument value. - S1 string - // S2 is the s2 argument value. - S2 string - // S3 is the s3 argument value. - S3 string + // Hostname is the hostname argument value. + Hostname string + // Key is the key argument value. + Key string + // Value is the value argument value. + Value string } // Version holds details about calls to the Version method. Version []struct { @@ -249,19 +250,19 @@ func (mock *ConfigMock) AuthenticationCalls() []struct { } // Browser calls BrowserFunc. -func (mock *ConfigMock) Browser(s string) string { +func (mock *ConfigMock) Browser(hostname string) string { if mock.BrowserFunc == nil { panic("ConfigMock.BrowserFunc: method is nil but Config.Browser was just called") } callInfo := struct { - S string + Hostname string }{ - S: s, + Hostname: hostname, } mock.lockBrowser.Lock() mock.calls.Browser = append(mock.calls.Browser, callInfo) mock.lockBrowser.Unlock() - return mock.BrowserFunc(s) + return mock.BrowserFunc(hostname) } // BrowserCalls gets all the calls that were made to Browser. @@ -269,10 +270,10 @@ func (mock *ConfigMock) Browser(s string) string { // // len(mockedConfig.BrowserCalls()) func (mock *ConfigMock) BrowserCalls() []struct { - S string + Hostname string } { var calls []struct { - S string + Hostname string } mock.lockBrowser.RLock() calls = mock.calls.Browser @@ -308,19 +309,19 @@ func (mock *ConfigMock) CacheDirCalls() []struct { } // Editor calls EditorFunc. -func (mock *ConfigMock) Editor(s string) string { +func (mock *ConfigMock) Editor(hostname string) string { if mock.EditorFunc == nil { panic("ConfigMock.EditorFunc: method is nil but Config.Editor was just called") } callInfo := struct { - S string + Hostname string }{ - S: s, + Hostname: hostname, } mock.lockEditor.Lock() mock.calls.Editor = append(mock.calls.Editor, callInfo) mock.lockEditor.Unlock() - return mock.EditorFunc(s) + return mock.EditorFunc(hostname) } // EditorCalls gets all the calls that were made to Editor. @@ -328,10 +329,10 @@ func (mock *ConfigMock) Editor(s string) string { // // len(mockedConfig.EditorCalls()) func (mock *ConfigMock) EditorCalls() []struct { - S string + Hostname string } { var calls []struct { - S string + Hostname string } mock.lockEditor.RLock() calls = mock.calls.Editor @@ -340,21 +341,21 @@ func (mock *ConfigMock) EditorCalls() []struct { } // GetOrDefault calls GetOrDefaultFunc. -func (mock *ConfigMock) GetOrDefault(s1 string, s2 string) (string, error) { +func (mock *ConfigMock) GetOrDefault(hostname string, key string) o.Option[string] { if mock.GetOrDefaultFunc == nil { panic("ConfigMock.GetOrDefaultFunc: method is nil but Config.GetOrDefault was just called") } callInfo := struct { - S1 string - S2 string + Hostname string + Key string }{ - S1: s1, - S2: s2, + Hostname: hostname, + Key: key, } mock.lockGetOrDefault.Lock() mock.calls.GetOrDefault = append(mock.calls.GetOrDefault, callInfo) mock.lockGetOrDefault.Unlock() - return mock.GetOrDefaultFunc(s1, s2) + return mock.GetOrDefaultFunc(hostname, key) } // GetOrDefaultCalls gets all the calls that were made to GetOrDefault. @@ -362,12 +363,12 @@ func (mock *ConfigMock) GetOrDefault(s1 string, s2 string) (string, error) { // // len(mockedConfig.GetOrDefaultCalls()) func (mock *ConfigMock) GetOrDefaultCalls() []struct { - S1 string - S2 string + Hostname string + Key string } { var calls []struct { - S1 string - S2 string + Hostname string + Key string } mock.lockGetOrDefault.RLock() calls = mock.calls.GetOrDefault @@ -376,19 +377,19 @@ func (mock *ConfigMock) GetOrDefaultCalls() []struct { } // GitProtocol calls GitProtocolFunc. -func (mock *ConfigMock) GitProtocol(s string) string { +func (mock *ConfigMock) GitProtocol(hostname string) string { if mock.GitProtocolFunc == nil { panic("ConfigMock.GitProtocolFunc: method is nil but Config.GitProtocol was just called") } callInfo := struct { - S string + Hostname string }{ - S: s, + Hostname: hostname, } mock.lockGitProtocol.Lock() mock.calls.GitProtocol = append(mock.calls.GitProtocol, callInfo) mock.lockGitProtocol.Unlock() - return mock.GitProtocolFunc(s) + return mock.GitProtocolFunc(hostname) } // GitProtocolCalls gets all the calls that were made to GitProtocol. @@ -396,10 +397,10 @@ func (mock *ConfigMock) GitProtocol(s string) string { // // len(mockedConfig.GitProtocolCalls()) func (mock *ConfigMock) GitProtocolCalls() []struct { - S string + Hostname string } { var calls []struct { - S string + Hostname string } mock.lockGitProtocol.RLock() calls = mock.calls.GitProtocol @@ -408,19 +409,19 @@ func (mock *ConfigMock) GitProtocolCalls() []struct { } // HTTPUnixSocket calls HTTPUnixSocketFunc. -func (mock *ConfigMock) HTTPUnixSocket(s string) string { +func (mock *ConfigMock) HTTPUnixSocket(hostname string) string { if mock.HTTPUnixSocketFunc == nil { panic("ConfigMock.HTTPUnixSocketFunc: method is nil but Config.HTTPUnixSocket was just called") } callInfo := struct { - S string + Hostname string }{ - S: s, + Hostname: hostname, } mock.lockHTTPUnixSocket.Lock() mock.calls.HTTPUnixSocket = append(mock.calls.HTTPUnixSocket, callInfo) mock.lockHTTPUnixSocket.Unlock() - return mock.HTTPUnixSocketFunc(s) + return mock.HTTPUnixSocketFunc(hostname) } // HTTPUnixSocketCalls gets all the calls that were made to HTTPUnixSocket. @@ -428,10 +429,10 @@ func (mock *ConfigMock) HTTPUnixSocket(s string) string { // // len(mockedConfig.HTTPUnixSocketCalls()) func (mock *ConfigMock) HTTPUnixSocketCalls() []struct { - S string + Hostname string } { var calls []struct { - S string + Hostname string } mock.lockHTTPUnixSocket.RLock() calls = mock.calls.HTTPUnixSocket @@ -472,19 +473,19 @@ func (mock *ConfigMock) MigrateCalls() []struct { } // Pager calls PagerFunc. -func (mock *ConfigMock) Pager(s string) string { +func (mock *ConfigMock) Pager(hostname string) string { if mock.PagerFunc == nil { panic("ConfigMock.PagerFunc: method is nil but Config.Pager was just called") } callInfo := struct { - S string + Hostname string }{ - S: s, + Hostname: hostname, } mock.lockPager.Lock() mock.calls.Pager = append(mock.calls.Pager, callInfo) mock.lockPager.Unlock() - return mock.PagerFunc(s) + return mock.PagerFunc(hostname) } // PagerCalls gets all the calls that were made to Pager. @@ -492,10 +493,10 @@ func (mock *ConfigMock) Pager(s string) string { // // len(mockedConfig.PagerCalls()) func (mock *ConfigMock) PagerCalls() []struct { - S string + Hostname string } { var calls []struct { - S string + Hostname string } mock.lockPager.RLock() calls = mock.calls.Pager @@ -504,19 +505,19 @@ func (mock *ConfigMock) PagerCalls() []struct { } // Prompt calls PromptFunc. -func (mock *ConfigMock) Prompt(s string) string { +func (mock *ConfigMock) Prompt(hostname string) string { if mock.PromptFunc == nil { panic("ConfigMock.PromptFunc: method is nil but Config.Prompt was just called") } callInfo := struct { - S string + Hostname string }{ - S: s, + Hostname: hostname, } mock.lockPrompt.Lock() mock.calls.Prompt = append(mock.calls.Prompt, callInfo) mock.lockPrompt.Unlock() - return mock.PromptFunc(s) + return mock.PromptFunc(hostname) } // PromptCalls gets all the calls that were made to Prompt. @@ -524,10 +525,10 @@ func (mock *ConfigMock) Prompt(s string) string { // // len(mockedConfig.PromptCalls()) func (mock *ConfigMock) PromptCalls() []struct { - S string + Hostname string } { var calls []struct { - S string + Hostname string } mock.lockPrompt.RLock() calls = mock.calls.Prompt @@ -536,23 +537,23 @@ func (mock *ConfigMock) PromptCalls() []struct { } // Set calls SetFunc. -func (mock *ConfigMock) Set(s1 string, s2 string, s3 string) { +func (mock *ConfigMock) Set(hostname string, key string, value string) { if mock.SetFunc == nil { panic("ConfigMock.SetFunc: method is nil but Config.Set was just called") } callInfo := struct { - S1 string - S2 string - S3 string + Hostname string + Key string + Value string }{ - S1: s1, - S2: s2, - S3: s3, + Hostname: hostname, + Key: key, + Value: value, } mock.lockSet.Lock() mock.calls.Set = append(mock.calls.Set, callInfo) mock.lockSet.Unlock() - mock.SetFunc(s1, s2, s3) + mock.SetFunc(hostname, key, value) } // SetCalls gets all the calls that were made to Set. @@ -560,14 +561,14 @@ func (mock *ConfigMock) Set(s1 string, s2 string, s3 string) { // // len(mockedConfig.SetCalls()) func (mock *ConfigMock) SetCalls() []struct { - S1 string - S2 string - S3 string + Hostname string + Key string + Value string } { var calls []struct { - S1 string - S2 string - S3 string + Hostname string + Key string + Value string } mock.lockSet.RLock() calls = mock.calls.Set @@ -576,7 +577,7 @@ func (mock *ConfigMock) SetCalls() []struct { } // Version calls VersionFunc. -func (mock *ConfigMock) Version() string { +func (mock *ConfigMock) Version() o.Option[string] { if mock.VersionFunc == nil { panic("ConfigMock.VersionFunc: method is nil but Config.Version was just called") } diff --git a/pkg/cmd/config/get/get.go b/pkg/cmd/config/get/get.go index e714c6a64..82d75111b 100644 --- a/pkg/cmd/config/get/get.go +++ b/pkg/cmd/config/get/get.go @@ -64,13 +64,22 @@ func getRun(opts *GetOptions) error { return nil } - val, err := opts.Config.GetOrDefault(opts.Hostname, opts.Key) - if err != nil { - return err + optionalValue := opts.Config.GetOrDefault(opts.Hostname, opts.Key) + if optionalValue.IsNone() { + return nonExistentKeyError{key: opts.Key} } + val := optionalValue.Unwrap() if val != "" { fmt.Fprintf(opts.IO.Out, "%s\n", val) } return nil } + +type nonExistentKeyError struct { + key string +} + +func (e nonExistentKeyError) Error() string { + return fmt.Sprintf("could not find key \"%s\"", e.key) +} diff --git a/pkg/cmd/config/get/get_test.go b/pkg/cmd/config/get/get_test.go index 40eca49e5..6320ffa16 100644 --- a/pkg/cmd/config/get/get_test.go +++ b/pkg/cmd/config/get/get_test.go @@ -10,6 +10,7 @@ import ( "github.com/cli/cli/v2/pkg/iostreams" "github.com/google/shlex" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNewCmdConfigGet(t *testing.T) { @@ -77,11 +78,10 @@ func TestNewCmdConfigGet(t *testing.T) { func Test_getRun(t *testing.T) { tests := []struct { - name string - input *GetOptions - stdout string - stderr string - wantErr bool + name string + input *GetOptions + stdout string + err error }{ { name: "get key", @@ -109,17 +109,24 @@ func Test_getRun(t *testing.T) { }, stdout: "vim\n", }, + { + name: "non-existent key", + input: &GetOptions{ + Key: "non-existent", + Config: config.NewBlankConfig(), + }, + err: nonExistentKeyError{key: "non-existent"}, + }, } for _, tt := range tests { - ios, _, stdout, stderr := iostreams.Test() + ios, _, stdout, _ := iostreams.Test() tt.input.IO = ios t.Run(tt.name, func(t *testing.T) { err := getRun(tt.input) - assert.NoError(t, err) - assert.Equal(t, tt.stdout, stdout.String()) - assert.Equal(t, tt.stderr, stderr.String()) + require.Equal(t, err, tt.err) + require.Equal(t, tt.stdout, stdout.String()) }) } } diff --git a/pkg/cmd/config/list/list.go b/pkg/cmd/config/list/list.go index da2707134..7e4efc06c 100644 --- a/pkg/cmd/config/list/list.go +++ b/pkg/cmd/config/list/list.go @@ -58,11 +58,12 @@ func listRun(opts *ListOptions) error { configOptions := config.ConfigOptions() for _, key := range configOptions { - val, err := cfg.GetOrDefault(host, key.Key) - if err != nil { - return err + optionalValue := cfg.GetOrDefault(host, key.Key) + if optionalValue.IsNone() { + return fmt.Errorf("invalid key: %s", key.Key) } - fmt.Fprintf(opts.IO.Out, "%s=%s\n", key.Key, val) + + fmt.Fprintf(opts.IO.Out, "%s=%s\n", key.Key, optionalValue.Unwrap()) } return nil diff --git a/pkg/cmd/config/set/set_test.go b/pkg/cmd/config/set/set_test.go index 532d78f62..98c5fb1f2 100644 --- a/pkg/cmd/config/set/set_test.go +++ b/pkg/cmd/config/set/set_test.go @@ -150,9 +150,9 @@ func Test_setRun(t *testing.T) { assert.Equal(t, tt.stdout, stdout.String()) assert.Equal(t, tt.stderr, stderr.String()) - val, err := tt.input.Config.GetOrDefault(tt.input.Hostname, tt.input.Key) - assert.NoError(t, err) - assert.Equal(t, tt.expectedValue, val) + optionalValue := tt.input.Config.GetOrDefault(tt.input.Hostname, tt.input.Key) + assert.True(t, optionalValue.IsSome(), "expected value to be set") + assert.Equal(t, tt.expectedValue, optionalValue.Unwrap()) }) } } diff --git a/pkg/option/option.go b/pkg/option/option.go new file mode 100644 index 000000000..6120994ba --- /dev/null +++ b/pkg/option/option.go @@ -0,0 +1,115 @@ +package o + +import "fmt" + +// Option represents an optional value. The [Some] variant contains a value and +// the [None] variant represents the absence of a value. +type Option[T any] struct { + value T + present bool +} + +// Some instantiates an [Option] with a value. +func Some[T any](value T) Option[T] { + return Option[T]{value, true} +} + +// None instantiates an [Option] with no value. +func None[T any]() Option[T] { + return Option[T]{} +} + +// String implements the [fmt.Stringer] interface. +func (o Option[T]) String() string { + if o.present { + return fmt.Sprintf("Some(%v)", o.value) + } + + return "None" +} + +var _ fmt.Stringer = Option[struct{}]{} + +// Unwrap returns the underlying value of a [Some] variant, or panics if called +// on a [None] variant. +func (o Option[T]) Unwrap() T { + if o.present { + return o.value + } + + panic("called `Option.Unwrap()` on a `None` value") +} + +// UnwrapOr returns the underlying value of a [Some] variant, or the provided +// value on a [None] variant. +func (o Option[T]) UnwrapOr(value T) T { + if o.present { + return o.value + } + + return value +} + +// UnwrapOrElse returns the underlying value of a [Some] variant, or the result +// of calling the provided function on a [None] variant. +func (o Option[T]) UnwrapOrElse(f func() T) T { + if o.present { + return o.value + } + + return f() +} + +// UnwrapOrZero returns the underlying value of a [Some] variant, or the zero +// value on a [None] variant. +func (o Option[T]) UnwrapOrZero() T { + if o.present { + return o.value + } + + var value T + return value +} + +// IsSome returns true if the [Option] is a [Some] variant. +func (o Option[T]) IsSome() bool { + return o.present +} + +// IsSome returns true if the [Option] is a [Some] variant and the value inside of it equals the provided value. +// func (o Option[T]) Is(t T) bool { +// return o.present && o.value == t +// } + +func (o Option[T]) IsSomeAnd(f func(T) bool) bool { + return o.present && f(o.value) +} + +// IsNone returns true if the [Option] is a [None] variant. +func (o Option[T]) IsNone() bool { + return !o.present +} + +// Value returns the underlying value and true for a [Some] variant, or the +// zero value and false for a [None] variant. +func (o Option[T]) Value() (T, bool) { + return o.value, o.present +} + +// Expect returns the underlying value for a [Some] variant, or panics with the +// provided message for a [None] variant. +func (o Option[T]) Expect(message string) T { + if o.present { + return o.value + } + + panic(message) +} + +func Map[T any, U any](f func(T) U, o Option[T]) Option[U] { + if o.present { + return Some(f(o.value)) + } + + return None[U]() +}