From dc8698ee46b6fa2b0311a1352e429cbbf9e2599d Mon Sep 17 00:00:00 2001 From: Alisson Santos Date: Mon, 19 Oct 2020 15:09:48 +0200 Subject: [PATCH 1/2] Make ssh parser to parse included config files --- git/ssh_config.go | 129 ++++++++++++++++++++++------------ git/ssh_config_test.go | 100 +++++++++++++++++++++++--- git/testdata/included.conf | 2 + git/testdata/ssh_config1.conf | 2 + git/testdata/ssh_config2.conf | 2 + git/testdata/ssh_config3.conf | 1 + 6 files changed, 181 insertions(+), 55 deletions(-) create mode 100644 git/testdata/included.conf create mode 100644 git/testdata/ssh_config1.conf create mode 100644 git/testdata/ssh_config2.conf create mode 100644 git/testdata/ssh_config3.conf diff --git a/git/ssh_config.go b/git/ssh_config.go index 287298cd9..b65d8c895 100644 --- a/git/ssh_config.go +++ b/git/ssh_config.go @@ -2,7 +2,6 @@ package git import ( "bufio" - "io" "net/url" "os" "path/filepath" @@ -13,12 +12,10 @@ import ( ) var ( - sshHostRE, sshTokenRE *regexp.Regexp ) func init() { - sshHostRE = regexp.MustCompile("(?i)^[ \t]*(host|hostname)[ \t]+(.+)$") sshTokenRE = regexp.MustCompile(`%[%h]`) } @@ -45,6 +42,88 @@ func (m SSHAliasMap) Translator() func(*url.URL) *url.URL { } } +type parser struct { + aliasMap SSHAliasMap +} + +func (p *parser) read(fileName string) error { + file, err := os.Open(fileName) + if err != nil { + return err + } + defer file.Close() + + hosts := []string{"*"} + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := scanner.Text() + fields := strings.Fields(line) + + if len(fields) < 2 { + continue + } + + directive, params := fields[0], fields[1:] + switch { + case strings.EqualFold(directive, "Host"): + hosts = params + case strings.EqualFold(directive, "Hostname"): + for _, host := range hosts { + for _, name := range params { + p.aliasMap[host] = sshExpandTokens(name, host) + } + } + case strings.EqualFold(directive, "Include"): + for _, path := range absolutePaths(fileName, params) { + fileNames, err := filepath.Glob(path) + if err != nil { + continue + } + + for _, fileName := range fileNames { + _ = p.read(fileName) + } + } + } + } + + return scanner.Err() +} + +func isSystem(path string) bool { + return strings.HasPrefix(path, "/etc/ssh") +} + +func absolutePaths(parentFile string, paths []string) []string { + absPaths := make([]string, len(paths)) + + for i, path := range paths { + switch { + case filepath.IsAbs(path): + absPaths[i] = path + case strings.HasPrefix(path, "~"): + absPaths[i], _ = homedir.Expand(path) + case isSystem(parentFile): + absPaths[i] = filepath.Join("/etc", "ssh", path) + default: + dir, _ := homedir.Dir() + absPaths[i] = filepath.Join(dir, ".ssh", path) + } + } + + return absPaths +} + +func parse(files ...string) SSHAliasMap { + p := parser{aliasMap: make(SSHAliasMap)} + + for _, file := range files { + _ = p.read(file) + } + + return p.aliasMap +} + // ParseSSHConfig constructs a map of SSH hostname aliases based on user and // system configuration files func ParseSSHConfig() SSHAliasMap { @@ -57,49 +136,7 @@ func ParseSSHConfig() SSHAliasMap { configFiles = append([]string{userConfig}, configFiles...) } - openFiles := make([]io.Reader, 0, len(configFiles)) - for _, file := range configFiles { - f, err := os.Open(file) - if err != nil { - continue - } - defer f.Close() - openFiles = append(openFiles, f) - } - return sshParse(openFiles...) -} - -func sshParse(r ...io.Reader) SSHAliasMap { - config := make(SSHAliasMap) - for _, file := range r { - _ = sshParseConfig(config, file) - } - return config -} - -func sshParseConfig(c SSHAliasMap, file io.Reader) error { - hosts := []string{"*"} - scanner := bufio.NewScanner(file) - for scanner.Scan() { - line := scanner.Text() - match := sshHostRE.FindStringSubmatch(line) - if match == nil { - continue - } - - names := strings.Fields(match[2]) - if strings.EqualFold(match[1], "host") { - hosts = names - } else { - for _, host := range hosts { - for _, name := range names { - c[host] = sshExpandTokens(name, host) - } - } - } - } - - return scanner.Err() + return parse(configFiles...) } func sshExpandTokens(text, host string) string { diff --git a/git/ssh_config_test.go b/git/ssh_config_test.go index 7aafc5b21..55874fb6e 100644 --- a/git/ssh_config_test.go +++ b/git/ssh_config_test.go @@ -1,10 +1,15 @@ package git import ( + "fmt" + "io/ioutil" "net/url" + "os" + "path/filepath" "reflect" - "strings" "testing" + + "github.com/mitchellh/go-homedir" ) // TODO: extract assertion helpers into a shared package @@ -15,19 +20,96 @@ func eq(t *testing.T, got interface{}, expected interface{}) { } } -func Test_sshParse(t *testing.T) { - m := sshParse(strings.NewReader(` - Host foo bar - HostName example.com - `), strings.NewReader(` - Host bar baz - hostname %%%h.net%% - `)) +func createTempFile(t *testing.T, prefix string) *os.File { + t.Helper() + + dir, err := homedir.Dir() + if err != nil { + t.Errorf("Could not find homedir: %s", err) + } + + tempFile, err := ioutil.TempFile(filepath.Join(dir, ".ssh"), prefix) + if err != nil { + t.Errorf("Could create a temp file: %s", err) + } + + t.Cleanup(func() { + tempFile.Close() + os.Remove(tempFile.Name()) + }) + + return tempFile +} + +func Test_parse(t *testing.T) { + includedTempFile := createTempFile(t, "included") + includedConfigFile := ` +Host webapp + HostName webapp.example.com + ` + fmt.Fprint(includedTempFile, includedConfigFile) + + m := parse( + "testdata/ssh_config1.conf", + "testdata/ssh_config2.conf", + "testdata/ssh_config3.conf", + ) + eq(t, m["foo"], "example.com") eq(t, m["bar"], "%bar.net%") eq(t, m["nonexistent"], "") } +func Test_absolutePaths(t *testing.T) { + dir, err := homedir.Dir() + if err != nil { + t.Errorf("Could not find homedir: %s", err) + } + + tests := map[string]struct { + parentFile string + Input []string + Want []string + }{ + "absolute path": { + parentFile: "/etc/ssh/ssh_config", + Input: []string{"/etc/ssh/config"}, + Want: []string{"/etc/ssh/config"}, + }, + "system relative path": { + parentFile: "/etc/ssh/config", + Input: []string{"configs/*.conf"}, + Want: []string{"/etc/ssh/configs/*.conf"}, + }, + "user relative path": { + parentFile: filepath.Join(dir, ".ssh", "ssh_config"), + Input: []string{"configs/*.conf"}, + Want: []string{filepath.Join(dir, ".ssh", "configs/*.conf")}, + }, + "shell-like ~ rerefence": { + parentFile: filepath.Join(dir, ".ssh", "ssh_config"), + Input: []string{"~/.ssh/*.conf"}, + Want: []string{filepath.Join(dir, ".ssh", "*.conf")}, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + paths := absolutePaths(test.parentFile, test.Input) + + if len(paths) != len(test.Input) { + t.Errorf("Expected %d, got %d", len(test.Input), len(paths)) + } + + for i, path := range paths { + if path != test.Want[i] { + t.Errorf("Expected %q, got %q", test.Want[i], path) + } + } + }) + } +} + func Test_Translator(t *testing.T) { m := SSHAliasMap{ "gh": "github.com", diff --git a/git/testdata/included.conf b/git/testdata/included.conf new file mode 100644 index 000000000..2ded9d103 --- /dev/null +++ b/git/testdata/included.conf @@ -0,0 +1,2 @@ +Host webapp + HostName webapp.example.com diff --git a/git/testdata/ssh_config1.conf b/git/testdata/ssh_config1.conf new file mode 100644 index 000000000..7249b01fb --- /dev/null +++ b/git/testdata/ssh_config1.conf @@ -0,0 +1,2 @@ +Host foo bar + HostName example.com diff --git a/git/testdata/ssh_config2.conf b/git/testdata/ssh_config2.conf new file mode 100644 index 000000000..3884f3d15 --- /dev/null +++ b/git/testdata/ssh_config2.conf @@ -0,0 +1,2 @@ +Host bar baz +hostname %%%h.net%% diff --git a/git/testdata/ssh_config3.conf b/git/testdata/ssh_config3.conf new file mode 100644 index 000000000..7b55743ea --- /dev/null +++ b/git/testdata/ssh_config3.conf @@ -0,0 +1 @@ +Include ~/.ssh/included* From 935f6444ae330b8ba89fd048c4ef5ef145e5910f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Tue, 15 Dec 2020 15:02:49 +0100 Subject: [PATCH 2/2] Refactor ssh parser for format compatibility & testability - Per ssh_config(5), keywords and arguments may be separated by an `=` sign as well as whitespace. - When following the `Include` directive, skip directories that were returned as the result of globbing. - Respect the `Host` context when recursing into `Include`s - Avoid having tests read from the actual filesystem. - Avoid repeatedly looking up the home directory. --- git/remote_test.go | 13 ++- git/ssh_config.go | 134 +++++++++++++++------------ git/ssh_config_test.go | 164 ++++++++++++++++++---------------- git/testdata/included.conf | 2 - git/testdata/ssh_config1.conf | 2 - git/testdata/ssh_config2.conf | 2 - git/testdata/ssh_config3.conf | 1 - 7 files changed, 178 insertions(+), 140 deletions(-) delete mode 100644 git/testdata/included.conf delete mode 100644 git/testdata/ssh_config1.conf delete mode 100644 git/testdata/ssh_config2.conf delete mode 100644 git/testdata/ssh_config3.conf diff --git a/git/remote_test.go b/git/remote_test.go index 2e7d30cb6..e8c091653 100644 --- a/git/remote_test.go +++ b/git/remote_test.go @@ -1,6 +1,17 @@ package git -import "testing" +import ( + "reflect" + "testing" +) + +// TODO: extract assertion helpers into a shared package +func eq(t *testing.T, got interface{}, expected interface{}) { + t.Helper() + if !reflect.DeepEqual(got, expected) { + t.Errorf("expected: %v, got: %v", expected, got) + } +} func Test_parseRemotes(t *testing.T) { remoteList := []string{ diff --git a/git/ssh_config.go b/git/ssh_config.go index b65d8c895..317ff6059 100644 --- a/git/ssh_config.go +++ b/git/ssh_config.go @@ -2,6 +2,7 @@ package git import ( "bufio" + "io" "net/url" "os" "path/filepath" @@ -12,13 +13,10 @@ import ( ) var ( - sshTokenRE *regexp.Regexp + sshConfigLineRE = regexp.MustCompile(`\A\s*(?P[A-Za-z][A-Za-z0-9]*)(?:\s+|\s*=\s*)(?P.+)`) + sshTokenRE = regexp.MustCompile(`%[%h]`) ) -func init() { - sshTokenRE = regexp.MustCompile(`%[%h]`) -} - // SSHAliasMap encapsulates the translation of SSH hostname aliases type SSHAliasMap map[string]string @@ -42,42 +40,75 @@ func (m SSHAliasMap) Translator() func(*url.URL) *url.URL { } } -type parser struct { +type sshParser struct { + homeDir string + aliasMap SSHAliasMap + hosts []string + + open func(string) (io.Reader, error) + glob func(string) ([]string, error) } -func (p *parser) read(fileName string) error { - file, err := os.Open(fileName) - if err != nil { - return err +func (p *sshParser) read(fileName string) error { + var file io.Reader + if p.open == nil { + f, err := os.Open(fileName) + if err != nil { + return err + } + defer f.Close() + file = f + } else { + var err error + file, err = p.open(fileName) + if err != nil { + return err + } + } + + if len(p.hosts) == 0 { + p.hosts = []string{"*"} } - defer file.Close() - hosts := []string{"*"} scanner := bufio.NewScanner(file) for scanner.Scan() { - line := scanner.Text() - fields := strings.Fields(line) - - if len(fields) < 2 { + m := sshConfigLineRE.FindStringSubmatch(scanner.Text()) + if len(m) < 3 { continue } - directive, params := fields[0], fields[1:] - switch { - case strings.EqualFold(directive, "Host"): - hosts = params - case strings.EqualFold(directive, "Hostname"): - for _, host := range hosts { - for _, name := range params { + keyword, arguments := strings.ToLower(m[1]), m[2] + switch keyword { + case "host": + p.hosts = strings.Fields(arguments) + case "hostname": + for _, host := range p.hosts { + for _, name := range strings.Fields(arguments) { + if p.aliasMap == nil { + p.aliasMap = make(SSHAliasMap) + } p.aliasMap[host] = sshExpandTokens(name, host) } } - case strings.EqualFold(directive, "Include"): - for _, path := range absolutePaths(fileName, params) { - fileNames, err := filepath.Glob(path) - if err != nil { - continue + case "include": + for _, arg := range strings.Fields(arguments) { + path := p.absolutePath(fileName, arg) + + var fileNames []string + if p.glob == nil { + paths, _ := filepath.Glob(path) + for _, p := range paths { + if s, err := os.Stat(p); err == nil && !s.IsDir() { + fileNames = append(fileNames, p) + } + } + } else { + var err error + fileNames, err = p.glob(path) + if err != nil { + continue + } } for _, fileName := range fileNames { @@ -90,38 +121,20 @@ func (p *parser) read(fileName string) error { return scanner.Err() } -func isSystem(path string) bool { - return strings.HasPrefix(path, "/etc/ssh") -} - -func absolutePaths(parentFile string, paths []string) []string { - absPaths := make([]string, len(paths)) - - for i, path := range paths { - switch { - case filepath.IsAbs(path): - absPaths[i] = path - case strings.HasPrefix(path, "~"): - absPaths[i], _ = homedir.Expand(path) - case isSystem(parentFile): - absPaths[i] = filepath.Join("/etc", "ssh", path) - default: - dir, _ := homedir.Dir() - absPaths[i] = filepath.Join(dir, ".ssh", path) - } +func (p *sshParser) absolutePath(parentFile, path string) string { + if filepath.IsAbs(path) || strings.HasPrefix(filepath.ToSlash(path), "/") { + return path } - return absPaths -} - -func parse(files ...string) SSHAliasMap { - p := parser{aliasMap: make(SSHAliasMap)} - - for _, file := range files { - _ = p.read(file) + if strings.HasPrefix(path, "~") { + return filepath.Join(p.homeDir, strings.TrimPrefix(path, "~")) } - return p.aliasMap + if strings.HasPrefix(filepath.ToSlash(parentFile), "/etc/ssh") { + return filepath.Join("/etc/ssh", path) + } + + return filepath.Join(p.homeDir, ".ssh", path) } // ParseSSHConfig constructs a map of SSH hostname aliases based on user and @@ -131,12 +144,19 @@ func ParseSSHConfig() SSHAliasMap { "/etc/ssh_config", "/etc/ssh/ssh_config", } + + p := sshParser{} + if homedir, err := homedir.Dir(); err == nil { userConfig := filepath.Join(homedir, ".ssh", "config") configFiles = append([]string{userConfig}, configFiles...) + p.homeDir = homedir } - return parse(configFiles...) + for _, file := range configFiles { + _ = p.read(file) + } + return p.aliasMap } func sshExpandTokens(text, host string) string { diff --git a/git/ssh_config_test.go b/git/ssh_config_test.go index 55874fb6e..f05ca303b 100644 --- a/git/ssh_config_test.go +++ b/git/ssh_config_test.go @@ -1,110 +1,124 @@ package git import ( + "bytes" "fmt" - "io/ioutil" + "io" "net/url" - "os" "path/filepath" - "reflect" "testing" - "github.com/mitchellh/go-homedir" + "github.com/MakeNowJust/heredoc" ) -// TODO: extract assertion helpers into a shared package -func eq(t *testing.T, got interface{}, expected interface{}) { - t.Helper() - if !reflect.DeepEqual(got, expected) { - t.Errorf("expected: %v, got: %v", expected, got) +func Test_sshParser_read(t *testing.T) { + testFiles := map[string]string{ + "/etc/ssh/config": heredoc.Doc(` + Include sites/* + `), + "/etc/ssh/sites/cfg1": heredoc.Doc(` + Host s1 + Hostname=site1.net + `), + "/etc/ssh/sites/cfg2": heredoc.Doc(` + Host s2 + Hostname = site2.net + `), + "HOME/.ssh/config": heredoc.Doc(` + Host * + Host gh gittyhubby + Hostname github.com + #Hostname example.com + Host ex + Include ex_config/* + `), + "HOME/.ssh/ex_config/ex_cfg": heredoc.Doc(` + Hostname example.com + `), + } + globResults := map[string][]string{ + "/etc/ssh/sites/*": {"/etc/ssh/sites/cfg1", "/etc/ssh/sites/cfg2"}, + "HOME/.ssh/ex_config/*": {"HOME/.ssh/ex_config/ex_cfg"}, + } + + p := &sshParser{ + homeDir: "HOME", + open: func(s string) (io.Reader, error) { + if contents, ok := testFiles[filepath.ToSlash(s)]; ok { + return bytes.NewBufferString(contents), nil + } else { + return nil, fmt.Errorf("no test file stub found: %q", s) + } + }, + glob: func(p string) ([]string, error) { + if results, ok := globResults[filepath.ToSlash(p)]; ok { + return results, nil + } else { + return nil, fmt.Errorf("no glob stubs found: %q", p) + } + }, + } + + if err := p.read("/etc/ssh/config"); err != nil { + t.Fatalf("read(global config) = %v", err) + } + if err := p.read("HOME/.ssh/config"); err != nil { + t.Fatalf("read(user config) = %v", err) + } + + if got := p.aliasMap["gh"]; got != "github.com" { + t.Errorf("expected alias %q to expand to %q, got %q", "gh", "github.com", got) + } + if got := p.aliasMap["gittyhubby"]; got != "github.com" { + t.Errorf("expected alias %q to expand to %q, got %q", "gittyhubby", "github.com", got) + } + if got := p.aliasMap["example.com"]; got != "" { + t.Errorf("expected alias %q to expand to %q, got %q", "example.com", "", got) + } + if got := p.aliasMap["ex"]; got != "example.com" { + t.Errorf("expected alias %q to expand to %q, got %q", "ex", "example.com", got) + } + if got := p.aliasMap["s1"]; got != "site1.net" { + t.Errorf("expected alias %q to expand to %q, got %q", "s1", "site1.net", got) } } -func createTempFile(t *testing.T, prefix string) *os.File { - t.Helper() - - dir, err := homedir.Dir() - if err != nil { - t.Errorf("Could not find homedir: %s", err) - } - - tempFile, err := ioutil.TempFile(filepath.Join(dir, ".ssh"), prefix) - if err != nil { - t.Errorf("Could create a temp file: %s", err) - } - - t.Cleanup(func() { - tempFile.Close() - os.Remove(tempFile.Name()) - }) - - return tempFile -} - -func Test_parse(t *testing.T) { - includedTempFile := createTempFile(t, "included") - includedConfigFile := ` -Host webapp - HostName webapp.example.com - ` - fmt.Fprint(includedTempFile, includedConfigFile) - - m := parse( - "testdata/ssh_config1.conf", - "testdata/ssh_config2.conf", - "testdata/ssh_config3.conf", - ) - - eq(t, m["foo"], "example.com") - eq(t, m["bar"], "%bar.net%") - eq(t, m["nonexistent"], "") -} - -func Test_absolutePaths(t *testing.T) { - dir, err := homedir.Dir() - if err != nil { - t.Errorf("Could not find homedir: %s", err) - } +func Test_sshParser_absolutePath(t *testing.T) { + dir := "HOME" + p := &sshParser{homeDir: dir} tests := map[string]struct { parentFile string - Input []string - Want []string + arg string + want string + wantErr bool }{ "absolute path": { parentFile: "/etc/ssh/ssh_config", - Input: []string{"/etc/ssh/config"}, - Want: []string{"/etc/ssh/config"}, + arg: "/etc/ssh/config", + want: "/etc/ssh/config", }, "system relative path": { parentFile: "/etc/ssh/config", - Input: []string{"configs/*.conf"}, - Want: []string{"/etc/ssh/configs/*.conf"}, + arg: "configs/*.conf", + want: filepath.Join("/etc", "ssh", "configs", "*.conf"), }, "user relative path": { parentFile: filepath.Join(dir, ".ssh", "ssh_config"), - Input: []string{"configs/*.conf"}, - Want: []string{filepath.Join(dir, ".ssh", "configs/*.conf")}, + arg: "configs/*.conf", + want: filepath.Join(dir, ".ssh", "configs/*.conf"), }, "shell-like ~ rerefence": { parentFile: filepath.Join(dir, ".ssh", "ssh_config"), - Input: []string{"~/.ssh/*.conf"}, - Want: []string{filepath.Join(dir, ".ssh", "*.conf")}, + arg: "~/.ssh/*.conf", + want: filepath.Join(dir, ".ssh", "*.conf"), }, } - for name, test := range tests { + for name, tt := range tests { t.Run(name, func(t *testing.T) { - paths := absolutePaths(test.parentFile, test.Input) - - if len(paths) != len(test.Input) { - t.Errorf("Expected %d, got %d", len(test.Input), len(paths)) - } - - for i, path := range paths { - if path != test.Want[i] { - t.Errorf("Expected %q, got %q", test.Want[i], path) - } + if got := p.absolutePath(tt.parentFile, tt.arg); got != tt.want { + t.Errorf("absolutePath(): %q, wants %q", got, tt.want) } }) } diff --git a/git/testdata/included.conf b/git/testdata/included.conf deleted file mode 100644 index 2ded9d103..000000000 --- a/git/testdata/included.conf +++ /dev/null @@ -1,2 +0,0 @@ -Host webapp - HostName webapp.example.com diff --git a/git/testdata/ssh_config1.conf b/git/testdata/ssh_config1.conf deleted file mode 100644 index 7249b01fb..000000000 --- a/git/testdata/ssh_config1.conf +++ /dev/null @@ -1,2 +0,0 @@ -Host foo bar - HostName example.com diff --git a/git/testdata/ssh_config2.conf b/git/testdata/ssh_config2.conf deleted file mode 100644 index 3884f3d15..000000000 --- a/git/testdata/ssh_config2.conf +++ /dev/null @@ -1,2 +0,0 @@ -Host bar baz -hostname %%%h.net%% diff --git a/git/testdata/ssh_config3.conf b/git/testdata/ssh_config3.conf deleted file mode 100644 index 7b55743ea..000000000 --- a/git/testdata/ssh_config3.conf +++ /dev/null @@ -1 +0,0 @@ -Include ~/.ssh/included*