diff --git a/api/client.go b/api/client.go index cafbbb507..a3fa8e72a 100644 --- a/api/client.go +++ b/api/client.go @@ -4,14 +4,69 @@ import ( "bytes" "encoding/json" "fmt" + "io" "io/ioutil" "net/http" - "os" - - "github.com/github/gh-cli/context" - "github.com/github/gh-cli/version" ) +// ClientOption represents an argument to NewClient +type ClientOption = func(http.RoundTripper) http.RoundTripper + +// NewClient initializes a Client +func NewClient(opts ...ClientOption) *Client { + tr := http.DefaultTransport + for _, opt := range opts { + tr = opt(tr) + } + http := &http.Client{Transport: tr} + client := &Client{http: http} + return client +} + +// AddHeader turns a RoundTripper into one that adds a request header +func AddHeader(name, value string) ClientOption { + return func(tr http.RoundTripper) http.RoundTripper { + return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) { + req.Header.Add(name, value) + return tr.RoundTrip(req) + }} + } +} + +// VerboseLog enables request/response logging within a RoundTripper +func VerboseLog(out io.Writer) ClientOption { + return func(tr http.RoundTripper) http.RoundTripper { + return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) { + fmt.Fprintf(out, "> %s %s\n", req.Method, req.URL.RequestURI()) + res, err := tr.RoundTrip(req) + if err == nil { + fmt.Fprintf(out, "< HTTP %s\n", res.Status) + } + return res, err + }} + } +} + +// ReplaceTripper substitutes the underlying RoundTripper with a custom one +func ReplaceTripper(tr http.RoundTripper) ClientOption { + return func(http.RoundTripper) http.RoundTripper { + return tr + } +} + +type funcTripper struct { + roundTrip func(*http.Request) (*http.Response, error) +} + +func (tr funcTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return tr.roundTrip(req) +} + +// Client facilitates making HTTP requests to the GitHub API +type Client struct { + http *http.Client +} + type graphQLResponse struct { Data interface{} Errors []struct { @@ -19,32 +74,8 @@ type graphQLResponse struct { } } -/* -GraphQL: Declared as an external variable so it can be mocked in tests - -type repoResponse struct { - Repository struct { - CreatedAt string - } -} - -query := `query { - repository(owner: "golang", name: "go") { - createdAt - } -}` - -variables := map[string]string{} - -var resp repoResponse -err := graphql(query, map[string]string{}, &resp) -if err != nil { - panic(err) -} - -fmt.Printf("%+v\n", resp) -*/ -var GraphQL = func(query string, variables map[string]string, data interface{}) error { +// GraphQL performs a GraphQL request and parses the response +func (c Client) GraphQL(query string, variables map[string]interface{}, data interface{}) error { url := "https://api.github.com/graphql" reqBody, err := json.Marshal(map[string]interface{}{"query": query, "variables": variables}) if err != nil { @@ -56,42 +87,31 @@ var GraphQL = func(query string, variables map[string]string, data interface{}) return err } - token, err := context.Current().AuthToken() - if err != nil { - return err - } - - req.Header.Set("Authorization", "token "+token) req.Header.Set("Content-Type", "application/json; charset=utf-8") - req.Header.Set("User-Agent", "GitHub CLI "+version.Version) - debugRequest(req, string(reqBody)) - - client := &http.Client{} - resp, err := client.Do(req) + resp, err := c.http.Do(req) if err != nil { return err } defer resp.Body.Close() + return handleResponse(resp, data) +} + +func handleResponse(resp *http.Response, data interface{}) error { + success := resp.StatusCode >= 200 && resp.StatusCode < 300 + + if !success { + return handleHTTPError(resp) + } + body, err := ioutil.ReadAll(resp.Body) if err != nil { return err } - debugResponse(resp, string(body)) - return handleResponse(resp, body, data) -} - -func handleResponse(resp *http.Response, body []byte, data interface{}) error { - success := resp.StatusCode >= 200 && resp.StatusCode < 300 - - if !success { - return handleHTTPError(resp, body) - } - gr := &graphQLResponse{Data: data} - err := json.Unmarshal(body, &gr) + err = json.Unmarshal(body, &gr) if err != nil { return err } @@ -107,12 +127,16 @@ func handleResponse(resp *http.Response, body []byte, data interface{}) error { } -func handleHTTPError(resp *http.Response, body []byte) error { +func handleHTTPError(resp *http.Response) error { var message string var parsedBody struct { Message string `json:"message"` } - err := json.Unmarshal(body, &parsedBody) + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return err + } + err = json.Unmarshal(body, &parsedBody) if err != nil { message = string(body) } else { @@ -121,19 +145,3 @@ func handleHTTPError(resp *http.Response, body []byte) error { return fmt.Errorf("http error, '%s' failed (%d): '%s'", resp.Request.URL, resp.StatusCode, message) } - -func debugRequest(req *http.Request, body string) { - if _, ok := os.LookupEnv("DEBUG"); !ok { - return - } - - fmt.Printf("DEBUG: GraphQL request to %s:\n %s\n\n", req.URL, body) -} - -func debugResponse(resp *http.Response, body string) { - if _, ok := os.LookupEnv("DEBUG"); !ok { - return - } - - fmt.Printf("DEBUG: GraphQL response:\n%+v\n\n%s\n\n", resp, body) -} diff --git a/api/client_test.go b/api/client_test.go new file mode 100644 index 000000000..4d7c64917 --- /dev/null +++ b/api/client_test.go @@ -0,0 +1,51 @@ +package api + +import ( + "bytes" + "fmt" + "io/ioutil" + "reflect" + "testing" +) + +func eq(t *testing.T, got interface{}, expected interface{}) { + t.Helper() + if !reflect.DeepEqual(got, expected) { + t.Errorf("expected: %v, got: %v", expected, got) + } +} + +func TestGraphQL(t *testing.T) { + http := &FakeHTTP{} + client := NewClient( + ReplaceTripper(http), + AddHeader("Authorization", "token OTOKEN"), + ) + + vars := map[string]interface{}{"name": "Mona"} + response := struct { + Viewer struct { + Login string + } + }{} + + http.StubResponse(200, bytes.NewBufferString(`{"data":{"viewer":{"login":"hubot"}}}`)) + err := client.GraphQL("QUERY", vars, &response) + eq(t, err, nil) + eq(t, response.Viewer.Login, "hubot") + + req := http.Requests[0] + reqBody, _ := ioutil.ReadAll(req.Body) + eq(t, string(reqBody), `{"query":"QUERY","variables":{"name":"Mona"}}`) + eq(t, req.Header.Get("Authorization"), "token OTOKEN") +} + +func TestGraphQLError(t *testing.T) { + http := &FakeHTTP{} + client := NewClient(ReplaceTripper(http)) + + response := struct{}{} + http.StubResponse(200, bytes.NewBufferString(`{"errors":[{"message":"OH NO"}]}`)) + err := client.GraphQL("", nil, &response) + eq(t, err, fmt.Errorf("graphql error: 'OH NO'")) +} diff --git a/api/fake_http.go b/api/fake_http.go new file mode 100644 index 000000000..96e38aab3 --- /dev/null +++ b/api/fake_http.go @@ -0,0 +1,37 @@ +package api + +import ( + "fmt" + "io" + "io/ioutil" + "net/http" +) + +// FakeHTTP provides a mechanism by which to stub HTTP responses through +type FakeHTTP struct { + // Requests stores references to sequental requests that RoundTrip has received + Requests []*http.Request + count int + responseStubs []*http.Response +} + +// StubResponse pre-records an HTTP response +func (f *FakeHTTP) StubResponse(status int, body io.Reader) { + resp := &http.Response{ + StatusCode: status, + Body: ioutil.NopCloser(body), + } + f.responseStubs = append(f.responseStubs, resp) +} + +// RoundTrip satisfies http.RoundTripper +func (f *FakeHTTP) RoundTrip(req *http.Request) (*http.Response, error) { + if len(f.responseStubs) <= f.count { + return nil, fmt.Errorf("FakeHTTP: missing response stub for request %d", f.count) + } + resp := f.responseStubs[f.count] + f.count++ + resp.Request = req + f.Requests = append(f.Requests, req) + return resp, nil +} diff --git a/api/queries.go b/api/queries.go index b9140670a..e1e078138 100644 --- a/api/queries.go +++ b/api/queries.go @@ -2,8 +2,6 @@ package api import ( "fmt" - - "github.com/github/gh-cli/context" ) type PullRequestsPayload struct { @@ -19,7 +17,12 @@ type PullRequest struct { HeadRefName string } -func PullRequests() (*PullRequestsPayload, error) { +type Repo interface { + RepoName() string + RepoOwner() string +} + +func PullRequests(client *Client, ghRepo Repo, currentBranch, currentUsername string) (*PullRequestsPayload, error) { type edges struct { Edges []struct { Node PullRequest @@ -48,7 +51,7 @@ func PullRequests() (*PullRequestsPayload, error) { query($owner: String!, $repo: String!, $headRefName: String!, $viewerQuery: String!, $reviewerQuery: String!, $per_page: Int = 10) { repository(owner: $owner, name: $repo) { - pullRequests(headRefName: $headRefName, first: 1) { + pullRequests(headRefName: $headRefName, states: OPEN, first: 1) { edges { node { ...pr @@ -79,26 +82,13 @@ func PullRequests() (*PullRequestsPayload, error) { } ` - ghRepo, err := context.Current().BaseRepo() - if err != nil { - return nil, err - } - currentBranch, err := context.Current().Branch() - if err != nil { - return nil, err - } - currentUsername, err := context.Current().AuthLogin() - if err != nil { - return nil, err - } - - owner := ghRepo.Owner - repo := ghRepo.Name + owner := ghRepo.RepoOwner() + repo := ghRepo.RepoName() viewerQuery := fmt.Sprintf("repo:%s/%s state:open is:pr author:%s", owner, repo, currentUsername) reviewerQuery := fmt.Sprintf("repo:%s/%s state:open review-requested:%s", owner, repo, currentUsername) - variables := map[string]string{ + variables := map[string]interface{}{ "viewerQuery": viewerQuery, "reviewerQuery": reviewerQuery, "owner": owner, @@ -107,7 +97,7 @@ func PullRequests() (*PullRequestsPayload, error) { } var resp response - err = GraphQL(query, variables, &resp) + err := client.GraphQL(query, variables, &resp) if err != nil { return nil, err } @@ -135,3 +125,49 @@ func PullRequests() (*PullRequestsPayload, error) { return &payload, nil } + +func PullRequestsForBranch(client *Client, ghRepo Repo, branch string) ([]PullRequest, error) { + type response struct { + Repository struct { + PullRequests struct { + Edges []struct { + Node PullRequest + } + } + } + } + + query := ` + query($owner: String!, $repo: String!, $headRefName: String!) { + repository(owner: $owner, name: $repo) { + pullRequests(headRefName: $headRefName, states: OPEN, first: 1) { + edges { + node { + number + title + url + } + } + } + } + }` + + variables := map[string]interface{}{ + "owner": ghRepo.RepoOwner(), + "repo": ghRepo.RepoName(), + "headRefName": branch, + } + + var resp response + err := client.GraphQL(query, variables, &resp) + if err != nil { + return nil, err + } + + prs := []PullRequest{} + for _, edge := range resp.Repository.PullRequests.Edges { + prs = append(prs, edge.Node) + } + + return prs, nil +} diff --git a/command/pr.go b/command/pr.go index 8f35b9183..f1dfe8fb0 100644 --- a/command/pr.go +++ b/command/pr.go @@ -5,7 +5,6 @@ import ( "strconv" "github.com/github/gh-cli/api" - "github.com/github/gh-cli/context" "github.com/github/gh-cli/utils" "github.com/spf13/cobra" ) @@ -38,7 +37,26 @@ work with pull requests.`, } func prList(cmd *cobra.Command, args []string) error { - prPayload, err := api.PullRequests() + ctx := contextForCommand(cmd) + apiClient, err := apiClientForContext(ctx) + if err != nil { + return err + } + + baseRepo, err := ctx.BaseRepo() + if err != nil { + return err + } + currentBranch, err := ctx.Branch() + if err != nil { + return err + } + currentUser, err := ctx.AuthLogin() + if err != nil { + return err + } + + prPayload, err := api.PullRequests(apiClient, baseRepo, currentBranch, currentUser) if err != nil { return err } @@ -47,10 +65,6 @@ func prList(cmd *cobra.Command, args []string) error { if prPayload.CurrentPR != nil { printPrs(*prPayload.CurrentPR) } else { - currentBranch, err := context.Current().Branch() - if err != nil { - return err - } message := fmt.Sprintf(" There is no pull request associated with %s", utils.Cyan("["+currentBranch+"]")) printMessage(message) } @@ -76,7 +90,8 @@ func prList(cmd *cobra.Command, args []string) error { } func prView(cmd *cobra.Command, args []string) error { - baseRepo, err := context.Current().BaseRepo() + ctx := contextForCommand(cmd) + baseRepo, err := ctx.BaseRepo() if err != nil { return err } @@ -85,23 +100,27 @@ func prView(cmd *cobra.Command, args []string) error { if len(args) > 0 { if prNumber, err := strconv.Atoi(args[0]); err == nil { // TODO: move URL generation into GitHubRepository - openURL = fmt.Sprintf("https://github.com/%s/%s/pull/%d", baseRepo.Owner, baseRepo.Name, prNumber) + openURL = fmt.Sprintf("https://github.com/%s/%s/pull/%d", baseRepo.RepoOwner(), baseRepo.RepoName(), prNumber) } else { return fmt.Errorf("invalid pull request number: '%s'", args[0]) } } else { - prPayload, err := api.PullRequests() + apiClient, err := apiClientForContext(ctx) if err != nil { return err - } else if prPayload.CurrentPR == nil { - branch, err := context.Current().Branch() - if err != nil { - return err - } - fmt.Printf("The [%s] branch has no open PRs", branch) - return nil } - openURL = prPayload.CurrentPR.URL + currentBranch, err := ctx.Branch() + if err != nil { + return err + } + + prs, err := api.PullRequestsForBranch(apiClient, baseRepo, currentBranch) + if err != nil { + return err + } else if len(prs) < 1 { + return fmt.Errorf("the '%s' branch has no open pull requests", currentBranch) + } + openURL = prs[0].URL } fmt.Printf("Opening %s in your browser.\n", openURL) diff --git a/command/pr_test.go b/command/pr_test.go index 10a47b5e9..e4a5122d7 100644 --- a/command/pr_test.go +++ b/command/pr_test.go @@ -1,21 +1,40 @@ package command import ( + "os" "regexp" "testing" + "github.com/github/gh-cli/api" "github.com/github/gh-cli/context" "github.com/github/gh-cli/test" "github.com/github/gh-cli/utils" ) -func TestPRList(t *testing.T) { - ctx := context.InitBlankContext() - ctx.SetBaseRepo("github/FAKE-GITHUB-REPO-NAME") - ctx.SetBranch("master") +func initBlankContext(repo, branch string) { + initContext = func() context.Context { + ctx := context.NewBlank() + ctx.SetBaseRepo(repo) + ctx.SetBranch(branch) + return ctx + } +} - teardown := test.MockGraphQLResponse("test/fixtures/prList.json") - defer teardown() +func initFakeHTTP() *api.FakeHTTP { + http := &api.FakeHTTP{} + apiClientForContext = func(context.Context) (*api.Client, error) { + return api.NewClient(api.ReplaceTripper(http)), nil + } + return http +} + +func TestPRList(t *testing.T) { + initBlankContext("OWNER/REPO", "master") + http := initFakeHTTP() + + jsonFile, _ := os.Open("../test/fixtures/prList.json") + defer jsonFile.Close() + http.StubResponse(200, jsonFile) output, err := test.RunCommand(RootCmd, "pr list") if err != nil { @@ -37,11 +56,12 @@ func TestPRList(t *testing.T) { } func TestPRView(t *testing.T) { - teardown := test.MockGraphQLResponse("test/fixtures/prView.json") - defer teardown() + initBlankContext("OWNER/REPO", "master") + http := initFakeHTTP() - gitRepo := test.UseTempGitRepo() - defer gitRepo.TearDown() + jsonFile, _ := os.Open("../test/fixtures/prView.json") + defer jsonFile.Close() + http.StubResponse(200, jsonFile) teardown, callCount := mockOpenInBrowser() defer teardown() @@ -61,24 +81,21 @@ func TestPRView(t *testing.T) { } func TestPRView_NoActiveBranch(t *testing.T) { - teardown := test.MockGraphQLResponse("test/fixtures/prView_NoActiveBranch.json") - defer teardown() + initBlankContext("OWNER/REPO", "master") + http := initFakeHTTP() - gitRepo := test.UseTempGitRepo() - defer gitRepo.TearDown() + jsonFile, _ := os.Open("../test/fixtures/prView_NoActiveBranch.json") + defer jsonFile.Close() + http.StubResponse(200, jsonFile) teardown, callCount := mockOpenInBrowser() defer teardown() output, err := test.RunCommand(RootCmd, "pr view") - if err != nil { + if err == nil || err.Error() != "the 'master' branch has no open pull requests" { t.Errorf("error running command `pr view`: %v", err) } - if output == "" { - t.Errorf("command output expected got an empty string") - } - if *callCount > 0 { t.Errorf("OpenInBrowser should NOT be called but was called %d time(s)", *callCount) } diff --git a/command/root.go b/command/root.go index d9fa859ee..c0b676c86 100644 --- a/command/root.go +++ b/command/root.go @@ -4,31 +4,18 @@ import ( "fmt" "os" + "github.com/github/gh-cli/api" "github.com/github/gh-cli/context" - "github.com/github/gh-cli/git" + "github.com/github/gh-cli/version" + "github.com/spf13/cobra" ) -var ( - currentRepo string - currentBranch string -) - func init() { - RootCmd.PersistentFlags().StringVarP(¤tRepo, "repo", "R", "", "current GitHub repository") - RootCmd.PersistentFlags().StringVarP(¤tBranch, "current-branch", "B", "", "current git branch") -} - -func initContext() { - ctx := context.InitDefaultContext() - ctx.SetBranch(currentBranch) - repo := currentRepo - if repo == "" { - repo = os.Getenv("GH_REPO") - } - ctx.SetBaseRepo(repo) - - git.InitSSHAliasMap(nil) + RootCmd.PersistentFlags().StringP("repo", "R", "", "current GitHub repository") + RootCmd.PersistentFlags().StringP("current-branch", "B", "", "current git branch") + // TODO: + // RootCmd.PersistentFlags().BoolP("verbose", "V", false, "enable verbose output") } // RootCmd is the entry point of command-line execution @@ -37,10 +24,43 @@ var RootCmd = &cobra.Command{ Short: "GitHub CLI", Long: `Do things with GitHub from your terminal`, Args: cobra.MinimumNArgs(1), - PersistentPreRun: func(cmd *cobra.Command, args []string) { - initContext() - }, Run: func(cmd *cobra.Command, args []string) { fmt.Println("root") }, } + +// overriden in tests +var initContext = func() context.Context { + ctx := context.New() + if repo := os.Getenv("GH_REPO"); repo != "" { + ctx.SetBaseRepo(repo) + } + return ctx +} + +func contextForCommand(cmd *cobra.Command) context.Context { + ctx := initContext() + if repo, err := cmd.Flags().GetString("repo"); err == nil && repo != "" { + ctx.SetBaseRepo(repo) + } + if branch, err := cmd.Flags().GetString("current-branch"); err == nil && branch != "" { + ctx.SetBranch(branch) + } + return ctx +} + +// overriden in tests +var apiClientForContext = func(ctx context.Context) (*api.Client, error) { + token, err := ctx.AuthToken() + if err != nil { + return nil, err + } + opts := []api.ClientOption{ + api.AddHeader("Authorization", fmt.Sprintf("token %s", token)), + api.AddHeader("User-Agent", fmt.Sprintf("GitHub CLI %s", version.Version)), + } + if verbose := os.Getenv("DEBUG"); verbose != "" { + opts = append(opts, api.VerboseLog(os.Stderr)) + } + return api.NewClient(opts...), nil +} diff --git a/context/blank_context.go b/context/blank_context.go index e7e48c878..d5ad2cafe 100644 --- a/context/blank_context.go +++ b/context/blank_context.go @@ -5,13 +5,9 @@ import ( "strings" ) -// InitBlankContext initializes a blank context for testing -func InitBlankContext() Context { - currentContext = &blankContext{ - authToken: "OTOKEN", - authLogin: "monalisa", - } - return currentContext +// NewBlank initializes a blank Context suitable for testing +func NewBlank() Context { + return &blankContext{} } // A Context implementation that queries the filesystem @@ -19,7 +15,19 @@ type blankContext struct { authToken string authLogin string branch string - baseRepo *GitHubRepository + baseRepo GitHubRepository +} + +type ghRepo struct { + owner string + name string +} + +func (r ghRepo) RepoOwner() string { + return r.owner +} +func (r ghRepo) RepoName() string { + return r.name } func (c *blankContext) AuthToken() (string, error) { @@ -49,7 +57,7 @@ func (c *blankContext) Remotes() (Remotes, error) { return Remotes{}, nil } -func (c *blankContext) BaseRepo() (*GitHubRepository, error) { +func (c *blankContext) BaseRepo() (GitHubRepository, error) { if c.baseRepo == nil { return nil, fmt.Errorf("base repo was not initialized") } @@ -59,9 +67,6 @@ func (c *blankContext) BaseRepo() (*GitHubRepository, error) { func (c *blankContext) SetBaseRepo(nwo string) { parts := strings.SplitN(nwo, "/", 2) if len(parts) == 2 { - c.baseRepo = &GitHubRepository{ - Owner: parts[0], - Name: parts[1], - } + c.baseRepo = &ghRepo{parts[0], parts[1]} } } diff --git a/context/config_file_test.go b/context/config_file_test.go index 7163e6b2a..5a915e78d 100644 --- a/context/config_file_test.go +++ b/context/config_file_test.go @@ -8,6 +8,7 @@ import ( ) func eq(t *testing.T, got interface{}, expected interface{}) { + t.Helper() if !reflect.DeepEqual(got, expected) { t.Errorf("expected: %v, got: %v", expected, got) } diff --git a/context/config_setup.go b/context/config_setup.go index 115d31475..290547b4a 100644 --- a/context/config_setup.go +++ b/context/config_setup.go @@ -7,6 +7,7 @@ import ( "os" "path/filepath" + "github.com/github/gh-cli/api" "github.com/github/gh-cli/auth" "gopkg.in/yaml.v3" ) @@ -20,6 +21,7 @@ const ( ) // TODO: have a conversation about whether this belongs in the "context" package +// FIXME: make testable func setupConfigFile(filename string) (*configEntry, error) { flow := &auth.OAuthFlow{ Hostname: oauthHost, @@ -38,12 +40,12 @@ func setupConfigFile(filename string) (*configEntry, error) { return nil, err } - u, err := getViewer(token) + userLogin, err := getViewer(token) if err != nil { return nil, err } entry := configEntry{ - User: u.Login, + User: userLogin, Token: token, } data := make(map[string][]configEntry) @@ -74,6 +76,18 @@ func setupConfigFile(filename string) (*configEntry, error) { return &entry, err } +func getViewer(token string) (string, error) { + http := api.NewClient(api.AddHeader("Authorization", fmt.Sprintf("token %s", token))) + + response := struct { + Viewer struct { + Login string + } + }{} + err := http.GraphQL("{ viewer { login } }", nil, &response) + return response.Viewer.Login, err +} + func waitForEnter(r io.Reader) error { scanner := bufio.NewScanner(r) scanner.Scan() diff --git a/context/config_viewer.go b/context/config_viewer.go deleted file mode 100644 index 0a03cdca3..000000000 --- a/context/config_viewer.go +++ /dev/null @@ -1,59 +0,0 @@ -package context - -import ( - "bytes" - "encoding/json" - "io/ioutil" - "net/http" -) - -type viewer struct { - Login string -} -type responseData struct { - Data struct { - Viewer *viewer - } -} - -// TODO: figure out how to enable using the "api" package here -// -// Right now "api" is coupled to "context", so we can't import "api" from here. -func getViewer(token string) (user *viewer, err error) { - url := "https://api.github.com/graphql" - query := `{ viewer { login } }` - - reqBody, err := json.Marshal(map[string]interface{}{"query": query}) - if err != nil { - return - } - - req, err := http.NewRequest("POST", url, bytes.NewBuffer(reqBody)) - if err != nil { - return - } - - req.Header.Set("Authorization", "token "+token) - req.Header.Set("Content-Type", "application/json; charset=utf-8") - - client := http.Client{} - resp, err := client.Do(req) - if err != nil { - return - } - defer resp.Body.Close() - - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - return - } - - data := responseData{} - err = json.Unmarshal(body, &data) - if err != nil { - return - } - - user = data.Data.Viewer - return -} diff --git a/context/context.go b/context/context.go index 3e6dd009e..7e165299d 100644 --- a/context/context.go +++ b/context/context.go @@ -15,24 +15,19 @@ type Context interface { Branch() (string, error) SetBranch(string) Remotes() (Remotes, error) - BaseRepo() (*GitHubRepository, error) + BaseRepo() (GitHubRepository, error) SetBaseRepo(string) } -var currentContext Context - -// Current returns the currently initialized Context instance -func Current() Context { - return currentContext +// GitHubRepository is anything that can be mapped to an OWNER/REPO pair +type GitHubRepository interface { + RepoOwner() string + RepoName() string } -// InitDefaultContext initializes the default filesystem context -func InitDefaultContext() Context { - ctx := &fsContext{} - if currentContext == nil { - currentContext = ctx - } - return ctx +// New initializes a Context that reads from the filesystem +func New() Context { + return &fsContext{} } // A Context implementation that queries the filesystem @@ -40,7 +35,7 @@ type fsContext struct { config *configEntry remotes Remotes branch string - baseRepo *GitHubRepository + baseRepo GitHubRepository authToken string } @@ -109,12 +104,13 @@ func (c *fsContext) Remotes() (Remotes, error) { if err != nil { return nil, err } - c.remotes = parseRemotes(gitRemotes) + sshTranslate := git.ParseSSHConfig().Translator() + c.remotes = translateRemotes(gitRemotes, sshTranslate) } return c.remotes, nil } -func (c *fsContext) BaseRepo() (*GitHubRepository, error) { +func (c *fsContext) BaseRepo() (GitHubRepository, error) { if c.baseRepo != nil { return c.baseRepo, nil } @@ -128,19 +124,13 @@ func (c *fsContext) BaseRepo() (*GitHubRepository, error) { return nil, err } - c.baseRepo = &GitHubRepository{ - Owner: rem.Owner, - Name: rem.Repo, - } + c.baseRepo = rem return c.baseRepo, nil } func (c *fsContext) SetBaseRepo(nwo string) { parts := strings.SplitN(nwo, "/", 2) if len(parts) == 2 { - c.baseRepo = &GitHubRepository{ - Owner: parts[0], - Name: parts[1], - } + c.baseRepo = &ghRepo{parts[0], parts[1]} } } diff --git a/context/remote.go b/context/remote.go index 8ffcef567..9f3b228dc 100644 --- a/context/remote.go +++ b/context/remote.go @@ -2,7 +2,7 @@ package context import ( "fmt" - "regexp" + "net/url" "strings" "github.com/github/gh-cli/git" @@ -27,74 +27,48 @@ func (r Remotes) FindByName(names ...string) (*Remote, error) { // Remote represents a git remote mapped to a GitHub repository type Remote struct { - Name string + *git.Remote Owner string Repo string } -func (r *Remote) String() string { - return r.Name +// RepoName is the name of the GitHub repository +func (r Remote) RepoName() string { + return r.Repo } -// GitHubRepository represents a GitHub respository -type GitHubRepository struct { - Name string - Owner string +// RepoOwner is the name of the GitHub account that owns the repo +func (r Remote) RepoOwner() string { + return r.Owner } -func parseRemotes(gitRemotes []string) (remotes Remotes) { - re := regexp.MustCompile(`(.+)\s+(.+)\s+\((push|fetch)\)`) - - names := []string{} - remotesMap := make(map[string]map[string]string) +// TODO: accept an interface instead of git.RemoteSet +func translateRemotes(gitRemotes git.RemoteSet, urlTranslate func(*url.URL) *url.URL) (remotes Remotes) { for _, r := range gitRemotes { - if re.MatchString(r) { - match := re.FindStringSubmatch(r) - name := strings.TrimSpace(match[1]) - url := strings.TrimSpace(match[2]) - urlType := strings.TrimSpace(match[3]) - utm, ok := remotesMap[name] - if !ok { - utm = make(map[string]string) - remotesMap[name] = utm - names = append(names, name) - } - utm[urlType] = url + var owner string + var repo string + if r.FetchURL != nil { + owner, repo, _ = repoFromURL(urlTranslate(r.FetchURL)) } + if r.PushURL != nil && owner == "" { + owner, repo, _ = repoFromURL(urlTranslate(r.PushURL)) + } + remotes = append(remotes, &Remote{ + Remote: r, + Owner: owner, + Repo: repo, + }) } - - for _, name := range names { - urlMap := remotesMap[name] - repo, err := repoFromURL(urlMap["fetch"]) - if err != nil { - repo, err = repoFromURL(urlMap["push"]) - } - if err == nil { - remotes = append(remotes, &Remote{ - Name: name, - Owner: repo.Owner, - Repo: repo.Name, - }) - } - } - return } -func repoFromURL(u string) (*GitHubRepository, error) { - url, err := git.ParseURL(u) - if err != nil { - return nil, err +func repoFromURL(u *url.URL) (string, string, error) { + if !strings.EqualFold(u.Hostname(), defaultHostname) { + return "", "", fmt.Errorf("unsupported hostname: %s", u.Hostname()) } - if url.Hostname() != defaultHostname { - return nil, fmt.Errorf("invalid hostname: %s", url.Hostname()) - } - parts := strings.SplitN(strings.TrimPrefix(url.Path, "/"), "/", 3) + parts := strings.SplitN(strings.TrimPrefix(u.Path, "/"), "/", 3) if len(parts) < 2 { - return nil, fmt.Errorf("invalid path: %s", url.Path) + return "", "", fmt.Errorf("invalid path: %s", u.Path) } - return &GitHubRepository{ - Owner: parts[0], - Name: strings.TrimSuffix(parts[1], ".git"), - }, nil + return parts[0], strings.TrimSuffix(parts[1], ".git"), nil } diff --git a/context/remote_test.go b/context/remote_test.go index 70b49c4e5..359fcaa7f 100644 --- a/context/remote_test.go +++ b/context/remote_test.go @@ -2,67 +2,43 @@ package context import ( "errors" + "net/url" "testing" "github.com/github/gh-cli/git" ) func Test_repoFromURL(t *testing.T) { - git.InitSSHAliasMap(nil) - - r, err := repoFromURL("http://github.com/monalisa/octo-cat.git") + u, _ := url.Parse("http://github.com/monalisa/octo-cat.git") + owner, repo, err := repoFromURL(u) eq(t, err, nil) - eq(t, r, &GitHubRepository{Owner: "monalisa", Name: "octo-cat"}) + eq(t, owner, "monalisa") + eq(t, repo, "octo-cat") } func Test_repoFromURL_invalid(t *testing.T) { - git.InitSSHAliasMap(nil) - - _, err := repoFromURL("https://example.com/one/two") - eq(t, err, errors.New(`invalid hostname: example.com`)) - - _, err = repoFromURL("/path/to/disk") - eq(t, err, errors.New(`invalid hostname: `)) -} - -func Test_repoFromURL_SSH(t *testing.T) { - git.InitSSHAliasMap(map[string]string{ - "gh": "github.com", - "github.com": "ssh.github.com", - }) - - r, err := repoFromURL("git@gh:monalisa/octo-cat") - eq(t, err, nil) - eq(t, r, &GitHubRepository{Owner: "monalisa", Name: "octo-cat"}) - - r, err = repoFromURL("git@github.com:monalisa/octo-cat") - eq(t, err, nil) - eq(t, r, &GitHubRepository{Owner: "monalisa", Name: "octo-cat"}) -} - -func Test_parseRemotes(t *testing.T) { - git.InitSSHAliasMap(nil) - - remoteList := []string{ - "mona\tgit@github.com:monalisa/myfork.git (fetch)", - "origin\thttps://github.com/monalisa/octo-cat.git (fetch)", - "origin\thttps://github.com/monalisa/octo-cat-push.git (push)", - "upstream\thttps://example.com/nowhere.git (fetch)", - "upstream\thttps://github.com/hubot/tools (push)", + cases := [][]string{ + []string{ + "https://example.com/one/two", + "unsupported hostname: example.com", + }, + []string{ + "/path/to/disk", + "unsupported hostname: ", + }, + } + for _, c := range cases { + u, _ := url.Parse(c[0]) + _, _, err := repoFromURL(u) + eq(t, err, errors.New(c[1])) } - r := parseRemotes(remoteList) - eq(t, len(r), 3) - - eq(t, r[0], &Remote{Name: "mona", Owner: "monalisa", Repo: "myfork"}) - eq(t, r[1], &Remote{Name: "origin", Owner: "monalisa", Repo: "octo-cat"}) - eq(t, r[2], &Remote{Name: "upstream", Owner: "hubot", Repo: "tools"}) } func Test_Remotes_FindByName(t *testing.T) { list := Remotes{ - &Remote{Name: "mona", Owner: "monalisa", Repo: "myfork"}, - &Remote{Name: "origin", Owner: "monalisa", Repo: "octo-cat"}, - &Remote{Name: "upstream", Owner: "hubot", Repo: "tools"}, + &Remote{Remote: &git.Remote{Name: "mona"}, Owner: "monalisa", Repo: "myfork"}, + &Remote{Remote: &git.Remote{Name: "origin"}, Owner: "monalisa", Repo: "octo-cat"}, + &Remote{Remote: &git.Remote{Name: "upstream"}, Owner: "hubot", Repo: "tools"}, } r, err := list.FindByName("upstream", "origin") diff --git a/git/git.go b/git/git.go index b9585024c..b3b63388e 100644 --- a/git/git.go +++ b/git/git.go @@ -165,7 +165,7 @@ func Log(sha1, sha2 string) (string, error) { return string(outputs), nil } -func Remotes() ([]string, error) { +func listRemotes() ([]string, error) { remoteCmd := exec.Command("git", "remote", "-v") remoteCmd.Stderr = nil output, err := remoteCmd.Output() diff --git a/git/remote.go b/git/remote.go new file mode 100644 index 000000000..ba29049c2 --- /dev/null +++ b/git/remote.go @@ -0,0 +1,69 @@ +package git + +import ( + "net/url" + "regexp" + "strings" +) + +var remoteRE = regexp.MustCompile(`(.+)\s+(.+)\s+\((push|fetch)\)`) + +// RemoteSet is a slice of git remotes +type RemoteSet []*Remote + +// Remote is a parsed git remote +type Remote struct { + Name string + FetchURL *url.URL + PushURL *url.URL +} + +func (r *Remote) String() string { + return r.Name +} + +// Remotes gets the git remotes set for the current repo +func Remotes() (RemoteSet, error) { + list, err := listRemotes() + if err != nil { + return nil, err + } + return parseRemotes(list), nil +} + +func parseRemotes(gitRemotes []string) (remotes RemoteSet) { + for _, r := range gitRemotes { + match := remoteRE.FindStringSubmatch(r) + if match == nil { + continue + } + name := strings.TrimSpace(match[1]) + urlStr := strings.TrimSpace(match[2]) + urlType := strings.TrimSpace(match[3]) + + var rem *Remote + if len(remotes) > 0 { + rem = remotes[len(remotes)-1] + if name != rem.Name { + rem = nil + } + } + if rem == nil { + rem = &Remote{Name: name} + remotes = append(remotes, rem) + } + + u, err := ParseURL(urlStr) + if err != nil { + continue + } + + switch urlType { + case "fetch": + rem.FetchURL = u + case "push": + rem.PushURL = u + } + } + return +} diff --git a/git/remote_test.go b/git/remote_test.go new file mode 100644 index 000000000..2e7d30cb6 --- /dev/null +++ b/git/remote_test.go @@ -0,0 +1,31 @@ +package git + +import "testing" + +func Test_parseRemotes(t *testing.T) { + remoteList := []string{ + "mona\tgit@github.com:monalisa/myfork.git (fetch)", + "origin\thttps://github.com/monalisa/octo-cat.git (fetch)", + "origin\thttps://github.com/monalisa/octo-cat-push.git (push)", + "upstream\thttps://example.com/nowhere.git (fetch)", + "upstream\thttps://github.com/hubot/tools (push)", + "zardoz\thttps://example.com/zed.git (push)", + } + r := parseRemotes(remoteList) + eq(t, len(r), 4) + + eq(t, r[0].Name, "mona") + eq(t, r[0].FetchURL.String(), "ssh://git@github.com/monalisa/myfork.git") + if r[0].PushURL != nil { + t.Errorf("expected no PushURL, got %q", r[0].PushURL) + } + eq(t, r[1].Name, "origin") + eq(t, r[1].FetchURL.Path, "/monalisa/octo-cat.git") + eq(t, r[1].PushURL.Path, "/monalisa/octo-cat-push.git") + + eq(t, r[2].Name, "upstream") + eq(t, r[2].FetchURL.Host, "example.com") + eq(t, r[2].PushURL.Host, "github.com") + + eq(t, r[3].Name, "zardoz") +} diff --git a/git/ssh_config.go b/git/ssh_config.go index 47d46ac46..1ac5e828e 100644 --- a/git/ssh_config.go +++ b/git/ssh_config.go @@ -3,6 +3,7 @@ package git import ( "bufio" "io" + "net/url" "os" "path/filepath" "regexp" @@ -21,9 +22,32 @@ func init() { sshTokenRE = regexp.MustCompile(`%[%h]`) } -type sshAliasMap map[string]string +// SSHAliasMap encapsulates the translation of SSH hostname aliases +type SSHAliasMap map[string]string -func sshParseFiles() sshAliasMap { +// Translator returns a function that applies hostname aliases to URLs +func (m SSHAliasMap) Translator() func(*url.URL) *url.URL { + return func(u *url.URL) *url.URL { + if u.Scheme != "ssh" { + return u + } + resolvedHost, ok := m[u.Hostname()] + if !ok { + return u + } + // FIXME: cleanup domain logic + if strings.EqualFold(u.Hostname(), "github.com") && strings.EqualFold(resolvedHost, "ssh.github.com") { + return u + } + newURL, _ := url.Parse(u.String()) + newURL.Host = resolvedHost + return newURL + } +} + +// ParseSSHConfig constructs a map of SSH hostname aliases based on user and +// system configuration files +func ParseSSHConfig() SSHAliasMap { configFiles := []string{ "/etc/ssh_config", "/etc/ssh/ssh_config", @@ -45,15 +69,15 @@ func sshParseFiles() sshAliasMap { return sshParse(openFiles...) } -func sshParse(r ...io.Reader) sshAliasMap { - config := sshAliasMap{} +func sshParse(r ...io.Reader) SSHAliasMap { + config := SSHAliasMap{} for _, file := range r { sshParseConfig(config, file) } return config } -func sshParseConfig(c sshAliasMap, file io.Reader) error { +func sshParseConfig(c SSHAliasMap, file io.Reader) error { hosts := []string{"*"} scanner := bufio.NewScanner(file) for scanner.Scan() { diff --git a/git/ssh_config_test.go b/git/ssh_config_test.go index 12de53bd8..35a0c93e6 100644 --- a/git/ssh_config_test.go +++ b/git/ssh_config_test.go @@ -1,6 +1,7 @@ package git import ( + "net/url" "reflect" "strings" "testing" @@ -8,6 +9,7 @@ import ( // TODO: extract assertion helpers into a shared package func eq(t *testing.T, got interface{}, expected interface{}) { + t.Helper() if !reflect.DeepEqual(got, expected) { t.Errorf("expected: %v, got: %v", expected, got) } @@ -25,3 +27,24 @@ func Test_sshParse(t *testing.T) { eq(t, m["bar"], "%bar.net%") eq(t, m["nonexist"], "") } + +func Test_Translator(t *testing.T) { + m := SSHAliasMap{ + "gh": "github.com", + "github.com": "ssh.github.com", + } + tr := m.Translator() + + cases := [][]string{ + []string{"ssh://gh/o/r", "ssh://github.com/o/r"}, + []string{"ssh://github.com/o/r", "ssh://github.com/o/r"}, + []string{"https://gh/o/r", "https://gh/o/r"}, + } + for _, c := range cases { + u, _ := url.Parse(c[0]) + got := tr(u) + if got.String() != c[1] { + t.Errorf("%q: expected %q, got %q", c[0], c[1], got) + } + } +} diff --git a/git/url.go b/git/url.go index 792d75350..55e11c08f 100644 --- a/git/url.go +++ b/git/url.go @@ -7,8 +7,7 @@ import ( ) var ( - cachedSSHConfig sshAliasMap - protocolRe = regexp.MustCompile("^[a-zA-Z_+-]+://") + protocolRe = regexp.MustCompile("^[a-zA-Z_+-]+://") ) // ParseURL normalizes git remote urls @@ -41,28 +40,5 @@ func ParseURL(rawURL string) (u *url.URL, err error) { u.Host = u.Host[0:idx] } - if cachedSSHConfig == nil { - return - } - sshHost := cachedSSHConfig[u.Host] - // ignore replacing host that fixes for limited network - // https://help.github.com/articles/using-ssh-over-the-https-port - ignoredHost := u.Host == "github.com" && sshHost == "ssh.github.com" - if !ignoredHost && sshHost != "" { - u.Host = sshHost - } - return } - -// InitSSHAliasMap prepares globally cached SSH hostname alias mappings -func InitSSHAliasMap(m map[string]string) { - if m == nil { - cachedSSHConfig = sshParseFiles() - return - } - cachedSSHConfig = sshAliasMap{} - for k, v := range m { - cachedSSHConfig[k] = v - } -} diff --git a/test/fixtures/prList.json b/test/fixtures/prList.json index d3a7e89f2..444fa5e01 100644 --- a/test/fixtures/prList.json +++ b/test/fixtures/prList.json @@ -1,4 +1,4 @@ -{ +{"data":{ "repository": { "pullRequests": { "edges": [ @@ -47,4 +47,4 @@ ], "pageInfo": { "hasNextPage": false } } -} +}} \ No newline at end of file diff --git a/test/fixtures/prView.json b/test/fixtures/prView.json index d3a7e89f2..444fa5e01 100644 --- a/test/fixtures/prView.json +++ b/test/fixtures/prView.json @@ -1,4 +1,4 @@ -{ +{"data":{ "repository": { "pullRequests": { "edges": [ @@ -47,4 +47,4 @@ ], "pageInfo": { "hasNextPage": false } } -} +}} \ No newline at end of file diff --git a/test/fixtures/prView_NoActiveBranch.json b/test/fixtures/prView_NoActiveBranch.json index dd7ddbafd..7c1fb0c05 100644 --- a/test/fixtures/prView_NoActiveBranch.json +++ b/test/fixtures/prView_NoActiveBranch.json @@ -1,4 +1,4 @@ -{ +{"data":{ "repository": { "pullRequests": { "edges": [] @@ -12,4 +12,4 @@ "edges": [], "pageInfo": { "hasNextPage": false } } -} +}} \ No newline at end of file diff --git a/test/helpers.go b/test/helpers.go index 90b16778a..d9276565c 100644 --- a/test/helpers.go +++ b/test/helpers.go @@ -1,7 +1,6 @@ package test import ( - "encoding/json" "fmt" "io/ioutil" "os" @@ -9,7 +8,6 @@ import ( "path/filepath" "strings" - "github.com/github/gh-cli/api" "github.com/spf13/cobra" ) @@ -67,30 +65,6 @@ func UseTempGitRepo() *TempGitRepo { return &TempGitRepo{Remote: remotePath, TearDown: tearDown} } -func MockGraphQLResponse(fixturePath string) (teardown func()) { - pwd, _ := os.Getwd() - fixturePath = filepath.Join(pwd, "..", fixturePath) - - originalGraphQL := api.GraphQL - api.GraphQL = func(query string, variables map[string]string, v interface{}) error { - contents, err := ioutil.ReadFile(fixturePath) - if err != nil { - return err - } - - json.Unmarshal(contents, &v) - if err != nil { - return err - } - - return nil - } - - return func() { - api.GraphQL = originalGraphQL - } -} - func RunCommand(root *cobra.Command, s string) (string, error) { var err error output := captureOutput(func() {