diff --git a/api/queries_pr.go b/api/queries_pr.go index 2bd7925c4..3f66c8ddc 100644 --- a/api/queries_pr.go +++ b/api/queries_pr.go @@ -474,11 +474,8 @@ func PullRequestList(client *Client, vars map[string]interface{}, limit int) ([] prs := []PullRequest{} pageLimit := min(limit, 100) variables := map[string]interface{}{} - for name, val := range vars { - variables[name] = val - } - if _, ok := vars["assignee"]; ok { + if assignee, ok := vars["assignee"].(string); ok { query = fragment + ` query( $q: String!, @@ -499,21 +496,36 @@ func PullRequestList(client *Client, vars map[string]interface{}, limit int) ([] }` owner := vars["owner"].(string) repo := vars["repo"].(string) - assignee := vars["assignee"].(string) - state := "" - states := vars["state"].([]string) - if len(states) == 1 { + search := []string{ + fmt.Sprintf("repo:%s/%s", owner, repo), + fmt.Sprintf("assignee:%s", assignee), + "is:pr", + "sort:created-desc", + } + if states, ok := vars["state"].([]string); ok && len(states) == 1 { switch states[0] { case "OPEN": - state = " state:open" + search = append(search, "state:open") case "CLOSED": - state = " state:closed" + search = append(search, "state:closed") case "MERGED": - state = " is:merged" + search = append(search, "is:merged") } } - // TODO: support base, label filtering - variables["q"] = fmt.Sprintf("repo:%s/%s assignee:%s is:pr%s sort:created-desc", owner, repo, assignee, state) + if labels, ok := vars["labels"].([]string); ok && len(labels) > 0 { + if len(labels) > 1 { + return nil, fmt.Errorf("multiple labels with --assignee are not supported: %#v", vars) + } + search = append(search, fmt.Sprintf(`label:"%s"`, labels[0])) + } + if baseBranch, ok := vars["baseBranch"].(string); ok { + search = append(search, fmt.Sprintf(`base:"%s"`, baseBranch)) + } + variables["q"] = strings.Join(search, " ") + } else { + for name, val := range vars { + variables[name] = val + } } for { diff --git a/command/pr.go b/command/pr.go index 8ff3d4f09..d4b923a91 100644 --- a/command/pr.go +++ b/command/pr.go @@ -27,7 +27,7 @@ func init() { prListCmd.Flags().IntP("limit", "L", 30, "Maximum number of items to fetch") prListCmd.Flags().StringP("state", "s", "open", "Filter by state") prListCmd.Flags().StringP("base", "B", "", "Filter by base branch") - prListCmd.Flags().StringArrayP("label", "l", nil, "Filter by label") + prListCmd.Flags().StringSliceP("label", "l", nil, "Filter by label") prListCmd.Flags().StringP("assignee", "a", "", "Filter by assignee") } @@ -137,7 +137,7 @@ func prList(cmd *cobra.Command, args []string) error { if err != nil { return err } - labels, err := cmd.Flags().GetStringArray("label") + labels, err := cmd.Flags().GetStringSlice("label") if err != nil { return err }