diff --git a/api/queries_pr.go b/api/queries_pr.go index 83104db29..3eeae1cee 100644 --- a/api/queries_pr.go +++ b/api/queries_pr.go @@ -117,7 +117,7 @@ type Repo interface { RepoOwner() string } -func PullRequests(client *Client, ghRepo Repo, currentBranch, currentUsername string) (*PullRequestsPayload, error) { +func PullRequests(client *Client, ghRepo Repo, currentPRNumber int, currentPRHeadRef, currentUsername string) (*PullRequestsPayload, error) { type edges struct { Edges []struct { Node PullRequest @@ -131,12 +131,13 @@ func PullRequests(client *Client, ghRepo Repo, currentBranch, currentUsername st type response struct { Repository struct { PullRequests edges + PullRequest *PullRequest } ViewerCreated edges ReviewRequested edges } - query := ` + fragments := ` fragment pr on PullRequest { number title @@ -170,16 +171,32 @@ func PullRequests(client *Client, ghRepo Repo, currentBranch, currentUsername st ...pr reviewDecision } - query($owner: String!, $repo: String!, $headRefName: String!, $viewerQuery: String!, $reviewerQuery: String!, $per_page: Int = 10) { - repository(owner: $owner, name: $repo) { - pullRequests(headRefName: $headRefName, states: OPEN, first: $per_page) { - edges { - node { - ...prWithReviews - } - } - } - } + ` + + queryPrefix := ` + query($owner: String!, $repo: String!, $headRefName: String!, $viewerQuery: String!, $reviewerQuery: String!, $per_page: Int = 10) { + repository(owner: $owner, name: $repo) { + pullRequests(headRefName: $headRefName, states: OPEN, first: $per_page) { + edges { + node { + ...prWithReviews + } + } + } + } + ` + if currentPRNumber > 0 { + queryPrefix = ` + query($owner: String!, $repo: String!, $number: Int!, $viewerQuery: String!, $reviewerQuery: String!, $per_page: Int = 10) { + repository(owner: $owner, name: $repo) { + pullRequest(number: $number) { + ...prWithReviews + } + } + ` + } + + query := fragments + queryPrefix + ` viewerCreated: search(query: $viewerQuery, type: ISSUE, first: $per_page) { edges { node { @@ -201,7 +218,7 @@ func PullRequests(client *Client, ghRepo Repo, currentBranch, currentUsername st } } } - ` + ` owner := ghRepo.RepoOwner() repo := ghRepo.RepoName() @@ -209,9 +226,9 @@ func PullRequests(client *Client, ghRepo Repo, currentBranch, currentUsername st viewerQuery := fmt.Sprintf("repo:%s/%s state:open is:pr author:%s", owner, repo, currentUsername) reviewerQuery := fmt.Sprintf("repo:%s/%s state:open review-requested:%s", owner, repo, currentUsername) - branchWithoutOwner := currentBranch - if idx := strings.Index(currentBranch, ":"); idx >= 0 { - branchWithoutOwner = currentBranch[idx+1:] + branchWithoutOwner := currentPRHeadRef + if idx := strings.Index(currentPRHeadRef, ":"); idx >= 0 { + branchWithoutOwner = currentPRHeadRef[idx+1:] } variables := map[string]interface{}{ @@ -220,6 +237,7 @@ func PullRequests(client *Client, ghRepo Repo, currentBranch, currentUsername st "owner": owner, "repo": repo, "headRefName": branchWithoutOwner, + "number": currentPRNumber, } var resp response @@ -238,10 +256,12 @@ func PullRequests(client *Client, ghRepo Repo, currentBranch, currentUsername st reviewRequested = append(reviewRequested, edge.Node) } - var currentPR *PullRequest - for _, edge := range resp.Repository.PullRequests.Edges { - if edge.Node.HeadLabel() == currentBranch { - currentPR = &edge.Node + var currentPR = resp.Repository.PullRequest + if currentPR == nil { + for _, edge := range resp.Repository.PullRequests.Edges { + if edge.Node.HeadLabel() == currentPRHeadRef { + currentPR = &edge.Node + } } } diff --git a/command/pr.go b/command/pr.go index e0715d639..122fbbe9c 100644 --- a/command/pr.go +++ b/command/pr.go @@ -68,7 +68,7 @@ func prStatus(cmd *cobra.Command, args []string) error { if err != nil { return err } - currentBranch, err := ctx.Branch() + currentPRNumber, currentPRHeadRef, err := prSelectorForCurrentBranch(ctx) if err != nil { return err } @@ -77,7 +77,7 @@ func prStatus(cmd *cobra.Command, args []string) error { return err } - prPayload, err := api.PullRequests(apiClient, baseRepo, currentBranch, currentUser) + prPayload, err := api.PullRequests(apiClient, baseRepo, currentPRNumber, currentPRHeadRef, currentUser) if err != nil { return err } @@ -88,7 +88,7 @@ func prStatus(cmd *cobra.Command, args []string) error { if prPayload.CurrentPR != nil { printPrs(out, *prPayload.CurrentPR) } else { - message := fmt.Sprintf(" There is no pull request associated with %s", utils.Cyan("["+currentBranch+"]")) + message := fmt.Sprintf(" There is no pull request associated with %s", utils.Cyan("["+currentPRHeadRef+"]")) printMessage(out, message) } fmt.Fprintln(out)