Merge pull request #1077 from cli/auth-split

Write per-host config info to `hosts.yml` instead of `config.yml`
This commit is contained in:
Mislav Marohnić 2020-06-04 12:49:12 +02:00 committed by GitHub
commit 1036666266
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 418 additions and 288 deletions

View file

@ -11,8 +11,8 @@ endif
LDFLAGS := -X github.com/cli/cli/command.Version=$(GH_VERSION) $(LDFLAGS)
LDFLAGS := -X github.com/cli/cli/command.BuildDate=$(BUILD_DATE) $(LDFLAGS)
ifdef GH_OAUTH_CLIENT_SECRET
LDFLAGS := -X github.com/cli/cli/context.oauthClientID=$(GH_OAUTH_CLIENT_ID) $(LDFLAGS)
LDFLAGS := -X github.com/cli/cli/context.oauthClientSecret=$(GH_OAUTH_CLIENT_SECRET) $(LDFLAGS)
LDFLAGS := -X github.com/cli/cli/internal/config.oauthClientID=$(GH_OAUTH_CLIENT_ID) $(LDFLAGS)
LDFLAGS := -X github.com/cli/cli/internal/config.oauthClientSecret=$(GH_OAUTH_CLIENT_SECRET) $(LDFLAGS)
endif
bin/gh: $(BUILD_FILES)

View file

@ -165,6 +165,10 @@ func (c Client) HasScopes(wantedScopes ...string) (bool, string, error) {
}
defer res.Body.Close()
if res.StatusCode != 200 {
return false, "", handleHTTPError(res)
}
appID := res.Header.Get("X-Oauth-Client-Id")
hasScopes := strings.Split(res.Header.Get("X-Oauth-Scopes"), ",")

View file

