From 344906bf0333d5a8b656a960bc3b68f99b5f710d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Thu, 17 Oct 2019 15:49:50 +0200 Subject: [PATCH] Test SSH config parser --- command/root.go | 2 ++ git/ssh_config.go | 59 +++++++++++++++++++++--------------------- git/ssh_config_test.go | 27 +++++++++++++++++++ git/url.go | 30 +++++++++++---------- 4 files changed, 75 insertions(+), 43 deletions(-) create mode 100644 git/ssh_config_test.go diff --git a/command/root.go b/command/root.go index f4c227cb5..1df3a2b0d 100644 --- a/command/root.go +++ b/command/root.go @@ -5,6 +5,7 @@ import ( "os" "github.com/github/gh-cli/context" + "github.com/github/gh-cli/git" "github.com/spf13/cobra" ) @@ -26,6 +27,7 @@ func initContext() { repo = os.Getenv("GH_REPO") } ctx.SetBaseRepo(repo) + git.InitSSHAliasMap(nil) } // RootCmd is the entry point of command-line execution diff --git a/git/ssh_config.go b/git/ssh_config.go index 6749d90cc..47d46ac46 100644 --- a/git/ssh_config.go +++ b/git/ssh_config.go @@ -2,6 +2,7 @@ package git import ( "bufio" + "io" "os" "path/filepath" "regexp" @@ -10,13 +11,19 @@ import ( "github.com/mitchellh/go-homedir" ) -const ( - hostReStr = "(?i)^[ \t]*(host|hostname)[ \t]+(.+)$" +var ( + sshHostRE, + sshTokenRE *regexp.Regexp ) -type SSHConfig map[string]string +func init() { + sshHostRE = regexp.MustCompile("(?i)^[ \t]*(host|hostname)[ \t]+(.+)$") + sshTokenRE = regexp.MustCompile(`%[%h]`) +} -func newSSHConfigReader() *SSHConfigReader { +type sshAliasMap map[string]string + +func sshParseFiles() sshAliasMap { configFiles := []string{ "/etc/ssh_config", "/etc/ssh/ssh_config", @@ -25,38 +32,33 @@ func newSSHConfigReader() *SSHConfigReader { userConfig := filepath.Join(homedir, ".ssh", "config") configFiles = append([]string{userConfig}, configFiles...) } - return &SSHConfigReader{ - Files: configFiles, + + openFiles := []io.Reader{} + for _, file := range configFiles { + f, err := os.Open(file) + if err != nil { + continue + } + defer f.Close() + openFiles = append(openFiles, f) } + return sshParse(openFiles...) } -type SSHConfigReader struct { - Files []string -} - -func (r *SSHConfigReader) Read() SSHConfig { - config := make(SSHConfig) - hostRe := regexp.MustCompile(hostReStr) - - for _, filename := range r.Files { - r.readFile(config, hostRe, filename) +func sshParse(r ...io.Reader) sshAliasMap { + config := sshAliasMap{} + for _, file := range r { + sshParseConfig(config, file) } - return config } -func (r *SSHConfigReader) readFile(c SSHConfig, re *regexp.Regexp, f string) error { - file, err := os.Open(f) - if err != nil { - return err - } - defer file.Close() - +func sshParseConfig(c sshAliasMap, file io.Reader) error { hosts := []string{"*"} scanner := bufio.NewScanner(file) for scanner.Scan() { line := scanner.Text() - match := re.FindStringSubmatch(line) + match := sshHostRE.FindStringSubmatch(line) if match == nil { continue } @@ -67,7 +69,7 @@ func (r *SSHConfigReader) readFile(c SSHConfig, re *regexp.Regexp, f string) err } else { for _, host := range hosts { for _, name := range names { - c[host] = expandTokens(name, host) + c[host] = sshExpandTokens(name, host) } } } @@ -76,9 +78,8 @@ func (r *SSHConfigReader) readFile(c SSHConfig, re *regexp.Regexp, f string) err return scanner.Err() } -func expandTokens(text, host string) string { - re := regexp.MustCompile(`%[%h]`) - return re.ReplaceAllStringFunc(text, func(match string) string { +func sshExpandTokens(text, host string) string { + return sshTokenRE.ReplaceAllStringFunc(text, func(match string) string { switch match { case "%h": return host diff --git a/git/ssh_config_test.go b/git/ssh_config_test.go new file mode 100644 index 000000000..12de53bd8 --- /dev/null +++ b/git/ssh_config_test.go @@ -0,0 +1,27 @@ +package git + +import ( + "reflect" + "strings" + "testing" +) + +// TODO: extract assertion helpers into a shared package +func eq(t *testing.T, got interface{}, expected interface{}) { + if !reflect.DeepEqual(got, expected) { + t.Errorf("expected: %v, got: %v", expected, got) + } +} + +func Test_sshParse(t *testing.T) { + m := sshParse(strings.NewReader(` + Host foo bar + HostName example.com + `), strings.NewReader(` + Host bar baz + hostname %%%h.net%% + `)) + eq(t, m["foo"], "example.com") + eq(t, m["bar"], "%bar.net%") + eq(t, m["nonexist"], "") +} diff --git a/git/url.go b/git/url.go index f3f4adf99..792d75350 100644 --- a/git/url.go +++ b/git/url.go @@ -7,15 +7,12 @@ import ( ) var ( - cachedSSHConfig SSHConfig + cachedSSHConfig sshAliasMap protocolRe = regexp.MustCompile("^[a-zA-Z_+-]+://") ) -type URLParser struct { - SSHConfig SSHConfig -} - -func (p *URLParser) Parse(rawURL string) (u *url.URL, err error) { +// ParseURL normalizes git remote urls +func ParseURL(rawURL string) (u *url.URL, err error) { if !protocolRe.MatchString(rawURL) && strings.Contains(rawURL, ":") && // not a Windows path @@ -44,7 +41,10 @@ func (p *URLParser) Parse(rawURL string) (u *url.URL, err error) { u.Host = u.Host[0:idx] } - sshHost := p.SSHConfig[u.Host] + if cachedSSHConfig == nil { + return + } + sshHost := cachedSSHConfig[u.Host] // ignore replacing host that fixes for limited network // https://help.github.com/articles/using-ssh-over-the-https-port ignoredHost := u.Host == "github.com" && sshHost == "ssh.github.com" @@ -55,12 +55,14 @@ func (p *URLParser) Parse(rawURL string) (u *url.URL, err error) { return } -func ParseURL(rawURL string) (u *url.URL, err error) { - if cachedSSHConfig == nil { - cachedSSHConfig = newSSHConfigReader().Read() +// InitSSHAliasMap prepares globally cached SSH hostname alias mappings +func InitSSHAliasMap(m map[string]string) { + if m == nil { + cachedSSHConfig = sshParseFiles() + return + } + cachedSSHConfig = sshAliasMap{} + for k, v := range m { + cachedSSHConfig[k] = v } - - p := &URLParser{cachedSSHConfig} - - return p.Parse(rawURL) }