From dc8698ee46b6fa2b0311a1352e429cbbf9e2599d Mon Sep 17 00:00:00 2001 From: Alisson Santos Date: Mon, 19 Oct 2020 15:09:48 +0200 Subject: [PATCH] Make ssh parser to parse included config files --- git/ssh_config.go | 129 ++++++++++++++++++++++------------ git/ssh_config_test.go | 100 +++++++++++++++++++++++--- git/testdata/included.conf | 2 + git/testdata/ssh_config1.conf | 2 + git/testdata/ssh_config2.conf | 2 + git/testdata/ssh_config3.conf | 1 + 6 files changed, 181 insertions(+), 55 deletions(-) create mode 100644 git/testdata/included.conf create mode 100644 git/testdata/ssh_config1.conf create mode 100644 git/testdata/ssh_config2.conf create mode 100644 git/testdata/ssh_config3.conf diff --git a/git/ssh_config.go b/git/ssh_config.go index 287298cd9..b65d8c895 100644 --- a/git/ssh_config.go +++ b/git/ssh_config.go @@ -2,7 +2,6 @@ package git import ( "bufio" - "io" "net/url" "os" "path/filepath" @@ -13,12 +12,10 @@ import ( ) var ( - sshHostRE, sshTokenRE *regexp.Regexp ) func init() { - sshHostRE = regexp.MustCompile("(?i)^[ \t]*(host|hostname)[ \t]+(.+)$") sshTokenRE = regexp.MustCompile(`%[%h]`) } @@ -45,6 +42,88 @@ func (m SSHAliasMap) Translator() func(*url.URL) *url.URL { } } +type parser struct { + aliasMap SSHAliasMap +} + +func (p *parser) read(fileName string) error { + file, err := os.Open(fileName) + if err != nil { + return err + } + defer file.Close() + + hosts := []string{"*"} + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := scanner.Text() + fields := strings.Fields(line) + + if len(fields) < 2 { + continue + } + + directive, params := fields[0], fields[1:] + switch { + case strings.EqualFold(directive, "Host"): + hosts = params + case strings.EqualFold(directive, "Hostname"): + for _, host := range hosts { + for _, name := range params { + p.aliasMap[host] = sshExpandTokens(name, host) + } + } + case strings.EqualFold(directive, "Include"): + for _, path := range absolutePaths(fileName, params) { + fileNames, err := filepath.Glob(path) + if err != nil { + continue + } + + for _, fileName := range fileNames { + _ = p.read(fileName) + } + } + } + } + + return scanner.Err() +} + +func isSystem(path string) bool { + return strings.HasPrefix(path, "/etc/ssh") +} + +func absolutePaths(parentFile string, paths []string) []string { + absPaths := make([]string, len(paths)) + + for i, path := range paths { + switch { + case filepath.IsAbs(path): + absPaths[i] = path + case strings.HasPrefix(path, "~"): + absPaths[i], _ = homedir.Expand(path) + case isSystem(parentFile): + absPaths[i] = filepath.Join("/etc", "ssh", path) + default: + dir, _ := homedir.Dir() + absPaths[i] = filepath.Join(dir, ".ssh", path) + } + } + + return absPaths +} + +func parse(files ...string) SSHAliasMap { + p := parser{aliasMap: make(SSHAliasMap)} + + for _, file := range files { + _ = p.read(file) + } + + return p.aliasMap +} + // ParseSSHConfig constructs a map of SSH hostname aliases based on user and // system configuration files func ParseSSHConfig() SSHAliasMap { @@ -57,49 +136,7 @@ func ParseSSHConfig() SSHAliasMap { configFiles = append([]string{userConfig}, configFiles...) } - openFiles := make([]io.Reader, 0, len(configFiles)) - for _, file := range configFiles { - f, err := os.Open(file) - if err != nil { - continue - } - defer f.Close() - openFiles = append(openFiles, f) - } - return sshParse(openFiles...) -} - -func sshParse(r ...io.Reader) SSHAliasMap { - config := make(SSHAliasMap) - for _, file := range r { - _ = sshParseConfig(config, file) - } - return config -} - -func sshParseConfig(c SSHAliasMap, file io.Reader) error { - hosts := []string{"*"} - scanner := bufio.NewScanner(file) - for scanner.Scan() { - line := scanner.Text() - match := sshHostRE.FindStringSubmatch(line) - if match == nil { - continue - } - - names := strings.Fields(match[2]) - if strings.EqualFold(match[1], "host") { - hosts = names - } else { - for _, host := range hosts { - for _, name := range names { - c[host] = sshExpandTokens(name, host) - } - } - } - } - - return scanner.Err() + return parse(configFiles...) } func sshExpandTokens(text, host string) string { diff --git a/git/ssh_config_test.go b/git/ssh_config_test.go index 7aafc5b21..55874fb6e 100644 --- a/git/ssh_config_test.go +++ b/git/ssh_config_test.go @@ -1,10 +1,15 @@ package git import ( + "fmt" + "io/ioutil" "net/url" + "os" + "path/filepath" "reflect" - "strings" "testing" + + "github.com/mitchellh/go-homedir" ) // TODO: extract assertion helpers into a shared package @@ -15,19 +20,96 @@ func eq(t *testing.T, got interface{}, expected interface{}) { } } -func Test_sshParse(t *testing.T) { - m := sshParse(strings.NewReader(` - Host foo bar - HostName example.com - `), strings.NewReader(` - Host bar baz - hostname %%%h.net%% - `)) +func createTempFile(t *testing.T, prefix string) *os.File { + t.Helper() + + dir, err := homedir.Dir() + if err != nil { + t.Errorf("Could not find homedir: %s", err) + } + + tempFile, err := ioutil.TempFile(filepath.Join(dir, ".ssh"), prefix) + if err != nil { + t.Errorf("Could create a temp file: %s", err) + } + + t.Cleanup(func() { + tempFile.Close() + os.Remove(tempFile.Name()) + }) + + return tempFile +} + +func Test_parse(t *testing.T) { + includedTempFile := createTempFile(t, "included") + includedConfigFile := ` +Host webapp + HostName webapp.example.com + ` + fmt.Fprint(includedTempFile, includedConfigFile) + + m := parse( + "testdata/ssh_config1.conf", + "testdata/ssh_config2.conf", + "testdata/ssh_config3.conf", + ) + eq(t, m["foo"], "example.com") eq(t, m["bar"], "%bar.net%") eq(t, m["nonexistent"], "") } +func Test_absolutePaths(t *testing.T) { + dir, err := homedir.Dir() + if err != nil { + t.Errorf("Could not find homedir: %s", err) + } + + tests := map[string]struct { + parentFile string + Input []string + Want []string + }{ + "absolute path": { + parentFile: "/etc/ssh/ssh_config", + Input: []string{"/etc/ssh/config"}, + Want: []string{"/etc/ssh/config"}, + }, + "system relative path": { + parentFile: "/etc/ssh/config", + Input: []string{"configs/*.conf"}, + Want: []string{"/etc/ssh/configs/*.conf"}, + }, + "user relative path": { + parentFile: filepath.Join(dir, ".ssh", "ssh_config"), + Input: []string{"configs/*.conf"}, + Want: []string{filepath.Join(dir, ".ssh", "configs/*.conf")}, + }, + "shell-like ~ rerefence": { + parentFile: filepath.Join(dir, ".ssh", "ssh_config"), + Input: []string{"~/.ssh/*.conf"}, + Want: []string{filepath.Join(dir, ".ssh", "*.conf")}, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + paths := absolutePaths(test.parentFile, test.Input) + + if len(paths) != len(test.Input) { + t.Errorf("Expected %d, got %d", len(test.Input), len(paths)) + } + + for i, path := range paths { + if path != test.Want[i] { + t.Errorf("Expected %q, got %q", test.Want[i], path) + } + } + }) + } +} + func Test_Translator(t *testing.T) { m := SSHAliasMap{ "gh": "github.com", diff --git a/git/testdata/included.conf b/git/testdata/included.conf new file mode 100644 index 000000000..2ded9d103 --- /dev/null +++ b/git/testdata/included.conf @@ -0,0 +1,2 @@ +Host webapp + HostName webapp.example.com diff --git a/git/testdata/ssh_config1.conf b/git/testdata/ssh_config1.conf new file mode 100644 index 000000000..7249b01fb --- /dev/null +++ b/git/testdata/ssh_config1.conf @@ -0,0 +1,2 @@ +Host foo bar + HostName example.com diff --git a/git/testdata/ssh_config2.conf b/git/testdata/ssh_config2.conf new file mode 100644 index 000000000..3884f3d15 --- /dev/null +++ b/git/testdata/ssh_config2.conf @@ -0,0 +1,2 @@ +Host bar baz +hostname %%%h.net%% diff --git a/git/testdata/ssh_config3.conf b/git/testdata/ssh_config3.conf new file mode 100644 index 000000000..7b55743ea --- /dev/null +++ b/git/testdata/ssh_config3.conf @@ -0,0 +1 @@ +Include ~/.ssh/included*