diff --git a/api/queries_pr.go b/api/queries_pr.go index ef49964dd..2104a6882 100644 --- a/api/queries_pr.go +++ b/api/queries_pr.go @@ -736,7 +736,9 @@ func PullRequestForBranch(client *Client, repo ghrepo.Interface, baseBranch, hea return nil, err } - prs := sortPullRequestsByState(resp.Repository.PullRequests.Nodes) + prs := resp.Repository.PullRequests.Nodes + sortPullRequestsByState(prs) + for _, pr := range prs { if pr.HeadLabel() == headBranch { if baseBranch != "" { @@ -751,6 +753,13 @@ func PullRequestForBranch(client *Client, repo ghrepo.Interface, baseBranch, hea return nil, &NotFoundError{fmt.Errorf("no pull requests found for branch %q", headBranch)} } +// sortPullRequestsByState sorts a PullRequest slice by open-first +func sortPullRequestsByState(prs []PullRequest) { + sort.SliceStable(prs, func(a, b int) bool { + return prs[a].State == "OPEN" + }) +} + // CreatePullRequest creates a pull request in a GitHub repository func CreatePullRequest(client *Client, repo *Repository, params map[string]interface{}) (*PullRequest, error) { query := ` @@ -1152,14 +1161,6 @@ func BranchDeleteRemote(client *Client, repo ghrepo.Interface, branch string) er return client.REST(repo.RepoHost(), "DELETE", path, nil, nil) } -// sortPullRequestsByState ensures that OPEN PRs precede non-open states (MERGED, CLOSED) -func sortPullRequestsByState(prs []PullRequest) []PullRequest { - sort.SliceStable(prs, func(a, b int) bool { - return prs[a].State == "OPEN" - }) - return prs -} - func min(a, b int) int { if a < b { return a diff --git a/api/queries_pr_test.go b/api/queries_pr_test.go index 77a40ba8c..4e0c1581e 100644 --- a/api/queries_pr_test.go +++ b/api/queries_pr_test.go @@ -139,22 +139,31 @@ func Test_determinePullRequestFeatures(t *testing.T) { } } -func Test_sortPullRequestsByPrecedence(t *testing.T) { - prs := sortPullRequestsByState([]PullRequest{ +func Test_sortPullRequestsByState(t *testing.T) { + prs := []PullRequest{ { - BaseRefName: "test-PR", + BaseRefName: "test1", State: "MERGED", }, { - BaseRefName: "test-PR", + BaseRefName: "test2", State: "CLOSED", }, { - BaseRefName: "test-PR", + BaseRefName: "test3", State: "OPEN", }, - }) - if prs[0].State != "OPEN" { - t.Errorf("sortPullRequestsByPrecedence() = got %s, want \"OPEN\"", prs[0].State) + } + + sortPullRequestsByState(prs) + + if prs[0].BaseRefName != "test3" { + t.Errorf("prs[0]: got %s, want %q", prs[0].BaseRefName, "test3") + } + if prs[1].BaseRefName != "test1" { + t.Errorf("prs[1]: got %s, want %q", prs[1].BaseRefName, "test1") + } + if prs[2].BaseRefName != "test2" { + t.Errorf("prs[2]: got %s, want %q", prs[2].BaseRefName, "test2") } }