diff --git a/api/client.go b/api/client.go index cafbbb507..a3fa8e72a 100644 --- a/api/client.go +++ b/api/client.go @@ -4,14 +4,69 @@ import ( "bytes" "encoding/json" "fmt" + "io" "io/ioutil" "net/http" - "os" - - "github.com/github/gh-cli/context" - "github.com/github/gh-cli/version" ) +// ClientOption represents an argument to NewClient +type ClientOption = func(http.RoundTripper) http.RoundTripper + +// NewClient initializes a Client +func NewClient(opts ...ClientOption) *Client { + tr := http.DefaultTransport + for _, opt := range opts { + tr = opt(tr) + } + http := &http.Client{Transport: tr} + client := &Client{http: http} + return client +} + +// AddHeader turns a RoundTripper into one that adds a request header +func AddHeader(name, value string) ClientOption { + return func(tr http.RoundTripper) http.RoundTripper { + return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) { + req.Header.Add(name, value) + return tr.RoundTrip(req) + }} + } +} + +// VerboseLog enables request/response logging within a RoundTripper +func VerboseLog(out io.Writer) ClientOption { + return func(tr http.RoundTripper) http.RoundTripper { + return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) { + fmt.Fprintf(out, "> %s %s\n", req.Method, req.URL.RequestURI()) + res, err := tr.RoundTrip(req) + if err == nil { + fmt.Fprintf(out, "< HTTP %s\n", res.Status) + } + return res, err + }} + } +} + +// ReplaceTripper substitutes the underlying RoundTripper with a custom one +func ReplaceTripper(tr http.RoundTripper) ClientOption { + return func(http.RoundTripper) http.RoundTripper { + return tr + } +} + +type funcTripper struct { + roundTrip func(*http.Request) (*http.Response, error) +} + +func (tr funcTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return tr.roundTrip(req) +} + +// Client facilitates making HTTP requests to the GitHub API +type Client struct { + http *http.Client +} + type graphQLResponse struct { Data interface{} Errors []struct { @@ -19,32 +74,8 @@ type graphQLResponse struct { } } -/* -GraphQL: Declared as an external variable so it can be mocked in tests - -type repoResponse struct { - Repository struct { - CreatedAt string - } -} - -query := `query { - repository(owner: "golang", name: "go") { - createdAt - } -}` - -variables := map[string]string{} - -var resp repoResponse -err := graphql(query, map[string]string{}, &resp) -if err != nil { - panic(err) -} - -fmt.Printf("%+v\n", resp) -*/ -var GraphQL = func(query string, variables map[string]string, data interface{}) error { +// GraphQL performs a GraphQL request and parses the response +func (c Client) GraphQL(query string, variables map[string]interface{}, data interface{}) error { url := "https://api.github.com/graphql" reqBody, err := json.Marshal(map[string]interface{}{"query": query, "variables": variables}) if err != nil { @@ -56,42 +87,31 @@ var GraphQL = func(query string, variables map[string]string, data interface{}) return err } - token, err := context.Current().AuthToken() - if err != nil { - return err - } - - req.Header.Set("Authorization", "token "+token) req.Header.Set("Content-Type", "application/json; charset=utf-8") - req.Header.Set("User-Agent", "GitHub CLI "+version.Version) - debugRequest(req, string(reqBody)) - - client := &http.Client{} - resp, err := client.Do(req) + resp, err := c.http.Do(req) if err != nil { return err } defer resp.Body.Close() + return handleResponse(resp, data) +} + +func handleResponse(resp *http.Response, data interface{}) error { + success := resp.StatusCode >= 200 && resp.StatusCode < 300 + + if !success { + return handleHTTPError(resp) + } + body, err := ioutil.ReadAll(resp.Body) if err != nil { return err } - debugResponse(resp, string(body)) - return handleResponse(resp, body, data) -} - -func handleResponse(resp *http.Response, body []byte, data interface{}) error { - success := resp.StatusCode >= 200 && resp.StatusCode < 300 - - if !success { - return handleHTTPError(resp, body) - } - gr := &graphQLResponse{Data: data} - err := json.Unmarshal(body, &gr) + err = json.Unmarshal(body, &gr) if err != nil { return err } @@ -107,12 +127,16 @@ func handleResponse(resp *http.Response, body []byte, data interface{}) error { } -func handleHTTPError(resp *http.Response, body []byte) error { +func handleHTTPError(resp *http.Response) error { var message string var parsedBody struct { Message string `json:"message"` } - err := json.Unmarshal(body, &parsedBody) + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return err + } + err = json.Unmarshal(body, &parsedBody) if err != nil { message = string(body) } else { @@ -121,19 +145,3 @@ func handleHTTPError(resp *http.Response, body []byte) error { return fmt.Errorf("http error, '%s' failed (%d): '%s'", resp.Request.URL, resp.StatusCode, message) } - -func debugRequest(req *http.Request, body string) { - if _, ok := os.LookupEnv("DEBUG"); !ok { - return - } - - fmt.Printf("DEBUG: GraphQL request to %s:\n %s\n\n", req.URL, body) -} - -func debugResponse(resp *http.Response, body string) { - if _, ok := os.LookupEnv("DEBUG"); !ok { - return - } - - fmt.Printf("DEBUG: GraphQL response:\n%+v\n\n%s\n\n", resp, body) -} diff --git a/api/queries.go b/api/queries.go index b9140670a..e1e078138 100644 --- a/api/queries.go +++ b/api/queries.go @@ -2,8 +2,6 @@ package api import ( "fmt" - - "github.com/github/gh-cli/context" ) type PullRequestsPayload struct { @@ -19,7 +17,12 @@ type PullRequest struct { HeadRefName string } -func PullRequests() (*PullRequestsPayload, error) { +type Repo interface { + RepoName() string + RepoOwner() string +} + +func PullRequests(client *Client, ghRepo Repo, currentBranch, currentUsername string) (*PullRequestsPayload, error) { type edges struct { Edges []struct { Node PullRequest @@ -48,7 +51,7 @@ func PullRequests() (*PullRequestsPayload, error) { query($owner: String!, $repo: String!, $headRefName: String!, $viewerQuery: String!, $reviewerQuery: String!, $per_page: Int = 10) { repository(owner: $owner, name: $repo) { - pullRequests(headRefName: $headRefName, first: 1) { + pullRequests(headRefName: $headRefName, states: OPEN, first: 1) { edges { node { ...pr @@ -79,26 +82,13 @@ func PullRequests() (*PullRequestsPayload, error) { } ` - ghRepo, err := context.Current().BaseRepo() - if err != nil { - return nil, err - } - currentBranch, err := context.Current().Branch() - if err != nil { - return nil, err - } - currentUsername, err := context.Current().AuthLogin() - if err != nil { - return nil, err - } - - owner := ghRepo.Owner - repo := ghRepo.Name + owner := ghRepo.RepoOwner() + repo := ghRepo.RepoName() viewerQuery := fmt.Sprintf("repo:%s/%s state:open is:pr author:%s", owner, repo, currentUsername) reviewerQuery := fmt.Sprintf("repo:%s/%s state:open review-requested:%s", owner, repo, currentUsername) - variables := map[string]string{ + variables := map[string]interface{}{ "viewerQuery": viewerQuery, "reviewerQuery": reviewerQuery, "owner": owner, @@ -107,7 +97,7 @@ func PullRequests() (*PullRequestsPayload, error) { } var resp response - err = GraphQL(query, variables, &resp) + err := client.GraphQL(query, variables, &resp) if err != nil { return nil, err } @@ -135,3 +125,49 @@ func PullRequests() (*PullRequestsPayload, error) { return &payload, nil } + +func PullRequestsForBranch(client *Client, ghRepo Repo, branch string) ([]PullRequest, error) { + type response struct { + Repository struct { + PullRequests struct { + Edges []struct { + Node PullRequest + } + } + } + } + + query := ` + query($owner: String!, $repo: String!, $headRefName: String!) { + repository(owner: $owner, name: $repo) { + pullRequests(headRefName: $headRefName, states: OPEN, first: 1) { + edges { + node { + number + title + url + } + } + } + } + }` + + variables := map[string]interface{}{ + "owner": ghRepo.RepoOwner(), + "repo": ghRepo.RepoName(), + "headRefName": branch, + } + + var resp response + err := client.GraphQL(query, variables, &resp) + if err != nil { + return nil, err + } + + prs := []PullRequest{} + for _, edge := range resp.Repository.PullRequests.Edges { + prs = append(prs, edge.Node) + } + + return prs, nil +} diff --git a/command/pr.go b/command/pr.go index 3ce11bc9b..f1dfe8fb0 100644 --- a/command/pr.go +++ b/command/pr.go @@ -38,7 +38,25 @@ work with pull requests.`, func prList(cmd *cobra.Command, args []string) error { ctx := contextForCommand(cmd) - prPayload, err := api.PullRequests() + apiClient, err := apiClientForContext(ctx) + if err != nil { + return err + } + + baseRepo, err := ctx.BaseRepo() + if err != nil { + return err + } + currentBranch, err := ctx.Branch() + if err != nil { + return err + } + currentUser, err := ctx.AuthLogin() + if err != nil { + return err + } + + prPayload, err := api.PullRequests(apiClient, baseRepo, currentBranch, currentUser) if err != nil { return err } @@ -47,10 +65,6 @@ func prList(cmd *cobra.Command, args []string) error { if prPayload.CurrentPR != nil { printPrs(*prPayload.CurrentPR) } else { - currentBranch, err := ctx.Branch() - if err != nil { - return err - } message := fmt.Sprintf(" There is no pull request associated with %s", utils.Cyan("["+currentBranch+"]")) printMessage(message) } @@ -86,23 +100,27 @@ func prView(cmd *cobra.Command, args []string) error { if len(args) > 0 { if prNumber, err := strconv.Atoi(args[0]); err == nil { // TODO: move URL generation into GitHubRepository - openURL = fmt.Sprintf("https://github.com/%s/%s/pull/%d", baseRepo.Owner, baseRepo.Name, prNumber) + openURL = fmt.Sprintf("https://github.com/%s/%s/pull/%d", baseRepo.RepoOwner(), baseRepo.RepoName(), prNumber) } else { return fmt.Errorf("invalid pull request number: '%s'", args[0]) } } else { - prPayload, err := api.PullRequests() + apiClient, err := apiClientForContext(ctx) if err != nil { return err - } else if prPayload.CurrentPR == nil { - branch, err := ctx.Branch() - if err != nil { - return err - } - fmt.Printf("The [%s] branch has no open PRs", branch) - return nil } - openURL = prPayload.CurrentPR.URL + currentBranch, err := ctx.Branch() + if err != nil { + return err + } + + prs, err := api.PullRequestsForBranch(apiClient, baseRepo, currentBranch) + if err != nil { + return err + } else if len(prs) < 1 { + return fmt.Errorf("the '%s' branch has no open pull requests", currentBranch) + } + openURL = prs[0].URL } fmt.Printf("Opening %s in your browser.\n", openURL) diff --git a/command/pr_test.go b/command/pr_test.go index 10a47b5e9..e4a5122d7 100644 --- a/command/pr_test.go +++ b/command/pr_test.go @@ -1,21 +1,40 @@ package command import ( + "os" "regexp" "testing" + "github.com/github/gh-cli/api" "github.com/github/gh-cli/context" "github.com/github/gh-cli/test" "github.com/github/gh-cli/utils" ) -func TestPRList(t *testing.T) { - ctx := context.InitBlankContext() - ctx.SetBaseRepo("github/FAKE-GITHUB-REPO-NAME") - ctx.SetBranch("master") +func initBlankContext(repo, branch string) { + initContext = func() context.Context { + ctx := context.NewBlank() + ctx.SetBaseRepo(repo) + ctx.SetBranch(branch) + return ctx + } +} - teardown := test.MockGraphQLResponse("test/fixtures/prList.json") - defer teardown() +func initFakeHTTP() *api.FakeHTTP { + http := &api.FakeHTTP{} + apiClientForContext = func(context.Context) (*api.Client, error) { + return api.NewClient(api.ReplaceTripper(http)), nil + } + return http +} + +func TestPRList(t *testing.T) { + initBlankContext("OWNER/REPO", "master") + http := initFakeHTTP() + + jsonFile, _ := os.Open("../test/fixtures/prList.json") + defer jsonFile.Close() + http.StubResponse(200, jsonFile) output, err := test.RunCommand(RootCmd, "pr list") if err != nil { @@ -37,11 +56,12 @@ func TestPRList(t *testing.T) { } func TestPRView(t *testing.T) { - teardown := test.MockGraphQLResponse("test/fixtures/prView.json") - defer teardown() + initBlankContext("OWNER/REPO", "master") + http := initFakeHTTP() - gitRepo := test.UseTempGitRepo() - defer gitRepo.TearDown() + jsonFile, _ := os.Open("../test/fixtures/prView.json") + defer jsonFile.Close() + http.StubResponse(200, jsonFile) teardown, callCount := mockOpenInBrowser() defer teardown() @@ -61,24 +81,21 @@ func TestPRView(t *testing.T) { } func TestPRView_NoActiveBranch(t *testing.T) { - teardown := test.MockGraphQLResponse("test/fixtures/prView_NoActiveBranch.json") - defer teardown() + initBlankContext("OWNER/REPO", "master") + http := initFakeHTTP() - gitRepo := test.UseTempGitRepo() - defer gitRepo.TearDown() + jsonFile, _ := os.Open("../test/fixtures/prView_NoActiveBranch.json") + defer jsonFile.Close() + http.StubResponse(200, jsonFile) teardown, callCount := mockOpenInBrowser() defer teardown() output, err := test.RunCommand(RootCmd, "pr view") - if err != nil { + if err == nil || err.Error() != "the 'master' branch has no open pull requests" { t.Errorf("error running command `pr view`: %v", err) } - if output == "" { - t.Errorf("command output expected got an empty string") - } - if *callCount > 0 { t.Errorf("OpenInBrowser should NOT be called but was called %d time(s)", *callCount) } diff --git a/command/root.go b/command/root.go index 76b5220bf..c0b676c86 100644 --- a/command/root.go +++ b/command/root.go @@ -4,7 +4,9 @@ import ( "fmt" "os" + "github.com/github/gh-cli/api" "github.com/github/gh-cli/context" + "github.com/github/gh-cli/version" "github.com/spf13/cobra" ) @@ -12,6 +14,8 @@ import ( func init() { RootCmd.PersistentFlags().StringP("repo", "R", "", "current GitHub repository") RootCmd.PersistentFlags().StringP("current-branch", "B", "", "current git branch") + // TODO: + // RootCmd.PersistentFlags().BoolP("verbose", "V", false, "enable verbose output") } // RootCmd is the entry point of command-line execution @@ -25,16 +29,38 @@ var RootCmd = &cobra.Command{ }, } -func contextForCommand(cmd *cobra.Command) context.Context { +// overriden in tests +var initContext = func() context.Context { ctx := context.New() if repo := os.Getenv("GH_REPO"); repo != "" { ctx.SetBaseRepo(repo) } - if repo, err := cmd.Flags().GetString("repo"); err == nil { + return ctx +} + +func contextForCommand(cmd *cobra.Command) context.Context { + ctx := initContext() + if repo, err := cmd.Flags().GetString("repo"); err == nil && repo != "" { ctx.SetBaseRepo(repo) } - if branch, err := cmd.Flags().GetString("current-branch"); err == nil { + if branch, err := cmd.Flags().GetString("current-branch"); err == nil && branch != "" { ctx.SetBranch(branch) } return ctx } + +// overriden in tests +var apiClientForContext = func(ctx context.Context) (*api.Client, error) { + token, err := ctx.AuthToken() + if err != nil { + return nil, err + } + opts := []api.ClientOption{ + api.AddHeader("Authorization", fmt.Sprintf("token %s", token)), + api.AddHeader("User-Agent", fmt.Sprintf("GitHub CLI %s", version.Version)), + } + if verbose := os.Getenv("DEBUG"); verbose != "" { + opts = append(opts, api.VerboseLog(os.Stderr)) + } + return api.NewClient(opts...), nil +} diff --git a/context/blank_context.go b/context/blank_context.go index 892f9bfdd..d5ad2cafe 100644 --- a/context/blank_context.go +++ b/context/blank_context.go @@ -5,6 +5,7 @@ import ( "strings" ) +// NewBlank initializes a blank Context suitable for testing func NewBlank() Context { return &blankContext{} } @@ -14,7 +15,19 @@ type blankContext struct { authToken string authLogin string branch string - baseRepo *GitHubRepository + baseRepo GitHubRepository +} + +type ghRepo struct { + owner string + name string +} + +func (r ghRepo) RepoOwner() string { + return r.owner +} +func (r ghRepo) RepoName() string { + return r.name } func (c *blankContext) AuthToken() (string, error) { @@ -44,7 +57,7 @@ func (c *blankContext) Remotes() (Remotes, error) { return Remotes{}, nil } -func (c *blankContext) BaseRepo() (*GitHubRepository, error) { +func (c *blankContext) BaseRepo() (GitHubRepository, error) { if c.baseRepo == nil { return nil, fmt.Errorf("base repo was not initialized") } @@ -54,9 +67,6 @@ func (c *blankContext) BaseRepo() (*GitHubRepository, error) { func (c *blankContext) SetBaseRepo(nwo string) { parts := strings.SplitN(nwo, "/", 2) if len(parts) == 2 { - c.baseRepo = &GitHubRepository{ - Owner: parts[0], - Name: parts[1], - } + c.baseRepo = &ghRepo{parts[0], parts[1]} } } diff --git a/context/context.go b/context/context.go index f73520f6e..7e165299d 100644 --- a/context/context.go +++ b/context/context.go @@ -15,12 +15,19 @@ type Context interface { Branch() (string, error) SetBranch(string) Remotes() (Remotes, error) - BaseRepo() (*GitHubRepository, error) + BaseRepo() (GitHubRepository, error) SetBaseRepo(string) } +// GitHubRepository is anything that can be mapped to an OWNER/REPO pair +type GitHubRepository interface { + RepoOwner() string + RepoName() string +} + +// New initializes a Context that reads from the filesystem func New() Context { - return &blankContext{} + return &fsContext{} } // A Context implementation that queries the filesystem @@ -28,7 +35,7 @@ type fsContext struct { config *configEntry remotes Remotes branch string - baseRepo *GitHubRepository + baseRepo GitHubRepository authToken string } @@ -103,7 +110,7 @@ func (c *fsContext) Remotes() (Remotes, error) { return c.remotes, nil } -func (c *fsContext) BaseRepo() (*GitHubRepository, error) { +func (c *fsContext) BaseRepo() (GitHubRepository, error) { if c.baseRepo != nil { return c.baseRepo, nil } @@ -117,19 +124,13 @@ func (c *fsContext) BaseRepo() (*GitHubRepository, error) { return nil, err } - c.baseRepo = &GitHubRepository{ - Owner: rem.Owner, - Name: rem.Repo, - } + c.baseRepo = rem return c.baseRepo, nil } func (c *fsContext) SetBaseRepo(nwo string) { parts := strings.SplitN(nwo, "/", 2) if len(parts) == 2 { - c.baseRepo = &GitHubRepository{ - Owner: parts[0], - Name: parts[1], - } + c.baseRepo = &ghRepo{parts[0], parts[1]} } } diff --git a/context/remote.go b/context/remote.go index f30a0958b..9f3b228dc 100644 --- a/context/remote.go +++ b/context/remote.go @@ -32,10 +32,14 @@ type Remote struct { Repo string } -// GitHubRepository represents a GitHub respository -type GitHubRepository struct { - Name string - Owner string +// RepoName is the name of the GitHub repository +func (r Remote) RepoName() string { + return r.Repo +} + +// RepoOwner is the name of the GitHub account that owns the repo +func (r Remote) RepoOwner() string { + return r.Owner } // TODO: accept an interface instead of git.RemoteSet diff --git a/test/fixtures/prList.json b/test/fixtures/prList.json index d3a7e89f2..444fa5e01 100644 --- a/test/fixtures/prList.json +++ b/test/fixtures/prList.json @@ -1,4 +1,4 @@ -{ +{"data":{ "repository": { "pullRequests": { "edges": [ @@ -47,4 +47,4 @@ ], "pageInfo": { "hasNextPage": false } } -} +}} \ No newline at end of file diff --git a/test/fixtures/prView.json b/test/fixtures/prView.json index d3a7e89f2..444fa5e01 100644 --- a/test/fixtures/prView.json +++ b/test/fixtures/prView.json @@ -1,4 +1,4 @@ -{ +{"data":{ "repository": { "pullRequests": { "edges": [ @@ -47,4 +47,4 @@ ], "pageInfo": { "hasNextPage": false } } -} +}} \ No newline at end of file diff --git a/test/fixtures/prView_NoActiveBranch.json b/test/fixtures/prView_NoActiveBranch.json index dd7ddbafd..7c1fb0c05 100644 --- a/test/fixtures/prView_NoActiveBranch.json +++ b/test/fixtures/prView_NoActiveBranch.json @@ -1,4 +1,4 @@ -{ +{"data":{ "repository": { "pullRequests": { "edges": [] @@ -12,4 +12,4 @@ "edges": [], "pageInfo": { "hasNextPage": false } } -} +}} \ No newline at end of file diff --git a/test/helpers.go b/test/helpers.go index 90b16778a..d9276565c 100644 --- a/test/helpers.go +++ b/test/helpers.go @@ -1,7 +1,6 @@ package test import ( - "encoding/json" "fmt" "io/ioutil" "os" @@ -9,7 +8,6 @@ import ( "path/filepath" "strings" - "github.com/github/gh-cli/api" "github.com/spf13/cobra" ) @@ -67,30 +65,6 @@ func UseTempGitRepo() *TempGitRepo { return &TempGitRepo{Remote: remotePath, TearDown: tearDown} } -func MockGraphQLResponse(fixturePath string) (teardown func()) { - pwd, _ := os.Getwd() - fixturePath = filepath.Join(pwd, "..", fixturePath) - - originalGraphQL := api.GraphQL - api.GraphQL = func(query string, variables map[string]string, v interface{}) error { - contents, err := ioutil.ReadFile(fixturePath) - if err != nil { - return err - } - - json.Unmarshal(contents, &v) - if err != nil { - return err - } - - return nil - } - - return func() { - api.GraphQL = originalGraphQL - } -} - func RunCommand(root *cobra.Command, s string) (string, error) { var err error output := captureOutput(func() {