diff --git a/command/root.go b/command/root.go index d9fa859ee..f4c227cb5 100644 --- a/command/root.go +++ b/command/root.go @@ -5,7 +5,6 @@ import ( "os" "github.com/github/gh-cli/context" - "github.com/github/gh-cli/git" "github.com/spf13/cobra" ) @@ -27,8 +26,6 @@ func initContext() { repo = os.Getenv("GH_REPO") } ctx.SetBaseRepo(repo) - - git.InitSSHAliasMap(nil) } // RootCmd is the entry point of command-line execution diff --git a/context/config_file_test.go b/context/config_file_test.go index 7163e6b2a..5a915e78d 100644 --- a/context/config_file_test.go +++ b/context/config_file_test.go @@ -8,6 +8,7 @@ import ( ) func eq(t *testing.T, got interface{}, expected interface{}) { + t.Helper() if !reflect.DeepEqual(got, expected) { t.Errorf("expected: %v, got: %v", expected, got) } diff --git a/context/context.go b/context/context.go index 3e6dd009e..8a5c6ccc2 100644 --- a/context/context.go +++ b/context/context.go @@ -109,7 +109,8 @@ func (c *fsContext) Remotes() (Remotes, error) { if err != nil { return nil, err } - c.remotes = parseRemotes(gitRemotes) + sshTranslate := git.ParseSSHConfig().Translator() + c.remotes = translateRemotes(gitRemotes, sshTranslate) } return c.remotes, nil } diff --git a/context/remote.go b/context/remote.go index 8ffcef567..f30a0958b 100644 --- a/context/remote.go +++ b/context/remote.go @@ -2,7 +2,7 @@ package context import ( "fmt" - "regexp" + "net/url" "strings" "github.com/github/gh-cli/git" @@ -27,74 +27,44 @@ func (r Remotes) FindByName(names ...string) (*Remote, error) { // Remote represents a git remote mapped to a GitHub repository type Remote struct { - Name string + *git.Remote Owner string Repo string } -func (r *Remote) String() string { - return r.Name -} - // GitHubRepository represents a GitHub respository type GitHubRepository struct { Name string Owner string } -func parseRemotes(gitRemotes []string) (remotes Remotes) { - re := regexp.MustCompile(`(.+)\s+(.+)\s+\((push|fetch)\)`) - - names := []string{} - remotesMap := make(map[string]map[string]string) +// TODO: accept an interface instead of git.RemoteSet +func translateRemotes(gitRemotes git.RemoteSet, urlTranslate func(*url.URL) *url.URL) (remotes Remotes) { for _, r := range gitRemotes { - if re.MatchString(r) { - match := re.FindStringSubmatch(r) - name := strings.TrimSpace(match[1]) - url := strings.TrimSpace(match[2]) - urlType := strings.TrimSpace(match[3]) - utm, ok := remotesMap[name] - if !ok { - utm = make(map[string]string) - remotesMap[name] = utm - names = append(names, name) - } - utm[urlType] = url + var owner string + var repo string + if r.FetchURL != nil { + owner, repo, _ = repoFromURL(urlTranslate(r.FetchURL)) } + if r.PushURL != nil && owner == "" { + owner, repo, _ = repoFromURL(urlTranslate(r.PushURL)) + } + remotes = append(remotes, &Remote{ + Remote: r, + Owner: owner, + Repo: repo, + }) } - - for _, name := range names { - urlMap := remotesMap[name] - repo, err := repoFromURL(urlMap["fetch"]) - if err != nil { - repo, err = repoFromURL(urlMap["push"]) - } - if err == nil { - remotes = append(remotes, &Remote{ - Name: name, - Owner: repo.Owner, - Repo: repo.Name, - }) - } - } - return } -func repoFromURL(u string) (*GitHubRepository, error) { - url, err := git.ParseURL(u) - if err != nil { - return nil, err +func repoFromURL(u *url.URL) (string, string, error) { + if !strings.EqualFold(u.Hostname(), defaultHostname) { + return "", "", fmt.Errorf("unsupported hostname: %s", u.Hostname()) } - if url.Hostname() != defaultHostname { - return nil, fmt.Errorf("invalid hostname: %s", url.Hostname()) - } - parts := strings.SplitN(strings.TrimPrefix(url.Path, "/"), "/", 3) + parts := strings.SplitN(strings.TrimPrefix(u.Path, "/"), "/", 3) if len(parts) < 2 { - return nil, fmt.Errorf("invalid path: %s", url.Path) + return "", "", fmt.Errorf("invalid path: %s", u.Path) } - return &GitHubRepository{ - Owner: parts[0], - Name: strings.TrimSuffix(parts[1], ".git"), - }, nil + return parts[0], strings.TrimSuffix(parts[1], ".git"), nil } diff --git a/context/remote_test.go b/context/remote_test.go index 70b49c4e5..359fcaa7f 100644 --- a/context/remote_test.go +++ b/context/remote_test.go @@ -2,67 +2,43 @@ package context import ( "errors" + "net/url" "testing" "github.com/github/gh-cli/git" ) func Test_repoFromURL(t *testing.T) { - git.InitSSHAliasMap(nil) - - r, err := repoFromURL("http://github.com/monalisa/octo-cat.git") + u, _ := url.Parse("http://github.com/monalisa/octo-cat.git") + owner, repo, err := repoFromURL(u) eq(t, err, nil) - eq(t, r, &GitHubRepository{Owner: "monalisa", Name: "octo-cat"}) + eq(t, owner, "monalisa") + eq(t, repo, "octo-cat") } func Test_repoFromURL_invalid(t *testing.T) { - git.InitSSHAliasMap(nil) - - _, err := repoFromURL("https://example.com/one/two") - eq(t, err, errors.New(`invalid hostname: example.com`)) - - _, err = repoFromURL("/path/to/disk") - eq(t, err, errors.New(`invalid hostname: `)) -} - -func Test_repoFromURL_SSH(t *testing.T) { - git.InitSSHAliasMap(map[string]string{ - "gh": "github.com", - "github.com": "ssh.github.com", - }) - - r, err := repoFromURL("git@gh:monalisa/octo-cat") - eq(t, err, nil) - eq(t, r, &GitHubRepository{Owner: "monalisa", Name: "octo-cat"}) - - r, err = repoFromURL("git@github.com:monalisa/octo-cat") - eq(t, err, nil) - eq(t, r, &GitHubRepository{Owner: "monalisa", Name: "octo-cat"}) -} - -func Test_parseRemotes(t *testing.T) { - git.InitSSHAliasMap(nil) - - remoteList := []string{ - "mona\tgit@github.com:monalisa/myfork.git (fetch)", - "origin\thttps://github.com/monalisa/octo-cat.git (fetch)", - "origin\thttps://github.com/monalisa/octo-cat-push.git (push)", - "upstream\thttps://example.com/nowhere.git (fetch)", - "upstream\thttps://github.com/hubot/tools (push)", + cases := [][]string{ + []string{ + "https://example.com/one/two", + "unsupported hostname: example.com", + }, + []string{ + "/path/to/disk", + "unsupported hostname: ", + }, + } + for _, c := range cases { + u, _ := url.Parse(c[0]) + _, _, err := repoFromURL(u) + eq(t, err, errors.New(c[1])) } - r := parseRemotes(remoteList) - eq(t, len(r), 3) - - eq(t, r[0], &Remote{Name: "mona", Owner: "monalisa", Repo: "myfork"}) - eq(t, r[1], &Remote{Name: "origin", Owner: "monalisa", Repo: "octo-cat"}) - eq(t, r[2], &Remote{Name: "upstream", Owner: "hubot", Repo: "tools"}) } func Test_Remotes_FindByName(t *testing.T) { list := Remotes{ - &Remote{Name: "mona", Owner: "monalisa", Repo: "myfork"}, - &Remote{Name: "origin", Owner: "monalisa", Repo: "octo-cat"}, - &Remote{Name: "upstream", Owner: "hubot", Repo: "tools"}, + &Remote{Remote: &git.Remote{Name: "mona"}, Owner: "monalisa", Repo: "myfork"}, + &Remote{Remote: &git.Remote{Name: "origin"}, Owner: "monalisa", Repo: "octo-cat"}, + &Remote{Remote: &git.Remote{Name: "upstream"}, Owner: "hubot", Repo: "tools"}, } r, err := list.FindByName("upstream", "origin") diff --git a/git/git.go b/git/git.go index b9585024c..b3b63388e 100644 --- a/git/git.go +++ b/git/git.go @@ -165,7 +165,7 @@ func Log(sha1, sha2 string) (string, error) { return string(outputs), nil } -func Remotes() ([]string, error) { +func listRemotes() ([]string, error) { remoteCmd := exec.Command("git", "remote", "-v") remoteCmd.Stderr = nil output, err := remoteCmd.Output() diff --git a/git/remote.go b/git/remote.go new file mode 100644 index 000000000..ba29049c2 --- /dev/null +++ b/git/remote.go @@ -0,0 +1,69 @@ +package git + +import ( + "net/url" + "regexp" + "strings" +) + +var remoteRE = regexp.MustCompile(`(.+)\s+(.+)\s+\((push|fetch)\)`) + +// RemoteSet is a slice of git remotes +type RemoteSet []*Remote + +// Remote is a parsed git remote +type Remote struct { + Name string + FetchURL *url.URL + PushURL *url.URL +} + +func (r *Remote) String() string { + return r.Name +} + +// Remotes gets the git remotes set for the current repo +func Remotes() (RemoteSet, error) { + list, err := listRemotes() + if err != nil { + return nil, err + } + return parseRemotes(list), nil +} + +func parseRemotes(gitRemotes []string) (remotes RemoteSet) { + for _, r := range gitRemotes { + match := remoteRE.FindStringSubmatch(r) + if match == nil { + continue + } + name := strings.TrimSpace(match[1]) + urlStr := strings.TrimSpace(match[2]) + urlType := strings.TrimSpace(match[3]) + + var rem *Remote + if len(remotes) > 0 { + rem = remotes[len(remotes)-1] + if name != rem.Name { + rem = nil + } + } + if rem == nil { + rem = &Remote{Name: name} + remotes = append(remotes, rem) + } + + u, err := ParseURL(urlStr) + if err != nil { + continue + } + + switch urlType { + case "fetch": + rem.FetchURL = u + case "push": + rem.PushURL = u + } + } + return +} diff --git a/git/remote_test.go b/git/remote_test.go new file mode 100644 index 000000000..2e7d30cb6 --- /dev/null +++ b/git/remote_test.go @@ -0,0 +1,31 @@ +package git + +import "testing" + +func Test_parseRemotes(t *testing.T) { + remoteList := []string{ + "mona\tgit@github.com:monalisa/myfork.git (fetch)", + "origin\thttps://github.com/monalisa/octo-cat.git (fetch)", + "origin\thttps://github.com/monalisa/octo-cat-push.git (push)", + "upstream\thttps://example.com/nowhere.git (fetch)", + "upstream\thttps://github.com/hubot/tools (push)", + "zardoz\thttps://example.com/zed.git (push)", + } + r := parseRemotes(remoteList) + eq(t, len(r), 4) + + eq(t, r[0].Name, "mona") + eq(t, r[0].FetchURL.String(), "ssh://git@github.com/monalisa/myfork.git") + if r[0].PushURL != nil { + t.Errorf("expected no PushURL, got %q", r[0].PushURL) + } + eq(t, r[1].Name, "origin") + eq(t, r[1].FetchURL.Path, "/monalisa/octo-cat.git") + eq(t, r[1].PushURL.Path, "/monalisa/octo-cat-push.git") + + eq(t, r[2].Name, "upstream") + eq(t, r[2].FetchURL.Host, "example.com") + eq(t, r[2].PushURL.Host, "github.com") + + eq(t, r[3].Name, "zardoz") +} diff --git a/git/ssh_config.go b/git/ssh_config.go index 47d46ac46..1ac5e828e 100644 --- a/git/ssh_config.go +++ b/git/ssh_config.go @@ -3,6 +3,7 @@ package git import ( "bufio" "io" + "net/url" "os" "path/filepath" "regexp" @@ -21,9 +22,32 @@ func init() { sshTokenRE = regexp.MustCompile(`%[%h]`) } -type sshAliasMap map[string]string +// SSHAliasMap encapsulates the translation of SSH hostname aliases +type SSHAliasMap map[string]string -func sshParseFiles() sshAliasMap { +// Translator returns a function that applies hostname aliases to URLs +func (m SSHAliasMap) Translator() func(*url.URL) *url.URL { + return func(u *url.URL) *url.URL { + if u.Scheme != "ssh" { + return u + } + resolvedHost, ok := m[u.Hostname()] + if !ok { + return u + } + // FIXME: cleanup domain logic + if strings.EqualFold(u.Hostname(), "github.com") && strings.EqualFold(resolvedHost, "ssh.github.com") { + return u + } + newURL, _ := url.Parse(u.String()) + newURL.Host = resolvedHost + return newURL + } +} + +// ParseSSHConfig constructs a map of SSH hostname aliases based on user and +// system configuration files +func ParseSSHConfig() SSHAliasMap { configFiles := []string{ "/etc/ssh_config", "/etc/ssh/ssh_config", @@ -45,15 +69,15 @@ func sshParseFiles() sshAliasMap { return sshParse(openFiles...) } -func sshParse(r ...io.Reader) sshAliasMap { - config := sshAliasMap{} +func sshParse(r ...io.Reader) SSHAliasMap { + config := SSHAliasMap{} for _, file := range r { sshParseConfig(config, file) } return config } -func sshParseConfig(c sshAliasMap, file io.Reader) error { +func sshParseConfig(c SSHAliasMap, file io.Reader) error { hosts := []string{"*"} scanner := bufio.NewScanner(file) for scanner.Scan() { diff --git a/git/ssh_config_test.go b/git/ssh_config_test.go index 12de53bd8..35a0c93e6 100644 --- a/git/ssh_config_test.go +++ b/git/ssh_config_test.go @@ -1,6 +1,7 @@ package git import ( + "net/url" "reflect" "strings" "testing" @@ -8,6 +9,7 @@ import ( // TODO: extract assertion helpers into a shared package func eq(t *testing.T, got interface{}, expected interface{}) { + t.Helper() if !reflect.DeepEqual(got, expected) { t.Errorf("expected: %v, got: %v", expected, got) } @@ -25,3 +27,24 @@ func Test_sshParse(t *testing.T) { eq(t, m["bar"], "%bar.net%") eq(t, m["nonexist"], "") } + +func Test_Translator(t *testing.T) { + m := SSHAliasMap{ + "gh": "github.com", + "github.com": "ssh.github.com", + } + tr := m.Translator() + + cases := [][]string{ + []string{"ssh://gh/o/r", "ssh://github.com/o/r"}, + []string{"ssh://github.com/o/r", "ssh://github.com/o/r"}, + []string{"https://gh/o/r", "https://gh/o/r"}, + } + for _, c := range cases { + u, _ := url.Parse(c[0]) + got := tr(u) + if got.String() != c[1] { + t.Errorf("%q: expected %q, got %q", c[0], c[1], got) + } + } +} diff --git a/git/url.go b/git/url.go index 792d75350..55e11c08f 100644 --- a/git/url.go +++ b/git/url.go @@ -7,8 +7,7 @@ import ( ) var ( - cachedSSHConfig sshAliasMap - protocolRe = regexp.MustCompile("^[a-zA-Z_+-]+://") + protocolRe = regexp.MustCompile("^[a-zA-Z_+-]+://") ) // ParseURL normalizes git remote urls @@ -41,28 +40,5 @@ func ParseURL(rawURL string) (u *url.URL, err error) { u.Host = u.Host[0:idx] } - if cachedSSHConfig == nil { - return - } - sshHost := cachedSSHConfig[u.Host] - // ignore replacing host that fixes for limited network - // https://help.github.com/articles/using-ssh-over-the-https-port - ignoredHost := u.Host == "github.com" && sshHost == "ssh.github.com" - if !ignoredHost && sshHost != "" { - u.Host = sshHost - } - return } - -// InitSSHAliasMap prepares globally cached SSH hostname alias mappings -func InitSSHAliasMap(m map[string]string) { - if m == nil { - cachedSSHConfig = sshParseFiles() - return - } - cachedSSHConfig = sshAliasMap{} - for k, v := range m { - cachedSSHConfig[k] = v - } -}