Test SSH config parser

This commit is contained in:
Mislav Marohnić 2019-10-17 15:49:50 +02:00
parent 79e8766d8f
commit 344906bf03
4 changed files with 75 additions and 43 deletions

View file

@ -5,6 +5,7 @@ import (
"os"
"github.com/github/gh-cli/context"
"github.com/github/gh-cli/git"
"github.com/spf13/cobra"
)
@ -26,6 +27,7 @@ func initContext() {
repo = os.Getenv("GH_REPO")
}
ctx.SetBaseRepo(repo)
git.InitSSHAliasMap(nil)
}
// RootCmd is the entry point of command-line execution

View file

@ -2,6 +2,7 @@ package git
import (
"bufio"
"io"
"os"
"path/filepath"
"regexp"
@ -10,13 +11,19 @@ import (
"github.com/mitchellh/go-homedir"
)
const (
hostReStr = "(?i)^[ \t]*(host|hostname)[ \t]+(.+)$"
var (
sshHostRE,
sshTokenRE *regexp.Regexp
)
type SSHConfig map[string]string
func init() {
sshHostRE = regexp.MustCompile("(?i)^[ \t]*(host|hostname)[ \t]+(.+)$")
sshTokenRE = regexp.MustCompile(`%[%h]`)
}
func newSSHConfigReader() *SSHConfigReader {
type sshAliasMap map[string]string
func sshParseFiles() sshAliasMap {
configFiles := []string{
"/etc/ssh_config",
"/etc/ssh/ssh_config",
@ -25,38 +32,33 @@ func newSSHConfigReader() *SSHConfigReader {
userConfig := filepath.Join(homedir, ".ssh", "config")
configFiles = append([]string{userConfig}, configFiles...)
}
return &SSHConfigReader{
Files: configFiles,
openFiles := []io.Reader{}
for _, file := range configFiles {
f, err := os.Open(file)
if err != nil {
continue
}
defer f.Close()
openFiles = append(openFiles, f)
}
return sshParse(openFiles...)
}
type SSHConfigReader struct {
Files []string
}
func (r *SSHConfigReader) Read() SSHConfig {
config := make(SSHConfig)
hostRe := regexp.MustCompile(hostReStr)
for _, filename := range r.Files {
r.readFile(config, hostRe, filename)
func sshParse(r ...io.Reader) sshAliasMap {
config := sshAliasMap{}
for _, file := range r {
sshParseConfig(config, file)
}
return config
}
func (r *SSHConfigReader) readFile(c SSHConfig, re *regexp.Regexp, f string) error {
file, err := os.Open(f)
if err != nil {
return err
}
defer file.Close()
func sshParseConfig(c sshAliasMap, file io.Reader) error {
hosts := []string{"*"}
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
match := re.FindStringSubmatch(line)
match := sshHostRE.FindStringSubmatch(line)
if match == nil {
continue
}
@ -67,7 +69,7 @@ func (r *SSHConfigReader) readFile(c SSHConfig, re *regexp.Regexp, f string) err
} else {
for _, host := range hosts {
for _, name := range names {
c[host] = expandTokens(name, host)
c[host] = sshExpandTokens(name, host)
}
}
}
@ -76,9 +78,8 @@ func (r *SSHConfigReader) readFile(c SSHConfig, re *regexp.Regexp, f string) err
return scanner.Err()
}
func expandTokens(text, host string) string {
re := regexp.MustCompile(`%[%h]`)
return re.ReplaceAllStringFunc(text, func(match string) string {
func sshExpandTokens(text, host string) string {
return sshTokenRE.ReplaceAllStringFunc(text, func(match string) string {
switch match {
case "%h":
return host

27
git/ssh_config_test.go Normal file
View file

@ -0,0 +1,27 @@
package git
import (
"reflect"
"strings"
"testing"
)
// TODO: extract assertion helpers into a shared package
func eq(t *testing.T, got interface{}, expected interface{}) {
if !reflect.DeepEqual(got, expected) {
t.Errorf("expected: %v, got: %v", expected, got)
}
}
func Test_sshParse(t *testing.T) {
m := sshParse(strings.NewReader(`
Host foo bar
HostName example.com
`), strings.NewReader(`
Host bar baz
hostname %%%h.net%%
`))
eq(t, m["foo"], "example.com")
eq(t, m["bar"], "%bar.net%")
eq(t, m["nonexist"], "")
}

View file

@ -7,15 +7,12 @@ import (
)
var (
cachedSSHConfig SSHConfig
cachedSSHConfig sshAliasMap
protocolRe = regexp.MustCompile("^[a-zA-Z_+-]+://")
)
type URLParser struct {
SSHConfig SSHConfig
}
func (p *URLParser) Parse(rawURL string) (u *url.URL, err error) {
// ParseURL normalizes git remote urls
func ParseURL(rawURL string) (u *url.URL, err error) {
if !protocolRe.MatchString(rawURL) &&
strings.Contains(rawURL, ":") &&
// not a Windows path
@ -44,7 +41,10 @@ func (p *URLParser) Parse(rawURL string) (u *url.URL, err error) {
u.Host = u.Host[0:idx]
}
sshHost := p.SSHConfig[u.Host]
if cachedSSHConfig == nil {
return
}
sshHost := cachedSSHConfig[u.Host]
// ignore replacing host that fixes for limited network
// https://help.github.com/articles/using-ssh-over-the-https-port
ignoredHost := u.Host == "github.com" && sshHost == "ssh.github.com"
@ -55,12 +55,14 @@ func (p *URLParser) Parse(rawURL string) (u *url.URL, err error) {
return
}
func ParseURL(rawURL string) (u *url.URL, err error) {
if cachedSSHConfig == nil {
cachedSSHConfig = newSSHConfigReader().Read()
// InitSSHAliasMap prepares globally cached SSH hostname alias mappings
func InitSSHAliasMap(m map[string]string) {
if m == nil {
cachedSSHConfig = sshParseFiles()
return
}
cachedSSHConfig = sshAliasMap{}
for k, v := range m {
cachedSSHConfig[k] = v
}
p := &URLParser{cachedSSHConfig}
return p.Parse(rawURL)
}