@ -66,9 +66,9 @@ func (oa *OAuthFlow) ObtainAccessToken() (accessToken string, err error) {
fmt.Fprintf(os.Stderr, "Please open the following URL manually:\n%s\n", startURL)
fmt.Fprintf(os.Stderr, "")
// TODO: Temporary workaround for https://github.com/cli/cli/issues/297
fmt.Fprintf(os.Stderr, "If you are on a server or other headless system, use this workaround instead:")
fmt.Fprintf(os.Stderr, " 1. Complete authentication on a GUI system")
fmt.Fprintf(os.Stderr, " 2. Copy the contents of ~/.config/gh/config.yml to this system")
fmt.Fprintf(os.Stderr, "If you are on a server or other headless system, use this workaround instead:\n")
fmt.Fprintf(os.Stderr, " 1. Complete authentication on a GUI system;\n")
fmt.Fprintf(os.Stderr, " 2. Copy the contents of `~/.config/gh/hosts.yml` to this system.\n")
}
_ = http.Serve(listener, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

View file

@ -12,8 +12,9 @@ import (
func TestAliasSet_gh_command(t *testing.T) {
initBlankContext("", "OWNER/REPO", "trunk")
buf := bytes.NewBufferString("")
defer config.StubWriteConfig(buf)()
mainBuf := bytes.Buffer{}
hostsBuf := bytes.Buffer{}
defer config.StubWriteConfig(&mainBuf, &hostsBuf)()
_, err := RunCommand("alias set pr pr status")
if err == nil {
@ -26,15 +27,14 @@ func TestAliasSet_gh_command(t *testing.T) {
func TestAliasSet_empty_aliases(t *testing.T) {
cfg := `---
aliases:
hosts:
github.com:
user: OWNER
oauth_token: token123
editor: vim
`
initBlankContext(cfg, "OWNER/REPO", "trunk")
buf := bytes.NewBufferString("")
defer config.StubWriteConfig(buf)()
mainBuf := bytes.Buffer{}
hostsBuf := bytes.Buffer{}
defer config.StubWriteConfig(&mainBuf, &hostsBuf)()
output, err := RunCommand("alias set co pr checkout")
if err != nil {
@ -45,12 +45,9 @@ hosts:
expected := `aliases:
co: pr checkout
hosts:
github.com:
user: OWNER
oauth_token: token123
editor: vim
`
eq(t, buf.String(), expected)
eq(t, mainBuf.String(), expected)
}
func TestAliasSet_existing_alias(t *testing.T) {
@ -64,8 +61,10 @@ aliases:
`
initBlankContext(cfg, "OWNER/REPO", "trunk")
buf := bytes.NewBufferString("")
defer config.StubWriteConfig(buf)()
mainBuf := bytes.Buffer{}
hostsBuf := bytes.Buffer{}
defer config.StubWriteConfig(&mainBuf, &hostsBuf)()
output, err := RunCommand("alias set co pr checkout -Rcool/repo")
if err != nil {
@ -78,8 +77,9 @@ aliases:
func TestAliasSet_space_args(t *testing.T) {
initBlankContext("", "OWNER/REPO", "trunk")
buf := bytes.NewBufferString("")
defer config.StubWriteConfig(buf)()
mainBuf := bytes.Buffer{}
hostsBuf := bytes.Buffer{}
defer config.StubWriteConfig(&mainBuf, &hostsBuf)()
output, err := RunCommand(`alias set il issue list -l 'cool story'`)
@ -89,7 +89,7 @@ func TestAliasSet_space_args(t *testing.T) {
test.ExpectLines(t, output.String(), `Adding alias for il: issue list -l "cool story"`)
test.ExpectLines(t, buf.String(), `il: issue list -l "cool story"`)
test.ExpectLines(t, mainBuf.String(), `il: issue list -l "cool story"`)
}
func TestAliasSet_arg_processing(t *testing.T) {
@ -118,77 +118,66 @@ func TestAliasSet_arg_processing(t *testing.T) {
`ix: issue list --author=\$1 --label=\$2`},
}
var buf *bytes.Buffer
for _, c := range cases {
buf = bytes.NewBufferString("")
defer config.StubWriteConfig(buf)()
mainBuf := bytes.Buffer{}
hostsBuf := bytes.Buffer{}
defer config.StubWriteConfig(&mainBuf, &hostsBuf)()
output, err := RunCommand(c.Cmd)
if err != nil {
t.Fatalf("got unexpected error running %s: %s", c.Cmd, err)
}
test.ExpectLines(t, output.String(), c.ExpectedOutputLine)
test.ExpectLines(t, buf.String(), c.ExpectedConfigLine)
test.ExpectLines(t, mainBuf.String(), c.ExpectedConfigLine)
}
}
func TestAliasSet_init_alias_cfg(t *testing.T) {
cfg := `---
hosts:
github.com:
user: OWNER
oauth_token: token123
editor: vim
`
initBlankContext(cfg, "OWNER/REPO", "trunk")
buf := bytes.NewBufferString("")
defer config.StubWriteConfig(buf)()
mainBuf := bytes.Buffer{}
hostsBuf := bytes.Buffer{}
defer config.StubWriteConfig(&mainBuf, &hostsBuf)()
output, err := RunCommand("alias set diff pr diff")
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
expected := `hosts:
github.com:
user: OWNER
oauth_token: token123
expected := `editor: vim
aliases:
diff: pr diff
`
test.ExpectLines(t, output.String(), "Adding alias for diff: pr diff", "Added alias.")
eq(t, buf.String(), expected)
eq(t, mainBuf.String(), expected)
}
func TestAliasSet_existing_aliases(t *testing.T) {
cfg := `---
hosts:
github.com:
user: OWNER
oauth_token: token123
aliases:
foo: bar
`
initBlankContext(cfg, "OWNER/REPO", "trunk")
buf := bytes.NewBufferString("")
defer config.StubWriteConfig(buf)()
mainBuf := bytes.Buffer{}
hostsBuf := bytes.Buffer{}
defer config.StubWriteConfig(&mainBuf, &hostsBuf)()
output, err := RunCommand("alias set view pr view")
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
expected := `hosts:
github.com:
user: OWNER
oauth_token: token123
aliases:
expected := `aliases:
foo: bar
view: pr view
`
test.ExpectLines(t, output.String(), "Adding alias for view: pr view", "Added alias.")
eq(t, buf.String(), expected)
eq(t, mainBuf.String(), expected)
}

View file

@ -49,23 +49,31 @@ func TestConfigGet_not_found(t *testing.T) {
func TestConfigSet(t *testing.T) {
initBlankContext("", "OWNER/REPO", "master")
buf := bytes.NewBufferString("")
defer config.StubWriteConfig(buf)()
mainBuf := bytes.Buffer{}
hostsBuf := bytes.Buffer{}
defer config.StubWriteConfig(&mainBuf, &hostsBuf)()
output, err := RunCommand("config set editor ed")
if err != nil {
t.Fatalf("error running command `config set editor ed`: %v", err)
}
eq(t, output.String(), "")
if len(output.String()) > 0 {
t.Errorf("expected output to be blank: %q", output.String())
}
expected := `hosts:
github.com:
user: OWNER
oauth_token: 1234567890
editor: ed
expectedMain := "editor: ed\n"
expectedHosts := `github.com:
user: OWNER
oauth_token: "1234567890"
`
eq(t, buf.String(), expected)
if mainBuf.String() != expectedMain {
t.Errorf("expected config.yml to be %q, got %q", expectedMain, mainBuf.String())
}
if hostsBuf.String() != expectedHosts {
t.Errorf("expected hosts.yml to be %q, got %q", expectedHosts, hostsBuf.String())
}
}
func TestConfigSet_update(t *testing.T) {
@ -79,23 +87,31 @@ editor: ed
initBlankContext(cfg, "OWNER/REPO", "master")
buf := bytes.NewBufferString("")
defer config.StubWriteConfig(buf)()
mainBuf := bytes.Buffer{}
hostsBuf := bytes.Buffer{}
defer config.StubWriteConfig(&mainBuf, &hostsBuf)()
output, err := RunCommand("config set editor vim")
if err != nil {
t.Fatalf("error running command `config get editor`: %v", err)
}
eq(t, output.String(), "")
if len(output.String()) > 0 {
t.Errorf("expected output to be blank: %q", output.String())
}
expected := `hosts:
github.com:
user: OWNER
oauth_token: MUSTBEHIGHCUZIMATOKEN
editor: vim
expectedMain := "editor: vim\n"
expectedHosts := `github.com:
user: OWNER
oauth_token: MUSTBEHIGHCUZIMATOKEN
`
eq(t, buf.String(), expected)
if mainBuf.String() != expectedMain {
t.Errorf("expected config.yml to be %q, got %q", expectedMain, mainBuf.String())
}
if hostsBuf.String() != expectedHosts {
t.Errorf("expected hosts.yml to be %q, got %q", expectedHosts, hostsBuf.String())
}
}
func TestConfigGetHost(t *testing.T) {
@ -141,23 +157,32 @@ git_protocol: ssh
func TestConfigSetHost(t *testing.T) {
initBlankContext("", "OWNER/REPO", "master")
buf := bytes.NewBufferString("")
defer config.StubWriteConfig(buf)()
mainBuf := bytes.Buffer{}
hostsBuf := bytes.Buffer{}
defer config.StubWriteConfig(&mainBuf, &hostsBuf)()
output, err := RunCommand("config set -hgithub.com git_protocol ssh")
if err != nil {
t.Fatalf("error running command `config set editor ed`: %v", err)
}
eq(t, output.String(), "")
if len(output.String()) > 0 {
t.Errorf("expected output to be blank: %q", output.String())
}
expected := `hosts:
github.com:
user: OWNER
oauth_token: 1234567890
git_protocol: ssh
expectedMain := ""
expectedHosts := `github.com:
user: OWNER
oauth_token: "1234567890"
git_protocol: ssh
`
eq(t, buf.String(), expected)
if mainBuf.String() != expectedMain {
t.Errorf("expected config.yml to be %q, got %q", expectedMain, mainBuf.String())
}
if hostsBuf.String() != expectedHosts {
t.Errorf("expected hosts.yml to be %q, got %q", expectedHosts, hostsBuf.String())
}
}
func TestConfigSetHost_update(t *testing.T) {
@ -171,21 +196,30 @@ hosts:
initBlankContext(cfg, "OWNER/REPO", "master")
buf := bytes.NewBufferString("")
defer config.StubWriteConfig(buf)()
mainBuf := bytes.Buffer{}
hostsBuf := bytes.Buffer{}
defer config.StubWriteConfig(&mainBuf, &hostsBuf)()
output, err := RunCommand("config set -hgithub.com git_protocol https")
if err != nil {
t.Fatalf("error running command `config get editor`: %v", err)
}
eq(t, output.String(), "")
if len(output.String()) > 0 {
t.Errorf("expected output to be blank: %q", output.String())
}
expected := `hosts:
github.com:
git_protocol: https
user: OWNER
oauth_token: MUSTBEHIGHCUZIMATOKEN
expectedMain := ""
expectedHosts := `github.com:
git_protocol: https
user: OWNER
oauth_token: MUSTBEHIGHCUZIMATOKEN
`
eq(t, buf.String(), expected)
if mainBuf.String() != expectedMain {
t.Errorf("expected config.yml to be %q, got %q", expectedMain, mainBuf.String())
}
if hostsBuf.String() != expectedHosts {
t.Errorf("expected hosts.yml to be %q, got %q", expectedHosts, hostsBuf.String())
}
}

View file

@ -181,24 +181,16 @@ var apiClientForContext = func(ctx context.Context) (*api.Client, error) {
checkScopesFunc := func(appID string) error {
if config.IsGitHubApp(appID) && !tokenFromEnv() && utils.IsTerminal(os.Stdin) && utils.IsTerminal(os.Stderr) {
newToken, loginHandle, err := config.AuthFlow("Notice: additional authorization required")
if err != nil {
return err
}
cfg, err := ctx.Config()
if err != nil {
return err
}
_ = cfg.Set(defaultHostname, "oauth_token", newToken)
_ = cfg.Set(defaultHostname, "user", loginHandle)
// update config file on disk
err = cfg.Write()
newToken, err := config.AuthFlowWithConfig(cfg, defaultHostname, "Notice: additional authorization required")
if err != nil {
return err
}
// update configuration in memory
token = newToken
config.AuthFlowComplete()
} else {
fmt.Fprintln(os.Stderr, "Warning: gh now requires the `read:org` OAuth scope.")
fmt.Fprintln(os.Stderr, "Visit https://github.com/settings/tokens and edit your token to enable `read:org`")
@ -235,28 +227,19 @@ var ensureScopes = func(ctx context.Context, client *api.Client, wantedScopes ..
tokenFromEnv := len(os.Getenv("GITHUB_TOKEN")) > 0
if config.IsGitHubApp(appID) && !tokenFromEnv && utils.IsTerminal(os.Stdin) && utils.IsTerminal(os.Stderr) {
newToken, loginHandle, err := config.AuthFlow("Notice: additional authorization required")
if err != nil {
return client, err
}
cfg, err := ctx.Config()
if err != nil {
return client, err
return nil, err
}
_ = cfg.Set(defaultHostname, "oauth_token", newToken)
_ = cfg.Set(defaultHostname, "user", loginHandle)
// update config file on disk
err = cfg.Write()
_, err = config.AuthFlowWithConfig(cfg, defaultHostname, "Notice: additional authorization required")
if err != nil {
return client, err
return nil, err
}
// update configuration in memory
config.AuthFlowComplete()
reloadedClient, err := apiClientForContext(ctx)
if err != nil {
return client, err
}
return reloadedClient, nil
} else {
fmt.Fprintln(os.Stderr, fmt.Sprintf("Warning: gh now requires %s OAuth scopes.", wantedScopes))

View file

@ -19,7 +19,7 @@ import (
const defaultTestConfig = `hosts:
github.com:
user: OWNER
oauth_token: 1234567890
oauth_token: "1234567890"
`
type askStubber struct {
@ -88,7 +88,7 @@ func initBlankContext(cfg, repo, branch string) {
// NOTE we are not restoring the original readConfig; we never want to touch the config file on
// disk during tests.
config.StubConfig(cfg)
config.StubConfig(cfg, "")
return ctx
}

View file

@ -23,7 +23,7 @@ type blankContext struct {
}
func (c *blankContext) Config() (config.Config, error) {
cfg, err := config.ParseConfig("boom.txt")
cfg, err := config.ParseConfig("config.yml")
if err != nil {
panic(fmt.Sprintf("failed to parse config during tests. did you remember to stub? error: %s", err))
}

View file

@ -3,6 +3,7 @@ package context
import (
"errors"
"fmt"
"os"
"sort"
"github.com/cli/cli/api"
@ -162,11 +163,13 @@ type fsContext struct {
func (c *fsContext) Config() (config.Config, error) {
if c.config == nil {
config, err := config.ParseOrSetupConfigFile(config.ConfigFile())
if err != nil {
cfg, err := config.ParseDefaultConfig()
if errors.Is(err, os.ErrNotExist) {
cfg = config.NewBlankConfig()
} else if err != nil {
return nil, err
}
c.config = config
c.config = cfg
c.authToken = ""
}
return c.config, nil
@ -182,8 +185,12 @@ func (c *fsContext) AuthToken() (string, error) {
return "", err
}
var notFound *config.NotFoundError
token, err := cfg.Get(defaultHostname, "oauth_token")
if token == "" || err != nil {
if token == "" || errors.As(err, &notFound) {
// interactive OAuth flow
return config.AuthFlowWithConfig(cfg, defaultHostname, "Notice: authentication required")
} else if err != nil {
return "", err
}

View file

@ -21,20 +21,16 @@ func ConfigFile() string {
return path.Join(ConfigDir(), "config.yml")
}
func ParseOrSetupConfigFile(fn string) (Config, error) {
config, err := ParseConfig(fn)
if err != nil && errors.Is(err, os.ErrNotExist) {
return setupConfigFile(fn)
}
return config, err
func hostsConfigFile(filename string) string {
return path.Join(path.Dir(filename), "hosts.yml")
}
func ParseDefaultConfig() (Config, error) {
return ParseConfig(ConfigFile())
}
var ReadConfigFile = func(fn string) ([]byte, error) {
f, err := os.Open(fn)
var ReadConfigFile = func(filename string) ([]byte, error) {
f, err := os.Open(filename)
if err != nil {
return nil, err
}
@ -48,8 +44,13 @@ var ReadConfigFile = func(fn string) ([]byte, error) {
return data, nil
}
var WriteConfigFile = func(fn string, data []byte) error {
cfgFile, err := os.OpenFile(ConfigFile(), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) // cargo coded from setup
var WriteConfigFile = func(filename string, data []byte) error {
err := os.MkdirAll(path.Dir(filename), 0771)
if err != nil {
return err
}
cfgFile, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) // cargo coded from setup
if err != nil {
return err
}
@ -63,12 +64,12 @@ var WriteConfigFile = func(fn string, data []byte) error {
return err
}
var BackupConfigFile = func(fn string) error {
return os.Rename(fn, fn+".bak")
var BackupConfigFile = func(filename string) error {
return os.Rename(filename, filename+".bak")
}
func parseConfigFile(fn string) ([]byte, *yaml.Node, error) {
data, err := ReadConfigFile(fn)
func parseConfigFile(filename string) ([]byte, *yaml.Node, error) {
data, err := ReadConfigFile(filename)
if err != nil {
return nil, nil, err
}
@ -78,8 +79,11 @@ func parseConfigFile(fn string) ([]byte, *yaml.Node, error) {
if err != nil {
return data, nil, err
}
if len(root.Content) < 1 {
return data, &root, fmt.Errorf("malformed config")
if len(root.Content) == 0 {
return data, &yaml.Node{
Kind: yaml.DocumentNode,
Content: []*yaml.Node{{Kind: yaml.MappingNode}},
}, nil
}
if root.Content[0].Kind != yaml.MappingNode {
return data, &root, fmt.Errorf("expected a top level map")
@ -90,91 +94,76 @@ func parseConfigFile(fn string) ([]byte, *yaml.Node, error) {
func isLegacy(root *yaml.Node) bool {
for _, v := range root.Content[0].Content {
if v.Value == "hosts" {
return false
if v.Value == "github.com" {
return true
}
}
return true
return false
}
func migrateConfig(fn string, root *yaml.Node) error {
type ConfigEntry map[string]string
type ConfigHash map[string]ConfigEntry
newConfigData := map[string]ConfigHash{}
newConfigData["hosts"] = ConfigHash{}
topLevelKeys := root.Content[0].Content
for i, x := range topLevelKeys {
if x.Value == "" {
continue
}
if i+1 == len(topLevelKeys) {
break
}
hostname := x.Value
newConfigData["hosts"][hostname] = ConfigEntry{}
authKeys := topLevelKeys[i+1].Content[0].Content
for j, y := range authKeys {
if j+1 == len(authKeys) {
break
}
switch y.Value {
case "user":
newConfigData["hosts"][hostname]["user"] = authKeys[j+1].Value
case "oauth_token":
newConfigData["hosts"][hostname]["oauth_token"] = authKeys[j+1].Value
}
}
}
if _, ok := newConfigData["hosts"][defaultHostname]; !ok {
return errors.New("could not find default host configuration")
}
defaultHostConfig := newConfigData["hosts"][defaultHostname]
if _, ok := defaultHostConfig["user"]; !ok {
return errors.New("default host configuration missing user")
}
if _, ok := defaultHostConfig["oauth_token"]; !ok {
return errors.New("default host configuration missing oauth_token")
}
newConfig, err := yaml.Marshal(newConfigData)
func migrateConfig(filename string) error {
b, err := ReadConfigFile(filename)
if err != nil {
return err
}
err = BackupConfigFile(fn)
var hosts map[string][]yaml.Node
err = yaml.Unmarshal(b, &hosts)
if err != nil {
return fmt.Errorf("error decoding legacy format: %w", err)
}
cfg := NewBlankConfig()
for hostname, entries := range hosts {
if len(entries) < 1 {
continue
}
mapContent := entries[0].Content
for i := 0; i < len(mapContent)-1; i += 2 {
if err := cfg.Set(hostname, mapContent[i].Value, mapContent[i+1].Value); err != nil {
return err
}
}
}
err = BackupConfigFile(filename)
if err != nil {
return fmt.Errorf("failed to back up existing config: %w", err)
}
return WriteConfigFile(fn, newConfig)
return cfg.Write()
}
func ParseConfig(fn string) (Config, error) {
_, root, err := parseConfigFile(fn)
func ParseConfig(filename string) (Config, error) {
_, root, err := parseConfigFile(filename)
if err != nil {
return nil, err
}
if isLegacy(root) {
err = migrateConfig(fn, root)
err = migrateConfig(filename)
if err != nil {
return nil, err
return nil, fmt.Errorf("error migrating legacy config: %w", err)
}
_, root, err = parseConfigFile(fn)
_, root, err = parseConfigFile(filename)
if err != nil {
return nil, fmt.Errorf("failed to reparse migrated config: %w", err)
}
} else {
if _, hostsRoot, err := parseConfigFile(hostsConfigFile(filename)); err == nil {
if len(hostsRoot.Content[0].Content) > 0 {
newContent := []*yaml.Node{
{Value: "hosts"},
hostsRoot.Content[0],
}
restContent := root.Content[0].Content
root.Content[0].Content = append(newContent, restContent...)
}
} else if !errors.Is(err, os.ErrNotExist) {
return nil, err
}
}
return NewConfig(root), nil

View file

@ -3,6 +3,7 @@ package config
import (
"bytes"
"errors"
"fmt"
"reflect"
"testing"
@ -22,8 +23,8 @@ hosts:
github.com:
user: monalisa
oauth_token: OTOKEN
`)()
config, err := ParseConfig("filename")
`, "")()
config, err := ParseConfig("config.yml")
eq(t, err, nil)
user, err := config.Get("github.com", "user")
eq(t, err, nil)
@ -42,8 +43,24 @@ hosts:
github.com:
user: monalisa
oauth_token: OTOKEN
`, "")()
config, err := ParseConfig("config.yml")
eq(t, err, nil)
user, err := config.Get("github.com", "user")
eq(t, err, nil)
eq(t, user, "monalisa")
token, err := config.Get("github.com", "oauth_token")
eq(t, err, nil)
eq(t, token, "OTOKEN")
}
func Test_parseConfig_hostsFile(t *testing.T) {
defer StubConfig("", `---
github.com:
user: monalisa
oauth_token: OTOKEN
`)()
config, err := ParseConfig("filename")
config, err := ParseConfig("config.yml")
eq(t, err, nil)
user, err := config.Get("github.com", "user")
eq(t, err, nil)
@ -59,38 +76,47 @@ hosts:
example.com:
user: wronguser
oauth_token: NOTTHIS
`)()
config, err := ParseConfig("filename")
`, "")()
config, err := ParseConfig("config.yml")
eq(t, err, nil)
_, err = config.Get("github.com", "user")
eq(t, err, errors.New(`could not find config entry for "github.com"`))
eq(t, err, &NotFoundError{errors.New(`could not find config entry for "github.com"`)})
}
func Test_migrateConfig(t *testing.T) {
oldStyle := `---
func Test_ParseConfig_migrateConfig(t *testing.T) {
defer StubConfig(`---
github.com:
- user: keiyuri
oauth_token: 123456`
var root yaml.Node
err := yaml.Unmarshal([]byte(oldStyle), &root)
if err != nil {
panic("failed to parse test yaml")
}
buf := bytes.NewBufferString("")
defer StubWriteConfig(buf)()
oauth_token: 123456
`, "")()
mainBuf := bytes.Buffer{}
hostsBuf := bytes.Buffer{}
defer StubWriteConfig(&mainBuf, &hostsBuf)()
defer StubBackupConfig()()
err = migrateConfig("boom.txt", &root)
_, err := ParseConfig("config.yml")
eq(t, err, nil)
expected := `hosts:
github.com:
oauth_token: "123456"
user: keiyuri
expectedMain := ""
expectedHosts := `github.com:
user: keiyuri
oauth_token: "123456"
`
eq(t, buf.String(), expected)
eq(t, mainBuf.String(), expectedMain)
eq(t, hostsBuf.String(), expectedHosts)
}
func Test_parseConfigFile(t *testing.T) {
fileContents := []string{"", " ", "\n"}
for _, contents := range fileContents {
t.Run(fmt.Sprintf("contents: %q", contents), func(t *testing.T) {
defer StubConfig(contents, "")()
_, yamlRoot, err := parseConfigFile("config.yml")
eq(t, err, nil)
eq(t, yamlRoot.Content[0].Kind, yaml.MappingNode)
eq(t, len(yamlRoot.Content[0].Content), 0)
})
}
}

View file

@ -5,16 +5,10 @@ import (
"fmt"
"io"
"os"
"path/filepath"
"strings"
"github.com/cli/cli/api"
"github.com/cli/cli/auth"
"gopkg.in/yaml.v3"
)
const (
oauthHost = "github.com"
)
var (
@ -31,7 +25,31 @@ func IsGitHubApp(id string) bool {
return id == "178c6fc778ccc68e1d6a" || id == "4d747ba5675d5d66553f"
}
func AuthFlow(notice string) (string, string, error) {
func AuthFlowWithConfig(cfg Config, hostname, notice string) (string, error) {
token, userLogin, err := authFlow(hostname, notice)
if err != nil {
return "", err
}
err = cfg.Set(hostname, "user", userLogin)
if err != nil {
return "", err
}
err = cfg.Set(hostname, "oauth_token", token)
if err != nil {
return "", err
}
err = cfg.Write()
if err != nil {
return "", err
}
AuthFlowComplete()
return token, nil
}
func authFlow(oauthHost, notice string) (string, string, error) {
var verboseStream io.Writer
if strings.Contains(os.Getenv("DEBUG"), "oauth") {
verboseStream = os.Stderr
@ -69,53 +87,6 @@ func AuthFlowComplete() {
_ = waitForEnter(os.Stdin)
}
// FIXME: make testable
func setupConfigFile(filename string) (Config, error) {
token, userLogin, err := AuthFlow("Notice: authentication required")
if err != nil {
return nil, err
}
// TODO this sucks. It precludes us laying out a nice config with comments and such.
type yamlConfig struct {
Hosts map[string]map[string]string
}
yamlHosts := map[string]map[string]string{}
yamlHosts[oauthHost] = map[string]string{}
yamlHosts[oauthHost]["user"] = userLogin
yamlHosts[oauthHost]["oauth_token"] = token
defaultConfig := yamlConfig{
Hosts: yamlHosts,
}
err = os.MkdirAll(filepath.Dir(filename), 0771)
if err != nil {
return nil, err
}
cfgFile, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return nil, err
}
defer cfgFile.Close()
yamlData, err := yaml.Marshal(defaultConfig)
if err != nil {
return nil, err
}
_, err = cfgFile.Write(yamlData)
if err != nil {
return nil, err
}
AuthFlowComplete()
// TODO cleaner error handling? this "should" always work given that we /just/ wrote the file...
return ParseConfig(filename)
}
func getViewer(token string) (string, error) {
http := api.NewClient(api.AddHeader("Authorization", fmt.Sprintf("token %s", token)))
return api.CurrentLoginName(http)

View file

@ -1,18 +1,17 @@
package config
import (
"bytes"
"errors"
"fmt"
"gopkg.in/yaml.v3"
)
const defaultHostname = "github.com"
const defaultGitProtocol = "https"
// This interface describes interacting with some persistent configuration for gh.
type Config interface {
Hosts() ([]*HostConfig, error)
Get(string, string) (string, error)
Set(string, string, string) error
Aliases() (*AliasConfig, error)
@ -61,6 +60,7 @@ func (cm *ConfigMap) SetStringValue(key, value string) error {
}
valueNode = &yaml.Node{
Kind: yaml.ScalarNode,
Tag: "!!str",
Value: "",
}
@ -107,11 +107,17 @@ func NewConfig(root *yaml.Node) Config {
}
}
func NewBlankConfig() Config {
return NewConfig(&yaml.Node{
Kind: yaml.DocumentNode,
Content: []*yaml.Node{{Kind: yaml.MappingNode}},
})
}
// This type implements a Config interface and represents a config file on disk.
type fileConfig struct {
ConfigMap
documentRoot *yaml.Node
hosts []*HostConfig
}
func (c *fileConfig) Root() *yaml.Node {
@ -159,7 +165,10 @@ func (c *fileConfig) Set(hostname, key, value string) error {
return c.SetStringValue(key, value)
} else {
hostCfg, err := c.configForHost(hostname)
if err != nil {
var notFound *NotFoundError
if errors.As(err, &notFound) {
hostCfg = c.makeConfigForHost(hostname)
} else if err != nil {
return err
}
return hostCfg.SetStringValue(key, value)
@ -167,7 +176,7 @@ func (c *fileConfig) Set(hostname, key, value string) error {
}
func (c *fileConfig) configForHost(hostname string) (*HostConfig, error) {
hosts, err := c.Hosts()
hosts, err := c.hostEntries()
if err != nil {
return nil, fmt.Errorf("failed to parse hosts config: %w", err)
}
@ -177,16 +186,46 @@ func (c *fileConfig) configForHost(hostname string) (*HostConfig, error) {
return hc, nil
}
}
return nil, fmt.Errorf("could not find config entry for %q", hostname)
return nil, &NotFoundError{fmt.Errorf("could not find config entry for %q", hostname)}
}
func (c *fileConfig) Write() error {
marshalled, err := yaml.Marshal(c.documentRoot)
mainData := yaml.Node{Kind: yaml.MappingNode}
hostsData := yaml.Node{Kind: yaml.MappingNode}
nodes := c.documentRoot.Content[0].Content
for i := 0; i < len(nodes)-1; i += 2 {
if nodes[i].Value == "hosts" {
hostsData.Content = append(hostsData.Content, nodes[i+1].Content...)
} else {
mainData.Content = append(mainData.Content, nodes[i], nodes[i+1])
}
}
mainBytes, err := yaml.Marshal(&mainData)
if err != nil {
return err
}
return WriteConfigFile(ConfigFile(), marshalled)
filename := ConfigFile()
err = WriteConfigFile(filename, yamlNormalize(mainBytes))
if err != nil {
return err
}
hostsBytes, err := yaml.Marshal(&hostsData)
if err != nil {
return err
}
return WriteConfigFile(hostsConfigFile(filename), yamlNormalize(hostsBytes))
}
func yamlNormalize(b []byte) []byte {
if bytes.Equal(b, []byte("{}\n")) {
return []byte{}
}
return b
}
func (c *fileConfig) Aliases() (*AliasConfig, error) {
@ -243,11 +282,7 @@ func (c *fileConfig) Aliases() (*AliasConfig, error) {
}, nil
}
func (c *fileConfig) Hosts() ([]*HostConfig, error) {
if len(c.hosts) > 0 {
return c.hosts, nil
}
func (c *fileConfig) hostEntries() ([]*HostConfig, error) {
entry, err := c.FindEntry("hosts")
if err != nil {
return nil, fmt.Errorf("could not find hosts config: %w", err)
@ -258,11 +293,39 @@ func (c *fileConfig) Hosts() ([]*HostConfig, error) {
return nil, fmt.Errorf("could not parse hosts config: %w", err)
}
c.hosts = hostConfigs
return hostConfigs, nil
}
func (c *fileConfig) makeConfigForHost(hostname string) *HostConfig {
hostRoot := &yaml.Node{Kind: yaml.MappingNode}
hostCfg := &HostConfig{
Host: hostname,
ConfigMap: ConfigMap{Root: hostRoot},
}
var notFound *NotFoundError
hostsEntry, err := c.FindEntry("hosts")
if errors.As(err, &notFound) {
hostsEntry.KeyNode = &yaml.Node{
Kind: yaml.ScalarNode,
Value: "hosts",
}
hostsEntry.ValueNode = &yaml.Node{Kind: yaml.MappingNode}
root := c.Root()
root.Content = append(root.Content, hostsEntry.KeyNode, hostsEntry.ValueNode)
} else if err != nil {
panic(err)
}
hostsEntry.ValueNode.Content = append(hostsEntry.ValueNode.Content,
&yaml.Node{
Kind: yaml.ScalarNode,
Value: hostname,
}, hostRoot)
return hostCfg
}
func (c *fileConfig) parseHosts(hostsEntry *yaml.Node) ([]*HostConfig, error) {
hostConfigs := []*HostConfig{}

View file

@ -0,0 +1,41 @@
package config
import (
"bytes"
"testing"
"github.com/stretchr/testify/assert"
)
func Test_fileConfig_Set(t *testing.T) {
mainBuf := bytes.Buffer{}
hostsBuf := bytes.Buffer{}
defer StubWriteConfig(&mainBuf, &hostsBuf)()
c := NewBlankConfig()
assert.NoError(t, c.Set("", "editor", "nano"))
assert.NoError(t, c.Set("github.com", "git_protocol", "ssh"))
assert.NoError(t, c.Set("example.com", "editor", "vim"))
assert.NoError(t, c.Set("github.com", "user", "hubot"))
assert.NoError(t, c.Write())
assert.Equal(t, "editor: nano\n", mainBuf.String())
assert.Equal(t, `github.com:
git_protocol: ssh
user: hubot
example.com:
editor: vim
`, hostsBuf.String())
}
func Test_fileConfig_Write(t *testing.T) {
mainBuf := bytes.Buffer{}
hostsBuf := bytes.Buffer{}
defer StubWriteConfig(&mainBuf, &hostsBuf)()
c := NewBlankConfig()
assert.NoError(t, c.Write())
assert.Equal(t, "", mainBuf.String())
assert.Equal(t, "", hostsBuf.String())
}

View file

@ -1,7 +1,10 @@
package config
import (
"fmt"
"io"
"os"
"path"
)
func StubBackupConfig() func() {
@ -15,21 +18,41 @@ func StubBackupConfig() func() {
}
}
func StubWriteConfig(w io.Writer) func() {
func StubWriteConfig(wc io.Writer, wh io.Writer) func() {
orig := WriteConfigFile
WriteConfigFile = func(fn string, data []byte) error {
_, err := w.Write(data)
return err
switch path.Base(fn) {
case "config.yml":
_, err := wc.Write(data)
return err
case "hosts.yml":
_, err := wh.Write(data)
return err
default:
return fmt.Errorf("write to unstubbed file: %q", fn)
}
}
return func() {
WriteConfigFile = orig
}
}
func StubConfig(content string) func() {
func StubConfig(main, hosts string) func() {
orig := ReadConfigFile
ReadConfigFile = func(fn string) ([]byte, error) {
return []byte(content), nil
switch path.Base(fn) {
case "config.yml":
return []byte(main), nil
case "hosts.yml":
if hosts == "" {
return []byte(nil), os.ErrNotExist
} else {
return []byte(hosts), nil
}
default:
return []byte(nil), fmt.Errorf("read from unstubbed file: %q", fn)
}
}
return func() {
ReadConfigFile = orig