diff --git a/api/queries_pr.go b/api/queries_pr.go index 34011d5be..1fa009ea1 100644 --- a/api/queries_pr.go +++ b/api/queries_pr.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net/http" + "sort" "strings" "time" @@ -621,7 +622,7 @@ func PullRequestByNumber(client *Client, repo ghrepo.Interface, number int) (*Pu return &resp.Repository.PullRequest, nil } -func PullRequestForBranch(client *Client, repo ghrepo.Interface, baseBranch, headBranch string) (*PullRequest, error) { +func PullRequestForBranch(client *Client, repo ghrepo.Interface, baseBranch, headBranch string, stateFilters []string) (*PullRequest, error) { type response struct { Repository struct { PullRequests struct { @@ -637,9 +638,9 @@ func PullRequestForBranch(client *Client, repo ghrepo.Interface, baseBranch, hea } query := ` - query PullRequestForBranch($owner: String!, $repo: String!, $headRefName: String!) { + query PullRequestForBranch($owner: String!, $repo: String!, $headRefName: String!, $states: [PullRequestState!]) { repository(owner: $owner, name: $repo) { - pullRequests(headRefName: $headRefName, states: OPEN, first: 30) { + pullRequests(headRefName: $headRefName, states: $states, first: 30) { nodes { id number @@ -726,6 +727,7 @@ func PullRequestForBranch(client *Client, repo ghrepo.Interface, baseBranch, hea "owner": repo.RepoOwner(), "repo": repo.RepoName(), "headRefName": branchWithoutOwner, + "states": stateFilters, } var resp response @@ -734,7 +736,8 @@ func PullRequestForBranch(client *Client, repo ghrepo.Interface, baseBranch, hea return nil, err } - for _, pr := range resp.Repository.PullRequests.Nodes { + prs := sortPullRequestsByState(resp.Repository.PullRequests.Nodes) + for _, pr := range prs { if pr.HeadLabel() == headBranch { if baseBranch != "" { if pr.BaseRefName != baseBranch { @@ -1149,6 +1152,14 @@ func BranchDeleteRemote(client *Client, repo ghrepo.Interface, branch string) er return client.REST(repo.RepoHost(), "DELETE", path, nil, nil) } +// sortPullRequestsByState ensures that OPEN PRs precede non-open states (MERGED, CLOSED) +func sortPullRequestsByState(prs []PullRequest) []PullRequest { + sort.SliceStable(prs, func(a, b int) bool { + return prs[a].State == "OPEN" + }) + return prs +} + func min(a, b int) int { if a < b { return a diff --git a/api/queries_pr_test.go b/api/queries_pr_test.go index e624b61b3..77a40ba8c 100644 --- a/api/queries_pr_test.go +++ b/api/queries_pr_test.go @@ -138,3 +138,23 @@ func Test_determinePullRequestFeatures(t *testing.T) { }) } } + +func Test_sortPullRequestsByPrecedence(t *testing.T) { + prs := sortPullRequestsByState([]PullRequest{ + { + BaseRefName: "test-PR", + State: "MERGED", + }, + { + BaseRefName: "test-PR", + State: "CLOSED", + }, + { + BaseRefName: "test-PR", + State: "OPEN", + }, + }) + if prs[0].State != "OPEN" { + t.Errorf("sortPullRequestsByPrecedence() = got %s, want \"OPEN\"", prs[0].State) + } +} diff --git a/pkg/cmd/pr/create/create.go b/pkg/cmd/pr/create/create.go index 670491ac7..38950cb65 100644 --- a/pkg/cmd/pr/create/create.go +++ b/pkg/cmd/pr/create/create.go @@ -303,7 +303,7 @@ func createRun(opts *CreateOptions) error { } if !opts.WebMode { - existingPR, err := api.PullRequestForBranch(client, baseRepo, baseBranch, headBranchLabel) + existingPR, err := api.PullRequestForBranch(client, baseRepo, baseBranch, headBranchLabel, []string{"OPEN"}) var notFound *api.NotFoundError if err != nil && !errors.As(err, ¬Found) { return fmt.Errorf("error checking for existing pull request: %w", err) diff --git a/pkg/cmd/pr/shared/lookup.go b/pkg/cmd/pr/shared/lookup.go index 9583b854f..06e9221c0 100644 --- a/pkg/cmd/pr/shared/lookup.go +++ b/pkg/cmd/pr/shared/lookup.go @@ -43,7 +43,7 @@ func PRFromArgs(apiClient *api.Client, baseRepoFn func() (ghrepo.Interface, erro } // Last see if it is a branch name - pr, err = api.PullRequestForBranch(apiClient, repo, "", arg) + pr, err = api.PullRequestForBranch(apiClient, repo, "", arg, nil) return pr, repo, err } } @@ -117,5 +117,5 @@ func prForCurrentBranch(apiClient *api.Client, repo ghrepo.Interface, branchFn f } } - return api.PullRequestForBranch(apiClient, repo, "", prHeadRef) + return api.PullRequestForBranch(apiClient, repo, "", prHeadRef, nil) }