Remove host entries without valid tokens during migration
This commit is contained in:
parent
4f33d88c5f
commit
06c36a74c2
3 changed files with 87 additions and 13 deletions
|
|
@ -10,6 +10,8 @@ import (
|
|||
"github.com/cli/go-gh/v2/pkg/config"
|
||||
)
|
||||
|
||||
var noTokenError = errors.New("no token found")
|
||||
|
||||
type CowardlyRefusalError struct {
|
||||
err error
|
||||
}
|
||||
|
|
@ -21,6 +23,11 @@ func (e CowardlyRefusalError) Error() string {
|
|||
|
||||
var hostsKey = []string{"hosts"}
|
||||
|
||||
type tokenSource struct {
|
||||
token string
|
||||
inKeyring bool
|
||||
}
|
||||
|
||||
// This migration exists to take a hosts section of the following structure:
|
||||
//
|
||||
// github.com:
|
||||
|
|
@ -95,12 +102,21 @@ func (m MultiAccount) Do(c *config.Config) error {
|
|||
|
||||
// Otherwise let's get to the business of migrating!
|
||||
for _, hostname := range hostnames {
|
||||
token, inKeyring, err := getToken(c, hostname)
|
||||
tokenSource, err := getToken(c, hostname)
|
||||
// If no token existed for this host we'll remove the entry from the hosts file
|
||||
// by deleting it and moving on to the next one.
|
||||
if errors.Is(err, noTokenError) {
|
||||
// The only error that can be returned here is the key not existing, which
|
||||
// we know can't be true.
|
||||
_ = c.Remove(append(hostsKey, hostname))
|
||||
continue
|
||||
}
|
||||
// For any other error we'll error out
|
||||
if err != nil {
|
||||
return CowardlyRefusalError{fmt.Errorf("couldn't find oauth token for %q: %w", hostname, err)}
|
||||
}
|
||||
|
||||
username, err := getUsername(c, hostname, token, m.Transport)
|
||||
username, err := getUsername(c, hostname, tokenSource.token, m.Transport)
|
||||
if err != nil {
|
||||
return CowardlyRefusalError{fmt.Errorf("couldn't get user name for %q: %w", hostname, err)}
|
||||
}
|
||||
|
|
@ -109,7 +125,7 @@ func (m MultiAccount) Do(c *config.Config) error {
|
|||
return CowardlyRefusalError{fmt.Errorf("couldn't not migrate config for %q: %w", hostname, err)}
|
||||
}
|
||||
|
||||
if err := migrateToken(hostname, username, token, inKeyring); err != nil {
|
||||
if err := migrateToken(hostname, username, tokenSource); err != nil {
|
||||
return CowardlyRefusalError{fmt.Errorf("couldn't not migrate oauth token for %q: %w", hostname, err)}
|
||||
}
|
||||
}
|
||||
|
|
@ -117,18 +133,27 @@ func (m MultiAccount) Do(c *config.Config) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func getToken(c *config.Config, hostname string) (string, bool, error) {
|
||||
func getToken(c *config.Config, hostname string) (tokenSource, error) {
|
||||
if token, _ := c.Get(append(hostsKey, hostname, "oauth_token")); token != "" {
|
||||
return token, false, nil
|
||||
return tokenSource{token: token, inKeyring: false}, nil
|
||||
}
|
||||
token, err := keyring.Get(keyringServiceName(hostname), "")
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
|
||||
// If we have an error and it's not relating to there being no token
|
||||
// then we'll return the error cause that's really unexpected.
|
||||
if err != nil && !errors.Is(err, keyring.ErrNotFound) {
|
||||
return tokenSource{}, err
|
||||
}
|
||||
if token == "" {
|
||||
return "", false, errors.New("token not found in config or keyring")
|
||||
|
||||
// Otherwise we'll return a sentinel error
|
||||
if err != nil || token == "" {
|
||||
return tokenSource{}, noTokenError
|
||||
}
|
||||
return token, true, nil
|
||||
|
||||
return tokenSource{
|
||||
token: token,
|
||||
inKeyring: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func getUsername(c *config.Config, hostname, token string, transport http.RoundTripper) (string, error) {
|
||||
|
|
@ -157,14 +182,14 @@ func getUsername(c *config.Config, hostname, token string, transport http.RoundT
|
|||
return query.Viewer.Login, nil
|
||||
}
|
||||
|
||||
func migrateToken(hostname, username, token string, inKeyring bool) error {
|
||||
func migrateToken(hostname, username string, tokenSource tokenSource) error {
|
||||
// If token is not currently stored in the keyring do not migrate it,
|
||||
// as it is being stored in the config and is being handled when
|
||||
// when migrating the config.
|
||||
if !inKeyring {
|
||||
if !tokenSource.inKeyring {
|
||||
return nil
|
||||
}
|
||||
return keyring.Set(keyringServiceName(hostname), username, token)
|
||||
return keyring.Set(keyringServiceName(hostname), username, tokenSource.token)
|
||||
}
|
||||
|
||||
func migrateConfig(c *config.Config, hostname, username string) error {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package migration_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
|
|
@ -252,6 +253,40 @@ hosts:
|
|||
requireKeyWithValue(t, cfg, []string{"hosts", "github.com", "users", "monalisa", "git_protocol"}, "ssh")
|
||||
}
|
||||
|
||||
func TestMigrationRemovesHostsWithInvalidTokens(t *testing.T) {
|
||||
// Simulates config when user is logged in securely
|
||||
// but no token entry is in the keyring.
|
||||
keyring.MockInit()
|
||||
cfg := config.ReadFromString(`
|
||||
hosts:
|
||||
github.com:
|
||||
user: user1
|
||||
git_protocol: ssh
|
||||
`)
|
||||
|
||||
m := migration.MultiAccount{}
|
||||
require.NoError(t, m.Do(cfg))
|
||||
|
||||
requireNoKey(t, cfg, []string{"hosts", "github.com"})
|
||||
}
|
||||
|
||||
func TestMigrationErrorsWhenUnableToGetExpectedSecureToken(t *testing.T) {
|
||||
// Simulates config when user is logged in securely
|
||||
// but no token entry is in the keyring.
|
||||
keyring.MockInitWithError(errors.New("keyring test error"))
|
||||
cfg := config.ReadFromString(`
|
||||
hosts:
|
||||
github.com:
|
||||
user: user1
|
||||
git_protocol: ssh
|
||||
`)
|
||||
|
||||
m := migration.MultiAccount{}
|
||||
err := m.Do(cfg)
|
||||
|
||||
require.ErrorContains(t, err, `couldn't find oauth token for "github.com": keyring test error`)
|
||||
}
|
||||
|
||||
func requireKeyWithValue(t *testing.T, cfg *config.Config, keys []string, value string) {
|
||||
t.Helper()
|
||||
|
||||
|
|
@ -260,3 +295,11 @@ func requireKeyWithValue(t *testing.T, cfg *config.Config, keys []string, value
|
|||
|
||||
require.Equal(t, value, actual)
|
||||
}
|
||||
|
||||
func requireNoKey(t *testing.T, cfg *config.Config, keys []string) {
|
||||
t.Helper()
|
||||
|
||||
_, err := cfg.Get(keys)
|
||||
var keyNotFoundError *config.KeyNotFoundError
|
||||
require.ErrorAs(t, err, &keyNotFoundError)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,11 +2,14 @@
|
|||
package keyring
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/zalando/go-keyring"
|
||||
)
|
||||
|
||||
var ErrNotFound = errors.New("secret not found in keyring")
|
||||
|
||||
type TimeoutError struct {
|
||||
message string
|
||||
}
|
||||
|
|
@ -46,6 +49,9 @@ func Get(service, user string) (string, error) {
|
|||
}()
|
||||
select {
|
||||
case res := <-ch:
|
||||
if errors.Is(res.err, keyring.ErrNotFound) {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
return res.val, res.err
|
||||
case <-time.After(3 * time.Second):
|
||||
return "", &TimeoutError{"timeout while trying to get secret from keyring"}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue