diff --git a/internal/config/migration/multi_account.go b/internal/config/migration/multi_account.go index 92a84d131..d7d28b1c6 100644 --- a/internal/config/migration/multi_account.go +++ b/internal/config/migration/multi_account.go @@ -10,6 +10,8 @@ import ( "github.com/cli/go-gh/v2/pkg/config" ) +var noTokenError = errors.New("no token found") + type CowardlyRefusalError struct { err error } @@ -21,6 +23,11 @@ func (e CowardlyRefusalError) Error() string { var hostsKey = []string{"hosts"} +type tokenSource struct { + token string + inKeyring bool +} + // This migration exists to take a hosts section of the following structure: // // github.com: @@ -95,12 +102,21 @@ func (m MultiAccount) Do(c *config.Config) error { // Otherwise let's get to the business of migrating! for _, hostname := range hostnames { - token, inKeyring, err := getToken(c, hostname) + tokenSource, err := getToken(c, hostname) + // If no token existed for this host we'll remove the entry from the hosts file + // by deleting it and moving on to the next one. + if errors.Is(err, noTokenError) { + // The only error that can be returned here is the key not existing, which + // we know can't be true. + _ = c.Remove(append(hostsKey, hostname)) + continue + } + // For any other error we'll error out if err != nil { return CowardlyRefusalError{fmt.Errorf("couldn't find oauth token for %q: %w", hostname, err)} } - username, err := getUsername(c, hostname, token, m.Transport) + username, err := getUsername(c, hostname, tokenSource.token, m.Transport) if err != nil { return CowardlyRefusalError{fmt.Errorf("couldn't get user name for %q: %w", hostname, err)} } @@ -109,7 +125,7 @@ func (m MultiAccount) Do(c *config.Config) error { return CowardlyRefusalError{fmt.Errorf("couldn't not migrate config for %q: %w", hostname, err)} } - if err := migrateToken(hostname, username, token, inKeyring); err != nil { + if err := migrateToken(hostname, username, tokenSource); err != nil { return CowardlyRefusalError{fmt.Errorf("couldn't not migrate oauth token for %q: %w", hostname, err)} } } @@ -117,18 +133,27 @@ func (m MultiAccount) Do(c *config.Config) error { return nil } -func getToken(c *config.Config, hostname string) (string, bool, error) { +func getToken(c *config.Config, hostname string) (tokenSource, error) { if token, _ := c.Get(append(hostsKey, hostname, "oauth_token")); token != "" { - return token, false, nil + return tokenSource{token: token, inKeyring: false}, nil } token, err := keyring.Get(keyringServiceName(hostname), "") - if err != nil { - return "", false, err + + // If we have an error and it's not relating to there being no token + // then we'll return the error cause that's really unexpected. + if err != nil && !errors.Is(err, keyring.ErrNotFound) { + return tokenSource{}, err } - if token == "" { - return "", false, errors.New("token not found in config or keyring") + + // Otherwise we'll return a sentinel error + if err != nil || token == "" { + return tokenSource{}, noTokenError } - return token, true, nil + + return tokenSource{ + token: token, + inKeyring: true, + }, nil } func getUsername(c *config.Config, hostname, token string, transport http.RoundTripper) (string, error) { @@ -157,14 +182,14 @@ func getUsername(c *config.Config, hostname, token string, transport http.RoundT return query.Viewer.Login, nil } -func migrateToken(hostname, username, token string, inKeyring bool) error { +func migrateToken(hostname, username string, tokenSource tokenSource) error { // If token is not currently stored in the keyring do not migrate it, // as it is being stored in the config and is being handled when // when migrating the config. - if !inKeyring { + if !tokenSource.inKeyring { return nil } - return keyring.Set(keyringServiceName(hostname), username, token) + return keyring.Set(keyringServiceName(hostname), username, tokenSource.token) } func migrateConfig(c *config.Config, hostname, username string) error { diff --git a/internal/config/migration/multi_account_test.go b/internal/config/migration/multi_account_test.go index bcdbcc343..f70839881 100644 --- a/internal/config/migration/multi_account_test.go +++ b/internal/config/migration/multi_account_test.go @@ -1,6 +1,7 @@ package migration_test import ( + "errors" "fmt" "testing" @@ -252,6 +253,40 @@ hosts: requireKeyWithValue(t, cfg, []string{"hosts", "github.com", "users", "monalisa", "git_protocol"}, "ssh") } +func TestMigrationRemovesHostsWithInvalidTokens(t *testing.T) { + // Simulates config when user is logged in securely + // but no token entry is in the keyring. + keyring.MockInit() + cfg := config.ReadFromString(` +hosts: + github.com: + user: user1 + git_protocol: ssh +`) + + m := migration.MultiAccount{} + require.NoError(t, m.Do(cfg)) + + requireNoKey(t, cfg, []string{"hosts", "github.com"}) +} + +func TestMigrationErrorsWhenUnableToGetExpectedSecureToken(t *testing.T) { + // Simulates config when user is logged in securely + // but no token entry is in the keyring. + keyring.MockInitWithError(errors.New("keyring test error")) + cfg := config.ReadFromString(` +hosts: + github.com: + user: user1 + git_protocol: ssh +`) + + m := migration.MultiAccount{} + err := m.Do(cfg) + + require.ErrorContains(t, err, `couldn't find oauth token for "github.com": keyring test error`) +} + func requireKeyWithValue(t *testing.T, cfg *config.Config, keys []string, value string) { t.Helper() @@ -260,3 +295,11 @@ func requireKeyWithValue(t *testing.T, cfg *config.Config, keys []string, value require.Equal(t, value, actual) } + +func requireNoKey(t *testing.T, cfg *config.Config, keys []string) { + t.Helper() + + _, err := cfg.Get(keys) + var keyNotFoundError *config.KeyNotFoundError + require.ErrorAs(t, err, &keyNotFoundError) +} diff --git a/internal/keyring/keyring.go b/internal/keyring/keyring.go index b6ee990fc..f873c6436 100644 --- a/internal/keyring/keyring.go +++ b/internal/keyring/keyring.go @@ -2,11 +2,14 @@ package keyring import ( + "errors" "time" "github.com/zalando/go-keyring" ) +var ErrNotFound = errors.New("secret not found in keyring") + type TimeoutError struct { message string } @@ -46,6 +49,9 @@ func Get(service, user string) (string, error) { }() select { case res := <-ch: + if errors.Is(res.err, keyring.ErrNotFound) { + return "", ErrNotFound + } return res.val, res.err case <-time.After(3 * time.Second): return "", &TimeoutError{"timeout while trying to get secret from keyring"}