From 7512d9131fa9a22388f8abcdd9ae13692969d7da Mon Sep 17 00:00:00 2001 From: Ishida Yuya Date: Thu, 22 Oct 2020 01:22:06 +0900 Subject: [PATCH] Get open and closed milestones when milestones are filtered by title (#2209) --- api/queries_issue.go | 2 +- api/queries_repo.go | 25 +++++++++++++++----- api/queries_repo_test.go | 51 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 71 insertions(+), 7 deletions(-) diff --git a/api/queries_issue.go b/api/queries_issue.go index 42fc9e196..690180feb 100644 --- a/api/queries_issue.go +++ b/api/queries_issue.go @@ -257,7 +257,7 @@ func IssueList(client *Client, repo ghrepo.Interface, state string, labels []str return nil, err } } else { - milestone, err = MilestoneByTitle(client, repo, milestoneString) + milestone, err = MilestoneByTitle(client, repo, "all", milestoneString) if err != nil { return nil, err } diff --git a/api/queries_repo.go b/api/queries_repo.go index 402afc783..54f88a7e1 100644 --- a/api/queries_repo.go +++ b/api/queries_repo.go @@ -536,7 +536,7 @@ func RepoMetadata(client *Client, repo ghrepo.Interface, input RepoMetadataInput if input.Milestones { count++ go func() { - milestones, err := RepoMilestones(client, repo) + milestones, err := RepoMilestones(client, repo, "open") if err != nil { err = fmt.Errorf("error fetching milestones: %w", err) } @@ -797,8 +797,8 @@ type RepoMilestone struct { Title string } -// RepoMilestones fetches all open milestones in a repository -func RepoMilestones(client *Client, repo ghrepo.Interface) ([]RepoMilestone, error) { +// RepoMilestones fetches milestones in a repository +func RepoMilestones(client *Client, repo ghrepo.Interface, state string) ([]RepoMilestone, error) { type responseData struct { Repository struct { Milestones struct { @@ -807,13 +807,26 @@ func RepoMilestones(client *Client, repo ghrepo.Interface) ([]RepoMilestone, err HasNextPage bool EndCursor string } - } `graphql:"milestones(states: [OPEN], first: 100, after: $endCursor)"` + } `graphql:"milestones(states: $states, first: 100, after: $endCursor)"` } `graphql:"repository(owner: $owner, name: $name)"` } + var states []githubv4.MilestoneState + switch state { + case "open": + states = []githubv4.MilestoneState{"OPEN"} + case "closed": + states = []githubv4.MilestoneState{"CLOSED"} + case "all": + states = []githubv4.MilestoneState{"OPEN", "CLOSED"} + default: + return nil, fmt.Errorf("invalid state: %s", state) + } + variables := map[string]interface{}{ "owner": githubv4.String(repo.RepoOwner()), "name": githubv4.String(repo.RepoName()), + "states": states, "endCursor": (*githubv4.String)(nil), } @@ -837,8 +850,8 @@ func RepoMilestones(client *Client, repo ghrepo.Interface) ([]RepoMilestone, err return milestones, nil } -func MilestoneByTitle(client *Client, repo ghrepo.Interface, title string) (*RepoMilestone, error) { - milestones, err := RepoMilestones(client, repo) +func MilestoneByTitle(client *Client, repo ghrepo.Interface, state, title string) (*RepoMilestone, error) { + milestones, err := RepoMilestones(client, repo, state) if err != nil { return nil, err } diff --git a/api/queries_repo_test.go b/api/queries_repo_test.go index ad0e36864..a90f6bea5 100644 --- a/api/queries_repo_test.go +++ b/api/queries_repo_test.go @@ -1,6 +1,9 @@ package api import ( + "io" + "net/http" + "strings" "testing" "github.com/cli/cli/internal/ghrepo" @@ -233,3 +236,51 @@ func sliceEqual(a, b []string) bool { return true } + +func Test_RepoMilestones(t *testing.T) { + tests := []struct { + state string + want string + wantErr bool + }{ + { + state: "open", + want: `"states":["OPEN"]`, + }, + { + state: "closed", + want: `"states":["CLOSED"]`, + }, + { + state: "all", + want: `"states":["OPEN","CLOSED"]`, + }, + { + state: "invalid state", + wantErr: true, + }, + } + for _, tt := range tests { + var query string + reg := &httpmock.Registry{} + reg.Register(httpmock.MatchAny, func(req *http.Request) (*http.Response, error) { + buf := new(strings.Builder) + _, err := io.Copy(buf, req.Body) + if err != nil { + return nil, err + } + query = buf.String() + return httpmock.StringResponse("{}")(req) + }) + client := NewClient(ReplaceTripper(reg)) + + _, err := RepoMilestones(client, ghrepo.New("OWNER", "REPO"), tt.state) + if (err != nil) != tt.wantErr { + t.Errorf("RepoMilestones() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !strings.Contains(query, tt.want) { + t.Errorf("query does not contain %v", tt.want) + } + } +}