From e021a072850dac409ce011b01238cd6b3931f7ac Mon Sep 17 00:00:00 2001 From: Heath Stewart Date: Sun, 8 Dec 2024 22:27:00 -0800 Subject: [PATCH] Confirm auto-detected base branch If interactive, confirm the automatically configured gh-merge-branch or, if not configured, the default branch. Based on PR feedback. --- pkg/cmd/pr/create/create.go | 24 +++++++- pkg/cmd/pr/create/create_test.go | 94 ++++++++++++++++++++++++++++++++ 2 files changed, 115 insertions(+), 3 deletions(-) diff --git a/pkg/cmd/pr/create/create.go b/pkg/cmd/pr/create/create.go index f3bd12870..ac747c36b 100644 --- a/pkg/cmd/pr/create/create.go +++ b/pkg/cmd/pr/create/create.go @@ -41,8 +41,9 @@ type CreateOptions struct { Finder shared.PRFinder TitledEditSurvey func(string, string) (string, string, error) - TitleProvided bool - BodyProvided bool + TitleProvided bool + BodyProvided bool + BaseBranchProvided bool RootDirOverride string RepoOverride string @@ -147,6 +148,7 @@ func NewCmdCreate(f *cmdutil.Factory, runF func(*CreateOptions) error) *cobra.Co opts.TitleProvided = cmd.Flags().Changed("title") opts.RepoOverride, _ = cmd.Flags().GetString("repo") + opts.BaseBranchProvided = cmd.Flags().Changed("base") // Workaround: Due to the way this command is implemented, we need to manually check GH_REPO. // Commands should use the standard BaseRepoOverride functionality to handle this behavior instead. if opts.RepoOverride == "" { @@ -340,7 +342,7 @@ func createRun(opts *CreateOptions) (err error) { ghrepo.FullName(ctx.BaseRepo)) } - if !opts.EditorMode && (opts.FillVerbose || opts.Autofill || opts.FillFirst || (opts.TitleProvided && opts.BodyProvided)) { + if !opts.EditorMode && (opts.FillVerbose || opts.Autofill || opts.FillFirst || (opts.TitleProvided && opts.BodyProvided && ctx.BaseBranch != "")) { err = handlePush(*opts, *ctx) if err != nil { return @@ -422,6 +424,14 @@ func createRun(opts *CreateOptions) (err error) { } } + // Confirm the automatically-selected base branch. + if !opts.BaseBranchProvided { + err = confirmTrackingBranch(opts, ctx) + if err != nil { + return + } + } + openURL, err = generateCompareURL(*ctx, *state) if err != nil { return @@ -557,6 +567,14 @@ func determineTrackingBranch(gitClient *git.Client, remotes ghContext.Remotes, h return nil } +func confirmTrackingBranch(opts *CreateOptions, ctx *CreateContext) (err error) { + ctx.BaseBranch, err = opts.Prompter.Input("Base branch", ctx.BaseBranch) + if err != nil { + return + } + return nil +} + func NewIssueState(ctx CreateContext, opts CreateOptions) (*shared.IssueMetadataState, error) { var milestoneTitles []string if opts.Milestone != "" { diff --git a/pkg/cmd/pr/create/create_test.go b/pkg/cmd/pr/create/create_test.go index 27220d052..1b4e7dbf7 100644 --- a/pkg/cmd/pr/create/create_test.go +++ b/pkg/cmd/pr/create/create_test.go @@ -1217,6 +1217,8 @@ func Test_createRun(t *testing.T) { pm.InputFunc = func(p, d string) (string, error) { if p == "Title (required)" { return d, nil + } else if p == "Base branch" { + return d, nil } else { return "", prompter.NoSuchPromptErr(p) } @@ -1323,6 +1325,8 @@ func Test_createRun(t *testing.T) { pm.InputFunc = func(p, d string) (string, error) { if p == "Title (required)" { return d, nil + } else if p == "Base branch" { + return d, nil } else if p == "Body" { return d, nil } else { @@ -1528,6 +1532,88 @@ func Test_createRun(t *testing.T) { expectedOut: "https://github.com/OWNER/REPO/pull/12\n", expectedErrOut: "\nCreating pull request for monalisa:task1 into feature/feat2 in OWNER/REPO\n\n", }, + { + name: "non-default base branch prompt", + tty: true, + setup: func(opts *CreateOptions, t *testing.T) func() { + opts.BodyProvided = true + opts.Body = "my body" + opts.Branch = func() (string, error) { + return "task1", nil + } + opts.Remotes = func() (context.Remotes, error) { + return context.Remotes{ + { + Remote: &git.Remote{ + Name: "upstream", + Resolved: "base", + }, + Repo: ghrepo.New("OWNER", "REPO"), + }, + { + Remote: &git.Remote{ + Name: "origin", + }, + Repo: ghrepo.New("monalisa", "REPO"), + }, + }, nil + } + return func() {} + }, + httpStubs: func(reg *httpmock.Registry, t *testing.T) { + reg.Register( + httpmock.GraphQL(`mutation PullRequestCreate\b`), + httpmock.GraphQLMutation(` + { "data": { "createPullRequest": { "pullRequest": { + "URL": "https://github.com/OWNER/REPO/pull/12" + } } } } + `, func(input map[string]interface{}) { + assert.Equal(t, "REPOID", input["repositoryId"].(string)) + assert.Equal(t, "my title", input["title"].(string)) + assert.Equal(t, "my body", input["body"].(string)) + assert.Equal(t, "feature/feat3", input["baseRefName"].(string)) + assert.Equal(t, "monalisa:task1", input["headRefName"].(string)) + })) + }, + customBranchConfig: true, + cmdStubs: func(cs *run.CommandStubber) { + cs.Register(`git config --get-regexp \^branch\\\.task1\\\.\(remote\|merge\|gh-merge-base\)\$`, 0, heredoc.Doc(` + branch.task1.remote origin + branch.task1.merge refs/heads/task1 + branch.task1.gh-merge-base feature/feat2`)) // ReadBranchConfig + cs.Register(`git show-ref --verify`, 0, heredoc.Doc(` + deadbeef HEAD + deadb00f refs/remotes/upstream/feature/feat2 + deadb01f refs/remotes/upstream/feature/feat3 + deadbeef refs/remotes/origin/task1`)) // determineTrackingBranch + cs.Register( + "git -c log.ShowSignature=false log --pretty=format:%H%x00%s%x00%b%x00 --cherry upstream/feature/feat2...task1", + 0, + "3a9b48085046d156c5acce8f3b3a0532cd706a4a\u0000my title\u0000my body\u0000\n") // initDefaultTitleBody from original base branch + }, + promptStubs: func(pm *prompter.PrompterMock) { + pm.InputFunc = func(p, d string) (string, error) { + switch p { + case "Title (required)": + return "my title", nil + case "Base branch": + return "feature/feat3", nil + default: + return "", prompter.NoSuchPromptErr(p) + } + } + pm.SelectFunc = func(p, _ string, opts []string) (int, error) { + if p == "What's next?" { + return 0, nil + } else { + return -1, prompter.NoSuchPromptErr(p) + } + } + }, + expectedOut: "https://github.com/OWNER/REPO/pull/12\n", + // Output message is created based on initial configuration; prompt will later override. + expectedErrOut: "\nCreating pull request for monalisa:task1 into feature/feat2 in OWNER/REPO\n\n", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -1541,6 +1627,14 @@ func Test_createRun(t *testing.T) { } pm := &prompter.PrompterMock{} + pm.InputFunc = func(p, d string) (string, error) { + switch p { + case "Base branch": + return d, nil + default: + return "", prompter.NoSuchPromptErr(p) + } + } if tt.promptStubs != nil { tt.promptStubs(pm)