Refactor ssh parser for format compatibility & testability

- Per ssh_config(5), keywords and arguments may be separated by an `=`
  sign as well as whitespace.
- When following the `Include` directive, skip directories that were
  returned as the result of globbing.
- Respect the `Host` context when recursing into `Include`s
- Avoid having tests read from the actual filesystem.
- Avoid repeatedly looking up the home directory.
This commit is contained in:
Mislav Marohnić 2020-12-15 15:02:49 +01:00
parent dc8698ee46
commit 935f6444ae
7 changed files with 178 additions and 140 deletions

View file

@ -1,6 +1,17 @@
package git
import "testing"
import (
"reflect"
"testing"
)
// TODO: extract assertion helpers into a shared package
func eq(t *testing.T, got interface{}, expected interface{}) {
t.Helper()
if !reflect.DeepEqual(got, expected) {
t.Errorf("expected: %v, got: %v", expected, got)
}
}
func Test_parseRemotes(t *testing.T) {
remoteList := []string{

View file

@ -2,6 +2,7 @@ package git
import (
"bufio"
"io"
"net/url"
"os"
"path/filepath"
@ -12,13 +13,10 @@ import (
)
var (
sshTokenRE *regexp.Regexp
sshConfigLineRE = regexp.MustCompile(`\A\s*(?P<keyword>[A-Za-z][A-Za-z0-9]*)(?:\s+|\s*=\s*)(?P<argument>.+)`)
sshTokenRE = regexp.MustCompile(`%[%h]`)
)
func init() {
sshTokenRE = regexp.MustCompile(`%[%h]`)
}
// SSHAliasMap encapsulates the translation of SSH hostname aliases
type SSHAliasMap map[string]string
@ -42,42 +40,75 @@ func (m SSHAliasMap) Translator() func(*url.URL) *url.URL {
}
}
type parser struct {
type sshParser struct {
homeDir string
aliasMap SSHAliasMap
hosts []string
open func(string) (io.Reader, error)
glob func(string) ([]string, error)
}
func (p *parser) read(fileName string) error {
file, err := os.Open(fileName)
if err != nil {
return err
func (p *sshParser) read(fileName string) error {
var file io.Reader
if p.open == nil {
f, err := os.Open(fileName)
if err != nil {
return err
}
defer f.Close()
file = f
} else {
var err error
file, err = p.open(fileName)
if err != nil {
return err
}
}
if len(p.hosts) == 0 {
p.hosts = []string{"*"}
}
defer file.Close()
hosts := []string{"*"}
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
fields := strings.Fields(line)
if len(fields) < 2 {
m := sshConfigLineRE.FindStringSubmatch(scanner.Text())
if len(m) < 3 {
continue
}
directive, params := fields[0], fields[1:]
switch {
case strings.EqualFold(directive, "Host"):
hosts = params
case strings.EqualFold(directive, "Hostname"):
for _, host := range hosts {
for _, name := range params {
keyword, arguments := strings.ToLower(m[1]), m[2]
switch keyword {
case "host":
p.hosts = strings.Fields(arguments)
case "hostname":
for _, host := range p.hosts {
for _, name := range strings.Fields(arguments) {
if p.aliasMap == nil {
p.aliasMap = make(SSHAliasMap)
}
p.aliasMap[host] = sshExpandTokens(name, host)
}
}
case strings.EqualFold(directive, "Include"):
for _, path := range absolutePaths(fileName, params) {
fileNames, err := filepath.Glob(path)
if err != nil {
continue
case "include":
for _, arg := range strings.Fields(arguments) {
path := p.absolutePath(fileName, arg)
var fileNames []string
if p.glob == nil {
paths, _ := filepath.Glob(path)
for _, p := range paths {
if s, err := os.Stat(p); err == nil && !s.IsDir() {
fileNames = append(fileNames, p)
}
}
} else {
var err error
fileNames, err = p.glob(path)
if err != nil {
continue
}
}
for _, fileName := range fileNames {
@ -90,38 +121,20 @@ func (p *parser) read(fileName string) error {
return scanner.Err()
}
func isSystem(path string) bool {
return strings.HasPrefix(path, "/etc/ssh")
}
func absolutePaths(parentFile string, paths []string) []string {
absPaths := make([]string, len(paths))
for i, path := range paths {
switch {
case filepath.IsAbs(path):
absPaths[i] = path
case strings.HasPrefix(path, "~"):
absPaths[i], _ = homedir.Expand(path)
case isSystem(parentFile):
absPaths[i] = filepath.Join("/etc", "ssh", path)
default:
dir, _ := homedir.Dir()
absPaths[i] = filepath.Join(dir, ".ssh", path)
}
func (p *sshParser) absolutePath(parentFile, path string) string {
if filepath.IsAbs(path) || strings.HasPrefix(filepath.ToSlash(path), "/") {
return path
}
return absPaths
}
func parse(files ...string) SSHAliasMap {
p := parser{aliasMap: make(SSHAliasMap)}
for _, file := range files {
_ = p.read(file)
if strings.HasPrefix(path, "~") {
return filepath.Join(p.homeDir, strings.TrimPrefix(path, "~"))
}
return p.aliasMap
if strings.HasPrefix(filepath.ToSlash(parentFile), "/etc/ssh") {
return filepath.Join("/etc/ssh", path)
}
return filepath.Join(p.homeDir, ".ssh", path)
}
// ParseSSHConfig constructs a map of SSH hostname aliases based on user and
@ -131,12 +144,19 @@ func ParseSSHConfig() SSHAliasMap {
"/etc/ssh_config",
"/etc/ssh/ssh_config",
}
p := sshParser{}
if homedir, err := homedir.Dir(); err == nil {
userConfig := filepath.Join(homedir, ".ssh", "config")
configFiles = append([]string{userConfig}, configFiles...)
p.homeDir = homedir
}
return parse(configFiles...)
for _, file := range configFiles {
_ = p.read(file)
}
return p.aliasMap
}
func sshExpandTokens(text, host string) string {

View file

@ -1,110 +1,124 @@
package git
import (
"bytes"
"fmt"
"io/ioutil"
"io"
"net/url"
"os"
"path/filepath"
"reflect"
"testing"
"github.com/mitchellh/go-homedir"
"github.com/MakeNowJust/heredoc"
)
// TODO: extract assertion helpers into a shared package
func eq(t *testing.T, got interface{}, expected interface{}) {
t.Helper()
if !reflect.DeepEqual(got, expected) {
t.Errorf("expected: %v, got: %v", expected, got)
func Test_sshParser_read(t *testing.T) {
testFiles := map[string]string{
"/etc/ssh/config": heredoc.Doc(`
Include sites/*
`),
"/etc/ssh/sites/cfg1": heredoc.Doc(`
Host s1
Hostname=site1.net
`),
"/etc/ssh/sites/cfg2": heredoc.Doc(`
Host s2
Hostname = site2.net
`),
"HOME/.ssh/config": heredoc.Doc(`
Host *
Host gh gittyhubby
Hostname github.com
#Hostname example.com
Host ex
Include ex_config/*
`),
"HOME/.ssh/ex_config/ex_cfg": heredoc.Doc(`
Hostname example.com
`),
}
globResults := map[string][]string{
"/etc/ssh/sites/*": {"/etc/ssh/sites/cfg1", "/etc/ssh/sites/cfg2"},
"HOME/.ssh/ex_config/*": {"HOME/.ssh/ex_config/ex_cfg"},
}
p := &sshParser{
homeDir: "HOME",
open: func(s string) (io.Reader, error) {
if contents, ok := testFiles[filepath.ToSlash(s)]; ok {
return bytes.NewBufferString(contents), nil
} else {
return nil, fmt.Errorf("no test file stub found: %q", s)
}
},
glob: func(p string) ([]string, error) {
if results, ok := globResults[filepath.ToSlash(p)]; ok {
return results, nil
} else {
return nil, fmt.Errorf("no glob stubs found: %q", p)
}
},
}
if err := p.read("/etc/ssh/config"); err != nil {
t.Fatalf("read(global config) = %v", err)
}
if err := p.read("HOME/.ssh/config"); err != nil {
t.Fatalf("read(user config) = %v", err)
}
if got := p.aliasMap["gh"]; got != "github.com" {
t.Errorf("expected alias %q to expand to %q, got %q", "gh", "github.com", got)
}
if got := p.aliasMap["gittyhubby"]; got != "github.com" {
t.Errorf("expected alias %q to expand to %q, got %q", "gittyhubby", "github.com", got)
}
if got := p.aliasMap["example.com"]; got != "" {
t.Errorf("expected alias %q to expand to %q, got %q", "example.com", "", got)
}
if got := p.aliasMap["ex"]; got != "example.com" {
t.Errorf("expected alias %q to expand to %q, got %q", "ex", "example.com", got)
}
if got := p.aliasMap["s1"]; got != "site1.net" {
t.Errorf("expected alias %q to expand to %q, got %q", "s1", "site1.net", got)
}
}
func createTempFile(t *testing.T, prefix string) *os.File {
t.Helper()
dir, err := homedir.Dir()
if err != nil {
t.Errorf("Could not find homedir: %s", err)
}
tempFile, err := ioutil.TempFile(filepath.Join(dir, ".ssh"), prefix)
if err != nil {
t.Errorf("Could create a temp file: %s", err)
}
t.Cleanup(func() {
tempFile.Close()
os.Remove(tempFile.Name())
})
return tempFile
}
func Test_parse(t *testing.T) {
includedTempFile := createTempFile(t, "included")
includedConfigFile := `
Host webapp
HostName webapp.example.com
`
fmt.Fprint(includedTempFile, includedConfigFile)
m := parse(
"testdata/ssh_config1.conf",
"testdata/ssh_config2.conf",
"testdata/ssh_config3.conf",
)
eq(t, m["foo"], "example.com")
eq(t, m["bar"], "%bar.net%")
eq(t, m["nonexistent"], "")
}
func Test_absolutePaths(t *testing.T) {
dir, err := homedir.Dir()
if err != nil {
t.Errorf("Could not find homedir: %s", err)
}
func Test_sshParser_absolutePath(t *testing.T) {
dir := "HOME"
p := &sshParser{homeDir: dir}
tests := map[string]struct {
parentFile string
Input []string
Want []string
arg string
want string
wantErr bool
}{
"absolute path": {
parentFile: "/etc/ssh/ssh_config",
Input: []string{"/etc/ssh/config"},
Want: []string{"/etc/ssh/config"},
arg: "/etc/ssh/config",
want: "/etc/ssh/config",
},
"system relative path": {
parentFile: "/etc/ssh/config",
Input: []string{"configs/*.conf"},
Want: []string{"/etc/ssh/configs/*.conf"},
arg: "configs/*.conf",
want: filepath.Join("/etc", "ssh", "configs", "*.conf"),
},
"user relative path": {
parentFile: filepath.Join(dir, ".ssh", "ssh_config"),
Input: []string{"configs/*.conf"},
Want: []string{filepath.Join(dir, ".ssh", "configs/*.conf")},
arg: "configs/*.conf",
want: filepath.Join(dir, ".ssh", "configs/*.conf"),
},
"shell-like ~ rerefence": {
parentFile: filepath.Join(dir, ".ssh", "ssh_config"),
Input: []string{"~/.ssh/*.conf"},
Want: []string{filepath.Join(dir, ".ssh", "*.conf")},
arg: "~/.ssh/*.conf",
want: filepath.Join(dir, ".ssh", "*.conf"),
},
}
for name, test := range tests {
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
paths := absolutePaths(test.parentFile, test.Input)
if len(paths) != len(test.Input) {
t.Errorf("Expected %d, got %d", len(test.Input), len(paths))
}
for i, path := range paths {
if path != test.Want[i] {
t.Errorf("Expected %q, got %q", test.Want[i], path)
}
if got := p.absolutePath(tt.parentFile, tt.arg); got != tt.want {
t.Errorf("absolutePath(): %q, wants %q", got, tt.want)
}
})
}

View file

@ -1,2 +0,0 @@
Host webapp
HostName webapp.example.com

View file

@ -1,2 +0,0 @@
Host foo bar
HostName example.com

View file

@ -1,2 +0,0 @@
Host bar baz
hostname %%%h.net%%

View file

@ -1 +0,0 @@
Include ~/.ssh/included*