diff --git a/pkg/cmd/pr/create/create_test.go b/pkg/cmd/pr/create/create_test.go index a6c785fdc..d1f9a9683 100644 --- a/pkg/cmd/pr/create/create_test.go +++ b/pkg/cmd/pr/create/create_test.go @@ -598,6 +598,16 @@ func TestPRCreate_nonLegacyTemplate(t *testing.T) { http.StubRepoInfoResponse("OWNER", "REPO", "master") shared.RunCommandFinder("feature", nil, nil) + http.Register( + httpmock.GraphQL(`query PullRequestTemplates\b`), + httpmock.StringResponse(` + { "data": { "repository": { "pullRequestTemplates": [ + { "filename": "template1", + "body": "this is a bug" }, + { "filename": "template2", + "body": "this is a enhancement" } + ] } } }`), + ) http.Register( httpmock.GraphQL(`mutation PullRequestCreate\b`), httpmock.GraphQLMutation(` @@ -606,7 +616,7 @@ func TestPRCreate_nonLegacyTemplate(t *testing.T) { } } } } `, func(input map[string]interface{}) { assert.Equal(t, "my title", input["title"].(string)) - assert.Equal(t, "- commit 1\n- commit 0\n\nFixes a bug and Closes an issue", input["body"].(string)) + assert.Equal(t, "- commit 1\n- commit 0\n\nthis is a bug", input["body"].(string)) })) cs, cmdTeardown := run.Stub() @@ -615,13 +625,11 @@ func TestPRCreate_nonLegacyTemplate(t *testing.T) { cs.Register(`git( .+)? log( .+)? origin/master\.\.\.feature`, 0, "1234567890,commit 0\n2345678901,commit 1") cs.Register(`git status --porcelain`, 0, "") - //nolint:staticcheck // SA1019: prompt.InitAskStubber is deprecated: use NewAskStubber - as, teardown := prompt.InitAskStubber() - defer teardown() + as := prompt.NewAskStubber(t) as.StubPrompt("Choose a template"). - AssertOptions([]string{"Bug fix", "Open a blank pull request"}). - AnswerWith("Bug fix") + AssertOptions([]string{"template1", "template2", "Open a blank pull request"}). + AnswerWith("template1") as.StubPrompt("Body").AnswerDefault() as.StubPrompt("What's next?"). AssertOptions([]string{"Submit", "Continue in browser", "Add metadata", "Cancel"}). diff --git a/pkg/cmd/pr/shared/templates.go b/pkg/cmd/pr/shared/templates.go index 4e60714f7..6750bd6d6 100644 --- a/pkg/cmd/pr/shared/templates.go +++ b/pkg/cmd/pr/shared/templates.go @@ -23,6 +23,12 @@ type issueTemplate struct { Gbody string `graphql:"body"` } +type pullRequestTemplate struct { + // I would have un-exported these fields, except `cli/shurcool-graphql` then cannot unmarshal them :/ + Gname string `graphql:"filename"` + Gbody string `graphql:"body"` +} + func (t *issueTemplate) Name() string { return t.Gname } @@ -35,7 +41,19 @@ func (t *issueTemplate) Body() []byte { return []byte(t.Gbody) } -func listIssueTemplates(httpClient *http.Client, repo ghrepo.Interface) ([]issueTemplate, error) { +func (t *pullRequestTemplate) Name() string { + return t.Gname +} + +func (t *pullRequestTemplate) NameForSubmit() string { + return "" +} + +func (t *pullRequestTemplate) Body() []byte { + return []byte(t.Gbody) +} + +func listIssueTemplates(httpClient *http.Client, repo ghrepo.Interface) ([]Template, error) { var query struct { Repository struct { IssueTemplates []issueTemplate @@ -54,10 +72,44 @@ func listIssueTemplates(httpClient *http.Client, repo ghrepo.Interface) ([]issue return nil, err } - return query.Repository.IssueTemplates, nil + ts := query.Repository.IssueTemplates + templates := make([]Template, len(ts)) + for i := range templates { + templates[i] = &ts[i] + } + + return templates, nil } -func hasIssueTemplateSupport(httpClient *http.Client, hostname string) (bool, error) { +func listPullRequestTemplates(httpClient *http.Client, repo ghrepo.Interface) ([]Template, error) { + var query struct { + Repository struct { + PullRequestTemplates []pullRequestTemplate + } `graphql:"repository(owner: $owner, name: $name)"` + } + + variables := map[string]interface{}{ + "owner": githubv4.String(repo.RepoOwner()), + "name": githubv4.String(repo.RepoName()), + } + + gql := graphql.NewClient(ghinstance.GraphQLEndpoint(repo.RepoHost()), httpClient) + + err := gql.QueryNamed(context.Background(), "PullRequestTemplates", &query, variables) + if err != nil { + return nil, err + } + + ts := query.Repository.PullRequestTemplates + templates := make([]Template, len(ts)) + for i := range templates { + templates[i] = &ts[i] + } + + return templates, nil +} + +func hasTemplateSupport(httpClient *http.Client, hostname string, isPR bool) (bool, error) { if !ghinstance.IsEnterprise(hostname) { return true, nil } @@ -81,20 +133,29 @@ func hasIssueTemplateSupport(httpClient *http.Client, hostname string) (bool, er return false, err } - var hasQuerySupport bool - var hasMutationSupport bool + var hasIssueQuerySupport bool + var hasIssueMutationSupport bool + var hasPullRequestQuerySupport bool + for _, field := range featureDetection.Repository.Fields { if field.Name == "issueTemplates" { - hasQuerySupport = true + hasIssueQuerySupport = true + } + if field.Name == "pullRequestTemplates" { + hasPullRequestQuerySupport = true } } for _, field := range featureDetection.CreateIssueInput.InputFields { if field.Name == "issueTemplate" { - hasMutationSupport = true + hasIssueMutationSupport = true } } - return hasQuerySupport && hasMutationSupport, nil + if isPR { + return hasPullRequestQuerySupport, nil + } else { + return hasIssueQuerySupport && hasIssueMutationSupport, nil + } } type Template interface { @@ -129,13 +190,10 @@ func NewTemplateManager(httpClient *http.Client, repo ghrepo.Interface, dir stri } func (m *templateManager) hasAPI() (bool, error) { - if m.isPR { - return false, nil - } if m.cachedClient == nil { m.cachedClient = api.NewCachedClient(m.httpClient, time.Hour*24) } - return hasIssueTemplateSupport(m.cachedClient, m.repo.RepoHost()) + return hasTemplateSupport(m.cachedClient, m.repo.RepoHost(), m.isPR) } func (m *templateManager) HasTemplates() (bool, error) { @@ -201,14 +259,15 @@ func (m *templateManager) fetch() error { } if hasAPI { - issueTemplates, err := listIssueTemplates(m.httpClient, m.repo) + lister := listIssueTemplates + if m.isPR { + lister = listPullRequestTemplates + } + templates, err := lister(m.httpClient, m.repo) if err != nil { return err } - m.templates = make([]Template, len(issueTemplates)) - for i := range issueTemplates { - m.templates[i] = &issueTemplates[i] - } + m.templates = templates } if !m.allowFS { diff --git a/pkg/cmd/pr/shared/templates_test.go b/pkg/cmd/pr/shared/templates_test.go index 03d522be6..5cce7a412 100644 --- a/pkg/cmd/pr/shared/templates_test.go +++ b/pkg/cmd/pr/shared/templates_test.go @@ -63,14 +63,70 @@ func TestTemplateManager_hasAPI(t *testing.T) { assert.Equal(t, "LEGACY", string(m.LegacyBody())) - //nolint:staticcheck // SA1019: prompt.InitAskStubber is deprecated: use NewAskStubber - as, askRestore := prompt.InitAskStubber() - defer askRestore() + as := prompt.NewAskStubber(t) + as.StubPrompt("Choose a template"). + AssertOptions([]string{"Bug report", "Feature request", "Open a blank issue"}). + AnswerWith("Feature request") - //nolint:staticcheck // SA1019: as.StubOne is deprecated: use StubPrompt - as.StubOne(1) // choose "Feature Request" tpl, err := m.Choose() + assert.NoError(t, err) assert.Equal(t, "Feature request", tpl.NameForSubmit()) assert.Equal(t, "I need a feature", string(tpl.Body())) } + +func TestTemplateManager_hasAPI_PullRequest(t *testing.T) { + rootDir := t.TempDir() + legacyTemplateFile := filepath.Join(rootDir, ".github", "PULL_REQUEST_TEMPLATE.md") + _ = os.MkdirAll(filepath.Dir(legacyTemplateFile), 0755) + _ = ioutil.WriteFile(legacyTemplateFile, []byte("LEGACY"), 0644) + + tr := httpmock.Registry{} + httpClient := &http.Client{Transport: &tr} + defer tr.Verify(t) + + tr.Register( + httpmock.GraphQL(`query IssueTemplates_fields\b`), + httpmock.StringResponse(`{"data":{ + "Repository": { + "fields": [ + {"name": "foo"}, + {"name": "pullRequestTemplates"} + ] + } + }}`)) + tr.Register( + httpmock.GraphQL(`query PullRequestTemplates\b`), + httpmock.StringResponse(`{"data":{"repository":{ + "pullRequestTemplates": [ + {"filename": "bug_pr.md", "body": "I fixed a problem"}, + {"filename": "feature_pr.md", "body": "I added a feature"} + ] + }}}`)) + + m := templateManager{ + repo: ghrepo.NewWithHost("OWNER", "REPO", "example.com"), + rootDir: rootDir, + allowFS: true, + isPR: true, + httpClient: httpClient, + cachedClient: httpClient, + } + + hasTemplates, err := m.HasTemplates() + assert.NoError(t, err) + assert.True(t, hasTemplates) + + assert.Equal(t, "LEGACY", string(m.LegacyBody())) + + as := prompt.NewAskStubber(t) + as.StubPrompt("Choose a template"). + AssertOptions([]string{"bug_pr.md", "feature_pr.md", "Open a blank pull request"}). + AnswerWith("bug_pr.md") + + tpl, err := m.Choose() + + assert.NoError(t, err) + assert.Equal(t, "", tpl.NameForSubmit()) + assert.Equal(t, "I fixed a problem", string(tpl.Body())) +}