diff --git a/Makefile b/Makefile index b925c11fd..fd219d42d 100644 --- a/Makefile +++ b/Makefile @@ -8,15 +8,26 @@ ifdef SOURCE_DATE_EPOCH else BUILD_DATE ?= $(shell date "$(DATE_FMT)") endif -LDFLAGS := -X github.com/cli/cli/command.Version=$(GH_VERSION) $(LDFLAGS) -LDFLAGS := -X github.com/cli/cli/command.BuildDate=$(BUILD_DATE) $(LDFLAGS) + +ifndef CGO_CPPFLAGS + export CGO_CPPFLAGS := $(CPPFLAGS) +endif +ifndef CGO_CFLAGS + export CGO_CFLAGS := $(CFLAGS) +endif +ifndef CGO_LDFLAGS + export CGO_LDFLAGS := $(LDFLAGS) +endif + +GO_LDFLAGS := -X github.com/cli/cli/command.Version=$(GH_VERSION) +GO_LDFLAGS := -X github.com/cli/cli/command.BuildDate=$(BUILD_DATE) ifdef GH_OAUTH_CLIENT_SECRET - LDFLAGS := -X github.com/cli/cli/internal/config.oauthClientID=$(GH_OAUTH_CLIENT_ID) $(LDFLAGS) - LDFLAGS := -X github.com/cli/cli/internal/config.oauthClientSecret=$(GH_OAUTH_CLIENT_SECRET) $(LDFLAGS) + GO_LDFLAGS := -X github.com/cli/cli/internal/config.oauthClientID=$(GH_OAUTH_CLIENT_ID) + GO_LDFLAGS := -X github.com/cli/cli/internal/config.oauthClientSecret=$(GH_OAUTH_CLIENT_SECRET) endif bin/gh: $(BUILD_FILES) - @go build -trimpath -ldflags "$(LDFLAGS)" -o "$@" ./cmd/gh + @go build -trimpath -ldflags "$(GO_LDFLAGS)" -o "$@" ./cmd/gh test: go test ./... diff --git a/api/client.go b/api/client.go index 1c8c1eaa2..74a5a0417 100644 --- a/api/client.go +++ b/api/client.go @@ -7,6 +7,7 @@ import ( "io" "io/ioutil" "net/http" + "net/url" "regexp" "strings" @@ -154,7 +155,21 @@ func (gr GraphQLErrorResponse) Error() string { for _, e := range gr.Errors { errorMessages = append(errorMessages, e.Message) } - return fmt.Sprintf("graphql error: '%s'", strings.Join(errorMessages, ", ")) + return fmt.Sprintf("GraphQL error: %s", strings.Join(errorMessages, "\n")) +} + +// HTTPError is an error returned by a failed API call +type HTTPError struct { + StatusCode int + RequestURL *url.URL + Message string +} + +func (err HTTPError) Error() string { + if err.Message != "" { + return fmt.Sprintf("HTTP %d: %s (%s)", err.StatusCode, err.Message, err.RequestURL) + } + return fmt.Sprintf("HTTP %d (%s)", err.StatusCode, err.RequestURL) } // Returns whether or not scopes are present, appID, and error @@ -283,22 +298,25 @@ func handleResponse(resp *http.Response, data interface{}) error { } func handleHTTPError(resp *http.Response) error { - var message string + httpError := HTTPError{ + StatusCode: resp.StatusCode, + RequestURL: resp.Request.URL, + } + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + httpError.Message = err.Error() + return httpError + } + var parsedBody struct { Message string `json:"message"` } - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - return err - } - err = json.Unmarshal(body, &parsedBody) - if err != nil { - message = string(body) - } else { - message = parsedBody.Message + if err := json.Unmarshal(body, &parsedBody); err == nil { + httpError.Message = parsedBody.Message } - return fmt.Errorf("http error, '%s' failed (%d): '%s'", resp.Request.URL, resp.StatusCode, message) + return httpError } var jsonTypeRE = regexp.MustCompile(`[/+]json($|;)`) diff --git a/api/client_test.go b/api/client_test.go index 4c81bf315..b7c226c8f 100644 --- a/api/client_test.go +++ b/api/client_test.go @@ -2,6 +2,7 @@ package api import ( "bytes" + "errors" "io/ioutil" "reflect" "testing" @@ -46,9 +47,15 @@ func TestGraphQLError(t *testing.T) { client := NewClient(ReplaceTripper(http)) response := struct{}{} - http.StubResponse(200, bytes.NewBufferString(`{"errors":[{"message":"OH NO"}]}`)) + http.StubResponse(200, bytes.NewBufferString(` + { "errors": [ + {"message":"OH NO"}, + {"message":"this is fine"} + ] + }`)) + err := client.GraphQL("", nil, &response) - if err == nil || err.Error() != "graphql error: 'OH NO'" { + if err == nil || err.Error() != "GraphQL error: OH NO\nthis is fine" { t.Fatalf("got %q", err.Error()) } } @@ -66,3 +73,23 @@ func TestRESTGetDelete(t *testing.T) { err := client.REST("DELETE", "applications/CLIENTID/grant", r, nil) eq(t, err, nil) } + +func TestRESTError(t *testing.T) { + http := &httpmock.Registry{} + client := NewClient(ReplaceTripper(http)) + + http.StubResponse(422, bytes.NewBufferString(`{"message": "OH NO"}`)) + + var httpErr HTTPError + err := client.REST("DELETE", "repos/branch", nil, nil) + if err == nil || !errors.As(err, &httpErr) { + t.Fatalf("got %v", err) + } + + if httpErr.StatusCode != 422 { + t.Errorf("expected status code 422, got %d", httpErr.StatusCode) + } + if httpErr.Error() != "HTTP 422: OH NO (https://api.github.com/repos/branch)" { + t.Errorf("got %q", httpErr.Error()) + } +} diff --git a/api/queries_issue.go b/api/queries_issue.go index 56a04edd9..76981f8d7 100644 --- a/api/queries_issue.go +++ b/api/queries_issue.go @@ -198,7 +198,7 @@ func IssueStatus(client *Client, repo ghrepo.Interface, currentUsername string) return &payload, nil } -func IssueList(client *Client, repo ghrepo.Interface, state string, labels []string, assigneeString string, limit int, authorString string) (*IssuesAndTotalCount, error) { +func IssueList(client *Client, repo ghrepo.Interface, state string, labels []string, assigneeString string, limit int, authorString string, mentionString string, milestoneString string) (*IssuesAndTotalCount, error) { var states []string switch state { case "open", "": @@ -212,10 +212,10 @@ func IssueList(client *Client, repo ghrepo.Interface, state string, labels []str } query := fragments + ` - query($owner: String!, $repo: String!, $limit: Int, $endCursor: String, $states: [IssueState!] = OPEN, $labels: [String!], $assignee: String, $author: String) { + query($owner: String!, $repo: String!, $limit: Int, $endCursor: String, $states: [IssueState!] = OPEN, $labels: [String!], $assignee: String, $author: String, $mention: String, $milestone: String) { repository(owner: $owner, name: $repo) { hasIssuesEnabled - issues(first: $limit, after: $endCursor, orderBy: {field: CREATED_AT, direction: DESC}, states: $states, labels: $labels, filterBy: {assignee: $assignee, createdBy: $author}) { + issues(first: $limit, after: $endCursor, orderBy: {field: CREATED_AT, direction: DESC}, states: $states, labels: $labels, filterBy: {assignee: $assignee, createdBy: $author, mentioned: $mention, milestone: $milestone}) { totalCount nodes { ...issue @@ -243,6 +243,12 @@ func IssueList(client *Client, repo ghrepo.Interface, state string, labels []str if authorString != "" { variables["author"] = authorString } + if mentionString != "" { + variables["mention"] = mentionString + } + if milestoneString != "" { + variables["milestone"] = milestoneString + } var response struct { Repository struct { diff --git a/api/queries_issue_test.go b/api/queries_issue_test.go index ffe4aaad1..2ea64f0c6 100644 --- a/api/queries_issue_test.go +++ b/api/queries_issue_test.go @@ -40,7 +40,7 @@ func TestIssueList(t *testing.T) { `)) repo, _ := ghrepo.FromFullName("OWNER/REPO") - _, err := IssueList(client, repo, "open", []string{}, "", 251, "") + _, err := IssueList(client, repo, "open", []string{}, "", 251, "", "", "") if err != nil { t.Fatalf("unexpected error: %v", err) } diff --git a/api/queries_pr.go b/api/queries_pr.go index 188e708a8..56582d0e3 100644 --- a/api/queries_pr.go +++ b/api/queries_pr.go @@ -540,8 +540,12 @@ func PullRequestForBranch(client *Client, repo ghrepo.Interface, baseBranch, hea headRepositoryOwner { login } + headRepository { + name + } isCrossRepository isDraft + maintainerCanModify reviewRequests(first: 100) { nodes { requestedReviewer { @@ -1007,11 +1011,8 @@ func PullRequestReady(client *Client, repo ghrepo.Interface, pr *PullRequest) er } func BranchDeleteRemote(client *Client, repo ghrepo.Interface, branch string) error { - var response struct { - NodeID string `json:"node_id"` - } path := fmt.Sprintf("repos/%s/%s/git/refs/heads/%s", repo.RepoOwner(), repo.RepoName(), branch) - return client.REST("DELETE", path, nil, &response) + return client.REST("DELETE", path, nil, nil) } func min(a, b int) int { diff --git a/api/queries_pr_test.go b/api/queries_pr_test.go new file mode 100644 index 000000000..07370023b --- /dev/null +++ b/api/queries_pr_test.go @@ -0,0 +1,45 @@ +package api + +import ( + "testing" + + "github.com/cli/cli/internal/ghrepo" + "github.com/cli/cli/pkg/httpmock" +) + +func TestBranchDeleteRemote(t *testing.T) { + var tests = []struct { + name string + responseStatus int + responseBody string + expectError bool + }{ + { + name: "success", + responseStatus: 204, + responseBody: "", + expectError: false, + }, + { + name: "error", + responseStatus: 500, + responseBody: `{"message": "oh no"}`, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + http := &httpmock.Registry{} + http.Register(httpmock.MatchAny, httpmock.StatusStringResponse(tt.responseStatus, tt.responseBody)) + + client := NewClient(ReplaceTripper(http)) + repo, _ := ghrepo.FromFullName("OWNER/REPO") + + err := BranchDeleteRemote(client, repo, "branch") + if (err != nil) != tt.expectError { + t.Fatalf("unexpected result: %v", err) + } + }) + } +} diff --git a/command/config.go b/command/config.go index 7f34b8687..0a77e99e5 100644 --- a/command/config.go +++ b/command/config.go @@ -48,6 +48,7 @@ var configSetCmd = &cobra.Command{ Short: "Update configuration with a value for the given key", Example: heredoc.Doc(` $ gh config set editor vim + $ gh config set editor "code --wait" `), Args: cobra.ExactArgs(2), RunE: configSet, diff --git a/command/issue.go b/command/issue.go index f790b3d05..15c74f82a 100644 --- a/command/issue.go +++ b/command/issue.go @@ -20,6 +20,8 @@ import ( ) func init() { + issueCmd.PersistentFlags().StringP("repo", "R", "", "Select another repository using the `OWNER/REPO` format") + RootCmd.AddCommand(issueCmd) issueCmd.AddCommand(issueStatusCmd) @@ -40,6 +42,8 @@ func init() { issueListCmd.Flags().StringP("state", "s", "open", "Filter by state: {open|closed|all}") issueListCmd.Flags().IntP("limit", "L", 30, "Maximum number of issues to fetch") issueListCmd.Flags().StringP("author", "A", "", "Filter by author") + issueListCmd.Flags().String("mention", "", "Filter by mention") + issueListCmd.Flags().StringP("milestone", "m", "", "Filter by milestone `name`") issueCmd.AddCommand(issueViewCmd) issueViewCmd.Flags().BoolP("web", "w", false, "Open an issue in the browser") @@ -154,7 +158,17 @@ func issueList(cmd *cobra.Command, args []string) error { return err } - listResult, err := api.IssueList(apiClient, baseRepo, state, labels, assignee, limit, author) + mention, err := cmd.Flags().GetString("mention") + if err != nil { + return err + } + + milestone, err := cmd.Flags().GetString("milestone") + if err != nil { + return err + } + + listResult, err := api.IssueList(apiClient, baseRepo, state, labels, assignee, limit, author, mention, milestone) if err != nil { return err } @@ -162,7 +176,7 @@ func issueList(cmd *cobra.Command, args []string) error { hasFilters := false cmd.Flags().Visit(func(f *pflag.Flag) { switch f.Name { - case "state", "label", "assignee", "author": + case "state", "label", "assignee", "author", "mention", "milestone": hasFilters = true } }) diff --git a/command/issue_test.go b/command/issue_test.go index 105ff8997..d2f1f1958 100644 --- a/command/issue_test.go +++ b/command/issue_test.go @@ -153,7 +153,7 @@ func TestIssueList_withFlags(t *testing.T) { } } } `)) - output, err := RunCommand("issue list -a probablyCher -l web,bug -s open -A foo") + output, err := RunCommand("issue list -a probablyCher -l web,bug -s open -A foo --mention me --milestone 1.x") if err != nil { t.Errorf("error running command `issue list`: %v", err) } @@ -167,10 +167,12 @@ No issues match your search in OWNER/REPO bodyBytes, _ := ioutil.ReadAll(http.Requests[1].Body) reqBody := struct { Variables struct { - Assignee string - Labels []string - States []string - Author string + Assignee string + Labels []string + States []string + Author string + Mention string + Milestone string } }{} _ = json.Unmarshal(bodyBytes, &reqBody) @@ -179,6 +181,8 @@ No issues match your search in OWNER/REPO eq(t, reqBody.Variables.Labels, []string{"web", "bug"}) eq(t, reqBody.Variables.States, []string{"OPEN"}) eq(t, reqBody.Variables.Author, "foo") + eq(t, reqBody.Variables.Mention, "me") + eq(t, reqBody.Variables.Milestone, "1.x") } func TestIssueList_withInvalidLimitFlag(t *testing.T) { @@ -403,7 +407,7 @@ func TestIssueView_web_notFound(t *testing.T) { defer restoreCmd() _, err := RunCommand("issue view -w 9999") - if err == nil || err.Error() != "graphql error: 'Could not resolve to an Issue with the number of 9999.'" { + if err == nil || err.Error() != "GraphQL error: Could not resolve to an Issue with the number of 9999." { t.Errorf("error running command `issue view`: %v", err) } diff --git a/command/pr.go b/command/pr.go index 18150d4eb..79c65b874 100644 --- a/command/pr.go +++ b/command/pr.go @@ -23,6 +23,8 @@ import ( ) func init() { + prCmd.PersistentFlags().StringP("repo", "R", "", "Select another repository using the `OWNER/REPO` format") + RootCmd.AddCommand(prCmd) prCmd.AddCommand(prCheckoutCmd) prCmd.AddCommand(prCreateCmd) @@ -131,7 +133,7 @@ func prStatus(cmd *cobra.Command, args []string) error { repoOverride, _ := cmd.Flags().GetString("repo") currentPRNumber, currentPRHeadRef, err := prSelectorForCurrentBranch(ctx, baseRepo) - if err != nil && repoOverride == "" && err.Error() != "git: not on any branch" { + if err != nil && repoOverride == "" && !errors.Is(err, git.ErrNotOnAnyBranch) { return fmt.Errorf("could not query for pull request for current branch: %w", err) } @@ -523,7 +525,9 @@ func prMerge(cmd *cobra.Command, args []string) error { if !crossRepoPR { err = api.BranchDeleteRemote(apiClient, baseRepo, pr.HeadRefName) - if err != nil { + var httpErr api.HTTPError + // The ref might have already been deleted by GitHub + if err != nil && (!errors.As(err, &httpErr) || httpErr.StatusCode != 422) { err = fmt.Errorf("failed to delete remote branch %s: %w", utils.Cyan(pr.HeadRefName), err) return err } diff --git a/command/pr_test.go b/command/pr_test.go index 4c31198b0..e7e1e2dc3 100644 --- a/command/pr_test.go +++ b/command/pr_test.go @@ -299,6 +299,38 @@ Requesting a code review from you } } +func TestPRStatus_detachedHead(t *testing.T) { + initBlankContext("", "OWNER/REPO", "") + http := initFakeHTTP() + http.StubRepoResponse("OWNER", "REPO") + + http.StubResponse(200, bytes.NewBufferString(` + { "data": {} } + `)) + + output, err := RunCommand("pr status") + if err != nil { + t.Errorf("error running command `pr status`: %v", err) + } + + expected := ` +Relevant pull requests in OWNER/REPO + +Current branch + There is no current branch + +Created by you + You have no open pull requests + +Requesting a code review from you + You have no pull requests to review + +` + if output.String() != expected { + t.Errorf("expected %q, got %q", expected, output.String()) + } +} + func TestPRList(t *testing.T) { initBlankContext("", "OWNER/REPO", "master") http := initFakeHTTP() diff --git a/command/root.go b/command/root.go index 3d53c3a71..40f44d136 100644 --- a/command/root.go +++ b/command/root.go @@ -52,7 +52,6 @@ func init() { RootCmd.AddCommand(versionCmd) RootCmd.SetVersionTemplate(versionOutput) - RootCmd.PersistentFlags().StringP("repo", "R", "", "Select another repository using the `OWNER/REPO` format") RootCmd.PersistentFlags().Bool("help", false, "Show help for command") RootCmd.Flags().Bool("version", false, "Show gh version") // TODO: @@ -304,8 +303,8 @@ func changelogURL(version string) string { } func determineBaseRepo(apiClient *api.Client, cmd *cobra.Command, ctx context.Context) (ghrepo.Interface, error) { - repo, err := cmd.Flags().GetString("repo") - if err == nil && repo != "" { + repo, _ := cmd.Flags().GetString("repo") + if repo != "" { baseRepo, err := ghrepo.FromFullName(repo) if err != nil { return nil, fmt.Errorf("argument error: %w", err) @@ -313,17 +312,12 @@ func determineBaseRepo(apiClient *api.Client, cmd *cobra.Command, ctx context.Co return baseRepo, nil } - baseOverride, err := cmd.Flags().GetString("repo") - if err != nil { - return nil, err - } - remotes, err := ctx.Remotes() if err != nil { return nil, err } - repoContext, err := context.ResolveRemotesToRepos(remotes, apiClient, baseOverride) + repoContext, err := context.ResolveRemotesToRepos(remotes, apiClient, "") if err != nil { return nil, err } diff --git a/context/blank_context.go b/context/blank_context.go index 151e5a9ed..3035b4d21 100644 --- a/context/blank_context.go +++ b/context/blank_context.go @@ -40,7 +40,7 @@ func (c *blankContext) SetAuthToken(t string) { func (c *blankContext) Branch() (string, error) { if c.branch == "" { - return "", fmt.Errorf("branch was not initialized") + return "", fmt.Errorf("branch was not initialized: %w", git.ErrNotOnAnyBranch) } return c.branch, nil } diff --git a/git/git.go b/git/git.go index 355f905f0..ff40adde3 100644 --- a/git/git.go +++ b/git/git.go @@ -13,6 +13,9 @@ import ( "github.com/cli/cli/internal/run" ) +// ErrNotOnAnyBranch indicates that the users is in detached HEAD state +var ErrNotOnAnyBranch = errors.New("git: not on any branch") + // Ref represents a git commit reference type Ref struct { Hash string @@ -64,7 +67,7 @@ func CurrentBranch() (string, error) { if errors.As(err, &cmdErr) { if cmdErr.Stderr.Len() == 0 { // Detached head - return "", errors.New("git: not on any branch") + return "", ErrNotOnAnyBranch } } diff --git a/git/git_test.go b/git/git_test.go index 8cd97a10d..c03e7153d 100644 --- a/git/git_test.go +++ b/git/git_test.go @@ -67,9 +67,8 @@ func Test_CurrentBranch_detached_head(t *testing.T) { if err == nil { t.Errorf("expected an error") } - expectedError := "git: not on any branch" - if err.Error() != expectedError { - t.Errorf("got unexpected error: %s instead of %s", err.Error(), expectedError) + if err != ErrNotOnAnyBranch { + t.Errorf("got unexpected error: %s instead of %s", err, ErrNotOnAnyBranch) } if len(cs.Calls) != 1 { t.Errorf("expected 1 git call, saw %d", len(cs.Calls)) diff --git a/pkg/cmd/api/api.go b/pkg/cmd/api/api.go index a5540cf2d..170c41b55 100644 --- a/pkg/cmd/api/api.go +++ b/pkg/cmd/api/api.go @@ -3,6 +3,7 @@ package api import ( "bytes" "encoding/json" + "errors" "fmt" "io" "io/ioutil" @@ -32,6 +33,8 @@ type ApiOptions struct { RawFields []string RequestHeaders []string ShowResponseHeaders bool + Paginate bool + Silent bool HttpClient func() (*http.Client, error) BaseRepo func() (ghrepo.Interface, error) @@ -74,7 +77,11 @@ on the format of the value: Raw request body may be passed from the outside via a file specified by '--input'. Pass "-" to read from standard input. In this mode, parameters specified via '--field' flags are serialized into URL query parameters. -`, + +In '--paginate' mode, all pages of results will sequentially be requested until +there are no more pages of results. For GraphQL requests, this requires that the +original query accepts an '$endCursor: String' variable and that it fetches the +'pageInfo{ hasNextPage, endCursor }' set of fields from a collection.`, Example: heredoc.Doc(` $ gh api repos/:owner/:repo/releases @@ -87,12 +94,33 @@ Pass "-" to read from standard input. In this mode, parameters specified via } } ' + + $ gh api graphql --paginate -f query=' + query($endCursor: String) { + viewer { + repositories(first: 100, after: $endCursor) { + nodes { nameWithOwner } + pageInfo { + hasNextPage + endCursor + } + } + } + } + ' `), Args: cobra.ExactArgs(1), RunE: func(c *cobra.Command, args []string) error { opts.RequestPath = args[0] opts.RequestMethodPassed = c.Flags().Changed("method") + if opts.Paginate && !strings.EqualFold(opts.RequestMethod, "GET") && opts.RequestPath != "graphql" { + return &cmdutil.FlagError{Err: errors.New(`the '--paginate' option is not supported for non-GET requests`)} + } + if opts.Paginate && opts.RequestInputFile != "" { + return &cmdutil.FlagError{Err: errors.New(`the '--paginate' option is not supported with '--input'`)} + } + if runF != nil { return runF(&opts) } @@ -105,7 +133,9 @@ Pass "-" to read from standard input. In this mode, parameters specified via cmd.Flags().StringArrayVarP(&opts.RawFields, "raw-field", "f", nil, "Add a string parameter") cmd.Flags().StringArrayVarP(&opts.RequestHeaders, "header", "H", nil, "Add an additional HTTP request header") cmd.Flags().BoolVarP(&opts.ShowResponseHeaders, "include", "i", false, "Include HTTP response headers in the output") + cmd.Flags().BoolVar(&opts.Paginate, "paginate", false, "Make additional HTTP requests to fetch all pages of results") cmd.Flags().StringVar(&opts.RequestInputFile, "input", "", "The file to use as body for the HTTP request") + cmd.Flags().BoolVar(&opts.Silent, "silent", false, "Do not print the response body") return cmd } @@ -115,6 +145,7 @@ func apiRun(opts *ApiOptions) error { return err } + isGraphQL := opts.RequestPath == "graphql" requestPath, err := fillPlaceholders(opts.RequestPath, opts) if err != nil { return fmt.Errorf("unable to expand placeholder in path: %w", err) @@ -127,6 +158,10 @@ func apiRun(opts *ApiOptions) error { method = "POST" } + if opts.Paginate && !isGraphQL { + requestPath = addPerPage(requestPath, 100, params) + } + if opts.RequestInputFile != "" { file, size, err := openUserFile(opts.RequestInputFile, opts.IO.In) if err != nil { @@ -145,19 +180,53 @@ func apiRun(opts *ApiOptions) error { return err } - resp, err := httpRequest(httpClient, method, requestPath, requestBody, requestHeaders) - if err != nil { - return err + headersOutputStream := opts.IO.Out + if opts.Silent { + opts.IO.Out = ioutil.Discard } + hasNextPage := true + for hasNextPage { + resp, err := httpRequest(httpClient, method, requestPath, requestBody, requestHeaders) + if err != nil { + return err + } + + endCursor, err := processResponse(resp, opts, headersOutputStream) + if err != nil { + return err + } + + if !opts.Paginate { + break + } + + if isGraphQL { + hasNextPage = endCursor != "" + if hasNextPage { + params["endCursor"] = endCursor + } + } else { + requestPath, hasNextPage = findNextPage(resp) + } + + if hasNextPage && opts.ShowResponseHeaders { + fmt.Fprint(opts.IO.Out, "\n") + } + } + + return nil +} + +func processResponse(resp *http.Response, opts *ApiOptions, headersOutputStream io.Writer) (endCursor string, err error) { if opts.ShowResponseHeaders { - fmt.Fprintln(opts.IO.Out, resp.Proto, resp.Status) - printHeaders(opts.IO.Out, resp.Header, opts.IO.ColorEnabled()) - fmt.Fprint(opts.IO.Out, "\r\n") + fmt.Fprintln(headersOutputStream, resp.Proto, resp.Status) + printHeaders(headersOutputStream, resp.Header, opts.IO.ColorEnabled()) + fmt.Fprint(headersOutputStream, "\r\n") } if resp.StatusCode == 204 { - return nil + return } var responseBody io.Reader = resp.Body defer resp.Body.Close() @@ -168,31 +237,44 @@ func apiRun(opts *ApiOptions) error { if isJSON && (opts.RequestPath == "graphql" || resp.StatusCode >= 400) { responseBody, serverError, err = parseErrorResponse(responseBody, resp.StatusCode) if err != nil { - return err + return } } + var bodyCopy *bytes.Buffer + isGraphQLPaginate := isJSON && resp.StatusCode == 200 && opts.Paginate && opts.RequestPath == "graphql" + if isGraphQLPaginate { + bodyCopy = &bytes.Buffer{} + responseBody = io.TeeReader(responseBody, bodyCopy) + } + if isJSON && opts.IO.ColorEnabled() { err = jsoncolor.Write(opts.IO.Out, responseBody, " ") if err != nil { - return err + return } } else { _, err = io.Copy(opts.IO.Out, responseBody) if err != nil { - return err + return } } if serverError != "" { fmt.Fprintf(opts.IO.ErrOut, "gh: %s\n", serverError) - return cmdutil.SilentError + err = cmdutil.SilentError + return } else if resp.StatusCode > 299 { fmt.Fprintf(opts.IO.ErrOut, "gh: HTTP %d\n", resp.StatusCode) - return cmdutil.SilentError + err = cmdutil.SilentError + return } - return nil + if isGraphQLPaginate { + endCursor = findEndCursor(bodyCopy) + } + + return } var placeholderRE = regexp.MustCompile(`\:(owner|repo)\b`) diff --git a/pkg/cmd/api/api_test.go b/pkg/cmd/api/api_test.go index 605e4bf9e..f590d93b4 100644 --- a/pkg/cmd/api/api_test.go +++ b/pkg/cmd/api/api_test.go @@ -2,6 +2,7 @@ package api import ( "bytes" + "encoding/json" "fmt" "io/ioutil" "net/http" @@ -13,6 +14,7 @@ import ( "github.com/cli/cli/pkg/iostreams" "github.com/google/shlex" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func Test_NewCmdApi(t *testing.T) { @@ -36,6 +38,8 @@ func Test_NewCmdApi(t *testing.T) { MagicFields: []string(nil), RequestHeaders: []string(nil), ShowResponseHeaders: false, + Paginate: false, + Silent: false, }, wantsErr: false, }, @@ -51,6 +55,8 @@ func Test_NewCmdApi(t *testing.T) { MagicFields: []string(nil), RequestHeaders: []string(nil), ShowResponseHeaders: false, + Paginate: false, + Silent: false, }, wantsErr: false, }, @@ -66,6 +72,8 @@ func Test_NewCmdApi(t *testing.T) { MagicFields: []string{"body=@file.txt"}, RequestHeaders: []string(nil), ShowResponseHeaders: false, + Paginate: false, + Silent: false, }, wantsErr: false, }, @@ -81,9 +89,72 @@ func Test_NewCmdApi(t *testing.T) { MagicFields: []string(nil), RequestHeaders: []string{"accept: text/plain"}, ShowResponseHeaders: true, + Paginate: false, + Silent: false, }, wantsErr: false, }, + { + name: "with pagination", + cli: "repos/OWNER/REPO/issues --paginate", + wants: ApiOptions{ + RequestMethod: "GET", + RequestMethodPassed: false, + RequestPath: "repos/OWNER/REPO/issues", + RequestInputFile: "", + RawFields: []string(nil), + MagicFields: []string(nil), + RequestHeaders: []string(nil), + ShowResponseHeaders: false, + Paginate: true, + Silent: false, + }, + wantsErr: false, + }, + { + name: "with silenced output", + cli: "repos/OWNER/REPO/issues --silent", + wants: ApiOptions{ + RequestMethod: "GET", + RequestMethodPassed: false, + RequestPath: "repos/OWNER/REPO/issues", + RequestInputFile: "", + RawFields: []string(nil), + MagicFields: []string(nil), + RequestHeaders: []string(nil), + ShowResponseHeaders: false, + Paginate: false, + Silent: true, + }, + wantsErr: false, + }, + { + name: "POST pagination", + cli: "-XPOST repos/OWNER/REPO/issues --paginate", + wantsErr: true, + }, + { + name: "GraphQL pagination", + cli: "-XPOST graphql --paginate", + wants: ApiOptions{ + RequestMethod: "POST", + RequestMethodPassed: true, + RequestPath: "graphql", + RequestInputFile: "", + RawFields: []string(nil), + MagicFields: []string(nil), + RequestHeaders: []string(nil), + ShowResponseHeaders: false, + Paginate: true, + Silent: false, + }, + wantsErr: false, + }, + { + name: "input pagination", + cli: "--input repos/OWNER/REPO/issues --paginate", + wantsErr: true, + }, { name: "with request body from file", cli: "user --input myfile", @@ -96,6 +167,8 @@ func Test_NewCmdApi(t *testing.T) { MagicFields: []string(nil), RequestHeaders: []string(nil), ShowResponseHeaders: false, + Paginate: false, + Silent: false, }, wantsErr: false, }, @@ -215,6 +288,36 @@ func Test_apiRun(t *testing.T) { stdout: `gateway timeout`, stderr: "gh: HTTP 502\n", }, + { + name: "silent", + options: ApiOptions{ + Silent: true, + }, + httpResponse: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(bytes.NewBufferString(`body`)), + }, + err: nil, + stdout: ``, + stderr: ``, + }, + { + name: "show response headers even when silent", + options: ApiOptions{ + ShowResponseHeaders: true, + Silent: true, + }, + httpResponse: &http.Response{ + Proto: "HTTP/1.1", + Status: "200 Okey-dokey", + StatusCode: 200, + Body: ioutil.NopCloser(bytes.NewBufferString(`body`)), + Header: http.Header{"Content-Type": []string{"text/plain"}}, + }, + err: nil, + stdout: "HTTP/1.1 200 Okey-dokey\nContent-Type: text/plain\r\n\r\n", + stderr: ``, + }, } for _, tt := range tests { @@ -246,6 +349,136 @@ func Test_apiRun(t *testing.T) { } } +func Test_apiRun_paginationREST(t *testing.T) { + io, _, stdout, stderr := iostreams.Test() + + requestCount := 0 + responses := []*http.Response{ + { + StatusCode: 200, + Body: ioutil.NopCloser(bytes.NewBufferString(`{"page":1}`)), + Header: http.Header{ + "Link": []string{`; rel="next", ; rel="last"`}, + }, + }, + { + StatusCode: 200, + Body: ioutil.NopCloser(bytes.NewBufferString(`{"page":2}`)), + Header: http.Header{ + "Link": []string{`; rel="next", ; rel="last"`}, + }, + }, + { + StatusCode: 200, + Body: ioutil.NopCloser(bytes.NewBufferString(`{"page":3}`)), + Header: http.Header{}, + }, + } + + options := ApiOptions{ + IO: io, + HttpClient: func() (*http.Client, error) { + var tr roundTripper = func(req *http.Request) (*http.Response, error) { + resp := responses[requestCount] + resp.Request = req + requestCount++ + return resp, nil + } + return &http.Client{Transport: tr}, nil + }, + + RequestPath: "issues", + Paginate: true, + } + + err := apiRun(&options) + assert.NoError(t, err) + + assert.Equal(t, `{"page":1}{"page":2}{"page":3}`, stdout.String(), "stdout") + assert.Equal(t, "", stderr.String(), "stderr") + + assert.Equal(t, "https://api.github.com/issues?per_page=100", responses[0].Request.URL.String()) + assert.Equal(t, "https://api.github.com/repositories/1227/issues?page=2", responses[1].Request.URL.String()) + assert.Equal(t, "https://api.github.com/repositories/1227/issues?page=3", responses[2].Request.URL.String()) +} + +func Test_apiRun_paginationGraphQL(t *testing.T) { + io, _, stdout, stderr := iostreams.Test() + + requestCount := 0 + responses := []*http.Response{ + { + StatusCode: 200, + Header: http.Header{"Content-Type": []string{`application/json`}}, + Body: ioutil.NopCloser(bytes.NewBufferString(`{ + "data": { + "nodes": ["page one"], + "pageInfo": { + "endCursor": "PAGE1_END", + "hasNextPage": true + } + } + }`)), + }, + { + StatusCode: 200, + Header: http.Header{"Content-Type": []string{`application/json`}}, + Body: ioutil.NopCloser(bytes.NewBufferString(`{ + "data": { + "nodes": ["page two"], + "pageInfo": { + "endCursor": "PAGE2_END", + "hasNextPage": false + } + } + }`)), + }, + } + + options := ApiOptions{ + IO: io, + HttpClient: func() (*http.Client, error) { + var tr roundTripper = func(req *http.Request) (*http.Response, error) { + resp := responses[requestCount] + resp.Request = req + requestCount++ + return resp, nil + } + return &http.Client{Transport: tr}, nil + }, + + RequestMethod: "POST", + RequestPath: "graphql", + Paginate: true, + } + + err := apiRun(&options) + require.NoError(t, err) + + assert.Contains(t, stdout.String(), `"page one"`) + assert.Contains(t, stdout.String(), `"page two"`) + assert.Equal(t, "", stderr.String(), "stderr") + + var requestData struct { + Variables map[string]interface{} + } + + bb, err := ioutil.ReadAll(responses[0].Request.Body) + require.NoError(t, err) + err = json.Unmarshal(bb, &requestData) + require.NoError(t, err) + _, hasCursor := requestData.Variables["endCursor"].(string) + assert.Equal(t, false, hasCursor) + + bb, err = ioutil.ReadAll(responses[1].Request.Body) + require.NoError(t, err) + err = json.Unmarshal(bb, &requestData) + require.NoError(t, err) + endCursor, hasCursor := requestData.Variables["endCursor"].(string) + assert.Equal(t, true, hasCursor) + assert.Equal(t, "PAGE1_END", endCursor) +} + func Test_apiRun_inputFile(t *testing.T) { tests := []struct { name string diff --git a/pkg/cmd/api/pagination.go b/pkg/cmd/api/pagination.go new file mode 100644 index 000000000..fce88fb92 --- /dev/null +++ b/pkg/cmd/api/pagination.go @@ -0,0 +1,108 @@ +package api + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "regexp" + "strings" +) + +var linkRE = regexp.MustCompile(`<([^>]+)>;\s*rel="([^"]+)"`) + +func findNextPage(resp *http.Response) (string, bool) { + for _, m := range linkRE.FindAllStringSubmatch(resp.Header.Get("Link"), -1) { + if len(m) >= 2 && m[2] == "next" { + return m[1], true + } + } + return "", false +} + +func findEndCursor(r io.Reader) string { + dec := json.NewDecoder(r) + + var idx int + var stack []json.Delim + var lastKey string + var contextKey string + + var endCursor string + var hasNextPage bool + var foundEndCursor bool + var foundNextPage bool + +loop: + for { + t, err := dec.Token() + if err == io.EOF { + break + } + if err != nil { + return "" + } + + switch tt := t.(type) { + case json.Delim: + switch tt { + case '{', '[': + stack = append(stack, tt) + contextKey = lastKey + idx = 0 + case '}', ']': + stack = stack[:len(stack)-1] + contextKey = "" + idx = 0 + } + default: + isKey := len(stack) > 0 && stack[len(stack)-1] == '{' && idx%2 == 0 + idx++ + + switch tt := t.(type) { + case string: + if isKey { + lastKey = tt + } else if contextKey == "pageInfo" && lastKey == "endCursor" { + endCursor = tt + foundEndCursor = true + if foundNextPage { + break loop + } + } + case bool: + if contextKey == "pageInfo" && lastKey == "hasNextPage" { + hasNextPage = tt + foundNextPage = true + if foundEndCursor { + break loop + } + } + } + } + } + + if hasNextPage { + return endCursor + } + return "" +} + +func addPerPage(p string, perPage int, params map[string]interface{}) string { + if _, hasPerPage := params["per_page"]; hasPerPage { + return p + } + + idx := strings.IndexRune(p, '?') + sep := "?" + + if idx >= 0 { + if qp, err := url.ParseQuery(p[idx+1:]); err == nil && qp.Get("per_page") != "" { + return p + } + sep = "&" + } + + return fmt.Sprintf("%s%sper_page=%d", p, sep, perPage) +} diff --git a/pkg/cmd/api/pagination_test.go b/pkg/cmd/api/pagination_test.go new file mode 100644 index 000000000..3bb1a8ec5 --- /dev/null +++ b/pkg/cmd/api/pagination_test.go @@ -0,0 +1,169 @@ +package api + +import ( + "bytes" + "io" + "net/http" + "testing" +) + +func Test_findNextPage(t *testing.T) { + tests := []struct { + name string + resp *http.Response + want string + want1 bool + }{ + { + name: "no Link header", + resp: &http.Response{}, + want: "", + want1: false, + }, + { + name: "no next page in Link", + resp: &http.Response{ + Header: http.Header{ + "Link": []string{`; rel="last"`}, + }, + }, + want: "", + want1: false, + }, + { + name: "has next page", + resp: &http.Response{ + Header: http.Header{ + "Link": []string{`; rel="next", ; rel="last"`}, + }, + }, + want: "https://api.github.com/issues?page=2", + want1: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, got1 := findNextPage(tt.resp) + if got != tt.want { + t.Errorf("findNextPage() got = %v, want %v", got, tt.want) + } + if got1 != tt.want1 { + t.Errorf("findNextPage() got1 = %v, want %v", got1, tt.want1) + } + }) + } +} + +func Test_findEndCursor(t *testing.T) { + tests := []struct { + name string + json io.Reader + want string + }{ + { + name: "blank", + json: bytes.NewBufferString(`{}`), + want: "", + }, + { + name: "unrelated fields", + json: bytes.NewBufferString(`{ + "hasNextPage": true, + "endCursor": "THE_END" + }`), + want: "", + }, + { + name: "has next page", + json: bytes.NewBufferString(`{ + "pageInfo": { + "hasNextPage": true, + "endCursor": "THE_END" + } + }`), + want: "THE_END", + }, + { + name: "more pageInfo blocks", + json: bytes.NewBufferString(`{ + "pageInfo": { + "hasNextPage": true, + "endCursor": "THE_END" + }, + "pageInfo": { + "hasNextPage": true, + "endCursor": "NOT_THIS" + } + }`), + want: "THE_END", + }, + { + name: "no next page", + json: bytes.NewBufferString(`{ + "pageInfo": { + "hasNextPage": false, + "endCursor": "THE_END" + } + }`), + want: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := findEndCursor(tt.json); got != tt.want { + t.Errorf("findEndCursor() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_addPerPage(t *testing.T) { + type args struct { + p string + perPage int + params map[string]interface{} + } + tests := []struct { + name string + args args + want string + }{ + { + name: "adds per_page", + args: args{ + p: "items", + perPage: 13, + params: nil, + }, + want: "items?per_page=13", + }, + { + name: "avoids adding per_page if already in params", + args: args{ + p: "items", + perPage: 13, + params: map[string]interface{}{ + "state": "open", + "per_page": 99, + }, + }, + want: "items", + }, + { + name: "avoids adding per_page if already in query", + args: args{ + p: "items?per_page=6&state=open", + perPage: 13, + params: nil, + }, + want: "items?per_page=6&state=open", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := addPerPage(tt.args.p, tt.args.perPage, tt.args.params); got != tt.want { + t.Errorf("addPerPage() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/httpmock/legacy.go b/pkg/httpmock/legacy.go index 9474c3dd8..9402c21c7 100644 --- a/pkg/httpmock/legacy.go +++ b/pkg/httpmock/legacy.go @@ -12,19 +12,19 @@ import ( // TODO: clean up methods in this file when there are no more callers func (r *Registry) StubResponse(status int, body io.Reader) { - r.Register(MatchAny, func(*http.Request) (*http.Response, error) { - return httpResponse(status, body), nil + r.Register(MatchAny, func(req *http.Request) (*http.Response, error) { + return httpResponse(status, req, body), nil }) } func (r *Registry) StubWithFixture(status int, fixtureFileName string) func() { fixturePath := path.Join("../test/fixtures/", fixtureFileName) fixtureFile, err := os.Open(fixturePath) - r.Register(MatchAny, func(*http.Request) (*http.Response, error) { + r.Register(MatchAny, func(req *http.Request) (*http.Response, error) { if err != nil { return nil, err } - return httpResponse(200, fixtureFile), nil + return httpResponse(200, req, fixtureFile), nil }) return func() { if err == nil { diff --git a/pkg/httpmock/stub.go b/pkg/httpmock/stub.go index 48077ed4e..c57b7a1ad 100644 --- a/pkg/httpmock/stub.go +++ b/pkg/httpmock/stub.go @@ -59,15 +59,21 @@ func decodeJSONBody(req *http.Request, dest interface{}) error { } func StringResponse(body string) Responder { - return func(*http.Request) (*http.Response, error) { - return httpResponse(200, bytes.NewBufferString(body)), nil + return func(req *http.Request) (*http.Response, error) { + return httpResponse(200, req, bytes.NewBufferString(body)), nil + } +} + +func StatusStringResponse(status int, body string) Responder { + return func(req *http.Request) (*http.Response, error) { + return httpResponse(status, req, bytes.NewBufferString(body)), nil } } func JSONResponse(body interface{}) Responder { - return func(*http.Request) (*http.Response, error) { + return func(req *http.Request) (*http.Response, error) { b, _ := json.Marshal(body) - return httpResponse(200, bytes.NewBuffer(b)), nil + return httpResponse(200, req, bytes.NewBuffer(b)), nil } } @@ -84,7 +90,7 @@ func GraphQLMutation(body string, cb func(map[string]interface{})) Responder { } cb(bodyData.Variables.Input) - return httpResponse(200, bytes.NewBufferString(body)), nil + return httpResponse(200, req, bytes.NewBufferString(body)), nil } } @@ -100,13 +106,14 @@ func GraphQLQuery(body string, cb func(string, map[string]interface{})) Responde } cb(bodyData.Query, bodyData.Variables) - return httpResponse(200, bytes.NewBufferString(body)), nil + return httpResponse(200, req, bytes.NewBufferString(body)), nil } } -func httpResponse(status int, body io.Reader) *http.Response { +func httpResponse(status int, req *http.Request, body io.Reader) *http.Response { return &http.Response{ StatusCode: status, + Request: req, Body: ioutil.NopCloser(body), } } diff --git a/utils/color.go b/utils/color.go index 4940fe26b..dd9a7d11c 100644 --- a/utils/color.go +++ b/utils/color.go @@ -10,8 +10,7 @@ import ( ) var ( - _isColorEnabled bool = true - _isStdoutTerminal, checkedTerminal, checkedNoColor bool + _isStdoutTerminal, checkedTerminal bool // Outputs ANSI color if stdout is a tty Magenta = makeColorFunc("magenta") @@ -45,7 +44,7 @@ func NewColorable(f *os.File) io.Writer { func makeColorFunc(color string) func(string) string { cf := ansi.ColorFunc(color) return func(arg string) string { - if isColorEnabled() && isStdoutTerminal() { + if isColorEnabled() { return cf(arg) } return arg @@ -53,9 +52,9 @@ func makeColorFunc(color string) func(string) string { } func isColorEnabled() bool { - if !checkedNoColor { - _isColorEnabled = os.Getenv("NO_COLOR") == "" - checkedNoColor = true + if os.Getenv("NO_COLOR") != "" { + return false } - return _isColorEnabled + + return isStdoutTerminal() }