Merge pull request #1077 from cli/auth-split
Write per-host config info to `hosts.yml` instead of `config.yml`
This commit is contained in:
commit
1036666266
15 changed files with 418 additions and 288 deletions
4
Makefile
4
Makefile
|
|
@ -11,8 +11,8 @@ endif
|
|||
LDFLAGS := -X github.com/cli/cli/command.Version=$(GH_VERSION) $(LDFLAGS)
|
||||
LDFLAGS := -X github.com/cli/cli/command.BuildDate=$(BUILD_DATE) $(LDFLAGS)
|
||||
ifdef GH_OAUTH_CLIENT_SECRET
|
||||
LDFLAGS := -X github.com/cli/cli/context.oauthClientID=$(GH_OAUTH_CLIENT_ID) $(LDFLAGS)
|
||||
LDFLAGS := -X github.com/cli/cli/context.oauthClientSecret=$(GH_OAUTH_CLIENT_SECRET) $(LDFLAGS)
|
||||
LDFLAGS := -X github.com/cli/cli/internal/config.oauthClientID=$(GH_OAUTH_CLIENT_ID) $(LDFLAGS)
|
||||
LDFLAGS := -X github.com/cli/cli/internal/config.oauthClientSecret=$(GH_OAUTH_CLIENT_SECRET) $(LDFLAGS)
|
||||
endif
|
||||
|
||||
bin/gh: $(BUILD_FILES)
|
||||
|
|
|
|||
|
|
@ -165,6 +165,10 @@ func (c Client) HasScopes(wantedScopes ...string) (bool, string, error) {
|
|||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != 200 {
|
||||
return false, "", handleHTTPError(res)
|
||||
}
|
||||
|
||||
appID := res.Header.Get("X-Oauth-Client-Id")
|
||||
hasScopes := strings.Split(res.Header.Get("X-Oauth-Scopes"), ",")
|
||||
|
||||
|
|
|
|||
|
|
@ -66,9 +66,9 @@ func (oa *OAuthFlow) ObtainAccessToken() (accessToken string, err error) {
|
|||
fmt.Fprintf(os.Stderr, "Please open the following URL manually:\n%s\n", startURL)
|
||||
fmt.Fprintf(os.Stderr, "")
|
||||
// TODO: Temporary workaround for https://github.com/cli/cli/issues/297
|
||||
fmt.Fprintf(os.Stderr, "If you are on a server or other headless system, use this workaround instead:")
|
||||
fmt.Fprintf(os.Stderr, " 1. Complete authentication on a GUI system")
|
||||
fmt.Fprintf(os.Stderr, " 2. Copy the contents of ~/.config/gh/config.yml to this system")
|
||||
fmt.Fprintf(os.Stderr, "If you are on a server or other headless system, use this workaround instead:\n")
|
||||
fmt.Fprintf(os.Stderr, " 1. Complete authentication on a GUI system;\n")
|
||||
fmt.Fprintf(os.Stderr, " 2. Copy the contents of `~/.config/gh/hosts.yml` to this system.\n")
|
||||
}
|
||||
|
||||
_ = http.Serve(listener, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
|
|
|||
|
|
@ -12,8 +12,9 @@ import (
|
|||
func TestAliasSet_gh_command(t *testing.T) {
|
||||
initBlankContext("", "OWNER/REPO", "trunk")
|
||||
|
||||
buf := bytes.NewBufferString("")
|
||||
defer config.StubWriteConfig(buf)()
|
||||
mainBuf := bytes.Buffer{}
|
||||
hostsBuf := bytes.Buffer{}
|
||||
defer config.StubWriteConfig(&mainBuf, &hostsBuf)()
|
||||
|
||||
_, err := RunCommand("alias set pr pr status")
|
||||
if err == nil {
|
||||
|
|
@ -26,15 +27,14 @@ func TestAliasSet_gh_command(t *testing.T) {
|
|||
func TestAliasSet_empty_aliases(t *testing.T) {
|
||||
cfg := `---
|
||||
aliases:
|
||||
hosts:
|
||||
github.com:
|
||||
user: OWNER
|
||||
oauth_token: token123
|
||||
editor: vim
|
||||
`
|
||||
initBlankContext(cfg, "OWNER/REPO", "trunk")
|
||||
|
||||
buf := bytes.NewBufferString("")
|
||||
defer config.StubWriteConfig(buf)()
|
||||
mainBuf := bytes.Buffer{}
|
||||
hostsBuf := bytes.Buffer{}
|
||||
defer config.StubWriteConfig(&mainBuf, &hostsBuf)()
|
||||
|
||||
output, err := RunCommand("alias set co pr checkout")
|
||||
|
||||
if err != nil {
|
||||
|
|
@ -45,12 +45,9 @@ hosts:
|
|||
|
||||
expected := `aliases:
|
||||
co: pr checkout
|
||||
hosts:
|
||||
github.com:
|
||||
user: OWNER
|
||||
oauth_token: token123
|
||||
editor: vim
|
||||
`
|
||||
eq(t, buf.String(), expected)
|
||||
eq(t, mainBuf.String(), expected)
|
||||
}
|
||||
|
||||
func TestAliasSet_existing_alias(t *testing.T) {
|
||||
|
|
@ -64,8 +61,10 @@ aliases:
|
|||
`
|
||||
initBlankContext(cfg, "OWNER/REPO", "trunk")
|
||||
|
||||
buf := bytes.NewBufferString("")
|
||||
defer config.StubWriteConfig(buf)()
|
||||
mainBuf := bytes.Buffer{}
|
||||
hostsBuf := bytes.Buffer{}
|
||||
defer config.StubWriteConfig(&mainBuf, &hostsBuf)()
|
||||
|
||||
output, err := RunCommand("alias set co pr checkout -Rcool/repo")
|
||||
|
||||
if err != nil {
|
||||
|
|
@ -78,8 +77,9 @@ aliases:
|
|||
func TestAliasSet_space_args(t *testing.T) {
|
||||
initBlankContext("", "OWNER/REPO", "trunk")
|
||||
|
||||
buf := bytes.NewBufferString("")
|
||||
defer config.StubWriteConfig(buf)()
|
||||
mainBuf := bytes.Buffer{}
|
||||
hostsBuf := bytes.Buffer{}
|
||||
defer config.StubWriteConfig(&mainBuf, &hostsBuf)()
|
||||
|
||||
output, err := RunCommand(`alias set il issue list -l 'cool story'`)
|
||||
|
||||
|
|
@ -89,7 +89,7 @@ func TestAliasSet_space_args(t *testing.T) {
|
|||
|
||||
test.ExpectLines(t, output.String(), `Adding alias for il: issue list -l "cool story"`)
|
||||
|
||||
test.ExpectLines(t, buf.String(), `il: issue list -l "cool story"`)
|
||||
test.ExpectLines(t, mainBuf.String(), `il: issue list -l "cool story"`)
|
||||
}
|
||||
|
||||
func TestAliasSet_arg_processing(t *testing.T) {
|
||||
|
|
@ -118,77 +118,66 @@ func TestAliasSet_arg_processing(t *testing.T) {
|
|||
`ix: issue list --author=\$1 --label=\$2`},
|
||||
}
|
||||
|
||||
var buf *bytes.Buffer
|
||||
for _, c := range cases {
|
||||
buf = bytes.NewBufferString("")
|
||||
defer config.StubWriteConfig(buf)()
|
||||
mainBuf := bytes.Buffer{}
|
||||
hostsBuf := bytes.Buffer{}
|
||||
defer config.StubWriteConfig(&mainBuf, &hostsBuf)()
|
||||
|
||||
output, err := RunCommand(c.Cmd)
|
||||
if err != nil {
|
||||
t.Fatalf("got unexpected error running %s: %s", c.Cmd, err)
|
||||
}
|
||||
|
||||
test.ExpectLines(t, output.String(), c.ExpectedOutputLine)
|
||||
test.ExpectLines(t, buf.String(), c.ExpectedConfigLine)
|
||||
test.ExpectLines(t, mainBuf.String(), c.ExpectedConfigLine)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAliasSet_init_alias_cfg(t *testing.T) {
|
||||
cfg := `---
|
||||
hosts:
|
||||
github.com:
|
||||
user: OWNER
|
||||
oauth_token: token123
|
||||
editor: vim
|
||||
`
|
||||
initBlankContext(cfg, "OWNER/REPO", "trunk")
|
||||
|
||||
buf := bytes.NewBufferString("")
|
||||
defer config.StubWriteConfig(buf)()
|
||||
mainBuf := bytes.Buffer{}
|
||||
hostsBuf := bytes.Buffer{}
|
||||
defer config.StubWriteConfig(&mainBuf, &hostsBuf)()
|
||||
|
||||
output, err := RunCommand("alias set diff pr diff")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
expected := `hosts:
|
||||
github.com:
|
||||
user: OWNER
|
||||
oauth_token: token123
|
||||
expected := `editor: vim
|
||||
aliases:
|
||||
diff: pr diff
|
||||
`
|
||||
|
||||
test.ExpectLines(t, output.String(), "Adding alias for diff: pr diff", "Added alias.")
|
||||
eq(t, buf.String(), expected)
|
||||
eq(t, mainBuf.String(), expected)
|
||||
}
|
||||
|
||||
func TestAliasSet_existing_aliases(t *testing.T) {
|
||||
cfg := `---
|
||||
hosts:
|
||||
github.com:
|
||||
user: OWNER
|
||||
oauth_token: token123
|
||||
aliases:
|
||||
foo: bar
|
||||
`
|
||||
initBlankContext(cfg, "OWNER/REPO", "trunk")
|
||||
|
||||
buf := bytes.NewBufferString("")
|
||||
defer config.StubWriteConfig(buf)()
|
||||
mainBuf := bytes.Buffer{}
|
||||
hostsBuf := bytes.Buffer{}
|
||||
defer config.StubWriteConfig(&mainBuf, &hostsBuf)()
|
||||
|
||||
output, err := RunCommand("alias set view pr view")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
expected := `hosts:
|
||||
github.com:
|
||||
user: OWNER
|
||||
oauth_token: token123
|
||||
aliases:
|
||||
expected := `aliases:
|
||||
foo: bar
|
||||
view: pr view
|
||||
`
|
||||
|
||||
test.ExpectLines(t, output.String(), "Adding alias for view: pr view", "Added alias.")
|
||||
eq(t, buf.String(), expected)
|
||||
eq(t, mainBuf.String(), expected)
|
||||
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -49,23 +49,31 @@ func TestConfigGet_not_found(t *testing.T) {
|
|||
func TestConfigSet(t *testing.T) {
|
||||
initBlankContext("", "OWNER/REPO", "master")
|
||||
|
||||
buf := bytes.NewBufferString("")
|
||||
defer config.StubWriteConfig(buf)()
|
||||
mainBuf := bytes.Buffer{}
|
||||
hostsBuf := bytes.Buffer{}
|
||||
defer config.StubWriteConfig(&mainBuf, &hostsBuf)()
|
||||
|
||||
output, err := RunCommand("config set editor ed")
|
||||
if err != nil {
|
||||
t.Fatalf("error running command `config set editor ed`: %v", err)
|
||||
}
|
||||
|
||||
eq(t, output.String(), "")
|
||||
if len(output.String()) > 0 {
|
||||
t.Errorf("expected output to be blank: %q", output.String())
|
||||
}
|
||||
|
||||
expected := `hosts:
|
||||
github.com:
|
||||
user: OWNER
|
||||
oauth_token: 1234567890
|
||||
editor: ed
|
||||
expectedMain := "editor: ed\n"
|
||||
expectedHosts := `github.com:
|
||||
user: OWNER
|
||||
oauth_token: "1234567890"
|
||||
`
|
||||
|
||||
eq(t, buf.String(), expected)
|
||||
if mainBuf.String() != expectedMain {
|
||||
t.Errorf("expected config.yml to be %q, got %q", expectedMain, mainBuf.String())
|
||||
}
|
||||
if hostsBuf.String() != expectedHosts {
|
||||
t.Errorf("expected hosts.yml to be %q, got %q", expectedHosts, hostsBuf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSet_update(t *testing.T) {
|
||||
|
|
@ -79,23 +87,31 @@ editor: ed
|
|||
|
||||
initBlankContext(cfg, "OWNER/REPO", "master")
|
||||
|
||||
buf := bytes.NewBufferString("")
|
||||
defer config.StubWriteConfig(buf)()
|
||||
mainBuf := bytes.Buffer{}
|
||||
hostsBuf := bytes.Buffer{}
|
||||
defer config.StubWriteConfig(&mainBuf, &hostsBuf)()
|
||||
|
||||
output, err := RunCommand("config set editor vim")
|
||||
if err != nil {
|
||||
t.Fatalf("error running command `config get editor`: %v", err)
|
||||
}
|
||||
|
||||
eq(t, output.String(), "")
|
||||
if len(output.String()) > 0 {
|
||||
t.Errorf("expected output to be blank: %q", output.String())
|
||||
}
|
||||
|
||||
expected := `hosts:
|
||||
github.com:
|
||||
user: OWNER
|
||||
oauth_token: MUSTBEHIGHCUZIMATOKEN
|
||||
editor: vim
|
||||
expectedMain := "editor: vim\n"
|
||||
expectedHosts := `github.com:
|
||||
user: OWNER
|
||||
oauth_token: MUSTBEHIGHCUZIMATOKEN
|
||||
`
|
||||
eq(t, buf.String(), expected)
|
||||
|
||||
if mainBuf.String() != expectedMain {
|
||||
t.Errorf("expected config.yml to be %q, got %q", expectedMain, mainBuf.String())
|
||||
}
|
||||
if hostsBuf.String() != expectedHosts {
|
||||
t.Errorf("expected hosts.yml to be %q, got %q", expectedHosts, hostsBuf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigGetHost(t *testing.T) {
|
||||
|
|
@ -141,23 +157,32 @@ git_protocol: ssh
|
|||
func TestConfigSetHost(t *testing.T) {
|
||||
initBlankContext("", "OWNER/REPO", "master")
|
||||
|
||||
buf := bytes.NewBufferString("")
|
||||
defer config.StubWriteConfig(buf)()
|
||||
mainBuf := bytes.Buffer{}
|
||||
hostsBuf := bytes.Buffer{}
|
||||
defer config.StubWriteConfig(&mainBuf, &hostsBuf)()
|
||||
|
||||
output, err := RunCommand("config set -hgithub.com git_protocol ssh")
|
||||
if err != nil {
|
||||
t.Fatalf("error running command `config set editor ed`: %v", err)
|
||||
}
|
||||
|
||||
eq(t, output.String(), "")
|
||||
if len(output.String()) > 0 {
|
||||
t.Errorf("expected output to be blank: %q", output.String())
|
||||
}
|
||||
|
||||
expected := `hosts:
|
||||
github.com:
|
||||
user: OWNER
|
||||
oauth_token: 1234567890
|
||||
git_protocol: ssh
|
||||
expectedMain := ""
|
||||
expectedHosts := `github.com:
|
||||
user: OWNER
|
||||
oauth_token: "1234567890"
|
||||
git_protocol: ssh
|
||||
`
|
||||
|
||||
eq(t, buf.String(), expected)
|
||||
if mainBuf.String() != expectedMain {
|
||||
t.Errorf("expected config.yml to be %q, got %q", expectedMain, mainBuf.String())
|
||||
}
|
||||
if hostsBuf.String() != expectedHosts {
|
||||
t.Errorf("expected hosts.yml to be %q, got %q", expectedHosts, hostsBuf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSetHost_update(t *testing.T) {
|
||||
|
|
@ -171,21 +196,30 @@ hosts:
|
|||
|
||||
initBlankContext(cfg, "OWNER/REPO", "master")
|
||||
|
||||
buf := bytes.NewBufferString("")
|
||||
defer config.StubWriteConfig(buf)()
|
||||
mainBuf := bytes.Buffer{}
|
||||
hostsBuf := bytes.Buffer{}
|
||||
defer config.StubWriteConfig(&mainBuf, &hostsBuf)()
|
||||
|
||||
output, err := RunCommand("config set -hgithub.com git_protocol https")
|
||||
if err != nil {
|
||||
t.Fatalf("error running command `config get editor`: %v", err)
|
||||
}
|
||||
|
||||
eq(t, output.String(), "")
|
||||
if len(output.String()) > 0 {
|
||||
t.Errorf("expected output to be blank: %q", output.String())
|
||||
}
|
||||
|
||||
expected := `hosts:
|
||||
github.com:
|
||||
git_protocol: https
|
||||
user: OWNER
|
||||
oauth_token: MUSTBEHIGHCUZIMATOKEN
|
||||
expectedMain := ""
|
||||
expectedHosts := `github.com:
|
||||
git_protocol: https
|
||||
user: OWNER
|
||||
oauth_token: MUSTBEHIGHCUZIMATOKEN
|
||||
`
|
||||
eq(t, buf.String(), expected)
|
||||
|
||||
if mainBuf.String() != expectedMain {
|
||||
t.Errorf("expected config.yml to be %q, got %q", expectedMain, mainBuf.String())
|
||||
}
|
||||
if hostsBuf.String() != expectedHosts {
|
||||
t.Errorf("expected hosts.yml to be %q, got %q", expectedHosts, hostsBuf.String())
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -181,24 +181,16 @@ var apiClientForContext = func(ctx context.Context) (*api.Client, error) {
|
|||
|
||||
checkScopesFunc := func(appID string) error {
|
||||
if config.IsGitHubApp(appID) && !tokenFromEnv() && utils.IsTerminal(os.Stdin) && utils.IsTerminal(os.Stderr) {
|
||||
newToken, loginHandle, err := config.AuthFlow("Notice: additional authorization required")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cfg, err := ctx.Config()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_ = cfg.Set(defaultHostname, "oauth_token", newToken)
|
||||
_ = cfg.Set(defaultHostname, "user", loginHandle)
|
||||
// update config file on disk
|
||||
err = cfg.Write()
|
||||
newToken, err := config.AuthFlowWithConfig(cfg, defaultHostname, "Notice: additional authorization required")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// update configuration in memory
|
||||
token = newToken
|
||||
config.AuthFlowComplete()
|
||||
} else {
|
||||
fmt.Fprintln(os.Stderr, "Warning: gh now requires the `read:org` OAuth scope.")
|
||||
fmt.Fprintln(os.Stderr, "Visit https://github.com/settings/tokens and edit your token to enable `read:org`")
|
||||
|
|
@ -235,28 +227,19 @@ var ensureScopes = func(ctx context.Context, client *api.Client, wantedScopes ..
|
|||
tokenFromEnv := len(os.Getenv("GITHUB_TOKEN")) > 0
|
||||
|
||||
if config.IsGitHubApp(appID) && !tokenFromEnv && utils.IsTerminal(os.Stdin) && utils.IsTerminal(os.Stderr) {
|
||||
newToken, loginHandle, err := config.AuthFlow("Notice: additional authorization required")
|
||||
if err != nil {
|
||||
return client, err
|
||||
}
|
||||
cfg, err := ctx.Config()
|
||||
if err != nil {
|
||||
return client, err
|
||||
return nil, err
|
||||
}
|
||||
_ = cfg.Set(defaultHostname, "oauth_token", newToken)
|
||||
_ = cfg.Set(defaultHostname, "user", loginHandle)
|
||||
// update config file on disk
|
||||
err = cfg.Write()
|
||||
_, err = config.AuthFlowWithConfig(cfg, defaultHostname, "Notice: additional authorization required")
|
||||
if err != nil {
|
||||
return client, err
|
||||
return nil, err
|
||||
}
|
||||
// update configuration in memory
|
||||
config.AuthFlowComplete()
|
||||
|
||||
reloadedClient, err := apiClientForContext(ctx)
|
||||
if err != nil {
|
||||
return client, err
|
||||
}
|
||||
|
||||
return reloadedClient, nil
|
||||
} else {
|
||||
fmt.Fprintln(os.Stderr, fmt.Sprintf("Warning: gh now requires %s OAuth scopes.", wantedScopes))
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ import (
|
|||
const defaultTestConfig = `hosts:
|
||||
github.com:
|
||||
user: OWNER
|
||||
oauth_token: 1234567890
|
||||
oauth_token: "1234567890"
|
||||
`
|
||||
|
||||
type askStubber struct {
|
||||
|
|
@ -88,7 +88,7 @@ func initBlankContext(cfg, repo, branch string) {
|
|||
|
||||
// NOTE we are not restoring the original readConfig; we never want to touch the config file on
|
||||
// disk during tests.
|
||||
config.StubConfig(cfg)
|
||||
config.StubConfig(cfg, "")
|
||||
|
||||
return ctx
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ type blankContext struct {
|
|||
}
|
||||
|
||||
func (c *blankContext) Config() (config.Config, error) {
|
||||
cfg, err := config.ParseConfig("boom.txt")
|
||||
cfg, err := config.ParseConfig("config.yml")
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to parse config during tests. did you remember to stub? error: %s", err))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package context
|
|||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"sort"
|
||||
|
||||
"github.com/cli/cli/api"
|
||||
|
|
@ -162,11 +163,13 @@ type fsContext struct {
|
|||
|
||||
func (c *fsContext) Config() (config.Config, error) {
|
||||
if c.config == nil {
|
||||
config, err := config.ParseOrSetupConfigFile(config.ConfigFile())
|
||||
if err != nil {
|
||||
cfg, err := config.ParseDefaultConfig()
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
cfg = config.NewBlankConfig()
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.config = config
|
||||
c.config = cfg
|
||||
c.authToken = ""
|
||||
}
|
||||
return c.config, nil
|
||||
|
|
@ -182,8 +185,12 @@ func (c *fsContext) AuthToken() (string, error) {
|
|||
return "", err
|
||||
}
|
||||
|
||||
var notFound *config.NotFoundError
|
||||
token, err := cfg.Get(defaultHostname, "oauth_token")
|
||||
if token == "" || err != nil {
|
||||
if token == "" || errors.As(err, ¬Found) {
|
||||
// interactive OAuth flow
|
||||
return config.AuthFlowWithConfig(cfg, defaultHostname, "Notice: authentication required")
|
||||
} else if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -21,20 +21,16 @@ func ConfigFile() string {
|
|||
return path.Join(ConfigDir(), "config.yml")
|
||||
}
|
||||
|
||||
func ParseOrSetupConfigFile(fn string) (Config, error) {
|
||||
config, err := ParseConfig(fn)
|
||||
if err != nil && errors.Is(err, os.ErrNotExist) {
|
||||
return setupConfigFile(fn)
|
||||
}
|
||||
return config, err
|
||||
func hostsConfigFile(filename string) string {
|
||||
return path.Join(path.Dir(filename), "hosts.yml")
|
||||
}
|
||||
|
||||
func ParseDefaultConfig() (Config, error) {
|
||||
return ParseConfig(ConfigFile())
|
||||
}
|
||||
|
||||
var ReadConfigFile = func(fn string) ([]byte, error) {
|
||||
f, err := os.Open(fn)
|
||||
var ReadConfigFile = func(filename string) ([]byte, error) {
|
||||
f, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -48,8 +44,13 @@ var ReadConfigFile = func(fn string) ([]byte, error) {
|
|||
return data, nil
|
||||
}
|
||||
|
||||
var WriteConfigFile = func(fn string, data []byte) error {
|
||||
cfgFile, err := os.OpenFile(ConfigFile(), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) // cargo coded from setup
|
||||
var WriteConfigFile = func(filename string, data []byte) error {
|
||||
err := os.MkdirAll(path.Dir(filename), 0771)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cfgFile, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) // cargo coded from setup
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -63,12 +64,12 @@ var WriteConfigFile = func(fn string, data []byte) error {
|
|||
return err
|
||||
}
|
||||
|
||||
var BackupConfigFile = func(fn string) error {
|
||||
return os.Rename(fn, fn+".bak")
|
||||
var BackupConfigFile = func(filename string) error {
|
||||
return os.Rename(filename, filename+".bak")
|
||||
}
|
||||
|
||||
func parseConfigFile(fn string) ([]byte, *yaml.Node, error) {
|
||||
data, err := ReadConfigFile(fn)
|
||||
func parseConfigFile(filename string) ([]byte, *yaml.Node, error) {
|
||||
data, err := ReadConfigFile(filename)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
|
@ -78,8 +79,11 @@ func parseConfigFile(fn string) ([]byte, *yaml.Node, error) {
|
|||
if err != nil {
|
||||
return data, nil, err
|
||||
}
|
||||
if len(root.Content) < 1 {
|
||||
return data, &root, fmt.Errorf("malformed config")
|
||||
if len(root.Content) == 0 {
|
||||
return data, &yaml.Node{
|
||||
Kind: yaml.DocumentNode,
|
||||
Content: []*yaml.Node{{Kind: yaml.MappingNode}},
|
||||
}, nil
|
||||
}
|
||||
if root.Content[0].Kind != yaml.MappingNode {
|
||||
return data, &root, fmt.Errorf("expected a top level map")
|
||||
|
|
@ -90,91 +94,76 @@ func parseConfigFile(fn string) ([]byte, *yaml.Node, error) {
|
|||
|
||||
func isLegacy(root *yaml.Node) bool {
|
||||
for _, v := range root.Content[0].Content {
|
||||
if v.Value == "hosts" {
|
||||
return false
|
||||
if v.Value == "github.com" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
return false
|
||||
}
|
||||
|
||||
func migrateConfig(fn string, root *yaml.Node) error {
|
||||
type ConfigEntry map[string]string
|
||||
type ConfigHash map[string]ConfigEntry
|
||||
|
||||
newConfigData := map[string]ConfigHash{}
|
||||
newConfigData["hosts"] = ConfigHash{}
|
||||
|
||||
topLevelKeys := root.Content[0].Content
|
||||
|
||||
for i, x := range topLevelKeys {
|
||||
if x.Value == "" {
|
||||
continue
|
||||
}
|
||||
if i+1 == len(topLevelKeys) {
|
||||
break
|
||||
}
|
||||
hostname := x.Value
|
||||
newConfigData["hosts"][hostname] = ConfigEntry{}
|
||||
|
||||
authKeys := topLevelKeys[i+1].Content[0].Content
|
||||
|
||||
for j, y := range authKeys {
|
||||
if j+1 == len(authKeys) {
|
||||
break
|
||||
}
|
||||
switch y.Value {
|
||||
case "user":
|
||||
newConfigData["hosts"][hostname]["user"] = authKeys[j+1].Value
|
||||
case "oauth_token":
|
||||
newConfigData["hosts"][hostname]["oauth_token"] = authKeys[j+1].Value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := newConfigData["hosts"][defaultHostname]; !ok {
|
||||
return errors.New("could not find default host configuration")
|
||||
}
|
||||
|
||||
defaultHostConfig := newConfigData["hosts"][defaultHostname]
|
||||
|
||||
if _, ok := defaultHostConfig["user"]; !ok {
|
||||
return errors.New("default host configuration missing user")
|
||||
}
|
||||
|
||||
if _, ok := defaultHostConfig["oauth_token"]; !ok {
|
||||
return errors.New("default host configuration missing oauth_token")
|
||||
}
|
||||
|
||||
newConfig, err := yaml.Marshal(newConfigData)
|
||||
func migrateConfig(filename string) error {
|
||||
b, err := ReadConfigFile(filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = BackupConfigFile(fn)
|
||||
var hosts map[string][]yaml.Node
|
||||
err = yaml.Unmarshal(b, &hosts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error decoding legacy format: %w", err)
|
||||
}
|
||||
|
||||
cfg := NewBlankConfig()
|
||||
for hostname, entries := range hosts {
|
||||
if len(entries) < 1 {
|
||||
continue
|
||||
}
|
||||
mapContent := entries[0].Content
|
||||
for i := 0; i < len(mapContent)-1; i += 2 {
|
||||
if err := cfg.Set(hostname, mapContent[i].Value, mapContent[i+1].Value); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = BackupConfigFile(filename)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to back up existing config: %w", err)
|
||||
}
|
||||
|
||||
return WriteConfigFile(fn, newConfig)
|
||||
return cfg.Write()
|
||||
}
|
||||
|
||||
func ParseConfig(fn string) (Config, error) {
|
||||
_, root, err := parseConfigFile(fn)
|
||||
func ParseConfig(filename string) (Config, error) {
|
||||
_, root, err := parseConfigFile(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if isLegacy(root) {
|
||||
err = migrateConfig(fn, root)
|
||||
err = migrateConfig(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("error migrating legacy config: %w", err)
|
||||
}
|
||||
|
||||
_, root, err = parseConfigFile(fn)
|
||||
_, root, err = parseConfigFile(filename)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to reparse migrated config: %w", err)
|
||||
}
|
||||
} else {
|
||||
if _, hostsRoot, err := parseConfigFile(hostsConfigFile(filename)); err == nil {
|
||||
if len(hostsRoot.Content[0].Content) > 0 {
|
||||
newContent := []*yaml.Node{
|
||||
{Value: "hosts"},
|
||||
hostsRoot.Content[0],
|
||||
}
|
||||
restContent := root.Content[0].Content
|
||||
root.Content[0].Content = append(newContent, restContent...)
|
||||
}
|
||||
} else if !errors.Is(err, os.ErrNotExist) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return NewConfig(root), nil
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package config
|
|||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
|
|
@ -22,8 +23,8 @@ hosts:
|
|||
github.com:
|
||||
user: monalisa
|
||||
oauth_token: OTOKEN
|
||||
`)()
|
||||
config, err := ParseConfig("filename")
|
||||
`, "")()
|
||||
config, err := ParseConfig("config.yml")
|
||||
eq(t, err, nil)
|
||||
user, err := config.Get("github.com", "user")
|
||||
eq(t, err, nil)
|
||||
|
|
@ -42,8 +43,24 @@ hosts:
|
|||
github.com:
|
||||
user: monalisa
|
||||
oauth_token: OTOKEN
|
||||
`, "")()
|
||||
config, err := ParseConfig("config.yml")
|
||||
eq(t, err, nil)
|
||||
user, err := config.Get("github.com", "user")
|
||||
eq(t, err, nil)
|
||||
eq(t, user, "monalisa")
|
||||
token, err := config.Get("github.com", "oauth_token")
|
||||
eq(t, err, nil)
|
||||
eq(t, token, "OTOKEN")
|
||||
}
|
||||
|
||||
func Test_parseConfig_hostsFile(t *testing.T) {
|
||||
defer StubConfig("", `---
|
||||
github.com:
|
||||
user: monalisa
|
||||
oauth_token: OTOKEN
|
||||
`)()
|
||||
config, err := ParseConfig("filename")
|
||||
config, err := ParseConfig("config.yml")
|
||||
eq(t, err, nil)
|
||||
user, err := config.Get("github.com", "user")
|
||||
eq(t, err, nil)
|
||||
|
|
@ -59,38 +76,47 @@ hosts:
|
|||
example.com:
|
||||
user: wronguser
|
||||
oauth_token: NOTTHIS
|
||||
`)()
|
||||
config, err := ParseConfig("filename")
|
||||
`, "")()
|
||||
config, err := ParseConfig("config.yml")
|
||||
eq(t, err, nil)
|
||||
_, err = config.Get("github.com", "user")
|
||||
eq(t, err, errors.New(`could not find config entry for "github.com"`))
|
||||
eq(t, err, &NotFoundError{errors.New(`could not find config entry for "github.com"`)})
|
||||
}
|
||||
|
||||
func Test_migrateConfig(t *testing.T) {
|
||||
oldStyle := `---
|
||||
func Test_ParseConfig_migrateConfig(t *testing.T) {
|
||||
defer StubConfig(`---
|
||||
github.com:
|
||||
- user: keiyuri
|
||||
oauth_token: 123456`
|
||||
|
||||
var root yaml.Node
|
||||
err := yaml.Unmarshal([]byte(oldStyle), &root)
|
||||
if err != nil {
|
||||
panic("failed to parse test yaml")
|
||||
}
|
||||
|
||||
buf := bytes.NewBufferString("")
|
||||
defer StubWriteConfig(buf)()
|
||||
oauth_token: 123456
|
||||
`, "")()
|
||||
|
||||
mainBuf := bytes.Buffer{}
|
||||
hostsBuf := bytes.Buffer{}
|
||||
defer StubWriteConfig(&mainBuf, &hostsBuf)()
|
||||
defer StubBackupConfig()()
|
||||
|
||||
err = migrateConfig("boom.txt", &root)
|
||||
_, err := ParseConfig("config.yml")
|
||||
eq(t, err, nil)
|
||||
|
||||
expected := `hosts:
|
||||
github.com:
|
||||
oauth_token: "123456"
|
||||
user: keiyuri
|
||||
expectedMain := ""
|
||||
expectedHosts := `github.com:
|
||||
user: keiyuri
|
||||
oauth_token: "123456"
|
||||
`
|
||||
|
||||
eq(t, buf.String(), expected)
|
||||
eq(t, mainBuf.String(), expectedMain)
|
||||
eq(t, hostsBuf.String(), expectedHosts)
|
||||
}
|
||||
|
||||
func Test_parseConfigFile(t *testing.T) {
|
||||
fileContents := []string{"", " ", "\n"}
|
||||
for _, contents := range fileContents {
|
||||
t.Run(fmt.Sprintf("contents: %q", contents), func(t *testing.T) {
|
||||
defer StubConfig(contents, "")()
|
||||
_, yamlRoot, err := parseConfigFile("config.yml")
|
||||
eq(t, err, nil)
|
||||
eq(t, yamlRoot.Content[0].Kind, yaml.MappingNode)
|
||||
eq(t, len(yamlRoot.Content[0].Content), 0)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,16 +5,10 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/cli/cli/api"
|
||||
"github.com/cli/cli/auth"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
const (
|
||||
oauthHost = "github.com"
|
||||
)
|
||||
|
||||
var (
|
||||
|
|
@ -31,7 +25,31 @@ func IsGitHubApp(id string) bool {
|
|||
return id == "178c6fc778ccc68e1d6a" || id == "4d747ba5675d5d66553f"
|
||||
}
|
||||
|
||||
func AuthFlow(notice string) (string, string, error) {
|
||||
func AuthFlowWithConfig(cfg Config, hostname, notice string) (string, error) {
|
||||
token, userLogin, err := authFlow(hostname, notice)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = cfg.Set(hostname, "user", userLogin)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
err = cfg.Set(hostname, "oauth_token", token)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = cfg.Write()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
AuthFlowComplete()
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func authFlow(oauthHost, notice string) (string, string, error) {
|
||||
var verboseStream io.Writer
|
||||
if strings.Contains(os.Getenv("DEBUG"), "oauth") {
|
||||
verboseStream = os.Stderr
|
||||
|
|
@ -69,53 +87,6 @@ func AuthFlowComplete() {
|
|||
_ = waitForEnter(os.Stdin)
|
||||
}
|
||||
|
||||
// FIXME: make testable
|
||||
func setupConfigFile(filename string) (Config, error) {
|
||||
token, userLogin, err := AuthFlow("Notice: authentication required")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO this sucks. It precludes us laying out a nice config with comments and such.
|
||||
type yamlConfig struct {
|
||||
Hosts map[string]map[string]string
|
||||
}
|
||||
|
||||
yamlHosts := map[string]map[string]string{}
|
||||
yamlHosts[oauthHost] = map[string]string{}
|
||||
yamlHosts[oauthHost]["user"] = userLogin
|
||||
yamlHosts[oauthHost]["oauth_token"] = token
|
||||
|
||||
defaultConfig := yamlConfig{
|
||||
Hosts: yamlHosts,
|
||||
}
|
||||
|
||||
err = os.MkdirAll(filepath.Dir(filename), 0771)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cfgFile, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer cfgFile.Close()
|
||||
|
||||
yamlData, err := yaml.Marshal(defaultConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, err = cfgFile.Write(yamlData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
AuthFlowComplete()
|
||||
|
||||
// TODO cleaner error handling? this "should" always work given that we /just/ wrote the file...
|
||||
return ParseConfig(filename)
|
||||
}
|
||||
|
||||
func getViewer(token string) (string, error) {
|
||||
http := api.NewClient(api.AddHeader("Authorization", fmt.Sprintf("token %s", token)))
|
||||
return api.CurrentLoginName(http)
|
||||
|
|
|
|||
|
|
@ -1,18 +1,17 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
const defaultHostname = "github.com"
|
||||
const defaultGitProtocol = "https"
|
||||
|
||||
// This interface describes interacting with some persistent configuration for gh.
|
||||
type Config interface {
|
||||
Hosts() ([]*HostConfig, error)
|
||||
Get(string, string) (string, error)
|
||||
Set(string, string, string) error
|
||||
Aliases() (*AliasConfig, error)
|
||||
|
|
@ -61,6 +60,7 @@ func (cm *ConfigMap) SetStringValue(key, value string) error {
|
|||
}
|
||||
valueNode = &yaml.Node{
|
||||
Kind: yaml.ScalarNode,
|
||||
Tag: "!!str",
|
||||
Value: "",
|
||||
}
|
||||
|
||||
|
|
@ -107,11 +107,17 @@ func NewConfig(root *yaml.Node) Config {
|
|||
}
|
||||
}
|
||||
|
||||
func NewBlankConfig() Config {
|
||||
return NewConfig(&yaml.Node{
|
||||
Kind: yaml.DocumentNode,
|
||||
Content: []*yaml.Node{{Kind: yaml.MappingNode}},
|
||||
})
|
||||
}
|
||||
|
||||
// This type implements a Config interface and represents a config file on disk.
|
||||
type fileConfig struct {
|
||||
ConfigMap
|
||||
documentRoot *yaml.Node
|
||||
hosts []*HostConfig
|
||||
}
|
||||
|
||||
func (c *fileConfig) Root() *yaml.Node {
|
||||
|
|
@ -159,7 +165,10 @@ func (c *fileConfig) Set(hostname, key, value string) error {
|
|||
return c.SetStringValue(key, value)
|
||||
} else {
|
||||
hostCfg, err := c.configForHost(hostname)
|
||||
if err != nil {
|
||||
var notFound *NotFoundError
|
||||
if errors.As(err, ¬Found) {
|
||||
hostCfg = c.makeConfigForHost(hostname)
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
return hostCfg.SetStringValue(key, value)
|
||||
|
|
@ -167,7 +176,7 @@ func (c *fileConfig) Set(hostname, key, value string) error {
|
|||
}
|
||||
|
||||
func (c *fileConfig) configForHost(hostname string) (*HostConfig, error) {
|
||||
hosts, err := c.Hosts()
|
||||
hosts, err := c.hostEntries()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse hosts config: %w", err)
|
||||
}
|
||||
|
|
@ -177,16 +186,46 @@ func (c *fileConfig) configForHost(hostname string) (*HostConfig, error) {
|
|||
return hc, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("could not find config entry for %q", hostname)
|
||||
return nil, &NotFoundError{fmt.Errorf("could not find config entry for %q", hostname)}
|
||||
}
|
||||
|
||||
func (c *fileConfig) Write() error {
|
||||
marshalled, err := yaml.Marshal(c.documentRoot)
|
||||
mainData := yaml.Node{Kind: yaml.MappingNode}
|
||||
hostsData := yaml.Node{Kind: yaml.MappingNode}
|
||||
|
||||
nodes := c.documentRoot.Content[0].Content
|
||||
for i := 0; i < len(nodes)-1; i += 2 {
|
||||
if nodes[i].Value == "hosts" {
|
||||
hostsData.Content = append(hostsData.Content, nodes[i+1].Content...)
|
||||
} else {
|
||||
mainData.Content = append(mainData.Content, nodes[i], nodes[i+1])
|
||||
}
|
||||
}
|
||||
|
||||
mainBytes, err := yaml.Marshal(&mainData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return WriteConfigFile(ConfigFile(), marshalled)
|
||||
filename := ConfigFile()
|
||||
err = WriteConfigFile(filename, yamlNormalize(mainBytes))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
hostsBytes, err := yaml.Marshal(&hostsData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return WriteConfigFile(hostsConfigFile(filename), yamlNormalize(hostsBytes))
|
||||
}
|
||||
|
||||
func yamlNormalize(b []byte) []byte {
|
||||
if bytes.Equal(b, []byte("{}\n")) {
|
||||
return []byte{}
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func (c *fileConfig) Aliases() (*AliasConfig, error) {
|
||||
|
|
@ -243,11 +282,7 @@ func (c *fileConfig) Aliases() (*AliasConfig, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (c *fileConfig) Hosts() ([]*HostConfig, error) {
|
||||
if len(c.hosts) > 0 {
|
||||
return c.hosts, nil
|
||||
}
|
||||
|
||||
func (c *fileConfig) hostEntries() ([]*HostConfig, error) {
|
||||
entry, err := c.FindEntry("hosts")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not find hosts config: %w", err)
|
||||
|
|
@ -258,11 +293,39 @@ func (c *fileConfig) Hosts() ([]*HostConfig, error) {
|
|||
return nil, fmt.Errorf("could not parse hosts config: %w", err)
|
||||
}
|
||||
|
||||
c.hosts = hostConfigs
|
||||
|
||||
return hostConfigs, nil
|
||||
}
|
||||
|
||||
func (c *fileConfig) makeConfigForHost(hostname string) *HostConfig {
|
||||
hostRoot := &yaml.Node{Kind: yaml.MappingNode}
|
||||
hostCfg := &HostConfig{
|
||||
Host: hostname,
|
||||
ConfigMap: ConfigMap{Root: hostRoot},
|
||||
}
|
||||
|
||||
var notFound *NotFoundError
|
||||
hostsEntry, err := c.FindEntry("hosts")
|
||||
if errors.As(err, ¬Found) {
|
||||
hostsEntry.KeyNode = &yaml.Node{
|
||||
Kind: yaml.ScalarNode,
|
||||
Value: "hosts",
|
||||
}
|
||||
hostsEntry.ValueNode = &yaml.Node{Kind: yaml.MappingNode}
|
||||
root := c.Root()
|
||||
root.Content = append(root.Content, hostsEntry.KeyNode, hostsEntry.ValueNode)
|
||||
} else if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
hostsEntry.ValueNode.Content = append(hostsEntry.ValueNode.Content,
|
||||
&yaml.Node{
|
||||
Kind: yaml.ScalarNode,
|
||||
Value: hostname,
|
||||
}, hostRoot)
|
||||
|
||||
return hostCfg
|
||||
}
|
||||
|
||||
func (c *fileConfig) parseHosts(hostsEntry *yaml.Node) ([]*HostConfig, error) {
|
||||
hostConfigs := []*HostConfig{}
|
||||
|
||||
|
|
|
|||
41
internal/config/config_type_test.go
Normal file
41
internal/config/config_type_test.go
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_fileConfig_Set(t *testing.T) {
|
||||
mainBuf := bytes.Buffer{}
|
||||
hostsBuf := bytes.Buffer{}
|
||||
defer StubWriteConfig(&mainBuf, &hostsBuf)()
|
||||
|
||||
c := NewBlankConfig()
|
||||
assert.NoError(t, c.Set("", "editor", "nano"))
|
||||
assert.NoError(t, c.Set("github.com", "git_protocol", "ssh"))
|
||||
assert.NoError(t, c.Set("example.com", "editor", "vim"))
|
||||
assert.NoError(t, c.Set("github.com", "user", "hubot"))
|
||||
assert.NoError(t, c.Write())
|
||||
|
||||
assert.Equal(t, "editor: nano\n", mainBuf.String())
|
||||
assert.Equal(t, `github.com:
|
||||
git_protocol: ssh
|
||||
user: hubot
|
||||
example.com:
|
||||
editor: vim
|
||||
`, hostsBuf.String())
|
||||
}
|
||||
|
||||
func Test_fileConfig_Write(t *testing.T) {
|
||||
mainBuf := bytes.Buffer{}
|
||||
hostsBuf := bytes.Buffer{}
|
||||
defer StubWriteConfig(&mainBuf, &hostsBuf)()
|
||||
|
||||
c := NewBlankConfig()
|
||||
assert.NoError(t, c.Write())
|
||||
|
||||
assert.Equal(t, "", mainBuf.String())
|
||||
assert.Equal(t, "", hostsBuf.String())
|
||||
}
|
||||
|
|
@ -1,7 +1,10 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path"
|
||||
)
|
||||
|
||||
func StubBackupConfig() func() {
|
||||
|
|
@ -15,21 +18,41 @@ func StubBackupConfig() func() {
|
|||
}
|
||||
}
|
||||
|
||||
func StubWriteConfig(w io.Writer) func() {
|
||||
func StubWriteConfig(wc io.Writer, wh io.Writer) func() {
|
||||
orig := WriteConfigFile
|
||||
WriteConfigFile = func(fn string, data []byte) error {
|
||||
_, err := w.Write(data)
|
||||
return err
|
||||
switch path.Base(fn) {
|
||||
case "config.yml":
|
||||
_, err := wc.Write(data)
|
||||
return err
|
||||
case "hosts.yml":
|
||||
_, err := wh.Write(data)
|
||||
return err
|
||||
default:
|
||||
return fmt.Errorf("write to unstubbed file: %q", fn)
|
||||
}
|
||||
}
|
||||
return func() {
|
||||
WriteConfigFile = orig
|
||||
}
|
||||
}
|
||||
|
||||
func StubConfig(content string) func() {
|
||||
func StubConfig(main, hosts string) func() {
|
||||
orig := ReadConfigFile
|
||||
ReadConfigFile = func(fn string) ([]byte, error) {
|
||||
return []byte(content), nil
|
||||
switch path.Base(fn) {
|
||||
case "config.yml":
|
||||
return []byte(main), nil
|
||||
case "hosts.yml":
|
||||
if hosts == "" {
|
||||
return []byte(nil), os.ErrNotExist
|
||||
} else {
|
||||
return []byte(hosts), nil
|
||||
}
|
||||
default:
|
||||
return []byte(nil), fmt.Errorf("read from unstubbed file: %q", fn)
|
||||
}
|
||||
|
||||
}
|
||||
return func() {
|
||||
ReadConfigFile = orig
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue