Test SSH config parser
This commit is contained in:
parent
79e8766d8f
commit
344906bf03
4 changed files with 75 additions and 43 deletions
|
|
@ -5,6 +5,7 @@ import (
|
|||
"os"
|
||||
|
||||
"github.com/github/gh-cli/context"
|
||||
"github.com/github/gh-cli/git"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
|
|
@ -26,6 +27,7 @@ func initContext() {
|
|||
repo = os.Getenv("GH_REPO")
|
||||
}
|
||||
ctx.SetBaseRepo(repo)
|
||||
git.InitSSHAliasMap(nil)
|
||||
}
|
||||
|
||||
// RootCmd is the entry point of command-line execution
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package git
|
|||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
|
|
@ -10,13 +11,19 @@ import (
|
|||
"github.com/mitchellh/go-homedir"
|
||||
)
|
||||
|
||||
const (
|
||||
hostReStr = "(?i)^[ \t]*(host|hostname)[ \t]+(.+)$"
|
||||
var (
|
||||
sshHostRE,
|
||||
sshTokenRE *regexp.Regexp
|
||||
)
|
||||
|
||||
type SSHConfig map[string]string
|
||||
func init() {
|
||||
sshHostRE = regexp.MustCompile("(?i)^[ \t]*(host|hostname)[ \t]+(.+)$")
|
||||
sshTokenRE = regexp.MustCompile(`%[%h]`)
|
||||
}
|
||||
|
||||
func newSSHConfigReader() *SSHConfigReader {
|
||||
type sshAliasMap map[string]string
|
||||
|
||||
func sshParseFiles() sshAliasMap {
|
||||
configFiles := []string{
|
||||
"/etc/ssh_config",
|
||||
"/etc/ssh/ssh_config",
|
||||
|
|
@ -25,38 +32,33 @@ func newSSHConfigReader() *SSHConfigReader {
|
|||
userConfig := filepath.Join(homedir, ".ssh", "config")
|
||||
configFiles = append([]string{userConfig}, configFiles...)
|
||||
}
|
||||
return &SSHConfigReader{
|
||||
Files: configFiles,
|
||||
|
||||
openFiles := []io.Reader{}
|
||||
for _, file := range configFiles {
|
||||
f, err := os.Open(file)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
defer f.Close()
|
||||
openFiles = append(openFiles, f)
|
||||
}
|
||||
return sshParse(openFiles...)
|
||||
}
|
||||
|
||||
type SSHConfigReader struct {
|
||||
Files []string
|
||||
}
|
||||
|
||||
func (r *SSHConfigReader) Read() SSHConfig {
|
||||
config := make(SSHConfig)
|
||||
hostRe := regexp.MustCompile(hostReStr)
|
||||
|
||||
for _, filename := range r.Files {
|
||||
r.readFile(config, hostRe, filename)
|
||||
func sshParse(r ...io.Reader) sshAliasMap {
|
||||
config := sshAliasMap{}
|
||||
for _, file := range r {
|
||||
sshParseConfig(config, file)
|
||||
}
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
func (r *SSHConfigReader) readFile(c SSHConfig, re *regexp.Regexp, f string) error {
|
||||
file, err := os.Open(f)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
func sshParseConfig(c sshAliasMap, file io.Reader) error {
|
||||
hosts := []string{"*"}
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
match := re.FindStringSubmatch(line)
|
||||
match := sshHostRE.FindStringSubmatch(line)
|
||||
if match == nil {
|
||||
continue
|
||||
}
|
||||
|
|
@ -67,7 +69,7 @@ func (r *SSHConfigReader) readFile(c SSHConfig, re *regexp.Regexp, f string) err
|
|||
} else {
|
||||
for _, host := range hosts {
|
||||
for _, name := range names {
|
||||
c[host] = expandTokens(name, host)
|
||||
c[host] = sshExpandTokens(name, host)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -76,9 +78,8 @@ func (r *SSHConfigReader) readFile(c SSHConfig, re *regexp.Regexp, f string) err
|
|||
return scanner.Err()
|
||||
}
|
||||
|
||||
func expandTokens(text, host string) string {
|
||||
re := regexp.MustCompile(`%[%h]`)
|
||||
return re.ReplaceAllStringFunc(text, func(match string) string {
|
||||
func sshExpandTokens(text, host string) string {
|
||||
return sshTokenRE.ReplaceAllStringFunc(text, func(match string) string {
|
||||
switch match {
|
||||
case "%h":
|
||||
return host
|
||||
|
|
|
|||
27
git/ssh_config_test.go
Normal file
27
git/ssh_config_test.go
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
package git
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TODO: extract assertion helpers into a shared package
|
||||
func eq(t *testing.T, got interface{}, expected interface{}) {
|
||||
if !reflect.DeepEqual(got, expected) {
|
||||
t.Errorf("expected: %v, got: %v", expected, got)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_sshParse(t *testing.T) {
|
||||
m := sshParse(strings.NewReader(`
|
||||
Host foo bar
|
||||
HostName example.com
|
||||
`), strings.NewReader(`
|
||||
Host bar baz
|
||||
hostname %%%h.net%%
|
||||
`))
|
||||
eq(t, m["foo"], "example.com")
|
||||
eq(t, m["bar"], "%bar.net%")
|
||||
eq(t, m["nonexist"], "")
|
||||
}
|
||||
30
git/url.go
30
git/url.go
|
|
@ -7,15 +7,12 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
cachedSSHConfig SSHConfig
|
||||
cachedSSHConfig sshAliasMap
|
||||
protocolRe = regexp.MustCompile("^[a-zA-Z_+-]+://")
|
||||
)
|
||||
|
||||
type URLParser struct {
|
||||
SSHConfig SSHConfig
|
||||
}
|
||||
|
||||
func (p *URLParser) Parse(rawURL string) (u *url.URL, err error) {
|
||||
// ParseURL normalizes git remote urls
|
||||
func ParseURL(rawURL string) (u *url.URL, err error) {
|
||||
if !protocolRe.MatchString(rawURL) &&
|
||||
strings.Contains(rawURL, ":") &&
|
||||
// not a Windows path
|
||||
|
|
@ -44,7 +41,10 @@ func (p *URLParser) Parse(rawURL string) (u *url.URL, err error) {
|
|||
u.Host = u.Host[0:idx]
|
||||
}
|
||||
|
||||
sshHost := p.SSHConfig[u.Host]
|
||||
if cachedSSHConfig == nil {
|
||||
return
|
||||
}
|
||||
sshHost := cachedSSHConfig[u.Host]
|
||||
// ignore replacing host that fixes for limited network
|
||||
// https://help.github.com/articles/using-ssh-over-the-https-port
|
||||
ignoredHost := u.Host == "github.com" && sshHost == "ssh.github.com"
|
||||
|
|
@ -55,12 +55,14 @@ func (p *URLParser) Parse(rawURL string) (u *url.URL, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func ParseURL(rawURL string) (u *url.URL, err error) {
|
||||
if cachedSSHConfig == nil {
|
||||
cachedSSHConfig = newSSHConfigReader().Read()
|
||||
// InitSSHAliasMap prepares globally cached SSH hostname alias mappings
|
||||
func InitSSHAliasMap(m map[string]string) {
|
||||
if m == nil {
|
||||
cachedSSHConfig = sshParseFiles()
|
||||
return
|
||||
}
|
||||
cachedSSHConfig = sshAliasMap{}
|
||||
for k, v := range m {
|
||||
cachedSSHConfig[k] = v
|
||||
}
|
||||
|
||||
p := &URLParser{cachedSSHConfig}
|
||||
|
||||
return p.Parse(rawURL)
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue