Use go-gh config package (#5771)

This commit is contained in:
Sam Coe 2022-06-23 12:50:04 +01:00 committed by GitHub
parent 5227af0c99
commit cacff4ad6d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
82 changed files with 1379 additions and 3334 deletions

View file

@ -14,6 +14,7 @@ import (
type configGetter interface {
Get(string, string) (string, error)
AuthToken(string) (string, string)
}
type HTTPClientOptions struct {
@ -52,7 +53,9 @@ func NewHTTPClient(opts HTTPClientOptions) (*http.Client, error) {
return nil, err
}
client.Transport = AddAuthTokenHeader(client.Transport, opts.Config)
if opts.Config != nil {
client.Transport = AddAuthTokenHeader(client.Transport, opts.Config)
}
return client, nil
}
@ -75,7 +78,7 @@ func AddCacheTTLHeader(rt http.RoundTripper, ttl time.Duration) http.RoundTrippe
func AddAuthTokenHeader(rt http.RoundTripper, cfg configGetter) http.RoundTripper {
return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) {
hostname := ghinstance.NormalizeHostname(getHost(req))
if token, err := cfg.Get(hostname, "oauth_token"); err == nil && token != "" {
if token, _ := cfg.AuthToken(hostname); token != "" {
req.Header.Set("Authorization", fmt.Sprintf("token %s", token))
}
return rt.RoundTrip(req)

View file

@ -211,6 +211,10 @@ func (c tinyConfig) Get(host, key string) (string, error) {
return c[fmt.Sprintf("%s:%s", host, key)], nil
}
func (c tinyConfig) AuthToken(host string) (string, string) {
return c[fmt.Sprintf("%s:%s", host, "oauth_token")], "oauth_token"
}
var requestAtRE = regexp.MustCompile(`(?m)^\* Request at .+`)
var dateRE = regexp.MustCompile(`(?m)^< Date: .+`)
var hostWithPortRE = regexp.MustCompile(`127\.0\.0\.1:\d+`)

View file

@ -98,10 +98,8 @@ func mainRun() exitCode {
return exitError
}
// TODO: remove after FromFullName has been revisited
if host, err := cfg.DefaultHost(); err == nil {
ghrepo.SetDefaultHost(host)
}
host, _ := cfg.DefaultHost()
ghrepo.SetDefaultHost(host)
expandedArgs := []string{}
if len(os.Args) > 0 {
@ -170,18 +168,17 @@ func mainRun() exitCode {
// provide completions for aliases and extensions
rootCmd.ValidArgsFunction = func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
var results []string
if aliases, err := cfg.Aliases(); err == nil {
for aliasName, aliasValue := range aliases.All() {
if strings.HasPrefix(aliasName, toComplete) {
var s string
if strings.HasPrefix(aliasValue, "!") {
s = fmt.Sprintf("%s\tShell alias", aliasName)
} else {
aliasValue = text.Truncate(80, aliasValue)
s = fmt.Sprintf("%s\tAlias for %s", aliasName, aliasValue)
}
results = append(results, s)
aliases := cfg.Aliases()
for aliasName, aliasValue := range aliases.All() {
if strings.HasPrefix(aliasName, toComplete) {
var s string
if strings.HasPrefix(aliasValue, "!") {
s = fmt.Sprintf("%s\tShell alias", aliasName)
} else {
aliasValue = text.Truncate(80, aliasValue)
s = fmt.Sprintf("%s\tAlias for %s", aliasName, aliasValue)
}
results = append(results, s)
}
}
for _, ext := range cmdFactory.ExtensionManager.List() {

2
go.mod
View file

@ -9,7 +9,7 @@ require (
github.com/charmbracelet/glamour v0.4.0
github.com/charmbracelet/lipgloss v0.5.0
github.com/cli/browser v1.1.0
github.com/cli/go-gh v0.0.4-0.20220614183308-ef2bca923638
github.com/cli/go-gh v0.0.4-0.20220623035622-91ca4ef447d4
github.com/cli/oauth v0.9.0
github.com/cli/safeexec v1.0.0
github.com/cli/shurcooL-graphql v0.0.1

4
go.sum
View file

@ -58,8 +58,8 @@ github.com/cli/browser v1.1.0 h1:xOZBfkfY9L9vMBgqb1YwRirGu6QFaQ5dP/vXt5ENSOY=
github.com/cli/browser v1.1.0/go.mod h1:HKMQAt9t12kov91Mn7RfZxyJQQgWgyS/3SZswlZ5iTI=
github.com/cli/crypto v0.0.0-20210929142629-6be313f59b03 h1:3f4uHLfWx4/WlnMPXGai03eoWAI+oGHJwr+5OXfxCr8=
github.com/cli/crypto v0.0.0-20210929142629-6be313f59b03/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
github.com/cli/go-gh v0.0.4-0.20220614183308-ef2bca923638 h1:7MXhocX2RDlWrjKZ1pZsy8eMNGa3xkZzPrGC1IPBfx4=
github.com/cli/go-gh v0.0.4-0.20220614183308-ef2bca923638/go.mod h1:Y/QFb/VxnXQH0W4VlP+507HVxMzQ430x8kdjUuVcono=
github.com/cli/go-gh v0.0.4-0.20220623035622-91ca4ef447d4 h1:6WrekNBE2Y+Xl9OCl7vsg49SSN68hwaVryfEawQevaQ=
github.com/cli/go-gh v0.0.4-0.20220623035622-91ca4ef447d4/go.mod h1:Y/QFb/VxnXQH0W4VlP+507HVxMzQ430x8kdjUuVcono=
github.com/cli/oauth v0.9.0 h1:nxBC0Df4tUzMkqffAB+uZvisOwT3/N9FpkfdTDtafxc=
github.com/cli/oauth v0.9.0/go.mod h1:qd/FX8ZBD6n1sVNQO3aIdRxeu5LGw9WhKnYhIIoC2A4=
github.com/cli/safeexec v1.0.0 h1:0VngyaIyqACHdcMNWfo6+KdUYnqEr2Sg+bSP1pdF+dI=

View file

@ -32,9 +32,8 @@ var (
type iconfig interface {
Get(string, string) (string, error)
Set(string, string, string) error
Set(string, string, string)
Write() error
WriteHosts() error
}
func AuthFlowWithConfig(cfg iconfig, IO *iostreams.IOStreams, hostname, notice string, additionalScopes []string, isInteractive bool) (string, error) {
@ -55,16 +54,10 @@ func AuthFlowWithConfig(cfg iconfig, IO *iostreams.IOStreams, hostname, notice s
return "", err
}
err = cfg.Set(hostname, "user", userLogin)
if err != nil {
return "", err
}
err = cfg.Set(hostname, "oauth_token", token)
if err != nil {
return "", err
}
cfg.Set(hostname, "user", userLogin)
cfg.Set(hostname, "oauth_token", token)
return token, cfg.WriteHosts()
return token, cfg.Write()
}
func authFlow(oauthHost string, IO *iostreams.IOStreams, notice string, additionalScopes []string, isInteractive bool, browserLauncher string) (string, string, error) {

View file

@ -1,60 +0,0 @@
package config
import (
"fmt"
)
type AliasConfig struct {
ConfigMap
Parent Config
}
func (a *AliasConfig) Get(alias string) (string, bool) {
if a.Empty() {
return "", false
}
value, _ := a.GetStringValue(alias)
return value, value != ""
}
func (a *AliasConfig) Add(alias, expansion string) error {
err := a.SetStringValue(alias, expansion)
if err != nil {
return fmt.Errorf("failed to update config: %w", err)
}
err = a.Parent.Write()
if err != nil {
return fmt.Errorf("failed to write config: %w", err)
}
return nil
}
func (a *AliasConfig) Delete(alias string) error {
a.RemoveEntry(alias)
err := a.Parent.Write()
if err != nil {
return fmt.Errorf("failed to write config: %w", err)
}
return nil
}
func (a *AliasConfig) All() map[string]string {
out := map[string]string{}
if a.Empty() {
return out
}
for i := 0; i < len(a.Root.Content)-1; i += 2 {
key := a.Root.Content[i].Value
value := a.Root.Content[i+1].Value
out[key] = value
}
return out
}

223
internal/config/config.go Normal file
View file

@ -0,0 +1,223 @@
package config
import (
"os"
"path/filepath"
ghAuth "github.com/cli/go-gh/pkg/auth"
ghConfig "github.com/cli/go-gh/pkg/config"
)
const (
hosts = "hosts"
aliases = "aliases"
)
// This interface describes interacting with some persistent configuration for gh.
//go:generate moq -rm -out config_mock.go . Config
type Config interface {
AuthToken(string) (string, string)
Get(string, string) (string, error)
GetOrDefault(string, string) (string, error)
Set(string, string, string)
UnsetHost(string)
Hosts() []string
DefaultHost() (string, string)
Aliases() *AliasConfig
Write() error
}
func NewConfig() (Config, error) {
c, err := ghConfig.Read()
if err != nil {
return nil, err
}
return &cfg{c}, nil
}
// Implements Config interface
type cfg struct {
cfg *ghConfig.Config
}
func (c *cfg) AuthToken(hostname string) (string, string) {
return ghAuth.TokenForHost(hostname)
}
func (c *cfg) Get(hostname, key string) (string, error) {
if hostname != "" {
val, err := c.cfg.Get([]string{hosts, hostname, key})
if err == nil {
return val, err
}
}
return c.cfg.Get([]string{key})
}
func (c *cfg) GetOrDefault(hostname, key string) (string, error) {
var val string
var err error
if hostname != "" {
val, err = c.cfg.Get([]string{hosts, hostname, key})
if err == nil {
return val, err
}
}
val, err = c.cfg.Get([]string{key})
if err == nil {
return val, err
}
if defaultExists(key) {
return defaultFor(key), nil
}
return val, err
}
func (c *cfg) Set(hostname, key, value string) {
if hostname == "" {
c.cfg.Set([]string{key}, value)
}
c.cfg.Set([]string{hosts, hostname, key}, value)
}
func (c *cfg) UnsetHost(hostname string) {
if hostname == "" {
return
}
_ = c.cfg.Remove([]string{hosts, hostname})
}
func (c *cfg) Hosts() []string {
return ghAuth.KnownHosts()
}
func (c *cfg) DefaultHost() (string, string) {
return ghAuth.DefaultHost()
}
func (c *cfg) Aliases() *AliasConfig {
return &AliasConfig{cfg: c.cfg}
}
func (c *cfg) Write() error {
return ghConfig.Write(c.cfg)
}
func defaultFor(key string) string {
for _, co := range configOptions {
if co.Key == key {
return co.DefaultValue
}
}
return ""
}
func defaultExists(key string) bool {
for _, co := range configOptions {
if co.Key == key {
return true
}
}
return false
}
type AliasConfig struct {
cfg *ghConfig.Config
}
func (a *AliasConfig) Get(alias string) (string, error) {
return a.cfg.Get([]string{aliases, alias})
}
func (a *AliasConfig) Add(alias, expansion string) {
a.cfg.Set([]string{aliases, alias}, expansion)
}
func (a *AliasConfig) Delete(alias string) error {
return a.cfg.Remove([]string{aliases, alias})
}
func (a *AliasConfig) All() map[string]string {
out := map[string]string{}
keys, err := a.cfg.Keys([]string{aliases})
if err != nil {
return out
}
for _, key := range keys {
val, _ := a.cfg.Get([]string{aliases, key})
out[key] = val
}
return out
}
type ConfigOption struct {
Key string
Description string
DefaultValue string
AllowedValues []string
}
var configOptions = []ConfigOption{
{
Key: "git_protocol",
Description: "the protocol to use for git clone and push operations",
DefaultValue: "https",
AllowedValues: []string{"https", "ssh"},
},
{
Key: "editor",
Description: "the text editor program to use for authoring text",
DefaultValue: "",
},
{
Key: "prompt",
Description: "toggle interactive prompting in the terminal",
DefaultValue: "enabled",
AllowedValues: []string{"enabled", "disabled"},
},
{
Key: "pager",
Description: "the terminal pager program to send standard output to",
DefaultValue: "",
},
{
Key: "http_unix_socket",
Description: "the path to a Unix socket through which to make an HTTP connection",
DefaultValue: "",
},
{
Key: "browser",
Description: "the web browser to use for opening URLs",
DefaultValue: "",
},
}
func ConfigOptions() []ConfigOption {
return configOptions
}
func HomeDirPath(subdir string) (string, error) {
homeDir, err := os.UserHomeDir()
if err != nil {
return "", err
}
newPath := filepath.Join(homeDir, subdir)
return newPath, nil
}
func StateDir() string {
return ghConfig.StateDir()
}
func DataDir() string {
return ghConfig.DataDir()
}
func ConfigDir() string {
return ghConfig.ConfigDir()
}

View file

@ -1,349 +0,0 @@
package config
import (
"errors"
"fmt"
"io"
"os"
"path/filepath"
"runtime"
"syscall"
"gopkg.in/yaml.v3"
)
const (
GH_CONFIG_DIR = "GH_CONFIG_DIR"
XDG_CONFIG_HOME = "XDG_CONFIG_HOME"
XDG_STATE_HOME = "XDG_STATE_HOME"
XDG_DATA_HOME = "XDG_DATA_HOME"
APP_DATA = "AppData"
LOCAL_APP_DATA = "LocalAppData"
)
// Config path precedence
// 1. GH_CONFIG_DIR
// 2. XDG_CONFIG_HOME
// 3. AppData (windows only)
// 4. HOME
func ConfigDir() string {
var path string
if a := os.Getenv(GH_CONFIG_DIR); a != "" {
path = a
} else if b := os.Getenv(XDG_CONFIG_HOME); b != "" {
path = filepath.Join(b, "gh")
} else if c := os.Getenv(APP_DATA); runtime.GOOS == "windows" && c != "" {
path = filepath.Join(c, "GitHub CLI")
} else {
d, _ := os.UserHomeDir()
path = filepath.Join(d, ".config", "gh")
}
// If the path does not exist and the GH_CONFIG_DIR flag is not set try
// migrating config from default paths.
if !dirExists(path) && os.Getenv(GH_CONFIG_DIR) == "" {
_ = autoMigrateConfigDir(path)
}
return path
}
// State path precedence
// 1. XDG_STATE_HOME
// 2. LocalAppData (windows only)
// 3. HOME
func StateDir() string {
var path string
if a := os.Getenv(XDG_STATE_HOME); a != "" {
path = filepath.Join(a, "gh")
} else if b := os.Getenv(LOCAL_APP_DATA); runtime.GOOS == "windows" && b != "" {
path = filepath.Join(b, "GitHub CLI")
} else {
c, _ := os.UserHomeDir()
path = filepath.Join(c, ".local", "state", "gh")
}
// If the path does not exist try migrating state from default paths
if !dirExists(path) {
_ = autoMigrateStateDir(path)
}
return path
}
// Data path precedence
// 1. XDG_DATA_HOME
// 2. LocalAppData (windows only)
// 3. HOME
func DataDir() string {
var path string
if a := os.Getenv(XDG_DATA_HOME); a != "" {
path = filepath.Join(a, "gh")
} else if b := os.Getenv(LOCAL_APP_DATA); runtime.GOOS == "windows" && b != "" {
path = filepath.Join(b, "GitHub CLI")
} else {
c, _ := os.UserHomeDir()
path = filepath.Join(c, ".local", "share", "gh")
}
return path
}
var errSamePath = errors.New("same path")
var errNotExist = errors.New("not exist")
// Check default path, os.UserHomeDir, for existing configs
// If configs exist then move them to newPath
func autoMigrateConfigDir(newPath string) error {
path, err := os.UserHomeDir()
if oldPath := filepath.Join(path, ".config", "gh"); err == nil && dirExists(oldPath) {
return migrateDir(oldPath, newPath)
}
return errNotExist
}
// Check default path, os.UserHomeDir, for existing state file (state.yml)
// If state file exist then move it to newPath
func autoMigrateStateDir(newPath string) error {
path, err := os.UserHomeDir()
if oldPath := filepath.Join(path, ".config", "gh"); err == nil && dirExists(oldPath) {
return migrateFile(oldPath, newPath, "state.yml")
}
return errNotExist
}
func migrateFile(oldPath, newPath, file string) error {
if oldPath == newPath {
return errSamePath
}
oldFile := filepath.Join(oldPath, file)
newFile := filepath.Join(newPath, file)
if !fileExists(oldFile) {
return errNotExist
}
_ = os.MkdirAll(filepath.Dir(newFile), 0755)
return os.Rename(oldFile, newFile)
}
func migrateDir(oldPath, newPath string) error {
if oldPath == newPath {
return errSamePath
}
if !dirExists(oldPath) {
return errNotExist
}
_ = os.MkdirAll(filepath.Dir(newPath), 0755)
return os.Rename(oldPath, newPath)
}
func dirExists(path string) bool {
f, err := os.Stat(path)
return err == nil && f.IsDir()
}
func fileExists(path string) bool {
f, err := os.Stat(path)
return err == nil && !f.IsDir()
}
func ConfigFile() string {
return filepath.Join(ConfigDir(), "config.yml")
}
func HostsConfigFile() string {
return filepath.Join(ConfigDir(), "hosts.yml")
}
func ParseDefaultConfig() (Config, error) {
return parseConfig(ConfigFile())
}
func HomeDirPath(subdir string) (string, error) {
homeDir, err := os.UserHomeDir()
if err != nil {
return "", err
}
newPath := filepath.Join(homeDir, subdir)
return newPath, nil
}
var ReadConfigFile = func(filename string) ([]byte, error) {
f, err := os.Open(filename)
if err != nil {
return nil, pathError(err)
}
defer f.Close()
data, err := io.ReadAll(f)
if err != nil {
return nil, err
}
return data, nil
}
var WriteConfigFile = func(filename string, data []byte) error {
err := os.MkdirAll(filepath.Dir(filename), 0771)
if err != nil {
return pathError(err)
}
cfgFile, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) // cargo coded from setup
if err != nil {
return err
}
defer cfgFile.Close()
_, err = cfgFile.Write(data)
return err
}
var BackupConfigFile = func(filename string) error {
return os.Rename(filename, filename+".bak")
}
func parseConfigFile(filename string) ([]byte, *yaml.Node, error) {
data, err := ReadConfigFile(filename)
if err != nil {
return nil, nil, err
}
root, err := parseConfigData(data)
if err != nil {
return nil, nil, err
}
return data, root, err
}
func parseConfigData(data []byte) (*yaml.Node, error) {
var root yaml.Node
err := yaml.Unmarshal(data, &root)
if err != nil {
return nil, err
}
if len(root.Content) == 0 {
return &yaml.Node{
Kind: yaml.DocumentNode,
Content: []*yaml.Node{{Kind: yaml.MappingNode}},
}, nil
}
if root.Content[0].Kind != yaml.MappingNode {
return &root, fmt.Errorf("expected a top level map")
}
return &root, nil
}
func isLegacy(root *yaml.Node) bool {
for _, v := range root.Content[0].Content {
if v.Value == "github.com" {
return true
}
}
return false
}
func migrateConfig(filename string) error {
b, err := ReadConfigFile(filename)
if err != nil {
return err
}
var hosts map[string][]yaml.Node
err = yaml.Unmarshal(b, &hosts)
if err != nil {
return fmt.Errorf("error decoding legacy format: %w", err)
}
cfg := NewBlankConfig()
for hostname, entries := range hosts {
if len(entries) < 1 {
continue
}
mapContent := entries[0].Content
for i := 0; i < len(mapContent)-1; i += 2 {
if err := cfg.Set(hostname, mapContent[i].Value, mapContent[i+1].Value); err != nil {
return err
}
}
}
err = BackupConfigFile(filename)
if err != nil {
return fmt.Errorf("failed to back up existing config: %w", err)
}
return cfg.Write()
}
func parseConfig(filename string) (Config, error) {
_, root, err := parseConfigFile(filename)
if err != nil {
if os.IsNotExist(err) {
root = NewBlankRoot()
} else {
return nil, err
}
}
if isLegacy(root) {
err = migrateConfig(filename)
if err != nil {
return nil, fmt.Errorf("error migrating legacy config: %w", err)
}
_, root, err = parseConfigFile(filename)
if err != nil {
return nil, fmt.Errorf("failed to reparse migrated config: %w", err)
}
} else {
if _, hostsRoot, err := parseConfigFile(HostsConfigFile()); err == nil {
if len(hostsRoot.Content[0].Content) > 0 {
newContent := []*yaml.Node{
{Value: "hosts"},
hostsRoot.Content[0],
}
restContent := root.Content[0].Content
root.Content[0].Content = append(newContent, restContent...)
}
} else if !errors.Is(err, os.ErrNotExist) {
return nil, err
}
}
return NewConfig(root), nil
}
func pathError(err error) error {
var pathError *os.PathError
if errors.As(err, &pathError) && errors.Is(pathError.Err, syscall.ENOTDIR) {
if p := findRegularFile(pathError.Path); p != "" {
return fmt.Errorf("remove or rename regular file `%s` (must be a directory)", p)
}
}
return err
}
func findRegularFile(p string) string {
for {
if s, err := os.Stat(p); err == nil && s.Mode().IsRegular() {
return p
}
newPath := filepath.Dir(p)
if newPath == p || newPath == "/" || newPath == "." {
break
}
p = newPath
}
return ""
}

View file

@ -1,576 +0,0 @@
package config
import (
"bytes"
"fmt"
"os"
"path/filepath"
"runtime"
"testing"
"github.com/MakeNowJust/heredoc"
"github.com/stretchr/testify/assert"
"gopkg.in/yaml.v3"
)
func Test_parseConfig(t *testing.T) {
defer stubConfig(`---
hosts:
github.com:
user: monalisa
oauth_token: OTOKEN
`, "")()
config, err := parseConfig("config.yml")
assert.NoError(t, err)
user, err := config.Get("github.com", "user")
assert.NoError(t, err)
assert.Equal(t, "monalisa", user)
token, err := config.Get("github.com", "oauth_token")
assert.NoError(t, err)
assert.Equal(t, "OTOKEN", token)
}
func Test_parseConfig_multipleHosts(t *testing.T) {
defer stubConfig(`---
hosts:
example.com:
user: wronguser
oauth_token: NOTTHIS
github.com:
user: monalisa
oauth_token: OTOKEN
`, "")()
config, err := parseConfig("config.yml")
assert.NoError(t, err)
user, err := config.Get("github.com", "user")
assert.NoError(t, err)
assert.Equal(t, "monalisa", user)
token, err := config.Get("github.com", "oauth_token")
assert.NoError(t, err)
assert.Equal(t, "OTOKEN", token)
}
func Test_parseConfig_hostsFile(t *testing.T) {
defer stubConfig("", `---
github.com:
user: monalisa
oauth_token: OTOKEN
`)()
config, err := parseConfig("config.yml")
assert.NoError(t, err)
user, err := config.Get("github.com", "user")
assert.NoError(t, err)
assert.Equal(t, "monalisa", user)
token, err := config.Get("github.com", "oauth_token")
assert.NoError(t, err)
assert.Equal(t, "OTOKEN", token)
}
func Test_parseConfig_hostFallback(t *testing.T) {
defer stubConfig(`---
git_protocol: ssh
`, `---
github.com:
user: monalisa
oauth_token: OTOKEN
example.com:
user: wronguser
oauth_token: NOTTHIS
git_protocol: https
`)()
config, err := parseConfig("config.yml")
assert.NoError(t, err)
val, err := config.GetOrDefault("example.com", "git_protocol")
assert.NoError(t, err)
assert.Equal(t, "https", val)
val, err = config.GetOrDefault("github.com", "git_protocol")
assert.NoError(t, err)
assert.Equal(t, "ssh", val)
val, err = config.GetOrDefault("nonexistent.io", "git_protocol")
assert.NoError(t, err)
assert.Equal(t, "ssh", val)
}
func Test_parseConfig_migrateConfig(t *testing.T) {
defer stubConfig(`---
github.com:
- user: keiyuri
oauth_token: 123456
`, "")()
mainBuf := bytes.Buffer{}
hostsBuf := bytes.Buffer{}
defer StubWriteConfig(&mainBuf, &hostsBuf)()
defer StubBackupConfig()()
_, err := parseConfig("config.yml")
assert.NoError(t, err)
expectedHosts := `github.com:
user: keiyuri
oauth_token: "123456"
`
assert.Equal(t, expectedHosts, hostsBuf.String())
assert.NotContains(t, mainBuf.String(), "github.com")
assert.NotContains(t, mainBuf.String(), "oauth_token")
}
func Test_parseConfigFile(t *testing.T) {
tests := []struct {
contents string
wantsErr bool
}{
{
contents: "",
wantsErr: true,
},
{
contents: " ",
wantsErr: false,
},
{
contents: "\n",
wantsErr: false,
},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("contents: %q", tt.contents), func(t *testing.T) {
defer stubConfig(tt.contents, "")()
_, yamlRoot, err := parseConfigFile("config.yml")
if tt.wantsErr != (err != nil) {
t.Fatalf("got error: %v", err)
}
if tt.wantsErr {
return
}
assert.Equal(t, yaml.MappingNode, yamlRoot.Content[0].Kind)
assert.Equal(t, 0, len(yamlRoot.Content[0].Content))
})
}
}
func Test_ConfigDir(t *testing.T) {
tempDir := t.TempDir()
tests := []struct {
name string
onlyWindows bool
env map[string]string
output string
}{
{
name: "HOME/USERPROFILE specified",
env: map[string]string{
"GH_CONFIG_DIR": "",
"XDG_CONFIG_HOME": "",
"AppData": "",
"USERPROFILE": tempDir,
"HOME": tempDir,
},
output: filepath.Join(tempDir, ".config", "gh"),
},
{
name: "GH_CONFIG_DIR specified",
env: map[string]string{
"GH_CONFIG_DIR": filepath.Join(tempDir, "gh_config_dir"),
},
output: filepath.Join(tempDir, "gh_config_dir"),
},
{
name: "XDG_CONFIG_HOME specified",
env: map[string]string{
"XDG_CONFIG_HOME": tempDir,
},
output: filepath.Join(tempDir, "gh"),
},
{
name: "GH_CONFIG_DIR and XDG_CONFIG_HOME specified",
env: map[string]string{
"GH_CONFIG_DIR": filepath.Join(tempDir, "gh_config_dir"),
"XDG_CONFIG_HOME": tempDir,
},
output: filepath.Join(tempDir, "gh_config_dir"),
},
{
name: "AppData specified",
onlyWindows: true,
env: map[string]string{
"AppData": tempDir,
},
output: filepath.Join(tempDir, "GitHub CLI"),
},
{
name: "GH_CONFIG_DIR and AppData specified",
onlyWindows: true,
env: map[string]string{
"GH_CONFIG_DIR": filepath.Join(tempDir, "gh_config_dir"),
"AppData": tempDir,
},
output: filepath.Join(tempDir, "gh_config_dir"),
},
{
name: "XDG_CONFIG_HOME and AppData specified",
onlyWindows: true,
env: map[string]string{
"XDG_CONFIG_HOME": tempDir,
"AppData": tempDir,
},
output: filepath.Join(tempDir, "gh"),
},
}
for _, tt := range tests {
if tt.onlyWindows && runtime.GOOS != "windows" {
continue
}
t.Run(tt.name, func(t *testing.T) {
if tt.env != nil {
for k, v := range tt.env {
old := os.Getenv(k)
os.Setenv(k, v)
defer os.Setenv(k, old)
}
}
// Create directory to skip auto migration code
// which gets run when target directory does not exist
_ = os.MkdirAll(tt.output, 0755)
assert.Equal(t, tt.output, ConfigDir())
})
}
}
func Test_configFile_Write_toDisk(t *testing.T) {
configDir := filepath.Join(t.TempDir(), ".config", "gh")
_ = os.MkdirAll(configDir, 0755)
os.Setenv(GH_CONFIG_DIR, configDir)
defer os.Unsetenv(GH_CONFIG_DIR)
cfg := NewFromString(`pager: less`)
err := cfg.Write()
if err != nil {
t.Fatal(err)
}
expectedConfig := "pager: less\n"
if configBytes, err := os.ReadFile(filepath.Join(configDir, "config.yml")); err != nil {
t.Error(err)
} else if string(configBytes) != expectedConfig {
t.Errorf("expected config.yml %q, got %q", expectedConfig, string(configBytes))
}
if configBytes, err := os.ReadFile(filepath.Join(configDir, "hosts.yml")); err != nil {
t.Error(err)
} else if string(configBytes) != "" {
t.Errorf("unexpected hosts.yml: %q", string(configBytes))
}
}
func Test_configFile_WriteHosts_toDisk(t *testing.T) {
configDir := filepath.Join(t.TempDir(), ".config", "gh")
_ = os.MkdirAll(configDir, 0755)
os.Setenv(GH_CONFIG_DIR, configDir)
defer os.Unsetenv(GH_CONFIG_DIR)
cfg := NewFromString(heredoc.Doc(`
hosts:
github.com:
user: monalisa
oauth_token: TOKEN
`))
err := cfg.WriteHosts()
if err != nil {
t.Fatal(err)
}
expectedConfig := "github.com:\n user: monalisa\n oauth_token: TOKEN\n"
actualConfig, err := os.ReadFile(filepath.Join(configDir, "hosts.yml"))
assert.NoError(t, err)
assert.Equal(t, expectedConfig, string(actualConfig))
_, nonExistErr := os.Stat(filepath.Join(configDir, "config.yml"))
assert.Error(t, nonExistErr)
}
func Test_autoMigrateConfigDir_noMigration_notExist(t *testing.T) {
homeDir := t.TempDir()
migrateDir := t.TempDir()
homeEnvVar := "HOME"
if runtime.GOOS == "windows" {
homeEnvVar = "USERPROFILE"
}
old := os.Getenv(homeEnvVar)
os.Setenv(homeEnvVar, homeDir)
defer os.Setenv(homeEnvVar, old)
err := autoMigrateConfigDir(migrateDir)
assert.Equal(t, errNotExist, err)
files, err := os.ReadDir(migrateDir)
assert.NoError(t, err)
assert.Equal(t, 0, len(files))
}
func Test_autoMigrateConfigDir_noMigration_samePath(t *testing.T) {
homeDir := t.TempDir()
migrateDir := filepath.Join(homeDir, ".config", "gh")
err := os.MkdirAll(migrateDir, 0755)
assert.NoError(t, err)
homeEnvVar := "HOME"
if runtime.GOOS == "windows" {
homeEnvVar = "USERPROFILE"
}
old := os.Getenv(homeEnvVar)
os.Setenv(homeEnvVar, homeDir)
defer os.Setenv(homeEnvVar, old)
err = autoMigrateConfigDir(migrateDir)
assert.Equal(t, errSamePath, err)
files, err := os.ReadDir(migrateDir)
assert.NoError(t, err)
assert.Equal(t, 0, len(files))
}
func Test_autoMigrateConfigDir_migration(t *testing.T) {
homeDir := t.TempDir()
migrateDir := t.TempDir()
homeConfigDir := filepath.Join(homeDir, ".config", "gh")
migrateConfigDir := filepath.Join(migrateDir, ".config", "gh")
homeEnvVar := "HOME"
if runtime.GOOS == "windows" {
homeEnvVar = "USERPROFILE"
}
old := os.Getenv(homeEnvVar)
os.Setenv(homeEnvVar, homeDir)
defer os.Setenv(homeEnvVar, old)
err := os.MkdirAll(homeConfigDir, 0755)
assert.NoError(t, err)
f, err := os.CreateTemp(homeConfigDir, "")
assert.NoError(t, err)
f.Close()
err = autoMigrateConfigDir(migrateConfigDir)
assert.NoError(t, err)
_, err = os.ReadDir(homeConfigDir)
assert.True(t, os.IsNotExist(err))
files, err := os.ReadDir(migrateConfigDir)
assert.NoError(t, err)
assert.Equal(t, 1, len(files))
}
func Test_StateDir(t *testing.T) {
tempDir := t.TempDir()
tests := []struct {
name string
onlyWindows bool
env map[string]string
output string
}{
{
name: "HOME/USERPROFILE specified",
env: map[string]string{
"XDG_STATE_HOME": "",
"GH_CONFIG_DIR": "",
"XDG_CONFIG_HOME": "",
"LocalAppData": "",
"USERPROFILE": tempDir,
"HOME": tempDir,
},
output: filepath.Join(tempDir, ".local", "state", "gh"),
},
{
name: "XDG_STATE_HOME specified",
env: map[string]string{
"XDG_STATE_HOME": tempDir,
},
output: filepath.Join(tempDir, "gh"),
},
{
name: "LocalAppData specified",
onlyWindows: true,
env: map[string]string{
"LocalAppData": tempDir,
},
output: filepath.Join(tempDir, "GitHub CLI"),
},
{
name: "XDG_STATE_HOME and LocalAppData specified",
onlyWindows: true,
env: map[string]string{
"XDG_STATE_HOME": tempDir,
"LocalAppData": tempDir,
},
output: filepath.Join(tempDir, "gh"),
},
}
for _, tt := range tests {
if tt.onlyWindows && runtime.GOOS != "windows" {
continue
}
t.Run(tt.name, func(t *testing.T) {
if tt.env != nil {
for k, v := range tt.env {
old := os.Getenv(k)
os.Setenv(k, v)
defer os.Setenv(k, old)
}
}
// Create directory to skip auto migration code
// which gets run when target directory does not exist
_ = os.MkdirAll(tt.output, 0755)
assert.Equal(t, tt.output, StateDir())
})
}
}
func Test_autoMigrateStateDir_noMigration_notExist(t *testing.T) {
homeDir := t.TempDir()
migrateDir := t.TempDir()
homeEnvVar := "HOME"
if runtime.GOOS == "windows" {
homeEnvVar = "USERPROFILE"
}
old := os.Getenv(homeEnvVar)
os.Setenv(homeEnvVar, homeDir)
defer os.Setenv(homeEnvVar, old)
err := autoMigrateStateDir(migrateDir)
assert.Equal(t, errNotExist, err)
files, err := os.ReadDir(migrateDir)
assert.NoError(t, err)
assert.Equal(t, 0, len(files))
}
func Test_autoMigrateStateDir_noMigration_samePath(t *testing.T) {
homeDir := t.TempDir()
migrateDir := filepath.Join(homeDir, ".config", "gh")
err := os.MkdirAll(migrateDir, 0755)
assert.NoError(t, err)
homeEnvVar := "HOME"
if runtime.GOOS == "windows" {
homeEnvVar = "USERPROFILE"
}
old := os.Getenv(homeEnvVar)
os.Setenv(homeEnvVar, homeDir)
defer os.Setenv(homeEnvVar, old)
err = autoMigrateStateDir(migrateDir)
assert.Equal(t, errSamePath, err)
files, err := os.ReadDir(migrateDir)
assert.NoError(t, err)
assert.Equal(t, 0, len(files))
}
func Test_autoMigrateStateDir_migration(t *testing.T) {
homeDir := t.TempDir()
migrateDir := t.TempDir()
homeConfigDir := filepath.Join(homeDir, ".config", "gh")
migrateStateDir := filepath.Join(migrateDir, ".local", "state", "gh")
homeEnvVar := "HOME"
if runtime.GOOS == "windows" {
homeEnvVar = "USERPROFILE"
}
old := os.Getenv(homeEnvVar)
os.Setenv(homeEnvVar, homeDir)
defer os.Setenv(homeEnvVar, old)
err := os.MkdirAll(homeConfigDir, 0755)
assert.NoError(t, err)
err = os.WriteFile(filepath.Join(homeConfigDir, "state.yml"), nil, 0755)
assert.NoError(t, err)
err = autoMigrateStateDir(migrateStateDir)
assert.NoError(t, err)
files, err := os.ReadDir(homeConfigDir)
assert.NoError(t, err)
assert.Equal(t, 0, len(files))
files, err = os.ReadDir(migrateStateDir)
assert.NoError(t, err)
assert.Equal(t, 1, len(files))
assert.Equal(t, "state.yml", files[0].Name())
}
func Test_DataDir(t *testing.T) {
tempDir := t.TempDir()
tests := []struct {
name string
onlyWindows bool
env map[string]string
output string
}{
{
name: "HOME/USERPROFILE specified",
env: map[string]string{
"XDG_DATA_HOME": "",
"GH_CONFIG_DIR": "",
"XDG_CONFIG_HOME": "",
"LocalAppData": "",
"USERPROFILE": tempDir,
"HOME": tempDir,
},
output: filepath.Join(tempDir, ".local", "share", "gh"),
},
{
name: "XDG_DATA_HOME specified",
env: map[string]string{
"XDG_DATA_HOME": tempDir,
},
output: filepath.Join(tempDir, "gh"),
},
{
name: "LocalAppData specified",
onlyWindows: true,
env: map[string]string{
"LocalAppData": tempDir,
},
output: filepath.Join(tempDir, "GitHub CLI"),
},
{
name: "XDG_DATA_HOME and LocalAppData specified",
onlyWindows: true,
env: map[string]string{
"XDG_DATA_HOME": tempDir,
"LocalAppData": tempDir,
},
output: filepath.Join(tempDir, "gh"),
},
}
for _, tt := range tests {
if tt.onlyWindows && runtime.GOOS != "windows" {
continue
}
t.Run(tt.name, func(t *testing.T) {
if tt.env != nil {
for k, v := range tt.env {
old := os.Getenv(k)
os.Setenv(k, v)
defer os.Setenv(k, old)
}
}
assert.Equal(t, tt.output, DataDir())
})
}
}

View file

@ -1,113 +0,0 @@
package config
import (
"errors"
"gopkg.in/yaml.v3"
)
// This type implements a low-level get/set config that is backed by an in-memory tree of yaml
// nodes. It allows us to interact with a yaml-based config programmatically, preserving any
// comments that were present when the yaml was parsed.
type ConfigMap struct {
Root *yaml.Node
}
type ConfigEntry struct {
KeyNode *yaml.Node
ValueNode *yaml.Node
Index int
}
type NotFoundError struct {
error
}
func (cm *ConfigMap) Empty() bool {
return cm.Root == nil || len(cm.Root.Content) == 0
}
func (cm *ConfigMap) GetStringValue(key string) (string, error) {
entry, err := cm.FindEntry(key)
if err != nil {
return "", err
}
return entry.ValueNode.Value, nil
}
func (cm *ConfigMap) SetStringValue(key, value string) error {
entry, err := cm.FindEntry(key)
if err == nil {
entry.ValueNode.Value = value
return nil
}
var notFound *NotFoundError
if err != nil && !errors.As(err, &notFound) {
return err
}
keyNode := &yaml.Node{
Kind: yaml.ScalarNode,
Value: key,
}
valueNode := &yaml.Node{
Kind: yaml.ScalarNode,
Tag: "!!str",
Value: value,
}
cm.Root.Content = append(cm.Root.Content, keyNode, valueNode)
return nil
}
func (cm *ConfigMap) FindEntry(key string) (*ConfigEntry, error) {
ce := &ConfigEntry{}
if cm.Empty() {
return ce, &NotFoundError{errors.New("not found")}
}
// Content slice goes [key1, value1, key2, value2, ...].
topLevelPairs := cm.Root.Content
for i, v := range topLevelPairs {
// Skip every other slice item since we only want to check against keys.
if i%2 != 0 {
continue
}
if v.Value == key {
ce.KeyNode = v
ce.Index = i
if i+1 < len(topLevelPairs) {
ce.ValueNode = topLevelPairs[i+1]
}
return ce, nil
}
}
return ce, &NotFoundError{errors.New("not found")}
}
func (cm *ConfigMap) RemoveEntry(key string) {
if cm.Empty() {
return
}
newContent := []*yaml.Node{}
var skipNext bool
for i, v := range cm.Root.Content {
if skipNext {
skipNext = false
continue
}
if i%2 != 0 || v.Value != key {
newContent = append(newContent, v)
} else {
// Don't append current node and skip the next which is this key's value.
skipNext = true
}
}
cm.Root.Content = newContent
}

View file

@ -1,187 +0,0 @@
package config
import (
"testing"
"github.com/stretchr/testify/assert"
"gopkg.in/yaml.v3"
)
func TestFindEntry(t *testing.T) {
tests := []struct {
name string
key string
output string
wantErr bool
}{
{
name: "find key",
key: "valid",
output: "present",
},
{
name: "find key that is not present",
key: "invalid",
wantErr: true,
},
{
name: "find key with blank value",
key: "blank",
output: "",
},
{
name: "find key that has same content as a value",
key: "same",
output: "logical",
},
}
for _, tt := range tests {
cm := ConfigMap{Root: testYaml()}
t.Run(tt.name, func(t *testing.T) {
out, err := cm.FindEntry(tt.key)
if tt.wantErr {
assert.EqualError(t, err, "not found")
return
}
assert.NoError(t, err)
assert.Equal(t, tt.output, out.ValueNode.Value)
})
}
}
func TestEmpty(t *testing.T) {
cm := ConfigMap{}
assert.Equal(t, true, cm.Empty())
cm.Root = &yaml.Node{
Content: []*yaml.Node{
{
Value: "test",
},
},
}
assert.Equal(t, false, cm.Empty())
}
func TestGetStringValue(t *testing.T) {
tests := []struct {
name string
key string
wantValue string
wantErr bool
}{
{
name: "get key",
key: "valid",
wantValue: "present",
},
{
name: "get key that is not present",
key: "invalid",
wantErr: true,
},
{
name: "get key that has same content as a value",
key: "same",
wantValue: "logical",
},
}
for _, tt := range tests {
cm := ConfigMap{Root: testYaml()}
t.Run(tt.name, func(t *testing.T) {
val, err := cm.GetStringValue(tt.key)
if tt.wantErr {
assert.EqualError(t, err, "not found")
return
}
assert.Equal(t, tt.wantValue, val)
})
}
}
func TestSetStringValue(t *testing.T) {
tests := []struct {
name string
key string
value string
}{
{
name: "set key that is not present",
key: "notPresent",
value: "test1",
},
{
name: "set key that is present",
key: "erroneous",
value: "test2",
},
{
name: "set key that is blank",
key: "blank",
value: "test3",
},
{
name: "set key that has same content as a value",
key: "present",
value: "test4",
},
}
for _, tt := range tests {
cm := ConfigMap{Root: testYaml()}
t.Run(tt.name, func(t *testing.T) {
err := cm.SetStringValue(tt.key, tt.value)
assert.NoError(t, err)
val, err := cm.GetStringValue(tt.key)
assert.NoError(t, err)
assert.Equal(t, tt.value, val)
})
}
}
func TestRemoveEntry(t *testing.T) {
tests := []struct {
name string
key string
wantLength int
}{
{
name: "remove key",
key: "erroneous",
wantLength: 6,
},
{
name: "remove key that is not present",
key: "invalid",
wantLength: 8,
},
{
name: "remove key that has same content as a value",
key: "same",
wantLength: 6,
},
}
for _, tt := range tests {
cm := ConfigMap{Root: testYaml()}
t.Run(tt.name, func(t *testing.T) {
cm.RemoveEntry(tt.key)
assert.Equal(t, tt.wantLength, len(cm.Root.Content))
_, err := cm.FindEntry(tt.key)
assert.EqualError(t, err, "not found")
})
}
}
func testYaml() *yaml.Node {
var root yaml.Node
var data = `
valid: present
erroneous: same
blank:
same: logical
`
_ = yaml.Unmarshal([]byte(data), &root)
return root.Content[0]
}

View file

@ -0,0 +1,413 @@
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package config
import (
"sync"
)
// Ensure, that ConfigMock does implement Config.
// If this is not the case, regenerate this file with moq.
var _ Config = &ConfigMock{}
// ConfigMock is a mock implementation of Config.
//
// func TestSomethingThatUsesConfig(t *testing.T) {
//
// // make and configure a mocked Config
// mockedConfig := &ConfigMock{
// AliasesFunc: func() *AliasConfig {
// panic("mock out the Aliases method")
// },
// AuthTokenFunc: func(s string) (string, string) {
// panic("mock out the AuthToken method")
// },
// DefaultHostFunc: func() (string, string) {
// panic("mock out the DefaultHost method")
// },
// GetFunc: func(s1 string, s2 string) (string, error) {
// panic("mock out the Get method")
// },
// GetOrDefaultFunc: func(s1 string, s2 string) (string, error) {
// panic("mock out the GetOrDefault method")
// },
// HostsFunc: func() []string {
// panic("mock out the Hosts method")
// },
// SetFunc: func(s1 string, s2 string, s3 string) {
// panic("mock out the Set method")
// },
// UnsetHostFunc: func(s string) {
// panic("mock out the UnsetHost method")
// },
// WriteFunc: func() error {
// panic("mock out the Write method")
// },
// }
//
// // use mockedConfig in code that requires Config
// // and then make assertions.
//
// }
type ConfigMock struct {
// AliasesFunc mocks the Aliases method.
AliasesFunc func() *AliasConfig
// AuthTokenFunc mocks the AuthToken method.
AuthTokenFunc func(s string) (string, string)
// DefaultHostFunc mocks the DefaultHost method.
DefaultHostFunc func() (string, string)
// GetFunc mocks the Get method.
GetFunc func(s1 string, s2 string) (string, error)
// GetOrDefaultFunc mocks the GetOrDefault method.
GetOrDefaultFunc func(s1 string, s2 string) (string, error)
// HostsFunc mocks the Hosts method.
HostsFunc func() []string
// SetFunc mocks the Set method.
SetFunc func(s1 string, s2 string, s3 string)
// UnsetHostFunc mocks the UnsetHost method.
UnsetHostFunc func(s string)
// WriteFunc mocks the Write method.
WriteFunc func() error
// calls tracks calls to the methods.
calls struct {
// Aliases holds details about calls to the Aliases method.
Aliases []struct {
}
// AuthToken holds details about calls to the AuthToken method.
AuthToken []struct {
// S is the s argument value.
S string
}
// DefaultHost holds details about calls to the DefaultHost method.
DefaultHost []struct {
}
// Get holds details about calls to the Get method.
Get []struct {
// S1 is the s1 argument value.
S1 string
// S2 is the s2 argument value.
S2 string
}
// GetOrDefault holds details about calls to the GetOrDefault method.
GetOrDefault []struct {
// S1 is the s1 argument value.
S1 string
// S2 is the s2 argument value.
S2 string
}
// Hosts holds details about calls to the Hosts method.
Hosts []struct {
}
// Set holds details about calls to the Set method.
Set []struct {
// S1 is the s1 argument value.
S1 string
// S2 is the s2 argument value.
S2 string
// S3 is the s3 argument value.
S3 string
}
// UnsetHost holds details about calls to the UnsetHost method.
UnsetHost []struct {
// S is the s argument value.
S string
}
// Write holds details about calls to the Write method.
Write []struct {
}
}
lockAliases sync.RWMutex
lockAuthToken sync.RWMutex
lockDefaultHost sync.RWMutex
lockGet sync.RWMutex
lockGetOrDefault sync.RWMutex
lockHosts sync.RWMutex
lockSet sync.RWMutex
lockUnsetHost sync.RWMutex
lockWrite sync.RWMutex
}
// Aliases calls AliasesFunc.
func (mock *ConfigMock) Aliases() *AliasConfig {
if mock.AliasesFunc == nil {
panic("ConfigMock.AliasesFunc: method is nil but Config.Aliases was just called")
}
callInfo := struct {
}{}
mock.lockAliases.Lock()
mock.calls.Aliases = append(mock.calls.Aliases, callInfo)
mock.lockAliases.Unlock()
return mock.AliasesFunc()
}
// AliasesCalls gets all the calls that were made to Aliases.
// Check the length with:
// len(mockedConfig.AliasesCalls())
func (mock *ConfigMock) AliasesCalls() []struct {
} {
var calls []struct {
}
mock.lockAliases.RLock()
calls = mock.calls.Aliases
mock.lockAliases.RUnlock()
return calls
}
// AuthToken calls AuthTokenFunc.
func (mock *ConfigMock) AuthToken(s string) (string, string) {
if mock.AuthTokenFunc == nil {
panic("ConfigMock.AuthTokenFunc: method is nil but Config.AuthToken was just called")
}
callInfo := struct {
S string
}{
S: s,
}
mock.lockAuthToken.Lock()
mock.calls.AuthToken = append(mock.calls.AuthToken, callInfo)
mock.lockAuthToken.Unlock()
return mock.AuthTokenFunc(s)
}
// AuthTokenCalls gets all the calls that were made to AuthToken.
// Check the length with:
// len(mockedConfig.AuthTokenCalls())
func (mock *ConfigMock) AuthTokenCalls() []struct {
S string
} {
var calls []struct {
S string
}
mock.lockAuthToken.RLock()
calls = mock.calls.AuthToken
mock.lockAuthToken.RUnlock()
return calls
}
// DefaultHost calls DefaultHostFunc.
func (mock *ConfigMock) DefaultHost() (string, string) {
if mock.DefaultHostFunc == nil {
panic("ConfigMock.DefaultHostFunc: method is nil but Config.DefaultHost was just called")
}
callInfo := struct {
}{}
mock.lockDefaultHost.Lock()
mock.calls.DefaultHost = append(mock.calls.DefaultHost, callInfo)
mock.lockDefaultHost.Unlock()
return mock.DefaultHostFunc()
}
// DefaultHostCalls gets all the calls that were made to DefaultHost.
// Check the length with:
// len(mockedConfig.DefaultHostCalls())
func (mock *ConfigMock) DefaultHostCalls() []struct {
} {
var calls []struct {
}
mock.lockDefaultHost.RLock()
calls = mock.calls.DefaultHost
mock.lockDefaultHost.RUnlock()
return calls
}
// Get calls GetFunc.
func (mock *ConfigMock) Get(s1 string, s2 string) (string, error) {
if mock.GetFunc == nil {
panic("ConfigMock.GetFunc: method is nil but Config.Get was just called")
}
callInfo := struct {
S1 string
S2 string
}{
S1: s1,
S2: s2,
}
mock.lockGet.Lock()
mock.calls.Get = append(mock.calls.Get, callInfo)
mock.lockGet.Unlock()
return mock.GetFunc(s1, s2)
}
// GetCalls gets all the calls that were made to Get.
// Check the length with:
// len(mockedConfig.GetCalls())
func (mock *ConfigMock) GetCalls() []struct {
S1 string
S2 string
} {
var calls []struct {
S1 string
S2 string
}
mock.lockGet.RLock()
calls = mock.calls.Get
mock.lockGet.RUnlock()
return calls
}
// GetOrDefault calls GetOrDefaultFunc.
func (mock *ConfigMock) GetOrDefault(s1 string, s2 string) (string, error) {
if mock.GetOrDefaultFunc == nil {
panic("ConfigMock.GetOrDefaultFunc: method is nil but Config.GetOrDefault was just called")
}
callInfo := struct {
S1 string
S2 string
}{
S1: s1,
S2: s2,
}
mock.lockGetOrDefault.Lock()
mock.calls.GetOrDefault = append(mock.calls.GetOrDefault, callInfo)
mock.lockGetOrDefault.Unlock()
return mock.GetOrDefaultFunc(s1, s2)
}
// GetOrDefaultCalls gets all the calls that were made to GetOrDefault.
// Check the length with:
// len(mockedConfig.GetOrDefaultCalls())
func (mock *ConfigMock) GetOrDefaultCalls() []struct {
S1 string
S2 string
} {
var calls []struct {
S1 string
S2 string
}
mock.lockGetOrDefault.RLock()
calls = mock.calls.GetOrDefault
mock.lockGetOrDefault.RUnlock()
return calls
}
// Hosts calls HostsFunc.
func (mock *ConfigMock) Hosts() []string {
if mock.HostsFunc == nil {
panic("ConfigMock.HostsFunc: method is nil but Config.Hosts was just called")
}
callInfo := struct {
}{}
mock.lockHosts.Lock()
mock.calls.Hosts = append(mock.calls.Hosts, callInfo)
mock.lockHosts.Unlock()
return mock.HostsFunc()
}
// HostsCalls gets all the calls that were made to Hosts.
// Check the length with:
// len(mockedConfig.HostsCalls())
func (mock *ConfigMock) HostsCalls() []struct {
} {
var calls []struct {
}
mock.lockHosts.RLock()
calls = mock.calls.Hosts
mock.lockHosts.RUnlock()
return calls
}
// Set calls SetFunc.
func (mock *ConfigMock) Set(s1 string, s2 string, s3 string) {
if mock.SetFunc == nil {
panic("ConfigMock.SetFunc: method is nil but Config.Set was just called")
}
callInfo := struct {
S1 string
S2 string
S3 string
}{
S1: s1,
S2: s2,
S3: s3,
}
mock.lockSet.Lock()
mock.calls.Set = append(mock.calls.Set, callInfo)
mock.lockSet.Unlock()
mock.SetFunc(s1, s2, s3)
}
// SetCalls gets all the calls that were made to Set.
// Check the length with:
// len(mockedConfig.SetCalls())
func (mock *ConfigMock) SetCalls() []struct {
S1 string
S2 string
S3 string
} {
var calls []struct {
S1 string
S2 string
S3 string
}
mock.lockSet.RLock()
calls = mock.calls.Set
mock.lockSet.RUnlock()
return calls
}
// UnsetHost calls UnsetHostFunc.
func (mock *ConfigMock) UnsetHost(s string) {
if mock.UnsetHostFunc == nil {
panic("ConfigMock.UnsetHostFunc: method is nil but Config.UnsetHost was just called")
}
callInfo := struct {
S string
}{
S: s,
}
mock.lockUnsetHost.Lock()
mock.calls.UnsetHost = append(mock.calls.UnsetHost, callInfo)
mock.lockUnsetHost.Unlock()
mock.UnsetHostFunc(s)
}
// UnsetHostCalls gets all the calls that were made to UnsetHost.
// Check the length with:
// len(mockedConfig.UnsetHostCalls())
func (mock *ConfigMock) UnsetHostCalls() []struct {
S string
} {
var calls []struct {
S string
}
mock.lockUnsetHost.RLock()
calls = mock.calls.UnsetHost
mock.lockUnsetHost.RUnlock()
return calls
}
// Write calls WriteFunc.
func (mock *ConfigMock) Write() error {
if mock.WriteFunc == nil {
panic("ConfigMock.WriteFunc: method is nil but Config.Write was just called")
}
callInfo := struct {
}{}
mock.lockWrite.Lock()
mock.calls.Write = append(mock.calls.Write, callInfo)
mock.lockWrite.Unlock()
return mock.WriteFunc()
}
// WriteCalls gets all the calls that were made to Write.
// Check the length with:
// len(mockedConfig.WriteCalls())
func (mock *ConfigMock) WriteCalls() []struct {
} {
var calls []struct {
}
mock.lockWrite.RLock()
calls = mock.calls.Write
mock.lockWrite.RUnlock()
return calls
}

View file

@ -1,218 +0,0 @@
package config
import (
"fmt"
"gopkg.in/yaml.v3"
)
// This interface describes interacting with some persistent configuration for gh.
type Config interface {
Get(string, string) (string, error)
GetOrDefault(string, string) (string, error)
GetWithSource(string, string) (string, string, error)
GetOrDefaultWithSource(string, string) (string, string, error)
Default(string) string
Set(string, string, string) error
UnsetHost(string)
Hosts() ([]string, error)
DefaultHost() (string, error)
DefaultHostWithSource() (string, string, error)
Aliases() (*AliasConfig, error)
CheckWriteable(string, string) error
Write() error
WriteHosts() error
}
type ConfigOption struct {
Key string
Description string
DefaultValue string
AllowedValues []string
}
var configOptions = []ConfigOption{
{
Key: "git_protocol",
Description: "the protocol to use for git clone and push operations",
DefaultValue: "https",
AllowedValues: []string{"https", "ssh"},
},
{
Key: "editor",
Description: "the text editor program to use for authoring text",
DefaultValue: "",
},
{
Key: "prompt",
Description: "toggle interactive prompting in the terminal",
DefaultValue: "enabled",
AllowedValues: []string{"enabled", "disabled"},
},
{
Key: "pager",
Description: "the terminal pager program to send standard output to",
DefaultValue: "",
},
{
Key: "http_unix_socket",
Description: "the path to a Unix socket through which to make an HTTP connection",
DefaultValue: "",
},
{
Key: "browser",
Description: "the web browser to use for opening URLs",
DefaultValue: "",
},
}
func ConfigOptions() []ConfigOption {
return configOptions
}
func ValidateKey(key string) error {
for _, configKey := range configOptions {
if key == configKey.Key {
return nil
}
}
return fmt.Errorf("invalid key")
}
type InvalidValueError struct {
ValidValues []string
}
func (e InvalidValueError) Error() string {
return "invalid value"
}
func ValidateValue(key, value string) error {
var validValues []string
for _, v := range configOptions {
if v.Key == key {
validValues = v.AllowedValues
break
}
}
if validValues == nil {
return nil
}
for _, v := range validValues {
if v == value {
return nil
}
}
return &InvalidValueError{ValidValues: validValues}
}
func NewConfig(root *yaml.Node) Config {
return &fileConfig{
ConfigMap: ConfigMap{Root: root.Content[0]},
documentRoot: root,
}
}
// NewFromString initializes a Config from a yaml string
func NewFromString(str string) Config {
root, err := parseConfigData([]byte(str))
if err != nil {
panic(err)
}
return NewConfig(root)
}
// NewBlankConfig initializes a config file pre-populated with comments and default values
func NewBlankConfig() Config {
return NewConfig(NewBlankRoot())
}
func NewBlankRoot() *yaml.Node {
return &yaml.Node{
Kind: yaml.DocumentNode,
Content: []*yaml.Node{
{
Kind: yaml.MappingNode,
Content: []*yaml.Node{
{
HeadComment: "What protocol to use when performing git operations. Supported values: ssh, https",
Kind: yaml.ScalarNode,
Value: "git_protocol",
},
{
Kind: yaml.ScalarNode,
Value: "https",
},
{
HeadComment: "What editor gh should run when creating issues, pull requests, etc. If blank, will refer to environment.",
Kind: yaml.ScalarNode,
Value: "editor",
},
{
Kind: yaml.ScalarNode,
Value: "",
},
{
HeadComment: "When to interactively prompt. This is a global config that cannot be overridden by hostname. Supported values: enabled, disabled",
Kind: yaml.ScalarNode,
Value: "prompt",
},
{
Kind: yaml.ScalarNode,
Value: "enabled",
},
{
HeadComment: "A pager program to send command output to, e.g. \"less\". Set the value to \"cat\" to disable the pager.",
Kind: yaml.ScalarNode,
Value: "pager",
},
{
Kind: yaml.ScalarNode,
Value: "",
},
{
HeadComment: "Aliases allow you to create nicknames for gh commands",
Kind: yaml.ScalarNode,
Value: "aliases",
},
{
Kind: yaml.MappingNode,
Content: []*yaml.Node{
{
Kind: yaml.ScalarNode,
Value: "co",
},
{
Kind: yaml.ScalarNode,
Value: "pr checkout",
},
},
},
{
HeadComment: "The path to a unix socket through which send HTTP connections. If blank, HTTP traffic will be handled by net/http.DefaultTransport.",
Kind: yaml.ScalarNode,
Value: "http_unix_socket",
},
{
Kind: yaml.ScalarNode,
Value: "",
},
{
HeadComment: "What web browser gh should use when opening URLs. If blank, will refer to environment.",
Kind: yaml.ScalarNode,
Value: "browser",
},
{
Kind: yaml.ScalarNode,
Value: "",
},
},
},
},
}
}

View file

@ -1,118 +0,0 @@
package config
import (
"bytes"
"testing"
"github.com/MakeNowJust/heredoc"
"github.com/stretchr/testify/assert"
)
func Test_fileConfig_Set(t *testing.T) {
mainBuf := bytes.Buffer{}
hostsBuf := bytes.Buffer{}
defer StubWriteConfig(&mainBuf, &hostsBuf)()
c := NewBlankConfig()
assert.NoError(t, c.Set("", "editor", "nano"))
assert.NoError(t, c.Set("github.com", "git_protocol", "ssh"))
assert.NoError(t, c.Set("example.com", "editor", "vim"))
assert.NoError(t, c.Set("github.com", "user", "hubot"))
assert.NoError(t, c.Write())
assert.Contains(t, mainBuf.String(), "editor: nano")
assert.Contains(t, mainBuf.String(), "git_protocol: https")
assert.Equal(t, `github.com:
git_protocol: ssh
user: hubot
example.com:
editor: vim
`, hostsBuf.String())
}
func Test_defaultConfig(t *testing.T) {
mainBuf := bytes.Buffer{}
hostsBuf := bytes.Buffer{}
defer StubWriteConfig(&mainBuf, &hostsBuf)()
cfg := NewBlankConfig()
assert.NoError(t, cfg.Write())
expected := heredoc.Doc(`
# What protocol to use when performing git operations. Supported values: ssh, https
git_protocol: https
# What editor gh should run when creating issues, pull requests, etc. If blank, will refer to environment.
editor:
# When to interactively prompt. This is a global config that cannot be overridden by hostname. Supported values: enabled, disabled
prompt: enabled
# A pager program to send command output to, e.g. "less". Set the value to "cat" to disable the pager.
pager:
# Aliases allow you to create nicknames for gh commands
aliases:
co: pr checkout
# The path to a unix socket through which send HTTP connections. If blank, HTTP traffic will be handled by net/http.DefaultTransport.
http_unix_socket:
# What web browser gh should use when opening URLs. If blank, will refer to environment.
browser:
`)
assert.Equal(t, expected, mainBuf.String())
assert.Equal(t, "", hostsBuf.String())
proto, err := cfg.GetOrDefault("", "git_protocol")
assert.NoError(t, err)
assert.Equal(t, "https", proto)
editor, err := cfg.Get("", "editor")
assert.NoError(t, err)
assert.Equal(t, "", editor)
aliases, err := cfg.Aliases()
assert.NoError(t, err)
assert.Equal(t, len(aliases.All()), 1)
expansion, _ := aliases.Get("co")
assert.Equal(t, expansion, "pr checkout")
browser, err := cfg.Get("", "browser")
assert.NoError(t, err)
assert.Equal(t, "", browser)
}
func Test_ValidateValue(t *testing.T) {
err := ValidateValue("git_protocol", "sshpps")
assert.EqualError(t, err, "invalid value")
err = ValidateValue("git_protocol", "ssh")
assert.NoError(t, err)
err = ValidateValue("editor", "vim")
assert.NoError(t, err)
err = ValidateValue("got", "123")
assert.NoError(t, err)
err = ValidateValue("http_unix_socket", "really_anything/is/allowed/and/net.Dial\\(...\\)/will/ultimately/validate")
assert.NoError(t, err)
}
func Test_ValidateKey(t *testing.T) {
err := ValidateKey("invalid")
assert.EqualError(t, err, "invalid key")
err = ValidateKey("git_protocol")
assert.NoError(t, err)
err = ValidateKey("editor")
assert.NoError(t, err)
err = ValidateKey("prompt")
assert.NoError(t, err)
err = ValidateKey("pager")
assert.NoError(t, err)
err = ValidateKey("http_unix_socket")
assert.NoError(t, err)
err = ValidateKey("browser")
assert.NoError(t, err)
}

View file

@ -1,156 +0,0 @@
package config
import (
"fmt"
"os"
"sort"
"strconv"
"github.com/cli/cli/v2/internal/ghinstance"
"github.com/cli/cli/v2/pkg/set"
)
const (
GH_HOST = "GH_HOST"
GH_TOKEN = "GH_TOKEN"
GITHUB_TOKEN = "GITHUB_TOKEN"
GH_ENTERPRISE_TOKEN = "GH_ENTERPRISE_TOKEN"
GITHUB_ENTERPRISE_TOKEN = "GITHUB_ENTERPRISE_TOKEN"
CODESPACES = "CODESPACES"
)
type ReadOnlyEnvError struct {
Variable string
}
func (e *ReadOnlyEnvError) Error() string {
return fmt.Sprintf("read-only value in %s", e.Variable)
}
func InheritEnv(c Config) Config {
return &envConfig{Config: c}
}
type envConfig struct {
Config
}
func (c *envConfig) Hosts() ([]string, error) {
hosts, err := c.Config.Hosts()
if err != nil {
return nil, err
}
hostSet := set.NewStringSet()
hostSet.AddValues(hosts)
// If GH_HOST is set then add it to list.
if host := os.Getenv(GH_HOST); host != "" {
hostSet.Add(host)
}
// If there is a valid environment variable token for the
// default host then add default host to list.
if token, _ := AuthTokenFromEnv(ghinstance.Default()); token != "" {
hostSet.Add(ghinstance.Default())
}
s := hostSet.ToSlice()
// If default host is in list then move it to the front.
sort.SliceStable(s, func(i, j int) bool { return s[i] == ghinstance.Default() })
return s, nil
}
func (c *envConfig) DefaultHost() (string, error) {
val, _, err := c.DefaultHostWithSource()
return val, err
}
func (c *envConfig) DefaultHostWithSource() (string, string, error) {
if host := os.Getenv(GH_HOST); host != "" {
return host, GH_HOST, nil
}
return c.Config.DefaultHostWithSource()
}
func (c *envConfig) Get(hostname, key string) (string, error) {
val, _, err := c.GetWithSource(hostname, key)
return val, err
}
func (c *envConfig) GetWithSource(hostname, key string) (string, string, error) {
if hostname != "" && key == "oauth_token" {
if token, env := AuthTokenFromEnv(hostname); token != "" {
return token, env, nil
}
}
return c.Config.GetWithSource(hostname, key)
}
func (c *envConfig) GetOrDefault(hostname, key string) (val string, err error) {
val, _, err = c.GetOrDefaultWithSource(hostname, key)
return
}
func (c *envConfig) GetOrDefaultWithSource(hostname, key string) (val string, src string, err error) {
val, src, err = c.GetWithSource(hostname, key)
if err == nil && val == "" {
val = c.Default(key)
}
return
}
func (c *envConfig) Default(key string) string {
return c.Config.Default(key)
}
func (c *envConfig) CheckWriteable(hostname, key string) error {
if hostname != "" && key == "oauth_token" {
if token, env := AuthTokenFromEnv(hostname); token != "" {
return &ReadOnlyEnvError{Variable: env}
}
}
return c.Config.CheckWriteable(hostname, key)
}
func AuthTokenFromEnv(hostname string) (string, string) {
if ghinstance.IsEnterprise(hostname) {
if token := os.Getenv(GH_ENTERPRISE_TOKEN); token != "" {
return token, GH_ENTERPRISE_TOKEN
}
if token := os.Getenv(GITHUB_ENTERPRISE_TOKEN); token != "" {
return token, GITHUB_ENTERPRISE_TOKEN
}
if isCodespaces, _ := strconv.ParseBool(os.Getenv(CODESPACES)); isCodespaces {
return os.Getenv(GITHUB_TOKEN), GITHUB_TOKEN
}
return "", ""
}
if token := os.Getenv(GH_TOKEN); token != "" {
return token, GH_TOKEN
}
return os.Getenv(GITHUB_TOKEN), GITHUB_TOKEN
}
func AuthTokenProvidedFromEnv() bool {
return os.Getenv(GH_ENTERPRISE_TOKEN) != "" ||
os.Getenv(GITHUB_ENTERPRISE_TOKEN) != "" ||
os.Getenv(GH_TOKEN) != "" ||
os.Getenv(GITHUB_TOKEN) != ""
}
func IsHostEnv(src string) bool {
return src == GH_HOST
}
func IsEnterpriseEnv(src string) bool {
return src == GH_ENTERPRISE_TOKEN || src == GITHUB_ENTERPRISE_TOKEN
}

View file

@ -1,389 +0,0 @@
package config
import (
"os"
"testing"
"github.com/MakeNowJust/heredoc"
"github.com/stretchr/testify/assert"
)
func setenv(t *testing.T, key, newValue string) {
oldValue, hasValue := os.LookupEnv(key)
os.Setenv(key, newValue)
t.Cleanup(func() {
if hasValue {
os.Setenv(key, oldValue)
} else {
os.Unsetenv(key)
}
})
}
func TestInheritEnv(t *testing.T) {
orig_GITHUB_TOKEN := os.Getenv("GITHUB_TOKEN")
orig_GITHUB_ENTERPRISE_TOKEN := os.Getenv("GITHUB_ENTERPRISE_TOKEN")
orig_GH_TOKEN := os.Getenv("GH_TOKEN")
orig_GH_ENTERPRISE_TOKEN := os.Getenv("GH_ENTERPRISE_TOKEN")
orig_AppData := os.Getenv("AppData")
t.Cleanup(func() {
os.Setenv("GITHUB_TOKEN", orig_GITHUB_TOKEN)
os.Setenv("GITHUB_ENTERPRISE_TOKEN", orig_GITHUB_ENTERPRISE_TOKEN)
os.Setenv("GH_TOKEN", orig_GH_TOKEN)
os.Setenv("GH_ENTERPRISE_TOKEN", orig_GH_ENTERPRISE_TOKEN)
os.Setenv("AppData", orig_AppData)
})
type wants struct {
hosts []string
token string
source string
writeable bool
}
tests := []struct {
name string
baseConfig string
GH_HOST string
GITHUB_TOKEN string
GITHUB_ENTERPRISE_TOKEN string
GH_TOKEN string
GH_ENTERPRISE_TOKEN string
CODESPACES string
hostname string
wants wants
}{
{
name: "blank",
baseConfig: ``,
hostname: "github.com",
wants: wants{
hosts: []string{},
token: "",
source: ".config.gh.config.yml",
writeable: true,
},
},
{
name: "GITHUB_TOKEN over blank config",
baseConfig: ``,
GITHUB_TOKEN: "OTOKEN",
hostname: "github.com",
wants: wants{
hosts: []string{"github.com"},
token: "OTOKEN",
source: "GITHUB_TOKEN",
writeable: false,
},
},
{
name: "GH_TOKEN over blank config",
baseConfig: ``,
GH_TOKEN: "OTOKEN",
hostname: "github.com",
wants: wants{
hosts: []string{"github.com"},
token: "OTOKEN",
source: "GH_TOKEN",
writeable: false,
},
},
{
name: "GITHUB_TOKEN not applicable to GHE",
baseConfig: ``,
GITHUB_TOKEN: "OTOKEN",
hostname: "example.org",
wants: wants{
hosts: []string{"github.com"},
token: "",
source: ".config.gh.config.yml",
writeable: true,
},
},
{
name: "GH_TOKEN not applicable to GHE",
baseConfig: ``,
GH_TOKEN: "OTOKEN",
hostname: "example.org",
wants: wants{
hosts: []string{"github.com"},
token: "",
source: ".config.gh.config.yml",
writeable: true,
},
},
{
name: "GITHUB_TOKEN allowed in Codespaces",
baseConfig: ``,
GITHUB_TOKEN: "OTOKEN",
hostname: "example.org",
CODESPACES: "true",
wants: wants{
hosts: []string{"github.com"},
token: "OTOKEN",
source: "GITHUB_TOKEN",
writeable: false,
},
},
{
name: "GITHUB_ENTERPRISE_TOKEN over blank config",
baseConfig: ``,
GITHUB_ENTERPRISE_TOKEN: "ENTOKEN",
hostname: "example.org",
wants: wants{
hosts: []string{},
token: "ENTOKEN",
source: "GITHUB_ENTERPRISE_TOKEN",
writeable: false,
},
},
{
name: "GH_ENTERPRISE_TOKEN over blank config",
baseConfig: ``,
GH_ENTERPRISE_TOKEN: "ENTOKEN",
hostname: "example.org",
wants: wants{
hosts: []string{},
token: "ENTOKEN",
source: "GH_ENTERPRISE_TOKEN",
writeable: false,
},
},
{
name: "token from file",
baseConfig: heredoc.Doc(`
hosts:
github.com:
oauth_token: OTOKEN
`),
hostname: "github.com",
wants: wants{
hosts: []string{"github.com"},
token: "OTOKEN",
source: ".config.gh.hosts.yml",
writeable: true,
},
},
{
name: "GITHUB_TOKEN shadows token from file",
baseConfig: heredoc.Doc(`
hosts:
github.com:
oauth_token: OTOKEN
`),
GITHUB_TOKEN: "ENVTOKEN",
hostname: "github.com",
wants: wants{
hosts: []string{"github.com"},
token: "ENVTOKEN",
source: "GITHUB_TOKEN",
writeable: false,
},
},
{
name: "GH_TOKEN shadows token from file",
baseConfig: heredoc.Doc(`
hosts:
github.com:
oauth_token: OTOKEN
`),
GH_TOKEN: "ENVTOKEN",
hostname: "github.com",
wants: wants{
hosts: []string{"github.com"},
token: "ENVTOKEN",
source: "GH_TOKEN",
writeable: false,
},
},
{
name: "GITHUB_ENTERPRISE_TOKEN shadows token from file",
baseConfig: heredoc.Doc(`
hosts:
example.org:
oauth_token: OTOKEN
`),
GITHUB_ENTERPRISE_TOKEN: "ENVTOKEN",
hostname: "example.org",
wants: wants{
hosts: []string{"example.org"},
token: "ENVTOKEN",
source: "GITHUB_ENTERPRISE_TOKEN",
writeable: false,
},
},
{
name: "GH_ENTERPRISE_TOKEN shadows token from file",
baseConfig: heredoc.Doc(`
hosts:
example.org:
oauth_token: OTOKEN
`),
GH_ENTERPRISE_TOKEN: "ENVTOKEN",
hostname: "example.org",
wants: wants{
hosts: []string{"example.org"},
token: "ENVTOKEN",
source: "GH_ENTERPRISE_TOKEN",
writeable: false,
},
},
{
name: "GH_TOKEN shadows token from GITHUB_TOKEN",
baseConfig: ``,
GH_TOKEN: "GHTOKEN",
GITHUB_TOKEN: "GITHUBTOKEN",
hostname: "github.com",
wants: wants{
hosts: []string{"github.com"},
token: "GHTOKEN",
source: "GH_TOKEN",
writeable: false,
},
},
{
name: "GH_ENTERPRISE_TOKEN shadows token from GITHUB_ENTERPRISE_TOKEN",
baseConfig: ``,
GH_ENTERPRISE_TOKEN: "GHTOKEN",
GITHUB_ENTERPRISE_TOKEN: "GITHUBTOKEN",
hostname: "example.org",
wants: wants{
hosts: []string{},
token: "GHTOKEN",
source: "GH_ENTERPRISE_TOKEN",
writeable: false,
},
},
{
name: "GITHUB_TOKEN adds host entry",
baseConfig: heredoc.Doc(`
hosts:
example.org:
oauth_token: OTOKEN
`),
GITHUB_TOKEN: "ENVTOKEN",
hostname: "github.com",
wants: wants{
hosts: []string{"github.com", "example.org"},
token: "ENVTOKEN",
source: "GITHUB_TOKEN",
writeable: false,
},
},
{
name: "GH_TOKEN adds host entry",
baseConfig: heredoc.Doc(`
hosts:
example.org:
oauth_token: OTOKEN
`),
GH_TOKEN: "ENVTOKEN",
hostname: "github.com",
wants: wants{
hosts: []string{"github.com", "example.org"},
token: "ENVTOKEN",
source: "GH_TOKEN",
writeable: false,
},
},
{
name: "GH_HOST adds host entry when paired with environment token",
baseConfig: ``,
GH_HOST: "example.org",
GH_ENTERPRISE_TOKEN: "GH_ENTERPRISE_TOKEN",
hostname: "example.org",
wants: wants{
hosts: []string{"example.org"},
token: "GH_ENTERPRISE_TOKEN",
source: "GH_ENTERPRISE_TOKEN",
writeable: false,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
setenv(t, "GH_HOST", tt.GH_HOST)
setenv(t, "GITHUB_TOKEN", tt.GITHUB_TOKEN)
setenv(t, "GITHUB_ENTERPRISE_TOKEN", tt.GITHUB_ENTERPRISE_TOKEN)
setenv(t, "GH_TOKEN", tt.GH_TOKEN)
setenv(t, "GH_ENTERPRISE_TOKEN", tt.GH_ENTERPRISE_TOKEN)
setenv(t, "AppData", "")
setenv(t, "CODESPACES", tt.CODESPACES)
baseCfg := NewFromString(tt.baseConfig)
cfg := InheritEnv(baseCfg)
hosts, _ := cfg.Hosts()
assert.Equal(t, tt.wants.hosts, hosts)
val, source, _ := cfg.GetWithSource(tt.hostname, "oauth_token")
assert.Equal(t, tt.wants.token, val)
assert.Regexp(t, tt.wants.source, source)
val, _ = cfg.Get(tt.hostname, "oauth_token")
assert.Equal(t, tt.wants.token, val)
err := cfg.CheckWriteable(tt.hostname, "oauth_token")
if tt.wants.writeable != (err == nil) {
t.Errorf("CheckWriteable() = %v, wants %v", err, tt.wants.writeable)
}
})
}
}
func TestAuthTokenProvidedFromEnv(t *testing.T) {
orig_GITHUB_TOKEN := os.Getenv("GITHUB_TOKEN")
orig_GITHUB_ENTERPRISE_TOKEN := os.Getenv("GITHUB_ENTERPRISE_TOKEN")
orig_GH_TOKEN := os.Getenv("GH_TOKEN")
orig_GH_ENTERPRISE_TOKEN := os.Getenv("GH_ENTERPRISE_TOKEN")
t.Cleanup(func() {
os.Setenv("GITHUB_TOKEN", orig_GITHUB_TOKEN)
os.Setenv("GITHUB_ENTERPRISE_TOKEN", orig_GITHUB_ENTERPRISE_TOKEN)
os.Setenv("GH_TOKEN", orig_GH_TOKEN)
os.Setenv("GH_ENTERPRISE_TOKEN", orig_GH_ENTERPRISE_TOKEN)
})
tests := []struct {
name string
GITHUB_TOKEN string
GITHUB_ENTERPRISE_TOKEN string
GH_TOKEN string
GH_ENTERPRISE_TOKEN string
provided bool
}{
{
name: "no env tokens",
provided: false,
},
{
name: "GH_TOKEN",
GH_TOKEN: "TOKEN",
provided: true,
},
{
name: "GITHUB_TOKEN",
GITHUB_TOKEN: "TOKEN",
provided: true,
},
{
name: "GH_ENTERPRISE_TOKEN",
GH_ENTERPRISE_TOKEN: "TOKEN",
provided: true,
},
{
name: "GITHUB_ENTERPRISE_TOKEN",
GITHUB_ENTERPRISE_TOKEN: "TOKEN",
provided: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
os.Setenv("GITHUB_TOKEN", tt.GITHUB_TOKEN)
os.Setenv("GITHUB_ENTERPRISE_TOKEN", tt.GITHUB_ENTERPRISE_TOKEN)
os.Setenv("GH_TOKEN", tt.GH_TOKEN)
os.Setenv("GH_ENTERPRISE_TOKEN", tt.GH_ENTERPRISE_TOKEN)
assert.Equal(t, tt.provided, AuthTokenProvidedFromEnv())
})
}
}

View file

@ -1,342 +0,0 @@
package config
import (
"bytes"
"errors"
"fmt"
"sort"
"strings"
"github.com/cli/cli/v2/internal/ghinstance"
"gopkg.in/yaml.v3"
)
// This type implements a Config interface and represents a config file on disk.
type fileConfig struct {
ConfigMap
documentRoot *yaml.Node
}
type HostConfig struct {
ConfigMap
Host string
}
func (c *fileConfig) Root() *yaml.Node {
return c.ConfigMap.Root
}
func (c *fileConfig) Get(hostname, key string) (string, error) {
val, _, err := c.GetWithSource(hostname, key)
return val, err
}
func (c *fileConfig) GetWithSource(hostname, key string) (string, string, error) {
if hostname != "" {
var notFound *NotFoundError
hostCfg, err := c.configForHost(hostname)
if err != nil && !errors.As(err, &notFound) {
return "", "", err
}
var hostValue string
if hostCfg != nil {
hostValue, err = hostCfg.GetStringValue(key)
if err != nil && !errors.As(err, &notFound) {
return "", "", err
}
}
if hostValue != "" {
return hostValue, HostsConfigFile(), nil
}
}
defaultSource := ConfigFile()
value, err := c.GetStringValue(key)
var notFound *NotFoundError
if err != nil && errors.As(err, &notFound) {
return defaultFor(key), defaultSource, nil
} else if err != nil {
return "", defaultSource, err
}
return value, defaultSource, nil
}
func (c *fileConfig) GetOrDefault(hostname, key string) (val string, err error) {
val, _, err = c.GetOrDefaultWithSource(hostname, key)
return
}
func (c *fileConfig) GetOrDefaultWithSource(hostname, key string) (val string, src string, err error) {
val, src, err = c.GetWithSource(hostname, key)
if err != nil && val == "" {
val = c.Default(key)
}
return
}
func (c *fileConfig) Default(key string) string {
return defaultFor(key)
}
func (c *fileConfig) Set(hostname, key, value string) error {
if hostname == "" {
return c.SetStringValue(key, value)
} else {
hostCfg, err := c.configForHost(hostname)
var notFound *NotFoundError
if errors.As(err, &notFound) {
hostCfg = c.makeConfigForHost(hostname)
} else if err != nil {
return err
}
return hostCfg.SetStringValue(key, value)
}
}
func (c *fileConfig) UnsetHost(hostname string) {
if hostname == "" {
return
}
hostsEntry, err := c.FindEntry("hosts")
if err != nil {
return
}
cm := ConfigMap{hostsEntry.ValueNode}
cm.RemoveEntry(hostname)
}
func (c *fileConfig) configForHost(hostname string) (*HostConfig, error) {
hosts, err := c.hostEntries()
if err != nil {
return nil, err
}
for _, hc := range hosts {
if strings.EqualFold(hc.Host, hostname) {
return hc, nil
}
}
return nil, &NotFoundError{fmt.Errorf("could not find config entry for %q", hostname)}
}
func (c *fileConfig) CheckWriteable(hostname, key string) error {
// TODO: check filesystem permissions
return nil
}
func (c *fileConfig) Write() error {
mainData := yaml.Node{Kind: yaml.MappingNode}
nodes := c.documentRoot.Content[0].Content
for i := 0; i < len(nodes)-1; i += 2 {
if nodes[i].Value != "hosts" {
mainData.Content = append(mainData.Content, nodes[i], nodes[i+1])
}
}
mainBytes, err := yaml.Marshal(&mainData)
if err != nil {
return err
}
err = WriteConfigFile(ConfigFile(), yamlNormalize(mainBytes))
if err != nil {
return err
}
return c.WriteHosts()
}
// Write the hosts config file only, so as to allow logging in when the main
// config file is not writable.
func (c *fileConfig) WriteHosts() error {
hostsData := yaml.Node{Kind: yaml.MappingNode}
nodes := c.documentRoot.Content[0].Content
for i := 0; i < len(nodes)-1; i += 2 {
if nodes[i].Value == "hosts" {
hostsData.Content = append(hostsData.Content, nodes[i+1].Content...)
}
}
hostsBytes, err := yaml.Marshal(&hostsData)
if err != nil {
return err
}
return WriteConfigFile(HostsConfigFile(), yamlNormalize(hostsBytes))
}
func (c *fileConfig) Aliases() (*AliasConfig, error) {
// The complexity here is for dealing with either a missing or empty aliases key. It's something
// we'll likely want for other config sections at some point.
entry, err := c.FindEntry("aliases")
var nfe *NotFoundError
notFound := errors.As(err, &nfe)
if err != nil && !notFound {
return nil, err
}
toInsert := []*yaml.Node{}
keyNode := entry.KeyNode
valueNode := entry.ValueNode
if keyNode == nil {
keyNode = &yaml.Node{
Kind: yaml.ScalarNode,
Value: "aliases",
}
toInsert = append(toInsert, keyNode)
}
if valueNode == nil || valueNode.Kind != yaml.MappingNode {
valueNode = &yaml.Node{
Kind: yaml.MappingNode,
Value: "",
}
toInsert = append(toInsert, valueNode)
}
if len(toInsert) > 0 {
newContent := []*yaml.Node{}
if notFound {
newContent = append(c.Root().Content, keyNode, valueNode)
} else {
for i := 0; i < len(c.Root().Content); i++ {
if i == entry.Index {
newContent = append(newContent, keyNode, valueNode)
i++
} else {
newContent = append(newContent, c.Root().Content[i])
}
}
}
c.Root().Content = newContent
}
return &AliasConfig{
Parent: c,
ConfigMap: ConfigMap{Root: valueNode},
}, nil
}
func (c *fileConfig) hostEntries() ([]*HostConfig, error) {
entry, err := c.FindEntry("hosts")
if err != nil {
return []*HostConfig{}, nil
}
hostConfigs, err := c.parseHosts(entry.ValueNode)
if err != nil {
return nil, fmt.Errorf("could not parse hosts config: %w", err)
}
return hostConfigs, nil
}
// Hosts returns a list of all known hostnames configured in hosts.yml
func (c *fileConfig) Hosts() ([]string, error) {
entries, err := c.hostEntries()
if err != nil {
return nil, err
}
hostnames := []string{}
for _, entry := range entries {
hostnames = append(hostnames, entry.Host)
}
sort.SliceStable(hostnames, func(i, j int) bool { return hostnames[i] == ghinstance.Default() })
return hostnames, nil
}
func (c *fileConfig) DefaultHost() (string, error) {
val, _, err := c.DefaultHostWithSource()
return val, err
}
func (c *fileConfig) DefaultHostWithSource() (string, string, error) {
hosts, err := c.Hosts()
if err == nil && len(hosts) == 1 {
return hosts[0], HostsConfigFile(), nil
}
return ghinstance.Default(), "", nil
}
func (c *fileConfig) makeConfigForHost(hostname string) *HostConfig {
hostRoot := &yaml.Node{Kind: yaml.MappingNode}
hostCfg := &HostConfig{
Host: hostname,
ConfigMap: ConfigMap{Root: hostRoot},
}
var notFound *NotFoundError
hostsEntry, err := c.FindEntry("hosts")
if errors.As(err, &notFound) {
hostsEntry.KeyNode = &yaml.Node{
Kind: yaml.ScalarNode,
Value: "hosts",
}
hostsEntry.ValueNode = &yaml.Node{Kind: yaml.MappingNode}
root := c.Root()
root.Content = append(root.Content, hostsEntry.KeyNode, hostsEntry.ValueNode)
} else if err != nil {
panic(err)
}
hostsEntry.ValueNode.Content = append(hostsEntry.ValueNode.Content,
&yaml.Node{
Kind: yaml.ScalarNode,
Value: hostname,
}, hostRoot)
return hostCfg
}
func (c *fileConfig) parseHosts(hostsEntry *yaml.Node) ([]*HostConfig, error) {
hostConfigs := []*HostConfig{}
for i := 0; i < len(hostsEntry.Content)-1; i = i + 2 {
hostname := hostsEntry.Content[i].Value
hostRoot := hostsEntry.Content[i+1]
hostConfig := HostConfig{
ConfigMap: ConfigMap{Root: hostRoot},
Host: hostname,
}
hostConfigs = append(hostConfigs, &hostConfig)
}
if len(hostConfigs) == 0 {
return nil, errors.New("could not find any host configurations")
}
return hostConfigs, nil
}
func yamlNormalize(b []byte) []byte {
if bytes.Equal(b, []byte("{}\n")) {
return []byte{}
}
return b
}
func defaultFor(key string) string {
for _, co := range configOptions {
if co.Key == key {
return co.DefaultValue
}
}
return ""
}

View file

@ -1,15 +0,0 @@
package config
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_fileConfig_Hosts(t *testing.T) {
c := NewBlankConfig()
hosts, err := c.Hosts()
require.NoError(t, err)
assert.Equal(t, []string{}, hosts)
}

View file

@ -1,80 +1,107 @@
package config
import (
"errors"
"io"
"os"
"path/filepath"
"testing"
ghConfig "github.com/cli/go-gh/pkg/config"
)
type ConfigStub map[string]string
func NewBlankConfig() *ConfigMock {
defaultStr := `
# What protocol to use when performing git operations. Supported values: ssh, https
git_protocol: https
# What editor gh should run when creating issues, pull requests, etc. If blank, will refer to environment.
editor:
# When to interactively prompt. This is a global config that cannot be overridden by hostname. Supported values: enabled, disabled
prompt: enabled
# A pager program to send command output to, e.g. "less". Set the value to "cat" to disable the pager.
pager:
# Aliases allow you to create nicknames for gh commands
aliases:
co: pr checkout
# The path to a unix socket through which send HTTP connections. If blank, HTTP traffic will be handled by net/http.DefaultTransport.
http_unix_socket:
# What web browser gh should use when opening URLs. If blank, will refer to environment.
browser:
`
return NewFromString(defaultStr)
}
func genKey(host, key string) string {
if host != "" {
return host + ":" + key
func NewFromString(cfgStr string) *ConfigMock {
c := ghConfig.ReadFromString(cfgStr)
cfg := cfg{c}
mock := &ConfigMock{}
mock.AuthTokenFunc = func(host string) (string, string) {
token, _ := c.Get([]string{"hosts", host, "oauth_token"})
return token, "oauth_token"
}
return key
}
func (c ConfigStub) Get(host, key string) (string, error) {
val, _, err := c.GetWithSource(host, key)
return val, err
}
func (c ConfigStub) GetWithSource(host, key string) (string, string, error) {
if v, found := c[genKey(host, key)]; found {
return v, "(memory)", nil
mock.GetFunc = func(host, key string) (string, error) {
return cfg.Get(host, key)
}
return "", "", errors.New("not found")
}
func (c ConfigStub) GetOrDefault(hostname, key string) (val string, err error) {
val, _, err = c.GetOrDefaultWithSource(hostname, key)
return
}
func (c ConfigStub) GetOrDefaultWithSource(hostname, key string) (val string, src string, err error) {
val, src, err = c.GetWithSource(hostname, key)
if err == nil && val == "" {
val = c.Default(key)
mock.GetOrDefaultFunc = func(host, key string) (string, error) {
return cfg.GetOrDefault(host, key)
}
return
mock.SetFunc = func(host, key, value string) {
cfg.Set(host, key, value)
}
mock.UnsetHostFunc = func(host string) {
cfg.UnsetHost(host)
}
mock.HostsFunc = func() []string {
keys, _ := c.Keys([]string{"hosts"})
return keys
}
mock.DefaultHostFunc = func() (string, string) {
return "github.com", "default"
}
mock.AliasesFunc = func() *AliasConfig {
return &AliasConfig{cfg: c}
}
mock.WriteFunc = func() error {
return cfg.Write()
}
return mock
}
func (c ConfigStub) Default(key string) string {
return defaultFor(key)
}
// StubWriteConfig stubs out the filesystem where config file are written.
// It then returns a function that will read in the config files into io.Writers.
// It automatically cleans up environment variables and written files.
func StubWriteConfig(t *testing.T) func(io.Writer, io.Writer) {
t.Helper()
tempDir := t.TempDir()
old := os.Getenv("GH_CONFIG_DIR")
os.Setenv("GH_CONFIG_DIR", tempDir)
t.Cleanup(func() { os.Setenv("GH_CONFIG_DIR", old) })
return func(wc io.Writer, wh io.Writer) {
config, err := os.Open(filepath.Join(tempDir, "config.yml"))
if err != nil {
return
}
defer config.Close()
configData, err := io.ReadAll(config)
if err != nil {
return
}
_, err = wc.Write(configData)
if err != nil {
return
}
func (c ConfigStub) Set(host, key, value string) error {
c[genKey(host, key)] = value
return nil
}
func (c ConfigStub) Aliases() (*AliasConfig, error) {
return nil, nil
}
func (c ConfigStub) Hosts() ([]string, error) {
return nil, nil
}
func (c ConfigStub) UnsetHost(hostname string) {
}
func (c ConfigStub) CheckWriteable(host, key string) error {
return nil
}
func (c ConfigStub) Write() error {
c["_written"] = "true"
return nil
}
func (c ConfigStub) WriteHosts() error {
return nil
}
func (c ConfigStub) DefaultHost() (string, error) {
return "", nil
}
func (c ConfigStub) DefaultHostWithSource() (string, string, error) {
return "", "", nil
hosts, err := os.Open(filepath.Join(tempDir, "hosts.yml"))
if err != nil {
return
}
defer hosts.Close()
hostsData, err := io.ReadAll(hosts)
if err != nil {
return
}
_, err = wh.Write(hostsData)
if err != nil {
return
}
}
}

View file

@ -1,64 +0,0 @@
package config
import (
"fmt"
"io"
"os"
"path/filepath"
)
func StubBackupConfig() func() {
orig := BackupConfigFile
BackupConfigFile = func(_ string) error {
return nil
}
return func() {
BackupConfigFile = orig
}
}
func StubWriteConfig(wc io.Writer, wh io.Writer) func() {
orig := WriteConfigFile
WriteConfigFile = func(fn string, data []byte) error {
switch filepath.Base(fn) {
case "config.yml":
_, err := wc.Write(data)
return err
case "hosts.yml":
_, err := wh.Write(data)
return err
default:
return fmt.Errorf("write to unstubbed file: %q", fn)
}
}
return func() {
WriteConfigFile = orig
}
}
func stubConfig(main, hosts string) func() {
orig := ReadConfigFile
ReadConfigFile = func(fn string) ([]byte, error) {
switch filepath.Base(fn) {
case "config.yml":
if main == "" {
return []byte(nil), os.ErrNotExist
} else {
return []byte(main), nil
}
case "hosts.yml":
if hosts == "" {
return []byte(nil), os.ErrNotExist
} else {
return []byte(hosts), nil
}
default:
return []byte(nil), fmt.Errorf("read from unstubbed file: %q", fn)
}
}
return func() {
ReadConfigFile = orig
}
}

View file

@ -45,13 +45,10 @@ func deleteRun(opts *DeleteOptions) error {
return err
}
aliasCfg, err := cfg.Aliases()
if err != nil {
return fmt.Errorf("couldn't read aliases config: %w", err)
}
aliasCfg := cfg.Aliases()
expansion, ok := aliasCfg.Get(opts.Name)
if !ok {
expansion, err := aliasCfg.Get(opts.Name)
if err != nil {
return fmt.Errorf("no such alias %s", opts.Name)
}
@ -61,6 +58,11 @@ func deleteRun(opts *DeleteOptions) error {
return fmt.Errorf("failed to delete alias %s: %w", opts.Name, err)
}
err = cfg.Write()
if err != nil {
return err
}
if opts.IO.IsStdoutTTY() {
cs := opts.IO.ColorScheme()
fmt.Fprintf(opts.IO.ErrOut, "%s Deleted alias %s; was %s\n", cs.SuccessIconWithColor(cs.Red), opts.Name, expansion)

View file

@ -15,6 +15,8 @@ import (
)
func TestAliasDelete(t *testing.T) {
_ = config.StubWriteConfig(t)
tests := []struct {
name string
config string
@ -48,8 +50,6 @@ func TestAliasDelete(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defer config.StubWriteConfig(io.Discard, io.Discard)()
cfg := config.NewFromString(tt.config)
ios, _, stdout, stderr := iostreams.Test()

View file

@ -23,13 +23,10 @@ func ExpandAlias(cfg config.Config, args []string, findShFunc func() (string, er
}
expanded = args[1:]
aliases, err := cfg.Aliases()
if err != nil {
return
}
aliases := cfg.Aliases()
expansion, ok := aliases.Get(args[1])
if !ok {
expansion, getErr := aliases.Get(args[1])
if getErr != nil {
return
}

View file

@ -1,7 +1,6 @@
package list
import (
"fmt"
"sort"
"github.com/MakeNowJust/heredoc"
@ -48,18 +47,14 @@ func listRun(opts *ListOptions) error {
return err
}
aliasCfg, err := cfg.Aliases()
if err != nil {
return fmt.Errorf("couldn't read aliases config: %w", err)
}
aliasCfg := cfg.Aliases()
if aliasCfg.Empty() {
aliasMap := aliasCfg.All()
if len(aliasMap) == 0 {
return cmdutil.NewNoResultsError("no aliases configured")
}
tp := utils.NewTablePrinter(opts.IO)
aliasMap := aliasCfg.All()
keys := []string{}
for alias := range aliasMap {
keys = append(keys, alias)

View file

@ -44,10 +44,6 @@ func TestAliasList(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// TODO: change underlying config implementation so Write is not
// automatically called when editing aliases in-memory
defer config.StubWriteConfig(io.Discard, io.Discard)()
cfg := config.NewFromString(tt.config)
ios, _, stdout, stderr := iostreams.Test()

View file

@ -109,10 +109,7 @@ func setRun(opts *SetOptions) error {
return err
}
aliasCfg, err := cfg.Aliases()
if err != nil {
return err
}
aliasCfg := cfg.Aliases()
expansion, err := getExpansion(opts)
if err != nil {
@ -139,7 +136,7 @@ func setRun(opts *SetOptions) error {
}
successMsg := fmt.Sprintf("%s Added alias.", cs.SuccessIcon())
if oldExpansion, ok := aliasCfg.Get(opts.Name); ok {
if oldExpansion, err := aliasCfg.Get(opts.Name); err == nil {
successMsg = fmt.Sprintf("%s Changed alias %s from %s to %s",
cs.SuccessIcon(),
cs.Bold(opts.Name),
@ -148,9 +145,11 @@ func setRun(opts *SetOptions) error {
)
}
err = aliasCfg.Add(opts.Name, expansion)
aliasCfg.Add(opts.Name, expansion)
err = cfg.Write()
if err != nil {
return fmt.Errorf("could not create alias: %s", err)
return err
}
if isTerminal {

View file

@ -70,8 +70,6 @@ func runCommand(cfg config.Config, isTTY bool, cli string, in string) (*test.Cmd
}
func TestAliasSet_gh_command(t *testing.T) {
defer config.StubWriteConfig(io.Discard, io.Discard)()
cfg := config.NewFromString(``)
_, err := runCommand(cfg, true, "pr 'pr status'", "")
@ -79,8 +77,7 @@ func TestAliasSet_gh_command(t *testing.T) {
}
func TestAliasSet_empty_aliases(t *testing.T) {
mainBuf := bytes.Buffer{}
defer config.StubWriteConfig(&mainBuf, io.Discard)()
readConfigs := config.StubWriteConfig(t)
cfg := config.NewFromString(heredoc.Doc(`
aliases:
@ -93,6 +90,9 @@ func TestAliasSet_empty_aliases(t *testing.T) {
t.Fatalf("unexpected error: %s", err)
}
mainBuf := bytes.Buffer{}
readConfigs(&mainBuf, io.Discard)
//nolint:staticcheck // prefer exact matchers over ExpectLines
test.ExpectLines(t, output.Stderr(), "Added alias")
//nolint:staticcheck // prefer exact matchers over ExpectLines
@ -106,8 +106,7 @@ editor: vim
}
func TestAliasSet_existing_alias(t *testing.T) {
mainBuf := bytes.Buffer{}
defer config.StubWriteConfig(&mainBuf, io.Discard)()
_ = config.StubWriteConfig(t)
cfg := config.NewFromString(heredoc.Doc(`
aliases:
@ -122,14 +121,16 @@ func TestAliasSet_existing_alias(t *testing.T) {
}
func TestAliasSet_space_args(t *testing.T) {
mainBuf := bytes.Buffer{}
defer config.StubWriteConfig(&mainBuf, io.Discard)()
readConfigs := config.StubWriteConfig(t)
cfg := config.NewFromString(``)
output, err := runCommand(cfg, true, `il 'issue list -l "cool story"'`, "")
require.NoError(t, err)
mainBuf := bytes.Buffer{}
readConfigs(&mainBuf, io.Discard)
//nolint:staticcheck // prefer exact matchers over ExpectLines
test.ExpectLines(t, output.Stderr(), `Adding alias for.*il.*issue list -l "cool story"`)
@ -138,6 +139,8 @@ func TestAliasSet_space_args(t *testing.T) {
}
func TestAliasSet_arg_processing(t *testing.T) {
readConfigs := config.StubWriteConfig(t)
cases := []struct {
Cmd string
ExpectedOutputLine string
@ -158,9 +161,6 @@ func TestAliasSet_arg_processing(t *testing.T) {
for _, c := range cases {
t.Run(c.Cmd, func(t *testing.T) {
mainBuf := bytes.Buffer{}
defer config.StubWriteConfig(&mainBuf, io.Discard)()
cfg := config.NewFromString(``)
output, err := runCommand(cfg, true, c.Cmd, "")
@ -168,6 +168,9 @@ func TestAliasSet_arg_processing(t *testing.T) {
t.Fatalf("got unexpected error running %s: %s", c.Cmd, err)
}
mainBuf := bytes.Buffer{}
readConfigs(&mainBuf, io.Discard)
//nolint:staticcheck // prefer exact matchers over ExpectLines
test.ExpectLines(t, output.Stderr(), c.ExpectedOutputLine)
//nolint:staticcheck // prefer exact matchers over ExpectLines
@ -177,8 +180,7 @@ func TestAliasSet_arg_processing(t *testing.T) {
}
func TestAliasSet_init_alias_cfg(t *testing.T) {
mainBuf := bytes.Buffer{}
defer config.StubWriteConfig(&mainBuf, io.Discard)()
readConfigs := config.StubWriteConfig(t)
cfg := config.NewFromString(heredoc.Doc(`
editor: vim
@ -187,6 +189,9 @@ func TestAliasSet_init_alias_cfg(t *testing.T) {
output, err := runCommand(cfg, true, "diff 'pr diff'", "")
require.NoError(t, err)
mainBuf := bytes.Buffer{}
readConfigs(&mainBuf, io.Discard)
expected := `editor: vim
aliases:
diff: pr diff
@ -198,8 +203,7 @@ aliases:
}
func TestAliasSet_existing_aliases(t *testing.T) {
mainBuf := bytes.Buffer{}
defer config.StubWriteConfig(&mainBuf, io.Discard)()
readConfigs := config.StubWriteConfig(t)
cfg := config.NewFromString(heredoc.Doc(`
aliases:
@ -209,6 +213,9 @@ func TestAliasSet_existing_aliases(t *testing.T) {
output, err := runCommand(cfg, true, "view 'pr view'", "")
require.NoError(t, err)
mainBuf := bytes.Buffer{}
readConfigs(&mainBuf, io.Discard)
expected := `aliases:
foo: bar
view: pr view
@ -221,8 +228,6 @@ func TestAliasSet_existing_aliases(t *testing.T) {
}
func TestAliasSet_invalid_command(t *testing.T) {
defer config.StubWriteConfig(io.Discard, io.Discard)()
cfg := config.NewFromString(``)
_, err := runCommand(cfg, true, "co 'pe checkout'", "")
@ -230,8 +235,7 @@ func TestAliasSet_invalid_command(t *testing.T) {
}
func TestShellAlias_flag(t *testing.T) {
mainBuf := bytes.Buffer{}
defer config.StubWriteConfig(&mainBuf, io.Discard)()
readConfigs := config.StubWriteConfig(t)
cfg := config.NewFromString(``)
@ -240,6 +244,9 @@ func TestShellAlias_flag(t *testing.T) {
t.Fatalf("unexpected error: %s", err)
}
mainBuf := bytes.Buffer{}
readConfigs(&mainBuf, io.Discard)
//nolint:staticcheck // prefer exact matchers over ExpectLines
test.ExpectLines(t, output.Stderr(), "Adding alias for.*igrep")
@ -250,14 +257,16 @@ func TestShellAlias_flag(t *testing.T) {
}
func TestShellAlias_bang(t *testing.T) {
mainBuf := bytes.Buffer{}
defer config.StubWriteConfig(&mainBuf, io.Discard)()
readConfigs := config.StubWriteConfig(t)
cfg := config.NewFromString(``)
output, err := runCommand(cfg, true, "igrep '!gh issue list | grep'", "")
require.NoError(t, err)
mainBuf := bytes.Buffer{}
readConfigs(&mainBuf, io.Discard)
//nolint:staticcheck // prefer exact matchers over ExpectLines
test.ExpectLines(t, output.Stderr(), "Adding alias for.*igrep")
@ -268,8 +277,7 @@ func TestShellAlias_bang(t *testing.T) {
}
func TestShellAlias_from_stdin(t *testing.T) {
mainBuf := bytes.Buffer{}
defer config.StubWriteConfig(&mainBuf, io.Discard)()
readConfigs := config.StubWriteConfig(t)
cfg := config.NewFromString(``)
@ -282,6 +290,9 @@ func TestShellAlias_from_stdin(t *testing.T) {
require.NoError(t, err)
mainBuf := bytes.Buffer{}
readConfigs(&mainBuf, io.Discard)
//nolint:staticcheck // prefer exact matchers over ExpectLines
test.ExpectLines(t, output.Stderr(), "Adding alias for.*users")

View file

@ -288,10 +288,7 @@ func apiRun(opts *ApiOptions) error {
return err
}
host, err := cfg.DefaultHost()
if err != nil {
return err
}
host, _ := cfg.DefaultHost()
if opts.Hostname != "" {
host = opts.Hostname

View file

@ -14,7 +14,8 @@ import (
const tokenUser = "x-access-token"
type config interface {
GetWithSource(string, string) (string, string, error)
AuthToken(string) (string, string)
Get(string, string) (string, error)
}
type CredentialOptions struct {
@ -102,16 +103,16 @@ func helperRun(opts *CredentialOptions) error {
lookupHost := wants["host"]
var gotUser string
gotToken, source, _ := cfg.GetWithSource(lookupHost, "oauth_token")
gotToken, source := cfg.AuthToken(lookupHost)
if gotToken == "" && strings.HasPrefix(lookupHost, "gist.") {
lookupHost = strings.TrimPrefix(lookupHost, "gist.")
gotToken, source, _ = cfg.GetWithSource(lookupHost, "oauth_token")
gotToken, source = cfg.AuthToken(lookupHost)
}
if strings.HasSuffix(source, "_TOKEN") {
gotUser = tokenUser
} else {
gotUser, _, _ = cfg.GetWithSource(lookupHost, "user")
gotUser, _ = cfg.Get(lookupHost, "user")
if gotUser == "" {
gotUser = tokenUser
}

View file

@ -11,8 +11,12 @@ import (
// why not just use the config stub argh
type tinyConfig map[string]string
func (c tinyConfig) GetWithSource(host, key string) (string, string, error) {
return c[fmt.Sprintf("%s:%s", host, key)], c["_source"], nil
func (c tinyConfig) AuthToken(host string) (string, string) {
return c[fmt.Sprintf("%s:%s", host, "oauth_token")], c["_source"]
}
func (c tinyConfig) Get(host, key string) (string, error) {
return c[fmt.Sprintf("%s:%s", host, key)], nil
}
func Test_helperRun(t *testing.T) {

View file

@ -1,7 +1,6 @@
package login
import (
"errors"
"fmt"
"io"
"net/http"
@ -136,14 +135,10 @@ func loginRun(opts *LoginOptions) error {
}
}
if err := cfg.CheckWriteable(hostname, "oauth_token"); err != nil {
var roErr *config.ReadOnlyEnvError
if errors.As(err, &roErr) {
fmt.Fprintf(opts.IO.ErrOut, "The value of the %s environment variable is being used for authentication.\n", roErr.Variable)
fmt.Fprint(opts.IO.ErrOut, "To have GitHub CLI store credentials instead, first clear the value from the environment.\n")
return cmdutil.SilentError
}
return err
if src, writeable := shared.AuthTokenWriteable(cfg, hostname); !writeable {
fmt.Fprintf(opts.IO.ErrOut, "The value of the %s environment variable is being used for authentication.\n", src)
fmt.Fprint(opts.IO.ErrOut, "To have GitHub CLI store credentials instead, first clear the value from the environment.\n")
return cmdutil.SilentError
}
httpClient, err := opts.HttpClient()
@ -152,19 +147,16 @@ func loginRun(opts *LoginOptions) error {
}
if opts.Token != "" {
err := cfg.Set(hostname, "oauth_token", opts.Token)
if err != nil {
return err
}
cfg.Set(hostname, "oauth_token", opts.Token)
if err := shared.HasMinimumScopes(httpClient, hostname, opts.Token); err != nil {
return fmt.Errorf("error validating token: %w", err)
}
return cfg.WriteHosts()
return cfg.Write()
}
existingToken, _ := cfg.Get(hostname, "oauth_token")
existingToken, _ := cfg.AuthToken(hostname)
if existingToken != "" && opts.Interactive {
if err := shared.HasMinimumScopes(httpClient, hostname, existingToken); err == nil {
var keepGoing bool

View file

@ -216,7 +216,7 @@ func Test_loginRun_nontty(t *testing.T) {
name string
opts *LoginOptions
httpStubs func(*httpmock.Registry)
env map[string]string
cfgStubs func(*config.ConfigMock)
wantHosts string
wantErr string
wantStderr string
@ -282,8 +282,10 @@ func Test_loginRun_nontty(t *testing.T) {
Hostname: "github.com",
Token: "abc456",
},
env: map[string]string{
"GH_TOKEN": "value_from_env",
cfgStubs: func(c *config.ConfigMock) {
c.AuthTokenFunc = func(string) (string, string) {
return "value_from_env", "GH_TOKEN"
}
},
wantErr: "SilentError",
wantStderr: heredoc.Doc(`
@ -297,8 +299,10 @@ func Test_loginRun_nontty(t *testing.T) {
Hostname: "ghe.io",
Token: "abc456",
},
env: map[string]string{
"GH_ENTERPRISE_TOKEN": "value_from_env",
cfgStubs: func(c *config.ConfigMock) {
c.AuthTokenFunc = func(string) (string, string) {
return "value_from_env", "GH_ENTERPRISE_TOKEN"
}
},
wantErr: "SilentError",
wantStderr: heredoc.Doc(`
@ -310,37 +314,24 @@ func Test_loginRun_nontty(t *testing.T) {
for _, tt := range tests {
ios, _, stdout, stderr := iostreams.Test()
ios.SetStdinTTY(false)
ios.SetStdoutTTY(false)
tt.opts.Config = func() (config.Config, error) {
cfg := config.NewBlankConfig()
return config.InheritEnv(cfg), nil
}
tt.opts.IO = ios
t.Run(tt.name, func(t *testing.T) {
readConfigs := config.StubWriteConfig(t)
cfg := config.NewBlankConfig()
if tt.cfgStubs != nil {
tt.cfgStubs(cfg)
}
tt.opts.Config = func() (config.Config, error) {
return cfg, nil
}
reg := &httpmock.Registry{}
tt.opts.HttpClient = func() (*http.Client, error) {
return &http.Client{Transport: reg}, nil
}
old_GH_TOKEN := os.Getenv("GH_TOKEN")
os.Setenv("GH_TOKEN", tt.env["GH_TOKEN"])
old_GITHUB_TOKEN := os.Getenv("GITHUB_TOKEN")
os.Setenv("GITHUB_TOKEN", tt.env["GITHUB_TOKEN"])
old_GH_ENTERPRISE_TOKEN := os.Getenv("GH_ENTERPRISE_TOKEN")
os.Setenv("GH_ENTERPRISE_TOKEN", tt.env["GH_ENTERPRISE_TOKEN"])
old_GITHUB_ENTERPRISE_TOKEN := os.Getenv("GITHUB_ENTERPRISE_TOKEN")
os.Setenv("GITHUB_ENTERPRISE_TOKEN", tt.env["GITHUB_ENTERPRISE_TOKEN"])
defer func() {
os.Setenv("GH_TOKEN", old_GH_TOKEN)
os.Setenv("GITHUB_TOKEN", old_GITHUB_TOKEN)
os.Setenv("GH_ENTERPRISE_TOKEN", old_GH_ENTERPRISE_TOKEN)
os.Setenv("GITHUB_ENTERPRISE_TOKEN", old_GITHUB_ENTERPRISE_TOKEN)
}()
if tt.httpStubs != nil {
tt.httpStubs(reg)
}
@ -348,10 +339,6 @@ func Test_loginRun_nontty(t *testing.T) {
_, restoreRun := run.Stub()
defer restoreRun(t)
mainBuf := bytes.Buffer{}
hostsBuf := bytes.Buffer{}
defer config.StubWriteConfig(&mainBuf, &hostsBuf)()
err := loginRun(tt.opts)
if tt.wantErr != "" {
assert.EqualError(t, err, tt.wantErr)
@ -359,6 +346,10 @@ func Test_loginRun_nontty(t *testing.T) {
assert.NoError(t, err)
}
mainBuf := bytes.Buffer{}
hostsBuf := bytes.Buffer{}
readConfigs(&mainBuf, &hostsBuf)
assert.Equal(t, "", stdout.String())
assert.Equal(t, tt.wantStderr, stderr.String())
assert.Equal(t, tt.wantHosts, hostsBuf.String())
@ -378,27 +369,26 @@ func Test_loginRun_Survey(t *testing.T) {
runStubs func(*run.CommandStubber)
wantHosts string
wantErrOut *regexp.Regexp
cfg func(config.Config)
cfgStubs func(*config.ConfigMock)
}{
{
name: "already authenticated",
opts: &LoginOptions{
Interactive: true,
},
cfg: func(cfg config.Config) {
_ = cfg.Set("github.com", "oauth_token", "ghi789")
cfgStubs: func(c *config.ConfigMock) {
c.AuthTokenFunc = func(h string) (string, string) {
return "ghi789", "oauth_token"
}
},
httpStubs: func(reg *httpmock.Registry) {
reg.Register(httpmock.REST("GET", ""), httpmock.ScopesResponder("repo,read:org"))
// reg.Register(
// httpmock.GraphQL(`query UserCurrent\b`),
// httpmock.StringResponse(`{"data":{"viewer":{"login":"jillv"}}}`))
},
askStubs: func(as *prompt.AskStubber) {
as.StubPrompt("What account do you want to log into?").AnswerWith("GitHub.com")
as.StubPrompt("You're already logged into github.com. Do you want to re-authenticate?").AnswerWith(false)
},
wantHosts: "", // nothing should have been written to hosts
wantHosts: "",
wantErrOut: nil,
},
{
@ -521,10 +511,11 @@ func Test_loginRun_Survey(t *testing.T) {
tt.opts.IO = ios
cfg := config.NewBlankConfig()
readConfigs := config.StubWriteConfig(t)
if tt.cfg != nil {
tt.cfg(cfg)
cfg := config.NewBlankConfig()
if tt.cfgStubs != nil {
tt.cfgStubs(cfg)
}
tt.opts.Config = func() (config.Config, error) {
return cfg, nil
@ -544,10 +535,6 @@ func Test_loginRun_Survey(t *testing.T) {
httpmock.StringResponse(`{"data":{"viewer":{"login":"jillv"}}}`))
}
mainBuf := bytes.Buffer{}
hostsBuf := bytes.Buffer{}
defer config.StubWriteConfig(&mainBuf, &hostsBuf)()
as := prompt.NewAskStubber(t)
if tt.askStubs != nil {
tt.askStubs(as)
@ -564,6 +551,10 @@ func Test_loginRun_Survey(t *testing.T) {
t.Fatalf("unexpected error: %s", err)
}
mainBuf := bytes.Buffer{}
hostsBuf := bytes.Buffer{}
readConfigs(&mainBuf, &hostsBuf)
assert.Equal(t, tt.wantHosts, hostsBuf.String())
if tt.wantErrOut == nil {
assert.Equal(t, "", stderr.String())

View file

@ -1,7 +1,6 @@
package logout
import (
"errors"
"fmt"
"net/http"
@ -9,6 +8,7 @@ import (
"github.com/MakeNowJust/heredoc"
"github.com/cli/cli/v2/api"
"github.com/cli/cli/v2/internal/config"
"github.com/cli/cli/v2/pkg/cmd/auth/shared"
"github.com/cli/cli/v2/pkg/cmdutil"
"github.com/cli/cli/v2/pkg/iostreams"
"github.com/cli/cli/v2/pkg/prompt"
@ -70,10 +70,7 @@ func logoutRun(opts *LogoutOptions) error {
return err
}
candidates, err := cfg.Hosts()
if err != nil {
return err
}
candidates := cfg.Hosts()
if len(candidates) == 0 {
return fmt.Errorf("not logged in to any hosts")
}
@ -105,14 +102,10 @@ func logoutRun(opts *LogoutOptions) error {
}
}
if err := cfg.CheckWriteable(hostname, "oauth_token"); err != nil {
var roErr *config.ReadOnlyEnvError
if errors.As(err, &roErr) {
fmt.Fprintf(opts.IO.ErrOut, "The value of the %s environment variable is being used for authentication.\n", roErr.Variable)
fmt.Fprint(opts.IO.ErrOut, "To erase credentials stored in GitHub CLI, first clear the value from the environment.\n")
return cmdutil.SilentError
}
return err
if src, writeable := shared.AuthTokenWriteable(cfg, hostname); !writeable {
fmt.Fprintf(opts.IO.ErrOut, "The value of the %s environment variable is being used for authentication.\n", src)
fmt.Fprint(opts.IO.ErrOut, "To erase credentials stored in GitHub CLI, first clear the value from the environment.\n")
return cmdutil.SilentError
}
httpClient, err := opts.HttpClient()
@ -134,7 +127,7 @@ func logoutRun(opts *LogoutOptions) error {
}
cfg.UnsetHost(hostname)
err = cfg.WriteHosts()
err = cfg.Write()
if err != nil {
return fmt.Errorf("failed to write config, authentication configuration not updated: %w", err)
}

View file

@ -114,6 +114,7 @@ func Test_logoutRun_tty(t *testing.T) {
name: "no arguments, one host",
opts: &LogoutOptions{},
cfgHosts: []string{"github.com"},
wantHosts: "{}\n",
wantErrOut: regexp.MustCompile(`Logged out of github.com account 'cybilb'`),
},
{
@ -134,34 +135,29 @@ func Test_logoutRun_tty(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ios, _, _, stderr := iostreams.Test()
ios.SetStdinTTY(true)
ios.SetStdoutTTY(true)
tt.opts.IO = ios
cfg := config.NewBlankConfig()
readConfigs := config.StubWriteConfig(t)
cfg := config.NewFromString("")
for _, hostname := range tt.cfgHosts {
cfg.Set(hostname, "oauth_token", "abc123")
}
tt.opts.Config = func() (config.Config, error) {
return cfg, nil
}
for _, hostname := range tt.cfgHosts {
_ = cfg.Set(hostname, "oauth_token", "abc123")
}
ios, _, _, stderr := iostreams.Test()
ios.SetStdinTTY(true)
ios.SetStdoutTTY(true)
tt.opts.IO = ios
reg := &httpmock.Registry{}
reg.Register(
httpmock.GraphQL(`query UserCurrent\b`),
httpmock.StringResponse(`{"data":{"viewer":{"login":"cybilb"}}}`))
httpmock.StringResponse(`{"data":{"viewer":{"login":"cybilb"}}}`),
)
tt.opts.HttpClient = func() (*http.Client, error) {
return &http.Client{Transport: reg}, nil
}
mainBuf := bytes.Buffer{}
hostsBuf := bytes.Buffer{}
defer config.StubWriteConfig(&mainBuf, &hostsBuf)()
as := prompt.NewAskStubber(t)
if tt.askStubs != nil {
tt.askStubs(as)
@ -181,6 +177,10 @@ func Test_logoutRun_tty(t *testing.T) {
assert.True(t, tt.wantErrOut.MatchString(stderr.String()))
}
mainBuf := bytes.Buffer{}
hostsBuf := bytes.Buffer{}
readConfigs(&mainBuf, &hostsBuf)
assert.Equal(t, tt.wantHosts, hostsBuf.String())
reg.Verify(t)
})
@ -201,7 +201,8 @@ func Test_logoutRun_nontty(t *testing.T) {
opts: &LogoutOptions{
Hostname: "harry.mason",
},
cfgHosts: []string{"harry.mason"},
cfgHosts: []string{"harry.mason"},
wantHosts: "{}\n",
},
{
name: "hostname, multiple hosts",
@ -222,30 +223,25 @@ func Test_logoutRun_nontty(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ios, _, _, stderr := iostreams.Test()
ios.SetStdinTTY(false)
ios.SetStdoutTTY(false)
tt.opts.IO = ios
cfg := config.NewBlankConfig()
readConfigs := config.StubWriteConfig(t)
cfg := config.NewFromString("")
for _, hostname := range tt.cfgHosts {
cfg.Set(hostname, "oauth_token", "abc123")
}
tt.opts.Config = func() (config.Config, error) {
return cfg, nil
}
for _, hostname := range tt.cfgHosts {
_ = cfg.Set(hostname, "oauth_token", "abc123")
}
ios, _, _, stderr := iostreams.Test()
ios.SetStdinTTY(false)
ios.SetStdoutTTY(false)
tt.opts.IO = ios
reg := &httpmock.Registry{}
tt.opts.HttpClient = func() (*http.Client, error) {
return &http.Client{Transport: reg}, nil
}
mainBuf := bytes.Buffer{}
hostsBuf := bytes.Buffer{}
defer config.StubWriteConfig(&mainBuf, &hostsBuf)()
err := logoutRun(tt.opts)
if tt.wantErr != "" {
assert.EqualError(t, err, tt.wantErr)
@ -255,6 +251,10 @@ func Test_logoutRun_nontty(t *testing.T) {
assert.Equal(t, "", stderr.String())
mainBuf := bytes.Buffer{}
hostsBuf := bytes.Buffer{}
readConfigs(&mainBuf, &hostsBuf)
assert.Equal(t, tt.wantHosts, hostsBuf.String())
reg.Verify(t)
})

View file

@ -1,7 +1,6 @@
package refresh
import (
"errors"
"fmt"
"net/http"
"strings"
@ -85,10 +84,7 @@ func refreshRun(opts *RefreshOptions) error {
return err
}
candidates, err := cfg.Hosts()
if err != nil {
return err
}
candidates := cfg.Hosts()
if len(candidates) == 0 {
return fmt.Errorf("not logged in to any hosts. Use 'gh auth login' to authenticate with a host")
}
@ -121,18 +117,14 @@ func refreshRun(opts *RefreshOptions) error {
}
}
if err := cfg.CheckWriteable(hostname, "oauth_token"); err != nil {
var roErr *config.ReadOnlyEnvError
if errors.As(err, &roErr) {
fmt.Fprintf(opts.IO.ErrOut, "The value of the %s environment variable is being used for authentication.\n", roErr.Variable)
fmt.Fprint(opts.IO.ErrOut, "To refresh credentials stored in GitHub CLI, first clear the value from the environment.\n")
return cmdutil.SilentError
}
return err
if src, writeable := shared.AuthTokenWriteable(cfg, hostname); !writeable {
fmt.Fprintf(opts.IO.ErrOut, "The value of the %s environment variable is being used for authentication.\n", src)
fmt.Fprint(opts.IO.ErrOut, "To refresh credentials stored in GitHub CLI, first clear the value from the environment.\n")
return cmdutil.SilentError
}
var additionalScopes []string
if oldToken, _ := cfg.Get(hostname, "oauth_token"); oldToken != "" {
if oldToken, _ := cfg.AuthToken(hostname); oldToken != "" {
if oldScopes, err := shared.GetScopes(opts.httpClient, hostname, oldToken); err == nil {
for _, s := range strings.Split(oldScopes, ",") {
s = strings.TrimSpace(s)
@ -163,7 +155,7 @@ func refreshRun(opts *RefreshOptions) error {
if credentialFlow.ShouldSetup() {
username, _ := cfg.Get(hostname, "user")
password, _ := cfg.Get(hostname, "oauth_token")
password, _ := cfg.AuthToken(hostname)
if err := credentialFlow.Setup(hostname, username, password); err != nil {
return err
}

View file

@ -238,19 +238,19 @@ func Test_refreshRun(t *testing.T) {
return nil
}
ios, _, _, _ := iostreams.Test()
ios.SetStdinTTY(!tt.nontty)
ios.SetStdoutTTY(!tt.nontty)
tt.opts.IO = ios
cfg := config.NewBlankConfig()
_ = config.StubWriteConfig(t)
cfg := config.NewFromString("")
for _, hostname := range tt.cfgHosts {
cfg.Set(hostname, "oauth_token", "abc123")
}
tt.opts.Config = func() (config.Config, error) {
return cfg, nil
}
for _, hostname := range tt.cfgHosts {
_ = cfg.Set(hostname, "oauth_token", "abc123")
}
ios, _, _, _ := iostreams.Test()
ios.SetStdinTTY(!tt.nontty)
ios.SetStdoutTTY(!tt.nontty)
tt.opts.IO = ios
httpReg := &httpmock.Registry{}
httpReg.Register(
@ -272,10 +272,6 @@ func Test_refreshRun(t *testing.T) {
)
tt.opts.httpClient = &http.Client{Transport: httpReg}
mainBuf := bytes.Buffer{}
hostsBuf := bytes.Buffer{}
defer config.StubWriteConfig(&mainBuf, &hostsBuf)()
as := prompt.NewAskStubber(t)
if tt.askStubs != nil {
tt.askStubs(as)

View file

@ -54,10 +54,7 @@ func setupGitRun(opts *SetupGitOptions) error {
return err
}
hostnames, err := cfg.Hosts()
if err != nil {
return err
}
hostnames := cfg.Hosts()
stderr := opts.IO.ErrOut
cs := opts.IO.ColorScheme()

View file

@ -7,7 +7,6 @@ import (
"github.com/cli/cli/v2/internal/config"
"github.com/cli/cli/v2/pkg/iostreams"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type mockGitConfigurer struct {
@ -35,8 +34,16 @@ func Test_setupGitRun(t *testing.T) {
expectedErr: "oops",
},
{
name: "no authenticated hostnames",
opts: &SetupGitOptions{},
name: "no authenticated hostnames",
opts: &SetupGitOptions{
Config: func() (config.Config, error) {
cfg := &config.ConfigMock{}
cfg.HostsFunc = func() []string {
return []string{}
}
return cfg, nil
},
},
expectedErr: "SilentError",
expectedErrOut: "You are not logged into any GitHub hosts. Run gh auth login to authenticate.\n",
},
@ -45,8 +52,10 @@ func Test_setupGitRun(t *testing.T) {
opts: &SetupGitOptions{
Hostname: "foo",
Config: func() (config.Config, error) {
cfg := config.NewBlankConfig()
require.NoError(t, cfg.Set("bar", "", ""))
cfg := &config.ConfigMock{}
cfg.HostsFunc = func() []string {
return []string{"bar"}
}
return cfg, nil
},
},
@ -60,8 +69,10 @@ func Test_setupGitRun(t *testing.T) {
setupErr: fmt.Errorf("broken"),
},
Config: func() (config.Config, error) {
cfg := config.NewBlankConfig()
require.NoError(t, cfg.Set("bar", "", ""))
cfg := &config.ConfigMock{}
cfg.HostsFunc = func() []string {
return []string{"bar"}
}
return cfg, nil
},
},
@ -73,8 +84,10 @@ func Test_setupGitRun(t *testing.T) {
opts: &SetupGitOptions{
gitConfigure: &mockGitConfigurer{},
Config: func() (config.Config, error) {
cfg := config.NewBlankConfig()
require.NoError(t, cfg.Set("bar", "", ""))
cfg := &config.ConfigMock{}
cfg.HostsFunc = func() []string {
return []string{"bar"}
}
return cfg, nil
},
},
@ -85,9 +98,10 @@ func Test_setupGitRun(t *testing.T) {
Hostname: "yes",
gitConfigure: &mockGitConfigurer{},
Config: func() (config.Config, error) {
cfg := config.NewBlankConfig()
require.NoError(t, cfg.Set("bar", "", ""))
require.NoError(t, cfg.Set("yes", "", ""))
cfg := &config.ConfigMock{}
cfg.HostsFunc = func() []string {
return []string{"bar", "yes"}
}
return cfg, nil
},
},
@ -98,7 +112,7 @@ func Test_setupGitRun(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
if tt.opts.Config == nil {
tt.opts.Config = func() (config.Config, error) {
return config.NewBlankConfig(), nil
return &config.ConfigMock{}, nil
}
}

View file

@ -21,9 +21,8 @@ const defaultSSHKeyTitle = "GitHub CLI"
type iconfig interface {
Get(string, string) (string, error)
Set(string, string, string) error
Set(string, string, string)
Write() error
WriteHosts() error
}
type LoginOptions struct {
@ -175,9 +174,7 @@ func Login(opts *LoginOptions) error {
return fmt.Errorf("error validating token: %w", err)
}
if err := cfg.Set(hostname, "oauth_token", authToken); err != nil {
return err
}
cfg.Set(hostname, "oauth_token", authToken)
}
var username string
@ -191,22 +188,16 @@ func Login(opts *LoginOptions) error {
return fmt.Errorf("error using api: %w", err)
}
err = cfg.Set(hostname, "user", username)
if err != nil {
return err
}
cfg.Set(hostname, "user", username)
}
if gitProtocol != "" {
fmt.Fprintf(opts.IO.ErrOut, "- gh config set -h %s git_protocol %s\n", hostname, gitProtocol)
err := cfg.Set(hostname, "git_protocol", gitProtocol)
if err != nil {
return err
}
cfg.Set(hostname, "git_protocol", gitProtocol)
fmt.Fprintf(opts.IO.ErrOut, "%s Configured git protocol\n", cs.SuccessIcon())
}
err := cfg.WriteHosts()
err := cfg.Write()
if err != nil {
return err
}

View file

@ -22,19 +22,14 @@ func (c tinyConfig) Get(host, key string) (string, error) {
return c[fmt.Sprintf("%s:%s", host, key)], nil
}
func (c tinyConfig) Set(host string, key string, value string) error {
func (c tinyConfig) Set(host string, key string, value string) {
c[fmt.Sprintf("%s:%s", host, key)] = value
return nil
}
func (c tinyConfig) Write() error {
return nil
}
func (c tinyConfig) WriteHosts() error {
return nil
}
func TestLogin_ssh(t *testing.T) {
dir := t.TempDir()
ios, _, stdout, stderr := iostreams.Test()

View file

@ -0,0 +1,14 @@
package shared
import (
"github.com/cli/cli/v2/internal/config"
)
const (
oauthToken = "oauth_token"
)
func AuthTokenWriteable(cfg config.Config, hostname string) (string, bool) {
token, src := cfg.AuthToken(hostname)
return src, (token == "" || src == oauthToken)
}

View file

@ -68,10 +68,7 @@ func statusRun(opts *StatusOptions) error {
statusInfo := map[string][]string{}
hostnames, err := cfg.Hosts()
if err != nil {
return err
}
hostnames := cfg.Hosts()
if len(hostnames) == 0 {
fmt.Fprintf(stderr,
"You are not logged into any GitHub hosts. Run %s to authenticate.\n", cs.Bold("gh auth login"))
@ -92,8 +89,8 @@ func statusRun(opts *StatusOptions) error {
}
isHostnameFound = true
token, tokenSource, _ := cfg.GetWithSource(hostname, "oauth_token")
tokenIsWriteable := cfg.CheckWriteable(hostname, "oauth_token") == nil
token, tokenSource := cfg.AuthToken(hostname)
_, tokenIsWriteable := shared.AuthTokenWriteable(cfg, hostname)
statusInfo[hostname] = []string{}
addMsg := func(x string, ys ...interface{}) {

View file

@ -71,11 +71,13 @@ func Test_NewCmdStatus(t *testing.T) {
}
func Test_statusRun(t *testing.T) {
readConfigs := config.StubWriteConfig(t)
tests := []struct {
name string
opts *StatusOptions
httpStubs func(*httpmock.Registry)
cfg func(config.Config)
cfgStubs func(*config.ConfigMock)
wantErr string
wantErrOut *regexp.Regexp
}{
@ -84,9 +86,9 @@ func Test_statusRun(t *testing.T) {
opts: &StatusOptions{
Hostname: "joel.miller",
},
cfg: func(c config.Config) {
_ = c.Set("joel.miller", "oauth_token", "abc123")
_ = c.Set("github.com", "oauth_token", "abc123")
cfgStubs: func(c *config.ConfigMock) {
c.Set("joel.miller", "oauth_token", "abc123")
c.Set("github.com", "oauth_token", "abc123")
},
httpStubs: func(reg *httpmock.Registry) {
reg.Register(httpmock.REST("GET", "api/v3/"), httpmock.ScopesResponder("repo,read:org"))
@ -99,9 +101,9 @@ func Test_statusRun(t *testing.T) {
{
name: "missing scope",
opts: &StatusOptions{},
cfg: func(c config.Config) {
_ = c.Set("joel.miller", "oauth_token", "abc123")
_ = c.Set("github.com", "oauth_token", "abc123")
cfgStubs: func(c *config.ConfigMock) {
c.Set("joel.miller", "oauth_token", "abc123")
c.Set("github.com", "oauth_token", "abc123")
},
httpStubs: func(reg *httpmock.Registry) {
reg.Register(httpmock.REST("GET", "api/v3/"), httpmock.ScopesResponder("repo"))
@ -116,9 +118,9 @@ func Test_statusRun(t *testing.T) {
{
name: "bad token",
opts: &StatusOptions{},
cfg: func(c config.Config) {
_ = c.Set("joel.miller", "oauth_token", "abc123")
_ = c.Set("github.com", "oauth_token", "abc123")
cfgStubs: func(c *config.ConfigMock) {
c.Set("joel.miller", "oauth_token", "abc123")
c.Set("github.com", "oauth_token", "abc123")
},
httpStubs: func(reg *httpmock.Registry) {
reg.Register(httpmock.REST("GET", "api/v3/"), httpmock.StatusStringResponse(400, "no bueno"))
@ -133,9 +135,9 @@ func Test_statusRun(t *testing.T) {
{
name: "all good",
opts: &StatusOptions{},
cfg: func(c config.Config) {
_ = c.Set("joel.miller", "oauth_token", "abc123")
_ = c.Set("github.com", "oauth_token", "abc123")
cfgStubs: func(c *config.ConfigMock) {
c.Set("github.com", "oauth_token", "abc123")
c.Set("joel.miller", "oauth_token", "abc123")
},
httpStubs: func(reg *httpmock.Registry) {
reg.Register(httpmock.REST("GET", "api/v3/"), httpmock.ScopesResponder("repo,read:org"))
@ -152,9 +154,9 @@ func Test_statusRun(t *testing.T) {
{
name: "hide token",
opts: &StatusOptions{},
cfg: func(c config.Config) {
_ = c.Set("joel.miller", "oauth_token", "abc123")
_ = c.Set("github.com", "oauth_token", "xyz456")
cfgStubs: func(c *config.ConfigMock) {
c.Set("joel.miller", "oauth_token", "abc123")
c.Set("github.com", "oauth_token", "xyz456")
},
httpStubs: func(reg *httpmock.Registry) {
reg.Register(httpmock.REST("GET", "api/v3/"), httpmock.ScopesResponder("repo,read:org"))
@ -173,9 +175,9 @@ func Test_statusRun(t *testing.T) {
opts: &StatusOptions{
ShowToken: true,
},
cfg: func(c config.Config) {
_ = c.Set("joel.miller", "oauth_token", "abc123")
_ = c.Set("github.com", "oauth_token", "xyz456")
cfgStubs: func(c *config.ConfigMock) {
c.Set("github.com", "oauth_token", "xyz456")
c.Set("joel.miller", "oauth_token", "abc123")
},
httpStubs: func(reg *httpmock.Registry) {
reg.Register(httpmock.REST("GET", "api/v3/"), httpmock.ScopesResponder("repo,read:org"))
@ -188,13 +190,14 @@ func Test_statusRun(t *testing.T) {
httpmock.StringResponse(`{"data":{"viewer":{"login":"tess"}}}`))
},
wantErrOut: regexp.MustCompile(`(?s)Token: xyz456.*Token: abc123`),
}, {
},
{
name: "missing hostname",
opts: &StatusOptions{
Hostname: "github.example.com",
},
cfg: func(c config.Config) {
_ = c.Set("github.com", "oauth_token", "abc123")
cfgStubs: func(c *config.ConfigMock) {
c.Set("github.com", "oauth_token", "abc123")
},
httpStubs: func(reg *httpmock.Registry) {},
wantErrOut: regexp.MustCompile(`(?s)Hostname "github.example.com" not found among authenticated GitHub hosts`),
@ -213,13 +216,11 @@ func Test_statusRun(t *testing.T) {
ios.SetStdinTTY(true)
ios.SetStderrTTY(true)
ios.SetStdoutTTY(true)
tt.opts.IO = ios
cfg := config.NewBlankConfig()
if tt.cfg != nil {
tt.cfg(cfg)
cfg := config.NewFromString("")
if tt.cfgStubs != nil {
tt.cfgStubs(cfg)
}
tt.opts.Config = func() (config.Config, error) {
return cfg, nil
@ -232,9 +233,6 @@ func Test_statusRun(t *testing.T) {
if tt.httpStubs != nil {
tt.httpStubs(reg)
}
mainBuf := bytes.Buffer{}
hostsBuf := bytes.Buffer{}
defer config.StubWriteConfig(&mainBuf, &hostsBuf)()
err := statusRun(tt.opts)
if tt.wantErr != "" {
@ -250,6 +248,10 @@ func Test_statusRun(t *testing.T) {
assert.True(t, tt.wantErrOut.MatchString(stderr.String()))
}
mainBuf := bytes.Buffer{}
hostsBuf := bytes.Buffer{}
readConfigs(&mainBuf, &hostsBuf)
assert.Equal(t, "", mainBuf.String())
assert.Equal(t, "", hostsBuf.String())

View file

@ -42,7 +42,7 @@ func TestNewCmdConfigGet(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
f := &cmdutil.Factory{
Config: func() (config.Config, error) {
return config.ConfigStub{}, nil
return config.NewBlankConfig(), nil
},
}
@ -86,9 +86,11 @@ func Test_getRun(t *testing.T) {
name: "get key",
input: &GetOptions{
Key: "editor",
Config: config.ConfigStub{
"editor": "ed",
},
Config: func() config.Config {
cfg := config.NewBlankConfig()
cfg.Set("", "editor", "ed")
return cfg
}(),
},
stdout: "ed\n",
},
@ -97,10 +99,12 @@ func Test_getRun(t *testing.T) {
input: &GetOptions{
Hostname: "github.com",
Key: "editor",
Config: config.ConfigStub{
"editor": "ed",
"github.com:editor": "vim",
},
Config: func() config.Config {
cfg := config.NewBlankConfig()
cfg.Set("", "editor", "ed")
cfg.Set("github.com", "editor", "vim")
return cfg
}(),
},
stdout: "vim\n",
},
@ -115,10 +119,6 @@ func Test_getRun(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, tt.stdout, stdout.String())
assert.Equal(t, tt.stderr, stderr.String())
_, err = tt.input.Config.GetOrDefault("", "_written")
assert.Error(t, err)
_, err = tt.input.Config.Get("", "_written")
assert.Error(t, err)
})
}
}

View file

@ -51,10 +51,7 @@ func listRun(opts *ListOptions) error {
if opts.Hostname != "" {
host = opts.Hostname
} else {
host, err = cfg.DefaultHost()
if err != nil {
return err
}
host, _ = cfg.DefaultHost()
}
configOptions := config.ConfigOptions()

View file

@ -36,7 +36,7 @@ func TestNewCmdConfigList(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
f := &cmdutil.Factory{
Config: func() (config.Config, error) {
return config.ConfigStub{}, nil
return config.NewBlankConfig(), nil
},
}
@ -71,21 +71,23 @@ func Test_listRun(t *testing.T) {
tests := []struct {
name string
input *ListOptions
config config.ConfigStub
config config.Config
stdout string
wantErr bool
}{
{
name: "list",
config: config.ConfigStub{
"HOST:git_protocol": "ssh",
"HOST:editor": "/usr/bin/vim",
"HOST:prompt": "disabled",
"HOST:pager": "less",
"HOST:http_unix_socket": "",
"HOST:browser": "brave",
},
input: &ListOptions{Hostname: "HOST"}, // ConfigStub gives empty DefaultHost
config: func() config.Config {
cfg := config.NewBlankConfig()
cfg.Set("HOST", "git_protocol", "ssh")
cfg.Set("HOST", "editor", "/usr/bin/vim")
cfg.Set("HOST", "prompt", "disabled")
cfg.Set("HOST", "pager", "less")
cfg.Set("HOST", "http_unix_socket", "")
cfg.Set("HOST", "browser", "brave")
return cfg
}(),
input: &ListOptions{Hostname: "HOST"},
stdout: `git_protocol=ssh
editor=/usr/bin/vim
prompt=disabled

View file

@ -59,15 +59,15 @@ func NewCmdConfigSet(f *cmdutil.Factory, runF func(*SetOptions) error) *cobra.Co
}
func setRun(opts *SetOptions) error {
err := config.ValidateKey(opts.Key)
err := ValidateKey(opts.Key)
if err != nil {
warningIcon := opts.IO.ColorScheme().WarningIcon()
fmt.Fprintf(opts.IO.ErrOut, "%s warning: '%s' is not a known configuration key\n", warningIcon, opts.Key)
}
err = config.ValidateValue(opts.Key, opts.Value)
err = ValidateValue(opts.Key, opts.Value)
if err != nil {
var invalidValue *config.InvalidValueError
var invalidValue InvalidValueError
if errors.As(err, &invalidValue) {
var values []string
for _, v := range invalidValue.ValidValues {
@ -77,10 +77,7 @@ func setRun(opts *SetOptions) error {
}
}
err = opts.Config.Set(opts.Hostname, opts.Key, opts.Value)
if err != nil {
return fmt.Errorf("failed to set %q to %q: %w", opts.Key, opts.Value, err)
}
opts.Config.Set(opts.Hostname, opts.Key, opts.Value)
err = opts.Config.Write()
if err != nil {
@ -88,3 +85,44 @@ func setRun(opts *SetOptions) error {
}
return nil
}
func ValidateKey(key string) error {
for _, configKey := range config.ConfigOptions() {
if key == configKey.Key {
return nil
}
}
return fmt.Errorf("invalid key")
}
type InvalidValueError struct {
ValidValues []string
}
func (e InvalidValueError) Error() string {
return "invalid value"
}
func ValidateValue(key, value string) error {
var validValues []string
for _, v := range config.ConfigOptions() {
if v.Key == key {
validValues = v.AllowedValues
break
}
}
if validValues == nil {
return nil
}
for _, v := range validValues {
if v == value {
return nil
}
}
return InvalidValueError{ValidValues: validValues}
}

View file

@ -46,9 +46,11 @@ func TestNewCmdConfigSet(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_ = config.StubWriteConfig(t)
f := &cmdutil.Factory{
Config: func() (config.Config, error) {
return config.ConfigStub{}, nil
return config.NewBlankConfig(), nil
},
}
@ -94,7 +96,7 @@ func Test_setRun(t *testing.T) {
{
name: "set key value",
input: &SetOptions{
Config: config.ConfigStub{},
Config: config.NewBlankConfig(),
Key: "editor",
Value: "vim",
},
@ -103,7 +105,7 @@ func Test_setRun(t *testing.T) {
{
name: "set key value scoped by host",
input: &SetOptions{
Config: config.ConfigStub{},
Config: config.NewBlankConfig(),
Hostname: "github.com",
Key: "editor",
Value: "vim",
@ -113,7 +115,7 @@ func Test_setRun(t *testing.T) {
{
name: "set unknown key",
input: &SetOptions{
Config: config.ConfigStub{},
Config: config.NewBlankConfig(),
Key: "unknownKey",
Value: "someValue",
},
@ -123,7 +125,7 @@ func Test_setRun(t *testing.T) {
{
name: "set invalid value",
input: &SetOptions{
Config: config.ConfigStub{},
Config: config.NewBlankConfig(),
Key: "git_protocol",
Value: "invalid",
},
@ -132,10 +134,12 @@ func Test_setRun(t *testing.T) {
},
}
for _, tt := range tests {
ios, _, stdout, stderr := iostreams.Test()
tt.input.IO = ios
t.Run(tt.name, func(t *testing.T) {
_ = config.StubWriteConfig(t)
ios, _, stdout, stderr := iostreams.Test()
tt.input.IO = ios
err := setRun(tt.input)
if tt.wantsErr {
assert.EqualError(t, err, tt.errMsg)
@ -148,10 +152,46 @@ func Test_setRun(t *testing.T) {
val, err := tt.input.Config.GetOrDefault(tt.input.Hostname, tt.input.Key)
assert.NoError(t, err)
assert.Equal(t, tt.expectedValue, val)
val, err = tt.input.Config.GetOrDefault("", "_written")
assert.NoError(t, err)
assert.Equal(t, "true", val)
})
}
}
func Test_ValidateValue(t *testing.T) {
err := ValidateValue("git_protocol", "sshpps")
assert.EqualError(t, err, "invalid value")
err = ValidateValue("git_protocol", "ssh")
assert.NoError(t, err)
err = ValidateValue("editor", "vim")
assert.NoError(t, err)
err = ValidateValue("got", "123")
assert.NoError(t, err)
err = ValidateValue("http_unix_socket", "really_anything/is/allowed/and/net.Dial\\(...\\)/will/ultimately/validate")
assert.NoError(t, err)
}
func Test_ValidateKey(t *testing.T) {
err := ValidateKey("invalid")
assert.EqualError(t, err, "invalid key")
err = ValidateKey("git_protocol")
assert.NoError(t, err)
err = ValidateKey("editor")
assert.NoError(t, err)
err = ValidateKey("prompt")
assert.NoError(t, err)
err = ValidateKey("pager")
assert.NoError(t, err)
err = ValidateKey("http_unix_socket")
assert.NoError(t, err)
err = ValidateKey("browser")
assert.NoError(t, err)
}

View file

@ -702,10 +702,7 @@ func (m *Manager) goBinScaffolding(gitExe, name string) error {
return err
}
host, err := m.config.DefaultHost()
if err != nil {
return err
}
host, _ := m.config.DefaultHost()
currentUser, err := api.CurrentLoginName(api.NewClientFromHTTP(m.client), host)
if err != nil {

View file

@ -1,7 +1,6 @@
package factory
import (
"errors"
"fmt"
"net/http"
"os"
@ -134,12 +133,7 @@ func configFunc() func() (config.Config, error) {
if cachedConfig != nil || configError != nil {
return cachedConfig, configError
}
cachedConfig, configError = config.ParseDefaultConfig()
if errors.Is(configError, os.ErrNotExist) {
cachedConfig = config.NewBlankConfig()
configError = nil
}
cachedConfig = config.InheritEnv(cachedConfig)
cachedConfig, configError = config.NewConfig()
return cachedConfig, configError
}
}

View file

@ -7,7 +7,6 @@ import (
"os"
"testing"
"github.com/MakeNowJust/heredoc"
"github.com/cli/cli/v2/git"
"github.com/cli/cli/v2/internal/config"
"github.com/cli/cli/v2/pkg/cmdutil"
@ -17,15 +16,9 @@ import (
)
func Test_BaseRepo(t *testing.T) {
orig_GH_HOST := os.Getenv("GH_HOST")
t.Cleanup(func() {
os.Setenv("GH_HOST", orig_GH_HOST)
})
tests := []struct {
name string
remotes git.RemoteSet
config config.Config
override string
wantsErr bool
wantsName string
@ -37,7 +30,6 @@ func Test_BaseRepo(t *testing.T) {
remotes: git.RemoteSet{
git.NewRemote("origin", "https://nonsense.com/owner/repo.git"),
},
config: defaultConfig(),
wantsName: "repo",
wantsOwner: "owner",
wantsHost: "nonsense.com",
@ -47,7 +39,6 @@ func Test_BaseRepo(t *testing.T) {
remotes: git.RemoteSet{
git.NewRemote("origin", "https://test.com/owner/repo.git"),
},
config: defaultConfig(),
wantsErr: true,
},
{
@ -55,7 +46,6 @@ func Test_BaseRepo(t *testing.T) {
remotes: git.RemoteSet{
git.NewRemote("origin", "https://test.com/owner/repo.git"),
},
config: defaultConfig(),
override: "test.com",
wantsName: "repo",
wantsOwner: "owner",
@ -66,7 +56,6 @@ func Test_BaseRepo(t *testing.T) {
remotes: git.RemoteSet{
git.NewRemote("origin", "https://nonsense.com/owner/repo.git"),
},
config: defaultConfig(),
override: "test.com",
wantsErr: true,
},
@ -74,18 +63,30 @@ func Test_BaseRepo(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.override != "" {
os.Setenv("GH_HOST", tt.override)
} else {
os.Unsetenv("GH_HOST")
}
f := New("1")
rr := &remoteResolver{
readRemotes: func() (git.RemoteSet, error) {
return tt.remotes, nil
},
getConfig: func() (config.Config, error) {
return tt.config, nil
cfg := &config.ConfigMock{}
cfg.HostsFunc = func() []string {
hosts := []string{"nonsense.com"}
if tt.override != "" {
hosts = append([]string{tt.override}, hosts...)
}
return hosts
}
cfg.DefaultHostFunc = func() (string, string) {
if tt.override != "" {
return tt.override, "GH_HOST"
}
return "nonsense.com", "hosts"
}
cfg.AuthTokenFunc = func(string) (string, string) {
return "", ""
}
return cfg, nil
},
}
f.Remotes = rr.Resolver()
@ -105,15 +106,10 @@ func Test_BaseRepo(t *testing.T) {
func Test_SmartBaseRepo(t *testing.T) {
pu, _ := url.Parse("https://test.com/newowner/newrepo.git")
orig_GH_HOST := os.Getenv("GH_HOST")
t.Cleanup(func() {
os.Setenv("GH_HOST", orig_GH_HOST)
})
tests := []struct {
name string
remotes git.RemoteSet
config config.Config
override string
wantsErr bool
wantsName string
@ -125,7 +121,6 @@ func Test_SmartBaseRepo(t *testing.T) {
remotes: git.RemoteSet{
git.NewRemote("origin", "https://test.com/owner/repo.git"),
},
config: defaultConfig(),
override: "test.com",
wantsName: "repo",
wantsOwner: "owner",
@ -139,7 +134,6 @@ func Test_SmartBaseRepo(t *testing.T) {
FetchURL: pu,
PushURL: pu},
},
config: defaultConfig(),
override: "test.com",
wantsName: "newrepo",
wantsOwner: "newowner",
@ -153,7 +147,6 @@ func Test_SmartBaseRepo(t *testing.T) {
FetchURL: pu,
PushURL: pu},
},
config: defaultConfig(),
override: "test.com",
wantsName: "test",
wantsOwner: "johnny",
@ -164,7 +157,6 @@ func Test_SmartBaseRepo(t *testing.T) {
remotes: git.RemoteSet{
git.NewRemote("origin", "https://example.com/owner/repo.git"),
},
config: defaultConfig(),
override: "test.com",
wantsErr: true,
},
@ -172,18 +164,27 @@ func Test_SmartBaseRepo(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.override != "" {
os.Setenv("GH_HOST", tt.override)
} else {
os.Unsetenv("GH_HOST")
}
f := New("1")
rr := &remoteResolver{
readRemotes: func() (git.RemoteSet, error) {
return tt.remotes, nil
},
getConfig: func() (config.Config, error) {
return tt.config, nil
cfg := &config.ConfigMock{}
cfg.HostsFunc = func() []string {
hosts := []string{"nonsense.com"}
if tt.override != "" {
hosts = append([]string{tt.override}, hosts...)
}
return hosts
}
cfg.DefaultHostFunc = func() (string, string) {
if tt.override != "" {
return tt.override, "GH_HOST"
}
return "nonsense.com", "hosts"
}
return cfg, nil
},
}
f.HttpClient = func() (*http.Client, error) { return nil, nil }
@ -204,11 +205,6 @@ func Test_SmartBaseRepo(t *testing.T) {
// Defined in pkg/cmdutil/repo_override.go but test it along with other BaseRepo functions
func Test_OverrideBaseRepo(t *testing.T) {
orig_GH_HOST := os.Getenv("GH_REPO")
t.Cleanup(func() {
os.Setenv("GH_REPO", orig_GH_HOST)
})
tests := []struct {
name string
remotes git.RemoteSet
@ -249,9 +245,9 @@ func Test_OverrideBaseRepo(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.envOverride != "" {
old := os.Getenv("GH_REPO")
os.Setenv("GH_REPO", tt.envOverride)
} else {
os.Unsetenv("GH_REPO")
defer os.Setenv("GH_REPO", old)
}
f := New("1")
rr := &remoteResolver{
@ -511,12 +507,10 @@ func TestSSOURL(t *testing.T) {
}
}
func defaultConfig() config.Config {
return config.InheritEnv(config.NewFromString(heredoc.Doc(`
hosts:
nonsense.com:
oauth_token: BLAH
`)))
func defaultConfig() *config.ConfigMock {
cfg := config.NewFromString("")
cfg.Set("nonsense.com", "oauth_token", "BLAH")
return cfg
}
func pagerConfig() config.Config {

View file

@ -13,6 +13,10 @@ import (
"github.com/cli/go-gh/pkg/ssh"
)
const (
GH_HOST = "GH_HOST"
)
type remoteResolver struct {
readRemotes func() (git.RemoteSet, error)
getConfig func() (config.Config, error)
@ -49,14 +53,12 @@ func (rr *remoteResolver) Resolver() func() (context.Remotes, error) {
return nil, err
}
authedHosts, err := cfg.Hosts()
if err != nil {
return nil, err
}
defaultHost, src, err := cfg.DefaultHostWithSource()
if err != nil {
return nil, err
authedHosts := cfg.Hosts()
if len(authedHosts) == 0 {
return nil, errors.New("could not find any host configurations")
}
defaultHost, src := cfg.DefaultHost()
// Use set to dedupe list of hosts
hostsSet := set.NewStringSet()
hostsSet.AddValues(authedHosts)
@ -72,18 +74,19 @@ func (rr *remoteResolver) Resolver() func() (context.Remotes, error) {
// Filter again by default host if one is set
// For config file default host fallback to cachedRemotes if none match
// For enviornment default host (GH_HOST) do not fallback to cachedRemotes if none match
if src != "" {
if src != "default" {
filteredRemotes := cachedRemotes.FilterByHosts([]string{defaultHost})
if config.IsHostEnv(src) || len(filteredRemotes) > 0 {
if isHostEnv(src) || len(filteredRemotes) > 0 {
cachedRemotes = filteredRemotes
}
}
if len(cachedRemotes) == 0 {
dummyHostname := "example.com" // any non-github.com hostname is fine here
if config.IsHostEnv(src) {
// Any non-github.com hostname is fine here
dummyHostname := "example.com"
if isHostEnv(src) {
return nil, fmt.Errorf("none of the git remotes configured for this repository correspond to the %s environment variable. Try adding a matching remote or unsetting the variable.", src)
} else if v, src, _ := cfg.GetWithSource(dummyHostname, "oauth_token"); v != "" && config.IsEnterpriseEnv(src) {
} else if v, _ := cfg.AuthToken(dummyHostname); v != "" {
return nil, errors.New("set the GH_HOST environment variable to specify which GitHub host to use")
}
return nil, errors.New("none of the git remotes configured for this repository point to a known GitHub host. To tell gh about a new GitHub host, please use `gh auth login`")
@ -92,3 +95,7 @@ func (rr *remoteResolver) Resolver() func() (context.Remotes, error) {
return cachedRemotes, nil
}
}
func isHostEnv(src string) bool {
return src == GH_HOST
}

View file

@ -5,7 +5,6 @@ import (
"os"
"testing"
"github.com/MakeNowJust/heredoc"
"github.com/cli/cli/v2/git"
"github.com/cli/cli/v2/internal/config"
"github.com/stretchr/testify/assert"
@ -26,8 +25,7 @@ func Test_remoteResolver(t *testing.T) {
tests := []struct {
name string
remotes func() (git.RemoteSet, error)
config func() (config.Config, error)
override string
config config.Config
output []string
wantsErr bool
}{
@ -38,9 +36,16 @@ func Test_remoteResolver(t *testing.T) {
git.NewRemote("origin", "https://github.com/owner/repo.git"),
}, nil
},
config: func() (config.Config, error) {
return config.NewFromString(heredoc.Doc(`hosts:`)), nil
},
config: func() config.Config {
cfg := &config.ConfigMock{}
cfg.HostsFunc = func() []string {
return []string{}
}
cfg.DefaultHostFunc = func() (string, string) {
return "github.com", "default"
}
return cfg
}(),
wantsErr: true,
},
{
@ -48,13 +53,16 @@ func Test_remoteResolver(t *testing.T) {
remotes: func() (git.RemoteSet, error) {
return git.RemoteSet{}, nil
},
config: func() (config.Config, error) {
return config.NewFromString(heredoc.Doc(`
hosts:
example.com:
oauth_token: GHETOKEN
`)), nil
},
config: func() config.Config {
cfg := &config.ConfigMock{}
cfg.HostsFunc = func() []string {
return []string{"example.com"}
}
cfg.DefaultHostFunc = func() (string, string) {
return "example.com", "hosts"
}
return cfg
}(),
wantsErr: true,
},
{
@ -64,13 +72,19 @@ func Test_remoteResolver(t *testing.T) {
git.NewRemote("origin", "https://test.com/owner/repo.git"),
}, nil
},
config: func() (config.Config, error) {
return config.NewFromString(heredoc.Doc(`
hosts:
example.com:
oauth_token: GHETOKEN
`)), nil
},
config: func() config.Config {
cfg := &config.ConfigMock{}
cfg.HostsFunc = func() []string {
return []string{"example.com"}
}
cfg.DefaultHostFunc = func() (string, string) {
return "example.com", "hosts"
}
cfg.AuthTokenFunc = func(string) (string, string) {
return "", ""
}
return cfg
}(),
wantsErr: true,
},
{
@ -80,30 +94,35 @@ func Test_remoteResolver(t *testing.T) {
git.NewRemote("origin", "https://github.com/owner/repo.git"),
}, nil
},
config: func() (config.Config, error) {
return config.NewFromString(heredoc.Doc(`
hosts:
example.com:
oauth_token: GHETOKEN
`)), nil
},
config: func() config.Config {
cfg := &config.ConfigMock{}
cfg.HostsFunc = func() []string {
return []string{"example.com"}
}
cfg.DefaultHostFunc = func() (string, string) {
return "example.com", "hosts"
}
return cfg
}(),
output: []string{"origin"},
},
{
name: "one authenticated host with matching git remote",
remotes: func() (git.RemoteSet, error) {
return git.RemoteSet{
git.NewRemote("upstream", "https://github.com/owner/repo.git"),
git.NewRemote("origin", "https://example.com/owner/repo.git"),
}, nil
},
config: func() (config.Config, error) {
return config.NewFromString(heredoc.Doc(`
hosts:
example.com:
oauth_token: GHETOKEN
`)), nil
},
config: func() config.Config {
cfg := &config.ConfigMock{}
cfg.HostsFunc = func() []string {
return []string{"example.com"}
}
cfg.DefaultHostFunc = func() (string, string) {
return "example.com", "default"
}
return cfg
}(),
output: []string{"origin"},
},
{
@ -116,13 +135,16 @@ func Test_remoteResolver(t *testing.T) {
git.NewRemote("fork", "https://example.com/owner/repo.git"),
}, nil
},
config: func() (config.Config, error) {
return config.NewFromString(heredoc.Doc(`
hosts:
example.com:
oauth_token: GHETOKEN
`)), nil
},
config: func() config.Config {
cfg := &config.ConfigMock{}
cfg.HostsFunc = func() []string {
return []string{"example.com"}
}
cfg.DefaultHostFunc = func() (string, string) {
return "example.com", "default"
}
return cfg
}(),
output: []string{"upstream", "github", "origin", "fork"},
},
{
@ -132,15 +154,19 @@ func Test_remoteResolver(t *testing.T) {
git.NewRemote("origin", "https://test.com/owner/repo.git"),
}, nil
},
config: func() (config.Config, error) {
return config.NewFromString(heredoc.Doc(`
hosts:
example.com:
oauth_token: GHETOKEN
github.com:
oauth_token: GHTOKEN
`)), nil
},
config: func() config.Config {
cfg := &config.ConfigMock{}
cfg.HostsFunc = func() []string {
return []string{"example.com", "github.com"}
}
cfg.DefaultHostFunc = func() (string, string) {
return "github.com", "default"
}
cfg.AuthTokenFunc = func(string) (string, string) {
return "", ""
}
return cfg
}(),
wantsErr: true,
},
{
@ -151,15 +177,16 @@ func Test_remoteResolver(t *testing.T) {
git.NewRemote("origin", "https://example.com/owner/repo.git"),
}, nil
},
config: func() (config.Config, error) {
return config.NewFromString(heredoc.Doc(`
hosts:
example.com:
oauth_token: GHETOKEN
github.com:
oauth_token: GHTOKEN
`)), nil
},
config: func() config.Config {
cfg := &config.ConfigMock{}
cfg.HostsFunc = func() []string {
return []string{"example.com", "github.com"}
}
cfg.DefaultHostFunc = func() (string, string) {
return "github.com", "default"
}
return cfg
}(),
output: []string{"origin"},
},
{
@ -173,15 +200,16 @@ func Test_remoteResolver(t *testing.T) {
git.NewRemote("test", "https://test.com/owner/repo.git"),
}, nil
},
config: func() (config.Config, error) {
return config.NewFromString(heredoc.Doc(`
hosts:
example.com:
oauth_token: GHETOKEN
github.com:
oauth_token: GHTOKEN
`)), nil
},
config: func() config.Config {
cfg := &config.ConfigMock{}
cfg.HostsFunc = func() []string {
return []string{"example.com", "github.com"}
}
cfg.DefaultHostFunc = func() (string, string) {
return "github.com", "default"
}
return cfg
}(),
output: []string{"upstream", "github", "origin", "fork"},
},
{
@ -191,14 +219,16 @@ func Test_remoteResolver(t *testing.T) {
git.NewRemote("origin", "https://example.com/owner/repo.git"),
}, nil
},
config: func() (config.Config, error) {
return config.InheritEnv(config.NewFromString(heredoc.Doc(`
hosts:
example.com:
oauth_token: GHETOKEN
`))), nil
},
override: "test.com",
config: func() config.Config {
cfg := &config.ConfigMock{}
cfg.HostsFunc = func() []string {
return []string{"example.com"}
}
cfg.DefaultHostFunc = func() (string, string) {
return "test.com", "GH_HOST"
}
return cfg
}(),
wantsErr: true,
},
{
@ -209,15 +239,17 @@ func Test_remoteResolver(t *testing.T) {
git.NewRemote("origin", "https://test.com/owner/repo.git"),
}, nil
},
config: func() (config.Config, error) {
return config.InheritEnv(config.NewFromString(heredoc.Doc(`
hosts:
example.com:
oauth_token: GHETOKEN
`))), nil
},
override: "test.com",
output: []string{"origin"},
config: func() config.Config {
cfg := &config.ConfigMock{}
cfg.HostsFunc = func() []string {
return []string{"example.com"}
}
cfg.DefaultHostFunc = func() (string, string) {
return "test.com", "GH_HOST"
}
return cfg
}(),
output: []string{"origin"},
},
{
name: "override host with multiple matching git remotes",
@ -228,26 +260,25 @@ func Test_remoteResolver(t *testing.T) {
git.NewRemote("origin", "https://test.com/owner/repo.git"),
}, nil
},
config: func() (config.Config, error) {
return config.InheritEnv(config.NewFromString(heredoc.Doc(`
hosts:
example.com:
oauth_token: GHETOKEN
`))), nil
},
override: "test.com",
output: []string{"upstream", "origin"},
config: func() config.Config {
cfg := &config.ConfigMock{}
cfg.HostsFunc = func() []string {
return []string{"example.com", "test.com"}
}
cfg.DefaultHostFunc = func() (string, string) {
return "test.com", "GH_HOST"
}
return cfg
}(),
output: []string{"upstream", "origin"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.override != "" {
os.Setenv("GH_HOST", tt.override)
}
rr := &remoteResolver{
readRemotes: tt.remotes,
getConfig: tt.config,
getConfig: func() (config.Config, error) { return tt.config, nil },
urlTranslator: identityTranslator{},
}
resolver := rr.Resolver()

View file

@ -75,10 +75,7 @@ func cloneRun(opts *CloneOptions) error {
if err != nil {
return err
}
hostname, err := cfg.DefaultHost()
if err != nil {
return err
}
hostname, _ := cfg.DefaultHost()
protocol, err := cfg.GetOrDefault(hostname, "git_protocol")
if err != nil {
return err

View file

@ -143,10 +143,7 @@ func createRun(opts *CreateOptions) error {
return err
}
host, err := cfg.DefaultHost()
if err != nil {
return err
}
host, _ := cfg.DefaultHost()
opts.IO.StartProgressIndicator()
gist, err := createGist(httpClient, host, opts.Description, opts.Public, files)

View file

@ -64,10 +64,7 @@ func deleteRun(opts *DeleteOptions) error {
return err
}
host, err := cfg.DefaultHost()
if err != nil {
return err
}
host, _ := cfg.DefaultHost()
apiClient := api.NewClientFromHTTP(client)
if err := deleteGist(apiClient, host, gistID); err != nil {

View file

@ -107,10 +107,7 @@ func editRun(opts *EditOptions) error {
return err
}
host, err := cfg.DefaultHost()
if err != nil {
return err
}
host, _ := cfg.DefaultHost()
gist, err := shared.GetGist(client, host, gistID)
if err != nil {

View file

@ -76,10 +76,7 @@ func listRun(opts *ListOptions) error {
return err
}
host, err := cfg.DefaultHost()
if err != nil {
return err
}
host, _ := cfg.DefaultHost()
gists, err := shared.ListGists(client, host, opts.Limit, opts.Visibility)
if err != nil {

View file

@ -86,10 +86,7 @@ func viewRun(opts *ViewOptions) error {
return err
}
hostname, err := cfg.DefaultHost()
if err != nil {
return err
}
hostname, _ := cfg.DefaultHost()
cs := opts.IO.ColorScheme()
if gistID == "" {

View file

@ -76,10 +76,7 @@ func runAdd(opts *AddOptions) error {
return err
}
hostname, err := cfg.DefaultHost()
if err != nil {
return err
}
hostname, _ := cfg.DefaultHost()
err = gpgKeyUpload(httpClient, hostname, keyReader)
if err != nil {

View file

@ -53,10 +53,7 @@ func listRun(opts *ListOptions) error {
return err
}
host, err := cfg.DefaultHost()
if err != nil {
return err
}
host, _ := cfg.DefaultHost()
gpgKeys, err := userKeys(apiClient, host, "")
if err != nil {

View file

@ -83,10 +83,7 @@ func archiveRun(opts *ArchiveOptions) error {
return err
}
hostname, err := cfg.DefaultHost()
if err != nil {
return err
}
hostname, _ := cfg.DefaultHost()
currentUser, err := api.CurrentLoginName(apiClient, hostname)
if err != nil {

View file

@ -111,10 +111,7 @@ func cloneRun(opts *CloneOptions) error {
if repositoryIsFullName {
fullName = opts.Repository
} else {
host, err := cfg.DefaultHost()
if err != nil {
return err
}
host, _ := cfg.DefaultHost()
currentUser, err := api.CurrentLoginName(apiClient, host)
if err != nil {
return err

View file

@ -191,10 +191,7 @@ func NewCmdCreate(f *cmdutil.Factory, runF func(*CreateOptions) error) *cobra.Co
if err != nil {
return nil, cobra.ShellCompDirectiveError
}
hostname, err := cfg.DefaultHost()
if err != nil {
return nil, cobra.ShellCompDirectiveError
}
hostname, _ := cfg.DefaultHost()
results, err := listGitIgnoreTemplates(httpClient, hostname)
if err != nil {
return nil, cobra.ShellCompDirectiveError
@ -211,10 +208,7 @@ func NewCmdCreate(f *cmdutil.Factory, runF func(*CreateOptions) error) *cobra.Co
if err != nil {
return nil, cobra.ShellCompDirectiveError
}
hostname, err := cfg.DefaultHost()
if err != nil {
return nil, cobra.ShellCompDirectiveError
}
hostname, _ := cfg.DefaultHost()
licenses, err := listLicenseTemplates(httpClient, hostname)
if err != nil {
return nil, cobra.ShellCompDirectiveError
@ -266,10 +260,7 @@ func createFromScratch(opts *CreateOptions) error {
return err
}
host, err := cfg.DefaultHost()
if err != nil {
return err
}
host, _ := cfg.DefaultHost()
if opts.Interactive {
opts.Name, opts.Description, opts.Visibility, err = interactiveRepoInfo("")
@ -409,10 +400,7 @@ func createFromLocal(opts *CreateOptions) error {
if err != nil {
return err
}
host, err := cfg.DefaultHost()
if err != nil {
return err
}
host, _ := cfg.DefaultHost()
if opts.Interactive {
opts.Source, err = interactiveSource()

View file

@ -225,10 +225,7 @@ func forkRun(opts *ForkOptions) error {
if err != nil {
return err
}
protocol, err := cfg.Get(repoToFork.RepoHost(), "git_protocol")
if err != nil {
return err
}
protocol, _ := cfg.Get(repoToFork.RepoHost(), "git_protocol")
if inParent {
remotes, err := opts.Remotes()
@ -248,7 +245,7 @@ func forkRun(opts *ForkOptions) error {
if scheme != "" {
protocol = scheme
} else {
protocol = cfg.Default("git_protocol")
protocol = "https"
}
}
}

View file

@ -210,7 +210,7 @@ func TestRepoFork(t *testing.T) {
httpStubs func(*httpmock.Registry)
execStubs func(*run.CommandStubber)
askStubs func(*prompt.AskStubber)
cfg func(config.Config) config.Config
cfgStubs func(*config.ConfigMock)
remotes []*context.Remote
wantOut string
wantErrOut string
@ -253,9 +253,8 @@ func TestRepoFork(t *testing.T) {
Repo: ghrepo.New("OWNER", "REPO"),
},
},
cfg: func(c config.Config) config.Config {
_ = c.Set("", "git_protocol", "")
return c
cfgStubs: func(c *config.ConfigMock) {
c.Set("", "git_protocol", "")
},
httpStubs: forkPost,
execStubs: func(cs *run.CommandStubber) {
@ -679,8 +678,8 @@ func TestRepoFork(t *testing.T) {
}
cfg := config.NewBlankConfig()
if tt.cfg != nil {
cfg = tt.cfg(cfg)
if tt.cfgStubs != nil {
tt.cfgStubs(cfg)
}
tt.opts.Config = func() (config.Config, error) {
return cfg, nil

View file

@ -155,11 +155,7 @@ func gardenRun(opts *GardenOptions) error {
if err != nil {
return err
}
hostname, err := cfg.DefaultHost()
if err != nil {
return err
}
hostname, _ := cfg.DefaultHost()
currentUser, err := api.CurrentLoginName(apiClient, hostname)
if err != nil {
return err

View file

@ -119,10 +119,7 @@ func listRun(opts *ListOptions) error {
return err
}
host, err := cfg.DefaultHost()
if err != nil {
return err
}
host, _ := cfg.DefaultHost()
if opts.Detector == nil {
cachedClient := api.NewCachedHTTPClient(httpClient, time.Hour*24)

View file

@ -421,7 +421,7 @@ func TestRepoList_noVisibilityField(t *testing.T) {
return &http.Client{Transport: reg}, nil
},
Config: func() (config.Config, error) {
return config.InheritEnv(config.NewBlankConfig()), nil
return config.NewBlankConfig(), nil
},
Now: func() time.Time {
t, _ := time.Parse(time.RFC822, "19 Feb 21 15:00 UTC")

View file

@ -97,11 +97,7 @@ func viewRun(opts *ViewOptions) error {
if err != nil {
return err
}
hostname, err := cfg.DefaultHost()
if err != nil {
return err
}
hostname, _ := cfg.DefaultHost()
currentUser, err := api.CurrentLoginName(apiClient, hostname)
if err != nil {
return err

View file

@ -40,10 +40,7 @@ func Searcher(f *cmdutil.Factory) (search.Searcher, error) {
if err != nil {
return nil, err
}
host, err := cfg.DefaultHost()
if err != nil {
return nil, err
}
host, _ := cfg.DefaultHost()
client, err := f.HttpClient()
if err != nil {
return nil, err

View file

@ -122,10 +122,7 @@ func removeRun(opts *DeleteOptions) error {
return err
}
host, err := cfg.DefaultHost()
if err != nil {
return err
}
host, _ := cfg.DefaultHost()
err = client.REST(host, "DELETE", path, nil, nil)
if err != nil {

View file

@ -123,10 +123,7 @@ func listRun(opts *ListOptions) error {
return err
}
host, err = cfg.DefaultHost()
if err != nil {
return err
}
host, _ = cfg.DefaultHost()
if secretEntity == shared.User {
secrets, err = getUserSecrets(client, host, showSelectedRepoInfo)

View file

@ -186,11 +186,7 @@ func setRun(opts *SetOptions) error {
if err != nil {
return err
}
host, err = cfg.DefaultHost()
if err != nil {
return err
}
host, _ = cfg.DefaultHost()
}
secretEntity, err := shared.GetSecretEntity(orgName, envName, opts.UserSecrets)

View file

@ -77,10 +77,7 @@ func runAdd(opts *AddOptions) error {
return err
}
hostname, err := cfg.DefaultHost()
if err != nil {
return err
}
hostname, _ := cfg.DefaultHost()
err = SSHKeyUpload(httpClient, hostname, keyReader, opts.Title)
if err != nil {

View file

@ -51,10 +51,7 @@ func listRun(opts *ListOptions) error {
return err
}
host, err := cfg.DefaultHost()
if err != nil {
return err
}
host, _ := cfg.DefaultHost()
sshKeys, err := userKeys(apiClient, host, "")
if err != nil {

View file

@ -26,7 +26,7 @@ import (
)
type hostConfig interface {
DefaultHost() (string, error)
DefaultHost() (string, string)
}
type StatusOptions struct {
@ -619,10 +619,7 @@ func statusRun(opts *StatusOptions) error {
return fmt.Errorf("could not create client: %w", err)
}
hostname, err := opts.HostConfig.DefaultHost()
if err != nil {
return err
}
hostname, _ := opts.HostConfig.DefaultHost()
sg := NewStatusGetter(client, hostname, opts)

View file

@ -19,8 +19,8 @@ import (
type testHostConfig string
func (c testHostConfig) DefaultHost() (string, error) {
return string(c), nil
func (c testHostConfig) DefaultHost() (string, string) {
return string(c), ""
}
func TestNewCmdStatus(t *testing.T) {

View file

@ -14,20 +14,17 @@ func DisableAuthCheck(cmd *cobra.Command) {
}
func CheckAuth(cfg config.Config) bool {
if config.AuthTokenProvidedFromEnv() {
// This will check if there are any environment variable
// authentication tokens set for enterprise hosts.
// Any non-github.com hostname is fine here
dummyHostname := "example.com"
token, _ := cfg.AuthToken(dummyHostname)
if token != "" {
return true
}
hosts, err := cfg.Hosts()
if err != nil {
return false
}
for _, hostname := range hosts {
token, _ := cfg.Get(hostname, "oauth_token")
if token != "" {
return true
}
if len(cfg.Hosts()) > 0 {
return true
}
return false

View file

@ -1,7 +1,6 @@
package cmdutil
import (
"os"
"testing"
"github.com/cli/cli/v2/internal/config"
@ -9,56 +8,38 @@ import (
)
func Test_CheckAuth(t *testing.T) {
orig_GITHUB_TOKEN := os.Getenv("GITHUB_TOKEN")
t.Cleanup(func() {
os.Setenv("GITHUB_TOKEN", orig_GITHUB_TOKEN)
})
tests := []struct {
name string
cfg func(config.Config)
envToken bool
cfgStubs func(*config.ConfigMock)
expected bool
}{
{
name: "no hosts",
cfg: func(c config.Config) {},
envToken: false,
name: "no known hosts, no env auth token",
cfgStubs: func(c *config.ConfigMock) {},
expected: false,
},
{name: "no hosts, env auth token",
cfg: func(c config.Config) {},
envToken: true,
{
name: "no known hosts, env auth token",
cfgStubs: func(c *config.ConfigMock) {
c.AuthTokenFunc = func(string) (string, string) {
return "token", "GITHUB_TOKEN"
}
},
expected: true,
},
{
name: "host, no token",
cfg: func(c config.Config) {
_ = c.Set("github.com", "oauth_token", "")
name: "known host",
cfgStubs: func(c *config.ConfigMock) {
c.Set("github.com", "oauth_token", "token")
},
envToken: false,
expected: false,
},
{
name: "host, token",
cfg: func(c config.Config) {
_ = c.Set("github.com", "oauth_token", "a token")
},
envToken: false,
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.envToken {
os.Setenv("GITHUB_TOKEN", "TOKEN")
} else {
os.Setenv("GITHUB_TOKEN", "")
}
cfg := config.NewBlankConfig()
tt.cfg(cfg)
tt.cfgStubs(cfg)
result := CheckAuth(cfg)
assert.Equal(t, tt.expected, result)
})

View file

@ -31,10 +31,7 @@ func EnableRepoOverride(cmd *cobra.Command, f *Factory) {
if err != nil {
return nil, cobra.ShellCompDirectiveError
}
defaultHost, err := config.DefaultHost()
if err != nil {
return nil, cobra.ShellCompDirectiveError
}
defaultHost, _ := config.DefaultHost()
var results []string
for _, remote := range remotes {