From c3087cde991bce40b29e5133201507f167159115 Mon Sep 17 00:00:00 2001 From: Kynan Ware <47394200+BagToad@users.noreply.github.com> Date: Wed, 5 Mar 2025 10:51:34 -0700 Subject: [PATCH] refactor(pr create): Refactor NewCreateContext - Use prRefs instead of local vars more. - Rename variables for readability. - Improve comments. - Refactor tests. --- pkg/cmd/pr/create/create.go | 240 ++++++++++++++++--------------- pkg/cmd/pr/create/create_test.go | 77 +++++----- 2 files changed, 160 insertions(+), 157 deletions(-) diff --git a/pkg/cmd/pr/create/create.go b/pkg/cmd/pr/create/create.go index 0b4a1ca8f..e1e5186cf 100644 --- a/pkg/cmd/pr/create/create.go +++ b/pkg/cmd/pr/create/create.go @@ -75,17 +75,17 @@ type CreateOptions struct { type CreateContext struct { // This struct stores contextual data about the creation process and is for building up enough // data to create a pull request - RepoContext *ghContext.ResolvedRemotes - BaseRepo *api.Repository - HeadRepo ghrepo.Interface - BaseTrackingBranch string - BaseBranch string - HeadBranch string - HeadBranchLabel string - HeadRemote *ghContext.Remote - PromptForPushDestination bool - Client *api.Client - GitClient *git.Client + RepoContext *ghContext.ResolvedRemotes + BaseRepo *api.Repository + HeadRepo ghrepo.Interface + BaseTrackingBranch string + BaseBranch string + HeadBranch string + HeadBranchLabel string + HeadRemote *ghContext.Remote + isPushEnabled bool + Client *api.Client + GitClient *git.Client } func NewCmdCreate(f *cmdutil.Factory, runF func(*CreateOptions) error) *cobra.Command { @@ -592,112 +592,115 @@ func NewCreateContext(opts *CreateOptions) (*CreateContext, error) { if err != nil { return nil, err } - repoContext, err := ghContext.ResolveRemotesToRepos(remotes, client, opts.RepoOverride) - if err != nil { - return nil, err - } - - var baseRepo *api.Repository - if br, err := repoContext.BaseRepo(opts.IO); err == nil { - if r, ok := br.(*api.Repository); ok { - baseRepo = r - } else { - // TODO: if RepoNetwork is going to be requested anyway in `repoContext.HeadRepos()`, - // consider piggybacking on that result instead of performing a separate lookup - baseRepo, err = api.GitHubRepo(client, br) - if err != nil { - return nil, err - } - } - } else { - return nil, err - } gitClient := opts.GitClient if ucc, err := gitClient.UncommittedChangeCount(context.Background()); err == nil && ucc > 0 { fmt.Fprintf(opts.IO.ErrOut, "Warning: %s\n", text.Pluralize(ucc, "uncommitted change")) } - headBranch := opts.HeadBranch - headBranchLabel := opts.HeadBranch - promptForPushDestination := true // Whether we will prompt the user for where to push the branch. - var headRepo ghrepo.Interface - var headRemote *ghContext.Remote - var headBranchConfig git.BranchConfig - // If --head was provided, then we don't ever ask where to push. - if headBranch != "" { - promptForPushDestination = false - // If the --head provided contains a colon, that means - // this is : syntax. - if idx := strings.IndexRune(headBranch, ':'); idx >= 0 { - headBranch = headBranch[idx+1:] - } - headBranchConfig, err = gitClient.ReadBranchConfig(context.Background(), headBranch) - if err != nil { - return nil, err - } - } else { - // If --head is not specified, we'll try to determine the ref - // from the current branch. If we can't, we'll prompt the user later. - headBranch, err = opts.Branch() - if err != nil { - return nil, fmt.Errorf("could not determine the current branch: %w", err) - } - headBranchLabel = headBranch + // Resolve base repo + repoContext, err := ghContext.ResolveRemotesToRepos(remotes, client, opts.RepoOverride) + if err != nil { + return nil, err + } - headBranchConfig, err = gitClient.ReadBranchConfig(context.Background(), headBranch) - if err != nil { - return nil, err - } - - // Suppressing these errors as we have other means of computing the PullRequestRefs when these fail. - parsedPushRevision, _ := opts.GitClient.ParsePushRevision(ctx, headBranch) - - remotePushDefault, err := opts.GitClient.RemotePushDefault(ctx) - if err != nil { - return nil, err - } - - pushDefault, err := opts.GitClient.PushDefault(ctx) - if err != nil { - return nil, err - } - - prRefs, err := shared.ParsePRRefs(headBranch, headBranchConfig, parsedPushRevision, pushDefault, remotePushDefault, baseRepo, remotes) - if err != nil { - return nil, err - } - - remoteHeadCurrent := isRemoteHeadCurrent(gitClient, prRefs, remotes) - // If the remote head is up-to-date, and we have the headRef, we do not need to push anything. - if remoteHeadCurrent && prRefs.HeadRepo != nil && prRefs.BranchName != "" { - promptForPushDestination = false - headRepo = prRefs.HeadRepo - headRemote, err = remotes.FindByRepo(headRepo.RepoOwner(), headRepo.RepoName()) - // TODO: KW what does an err here mean? + var targetBaseRepo *api.Repository + if br, err := repoContext.BaseRepo(opts.IO); err == nil { + if r, ok := br.(*api.Repository); ok { + targetBaseRepo = r + } else { + // TODO: if RepoNetwork is going to be requested anyway in `repoContext.HeadRepos()`, + // consider piggybacking on that result instead of performing a separate lookup + targetBaseRepo, err = api.GitHubRepo(client, br) if err != nil { return nil, err } + } + } else { + return nil, err + } - headBranchLabel = prRefs.GetPRHeadLabel() + // Resolve target head branch name from either + // --head or the current branch. + var targetHeadBranch string + var headBranchLabel string + isPushEnabled := true + + if opts.HeadBranch != "" { + isPushEnabled = false + targetHeadBranch = opts.HeadBranch + headBranchLabel = opts.HeadBranch + // If the --head provided contains a colon, that means + // this is : syntax. + if idx := strings.IndexRune(targetHeadBranch, ':'); idx >= 0 { + targetHeadBranch = targetHeadBranch[idx+1:] + } + } else { + // Use the current branch as the target local head branch when + // --head is not provided. + targetHeadBranch, err = opts.Branch() + headBranchLabel = targetHeadBranch + if err != nil { + return nil, fmt.Errorf("could not determine the current branch: %w", err) } } - // otherwise, ask the user for the head repository using info obtained from the API - if headRepo == nil && promptForPushDestination && opts.IO.CanPrompt() { + targetHeadBranchConfig, err := gitClient.ReadBranchConfig(context.Background(), targetHeadBranch) + if err != nil { + return nil, err + } + + // See if we can determine if this branch has been push previously with + // Git configurations and @{push} revision syntax. + remotePushDefault, err := opts.GitClient.RemotePushDefault(ctx) + if err != nil { + return nil, err + } + // Suppressing these errors as we have other means of computing the PullRequestRefs when these fail. + parsedPushRevision, _ := opts.GitClient.ParsePushRevision(ctx, targetHeadBranch) + pushDefault, err := opts.GitClient.PushDefault(ctx) + if err != nil { + return nil, err + } + + prRefs, err := shared.ParsePRRefs(targetHeadBranch, targetHeadBranchConfig, parsedPushRevision, pushDefault, remotePushDefault, targetBaseRepo, remotes) + if err != nil { + return nil, err + } + + // We received the head repository and branch from ParsePRRefs, but we + // need to check if it's up-to-date with our local branch state. + // If it is, we can use it as the head remote for the PR + // and avoid prompting the user. + var headRemote *ghContext.Remote + + remoteHeadCurrent := isRemoteHeadCurrent(gitClient, prRefs, remotes) + if remoteHeadCurrent && prRefs.HeadRepo != nil && prRefs.BranchName != "" { + isPushEnabled = false + headRemote, err = remotes.FindByRepo(prRefs.HeadRepo.RepoOwner(), prRefs.HeadRepo.RepoName()) + // TODO: KW what does an err here mean? + // If we fail to find a remote for that repo, shouldn't we just try to prompt + // for head repos? + if err != nil { + return nil, err + } + headBranchLabel = prRefs.GetPRHeadLabel() + } else if isPushEnabled && opts.IO.CanPrompt() { + // Since we could not determine a head ref, prompt the user for the head repository to push + // using a list of repositories obtained from the API pushableRepos, err := repoContext.HeadRepos() if err != nil { return nil, err } if len(pushableRepos) == 0 { - pushableRepos, err = api.RepoFindForks(client, baseRepo, 3) + pushableRepos, err = api.RepoFindForks(client, prRefs.BaseRepo, 3) if err != nil { return nil, err } } - currentLogin, err := api.CurrentLoginName(client, baseRepo.RepoHost()) + currentLogin, err := api.CurrentLoginName(client, prRefs.BaseRepo.RepoHost()) if err != nil { return nil, err } @@ -712,64 +715,65 @@ func NewCreateContext(opts *CreateOptions) (*CreateContext, error) { } if !hasOwnFork { - pushOptions = append(pushOptions, "Create a fork of "+ghrepo.FullName(baseRepo)) + pushOptions = append(pushOptions, "Create a fork of "+ghrepo.FullName(prRefs.BaseRepo)) } pushOptions = append(pushOptions, "Skip pushing the branch") pushOptions = append(pushOptions, "Cancel") - selectedOption, err := opts.Prompter.Select(fmt.Sprintf("Where should we push the '%s' branch?", headBranch), "", pushOptions) + selectedOption, err := opts.Prompter.Select(fmt.Sprintf("Where should we push the '%s' branch?", prRefs.BranchName), "", pushOptions) if err != nil { return nil, err } if selectedOption < len(pushableRepos) { - headRepo = pushableRepos[selectedOption] - if !ghrepo.IsSame(baseRepo, headRepo) { - headBranchLabel = fmt.Sprintf("%s:%s", headRepo.RepoOwner(), headBranch) + prRefs.HeadRepo = pushableRepos[selectedOption] + if !ghrepo.IsSame(prRefs.BaseRepo, prRefs.HeadRepo) { + headBranchLabel = fmt.Sprintf("%s:%s", prRefs.HeadRepo.RepoOwner(), prRefs.BranchName) } } else if pushOptions[selectedOption] == "Skip pushing the branch" { - promptForPushDestination = false + isPushEnabled = false } else if pushOptions[selectedOption] == "Cancel" { return nil, cmdutil.CancelError } else { // "Create a fork of ..." - headBranchLabel = fmt.Sprintf("%s:%s", currentLogin, headBranch) + headBranchLabel = fmt.Sprintf("%s:%s", currentLogin, prRefs.BranchName) + prRefs.HeadRepo = nil } } - if headRepo == nil && promptForPushDestination && !opts.IO.CanPrompt() { + if prRefs.HeadRepo == nil && isPushEnabled && !opts.IO.CanPrompt() { fmt.Fprintf(opts.IO.ErrOut, "aborted: you must first push the current branch to a remote, or use the --head flag") return nil, cmdutil.SilentError } baseBranch := opts.BaseBranch if baseBranch == "" { - baseBranch = headBranchConfig.MergeBase + baseBranch = targetHeadBranchConfig.MergeBase } if baseBranch == "" { - baseBranch = baseRepo.DefaultBranchRef.Name + baseBranch = targetBaseRepo.DefaultBranchRef.Name } - if headBranch == baseBranch && headRepo != nil && ghrepo.IsSame(baseRepo, headRepo) { + if prRefs.BranchName == baseBranch && prRefs.HeadRepo != nil && ghrepo.IsSame(prRefs.BaseRepo, prRefs.HeadRepo) { return nil, fmt.Errorf("must be on a branch named differently than %q", baseBranch) } baseTrackingBranch := baseBranch - if baseRemote, err := remotes.FindByRepo(baseRepo.RepoOwner(), baseRepo.RepoName()); err == nil { + if baseRemote, err := remotes.FindByRepo(prRefs.BaseRepo.RepoOwner(), prRefs.BaseRepo.RepoName()); err == nil { baseTrackingBranch = fmt.Sprintf("%s/%s", baseRemote.Name, baseBranch) } return &CreateContext{ - BaseRepo: baseRepo, - HeadRepo: headRepo, - BaseBranch: baseBranch, - BaseTrackingBranch: baseTrackingBranch, - HeadBranch: headBranch, - HeadBranchLabel: headBranchLabel, - HeadRemote: headRemote, - PromptForPushDestination: promptForPushDestination, - RepoContext: repoContext, - Client: client, - GitClient: gitClient, + BaseRepo: prRefs.BaseRepo.(*api.Repository), + HeadRepo: prRefs.HeadRepo, + BaseBranch: baseBranch, // Currently not supported by shared.PullRequestRefs struct + BaseTrackingBranch: baseTrackingBranch, + HeadBranch: prRefs.BranchName, + HeadBranchLabel: headBranchLabel, + HeadRemote: headRemote, + isPushEnabled: isPushEnabled, + RepoContext: repoContext, + Client: client, + GitClient: gitClient, }, nil } @@ -921,7 +925,7 @@ func handlePush(opts CreateOptions, ctx CreateContext) error { var err error // if a head repository could not be determined so far, automatically create // one by forking the base repository - if headRepo == nil && ctx.PromptForPushDestination { + if headRepo == nil && ctx.isPushEnabled { opts.IO.StartProgressIndicator() headRepo, err = api.ForkRepo(client, ctx.BaseRepo, "", "", false) opts.IO.StopProgressIndicator() @@ -943,7 +947,7 @@ func handlePush(opts CreateOptions, ctx CreateContext) error { // can push to it. We will try to add the head repo as the "origin" remote // and fallback to the "fork" remote if it is unavailable. Also, if the // base repo is the "origin" remote we will rename it "upstream". - if headRemote == nil && ctx.PromptForPushDestination { + if headRemote == nil && ctx.isPushEnabled { cfg, err := opts.Config() if err != nil { return err @@ -1005,7 +1009,7 @@ func handlePush(opts CreateOptions, ctx CreateContext) error { } // automatically push the branch if it hasn't been pushed anywhere yet - if ctx.PromptForPushDestination { + if ctx.isPushEnabled { pushBranch := func() error { w := NewRegexpWriter(opts.IO.ErrOut, gitPushRegexp, "") defer w.Flush() diff --git a/pkg/cmd/pr/create/create_test.go b/pkg/cmd/pr/create/create_test.go index a94dab48a..8353e8c4b 100644 --- a/pkg/cmd/pr/create/create_test.go +++ b/pkg/cmd/pr/create/create_test.go @@ -332,18 +332,19 @@ func TestNewCmdCreate(t *testing.T) { func Test_createRun(t *testing.T) { tests := []struct { - name string - setup func(*CreateOptions, *testing.T) func() - cmdStubs func(*run.CommandStubber) - promptStubs func(*prompter.PrompterMock) - httpStubs func(*httpmock.Registry, *testing.T) - expectedOutputs []string - expectedOut string - expectedErrOut string - expectedBrowse string - wantErr string - tty bool - customBranchConfig bool + name string + setup func(*CreateOptions, *testing.T) func() + cmdStubs func(*run.CommandStubber) + promptStubs func(*prompter.PrompterMock) + httpStubs func(*httpmock.Registry, *testing.T) + expectedOutputs []string + expectedOut string + expectedErrOut string + expectedBrowse string + wantErr string + tty bool + customBranchConfig bool + customPushDestination bool }{ { name: "nontty web", @@ -636,10 +637,6 @@ func Test_createRun(t *testing.T) { })) }, cmdStubs: func(cs *run.CommandStubber) { - cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 0, "") - cs.Register("git rev-parse --abbrev-ref feature@{push}", 0, "origin/feature") - cs.Register("git config remote.pushDefault", 0, "") - cs.Register("git config push.default", 0, "") cs.Register(`git push --set-upstream origin HEAD:refs/heads/feature`, 0, "") }, promptStubs: func(pm *prompter.PrompterMock) { @@ -702,10 +699,6 @@ func Test_createRun(t *testing.T) { })) }, cmdStubs: func(cs *run.CommandStubber) { - cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 0, "") - cs.Register("git rev-parse --abbrev-ref feature@{push}", 0, "origin/feature") - cs.Register("git config remote.pushDefault", 0, "") - cs.Register("git config push.default", 0, "") cs.Register(`git push --set-upstream origin HEAD:refs/heads/feature`, 0, "") }, promptStubs: func(pm *prompter.PrompterMock) { @@ -751,10 +744,6 @@ func Test_createRun(t *testing.T) { })) }, cmdStubs: func(cs *run.CommandStubber) { - cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 0, "") - cs.Register("git rev-parse --abbrev-ref feature@{push}", 0, "origin/feature") - cs.Register("git config remote.pushDefault", 0, "") - cs.Register("git config push.default", 0, "") cs.Register(`git push --set-upstream origin HEAD:refs/heads/feature`, 0, "") }, promptStubs: func(pm *prompter.PrompterMock) { @@ -802,9 +791,10 @@ func Test_createRun(t *testing.T) { assert.Equal(t, "monalisa:feature", input["headRefName"].(string)) })) }, + customPushDestination: true, cmdStubs: func(cs *run.CommandStubber) { cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 0, "") - cs.Register("git rev-parse --abbrev-ref feature@{push}", 0, "origin/feature") + cs.Register("git rev-parse --abbrev-ref feature@{push}", 0, "") cs.Register("git config remote.pushDefault", 0, "") cs.Register("git config push.default", 0, "") cs.Register("git remote rename origin upstream", 0, "") @@ -864,6 +854,7 @@ func Test_createRun(t *testing.T) { assert.Equal(t, "monalisa:feature", input["headRefName"].(string)) })) }, + customPushDestination: true, cmdStubs: func(cs *run.CommandStubber) { cs.Register("git show-ref --verify", 0, heredoc.Doc(` deadbeef HEAD @@ -899,7 +890,8 @@ func Test_createRun(t *testing.T) { assert.Equal(t, "my-feat2", input["headRefName"].(string)) })) }, - customBranchConfig: true, + customBranchConfig: true, + customPushDestination: true, cmdStubs: func(cs *run.CommandStubber) { cs.Register(`git config --get-regexp \^branch\\\.feature\\\.`, 0, heredoc.Doc(` branch.feature.remote origin @@ -1091,11 +1083,7 @@ func Test_createRun(t *testing.T) { httpmock.StringResponse(`{"data": {"viewer": {"login": "OWNER"} } }`)) }, cmdStubs: func(cs *run.CommandStubber) { - cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 0, "") cs.Register(`git( .+)? log( .+)? origin/master\.\.\.feature`, 0, "") - cs.Register("git rev-parse --abbrev-ref feature@{push}", 0, "origin/feature") - cs.Register("git config remote.pushDefault", 0, "") - cs.Register("git config push.default", 0, "") cs.Register(`git push --set-upstream origin HEAD:refs/heads/feature`, 0, "") }, promptStubs: func(pm *prompter.PrompterMock) { @@ -1126,11 +1114,7 @@ func Test_createRun(t *testing.T) { mockRetrieveProjects(t, reg) }, cmdStubs: func(cs *run.CommandStubber) { - cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 0, "") cs.Register(`git( .+)? log( .+)? origin/master\.\.\.feature`, 0, "") - cs.Register("git rev-parse --abbrev-ref feature@{push}", 0, "origin/feature") - cs.Register("git config remote.pushDefault", 0, "") - cs.Register("git config push.default", 0, "") cs.Register(`git push --set-upstream origin HEAD:refs/heads/feature`, 0, "") }, promptStubs: func(pm *prompter.PrompterMock) { @@ -1313,11 +1297,17 @@ func Test_createRun(t *testing.T) { reg.Register( httpmock.GraphQL(`mutation PullRequestCreate\b`), httpmock.StringResponse(` - { "data": { "createPullRequest": { "pullRequest": { - "URL": "https://github.com/OWNER/REPO/pull/12" - } } } } + { "data": { "createPullRequest": { "pullRequest": { + "URL": "https://github.com/OWNER/REPO/pull/12" + } } } } `)) }, + customPushDestination: true, + cmdStubs: func(cs *run.CommandStubber) { + cs.Register("git rev-parse --abbrev-ref feature@{push}", 0, "") + cs.Register("git config remote.pushDefault", 0, "") + cs.Register("git config push.default", 0, "") + }, expectedOut: "https://github.com/OWNER/REPO/pull/12\n", }, { @@ -1538,7 +1528,8 @@ func Test_createRun(t *testing.T) { assert.Equal(t, "monalisa:task1", input["headRefName"].(string)) })) }, - customBranchConfig: true, + customBranchConfig: true, + customPushDestination: true, cmdStubs: func(cs *run.CommandStubber) { cs.Register(`git config --get-regexp \^branch\\\.task1\\\.\(remote\|merge\|pushremote\|gh-merge-base\)\$`, 0, heredoc.Doc(` branch.task1.remote origin @@ -1586,7 +1577,6 @@ func Test_createRun(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { branch := "feature" - reg := &httpmock.Registry{} reg.StubRepoInfoResponse("OWNER", "REPO", "master") defer reg.Verify(t) @@ -1603,6 +1593,15 @@ func Test_createRun(t *testing.T) { cs, cmdTeardown := run.Stub() defer cmdTeardown(t) cs.Register(`git status --porcelain`, 0, "") + // TODO this could be be values in the test struct with a helper + // function to invoke the apporpriate command stubs based on + // those values. + if !tt.customPushDestination { + cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 0, "") + cs.Register("git rev-parse --abbrev-ref feature@{push}", 0, "origin/feature") + cs.Register("git config remote.pushDefault", 0, "") + cs.Register("git config push.default", 0, "") + } if !tt.customBranchConfig { cs.Register(`git config --get-regexp \^branch\\\..+\\\.\(remote\|merge\|pushremote\|gh-merge-base\)\$`, 0, "") }