diff --git a/api/queries.go b/api/queries.go index 5a8dfec9e..335bb6f20 100644 --- a/api/queries.go +++ b/api/queries.go @@ -277,13 +277,10 @@ func PullRequestsForBranch(client *Client, ghRepo Repo, branch string) ([]PullRe return prs, nil } -func CreatePullRequest(client *Client, ghRepo Repo, title string, body string, draft bool, base string, head string) (string, error) { - repoId, err := GitHubRepoId(client, ghRepo) +func CreatePullRequest(client *Client, ghRepo Repo, params map[string]interface{}) (*PullRequest, error) { + repoID, err := GitHubRepoId(client, ghRepo) if err != nil { - return "", err - } - if repoId == "" { - return "", fmt.Errorf("could not determine GH repo ID") + return nil, err } query := ` @@ -295,15 +292,14 @@ func CreatePullRequest(client *Client, ghRepo Repo, title string, body string, d } }` + inputParams := map[string]interface{}{ + "repositoryId": repoID, + } + for key, val := range params { + inputParams[key] = val + } variables := map[string]interface{}{ - "input": map[string]interface{}{ - "repositoryId": repoId, - "baseRefName": base, - "headRefName": head, - "title": title, - "body": body, - "draft": draft, - }, + "input": inputParams, } result := struct { @@ -314,10 +310,10 @@ func CreatePullRequest(client *Client, ghRepo Repo, title string, body string, d err = client.GraphQL(query, variables, &result) if err != nil { - return "", err + return nil, err } - return result.CreatePullRequest.PullRequest.URL, nil + return &result.CreatePullRequest.PullRequest, nil } func PullRequestList(client *Client, vars map[string]interface{}, limit int) ([]PullRequest, error) { diff --git a/command/pr_create.go b/command/pr_create.go index a3e3581fc..c52571123 100644 --- a/command/pr_create.go +++ b/command/pr_create.go @@ -148,13 +148,20 @@ func prCreate(ctx context.Context) error { return fmt.Errorf("could not determine GitHub repo: %s", err) } - payload, err := api.CreatePullRequest(client, repo, title, body, _draftF, base, head) + params := map[string]interface{}{ + "title": title, + "body": body, + "draft": _draftF, + "baseRefName": base, + "headRefName": head, + } + + pr, err := api.CreatePullRequest(client, repo, params) if err != nil { return fmt.Errorf("failed to create PR: %s", err) } - fmt.Println(payload) - + fmt.Fprintln(cmd.OutOrStdout(), pr.URL) return nil }