Merge pull request #34 from github/no-global
Eliminate package-level global state
This commit is contained in:
commit
eefb6d13ee
24 changed files with 598 additions and 412 deletions
150
api/client.go
150
api/client.go
|
|
@ -4,14 +4,69 @@ import (
|
|||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/github/gh-cli/context"
|
||||
"github.com/github/gh-cli/version"
|
||||
)
|
||||
|
||||
// ClientOption represents an argument to NewClient
|
||||
type ClientOption = func(http.RoundTripper) http.RoundTripper
|
||||
|
||||
// NewClient initializes a Client
|
||||
func NewClient(opts ...ClientOption) *Client {
|
||||
tr := http.DefaultTransport
|
||||
for _, opt := range opts {
|
||||
tr = opt(tr)
|
||||
}
|
||||
http := &http.Client{Transport: tr}
|
||||
client := &Client{http: http}
|
||||
return client
|
||||
}
|
||||
|
||||
// AddHeader turns a RoundTripper into one that adds a request header
|
||||
func AddHeader(name, value string) ClientOption {
|
||||
return func(tr http.RoundTripper) http.RoundTripper {
|
||||
return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) {
|
||||
req.Header.Add(name, value)
|
||||
return tr.RoundTrip(req)
|
||||
}}
|
||||
}
|
||||
}
|
||||
|
||||
// VerboseLog enables request/response logging within a RoundTripper
|
||||
func VerboseLog(out io.Writer) ClientOption {
|
||||
return func(tr http.RoundTripper) http.RoundTripper {
|
||||
return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) {
|
||||
fmt.Fprintf(out, "> %s %s\n", req.Method, req.URL.RequestURI())
|
||||
res, err := tr.RoundTrip(req)
|
||||
if err == nil {
|
||||
fmt.Fprintf(out, "< HTTP %s\n", res.Status)
|
||||
}
|
||||
return res, err
|
||||
}}
|
||||
}
|
||||
}
|
||||
|
||||
// ReplaceTripper substitutes the underlying RoundTripper with a custom one
|
||||
func ReplaceTripper(tr http.RoundTripper) ClientOption {
|
||||
return func(http.RoundTripper) http.RoundTripper {
|
||||
return tr
|
||||
}
|
||||
}
|
||||
|
||||
type funcTripper struct {
|
||||
roundTrip func(*http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
func (tr funcTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return tr.roundTrip(req)
|
||||
}
|
||||
|
||||
// Client facilitates making HTTP requests to the GitHub API
|
||||
type Client struct {
|
||||
http *http.Client
|
||||
}
|
||||
|
||||
type graphQLResponse struct {
|
||||
Data interface{}
|
||||
Errors []struct {
|
||||
|
|
@ -19,32 +74,8 @@ type graphQLResponse struct {
|
|||
}
|
||||
}
|
||||
|
||||
/*
|
||||
GraphQL: Declared as an external variable so it can be mocked in tests
|
||||
|
||||
type repoResponse struct {
|
||||
Repository struct {
|
||||
CreatedAt string
|
||||
}
|
||||
}
|
||||
|
||||
query := `query {
|
||||
repository(owner: "golang", name: "go") {
|
||||
createdAt
|
||||
}
|
||||
}`
|
||||
|
||||
variables := map[string]string{}
|
||||
|
||||
var resp repoResponse
|
||||
err := graphql(query, map[string]string{}, &resp)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fmt.Printf("%+v\n", resp)
|
||||
*/
|
||||
var GraphQL = func(query string, variables map[string]string, data interface{}) error {
|
||||
// GraphQL performs a GraphQL request and parses the response
|
||||
func (c Client) GraphQL(query string, variables map[string]interface{}, data interface{}) error {
|
||||
url := "https://api.github.com/graphql"
|
||||
reqBody, err := json.Marshal(map[string]interface{}{"query": query, "variables": variables})
|
||||
if err != nil {
|
||||
|
|
@ -56,42 +87,31 @@ var GraphQL = func(query string, variables map[string]string, data interface{})
|
|||
return err
|
||||
}
|
||||
|
||||
token, err := context.Current().AuthToken()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "token "+token)
|
||||
req.Header.Set("Content-Type", "application/json; charset=utf-8")
|
||||
req.Header.Set("User-Agent", "GitHub CLI "+version.Version)
|
||||
|
||||
debugRequest(req, string(reqBody))
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
return handleResponse(resp, data)
|
||||
}
|
||||
|
||||
func handleResponse(resp *http.Response, data interface{}) error {
|
||||
success := resp.StatusCode >= 200 && resp.StatusCode < 300
|
||||
|
||||
if !success {
|
||||
return handleHTTPError(resp)
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
debugResponse(resp, string(body))
|
||||
return handleResponse(resp, body, data)
|
||||
}
|
||||
|
||||
func handleResponse(resp *http.Response, body []byte, data interface{}) error {
|
||||
success := resp.StatusCode >= 200 && resp.StatusCode < 300
|
||||
|
||||
if !success {
|
||||
return handleHTTPError(resp, body)
|
||||
}
|
||||
|
||||
gr := &graphQLResponse{Data: data}
|
||||
err := json.Unmarshal(body, &gr)
|
||||
err = json.Unmarshal(body, &gr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -107,12 +127,16 @@ func handleResponse(resp *http.Response, body []byte, data interface{}) error {
|
|||
|
||||
}
|
||||
|
||||
func handleHTTPError(resp *http.Response, body []byte) error {
|
||||
func handleHTTPError(resp *http.Response) error {
|
||||
var message string
|
||||
var parsedBody struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
err := json.Unmarshal(body, &parsedBody)
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = json.Unmarshal(body, &parsedBody)
|
||||
if err != nil {
|
||||
message = string(body)
|
||||
} else {
|
||||
|
|
@ -121,19 +145,3 @@ func handleHTTPError(resp *http.Response, body []byte) error {
|
|||
|
||||
return fmt.Errorf("http error, '%s' failed (%d): '%s'", resp.Request.URL, resp.StatusCode, message)
|
||||
}
|
||||
|
||||
func debugRequest(req *http.Request, body string) {
|
||||
if _, ok := os.LookupEnv("DEBUG"); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("DEBUG: GraphQL request to %s:\n %s\n\n", req.URL, body)
|
||||
}
|
||||
|
||||
func debugResponse(resp *http.Response, body string) {
|
||||
if _, ok := os.LookupEnv("DEBUG"); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("DEBUG: GraphQL response:\n%+v\n\n%s\n\n", resp, body)
|
||||
}
|
||||
|
|
|
|||
51
api/client_test.go
Normal file
51
api/client_test.go
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func eq(t *testing.T, got interface{}, expected interface{}) {
|
||||
t.Helper()
|
||||
if !reflect.DeepEqual(got, expected) {
|
||||
t.Errorf("expected: %v, got: %v", expected, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGraphQL(t *testing.T) {
|
||||
http := &FakeHTTP{}
|
||||
client := NewClient(
|
||||
ReplaceTripper(http),
|
||||
AddHeader("Authorization", "token OTOKEN"),
|
||||
)
|
||||
|
||||
vars := map[string]interface{}{"name": "Mona"}
|
||||
response := struct {
|
||||
Viewer struct {
|
||||
Login string
|
||||
}
|
||||
}{}
|
||||
|
||||
http.StubResponse(200, bytes.NewBufferString(`{"data":{"viewer":{"login":"hubot"}}}`))
|
||||
err := client.GraphQL("QUERY", vars, &response)
|
||||
eq(t, err, nil)
|
||||
eq(t, response.Viewer.Login, "hubot")
|
||||
|
||||
req := http.Requests[0]
|
||||
reqBody, _ := ioutil.ReadAll(req.Body)
|
||||
eq(t, string(reqBody), `{"query":"QUERY","variables":{"name":"Mona"}}`)
|
||||
eq(t, req.Header.Get("Authorization"), "token OTOKEN")
|
||||
}
|
||||
|
||||
func TestGraphQLError(t *testing.T) {
|
||||
http := &FakeHTTP{}
|
||||
client := NewClient(ReplaceTripper(http))
|
||||
|
||||
response := struct{}{}
|
||||
http.StubResponse(200, bytes.NewBufferString(`{"errors":[{"message":"OH NO"}]}`))
|
||||
err := client.GraphQL("", nil, &response)
|
||||
eq(t, err, fmt.Errorf("graphql error: 'OH NO'"))
|
||||
}
|
||||
37
api/fake_http.go
Normal file
37
api/fake_http.go
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// FakeHTTP provides a mechanism by which to stub HTTP responses through
|
||||
type FakeHTTP struct {
|
||||
// Requests stores references to sequental requests that RoundTrip has received
|
||||
Requests []*http.Request
|
||||
count int
|
||||
responseStubs []*http.Response
|
||||
}
|
||||
|
||||
// StubResponse pre-records an HTTP response
|
||||
func (f *FakeHTTP) StubResponse(status int, body io.Reader) {
|
||||
resp := &http.Response{
|
||||
StatusCode: status,
|
||||
Body: ioutil.NopCloser(body),
|
||||
}
|
||||
f.responseStubs = append(f.responseStubs, resp)
|
||||
}
|
||||
|
||||
// RoundTrip satisfies http.RoundTripper
|
||||
func (f *FakeHTTP) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
if len(f.responseStubs) <= f.count {
|
||||
return nil, fmt.Errorf("FakeHTTP: missing response stub for request %d", f.count)
|
||||
}
|
||||
resp := f.responseStubs[f.count]
|
||||
f.count++
|
||||
resp.Request = req
|
||||
f.Requests = append(f.Requests, req)
|
||||
return resp, nil
|
||||
}
|
||||
|
|
@ -2,8 +2,6 @@ package api
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/github/gh-cli/context"
|
||||
)
|
||||
|
||||
type PullRequestsPayload struct {
|
||||
|
|
@ -19,7 +17,12 @@ type PullRequest struct {
|
|||
HeadRefName string
|
||||
}
|
||||
|
||||
func PullRequests() (*PullRequestsPayload, error) {
|
||||
type Repo interface {
|
||||
RepoName() string
|
||||
RepoOwner() string
|
||||
}
|
||||
|
||||
func PullRequests(client *Client, ghRepo Repo, currentBranch, currentUsername string) (*PullRequestsPayload, error) {
|
||||
type edges struct {
|
||||
Edges []struct {
|
||||
Node PullRequest
|
||||
|
|
@ -48,7 +51,7 @@ func PullRequests() (*PullRequestsPayload, error) {
|
|||
|
||||
query($owner: String!, $repo: String!, $headRefName: String!, $viewerQuery: String!, $reviewerQuery: String!, $per_page: Int = 10) {
|
||||
repository(owner: $owner, name: $repo) {
|
||||
pullRequests(headRefName: $headRefName, first: 1) {
|
||||
pullRequests(headRefName: $headRefName, states: OPEN, first: 1) {
|
||||
edges {
|
||||
node {
|
||||
...pr
|
||||
|
|
@ -79,26 +82,13 @@ func PullRequests() (*PullRequestsPayload, error) {
|
|||
}
|
||||
`
|
||||
|
||||
ghRepo, err := context.Current().BaseRepo()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
currentBranch, err := context.Current().Branch()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
currentUsername, err := context.Current().AuthLogin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
owner := ghRepo.Owner
|
||||
repo := ghRepo.Name
|
||||
owner := ghRepo.RepoOwner()
|
||||
repo := ghRepo.RepoName()
|
||||
|
||||
viewerQuery := fmt.Sprintf("repo:%s/%s state:open is:pr author:%s", owner, repo, currentUsername)
|
||||
reviewerQuery := fmt.Sprintf("repo:%s/%s state:open review-requested:%s", owner, repo, currentUsername)
|
||||
|
||||
variables := map[string]string{
|
||||
variables := map[string]interface{}{
|
||||
"viewerQuery": viewerQuery,
|
||||
"reviewerQuery": reviewerQuery,
|
||||
"owner": owner,
|
||||
|
|
@ -107,7 +97,7 @@ func PullRequests() (*PullRequestsPayload, error) {
|
|||
}
|
||||
|
||||
var resp response
|
||||
err = GraphQL(query, variables, &resp)
|
||||
err := client.GraphQL(query, variables, &resp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -135,3 +125,49 @@ func PullRequests() (*PullRequestsPayload, error) {
|
|||
|
||||
return &payload, nil
|
||||
}
|
||||
|
||||
func PullRequestsForBranch(client *Client, ghRepo Repo, branch string) ([]PullRequest, error) {
|
||||
type response struct {
|
||||
Repository struct {
|
||||
PullRequests struct {
|
||||
Edges []struct {
|
||||
Node PullRequest
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
query := `
|
||||
query($owner: String!, $repo: String!, $headRefName: String!) {
|
||||
repository(owner: $owner, name: $repo) {
|
||||
pullRequests(headRefName: $headRefName, states: OPEN, first: 1) {
|
||||
edges {
|
||||
node {
|
||||
number
|
||||
title
|
||||
url
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
variables := map[string]interface{}{
|
||||
"owner": ghRepo.RepoOwner(),
|
||||
"repo": ghRepo.RepoName(),
|
||||
"headRefName": branch,
|
||||
}
|
||||
|
||||
var resp response
|
||||
err := client.GraphQL(query, variables, &resp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
prs := []PullRequest{}
|
||||
for _, edge := range resp.Repository.PullRequests.Edges {
|
||||
prs = append(prs, edge.Node)
|
||||
}
|
||||
|
||||
return prs, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ import (
|
|||
"strconv"
|
||||
|
||||
"github.com/github/gh-cli/api"
|
||||
"github.com/github/gh-cli/context"
|
||||
"github.com/github/gh-cli/utils"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
|
@ -38,7 +37,26 @@ work with pull requests.`,
|
|||
}
|
||||
|
||||
func prList(cmd *cobra.Command, args []string) error {
|
||||
prPayload, err := api.PullRequests()
|
||||
ctx := contextForCommand(cmd)
|
||||
apiClient, err := apiClientForContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
baseRepo, err := ctx.BaseRepo()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
currentBranch, err := ctx.Branch()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
currentUser, err := ctx.AuthLogin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
prPayload, err := api.PullRequests(apiClient, baseRepo, currentBranch, currentUser)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -47,10 +65,6 @@ func prList(cmd *cobra.Command, args []string) error {
|
|||
if prPayload.CurrentPR != nil {
|
||||
printPrs(*prPayload.CurrentPR)
|
||||
} else {
|
||||
currentBranch, err := context.Current().Branch()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
message := fmt.Sprintf(" There is no pull request associated with %s", utils.Cyan("["+currentBranch+"]"))
|
||||
printMessage(message)
|
||||
}
|
||||
|
|
@ -76,7 +90,8 @@ func prList(cmd *cobra.Command, args []string) error {
|
|||
}
|
||||
|
||||
func prView(cmd *cobra.Command, args []string) error {
|
||||
baseRepo, err := context.Current().BaseRepo()
|
||||
ctx := contextForCommand(cmd)
|
||||
baseRepo, err := ctx.BaseRepo()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -85,23 +100,27 @@ func prView(cmd *cobra.Command, args []string) error {
|
|||
if len(args) > 0 {
|
||||
if prNumber, err := strconv.Atoi(args[0]); err == nil {
|
||||
// TODO: move URL generation into GitHubRepository
|
||||
openURL = fmt.Sprintf("https://github.com/%s/%s/pull/%d", baseRepo.Owner, baseRepo.Name, prNumber)
|
||||
openURL = fmt.Sprintf("https://github.com/%s/%s/pull/%d", baseRepo.RepoOwner(), baseRepo.RepoName(), prNumber)
|
||||
} else {
|
||||
return fmt.Errorf("invalid pull request number: '%s'", args[0])
|
||||
}
|
||||
} else {
|
||||
prPayload, err := api.PullRequests()
|
||||
apiClient, err := apiClientForContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if prPayload.CurrentPR == nil {
|
||||
branch, err := context.Current().Branch()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Printf("The [%s] branch has no open PRs", branch)
|
||||
return nil
|
||||
}
|
||||
openURL = prPayload.CurrentPR.URL
|
||||
currentBranch, err := ctx.Branch()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
prs, err := api.PullRequestsForBranch(apiClient, baseRepo, currentBranch)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if len(prs) < 1 {
|
||||
return fmt.Errorf("the '%s' branch has no open pull requests", currentBranch)
|
||||
}
|
||||
openURL = prs[0].URL
|
||||
}
|
||||
|
||||
fmt.Printf("Opening %s in your browser.\n", openURL)
|
||||
|
|
|
|||
|
|
@ -1,21 +1,40 @@
|
|||
package command
|
||||
|
||||
import (
|
||||
"os"
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/github/gh-cli/api"
|
||||
"github.com/github/gh-cli/context"
|
||||
"github.com/github/gh-cli/test"
|
||||
"github.com/github/gh-cli/utils"
|
||||
)
|
||||
|
||||
func TestPRList(t *testing.T) {
|
||||
ctx := context.InitBlankContext()
|
||||
ctx.SetBaseRepo("github/FAKE-GITHUB-REPO-NAME")
|
||||
ctx.SetBranch("master")
|
||||
func initBlankContext(repo, branch string) {
|
||||
initContext = func() context.Context {
|
||||
ctx := context.NewBlank()
|
||||
ctx.SetBaseRepo(repo)
|
||||
ctx.SetBranch(branch)
|
||||
return ctx
|
||||
}
|
||||
}
|
||||
|
||||
teardown := test.MockGraphQLResponse("test/fixtures/prList.json")
|
||||
defer teardown()
|
||||
func initFakeHTTP() *api.FakeHTTP {
|
||||
http := &api.FakeHTTP{}
|
||||
apiClientForContext = func(context.Context) (*api.Client, error) {
|
||||
return api.NewClient(api.ReplaceTripper(http)), nil
|
||||
}
|
||||
return http
|
||||
}
|
||||
|
||||
func TestPRList(t *testing.T) {
|
||||
initBlankContext("OWNER/REPO", "master")
|
||||
http := initFakeHTTP()
|
||||
|
||||
jsonFile, _ := os.Open("../test/fixtures/prList.json")
|
||||
defer jsonFile.Close()
|
||||
http.StubResponse(200, jsonFile)
|
||||
|
||||
output, err := test.RunCommand(RootCmd, "pr list")
|
||||
if err != nil {
|
||||
|
|
@ -37,11 +56,12 @@ func TestPRList(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestPRView(t *testing.T) {
|
||||
teardown := test.MockGraphQLResponse("test/fixtures/prView.json")
|
||||
defer teardown()
|
||||
initBlankContext("OWNER/REPO", "master")
|
||||
http := initFakeHTTP()
|
||||
|
||||
gitRepo := test.UseTempGitRepo()
|
||||
defer gitRepo.TearDown()
|
||||
jsonFile, _ := os.Open("../test/fixtures/prView.json")
|
||||
defer jsonFile.Close()
|
||||
http.StubResponse(200, jsonFile)
|
||||
|
||||
teardown, callCount := mockOpenInBrowser()
|
||||
defer teardown()
|
||||
|
|
@ -61,24 +81,21 @@ func TestPRView(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestPRView_NoActiveBranch(t *testing.T) {
|
||||
teardown := test.MockGraphQLResponse("test/fixtures/prView_NoActiveBranch.json")
|
||||
defer teardown()
|
||||
initBlankContext("OWNER/REPO", "master")
|
||||
http := initFakeHTTP()
|
||||
|
||||
gitRepo := test.UseTempGitRepo()
|
||||
defer gitRepo.TearDown()
|
||||
jsonFile, _ := os.Open("../test/fixtures/prView_NoActiveBranch.json")
|
||||
defer jsonFile.Close()
|
||||
http.StubResponse(200, jsonFile)
|
||||
|
||||
teardown, callCount := mockOpenInBrowser()
|
||||
defer teardown()
|
||||
|
||||
output, err := test.RunCommand(RootCmd, "pr view")
|
||||
if err != nil {
|
||||
if err == nil || err.Error() != "the 'master' branch has no open pull requests" {
|
||||
t.Errorf("error running command `pr view`: %v", err)
|
||||
}
|
||||
|
||||
if output == "" {
|
||||
t.Errorf("command output expected got an empty string")
|
||||
}
|
||||
|
||||
if *callCount > 0 {
|
||||
t.Errorf("OpenInBrowser should NOT be called but was called %d time(s)", *callCount)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,31 +4,18 @@ import (
|
|||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/github/gh-cli/api"
|
||||
"github.com/github/gh-cli/context"
|
||||
"github.com/github/gh-cli/git"
|
||||
"github.com/github/gh-cli/version"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var (
|
||||
currentRepo string
|
||||
currentBranch string
|
||||
)
|
||||
|
||||
func init() {
|
||||
RootCmd.PersistentFlags().StringVarP(¤tRepo, "repo", "R", "", "current GitHub repository")
|
||||
RootCmd.PersistentFlags().StringVarP(¤tBranch, "current-branch", "B", "", "current git branch")
|
||||
}
|
||||
|
||||
func initContext() {
|
||||
ctx := context.InitDefaultContext()
|
||||
ctx.SetBranch(currentBranch)
|
||||
repo := currentRepo
|
||||
if repo == "" {
|
||||
repo = os.Getenv("GH_REPO")
|
||||
}
|
||||
ctx.SetBaseRepo(repo)
|
||||
|
||||
git.InitSSHAliasMap(nil)
|
||||
RootCmd.PersistentFlags().StringP("repo", "R", "", "current GitHub repository")
|
||||
RootCmd.PersistentFlags().StringP("current-branch", "B", "", "current git branch")
|
||||
// TODO:
|
||||
// RootCmd.PersistentFlags().BoolP("verbose", "V", false, "enable verbose output")
|
||||
}
|
||||
|
||||
// RootCmd is the entry point of command-line execution
|
||||
|
|
@ -37,10 +24,43 @@ var RootCmd = &cobra.Command{
|
|||
Short: "GitHub CLI",
|
||||
Long: `Do things with GitHub from your terminal`,
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
PersistentPreRun: func(cmd *cobra.Command, args []string) {
|
||||
initContext()
|
||||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
fmt.Println("root")
|
||||
},
|
||||
}
|
||||
|
||||
// overriden in tests
|
||||
var initContext = func() context.Context {
|
||||
ctx := context.New()
|
||||
if repo := os.Getenv("GH_REPO"); repo != "" {
|
||||
ctx.SetBaseRepo(repo)
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
func contextForCommand(cmd *cobra.Command) context.Context {
|
||||
ctx := initContext()
|
||||
if repo, err := cmd.Flags().GetString("repo"); err == nil && repo != "" {
|
||||
ctx.SetBaseRepo(repo)
|
||||
}
|
||||
if branch, err := cmd.Flags().GetString("current-branch"); err == nil && branch != "" {
|
||||
ctx.SetBranch(branch)
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
// overriden in tests
|
||||
var apiClientForContext = func(ctx context.Context) (*api.Client, error) {
|
||||
token, err := ctx.AuthToken()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
opts := []api.ClientOption{
|
||||
api.AddHeader("Authorization", fmt.Sprintf("token %s", token)),
|
||||
api.AddHeader("User-Agent", fmt.Sprintf("GitHub CLI %s", version.Version)),
|
||||
}
|
||||
if verbose := os.Getenv("DEBUG"); verbose != "" {
|
||||
opts = append(opts, api.VerboseLog(os.Stderr))
|
||||
}
|
||||
return api.NewClient(opts...), nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,13 +5,9 @@ import (
|
|||
"strings"
|
||||
)
|
||||
|
||||
// InitBlankContext initializes a blank context for testing
|
||||
func InitBlankContext() Context {
|
||||
currentContext = &blankContext{
|
||||
authToken: "OTOKEN",
|
||||
authLogin: "monalisa",
|
||||
}
|
||||
return currentContext
|
||||
// NewBlank initializes a blank Context suitable for testing
|
||||
func NewBlank() Context {
|
||||
return &blankContext{}
|
||||
}
|
||||
|
||||
// A Context implementation that queries the filesystem
|
||||
|
|
@ -19,7 +15,19 @@ type blankContext struct {
|
|||
authToken string
|
||||
authLogin string
|
||||
branch string
|
||||
baseRepo *GitHubRepository
|
||||
baseRepo GitHubRepository
|
||||
}
|
||||
|
||||
type ghRepo struct {
|
||||
owner string
|
||||
name string
|
||||
}
|
||||
|
||||
func (r ghRepo) RepoOwner() string {
|
||||
return r.owner
|
||||
}
|
||||
func (r ghRepo) RepoName() string {
|
||||
return r.name
|
||||
}
|
||||
|
||||
func (c *blankContext) AuthToken() (string, error) {
|
||||
|
|
@ -49,7 +57,7 @@ func (c *blankContext) Remotes() (Remotes, error) {
|
|||
return Remotes{}, nil
|
||||
}
|
||||
|
||||
func (c *blankContext) BaseRepo() (*GitHubRepository, error) {
|
||||
func (c *blankContext) BaseRepo() (GitHubRepository, error) {
|
||||
if c.baseRepo == nil {
|
||||
return nil, fmt.Errorf("base repo was not initialized")
|
||||
}
|
||||
|
|
@ -59,9 +67,6 @@ func (c *blankContext) BaseRepo() (*GitHubRepository, error) {
|
|||
func (c *blankContext) SetBaseRepo(nwo string) {
|
||||
parts := strings.SplitN(nwo, "/", 2)
|
||||
if len(parts) == 2 {
|
||||
c.baseRepo = &GitHubRepository{
|
||||
Owner: parts[0],
|
||||
Name: parts[1],
|
||||
}
|
||||
c.baseRepo = &ghRepo{parts[0], parts[1]}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import (
|
|||
)
|
||||
|
||||
func eq(t *testing.T, got interface{}, expected interface{}) {
|
||||
t.Helper()
|
||||
if !reflect.DeepEqual(got, expected) {
|
||||
t.Errorf("expected: %v, got: %v", expected, got)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import (
|
|||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/github/gh-cli/api"
|
||||
"github.com/github/gh-cli/auth"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
|
@ -20,6 +21,7 @@ const (
|
|||
)
|
||||
|
||||
// TODO: have a conversation about whether this belongs in the "context" package
|
||||
// FIXME: make testable
|
||||
func setupConfigFile(filename string) (*configEntry, error) {
|
||||
flow := &auth.OAuthFlow{
|
||||
Hostname: oauthHost,
|
||||
|
|
@ -38,12 +40,12 @@ func setupConfigFile(filename string) (*configEntry, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
u, err := getViewer(token)
|
||||
userLogin, err := getViewer(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
entry := configEntry{
|
||||
User: u.Login,
|
||||
User: userLogin,
|
||||
Token: token,
|
||||
}
|
||||
data := make(map[string][]configEntry)
|
||||
|
|
@ -74,6 +76,18 @@ func setupConfigFile(filename string) (*configEntry, error) {
|
|||
return &entry, err
|
||||
}
|
||||
|
||||
func getViewer(token string) (string, error) {
|
||||
http := api.NewClient(api.AddHeader("Authorization", fmt.Sprintf("token %s", token)))
|
||||
|
||||
response := struct {
|
||||
Viewer struct {
|
||||
Login string
|
||||
}
|
||||
}{}
|
||||
err := http.GraphQL("{ viewer { login } }", nil, &response)
|
||||
return response.Viewer.Login, err
|
||||
}
|
||||
|
||||
func waitForEnter(r io.Reader) error {
|
||||
scanner := bufio.NewScanner(r)
|
||||
scanner.Scan()
|
||||
|
|
|
|||
|
|
@ -1,59 +0,0 @@
|
|||
package context
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type viewer struct {
|
||||
Login string
|
||||
}
|
||||
type responseData struct {
|
||||
Data struct {
|
||||
Viewer *viewer
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: figure out how to enable using the "api" package here
|
||||
//
|
||||
// Right now "api" is coupled to "context", so we can't import "api" from here.
|
||||
func getViewer(token string) (user *viewer, err error) {
|
||||
url := "https://api.github.com/graphql"
|
||||
query := `{ viewer { login } }`
|
||||
|
||||
reqBody, err := json.Marshal(map[string]interface{}{"query": query})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(reqBody))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "token "+token)
|
||||
req.Header.Set("Content-Type", "application/json; charset=utf-8")
|
||||
|
||||
client := http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
data := responseData{}
|
||||
err = json.Unmarshal(body, &data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
user = data.Data.Viewer
|
||||
return
|
||||
}
|
||||
|
|
@ -15,24 +15,19 @@ type Context interface {
|
|||
Branch() (string, error)
|
||||
SetBranch(string)
|
||||
Remotes() (Remotes, error)
|
||||
BaseRepo() (*GitHubRepository, error)
|
||||
BaseRepo() (GitHubRepository, error)
|
||||
SetBaseRepo(string)
|
||||
}
|
||||
|
||||
var currentContext Context
|
||||
|
||||
// Current returns the currently initialized Context instance
|
||||
func Current() Context {
|
||||
return currentContext
|
||||
// GitHubRepository is anything that can be mapped to an OWNER/REPO pair
|
||||
type GitHubRepository interface {
|
||||
RepoOwner() string
|
||||
RepoName() string
|
||||
}
|
||||
|
||||
// InitDefaultContext initializes the default filesystem context
|
||||
func InitDefaultContext() Context {
|
||||
ctx := &fsContext{}
|
||||
if currentContext == nil {
|
||||
currentContext = ctx
|
||||
}
|
||||
return ctx
|
||||
// New initializes a Context that reads from the filesystem
|
||||
func New() Context {
|
||||
return &fsContext{}
|
||||
}
|
||||
|
||||
// A Context implementation that queries the filesystem
|
||||
|
|
@ -40,7 +35,7 @@ type fsContext struct {
|
|||
config *configEntry
|
||||
remotes Remotes
|
||||
branch string
|
||||
baseRepo *GitHubRepository
|
||||
baseRepo GitHubRepository
|
||||
authToken string
|
||||
}
|
||||
|
||||
|
|
@ -109,12 +104,13 @@ func (c *fsContext) Remotes() (Remotes, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.remotes = parseRemotes(gitRemotes)
|
||||
sshTranslate := git.ParseSSHConfig().Translator()
|
||||
c.remotes = translateRemotes(gitRemotes, sshTranslate)
|
||||
}
|
||||
return c.remotes, nil
|
||||
}
|
||||
|
||||
func (c *fsContext) BaseRepo() (*GitHubRepository, error) {
|
||||
func (c *fsContext) BaseRepo() (GitHubRepository, error) {
|
||||
if c.baseRepo != nil {
|
||||
return c.baseRepo, nil
|
||||
}
|
||||
|
|
@ -128,19 +124,13 @@ func (c *fsContext) BaseRepo() (*GitHubRepository, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
c.baseRepo = &GitHubRepository{
|
||||
Owner: rem.Owner,
|
||||
Name: rem.Repo,
|
||||
}
|
||||
c.baseRepo = rem
|
||||
return c.baseRepo, nil
|
||||
}
|
||||
|
||||
func (c *fsContext) SetBaseRepo(nwo string) {
|
||||
parts := strings.SplitN(nwo, "/", 2)
|
||||
if len(parts) == 2 {
|
||||
c.baseRepo = &GitHubRepository{
|
||||
Owner: parts[0],
|
||||
Name: parts[1],
|
||||
}
|
||||
c.baseRepo = &ghRepo{parts[0], parts[1]}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ package context
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/github/gh-cli/git"
|
||||
|
|
@ -27,74 +27,48 @@ func (r Remotes) FindByName(names ...string) (*Remote, error) {
|
|||
|
||||
// Remote represents a git remote mapped to a GitHub repository
|
||||
type Remote struct {
|
||||
Name string
|
||||
*git.Remote
|
||||
Owner string
|
||||
Repo string
|
||||
}
|
||||
|
||||
func (r *Remote) String() string {
|
||||
return r.Name
|
||||
// RepoName is the name of the GitHub repository
|
||||
func (r Remote) RepoName() string {
|
||||
return r.Repo
|
||||
}
|
||||
|
||||
// GitHubRepository represents a GitHub respository
|
||||
type GitHubRepository struct {
|
||||
Name string
|
||||
Owner string
|
||||
// RepoOwner is the name of the GitHub account that owns the repo
|
||||
func (r Remote) RepoOwner() string {
|
||||
return r.Owner
|
||||
}
|
||||
|
||||
func parseRemotes(gitRemotes []string) (remotes Remotes) {
|
||||
re := regexp.MustCompile(`(.+)\s+(.+)\s+\((push|fetch)\)`)
|
||||
|
||||
names := []string{}
|
||||
remotesMap := make(map[string]map[string]string)
|
||||
// TODO: accept an interface instead of git.RemoteSet
|
||||
func translateRemotes(gitRemotes git.RemoteSet, urlTranslate func(*url.URL) *url.URL) (remotes Remotes) {
|
||||
for _, r := range gitRemotes {
|
||||
if re.MatchString(r) {
|
||||
match := re.FindStringSubmatch(r)
|
||||
name := strings.TrimSpace(match[1])
|
||||
url := strings.TrimSpace(match[2])
|
||||
urlType := strings.TrimSpace(match[3])
|
||||
utm, ok := remotesMap[name]
|
||||
if !ok {
|
||||
utm = make(map[string]string)
|
||||
remotesMap[name] = utm
|
||||
names = append(names, name)
|
||||
}
|
||||
utm[urlType] = url
|
||||
var owner string
|
||||
var repo string
|
||||
if r.FetchURL != nil {
|
||||
owner, repo, _ = repoFromURL(urlTranslate(r.FetchURL))
|
||||
}
|
||||
if r.PushURL != nil && owner == "" {
|
||||
owner, repo, _ = repoFromURL(urlTranslate(r.PushURL))
|
||||
}
|
||||
remotes = append(remotes, &Remote{
|
||||
Remote: r,
|
||||
Owner: owner,
|
||||
Repo: repo,
|
||||
})
|
||||
}
|
||||
|
||||
for _, name := range names {
|
||||
urlMap := remotesMap[name]
|
||||
repo, err := repoFromURL(urlMap["fetch"])
|
||||
if err != nil {
|
||||
repo, err = repoFromURL(urlMap["push"])
|
||||
}
|
||||
if err == nil {
|
||||
remotes = append(remotes, &Remote{
|
||||
Name: name,
|
||||
Owner: repo.Owner,
|
||||
Repo: repo.Name,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func repoFromURL(u string) (*GitHubRepository, error) {
|
||||
url, err := git.ParseURL(u)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
func repoFromURL(u *url.URL) (string, string, error) {
|
||||
if !strings.EqualFold(u.Hostname(), defaultHostname) {
|
||||
return "", "", fmt.Errorf("unsupported hostname: %s", u.Hostname())
|
||||
}
|
||||
if url.Hostname() != defaultHostname {
|
||||
return nil, fmt.Errorf("invalid hostname: %s", url.Hostname())
|
||||
}
|
||||
parts := strings.SplitN(strings.TrimPrefix(url.Path, "/"), "/", 3)
|
||||
parts := strings.SplitN(strings.TrimPrefix(u.Path, "/"), "/", 3)
|
||||
if len(parts) < 2 {
|
||||
return nil, fmt.Errorf("invalid path: %s", url.Path)
|
||||
return "", "", fmt.Errorf("invalid path: %s", u.Path)
|
||||
}
|
||||
return &GitHubRepository{
|
||||
Owner: parts[0],
|
||||
Name: strings.TrimSuffix(parts[1], ".git"),
|
||||
}, nil
|
||||
return parts[0], strings.TrimSuffix(parts[1], ".git"), nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,67 +2,43 @@ package context
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/github/gh-cli/git"
|
||||
)
|
||||
|
||||
func Test_repoFromURL(t *testing.T) {
|
||||
git.InitSSHAliasMap(nil)
|
||||
|
||||
r, err := repoFromURL("http://github.com/monalisa/octo-cat.git")
|
||||
u, _ := url.Parse("http://github.com/monalisa/octo-cat.git")
|
||||
owner, repo, err := repoFromURL(u)
|
||||
eq(t, err, nil)
|
||||
eq(t, r, &GitHubRepository{Owner: "monalisa", Name: "octo-cat"})
|
||||
eq(t, owner, "monalisa")
|
||||
eq(t, repo, "octo-cat")
|
||||
}
|
||||
|
||||
func Test_repoFromURL_invalid(t *testing.T) {
|
||||
git.InitSSHAliasMap(nil)
|
||||
|
||||
_, err := repoFromURL("https://example.com/one/two")
|
||||
eq(t, err, errors.New(`invalid hostname: example.com`))
|
||||
|
||||
_, err = repoFromURL("/path/to/disk")
|
||||
eq(t, err, errors.New(`invalid hostname: `))
|
||||
}
|
||||
|
||||
func Test_repoFromURL_SSH(t *testing.T) {
|
||||
git.InitSSHAliasMap(map[string]string{
|
||||
"gh": "github.com",
|
||||
"github.com": "ssh.github.com",
|
||||
})
|
||||
|
||||
r, err := repoFromURL("git@gh:monalisa/octo-cat")
|
||||
eq(t, err, nil)
|
||||
eq(t, r, &GitHubRepository{Owner: "monalisa", Name: "octo-cat"})
|
||||
|
||||
r, err = repoFromURL("git@github.com:monalisa/octo-cat")
|
||||
eq(t, err, nil)
|
||||
eq(t, r, &GitHubRepository{Owner: "monalisa", Name: "octo-cat"})
|
||||
}
|
||||
|
||||
func Test_parseRemotes(t *testing.T) {
|
||||
git.InitSSHAliasMap(nil)
|
||||
|
||||
remoteList := []string{
|
||||
"mona\tgit@github.com:monalisa/myfork.git (fetch)",
|
||||
"origin\thttps://github.com/monalisa/octo-cat.git (fetch)",
|
||||
"origin\thttps://github.com/monalisa/octo-cat-push.git (push)",
|
||||
"upstream\thttps://example.com/nowhere.git (fetch)",
|
||||
"upstream\thttps://github.com/hubot/tools (push)",
|
||||
cases := [][]string{
|
||||
[]string{
|
||||
"https://example.com/one/two",
|
||||
"unsupported hostname: example.com",
|
||||
},
|
||||
[]string{
|
||||
"/path/to/disk",
|
||||
"unsupported hostname: ",
|
||||
},
|
||||
}
|
||||
for _, c := range cases {
|
||||
u, _ := url.Parse(c[0])
|
||||
_, _, err := repoFromURL(u)
|
||||
eq(t, err, errors.New(c[1]))
|
||||
}
|
||||
r := parseRemotes(remoteList)
|
||||
eq(t, len(r), 3)
|
||||
|
||||
eq(t, r[0], &Remote{Name: "mona", Owner: "monalisa", Repo: "myfork"})
|
||||
eq(t, r[1], &Remote{Name: "origin", Owner: "monalisa", Repo: "octo-cat"})
|
||||
eq(t, r[2], &Remote{Name: "upstream", Owner: "hubot", Repo: "tools"})
|
||||
}
|
||||
|
||||
func Test_Remotes_FindByName(t *testing.T) {
|
||||
list := Remotes{
|
||||
&Remote{Name: "mona", Owner: "monalisa", Repo: "myfork"},
|
||||
&Remote{Name: "origin", Owner: "monalisa", Repo: "octo-cat"},
|
||||
&Remote{Name: "upstream", Owner: "hubot", Repo: "tools"},
|
||||
&Remote{Remote: &git.Remote{Name: "mona"}, Owner: "monalisa", Repo: "myfork"},
|
||||
&Remote{Remote: &git.Remote{Name: "origin"}, Owner: "monalisa", Repo: "octo-cat"},
|
||||
&Remote{Remote: &git.Remote{Name: "upstream"}, Owner: "hubot", Repo: "tools"},
|
||||
}
|
||||
|
||||
r, err := list.FindByName("upstream", "origin")
|
||||
|
|
|
|||
|
|
@ -165,7 +165,7 @@ func Log(sha1, sha2 string) (string, error) {
|
|||
return string(outputs), nil
|
||||
}
|
||||
|
||||
func Remotes() ([]string, error) {
|
||||
func listRemotes() ([]string, error) {
|
||||
remoteCmd := exec.Command("git", "remote", "-v")
|
||||
remoteCmd.Stderr = nil
|
||||
output, err := remoteCmd.Output()
|
||||
|
|
|
|||
69
git/remote.go
Normal file
69
git/remote.go
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
package git
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var remoteRE = regexp.MustCompile(`(.+)\s+(.+)\s+\((push|fetch)\)`)
|
||||
|
||||
// RemoteSet is a slice of git remotes
|
||||
type RemoteSet []*Remote
|
||||
|
||||
// Remote is a parsed git remote
|
||||
type Remote struct {
|
||||
Name string
|
||||
FetchURL *url.URL
|
||||
PushURL *url.URL
|
||||
}
|
||||
|
||||
func (r *Remote) String() string {
|
||||
return r.Name
|
||||
}
|
||||
|
||||
// Remotes gets the git remotes set for the current repo
|
||||
func Remotes() (RemoteSet, error) {
|
||||
list, err := listRemotes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return parseRemotes(list), nil
|
||||
}
|
||||
|
||||
func parseRemotes(gitRemotes []string) (remotes RemoteSet) {
|
||||
for _, r := range gitRemotes {
|
||||
match := remoteRE.FindStringSubmatch(r)
|
||||
if match == nil {
|
||||
continue
|
||||
}
|
||||
name := strings.TrimSpace(match[1])
|
||||
urlStr := strings.TrimSpace(match[2])
|
||||
urlType := strings.TrimSpace(match[3])
|
||||
|
||||
var rem *Remote
|
||||
if len(remotes) > 0 {
|
||||
rem = remotes[len(remotes)-1]
|
||||
if name != rem.Name {
|
||||
rem = nil
|
||||
}
|
||||
}
|
||||
if rem == nil {
|
||||
rem = &Remote{Name: name}
|
||||
remotes = append(remotes, rem)
|
||||
}
|
||||
|
||||
u, err := ParseURL(urlStr)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
switch urlType {
|
||||
case "fetch":
|
||||
rem.FetchURL = u
|
||||
case "push":
|
||||
rem.PushURL = u
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
31
git/remote_test.go
Normal file
31
git/remote_test.go
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
package git
|
||||
|
||||
import "testing"
|
||||
|
||||
func Test_parseRemotes(t *testing.T) {
|
||||
remoteList := []string{
|
||||
"mona\tgit@github.com:monalisa/myfork.git (fetch)",
|
||||
"origin\thttps://github.com/monalisa/octo-cat.git (fetch)",
|
||||
"origin\thttps://github.com/monalisa/octo-cat-push.git (push)",
|
||||
"upstream\thttps://example.com/nowhere.git (fetch)",
|
||||
"upstream\thttps://github.com/hubot/tools (push)",
|
||||
"zardoz\thttps://example.com/zed.git (push)",
|
||||
}
|
||||
r := parseRemotes(remoteList)
|
||||
eq(t, len(r), 4)
|
||||
|
||||
eq(t, r[0].Name, "mona")
|
||||
eq(t, r[0].FetchURL.String(), "ssh://git@github.com/monalisa/myfork.git")
|
||||
if r[0].PushURL != nil {
|
||||
t.Errorf("expected no PushURL, got %q", r[0].PushURL)
|
||||
}
|
||||
eq(t, r[1].Name, "origin")
|
||||
eq(t, r[1].FetchURL.Path, "/monalisa/octo-cat.git")
|
||||
eq(t, r[1].PushURL.Path, "/monalisa/octo-cat-push.git")
|
||||
|
||||
eq(t, r[2].Name, "upstream")
|
||||
eq(t, r[2].FetchURL.Host, "example.com")
|
||||
eq(t, r[2].PushURL.Host, "github.com")
|
||||
|
||||
eq(t, r[3].Name, "zardoz")
|
||||
}
|
||||
|
|
@ -3,6 +3,7 @@ package git
|
|||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
|
|
@ -21,9 +22,32 @@ func init() {
|
|||
sshTokenRE = regexp.MustCompile(`%[%h]`)
|
||||
}
|
||||
|
||||
type sshAliasMap map[string]string
|
||||
// SSHAliasMap encapsulates the translation of SSH hostname aliases
|
||||
type SSHAliasMap map[string]string
|
||||
|
||||
func sshParseFiles() sshAliasMap {
|
||||
// Translator returns a function that applies hostname aliases to URLs
|
||||
func (m SSHAliasMap) Translator() func(*url.URL) *url.URL {
|
||||
return func(u *url.URL) *url.URL {
|
||||
if u.Scheme != "ssh" {
|
||||
return u
|
||||
}
|
||||
resolvedHost, ok := m[u.Hostname()]
|
||||
if !ok {
|
||||
return u
|
||||
}
|
||||
// FIXME: cleanup domain logic
|
||||
if strings.EqualFold(u.Hostname(), "github.com") && strings.EqualFold(resolvedHost, "ssh.github.com") {
|
||||
return u
|
||||
}
|
||||
newURL, _ := url.Parse(u.String())
|
||||
newURL.Host = resolvedHost
|
||||
return newURL
|
||||
}
|
||||
}
|
||||
|
||||
// ParseSSHConfig constructs a map of SSH hostname aliases based on user and
|
||||
// system configuration files
|
||||
func ParseSSHConfig() SSHAliasMap {
|
||||
configFiles := []string{
|
||||
"/etc/ssh_config",
|
||||
"/etc/ssh/ssh_config",
|
||||
|
|
@ -45,15 +69,15 @@ func sshParseFiles() sshAliasMap {
|
|||
return sshParse(openFiles...)
|
||||
}
|
||||
|
||||
func sshParse(r ...io.Reader) sshAliasMap {
|
||||
config := sshAliasMap{}
|
||||
func sshParse(r ...io.Reader) SSHAliasMap {
|
||||
config := SSHAliasMap{}
|
||||
for _, file := range r {
|
||||
sshParseConfig(config, file)
|
||||
}
|
||||
return config
|
||||
}
|
||||
|
||||
func sshParseConfig(c sshAliasMap, file io.Reader) error {
|
||||
func sshParseConfig(c SSHAliasMap, file io.Reader) error {
|
||||
hosts := []string{"*"}
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package git
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
|
@ -8,6 +9,7 @@ import (
|
|||
|
||||
// TODO: extract assertion helpers into a shared package
|
||||
func eq(t *testing.T, got interface{}, expected interface{}) {
|
||||
t.Helper()
|
||||
if !reflect.DeepEqual(got, expected) {
|
||||
t.Errorf("expected: %v, got: %v", expected, got)
|
||||
}
|
||||
|
|
@ -25,3 +27,24 @@ func Test_sshParse(t *testing.T) {
|
|||
eq(t, m["bar"], "%bar.net%")
|
||||
eq(t, m["nonexist"], "")
|
||||
}
|
||||
|
||||
func Test_Translator(t *testing.T) {
|
||||
m := SSHAliasMap{
|
||||
"gh": "github.com",
|
||||
"github.com": "ssh.github.com",
|
||||
}
|
||||
tr := m.Translator()
|
||||
|
||||
cases := [][]string{
|
||||
[]string{"ssh://gh/o/r", "ssh://github.com/o/r"},
|
||||
[]string{"ssh://github.com/o/r", "ssh://github.com/o/r"},
|
||||
[]string{"https://gh/o/r", "https://gh/o/r"},
|
||||
}
|
||||
for _, c := range cases {
|
||||
u, _ := url.Parse(c[0])
|
||||
got := tr(u)
|
||||
if got.String() != c[1] {
|
||||
t.Errorf("%q: expected %q, got %q", c[0], c[1], got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
26
git/url.go
26
git/url.go
|
|
@ -7,8 +7,7 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
cachedSSHConfig sshAliasMap
|
||||
protocolRe = regexp.MustCompile("^[a-zA-Z_+-]+://")
|
||||
protocolRe = regexp.MustCompile("^[a-zA-Z_+-]+://")
|
||||
)
|
||||
|
||||
// ParseURL normalizes git remote urls
|
||||
|
|
@ -41,28 +40,5 @@ func ParseURL(rawURL string) (u *url.URL, err error) {
|
|||
u.Host = u.Host[0:idx]
|
||||
}
|
||||
|
||||
if cachedSSHConfig == nil {
|
||||
return
|
||||
}
|
||||
sshHost := cachedSSHConfig[u.Host]
|
||||
// ignore replacing host that fixes for limited network
|
||||
// https://help.github.com/articles/using-ssh-over-the-https-port
|
||||
ignoredHost := u.Host == "github.com" && sshHost == "ssh.github.com"
|
||||
if !ignoredHost && sshHost != "" {
|
||||
u.Host = sshHost
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// InitSSHAliasMap prepares globally cached SSH hostname alias mappings
|
||||
func InitSSHAliasMap(m map[string]string) {
|
||||
if m == nil {
|
||||
cachedSSHConfig = sshParseFiles()
|
||||
return
|
||||
}
|
||||
cachedSSHConfig = sshAliasMap{}
|
||||
for k, v := range m {
|
||||
cachedSSHConfig[k] = v
|
||||
}
|
||||
}
|
||||
|
|
|
|||
4
test/fixtures/prList.json
vendored
4
test/fixtures/prList.json
vendored
|
|
@ -1,4 +1,4 @@
|
|||
{
|
||||
{"data":{
|
||||
"repository": {
|
||||
"pullRequests": {
|
||||
"edges": [
|
||||
|
|
@ -47,4 +47,4 @@
|
|||
],
|
||||
"pageInfo": { "hasNextPage": false }
|
||||
}
|
||||
}
|
||||
}}
|
||||
4
test/fixtures/prView.json
vendored
4
test/fixtures/prView.json
vendored
|
|
@ -1,4 +1,4 @@
|
|||
{
|
||||
{"data":{
|
||||
"repository": {
|
||||
"pullRequests": {
|
||||
"edges": [
|
||||
|
|
@ -47,4 +47,4 @@
|
|||
],
|
||||
"pageInfo": { "hasNextPage": false }
|
||||
}
|
||||
}
|
||||
}}
|
||||
4
test/fixtures/prView_NoActiveBranch.json
vendored
4
test/fixtures/prView_NoActiveBranch.json
vendored
|
|
@ -1,4 +1,4 @@
|
|||
{
|
||||
{"data":{
|
||||
"repository": {
|
||||
"pullRequests": {
|
||||
"edges": []
|
||||
|
|
@ -12,4 +12,4 @@
|
|||
"edges": [],
|
||||
"pageInfo": { "hasNextPage": false }
|
||||
}
|
||||
}
|
||||
}}
|
||||
|
|
@ -1,7 +1,6 @@
|
|||
package test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
|
|
@ -9,7 +8,6 @@ import (
|
|||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/github/gh-cli/api"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
|
|
@ -67,30 +65,6 @@ func UseTempGitRepo() *TempGitRepo {
|
|||
return &TempGitRepo{Remote: remotePath, TearDown: tearDown}
|
||||
}
|
||||
|
||||
func MockGraphQLResponse(fixturePath string) (teardown func()) {
|
||||
pwd, _ := os.Getwd()
|
||||
fixturePath = filepath.Join(pwd, "..", fixturePath)
|
||||
|
||||
originalGraphQL := api.GraphQL
|
||||
api.GraphQL = func(query string, variables map[string]string, v interface{}) error {
|
||||
contents, err := ioutil.ReadFile(fixturePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
json.Unmarshal(contents, &v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return func() {
|
||||
api.GraphQL = originalGraphQL
|
||||
}
|
||||
}
|
||||
|
||||
func RunCommand(root *cobra.Command, s string) (string, error) {
|
||||
var err error
|
||||
output := captureOutput(func() {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue