refactor(pr create): Refactor NewCreateContext

- Use prRefs instead of local vars more.
- Rename variables for readability.
- Improve comments.
- Refactor tests.
This commit is contained in:
Kynan Ware 2025-03-05 10:51:34 -07:00
parent a5fe37f91b
commit c3087cde99
2 changed files with 160 additions and 157 deletions

View file

@ -75,17 +75,17 @@ type CreateOptions struct {
type CreateContext struct {
// This struct stores contextual data about the creation process and is for building up enough
// data to create a pull request
RepoContext *ghContext.ResolvedRemotes
BaseRepo *api.Repository
HeadRepo ghrepo.Interface
BaseTrackingBranch string
BaseBranch string
HeadBranch string
HeadBranchLabel string
HeadRemote *ghContext.Remote
PromptForPushDestination bool
Client *api.Client
GitClient *git.Client
RepoContext *ghContext.ResolvedRemotes
BaseRepo *api.Repository
HeadRepo ghrepo.Interface
BaseTrackingBranch string
BaseBranch string
HeadBranch string
HeadBranchLabel string
HeadRemote *ghContext.Remote
isPushEnabled bool
Client *api.Client
GitClient *git.Client
}
func NewCmdCreate(f *cmdutil.Factory, runF func(*CreateOptions) error) *cobra.Command {
@ -592,112 +592,115 @@ func NewCreateContext(opts *CreateOptions) (*CreateContext, error) {
if err != nil {
return nil, err
}
repoContext, err := ghContext.ResolveRemotesToRepos(remotes, client, opts.RepoOverride)
if err != nil {
return nil, err
}
var baseRepo *api.Repository
if br, err := repoContext.BaseRepo(opts.IO); err == nil {
if r, ok := br.(*api.Repository); ok {
baseRepo = r
} else {
// TODO: if RepoNetwork is going to be requested anyway in `repoContext.HeadRepos()`,
// consider piggybacking on that result instead of performing a separate lookup
baseRepo, err = api.GitHubRepo(client, br)
if err != nil {
return nil, err
}
}
} else {
return nil, err
}
gitClient := opts.GitClient
if ucc, err := gitClient.UncommittedChangeCount(context.Background()); err == nil && ucc > 0 {
fmt.Fprintf(opts.IO.ErrOut, "Warning: %s\n", text.Pluralize(ucc, "uncommitted change"))
}
headBranch := opts.HeadBranch
headBranchLabel := opts.HeadBranch
promptForPushDestination := true // Whether we will prompt the user for where to push the branch.
var headRepo ghrepo.Interface
var headRemote *ghContext.Remote
var headBranchConfig git.BranchConfig
// If --head was provided, then we don't ever ask where to push.
if headBranch != "" {
promptForPushDestination = false
// If the --head provided contains a colon, that means
// this is <remote>:<branch> syntax.
if idx := strings.IndexRune(headBranch, ':'); idx >= 0 {
headBranch = headBranch[idx+1:]
}
headBranchConfig, err = gitClient.ReadBranchConfig(context.Background(), headBranch)
if err != nil {
return nil, err
}
} else {
// If --head is not specified, we'll try to determine the ref
// from the current branch. If we can't, we'll prompt the user later.
headBranch, err = opts.Branch()
if err != nil {
return nil, fmt.Errorf("could not determine the current branch: %w", err)
}
headBranchLabel = headBranch
// Resolve base repo
repoContext, err := ghContext.ResolveRemotesToRepos(remotes, client, opts.RepoOverride)
if err != nil {
return nil, err
}
headBranchConfig, err = gitClient.ReadBranchConfig(context.Background(), headBranch)
if err != nil {
return nil, err
}
// Suppressing these errors as we have other means of computing the PullRequestRefs when these fail.
parsedPushRevision, _ := opts.GitClient.ParsePushRevision(ctx, headBranch)
remotePushDefault, err := opts.GitClient.RemotePushDefault(ctx)
if err != nil {
return nil, err
}
pushDefault, err := opts.GitClient.PushDefault(ctx)
if err != nil {
return nil, err
}
prRefs, err := shared.ParsePRRefs(headBranch, headBranchConfig, parsedPushRevision, pushDefault, remotePushDefault, baseRepo, remotes)
if err != nil {
return nil, err
}
remoteHeadCurrent := isRemoteHeadCurrent(gitClient, prRefs, remotes)
// If the remote head is up-to-date, and we have the headRef, we do not need to push anything.
if remoteHeadCurrent && prRefs.HeadRepo != nil && prRefs.BranchName != "" {
promptForPushDestination = false
headRepo = prRefs.HeadRepo
headRemote, err = remotes.FindByRepo(headRepo.RepoOwner(), headRepo.RepoName())
// TODO: KW what does an err here mean?
var targetBaseRepo *api.Repository
if br, err := repoContext.BaseRepo(opts.IO); err == nil {
if r, ok := br.(*api.Repository); ok {
targetBaseRepo = r
} else {
// TODO: if RepoNetwork is going to be requested anyway in `repoContext.HeadRepos()`,
// consider piggybacking on that result instead of performing a separate lookup
targetBaseRepo, err = api.GitHubRepo(client, br)
if err != nil {
return nil, err
}
}
} else {
return nil, err
}
headBranchLabel = prRefs.GetPRHeadLabel()
// Resolve target head branch name from either
// --head or the current branch.
var targetHeadBranch string
var headBranchLabel string
isPushEnabled := true
if opts.HeadBranch != "" {
isPushEnabled = false
targetHeadBranch = opts.HeadBranch
headBranchLabel = opts.HeadBranch
// If the --head provided contains a colon, that means
// this is <remote>:<branch> syntax.
if idx := strings.IndexRune(targetHeadBranch, ':'); idx >= 0 {
targetHeadBranch = targetHeadBranch[idx+1:]
}
} else {
// Use the current branch as the target local head branch when
// --head is not provided.
targetHeadBranch, err = opts.Branch()
headBranchLabel = targetHeadBranch
if err != nil {
return nil, fmt.Errorf("could not determine the current branch: %w", err)
}
}
// otherwise, ask the user for the head repository using info obtained from the API
if headRepo == nil && promptForPushDestination && opts.IO.CanPrompt() {
targetHeadBranchConfig, err := gitClient.ReadBranchConfig(context.Background(), targetHeadBranch)
if err != nil {
return nil, err
}
// See if we can determine if this branch has been push previously with
// Git configurations and @{push} revision syntax.
remotePushDefault, err := opts.GitClient.RemotePushDefault(ctx)
if err != nil {
return nil, err
}
// Suppressing these errors as we have other means of computing the PullRequestRefs when these fail.
parsedPushRevision, _ := opts.GitClient.ParsePushRevision(ctx, targetHeadBranch)
pushDefault, err := opts.GitClient.PushDefault(ctx)
if err != nil {
return nil, err
}
prRefs, err := shared.ParsePRRefs(targetHeadBranch, targetHeadBranchConfig, parsedPushRevision, pushDefault, remotePushDefault, targetBaseRepo, remotes)
if err != nil {
return nil, err
}
// We received the head repository and branch from ParsePRRefs, but we
// need to check if it's up-to-date with our local branch state.
// If it is, we can use it as the head remote for the PR
// and avoid prompting the user.
var headRemote *ghContext.Remote
remoteHeadCurrent := isRemoteHeadCurrent(gitClient, prRefs, remotes)
if remoteHeadCurrent && prRefs.HeadRepo != nil && prRefs.BranchName != "" {
isPushEnabled = false
headRemote, err = remotes.FindByRepo(prRefs.HeadRepo.RepoOwner(), prRefs.HeadRepo.RepoName())
// TODO: KW what does an err here mean?
// If we fail to find a remote for that repo, shouldn't we just try to prompt
// for head repos?
if err != nil {
return nil, err
}
headBranchLabel = prRefs.GetPRHeadLabel()
} else if isPushEnabled && opts.IO.CanPrompt() {
// Since we could not determine a head ref, prompt the user for the head repository to push
// using a list of repositories obtained from the API
pushableRepos, err := repoContext.HeadRepos()
if err != nil {
return nil, err
}
if len(pushableRepos) == 0 {
pushableRepos, err = api.RepoFindForks(client, baseRepo, 3)
pushableRepos, err = api.RepoFindForks(client, prRefs.BaseRepo, 3)
if err != nil {
return nil, err
}
}
currentLogin, err := api.CurrentLoginName(client, baseRepo.RepoHost())
currentLogin, err := api.CurrentLoginName(client, prRefs.BaseRepo.RepoHost())
if err != nil {
return nil, err
}
@ -712,64 +715,65 @@ func NewCreateContext(opts *CreateOptions) (*CreateContext, error) {
}
if !hasOwnFork {
pushOptions = append(pushOptions, "Create a fork of "+ghrepo.FullName(baseRepo))
pushOptions = append(pushOptions, "Create a fork of "+ghrepo.FullName(prRefs.BaseRepo))
}
pushOptions = append(pushOptions, "Skip pushing the branch")
pushOptions = append(pushOptions, "Cancel")
selectedOption, err := opts.Prompter.Select(fmt.Sprintf("Where should we push the '%s' branch?", headBranch), "", pushOptions)
selectedOption, err := opts.Prompter.Select(fmt.Sprintf("Where should we push the '%s' branch?", prRefs.BranchName), "", pushOptions)
if err != nil {
return nil, err
}
if selectedOption < len(pushableRepos) {
headRepo = pushableRepos[selectedOption]
if !ghrepo.IsSame(baseRepo, headRepo) {
headBranchLabel = fmt.Sprintf("%s:%s", headRepo.RepoOwner(), headBranch)
prRefs.HeadRepo = pushableRepos[selectedOption]
if !ghrepo.IsSame(prRefs.BaseRepo, prRefs.HeadRepo) {
headBranchLabel = fmt.Sprintf("%s:%s", prRefs.HeadRepo.RepoOwner(), prRefs.BranchName)
}
} else if pushOptions[selectedOption] == "Skip pushing the branch" {
promptForPushDestination = false
isPushEnabled = false
} else if pushOptions[selectedOption] == "Cancel" {
return nil, cmdutil.CancelError
} else {
// "Create a fork of ..."
headBranchLabel = fmt.Sprintf("%s:%s", currentLogin, headBranch)
headBranchLabel = fmt.Sprintf("%s:%s", currentLogin, prRefs.BranchName)
prRefs.HeadRepo = nil
}
}
if headRepo == nil && promptForPushDestination && !opts.IO.CanPrompt() {
if prRefs.HeadRepo == nil && isPushEnabled && !opts.IO.CanPrompt() {
fmt.Fprintf(opts.IO.ErrOut, "aborted: you must first push the current branch to a remote, or use the --head flag")
return nil, cmdutil.SilentError
}
baseBranch := opts.BaseBranch
if baseBranch == "" {
baseBranch = headBranchConfig.MergeBase
baseBranch = targetHeadBranchConfig.MergeBase
}
if baseBranch == "" {
baseBranch = baseRepo.DefaultBranchRef.Name
baseBranch = targetBaseRepo.DefaultBranchRef.Name
}
if headBranch == baseBranch && headRepo != nil && ghrepo.IsSame(baseRepo, headRepo) {
if prRefs.BranchName == baseBranch && prRefs.HeadRepo != nil && ghrepo.IsSame(prRefs.BaseRepo, prRefs.HeadRepo) {
return nil, fmt.Errorf("must be on a branch named differently than %q", baseBranch)
}
baseTrackingBranch := baseBranch
if baseRemote, err := remotes.FindByRepo(baseRepo.RepoOwner(), baseRepo.RepoName()); err == nil {
if baseRemote, err := remotes.FindByRepo(prRefs.BaseRepo.RepoOwner(), prRefs.BaseRepo.RepoName()); err == nil {
baseTrackingBranch = fmt.Sprintf("%s/%s", baseRemote.Name, baseBranch)
}
return &CreateContext{
BaseRepo: baseRepo,
HeadRepo: headRepo,
BaseBranch: baseBranch,
BaseTrackingBranch: baseTrackingBranch,
HeadBranch: headBranch,
HeadBranchLabel: headBranchLabel,
HeadRemote: headRemote,
PromptForPushDestination: promptForPushDestination,
RepoContext: repoContext,
Client: client,
GitClient: gitClient,
BaseRepo: prRefs.BaseRepo.(*api.Repository),
HeadRepo: prRefs.HeadRepo,
BaseBranch: baseBranch, // Currently not supported by shared.PullRequestRefs struct
BaseTrackingBranch: baseTrackingBranch,
HeadBranch: prRefs.BranchName,
HeadBranchLabel: headBranchLabel,
HeadRemote: headRemote,
isPushEnabled: isPushEnabled,
RepoContext: repoContext,
Client: client,
GitClient: gitClient,
}, nil
}
@ -921,7 +925,7 @@ func handlePush(opts CreateOptions, ctx CreateContext) error {
var err error
// if a head repository could not be determined so far, automatically create
// one by forking the base repository
if headRepo == nil && ctx.PromptForPushDestination {
if headRepo == nil && ctx.isPushEnabled {
opts.IO.StartProgressIndicator()
headRepo, err = api.ForkRepo(client, ctx.BaseRepo, "", "", false)
opts.IO.StopProgressIndicator()
@ -943,7 +947,7 @@ func handlePush(opts CreateOptions, ctx CreateContext) error {
// can push to it. We will try to add the head repo as the "origin" remote
// and fallback to the "fork" remote if it is unavailable. Also, if the
// base repo is the "origin" remote we will rename it "upstream".
if headRemote == nil && ctx.PromptForPushDestination {
if headRemote == nil && ctx.isPushEnabled {
cfg, err := opts.Config()
if err != nil {
return err
@ -1005,7 +1009,7 @@ func handlePush(opts CreateOptions, ctx CreateContext) error {
}
// automatically push the branch if it hasn't been pushed anywhere yet
if ctx.PromptForPushDestination {
if ctx.isPushEnabled {
pushBranch := func() error {
w := NewRegexpWriter(opts.IO.ErrOut, gitPushRegexp, "")
defer w.Flush()

View file

@ -332,18 +332,19 @@ func TestNewCmdCreate(t *testing.T) {
func Test_createRun(t *testing.T) {
tests := []struct {
name string
setup func(*CreateOptions, *testing.T) func()
cmdStubs func(*run.CommandStubber)
promptStubs func(*prompter.PrompterMock)
httpStubs func(*httpmock.Registry, *testing.T)
expectedOutputs []string
expectedOut string
expectedErrOut string
expectedBrowse string
wantErr string
tty bool
customBranchConfig bool
name string
setup func(*CreateOptions, *testing.T) func()
cmdStubs func(*run.CommandStubber)
promptStubs func(*prompter.PrompterMock)
httpStubs func(*httpmock.Registry, *testing.T)
expectedOutputs []string
expectedOut string
expectedErrOut string
expectedBrowse string
wantErr string
tty bool
customBranchConfig bool
customPushDestination bool
}{
{
name: "nontty web",
@ -636,10 +637,6 @@ func Test_createRun(t *testing.T) {
}))
},
cmdStubs: func(cs *run.CommandStubber) {
cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 0, "")
cs.Register("git rev-parse --abbrev-ref feature@{push}", 0, "origin/feature")
cs.Register("git config remote.pushDefault", 0, "")
cs.Register("git config push.default", 0, "")
cs.Register(`git push --set-upstream origin HEAD:refs/heads/feature`, 0, "")
},
promptStubs: func(pm *prompter.PrompterMock) {
@ -702,10 +699,6 @@ func Test_createRun(t *testing.T) {
}))
},
cmdStubs: func(cs *run.CommandStubber) {
cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 0, "")
cs.Register("git rev-parse --abbrev-ref feature@{push}", 0, "origin/feature")
cs.Register("git config remote.pushDefault", 0, "")
cs.Register("git config push.default", 0, "")
cs.Register(`git push --set-upstream origin HEAD:refs/heads/feature`, 0, "")
},
promptStubs: func(pm *prompter.PrompterMock) {
@ -751,10 +744,6 @@ func Test_createRun(t *testing.T) {
}))
},
cmdStubs: func(cs *run.CommandStubber) {
cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 0, "")
cs.Register("git rev-parse --abbrev-ref feature@{push}", 0, "origin/feature")
cs.Register("git config remote.pushDefault", 0, "")
cs.Register("git config push.default", 0, "")
cs.Register(`git push --set-upstream origin HEAD:refs/heads/feature`, 0, "")
},
promptStubs: func(pm *prompter.PrompterMock) {
@ -802,9 +791,10 @@ func Test_createRun(t *testing.T) {
assert.Equal(t, "monalisa:feature", input["headRefName"].(string))
}))
},
customPushDestination: true,
cmdStubs: func(cs *run.CommandStubber) {
cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 0, "")
cs.Register("git rev-parse --abbrev-ref feature@{push}", 0, "origin/feature")
cs.Register("git rev-parse --abbrev-ref feature@{push}", 0, "")
cs.Register("git config remote.pushDefault", 0, "")
cs.Register("git config push.default", 0, "")
cs.Register("git remote rename origin upstream", 0, "")
@ -864,6 +854,7 @@ func Test_createRun(t *testing.T) {
assert.Equal(t, "monalisa:feature", input["headRefName"].(string))
}))
},
customPushDestination: true,
cmdStubs: func(cs *run.CommandStubber) {
cs.Register("git show-ref --verify", 0, heredoc.Doc(`
deadbeef HEAD
@ -899,7 +890,8 @@ func Test_createRun(t *testing.T) {
assert.Equal(t, "my-feat2", input["headRefName"].(string))
}))
},
customBranchConfig: true,
customBranchConfig: true,
customPushDestination: true,
cmdStubs: func(cs *run.CommandStubber) {
cs.Register(`git config --get-regexp \^branch\\\.feature\\\.`, 0, heredoc.Doc(`
branch.feature.remote origin
@ -1091,11 +1083,7 @@ func Test_createRun(t *testing.T) {
httpmock.StringResponse(`{"data": {"viewer": {"login": "OWNER"} } }`))
},
cmdStubs: func(cs *run.CommandStubber) {
cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 0, "")
cs.Register(`git( .+)? log( .+)? origin/master\.\.\.feature`, 0, "")
cs.Register("git rev-parse --abbrev-ref feature@{push}", 0, "origin/feature")
cs.Register("git config remote.pushDefault", 0, "")
cs.Register("git config push.default", 0, "")
cs.Register(`git push --set-upstream origin HEAD:refs/heads/feature`, 0, "")
},
promptStubs: func(pm *prompter.PrompterMock) {
@ -1126,11 +1114,7 @@ func Test_createRun(t *testing.T) {
mockRetrieveProjects(t, reg)
},
cmdStubs: func(cs *run.CommandStubber) {
cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 0, "")
cs.Register(`git( .+)? log( .+)? origin/master\.\.\.feature`, 0, "")
cs.Register("git rev-parse --abbrev-ref feature@{push}", 0, "origin/feature")
cs.Register("git config remote.pushDefault", 0, "")
cs.Register("git config push.default", 0, "")
cs.Register(`git push --set-upstream origin HEAD:refs/heads/feature`, 0, "")
},
promptStubs: func(pm *prompter.PrompterMock) {
@ -1313,11 +1297,17 @@ func Test_createRun(t *testing.T) {
reg.Register(
httpmock.GraphQL(`mutation PullRequestCreate\b`),
httpmock.StringResponse(`
{ "data": { "createPullRequest": { "pullRequest": {
"URL": "https://github.com/OWNER/REPO/pull/12"
} } } }
{ "data": { "createPullRequest": { "pullRequest": {
"URL": "https://github.com/OWNER/REPO/pull/12"
} } } }
`))
},
customPushDestination: true,
cmdStubs: func(cs *run.CommandStubber) {
cs.Register("git rev-parse --abbrev-ref feature@{push}", 0, "")
cs.Register("git config remote.pushDefault", 0, "")
cs.Register("git config push.default", 0, "")
},
expectedOut: "https://github.com/OWNER/REPO/pull/12\n",
},
{
@ -1538,7 +1528,8 @@ func Test_createRun(t *testing.T) {
assert.Equal(t, "monalisa:task1", input["headRefName"].(string))
}))
},
customBranchConfig: true,
customBranchConfig: true,
customPushDestination: true,
cmdStubs: func(cs *run.CommandStubber) {
cs.Register(`git config --get-regexp \^branch\\\.task1\\\.\(remote\|merge\|pushremote\|gh-merge-base\)\$`, 0, heredoc.Doc(`
branch.task1.remote origin
@ -1586,7 +1577,6 @@ func Test_createRun(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
branch := "feature"
reg := &httpmock.Registry{}
reg.StubRepoInfoResponse("OWNER", "REPO", "master")
defer reg.Verify(t)
@ -1603,6 +1593,15 @@ func Test_createRun(t *testing.T) {
cs, cmdTeardown := run.Stub()
defer cmdTeardown(t)
cs.Register(`git status --porcelain`, 0, "")
// TODO this could be be values in the test struct with a helper
// function to invoke the apporpriate command stubs based on
// those values.
if !tt.customPushDestination {
cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 0, "")
cs.Register("git rev-parse --abbrev-ref feature@{push}", 0, "origin/feature")
cs.Register("git config remote.pushDefault", 0, "")
cs.Register("git config push.default", 0, "")
}
if !tt.customBranchConfig {
cs.Register(`git config --get-regexp \^branch\\\..+\\\.\(remote\|merge\|pushremote\|gh-merge-base\)\$`, 0, "")
}