refactor(pr create): use GetPRHeadLabel()

Use PrRefs.GetPRHeadLabel() instead of headBranchLabel.
Also remove headBranchLabel from CreateContext struct.

To do this, we needed a new identifier for when the head repo should be
created via a new fork of the base repo. Previously, this was done by
checking if the head repo was nil, but if we want to call
GetPRHeadLabel(), it requires a non-nil head repo to construct the
headBranchLabel. So, instead of the head repo being nil to signal
a fork, we pass a new forkHeadRepo bool in the CreateContext struct.
This also makes the decision to fork more intentional; now the decision
is made clearly instead of if the headRepo happens to be nil.
This commit is contained in:
Kynan Ware 2025-03-05 14:02:57 -07:00
parent 178fb40515
commit c0c5d9123d
2 changed files with 55 additions and 38 deletions

View file

@ -79,9 +79,9 @@ type CreateContext struct {
PrRefs shared.PullRequestRefs
BaseTrackingBranch string
BaseBranch string // Currently not supported by shared.PullRequestRefs struct
HeadBranchLabel string
HeadRemote *ghContext.Remote
isPushEnabled bool
forkHeadRepo bool
Client *api.Client
GitClient *git.Client
}
@ -308,7 +308,7 @@ func createRun(opts *CreateOptions) error {
}
existingPR, _, err := opts.Finder.Find(shared.FindOptions{
Selector: ctx.HeadBranchLabel,
Selector: ctx.PrRefs.GetPRHeadLabel(),
BaseBranch: ctx.BaseBranch,
States: []string{"OPEN"},
Fields: []string{"url"},
@ -319,7 +319,7 @@ func createRun(opts *CreateOptions) error {
}
if err == nil {
return fmt.Errorf("a pull request for branch %q into branch %q already exists:\n%s",
ctx.HeadBranchLabel, ctx.BaseBranch, existingPR.URL)
ctx.PrRefs.GetPRHeadLabel(), ctx.BaseBranch, existingPR.URL)
}
message := "\nCreating pull request for %s into %s in %s\n\n"
@ -334,7 +334,7 @@ func createRun(opts *CreateOptions) error {
if opts.IO.CanPrompt() {
fmt.Fprintf(opts.IO.ErrOut, message,
cs.Cyan(ctx.HeadBranchLabel),
cs.Cyan(ctx.PrRefs.GetPRHeadLabel()),
cs.Cyan(ctx.BaseBranch),
ghrepo.FullName(ctx.PrRefs.BaseRepo))
}
@ -621,23 +621,23 @@ func NewCreateContext(opts *CreateOptions) (*CreateContext, error) {
// Resolve target head branch name from either
// --head or the current branch.
var targetHeadBranch string
var headBranchLabel string
var targetHeadRepoOwner string
isPushEnabled := true
if opts.HeadBranch != "" {
isPushEnabled = false
targetHeadBranch = opts.HeadBranch
headBranchLabel = opts.HeadBranch
// If the --head provided contains a colon, that means
// this is <remote>:<branch> syntax.
if idx := strings.IndexRune(targetHeadBranch, ':'); idx >= 0 {
targetHeadBranch = targetHeadBranch[idx+1:]
if idx := strings.IndexRune(opts.HeadBranch, ':'); idx >= 0 {
targetHeadRepoOwner = opts.HeadBranch[:idx]
targetHeadBranch = opts.HeadBranch[idx+1:]
}
} else {
// Use the current branch as the target local head branch when
// --head is not provided.
targetHeadBranch, err = opts.Branch()
headBranchLabel = targetHeadBranch
if err != nil {
return nil, fmt.Errorf("could not determine the current branch: %w", err)
}
@ -666,11 +666,19 @@ func NewCreateContext(opts *CreateOptions) (*CreateContext, error) {
return nil, err
}
// We received the head repository and branch from ParsePRRefs, but we
// need to check if it's up-to-date with our local branch state.
// If the --head provided contains <owner>:<branch> syntax, we need to use
// the provided owner instead of the owner of the base repository.
if targetHeadRepoOwner != "" {
prRefs.HeadRepo = ghrepo.New(targetHeadRepoOwner, prRefs.HeadRepo.RepoName())
}
// We received the head repository and branch from ParsePRRefs, or inferred
// it from --head input, but we need to check if it's up-to-date with
// our local branch state.
// If it is, we can use it as the head remote for the PR
// and avoid prompting the user.
var headRemote *ghContext.Remote
var forkHeadRepo bool
remoteHeadCurrent := isRemoteHeadCurrent(gitClient, prRefs, remotes)
if remoteHeadCurrent && prRefs.HeadRepo != nil && prRefs.BranchName != "" {
@ -682,7 +690,6 @@ func NewCreateContext(opts *CreateOptions) (*CreateContext, error) {
if err != nil {
return nil, err
}
headBranchLabel = prRefs.GetPRHeadLabel()
} else if isPushEnabled && opts.IO.CanPrompt() {
// Since we could not determine a head ref, prompt the user for the head repository to push
// using a list of repositories obtained from the API
@ -725,17 +732,14 @@ func NewCreateContext(opts *CreateOptions) (*CreateContext, error) {
if selectedOption < len(pushableRepos) {
prRefs.HeadRepo = pushableRepos[selectedOption]
if !ghrepo.IsSame(prRefs.BaseRepo, prRefs.HeadRepo) {
headBranchLabel = fmt.Sprintf("%s:%s", prRefs.HeadRepo.RepoOwner(), prRefs.BranchName)
}
} else if pushOptions[selectedOption] == "Skip pushing the branch" {
isPushEnabled = false
} else if pushOptions[selectedOption] == "Cancel" {
return nil, cmdutil.CancelError
} else {
// "Create a fork of ..."
headBranchLabel = fmt.Sprintf("%s:%s", currentLogin, prRefs.BranchName)
prRefs.HeadRepo = nil
forkHeadRepo = true
prRefs.HeadRepo = ghrepo.New(currentLogin, prRefs.HeadRepo.RepoName())
}
}
@ -764,9 +768,9 @@ func NewCreateContext(opts *CreateOptions) (*CreateContext, error) {
PrRefs: prRefs,
BaseBranch: baseBranch, // Currently not supported by shared.PullRequestRefs struct
BaseTrackingBranch: baseTrackingBranch,
HeadBranchLabel: headBranchLabel,
HeadRemote: headRemote,
isPushEnabled: isPushEnabled,
forkHeadRepo: forkHeadRepo,
RepoContext: repoContext,
Client: client,
GitClient: gitClient,
@ -794,7 +798,7 @@ func submitPR(opts CreateOptions, ctx CreateContext, state shared.IssueMetadataS
"body": state.Body,
"draft": state.Draft,
"baseRefName": ctx.BaseBranch,
"headRefName": ctx.HeadBranchLabel,
"headRefName": ctx.PrRefs.GetPRHeadLabel(),
"maintainerCanModify": opts.MaintainerCanModify,
}
@ -921,7 +925,7 @@ func handlePush(opts CreateOptions, ctx CreateContext) error {
var err error
// if a head repository could not be determined so far, automatically create
// one by forking the base repository
if headRepo == nil && ctx.isPushEnabled {
if ctx.forkHeadRepo && ctx.isPushEnabled {
opts.IO.StartProgressIndicator()
headRepo, err = api.ForkRepo(client, ctx.PrRefs.BaseRepo, "", "", false)
opts.IO.StopProgressIndicator()
@ -1038,7 +1042,7 @@ func generateCompareURL(ctx CreateContext, state shared.IssueMetadataState) (str
u := ghrepo.GenerateRepoURL(
ctx.PrRefs.BaseRepo,
"compare/%s...%s?expand=1",
url.PathEscape(ctx.BaseBranch), url.PathEscape(ctx.HeadBranchLabel))
url.PathEscape(ctx.BaseBranch), url.PathEscape(ctx.PrRefs.GetPRHeadLabel()))
url, err := shared.WithPrAndIssueQueryParams(ctx.Client, ctx.PrRefs.BaseRepo, u, state)
if err != nil {
return "", err

View file

@ -1571,6 +1571,12 @@ func Test_createRun(t *testing.T) {
opts.HeadBranch = "otherowner:feature"
return func() {}
},
customPushDestination: true,
cmdStubs: func(cs *run.CommandStubber) {
cs.Register("git rev-parse --abbrev-ref feature@{push}", 0, "origin/feature")
cs.Register("git config remote.pushDefault", 0, "")
cs.Register("git config push.default", 0, "")
},
expectedOut: "https://github.com/OWNER/REPO/pull/12\n",
},
}
@ -1687,10 +1693,11 @@ func Test_generateCompareURL(t *testing.T) {
name: "basic",
ctx: CreateContext{
PrRefs: shared.PullRequestRefs{
BaseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"),
BaseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"),
HeadRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"),
BranchName: "feature",
},
BaseBranch: "main",
HeadBranchLabel: "feature",
BaseBranch: "main",
},
want: "https://github.com/OWNER/REPO/compare/main...feature?body=&expand=1",
wantErr: false,
@ -1699,10 +1706,11 @@ func Test_generateCompareURL(t *testing.T) {
name: "with labels",
ctx: CreateContext{
PrRefs: shared.PullRequestRefs{
BaseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"),
BaseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"),
HeadRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"),
BranchName: "b",
},
BaseBranch: "a",
HeadBranchLabel: "b",
BaseBranch: "a",
},
state: shared.IssueMetadataState{
Labels: []string{"one", "two three"},
@ -1714,12 +1722,13 @@ func Test_generateCompareURL(t *testing.T) {
name: "'/'s in branch names/labels are percent-encoded",
ctx: CreateContext{
PrRefs: shared.PullRequestRefs{
BaseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"),
BaseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER-UPSTREAM"}}, "github.com"),
HeadRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"),
BranchName: "feature",
},
BaseBranch: "main/trunk",
HeadBranchLabel: "owner:feature",
BaseBranch: "main/trunk",
},
want: "https://github.com/OWNER/REPO/compare/main%2Ftrunk...owner:feature?body=&expand=1",
want: "https://github.com/OWNER-UPSTREAM/REPO/compare/main%2Ftrunk...OWNER:feature?body=&expand=1",
wantErr: false,
},
{
@ -1732,22 +1741,26 @@ func Test_generateCompareURL(t *testing.T) {
*/
ctx: CreateContext{
PrRefs: shared.PullRequestRefs{
BaseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"),
BaseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER-UPSTREAM"}}, "github.com"),
HeadRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"),
BranchName: "!$&'()+,;=@",
},
BaseBranch: "main/trunk",
HeadBranchLabel: "owner:!$&'()+,;=@",
BaseBranch: "main/trunk",
//TODO check this
// HeadBranchLabel: "owner:!$&'()+,;=@",
},
want: "https://github.com/OWNER/REPO/compare/main%2Ftrunk...owner:%21$&%27%28%29+%2C%3B=@?body=&expand=1",
want: "https://github.com/OWNER-UPSTREAM/REPO/compare/main%2Ftrunk...OWNER:%21$&%27%28%29+%2C%3B=@?body=&expand=1",
wantErr: false,
},
{
name: "with template",
ctx: CreateContext{
PrRefs: shared.PullRequestRefs{
BaseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"),
BaseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"),
HeadRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"),
BranchName: "feature",
},
BaseBranch: "main",
HeadBranchLabel: "feature",
BaseBranch: "main",
},
state: shared.IssueMetadataState{
Template: "story.md",