Refactor ssh parser for format compatibility & testability
- Per ssh_config(5), keywords and arguments may be separated by an `=` sign as well as whitespace. - When following the `Include` directive, skip directories that were returned as the result of globbing. - Respect the `Host` context when recursing into `Include`s - Avoid having tests read from the actual filesystem. - Avoid repeatedly looking up the home directory.
This commit is contained in:
parent
dc8698ee46
commit
935f6444ae
7 changed files with 178 additions and 140 deletions
|
|
@ -1,6 +1,17 @@
|
|||
package git
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TODO: extract assertion helpers into a shared package
|
||||
func eq(t *testing.T, got interface{}, expected interface{}) {
|
||||
t.Helper()
|
||||
if !reflect.DeepEqual(got, expected) {
|
||||
t.Errorf("expected: %v, got: %v", expected, got)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_parseRemotes(t *testing.T) {
|
||||
remoteList := []string{
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package git
|
|||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
|
@ -12,13 +13,10 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
sshTokenRE *regexp.Regexp
|
||||
sshConfigLineRE = regexp.MustCompile(`\A\s*(?P<keyword>[A-Za-z][A-Za-z0-9]*)(?:\s+|\s*=\s*)(?P<argument>.+)`)
|
||||
sshTokenRE = regexp.MustCompile(`%[%h]`)
|
||||
)
|
||||
|
||||
func init() {
|
||||
sshTokenRE = regexp.MustCompile(`%[%h]`)
|
||||
}
|
||||
|
||||
// SSHAliasMap encapsulates the translation of SSH hostname aliases
|
||||
type SSHAliasMap map[string]string
|
||||
|
||||
|
|
@ -42,42 +40,75 @@ func (m SSHAliasMap) Translator() func(*url.URL) *url.URL {
|
|||
}
|
||||
}
|
||||
|
||||
type parser struct {
|
||||
type sshParser struct {
|
||||
homeDir string
|
||||
|
||||
aliasMap SSHAliasMap
|
||||
hosts []string
|
||||
|
||||
open func(string) (io.Reader, error)
|
||||
glob func(string) ([]string, error)
|
||||
}
|
||||
|
||||
func (p *parser) read(fileName string) error {
|
||||
file, err := os.Open(fileName)
|
||||
if err != nil {
|
||||
return err
|
||||
func (p *sshParser) read(fileName string) error {
|
||||
var file io.Reader
|
||||
if p.open == nil {
|
||||
f, err := os.Open(fileName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
file = f
|
||||
} else {
|
||||
var err error
|
||||
file, err = p.open(fileName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(p.hosts) == 0 {
|
||||
p.hosts = []string{"*"}
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
hosts := []string{"*"}
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
fields := strings.Fields(line)
|
||||
|
||||
if len(fields) < 2 {
|
||||
m := sshConfigLineRE.FindStringSubmatch(scanner.Text())
|
||||
if len(m) < 3 {
|
||||
continue
|
||||
}
|
||||
|
||||
directive, params := fields[0], fields[1:]
|
||||
switch {
|
||||
case strings.EqualFold(directive, "Host"):
|
||||
hosts = params
|
||||
case strings.EqualFold(directive, "Hostname"):
|
||||
for _, host := range hosts {
|
||||
for _, name := range params {
|
||||
keyword, arguments := strings.ToLower(m[1]), m[2]
|
||||
switch keyword {
|
||||
case "host":
|
||||
p.hosts = strings.Fields(arguments)
|
||||
case "hostname":
|
||||
for _, host := range p.hosts {
|
||||
for _, name := range strings.Fields(arguments) {
|
||||
if p.aliasMap == nil {
|
||||
p.aliasMap = make(SSHAliasMap)
|
||||
}
|
||||
p.aliasMap[host] = sshExpandTokens(name, host)
|
||||
}
|
||||
}
|
||||
case strings.EqualFold(directive, "Include"):
|
||||
for _, path := range absolutePaths(fileName, params) {
|
||||
fileNames, err := filepath.Glob(path)
|
||||
if err != nil {
|
||||
continue
|
||||
case "include":
|
||||
for _, arg := range strings.Fields(arguments) {
|
||||
path := p.absolutePath(fileName, arg)
|
||||
|
||||
var fileNames []string
|
||||
if p.glob == nil {
|
||||
paths, _ := filepath.Glob(path)
|
||||
for _, p := range paths {
|
||||
if s, err := os.Stat(p); err == nil && !s.IsDir() {
|
||||
fileNames = append(fileNames, p)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
var err error
|
||||
fileNames, err = p.glob(path)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
for _, fileName := range fileNames {
|
||||
|
|
@ -90,38 +121,20 @@ func (p *parser) read(fileName string) error {
|
|||
return scanner.Err()
|
||||
}
|
||||
|
||||
func isSystem(path string) bool {
|
||||
return strings.HasPrefix(path, "/etc/ssh")
|
||||
}
|
||||
|
||||
func absolutePaths(parentFile string, paths []string) []string {
|
||||
absPaths := make([]string, len(paths))
|
||||
|
||||
for i, path := range paths {
|
||||
switch {
|
||||
case filepath.IsAbs(path):
|
||||
absPaths[i] = path
|
||||
case strings.HasPrefix(path, "~"):
|
||||
absPaths[i], _ = homedir.Expand(path)
|
||||
case isSystem(parentFile):
|
||||
absPaths[i] = filepath.Join("/etc", "ssh", path)
|
||||
default:
|
||||
dir, _ := homedir.Dir()
|
||||
absPaths[i] = filepath.Join(dir, ".ssh", path)
|
||||
}
|
||||
func (p *sshParser) absolutePath(parentFile, path string) string {
|
||||
if filepath.IsAbs(path) || strings.HasPrefix(filepath.ToSlash(path), "/") {
|
||||
return path
|
||||
}
|
||||
|
||||
return absPaths
|
||||
}
|
||||
|
||||
func parse(files ...string) SSHAliasMap {
|
||||
p := parser{aliasMap: make(SSHAliasMap)}
|
||||
|
||||
for _, file := range files {
|
||||
_ = p.read(file)
|
||||
if strings.HasPrefix(path, "~") {
|
||||
return filepath.Join(p.homeDir, strings.TrimPrefix(path, "~"))
|
||||
}
|
||||
|
||||
return p.aliasMap
|
||||
if strings.HasPrefix(filepath.ToSlash(parentFile), "/etc/ssh") {
|
||||
return filepath.Join("/etc/ssh", path)
|
||||
}
|
||||
|
||||
return filepath.Join(p.homeDir, ".ssh", path)
|
||||
}
|
||||
|
||||
// ParseSSHConfig constructs a map of SSH hostname aliases based on user and
|
||||
|
|
@ -131,12 +144,19 @@ func ParseSSHConfig() SSHAliasMap {
|
|||
"/etc/ssh_config",
|
||||
"/etc/ssh/ssh_config",
|
||||
}
|
||||
|
||||
p := sshParser{}
|
||||
|
||||
if homedir, err := homedir.Dir(); err == nil {
|
||||
userConfig := filepath.Join(homedir, ".ssh", "config")
|
||||
configFiles = append([]string{userConfig}, configFiles...)
|
||||
p.homeDir = homedir
|
||||
}
|
||||
|
||||
return parse(configFiles...)
|
||||
for _, file := range configFiles {
|
||||
_ = p.read(file)
|
||||
}
|
||||
return p.aliasMap
|
||||
}
|
||||
|
||||
func sshExpandTokens(text, host string) string {
|
||||
|
|
|
|||
|
|
@ -1,110 +1,124 @@
|
|||
package git
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"io"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/mitchellh/go-homedir"
|
||||
"github.com/MakeNowJust/heredoc"
|
||||
)
|
||||
|
||||
// TODO: extract assertion helpers into a shared package
|
||||
func eq(t *testing.T, got interface{}, expected interface{}) {
|
||||
t.Helper()
|
||||
if !reflect.DeepEqual(got, expected) {
|
||||
t.Errorf("expected: %v, got: %v", expected, got)
|
||||
func Test_sshParser_read(t *testing.T) {
|
||||
testFiles := map[string]string{
|
||||
"/etc/ssh/config": heredoc.Doc(`
|
||||
Include sites/*
|
||||
`),
|
||||
"/etc/ssh/sites/cfg1": heredoc.Doc(`
|
||||
Host s1
|
||||
Hostname=site1.net
|
||||
`),
|
||||
"/etc/ssh/sites/cfg2": heredoc.Doc(`
|
||||
Host s2
|
||||
Hostname = site2.net
|
||||
`),
|
||||
"HOME/.ssh/config": heredoc.Doc(`
|
||||
Host *
|
||||
Host gh gittyhubby
|
||||
Hostname github.com
|
||||
#Hostname example.com
|
||||
Host ex
|
||||
Include ex_config/*
|
||||
`),
|
||||
"HOME/.ssh/ex_config/ex_cfg": heredoc.Doc(`
|
||||
Hostname example.com
|
||||
`),
|
||||
}
|
||||
globResults := map[string][]string{
|
||||
"/etc/ssh/sites/*": {"/etc/ssh/sites/cfg1", "/etc/ssh/sites/cfg2"},
|
||||
"HOME/.ssh/ex_config/*": {"HOME/.ssh/ex_config/ex_cfg"},
|
||||
}
|
||||
|
||||
p := &sshParser{
|
||||
homeDir: "HOME",
|
||||
open: func(s string) (io.Reader, error) {
|
||||
if contents, ok := testFiles[filepath.ToSlash(s)]; ok {
|
||||
return bytes.NewBufferString(contents), nil
|
||||
} else {
|
||||
return nil, fmt.Errorf("no test file stub found: %q", s)
|
||||
}
|
||||
},
|
||||
glob: func(p string) ([]string, error) {
|
||||
if results, ok := globResults[filepath.ToSlash(p)]; ok {
|
||||
return results, nil
|
||||
} else {
|
||||
return nil, fmt.Errorf("no glob stubs found: %q", p)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
if err := p.read("/etc/ssh/config"); err != nil {
|
||||
t.Fatalf("read(global config) = %v", err)
|
||||
}
|
||||
if err := p.read("HOME/.ssh/config"); err != nil {
|
||||
t.Fatalf("read(user config) = %v", err)
|
||||
}
|
||||
|
||||
if got := p.aliasMap["gh"]; got != "github.com" {
|
||||
t.Errorf("expected alias %q to expand to %q, got %q", "gh", "github.com", got)
|
||||
}
|
||||
if got := p.aliasMap["gittyhubby"]; got != "github.com" {
|
||||
t.Errorf("expected alias %q to expand to %q, got %q", "gittyhubby", "github.com", got)
|
||||
}
|
||||
if got := p.aliasMap["example.com"]; got != "" {
|
||||
t.Errorf("expected alias %q to expand to %q, got %q", "example.com", "", got)
|
||||
}
|
||||
if got := p.aliasMap["ex"]; got != "example.com" {
|
||||
t.Errorf("expected alias %q to expand to %q, got %q", "ex", "example.com", got)
|
||||
}
|
||||
if got := p.aliasMap["s1"]; got != "site1.net" {
|
||||
t.Errorf("expected alias %q to expand to %q, got %q", "s1", "site1.net", got)
|
||||
}
|
||||
}
|
||||
|
||||
func createTempFile(t *testing.T, prefix string) *os.File {
|
||||
t.Helper()
|
||||
|
||||
dir, err := homedir.Dir()
|
||||
if err != nil {
|
||||
t.Errorf("Could not find homedir: %s", err)
|
||||
}
|
||||
|
||||
tempFile, err := ioutil.TempFile(filepath.Join(dir, ".ssh"), prefix)
|
||||
if err != nil {
|
||||
t.Errorf("Could create a temp file: %s", err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
tempFile.Close()
|
||||
os.Remove(tempFile.Name())
|
||||
})
|
||||
|
||||
return tempFile
|
||||
}
|
||||
|
||||
func Test_parse(t *testing.T) {
|
||||
includedTempFile := createTempFile(t, "included")
|
||||
includedConfigFile := `
|
||||
Host webapp
|
||||
HostName webapp.example.com
|
||||
`
|
||||
fmt.Fprint(includedTempFile, includedConfigFile)
|
||||
|
||||
m := parse(
|
||||
"testdata/ssh_config1.conf",
|
||||
"testdata/ssh_config2.conf",
|
||||
"testdata/ssh_config3.conf",
|
||||
)
|
||||
|
||||
eq(t, m["foo"], "example.com")
|
||||
eq(t, m["bar"], "%bar.net%")
|
||||
eq(t, m["nonexistent"], "")
|
||||
}
|
||||
|
||||
func Test_absolutePaths(t *testing.T) {
|
||||
dir, err := homedir.Dir()
|
||||
if err != nil {
|
||||
t.Errorf("Could not find homedir: %s", err)
|
||||
}
|
||||
func Test_sshParser_absolutePath(t *testing.T) {
|
||||
dir := "HOME"
|
||||
p := &sshParser{homeDir: dir}
|
||||
|
||||
tests := map[string]struct {
|
||||
parentFile string
|
||||
Input []string
|
||||
Want []string
|
||||
arg string
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
"absolute path": {
|
||||
parentFile: "/etc/ssh/ssh_config",
|
||||
Input: []string{"/etc/ssh/config"},
|
||||
Want: []string{"/etc/ssh/config"},
|
||||
arg: "/etc/ssh/config",
|
||||
want: "/etc/ssh/config",
|
||||
},
|
||||
"system relative path": {
|
||||
parentFile: "/etc/ssh/config",
|
||||
Input: []string{"configs/*.conf"},
|
||||
Want: []string{"/etc/ssh/configs/*.conf"},
|
||||
arg: "configs/*.conf",
|
||||
want: filepath.Join("/etc", "ssh", "configs", "*.conf"),
|
||||
},
|
||||
"user relative path": {
|
||||
parentFile: filepath.Join(dir, ".ssh", "ssh_config"),
|
||||
Input: []string{"configs/*.conf"},
|
||||
Want: []string{filepath.Join(dir, ".ssh", "configs/*.conf")},
|
||||
arg: "configs/*.conf",
|
||||
want: filepath.Join(dir, ".ssh", "configs/*.conf"),
|
||||
},
|
||||
"shell-like ~ rerefence": {
|
||||
parentFile: filepath.Join(dir, ".ssh", "ssh_config"),
|
||||
Input: []string{"~/.ssh/*.conf"},
|
||||
Want: []string{filepath.Join(dir, ".ssh", "*.conf")},
|
||||
arg: "~/.ssh/*.conf",
|
||||
want: filepath.Join(dir, ".ssh", "*.conf"),
|
||||
},
|
||||
}
|
||||
|
||||
for name, test := range tests {
|
||||
for name, tt := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
paths := absolutePaths(test.parentFile, test.Input)
|
||||
|
||||
if len(paths) != len(test.Input) {
|
||||
t.Errorf("Expected %d, got %d", len(test.Input), len(paths))
|
||||
}
|
||||
|
||||
for i, path := range paths {
|
||||
if path != test.Want[i] {
|
||||
t.Errorf("Expected %q, got %q", test.Want[i], path)
|
||||
}
|
||||
if got := p.absolutePath(tt.parentFile, tt.arg); got != tt.want {
|
||||
t.Errorf("absolutePath(): %q, wants %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
|||
2
git/testdata/included.conf
vendored
2
git/testdata/included.conf
vendored
|
|
@ -1,2 +0,0 @@
|
|||
Host webapp
|
||||
HostName webapp.example.com
|
||||
2
git/testdata/ssh_config1.conf
vendored
2
git/testdata/ssh_config1.conf
vendored
|
|
@ -1,2 +0,0 @@
|
|||
Host foo bar
|
||||
HostName example.com
|
||||
2
git/testdata/ssh_config2.conf
vendored
2
git/testdata/ssh_config2.conf
vendored
|
|
@ -1,2 +0,0 @@
|
|||
Host bar baz
|
||||
hostname %%%h.net%%
|
||||
1
git/testdata/ssh_config3.conf
vendored
1
git/testdata/ssh_config3.conf
vendored
|
|
@ -1 +0,0 @@
|
|||
Include ~/.ssh/included*
|
||||
Loading…
Add table
Add a link
Reference in a new issue