From c0c5d9123deefc45814daa716ab25e92c70e2f63 Mon Sep 17 00:00:00 2001 From: Kynan Ware <47394200+BagToad@users.noreply.github.com> Date: Wed, 5 Mar 2025 14:02:57 -0700 Subject: [PATCH] refactor(pr create): use GetPRHeadLabel() Use PrRefs.GetPRHeadLabel() instead of headBranchLabel. Also remove headBranchLabel from CreateContext struct. To do this, we needed a new identifier for when the head repo should be created via a new fork of the base repo. Previously, this was done by checking if the head repo was nil, but if we want to call GetPRHeadLabel(), it requires a non-nil head repo to construct the headBranchLabel. So, instead of the head repo being nil to signal a fork, we pass a new forkHeadRepo bool in the CreateContext struct. This also makes the decision to fork more intentional; now the decision is made clearly instead of if the headRepo happens to be nil. --- pkg/cmd/pr/create/create.go | 46 +++++++++++++++++-------------- pkg/cmd/pr/create/create_test.go | 47 ++++++++++++++++++++------------ 2 files changed, 55 insertions(+), 38 deletions(-) diff --git a/pkg/cmd/pr/create/create.go b/pkg/cmd/pr/create/create.go index 2ba3d2d10..371b14dbe 100644 --- a/pkg/cmd/pr/create/create.go +++ b/pkg/cmd/pr/create/create.go @@ -79,9 +79,9 @@ type CreateContext struct { PrRefs shared.PullRequestRefs BaseTrackingBranch string BaseBranch string // Currently not supported by shared.PullRequestRefs struct - HeadBranchLabel string HeadRemote *ghContext.Remote isPushEnabled bool + forkHeadRepo bool Client *api.Client GitClient *git.Client } @@ -308,7 +308,7 @@ func createRun(opts *CreateOptions) error { } existingPR, _, err := opts.Finder.Find(shared.FindOptions{ - Selector: ctx.HeadBranchLabel, + Selector: ctx.PrRefs.GetPRHeadLabel(), BaseBranch: ctx.BaseBranch, States: []string{"OPEN"}, Fields: []string{"url"}, @@ -319,7 +319,7 @@ func createRun(opts *CreateOptions) error { } if err == nil { return fmt.Errorf("a pull request for branch %q into branch %q already exists:\n%s", - ctx.HeadBranchLabel, ctx.BaseBranch, existingPR.URL) + ctx.PrRefs.GetPRHeadLabel(), ctx.BaseBranch, existingPR.URL) } message := "\nCreating pull request for %s into %s in %s\n\n" @@ -334,7 +334,7 @@ func createRun(opts *CreateOptions) error { if opts.IO.CanPrompt() { fmt.Fprintf(opts.IO.ErrOut, message, - cs.Cyan(ctx.HeadBranchLabel), + cs.Cyan(ctx.PrRefs.GetPRHeadLabel()), cs.Cyan(ctx.BaseBranch), ghrepo.FullName(ctx.PrRefs.BaseRepo)) } @@ -621,23 +621,23 @@ func NewCreateContext(opts *CreateOptions) (*CreateContext, error) { // Resolve target head branch name from either // --head or the current branch. var targetHeadBranch string - var headBranchLabel string + var targetHeadRepoOwner string + isPushEnabled := true if opts.HeadBranch != "" { isPushEnabled = false targetHeadBranch = opts.HeadBranch - headBranchLabel = opts.HeadBranch // If the --head provided contains a colon, that means // this is : syntax. - if idx := strings.IndexRune(targetHeadBranch, ':'); idx >= 0 { - targetHeadBranch = targetHeadBranch[idx+1:] + if idx := strings.IndexRune(opts.HeadBranch, ':'); idx >= 0 { + targetHeadRepoOwner = opts.HeadBranch[:idx] + targetHeadBranch = opts.HeadBranch[idx+1:] } } else { // Use the current branch as the target local head branch when // --head is not provided. targetHeadBranch, err = opts.Branch() - headBranchLabel = targetHeadBranch if err != nil { return nil, fmt.Errorf("could not determine the current branch: %w", err) } @@ -666,11 +666,19 @@ func NewCreateContext(opts *CreateOptions) (*CreateContext, error) { return nil, err } - // We received the head repository and branch from ParsePRRefs, but we - // need to check if it's up-to-date with our local branch state. + // If the --head provided contains : syntax, we need to use + // the provided owner instead of the owner of the base repository. + if targetHeadRepoOwner != "" { + prRefs.HeadRepo = ghrepo.New(targetHeadRepoOwner, prRefs.HeadRepo.RepoName()) + } + + // We received the head repository and branch from ParsePRRefs, or inferred + // it from --head input, but we need to check if it's up-to-date with + // our local branch state. // If it is, we can use it as the head remote for the PR // and avoid prompting the user. var headRemote *ghContext.Remote + var forkHeadRepo bool remoteHeadCurrent := isRemoteHeadCurrent(gitClient, prRefs, remotes) if remoteHeadCurrent && prRefs.HeadRepo != nil && prRefs.BranchName != "" { @@ -682,7 +690,6 @@ func NewCreateContext(opts *CreateOptions) (*CreateContext, error) { if err != nil { return nil, err } - headBranchLabel = prRefs.GetPRHeadLabel() } else if isPushEnabled && opts.IO.CanPrompt() { // Since we could not determine a head ref, prompt the user for the head repository to push // using a list of repositories obtained from the API @@ -725,17 +732,14 @@ func NewCreateContext(opts *CreateOptions) (*CreateContext, error) { if selectedOption < len(pushableRepos) { prRefs.HeadRepo = pushableRepos[selectedOption] - if !ghrepo.IsSame(prRefs.BaseRepo, prRefs.HeadRepo) { - headBranchLabel = fmt.Sprintf("%s:%s", prRefs.HeadRepo.RepoOwner(), prRefs.BranchName) - } } else if pushOptions[selectedOption] == "Skip pushing the branch" { isPushEnabled = false } else if pushOptions[selectedOption] == "Cancel" { return nil, cmdutil.CancelError } else { // "Create a fork of ..." - headBranchLabel = fmt.Sprintf("%s:%s", currentLogin, prRefs.BranchName) - prRefs.HeadRepo = nil + forkHeadRepo = true + prRefs.HeadRepo = ghrepo.New(currentLogin, prRefs.HeadRepo.RepoName()) } } @@ -764,9 +768,9 @@ func NewCreateContext(opts *CreateOptions) (*CreateContext, error) { PrRefs: prRefs, BaseBranch: baseBranch, // Currently not supported by shared.PullRequestRefs struct BaseTrackingBranch: baseTrackingBranch, - HeadBranchLabel: headBranchLabel, HeadRemote: headRemote, isPushEnabled: isPushEnabled, + forkHeadRepo: forkHeadRepo, RepoContext: repoContext, Client: client, GitClient: gitClient, @@ -794,7 +798,7 @@ func submitPR(opts CreateOptions, ctx CreateContext, state shared.IssueMetadataS "body": state.Body, "draft": state.Draft, "baseRefName": ctx.BaseBranch, - "headRefName": ctx.HeadBranchLabel, + "headRefName": ctx.PrRefs.GetPRHeadLabel(), "maintainerCanModify": opts.MaintainerCanModify, } @@ -921,7 +925,7 @@ func handlePush(opts CreateOptions, ctx CreateContext) error { var err error // if a head repository could not be determined so far, automatically create // one by forking the base repository - if headRepo == nil && ctx.isPushEnabled { + if ctx.forkHeadRepo && ctx.isPushEnabled { opts.IO.StartProgressIndicator() headRepo, err = api.ForkRepo(client, ctx.PrRefs.BaseRepo, "", "", false) opts.IO.StopProgressIndicator() @@ -1038,7 +1042,7 @@ func generateCompareURL(ctx CreateContext, state shared.IssueMetadataState) (str u := ghrepo.GenerateRepoURL( ctx.PrRefs.BaseRepo, "compare/%s...%s?expand=1", - url.PathEscape(ctx.BaseBranch), url.PathEscape(ctx.HeadBranchLabel)) + url.PathEscape(ctx.BaseBranch), url.PathEscape(ctx.PrRefs.GetPRHeadLabel())) url, err := shared.WithPrAndIssueQueryParams(ctx.Client, ctx.PrRefs.BaseRepo, u, state) if err != nil { return "", err diff --git a/pkg/cmd/pr/create/create_test.go b/pkg/cmd/pr/create/create_test.go index a69ab3b3a..c07195d40 100644 --- a/pkg/cmd/pr/create/create_test.go +++ b/pkg/cmd/pr/create/create_test.go @@ -1571,6 +1571,12 @@ func Test_createRun(t *testing.T) { opts.HeadBranch = "otherowner:feature" return func() {} }, + customPushDestination: true, + cmdStubs: func(cs *run.CommandStubber) { + cs.Register("git rev-parse --abbrev-ref feature@{push}", 0, "origin/feature") + cs.Register("git config remote.pushDefault", 0, "") + cs.Register("git config push.default", 0, "") + }, expectedOut: "https://github.com/OWNER/REPO/pull/12\n", }, } @@ -1687,10 +1693,11 @@ func Test_generateCompareURL(t *testing.T) { name: "basic", ctx: CreateContext{ PrRefs: shared.PullRequestRefs{ - BaseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"), + BaseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"), + HeadRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"), + BranchName: "feature", }, - BaseBranch: "main", - HeadBranchLabel: "feature", + BaseBranch: "main", }, want: "https://github.com/OWNER/REPO/compare/main...feature?body=&expand=1", wantErr: false, @@ -1699,10 +1706,11 @@ func Test_generateCompareURL(t *testing.T) { name: "with labels", ctx: CreateContext{ PrRefs: shared.PullRequestRefs{ - BaseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"), + BaseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"), + HeadRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"), + BranchName: "b", }, - BaseBranch: "a", - HeadBranchLabel: "b", + BaseBranch: "a", }, state: shared.IssueMetadataState{ Labels: []string{"one", "two three"}, @@ -1714,12 +1722,13 @@ func Test_generateCompareURL(t *testing.T) { name: "'/'s in branch names/labels are percent-encoded", ctx: CreateContext{ PrRefs: shared.PullRequestRefs{ - BaseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"), + BaseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER-UPSTREAM"}}, "github.com"), + HeadRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"), + BranchName: "feature", }, - BaseBranch: "main/trunk", - HeadBranchLabel: "owner:feature", + BaseBranch: "main/trunk", }, - want: "https://github.com/OWNER/REPO/compare/main%2Ftrunk...owner:feature?body=&expand=1", + want: "https://github.com/OWNER-UPSTREAM/REPO/compare/main%2Ftrunk...OWNER:feature?body=&expand=1", wantErr: false, }, { @@ -1732,22 +1741,26 @@ func Test_generateCompareURL(t *testing.T) { */ ctx: CreateContext{ PrRefs: shared.PullRequestRefs{ - BaseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"), + BaseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER-UPSTREAM"}}, "github.com"), + HeadRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"), + BranchName: "!$&'()+,;=@", }, - BaseBranch: "main/trunk", - HeadBranchLabel: "owner:!$&'()+,;=@", + BaseBranch: "main/trunk", + //TODO check this + // HeadBranchLabel: "owner:!$&'()+,;=@", }, - want: "https://github.com/OWNER/REPO/compare/main%2Ftrunk...owner:%21$&%27%28%29+%2C%3B=@?body=&expand=1", + want: "https://github.com/OWNER-UPSTREAM/REPO/compare/main%2Ftrunk...OWNER:%21$&%27%28%29+%2C%3B=@?body=&expand=1", wantErr: false, }, { name: "with template", ctx: CreateContext{ PrRefs: shared.PullRequestRefs{ - BaseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"), + BaseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"), + HeadRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"), + BranchName: "feature", }, - BaseBranch: "main", - HeadBranchLabel: "feature", + BaseBranch: "main", }, state: shared.IssueMetadataState{ Template: "story.md",