diff --git a/git/remote_test.go b/git/remote_test.go index 2e7d30cb6..e8c091653 100644 --- a/git/remote_test.go +++ b/git/remote_test.go @@ -1,6 +1,17 @@ package git -import "testing" +import ( + "reflect" + "testing" +) + +// TODO: extract assertion helpers into a shared package +func eq(t *testing.T, got interface{}, expected interface{}) { + t.Helper() + if !reflect.DeepEqual(got, expected) { + t.Errorf("expected: %v, got: %v", expected, got) + } +} func Test_parseRemotes(t *testing.T) { remoteList := []string{ diff --git a/git/ssh_config.go b/git/ssh_config.go index b65d8c895..317ff6059 100644 --- a/git/ssh_config.go +++ b/git/ssh_config.go @@ -2,6 +2,7 @@ package git import ( "bufio" + "io" "net/url" "os" "path/filepath" @@ -12,13 +13,10 @@ import ( ) var ( - sshTokenRE *regexp.Regexp + sshConfigLineRE = regexp.MustCompile(`\A\s*(?P[A-Za-z][A-Za-z0-9]*)(?:\s+|\s*=\s*)(?P.+)`) + sshTokenRE = regexp.MustCompile(`%[%h]`) ) -func init() { - sshTokenRE = regexp.MustCompile(`%[%h]`) -} - // SSHAliasMap encapsulates the translation of SSH hostname aliases type SSHAliasMap map[string]string @@ -42,42 +40,75 @@ func (m SSHAliasMap) Translator() func(*url.URL) *url.URL { } } -type parser struct { +type sshParser struct { + homeDir string + aliasMap SSHAliasMap + hosts []string + + open func(string) (io.Reader, error) + glob func(string) ([]string, error) } -func (p *parser) read(fileName string) error { - file, err := os.Open(fileName) - if err != nil { - return err +func (p *sshParser) read(fileName string) error { + var file io.Reader + if p.open == nil { + f, err := os.Open(fileName) + if err != nil { + return err + } + defer f.Close() + file = f + } else { + var err error + file, err = p.open(fileName) + if err != nil { + return err + } + } + + if len(p.hosts) == 0 { + p.hosts = []string{"*"} } - defer file.Close() - hosts := []string{"*"} scanner := bufio.NewScanner(file) for scanner.Scan() { - line := scanner.Text() - fields := strings.Fields(line) - - if len(fields) < 2 { + m := sshConfigLineRE.FindStringSubmatch(scanner.Text()) + if len(m) < 3 { continue } - directive, params := fields[0], fields[1:] - switch { - case strings.EqualFold(directive, "Host"): - hosts = params - case strings.EqualFold(directive, "Hostname"): - for _, host := range hosts { - for _, name := range params { + keyword, arguments := strings.ToLower(m[1]), m[2] + switch keyword { + case "host": + p.hosts = strings.Fields(arguments) + case "hostname": + for _, host := range p.hosts { + for _, name := range strings.Fields(arguments) { + if p.aliasMap == nil { + p.aliasMap = make(SSHAliasMap) + } p.aliasMap[host] = sshExpandTokens(name, host) } } - case strings.EqualFold(directive, "Include"): - for _, path := range absolutePaths(fileName, params) { - fileNames, err := filepath.Glob(path) - if err != nil { - continue + case "include": + for _, arg := range strings.Fields(arguments) { + path := p.absolutePath(fileName, arg) + + var fileNames []string + if p.glob == nil { + paths, _ := filepath.Glob(path) + for _, p := range paths { + if s, err := os.Stat(p); err == nil && !s.IsDir() { + fileNames = append(fileNames, p) + } + } + } else { + var err error + fileNames, err = p.glob(path) + if err != nil { + continue + } } for _, fileName := range fileNames { @@ -90,38 +121,20 @@ func (p *parser) read(fileName string) error { return scanner.Err() } -func isSystem(path string) bool { - return strings.HasPrefix(path, "/etc/ssh") -} - -func absolutePaths(parentFile string, paths []string) []string { - absPaths := make([]string, len(paths)) - - for i, path := range paths { - switch { - case filepath.IsAbs(path): - absPaths[i] = path - case strings.HasPrefix(path, "~"): - absPaths[i], _ = homedir.Expand(path) - case isSystem(parentFile): - absPaths[i] = filepath.Join("/etc", "ssh", path) - default: - dir, _ := homedir.Dir() - absPaths[i] = filepath.Join(dir, ".ssh", path) - } +func (p *sshParser) absolutePath(parentFile, path string) string { + if filepath.IsAbs(path) || strings.HasPrefix(filepath.ToSlash(path), "/") { + return path } - return absPaths -} - -func parse(files ...string) SSHAliasMap { - p := parser{aliasMap: make(SSHAliasMap)} - - for _, file := range files { - _ = p.read(file) + if strings.HasPrefix(path, "~") { + return filepath.Join(p.homeDir, strings.TrimPrefix(path, "~")) } - return p.aliasMap + if strings.HasPrefix(filepath.ToSlash(parentFile), "/etc/ssh") { + return filepath.Join("/etc/ssh", path) + } + + return filepath.Join(p.homeDir, ".ssh", path) } // ParseSSHConfig constructs a map of SSH hostname aliases based on user and @@ -131,12 +144,19 @@ func ParseSSHConfig() SSHAliasMap { "/etc/ssh_config", "/etc/ssh/ssh_config", } + + p := sshParser{} + if homedir, err := homedir.Dir(); err == nil { userConfig := filepath.Join(homedir, ".ssh", "config") configFiles = append([]string{userConfig}, configFiles...) + p.homeDir = homedir } - return parse(configFiles...) + for _, file := range configFiles { + _ = p.read(file) + } + return p.aliasMap } func sshExpandTokens(text, host string) string { diff --git a/git/ssh_config_test.go b/git/ssh_config_test.go index 55874fb6e..f05ca303b 100644 --- a/git/ssh_config_test.go +++ b/git/ssh_config_test.go @@ -1,110 +1,124 @@ package git import ( + "bytes" "fmt" - "io/ioutil" + "io" "net/url" - "os" "path/filepath" - "reflect" "testing" - "github.com/mitchellh/go-homedir" + "github.com/MakeNowJust/heredoc" ) -// TODO: extract assertion helpers into a shared package -func eq(t *testing.T, got interface{}, expected interface{}) { - t.Helper() - if !reflect.DeepEqual(got, expected) { - t.Errorf("expected: %v, got: %v", expected, got) +func Test_sshParser_read(t *testing.T) { + testFiles := map[string]string{ + "/etc/ssh/config": heredoc.Doc(` + Include sites/* + `), + "/etc/ssh/sites/cfg1": heredoc.Doc(` + Host s1 + Hostname=site1.net + `), + "/etc/ssh/sites/cfg2": heredoc.Doc(` + Host s2 + Hostname = site2.net + `), + "HOME/.ssh/config": heredoc.Doc(` + Host * + Host gh gittyhubby + Hostname github.com + #Hostname example.com + Host ex + Include ex_config/* + `), + "HOME/.ssh/ex_config/ex_cfg": heredoc.Doc(` + Hostname example.com + `), + } + globResults := map[string][]string{ + "/etc/ssh/sites/*": {"/etc/ssh/sites/cfg1", "/etc/ssh/sites/cfg2"}, + "HOME/.ssh/ex_config/*": {"HOME/.ssh/ex_config/ex_cfg"}, + } + + p := &sshParser{ + homeDir: "HOME", + open: func(s string) (io.Reader, error) { + if contents, ok := testFiles[filepath.ToSlash(s)]; ok { + return bytes.NewBufferString(contents), nil + } else { + return nil, fmt.Errorf("no test file stub found: %q", s) + } + }, + glob: func(p string) ([]string, error) { + if results, ok := globResults[filepath.ToSlash(p)]; ok { + return results, nil + } else { + return nil, fmt.Errorf("no glob stubs found: %q", p) + } + }, + } + + if err := p.read("/etc/ssh/config"); err != nil { + t.Fatalf("read(global config) = %v", err) + } + if err := p.read("HOME/.ssh/config"); err != nil { + t.Fatalf("read(user config) = %v", err) + } + + if got := p.aliasMap["gh"]; got != "github.com" { + t.Errorf("expected alias %q to expand to %q, got %q", "gh", "github.com", got) + } + if got := p.aliasMap["gittyhubby"]; got != "github.com" { + t.Errorf("expected alias %q to expand to %q, got %q", "gittyhubby", "github.com", got) + } + if got := p.aliasMap["example.com"]; got != "" { + t.Errorf("expected alias %q to expand to %q, got %q", "example.com", "", got) + } + if got := p.aliasMap["ex"]; got != "example.com" { + t.Errorf("expected alias %q to expand to %q, got %q", "ex", "example.com", got) + } + if got := p.aliasMap["s1"]; got != "site1.net" { + t.Errorf("expected alias %q to expand to %q, got %q", "s1", "site1.net", got) } } -func createTempFile(t *testing.T, prefix string) *os.File { - t.Helper() - - dir, err := homedir.Dir() - if err != nil { - t.Errorf("Could not find homedir: %s", err) - } - - tempFile, err := ioutil.TempFile(filepath.Join(dir, ".ssh"), prefix) - if err != nil { - t.Errorf("Could create a temp file: %s", err) - } - - t.Cleanup(func() { - tempFile.Close() - os.Remove(tempFile.Name()) - }) - - return tempFile -} - -func Test_parse(t *testing.T) { - includedTempFile := createTempFile(t, "included") - includedConfigFile := ` -Host webapp - HostName webapp.example.com - ` - fmt.Fprint(includedTempFile, includedConfigFile) - - m := parse( - "testdata/ssh_config1.conf", - "testdata/ssh_config2.conf", - "testdata/ssh_config3.conf", - ) - - eq(t, m["foo"], "example.com") - eq(t, m["bar"], "%bar.net%") - eq(t, m["nonexistent"], "") -} - -func Test_absolutePaths(t *testing.T) { - dir, err := homedir.Dir() - if err != nil { - t.Errorf("Could not find homedir: %s", err) - } +func Test_sshParser_absolutePath(t *testing.T) { + dir := "HOME" + p := &sshParser{homeDir: dir} tests := map[string]struct { parentFile string - Input []string - Want []string + arg string + want string + wantErr bool }{ "absolute path": { parentFile: "/etc/ssh/ssh_config", - Input: []string{"/etc/ssh/config"}, - Want: []string{"/etc/ssh/config"}, + arg: "/etc/ssh/config", + want: "/etc/ssh/config", }, "system relative path": { parentFile: "/etc/ssh/config", - Input: []string{"configs/*.conf"}, - Want: []string{"/etc/ssh/configs/*.conf"}, + arg: "configs/*.conf", + want: filepath.Join("/etc", "ssh", "configs", "*.conf"), }, "user relative path": { parentFile: filepath.Join(dir, ".ssh", "ssh_config"), - Input: []string{"configs/*.conf"}, - Want: []string{filepath.Join(dir, ".ssh", "configs/*.conf")}, + arg: "configs/*.conf", + want: filepath.Join(dir, ".ssh", "configs/*.conf"), }, "shell-like ~ rerefence": { parentFile: filepath.Join(dir, ".ssh", "ssh_config"), - Input: []string{"~/.ssh/*.conf"}, - Want: []string{filepath.Join(dir, ".ssh", "*.conf")}, + arg: "~/.ssh/*.conf", + want: filepath.Join(dir, ".ssh", "*.conf"), }, } - for name, test := range tests { + for name, tt := range tests { t.Run(name, func(t *testing.T) { - paths := absolutePaths(test.parentFile, test.Input) - - if len(paths) != len(test.Input) { - t.Errorf("Expected %d, got %d", len(test.Input), len(paths)) - } - - for i, path := range paths { - if path != test.Want[i] { - t.Errorf("Expected %q, got %q", test.Want[i], path) - } + if got := p.absolutePath(tt.parentFile, tt.arg); got != tt.want { + t.Errorf("absolutePath(): %q, wants %q", got, tt.want) } }) } diff --git a/git/testdata/included.conf b/git/testdata/included.conf deleted file mode 100644 index 2ded9d103..000000000 --- a/git/testdata/included.conf +++ /dev/null @@ -1,2 +0,0 @@ -Host webapp - HostName webapp.example.com diff --git a/git/testdata/ssh_config1.conf b/git/testdata/ssh_config1.conf deleted file mode 100644 index 7249b01fb..000000000 --- a/git/testdata/ssh_config1.conf +++ /dev/null @@ -1,2 +0,0 @@ -Host foo bar - HostName example.com diff --git a/git/testdata/ssh_config2.conf b/git/testdata/ssh_config2.conf deleted file mode 100644 index 3884f3d15..000000000 --- a/git/testdata/ssh_config2.conf +++ /dev/null @@ -1,2 +0,0 @@ -Host bar baz -hostname %%%h.net%% diff --git a/git/testdata/ssh_config3.conf b/git/testdata/ssh_config3.conf deleted file mode 100644 index 7b55743ea..000000000 --- a/git/testdata/ssh_config3.conf +++ /dev/null @@ -1 +0,0 @@ -Include ~/.ssh/included*