Rework ref usage when finding and creating PRs

This commit is contained in:
William Martin 2025-04-02 13:21:10 +02:00 committed by Kynan Ware
parent ebd147b43e
commit a9dbda6913
28 changed files with 2254 additions and 1110 deletions

View file

@ -12,6 +12,7 @@ defer gh repo delete --yes ${ORG}/${REPO}
# Create a fork
exec gh repo fork ${ORG}/${REPO} --org ${ORG} --fork-name ${REPO}-fork
sleep 5
# Defer fork cleanup
defer gh repo delete --yes ${ORG}/${REPO}-fork

View file

@ -0,0 +1,46 @@
env REPO=${SCRIPT_NAME}-${RANDOM_STRING}
env FORK=${REPO}-fork
# Use gh as a credential helper
exec gh auth setup-git
# Get the current username for the fork owner
exec gh api user --jq .login
stdout2env USER
# Create a repository with a file so it has a default branch
exec gh repo create ${ORG}/${REPO} --add-readme --private
# Defer repo cleanup
defer gh repo delete --yes ${ORG}/${REPO}
# Create a user fork of repository. This will be owned by USER.
exec gh repo fork ${ORG}/${REPO} --fork-name ${FORK}
sleep 5
# Defer repo cleanup of fork
defer gh repo delete --yes ${USER}/${FORK}
# Retrieve fork repository information
exec gh repo view ${USER}/${FORK} --json id --jq '.id'
stdout2env FORK_ID
exec gh repo clone ${USER}/${FORK}
cd ${FORK}
# Prepare a branch to commit
exec git checkout -b feature-branch
exec git commit --allow-empty -m 'Upstream Commit'
exec git push upstream feature-branch
# Prepare an additional commit
exec git commit --allow-empty -m 'Fork Commit'
exec git push origin feature-branch
# Create the PR
exec gh pr create --title 'Feature Title' --body 'Feature Body'
stdout https://${GH_HOST}/${ORG}/${REPO}/pull/1
# Check the PR is indeed created
exec gh pr view ${USER}:feature-branch --json headRefName,headRepository,baseRefName,isCrossRepository
stdout {"baseRefName":"main","headRefName":"feature-branch","headRepository":{"id":"${FORK_ID}","name":"${FORK}"},"isCrossRepository":true}

View file

@ -19,12 +19,12 @@ defer gh repo delete --yes ${ORG}/${REPO}
# Create a user fork of repository. This will be owned by USER.
exec gh repo fork ${ORG}/${REPO} --fork-name ${FORK}
sleep 5
# Defer repo cleanup of fork
defer gh repo delete --yes ${USER}/${FORK}
# Retrieve fork repository information
sleep 5
exec gh repo view ${USER}/${FORK} --json id --jq '.id'
stdout2env FORK_ID

View file

@ -19,12 +19,12 @@ defer gh repo delete --yes ${ORG}/${REPO}
# Create a user fork of repository. This will be owned by USER.
exec gh repo fork ${ORG}/${REPO} --fork-name ${FORK}
sleep 5
# Defer repo cleanup of fork
defer gh repo delete --yes ${USER}/${FORK}
# Retrieve fork repository information
sleep 5
exec gh repo view ${USER}/${FORK} --json id --jq '.id'
stdout2env FORK_ID
@ -50,4 +50,4 @@ stdout https://${GH_HOST}/${ORG}/${REPO}/pull/1
# Assert that the PR was created with the correct head repository and refs
exec gh pr view --json headRefName,headRepository,baseRefName,isCrossRepository
stdout {"baseRefName":"main","headRefName":"feature-branch","headRepository":{"id":"${FORK_ID}","name":"${FORK}"},"isCrossRepository":true}
stdout {"baseRefName":"main","headRefName":"feature-branch","headRepository":{"id":"${FORK_ID}","name":"${FORK}"},"isCrossRepository":true}

View file

@ -19,12 +19,12 @@ defer gh repo delete --yes ${ORG}/${REPO}
# Create a user fork of repository. This will be owned by USER.
exec gh repo fork ${ORG}/${REPO} --fork-name ${FORK}
sleep 5
# Defer repo cleanup of fork
defer gh repo delete --yes ${USER}/${FORK}
# Retrieve fork repository information
sleep 5
exec gh repo view ${USER}/${FORK} --json id --jq '.id'
stdout2env FORK_ID

View file

@ -19,16 +19,16 @@ defer gh repo delete --yes ${ORG}/${REPO}
# Create a user fork of repository. This will be owned by USER.
exec gh repo fork ${ORG}/${REPO} --fork-name ${FORK}
sleep 5
# Defer repo cleanup of fork
defer gh repo delete --yes ${USER}/${FORK}
# Retrieve fork repository information
sleep 5
exec gh repo view ${USER}/${FORK} --json id --jq '.id'
stdout2env FORK_ID
# Clone the repo
# Clone the fork
exec gh repo clone ${USER}/${FORK}
cd ${FORK}

View file

@ -1,20 +1,22 @@
# This test is the same as pr-create-basic, except that the git push doesn't include the -u argument
# This causes a git config read to fail during gh pr create, but it should not be fatal
env REPO=${SCRIPT_NAME}-${RANDOM_STRING}
# Use gh as a credential helper
exec gh auth setup-git
# Create a repository with a file so it has a default branch
exec gh repo create $ORG/$SCRIPT_NAME-$RANDOM_STRING --add-readme --private
exec gh repo create ${ORG}/${REPO} --add-readme --private
# Defer repo cleanup
defer gh repo delete --yes $ORG/$SCRIPT_NAME-$RANDOM_STRING
defer gh repo delete --yes ${ORG}/${REPO}
# Clone the repo
exec gh repo clone $ORG/$SCRIPT_NAME-$RANDOM_STRING
exec gh repo clone ${ORG}/${REPO}
# Prepare a branch to PR
cd $SCRIPT_NAME-$RANDOM_STRING
cd ${REPO}
exec git checkout -b feature-branch
exec git commit --allow-empty -m 'Empty Commit'
exec git push origin feature-branch

View file

@ -0,0 +1,46 @@
skip 'it creates a fork owned by the user running the test'
# Setup environment variables used for testscript
env REPO=${SCRIPT_NAME}-${RANDOM_STRING}
env FORK=${REPO}-fork
# Use gh as a credential helper
exec gh auth setup-git
# Get the current username for the fork owner
exec gh api user --jq .login
stdout2env USER
# Create a repository to act as upstream with a file so it has a default branch
exec gh repo create ${ORG}/${REPO} --add-readme --private
# Defer repo cleanup of upstream
defer gh repo delete --yes ${ORG}/${REPO}
# Create a user fork of repository. This will be owned by USER.
exec gh repo fork ${ORG}/${REPO} --fork-name ${FORK}
sleep 5
# Defer repo cleanup of fork
defer gh repo delete --yes ${USER}/${FORK}
# Retrieve fork repository information
exec gh repo view ${USER}/${FORK} --json id --jq '.id'
stdout2env FORK_ID
# Clone the repo
exec gh repo clone ${USER}/${FORK}
cd ${FORK}
# Prepare a branch where changes are pulled from the upstream default branch but pushed to fork
exec git checkout -b feature-branch
exec git commit --allow-empty -m 'Empty Commit'
exec git push -u origin feature-branch
# Create the PR spanning upstream and fork repositories
exec gh pr create --title 'Feature Title' --body 'Feature Body'
stdout https://${GH_HOST}/${ORG}/${REPO}/pull/1
# Assert that the PR was created with the correct head repository and refs
exec gh pr status
! stdout 'There is no pull request associated with'

View file

@ -15,10 +15,11 @@ stdout2env REPO_ID
# Create a fork in the same org
exec gh repo fork ${ORG}/${REPO} --org ${ORG} --fork-name ${FORK}
sleep 5
# Defer repo cleanup of fork
defer gh repo delete --yes ${ORG}/${FORK}
sleep 1
exec gh repo view ${ORG}/${FORK} --json id --jq '.id'
stdout2env FORK_ID

View file

@ -15,10 +15,11 @@ stdout2env REPO_ID
# Create a user fork of repository as opposed to private organization fork
exec gh repo fork ${ORG}/${REPO} --org ${ORG} --fork-name ${FORK}
sleep 5
# Defer repo cleanup of fork
defer gh repo delete --yes ${ORG}/${FORK}
sleep 5
exec gh repo view ${ORG}/${FORK} --json id --jq '.id'
stdout2env FORK_ID

View file

@ -15,10 +15,11 @@ stdout2env REPO_ID
# Create a user fork of repository as opposed to private organization fork
exec gh repo fork ${ORG}/${REPO} --org ${ORG} --fork-name ${FORK}
sleep 5
# Defer repo cleanup of fork
defer gh repo delete --yes ${ORG}/${FORK}
sleep 5
exec gh repo view ${ORG}/${FORK} --json id --jq '.id'
stdout2env FORK_ID

View file

@ -9,13 +9,11 @@ defer gh repo delete --yes $ORG/$SCRIPT_NAME-$RANDOM_STRING
# Fork and clone the repo
exec gh repo fork $ORG/$SCRIPT_NAME-$RANDOM_STRING --org $ORG --fork-name $SCRIPT_NAME-$RANDOM_STRING-fork --clone
sleep 5
# Defer fork cleanup
defer gh repo delete $ORG/$SCRIPT_NAME-$RANDOM_STRING-fork --yes
# Sleep so that the BE has time to sync
sleep 5
# Check that the repo was forked
exec gh repo view $ORG/$SCRIPT_NAME-$RANDOM_STRING-fork --json='isFork' --jq='.isFork'
stdout 'true'

View file

@ -12,13 +12,11 @@ defer gh repo delete --yes ${ORG}/${REPO}
# Create a fork
exec gh repo fork ${ORG}/${REPO} --org ${ORG} --fork-name ${REPO}-fork
sleep 5
# Defer fork cleanup
defer gh repo delete --yes ${ORG}/${REPO}-fork
# Sleep to allow the fork to be created before cloning
sleep 2
# Clone and move into the fork repo
exec gh repo clone ${ORG}/${REPO}-fork
cd ${REPO}-fork

View file

@ -381,7 +381,6 @@ func (c *Client) lookupCommit(ctx context.Context, sha, format string) ([]byte,
// Downstream consumers of ReadBranchConfig should consider the behavior they desire if this errors,
// as an empty config is not necessarily breaking.
func (c *Client) ReadBranchConfig(ctx context.Context, branch string) (BranchConfig, error) {
prefix := regexp.QuoteMeta(fmt.Sprintf("branch.%s.", branch))
args := []string{"config", "--get-regexp", fmt.Sprintf("^%s(remote|merge|pushremote|%s)$", prefix, MergeBaseConfig)}
cmd, err := c.Command(ctx, args...)
@ -441,18 +440,50 @@ func (c *Client) SetBranchConfig(ctx context.Context, branch, name, value string
return err
}
// PushDefault defines the action git push should take if no refspec is given.
// See: https://git-scm.com/docs/git-config#Documentation/git-config.txt-pushdefault
type PushDefault string
const (
PushDefaultNothing PushDefault = "nothing"
PushDefaultCurrent PushDefault = "current"
PushDefaultUpstream PushDefault = "upstream"
PushDefaultTracking PushDefault = "tracking"
PushDefaultSimple PushDefault = "simple"
PushDefaultMatching PushDefault = "matching"
)
func ParsePushDefault(s string) (PushDefault, error) {
validPushDefaults := map[string]struct{}{
string(PushDefaultNothing): {},
string(PushDefaultCurrent): {},
string(PushDefaultUpstream): {},
string(PushDefaultTracking): {},
string(PushDefaultSimple): {},
string(PushDefaultMatching): {},
}
if _, ok := validPushDefaults[s]; ok {
return PushDefault(s), nil
}
return "", fmt.Errorf("unknown push.default value: %s", s)
}
// PushDefault returns the value of push.default in the config. If the value
// is not set, it returns "simple" (the default git value). See
// https://git-scm.com/docs/git-config#Documentation/git-config.txt-pushdefault
func (c *Client) PushDefault(ctx context.Context) (string, error) {
func (c *Client) PushDefault(ctx context.Context) (PushDefault, error) {
pushDefault, err := c.Config(ctx, "push.default")
if err == nil {
return pushDefault, nil
return ParsePushDefault(pushDefault)
}
// If there is an error that the config key is not set, return the default value
// that git uses since 2.0.
var gitError *GitError
if ok := errors.As(err, &gitError); ok && gitError.ExitCode == 1 {
return "simple", nil
return PushDefaultSimple, nil
}
return "", err
}
@ -473,13 +504,48 @@ func (c *Client) RemotePushDefault(ctx context.Context) (string, error) {
return "", err
}
// ParsePushRevision gets the value of the @{push} revision syntax
// RemoteTrackingRef is the structured form of the string "refs/remotes/<remote>/<branch>".
// For example, the @{push} revision syntax could report "refs/remotes/origin/main" which would
// be parsed into RemoteTrackingRef{Remote: "origin", Branch: "main"}.
type RemoteTrackingRef struct {
Remote string
Branch string
}
func (r RemoteTrackingRef) String() string {
return fmt.Sprintf("refs/remotes/%s/%s", r.Remote, r.Branch)
}
// ParseRemoteTrackingRef parses a string of the form "refs/remotes/<remote>/<branch>" into
// a RemoteTrackingBranch struct. If the string does not match this format, an error is returned.
func ParseRemoteTrackingRef(s string) (RemoteTrackingRef, error) {
parts := strings.Split(s, "/")
if len(parts) != 4 || parts[0] != "refs" || parts[1] != "remotes" {
return RemoteTrackingRef{}, fmt.Errorf("remote tracking branch must have format refs/remotes/<remote>/<branch> but was: %s", s)
}
return RemoteTrackingRef{
Remote: parts[2],
Branch: parts[3],
}, nil
}
// PushRevision gets the value of the @{push} revision syntax
// An error here doesn't necessarily mean something is broken, but may mean that the @{push}
// revision syntax couldn't be resolved, such as in non-centralized workflows with
// push.default = simple. Downstream consumers should consider how to handle this error.
func (c *Client) ParsePushRevision(ctx context.Context, branch string) (string, error) {
revParseOut, err := c.revParse(ctx, "--abbrev-ref", branch+"@{push}")
return firstLine(revParseOut), err
func (c *Client) PushRevision(ctx context.Context, branch string) (RemoteTrackingRef, error) {
revParseOut, err := c.revParse(ctx, "--symbolic-full-name", branch+"@{push}")
if err != nil {
return RemoteTrackingRef{}, err
}
ref, err := ParseRemoteTrackingRef(firstLine(revParseOut))
if err != nil {
return RemoteTrackingRef{}, fmt.Errorf("could not parse push revision: %v", err)
}
return ref, nil
}
func (c *Client) DeleteLocalTag(ctx context.Context, tag string) error {

View file

@ -952,7 +952,7 @@ func TestClientPushDefault(t *testing.T) {
tests := []struct {
name string
commandResult commandResult
wantPushDefault string
wantPushDefault PushDefault
wantError *GitError
}{
{
@ -961,7 +961,7 @@ func TestClientPushDefault(t *testing.T) {
ExitStatus: 1,
Stderr: "error: key does not contain a section: remote.pushDefault",
},
wantPushDefault: "simple",
wantPushDefault: PushDefaultSimple,
wantError: nil,
},
{
@ -970,7 +970,7 @@ func TestClientPushDefault(t *testing.T) {
ExitStatus: 0,
Stdout: "current",
},
wantPushDefault: "current",
wantPushDefault: PushDefaultCurrent,
wantError: nil,
},
{
@ -1077,17 +1077,17 @@ func TestClientParsePushRevision(t *testing.T) {
name string
branch string
commandResult commandResult
wantParsedPushRevision string
wantError *GitError
wantParsedPushRevision RemoteTrackingRef
wantError error
}{
{
name: "@{push} resolves to origin/branchName",
name: "@{push} resolves to refs/remotes/origin/branchName",
branch: "branchName",
commandResult: commandResult{
ExitStatus: 0,
Stdout: "origin/branchName",
Stdout: "refs/remotes/origin/branchName",
},
wantParsedPushRevision: "origin/branchName",
wantParsedPushRevision: RemoteTrackingRef{Remote: "origin", Branch: "branchName"},
},
{
name: "@{push} doesn't resolve",
@ -1095,16 +1095,25 @@ func TestClientParsePushRevision(t *testing.T) {
ExitStatus: 128,
Stderr: "fatal: git error",
},
wantParsedPushRevision: "",
wantParsedPushRevision: RemoteTrackingRef{},
wantError: &GitError{
ExitCode: 128,
Stderr: "fatal: git error",
},
},
{
name: "@{push} resolves to something surprising",
commandResult: commandResult{
ExitStatus: 0,
Stdout: "not/a/valid/remote/ref",
},
wantParsedPushRevision: RemoteTrackingRef{},
wantError: fmt.Errorf("could not parse push revision: remote tracking branch must have format refs/remotes/<remote>/<branch> but was: not/a/valid/remote/ref"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cmd := fmt.Sprintf("path/to/git rev-parse --abbrev-ref %s@{push}", tt.branch)
cmd := fmt.Sprintf("path/to/git rev-parse --symbolic-full-name %s@{push}", tt.branch)
cmdCtx := createMockedCommandContext(t, mockedCommands{
args(cmd): tt.commandResult,
})
@ -1112,20 +1121,91 @@ func TestClientParsePushRevision(t *testing.T) {
GitPath: "path/to/git",
commandContext: cmdCtx,
}
pushDefault, err := client.ParsePushRevision(context.Background(), tt.branch)
trackingRef, err := client.PushRevision(context.Background(), tt.branch)
if tt.wantError != nil {
var gitError *GitError
require.ErrorAs(t, err, &gitError)
assert.Equal(t, tt.wantError.ExitCode, gitError.ExitCode)
assert.Equal(t, tt.wantError.Stderr, gitError.Stderr)
var wantErrorAsGit *GitError
if errors.As(err, &wantErrorAsGit) {
var gitError *GitError
require.ErrorAs(t, err, &gitError)
assert.Equal(t, wantErrorAsGit.ExitCode, gitError.ExitCode)
assert.Equal(t, wantErrorAsGit.Stderr, gitError.Stderr)
} else {
assert.Equal(t, err, tt.wantError)
}
} else {
require.NoError(t, err)
}
assert.Equal(t, tt.wantParsedPushRevision, pushDefault)
assert.Equal(t, tt.wantParsedPushRevision, trackingRef)
})
}
}
func TestRemoteTrackingRef(t *testing.T) {
t.Run("parsing", func(t *testing.T) {
t.Parallel()
tests := []struct {
name string
remoteTrackingRef string
wantRemoteTrackingRef RemoteTrackingRef
wantError error
}{
{
name: "valid remote tracking ref",
remoteTrackingRef: "refs/remotes/origin/branchName",
wantRemoteTrackingRef: RemoteTrackingRef{
Remote: "origin",
Branch: "branchName",
},
},
{
name: "incorrect parts",
remoteTrackingRef: "refs/remotes/origin",
wantRemoteTrackingRef: RemoteTrackingRef{},
wantError: fmt.Errorf("remote tracking branch must have format refs/remotes/<remote>/<branch> but was: refs/remotes/origin"),
},
{
name: "incorrect prefix type",
remoteTrackingRef: "invalid/remotes/origin/branchName",
wantRemoteTrackingRef: RemoteTrackingRef{},
wantError: fmt.Errorf("remote tracking branch must have format refs/remotes/<remote>/<branch> but was: invalid/remotes/origin/branchName"),
},
{
name: "incorrect ref type",
remoteTrackingRef: "refs/invalid/origin/branchName",
wantRemoteTrackingRef: RemoteTrackingRef{},
wantError: fmt.Errorf("remote tracking branch must have format refs/remotes/<remote>/<branch> but was: refs/invalid/origin/branchName"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
trackingRef, err := ParseRemoteTrackingRef(tt.remoteTrackingRef)
if tt.wantError != nil {
require.Equal(t, tt.wantError, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.wantRemoteTrackingRef, trackingRef)
})
}
})
t.Run("stringifying", func(t *testing.T) {
t.Parallel()
remoteTrackingRef := RemoteTrackingRef{
Remote: "origin",
Branch: "branchName",
}
require.Equal(t, "refs/remotes/origin/branchName", remoteTrackingRef.String())
})
}
func TestClientDeleteLocalTag(t *testing.T) {
tests := []struct {
name string
@ -1992,6 +2072,41 @@ func TestCredentialPatternFromHost(t *testing.T) {
}
}
func TestPushDefault(t *testing.T) {
t.Run("it parses valid values correctly", func(t *testing.T) {
t.Parallel()
tests := []struct {
value string
expectedPushDefault PushDefault
}{
{"nothing", PushDefaultNothing},
{"current", PushDefaultCurrent},
{"upstream", PushDefaultUpstream},
{"tracking", PushDefaultTracking},
{"simple", PushDefaultSimple},
{"matching", PushDefaultMatching},
}
for _, test := range tests {
t.Run(test.value, func(t *testing.T) {
t.Parallel()
pushDefault, err := ParsePushDefault(test.value)
require.NoError(t, err)
assert.Equal(t, test.expectedPushDefault, pushDefault)
})
}
})
t.Run("it returns an error for invalid values", func(t *testing.T) {
t.Parallel()
_, err := ParsePushDefault("invalid")
require.Error(t, err)
})
}
func createCommandContext(t *testing.T, exitStatus int, stdout, stderr string) (*exec.Cmd, commandCtx) {
cmd := exec.CommandContext(context.Background(), os.Args[0], "-test.run=TestHelperProcess", "--")
cmd.Env = []string{

View file

@ -46,7 +46,7 @@ func Stub() (*CommandStubber, func(T)) {
return
}
t.Helper()
t.Errorf("unmatched stubs (%d): %s", len(unmatched), strings.Join(unmatched, ", "))
t.Errorf("unmatched exec stubs (%d): %s", len(unmatched), strings.Join(unmatched, ", "))
}
}

View file

@ -25,6 +25,7 @@ import (
"github.com/cli/cli/v2/pkg/cmdutil"
"github.com/cli/cli/v2/pkg/iostreams"
"github.com/cli/cli/v2/pkg/markdown"
o "github.com/cli/cli/v2/pkg/option"
"github.com/spf13/cobra"
)
@ -72,16 +73,107 @@ type CreateOptions struct {
DryRun bool
}
// creationRefs is an interface that provides the necessary information for creating a pull request in the API.
// Upcasting to concrete implementations can provide further context on other operations (forking and pushing).
type creationRefs interface {
// QualifiedHeadRef returns a stringified form of the head ref, varying depending
// on whether the head ref is in the same repository as the base ref. If they are
// the same repository, we return the branch name only. If they are different repositories,
// we return the owner and branch name in the form <owner>:<branch>.
QualifiedHeadRef() string
// UnqualifiedHeadRef returns a head ref in the form of the branch name only.
UnqualifiedHeadRef() string
//BaseRef returns the base branch name.
BaseRef() string
// While the only thing really required from an api.Repository is the repository ID, changing that
// would require changing the API function signatures, and the refactor that introduced this refs
// type is already large enough.
BaseRepo() *api.Repository
}
type baseRefs struct {
baseRepo *api.Repository
baseBranchName string
}
func (r baseRefs) BaseRef() string {
return r.baseBranchName
}
func (r baseRefs) BaseRepo() *api.Repository {
return r.baseRepo
}
// skipPushRefs indicate to handlePush that no pushing is required.
type skipPushRefs struct {
baseRefs
qualifiedHeadRef shared.QualifiedHeadRef
}
func (r skipPushRefs) QualifiedHeadRef() string {
return r.qualifiedHeadRef.String()
}
func (r skipPushRefs) UnqualifiedHeadRef() string {
return r.qualifiedHeadRef.BranchName()
}
// pushableRefs indicate to handlePush that pushing is required,
// and provide further information (HeadRepo) on where that push
// should go.
type pushableRefs struct {
baseRefs
headRepo ghrepo.Interface
headBranchName string
}
func (r pushableRefs) QualifiedHeadRef() string {
if ghrepo.IsSame(r.headRepo, r.baseRepo) {
return r.headBranchName
}
return fmt.Sprintf("%s:%s", r.headRepo.RepoOwner(), r.headBranchName)
}
func (r pushableRefs) UnqualifiedHeadRef() string {
return r.headBranchName
}
func (r pushableRefs) HeadRepo() ghrepo.Interface {
return r.headRepo
}
// forkableRefs indicate to handlePush that forking is required before
// pushing. The expectation is that after forking, this is converted to
// pushableRefs. We could go very OOP and have a Fork method on this
// struct that returns a pushableRefs but then we'd need to embed an API client
// and it just seems nice that it is a simple bag of data.
type forkableRefs struct {
baseRefs
qualifiedHeadRef shared.QualifiedHeadRef
}
func (r forkableRefs) QualifiedHeadRef() string {
return r.qualifiedHeadRef.String()
}
func (r forkableRefs) UnqualifiedHeadRef() string {
return r.qualifiedHeadRef.BranchName()
}
// CreateContext stores contextual data about the creation process and is for building up enough
// data to create a pull request.
type CreateContext struct {
// This struct stores contextual data about the creation process and is for building up enough
// data to create a pull request
RepoContext *ghContext.ResolvedRemotes
PrRefs shared.PullRequestRefs
ResolvedRemotes *ghContext.ResolvedRemotes
PRRefs creationRefs
// BaseTrackingBranch is perhaps a slightly leaky abstraction in the presence
// of PRRefs, but a huge amount of refactoring was done to introduce that struct,
// and this is a small price to pay for the convenience of not having to do a lot
// more design.
BaseTrackingBranch string
BaseBranch string // Currently not supported by shared.PullRequestRefs struct
HeadRemote *ghContext.Remote
isPushEnabled bool
forkHeadRepo bool
Client *api.Client
GitClient *git.Client
}
@ -312,8 +404,8 @@ func createRun(opts *CreateOptions) error {
}
existingPR, _, err := opts.Finder.Find(shared.FindOptions{
Selector: ctx.PrRefs.GetPRHeadLabel(),
BaseBranch: ctx.BaseBranch,
Selector: ctx.PRRefs.QualifiedHeadRef(),
BaseBranch: ctx.PRRefs.BaseRef(),
States: []string{"OPEN"},
Fields: []string{"url"},
})
@ -323,7 +415,7 @@ func createRun(opts *CreateOptions) error {
}
if err == nil {
return fmt.Errorf("a pull request for branch %q into branch %q already exists:\n%s",
ctx.PrRefs.GetPRHeadLabel(), ctx.BaseBranch, existingPR.URL)
ctx.PRRefs.QualifiedHeadRef(), ctx.PRRefs.BaseRef(), existingPR.URL)
}
message := "\nCreating pull request for %s into %s in %s\n\n"
@ -338,9 +430,9 @@ func createRun(opts *CreateOptions) error {
if opts.IO.CanPrompt() {
fmt.Fprintf(opts.IO.ErrOut, message,
cs.Cyan(ctx.PrRefs.GetPRHeadLabel()),
cs.Cyan(ctx.BaseBranch),
ghrepo.FullName(ctx.PrRefs.BaseRepo))
cs.Cyan(ctx.PRRefs.QualifiedHeadRef()),
cs.Cyan(ctx.PRRefs.BaseRef()),
ghrepo.FullName(ctx.PRRefs.BaseRepo()))
}
if !opts.EditorMode && (opts.FillVerbose || opts.Autofill || opts.FillFirst || (opts.TitleProvided && opts.BodyProvided)) {
@ -363,7 +455,7 @@ func createRun(opts *CreateOptions) error {
action = shared.SubmitDraftAction
}
tpl := shared.NewTemplateManager(client.HTTP(), ctx.PrRefs.BaseRepo, opts.Prompter, opts.RootDirOverride, opts.RepoOverride == "", true)
tpl := shared.NewTemplateManager(client.HTTP(), ctx.PRRefs.BaseRepo(), opts.Prompter, opts.RootDirOverride, opts.RepoOverride == "", true)
if opts.EditorMode {
if opts.Template != "" {
@ -431,7 +523,7 @@ func createRun(opts *CreateOptions) error {
}
allowPreview := !state.HasMetadata() && shared.ValidURL(openURL) && !opts.DryRun
allowMetadata := ctx.PrRefs.BaseRepo.(*api.Repository).ViewerCanTriage()
allowMetadata := ctx.PRRefs.BaseRepo().ViewerCanTriage()
action, err = shared.ConfirmPRSubmission(opts.Prompter, allowPreview, allowMetadata, state.Draft)
if err != nil {
return fmt.Errorf("unable to confirm: %w", err)
@ -441,10 +533,10 @@ func createRun(opts *CreateOptions) error {
fetcher := &shared.MetadataFetcher{
IO: opts.IO,
APIClient: client,
Repo: ctx.PrRefs.BaseRepo,
Repo: ctx.PRRefs.BaseRepo(),
State: state,
}
err = shared.MetadataSurvey(opts.Prompter, opts.IO, ctx.PrRefs.BaseRepo, fetcher, state)
err = shared.MetadataSurvey(opts.Prompter, opts.IO, ctx.PRRefs.BaseRepo(), fetcher, state)
if err != nil {
return err
}
@ -487,11 +579,7 @@ func createRun(opts *CreateOptions) error {
var regexPattern = regexp.MustCompile(`(?m)^`)
func initDefaultTitleBody(ctx CreateContext, state *shared.IssueMetadataState, useFirstCommit bool, addBody bool) error {
baseRef := ctx.BaseTrackingBranch
headRef := ctx.PrRefs.BranchName
gitClient := ctx.GitClient
commits, err := gitClient.Commits(context.Background(), baseRef, headRef)
commits, err := ctx.GitClient.Commits(context.Background(), ctx.BaseTrackingBranch, ctx.PRRefs.UnqualifiedHeadRef())
if err != nil {
return err
}
@ -500,7 +588,7 @@ func initDefaultTitleBody(ctx CreateContext, state *shared.IssueMetadataState, u
state.Title = commits[len(commits)-1].Title
state.Body = commits[len(commits)-1].Body
} else {
state.Title = humanize(headRef)
state.Title = humanize(ctx.PRRefs.UnqualifiedHeadRef())
var body strings.Builder
for i := len(commits) - 1; i >= 0; i-- {
fmt.Fprintf(&body, "- **%s**\n", commits[i].Title)
@ -526,7 +614,7 @@ func NewIssueState(ctx CreateContext, opts CreateOptions) (*shared.IssueMetadata
milestoneTitles = []string{opts.Milestone}
}
meReplacer := shared.NewMeReplacer(ctx.Client, ctx.PrRefs.BaseRepo.RepoHost())
meReplacer := shared.NewMeReplacer(ctx.Client, ctx.PRRefs.BaseRepo().RepoHost())
assignees, err := meReplacer.ReplaceSlice(opts.Assignees)
if err != nil {
return nil, err
@ -553,7 +641,6 @@ func NewIssueState(ctx CreateContext, opts CreateOptions) (*shared.IssueMetadata
}
func NewCreateContext(opts *CreateOptions) (*CreateContext, error) {
ctx := context.Background()
httpClient, err := opts.HttpClient()
if err != nil {
return nil, err
@ -565,25 +652,19 @@ func NewCreateContext(opts *CreateOptions) (*CreateContext, error) {
return nil, err
}
gitClient := opts.GitClient
if ucc, err := gitClient.UncommittedChangeCount(ctx); err == nil && ucc > 0 {
fmt.Fprintf(opts.IO.ErrOut, "Warning: %s\n", text.Pluralize(ucc, "uncommitted change"))
}
// Resolve base repo
repoContext, err := ghContext.ResolveRemotesToRepos(remotes, client, opts.RepoOverride)
resolvedRemotes, err := ghContext.ResolveRemotesToRepos(remotes, client, opts.RepoOverride)
if err != nil {
return nil, err
}
var targetBaseRepo *api.Repository
if br, err := repoContext.BaseRepo(opts.IO); err == nil {
var baseRepo *api.Repository
if br, err := resolvedRemotes.BaseRepo(opts.IO); err == nil {
if r, ok := br.(*api.Repository); ok {
targetBaseRepo = r
baseRepo = r
} else {
// TODO: if RepoNetwork is going to be requested anyway in `repoContext.HeadRepos()`,
// consider piggybacking on that result instead of performing a separate lookup
targetBaseRepo, err = api.GitHubRepo(client, br)
baseRepo, err = api.GitHubRepo(client, br)
if err != nil {
return nil, err
}
@ -592,181 +673,284 @@ func NewCreateContext(opts *CreateOptions) (*CreateContext, error) {
return nil, err
}
// Resolve target head branch name from either
// --head or the current branch.
var targetHeadBranch string
var targetHeadRepoOwner string
// This closure provides an easy way to instantiate a CreateContext with everything other than
// the refs. This probably indicates that CreateContext could do with some rework, but the refactor
// to introduce PRRefs is already large enough.
var newCreateContext = func(refs creationRefs) *CreateContext {
baseTrackingBranch := refs.BaseRef()
promptForHeadRepo := true
// The baseTrackingBranch is used later for a command like:
// `git commit upstream/main feature` in order to create a PR message showing the commits
// between these two refs. I'm not really sure what is expected to happen if we don't have a remote,
// which seems like it would be possible with a command `gh pr create --repo owner/repo-that-is-not-a-remote`.
// In that case, we might just have a mess? In any case, this is what the old code did, so I don't want to change
// it as part of an already large refactor.
baseRemote, _ := resolvedRemotes.RemoteForRepo(baseRepo)
if baseRemote != nil {
baseTrackingBranch = fmt.Sprintf("%s/%s", baseRemote.Name, baseTrackingBranch)
}
return &CreateContext{
ResolvedRemotes: resolvedRemotes,
Client: client,
GitClient: opts.GitClient,
PRRefs: refs,
BaseTrackingBranch: baseTrackingBranch,
}
}
// If the user provided a head branch we're going to use that without any interrogation
// of git. The value can take the form of <branch> or <user>:<branch>. In the former case, the
// PR base and head repos are the same. In the latter case we don't know the head repo
// (though we could look it up in the API) but fortunately we don't need to because the API
// will resolve this for us when we create the pull request. This is possible because
// users can only have a single fork in their namespace, and organizations don't work at all with this ref format.
//
// Note that providing the head branch in this way indicates that we shouldn't push the branch,
// and we indicate that via the returned type as well.
if opts.HeadBranch != "" {
promptForHeadRepo = false
targetHeadBranch = opts.HeadBranch
// If the --head provided contains a colon, that means
// this is <user>:<branch> syntax.
if idx := strings.IndexRune(opts.HeadBranch, ':'); idx >= 0 {
targetHeadRepoOwner = opts.HeadBranch[:idx]
targetHeadBranch = opts.HeadBranch[idx+1:]
}
} else {
// Use the current branch as the target local head branch when
// --head is not provided.
targetHeadBranch, err = opts.Branch()
if err != nil {
return nil, fmt.Errorf("could not determine the current branch: %w", err)
}
}
targetHeadBranchConfig, err := gitClient.ReadBranchConfig(ctx, targetHeadBranch)
if err != nil {
return nil, err
}
// See if we can determine if this branch has been push previously with
// Git configurations and @{push} revision syntax.
remotePushDefault, err := opts.GitClient.RemotePushDefault(ctx)
if err != nil {
return nil, err
}
// Suppressing these errors as we have other means of computing the PullRequestRefs when these fail.
parsedPushRevision, _ := gitClient.ParsePushRevision(ctx, targetHeadBranch)
pushDefault, err := gitClient.PushDefault(ctx)
if err != nil {
return nil, err
}
prRefs, err := shared.ParsePRRefs(targetHeadBranch, targetHeadBranchConfig, parsedPushRevision, pushDefault, remotePushDefault, targetBaseRepo, remotes)
if err != nil {
return nil, err
}
// If the --head provided contains <user>:<branch> syntax, we need to use
// the provided owner instead of the owner of the base repository.
if targetHeadRepoOwner != "" {
prRefs.HeadRepo = ghrepo.New(targetHeadRepoOwner, prRefs.HeadRepo.RepoName())
}
var headRemote *ghContext.Remote
// We received the head repository and branch from ParsePRRefs, or inferred
// it from --head input, but we need to check if it's up-to-date with
// our local branch state.
// If it is, we can use it as the head repo for the PR
// and avoid prompting the user.
// Errors raised here should not cause command to fail,
// prompt user for head repo if an error is raised or no remote found.
if prRefs.HasHead() {
// Check if the head branch is up-to-date with the local branch
headRemote, err := remotes.FindByRepo(prRefs.HeadRepo.RepoOwner(), prRefs.HeadRepo.RepoName())
if headRemote != nil && err == nil {
headRefName := fmt.Sprintf("refs/remotes/%s/%s", headRemote, prRefs.BranchName)
refsForLookup := []string{"HEAD", headRefName}
resolvedRefs, err := gitClient.ShowRefs(ctx, refsForLookup)
// If there is more than one resolved ref, then remote head ref was resolved.
if err == nil && len(resolvedRefs) > 1 {
headRef := resolvedRefs[0]
for _, r := range resolvedRefs[1:] {
// If the head ref is the same as the remote head ref,
// then the remote head is current and we can use it.
if r.Hash == headRef.Hash {
promptForHeadRepo = false
break
}
}
}
}
}
var forkHeadRepo bool
var isPushEnabled bool
if promptForHeadRepo && opts.IO.CanPrompt() {
isPushEnabled = true
// Since we could not determine a head ref, prompt the user for the head repository to push
// using a list of repositories obtained from the API
pushableRepos, err := repoContext.HeadRepos()
qualifiedHeadRef, err := shared.ParseQualifiedHeadRef(opts.HeadBranch)
if err != nil {
return nil, err
}
if len(pushableRepos) == 0 {
pushableRepos, err = api.RepoFindForks(client, prRefs.BaseRepo, 3)
if err != nil {
return nil, err
}
}
currentLogin, err := api.CurrentLoginName(client, prRefs.BaseRepo.RepoHost())
branchConfig, err := opts.GitClient.ReadBranchConfig(context.Background(), qualifiedHeadRef.BranchName())
if err != nil {
return nil, err
}
hasOwnFork := false
var pushOptions []string
for _, r := range pushableRepos {
pushOptions = append(pushOptions, ghrepo.FullName(r))
if r.RepoOwner() == currentLogin {
hasOwnFork = true
}
baseBranch := opts.BaseBranch
if baseBranch == "" {
baseBranch = branchConfig.MergeBase
}
if baseBranch == "" {
baseBranch = baseRepo.DefaultBranchRef.Name
}
if !hasOwnFork {
pushOptions = append(pushOptions, "Create a fork of "+ghrepo.FullName(prRefs.BaseRepo))
}
pushOptions = append(pushOptions, "Skip pushing the branch")
pushOptions = append(pushOptions, "Cancel")
selectedOption, err := opts.Prompter.Select(fmt.Sprintf("Where should we push the '%s' branch?", prRefs.BranchName), "", pushOptions)
if err != nil {
return nil, err
}
if selectedOption < len(pushableRepos) {
prRefs.HeadRepo = pushableRepos[selectedOption]
} else if pushOptions[selectedOption] == "Skip pushing the branch" {
isPushEnabled = false
} else if pushOptions[selectedOption] == "Cancel" {
return nil, cmdutil.CancelError
} else {
// "Create a fork of ..."
forkHeadRepo = true
prRefs.HeadRepo = ghrepo.New(currentLogin, prRefs.HeadRepo.RepoName())
}
return newCreateContext(skipPushRefs{
qualifiedHeadRef: qualifiedHeadRef,
baseRefs: baseRefs{
baseRepo: baseRepo,
baseBranchName: baseBranch,
},
}), nil
}
if prRefs.HeadRepo == nil && isPushEnabled && !opts.IO.CanPrompt() {
fmt.Fprintf(opts.IO.ErrOut, "aborted: you must first push the current branch to a remote, or use the --head flag")
return nil, cmdutil.SilentError
if ucc, err := opts.GitClient.UncommittedChangeCount(context.Background()); err == nil && ucc > 0 {
fmt.Fprintf(opts.IO.ErrOut, "Warning: %s\n", text.Pluralize(ucc, "uncommitted change"))
}
// If the user didn't provide a head branch then we're gettin' real. We're going to interrogate git
// and try to create refs that are pushable.
currentBranch, err := opts.Branch()
if err != nil {
return nil, fmt.Errorf("could not determine the current branch: %w", err)
}
branchConfig, err := opts.GitClient.ReadBranchConfig(context.Background(), currentBranch)
if err != nil {
return nil, err
}
baseBranch := opts.BaseBranch
if baseBranch == "" {
baseBranch = targetHeadBranchConfig.MergeBase
baseBranch = branchConfig.MergeBase
}
if baseBranch == "" {
baseBranch = targetBaseRepo.DefaultBranchRef.Name
}
if prRefs.BranchName == baseBranch && prRefs.HeadRepo != nil && ghrepo.IsSame(prRefs.BaseRepo, prRefs.HeadRepo) {
return nil, fmt.Errorf("must be on a branch named differently than %q", baseBranch)
baseBranch = baseRepo.DefaultBranchRef.Name
}
baseTrackingBranch := baseBranch
if baseRemote, err := remotes.FindByRepo(prRefs.BaseRepo.RepoOwner(), prRefs.BaseRepo.RepoName()); err == nil {
baseTrackingBranch = fmt.Sprintf("%s/%s", baseRemote.Name, baseBranch)
// First we check with the git information we have to see if we can figure out the default
// head repo and remote branch name.
defaultPRHead, err := shared.TryDetermineDefaultPRHead(
// We requested the branch config already, so let's cache that
shared.CachedBranchConfigGitConfigClient{
CachedBranchConfig: branchConfig,
GitConfigClient: opts.GitClient,
},
shared.NewRemoteToRepoResolver(opts.Remotes),
currentBranch,
)
if err != nil {
return nil, err
}
return &CreateContext{
PrRefs: prRefs,
BaseBranch: baseBranch, // Currently not supported by shared.PullRequestRefs struct
BaseTrackingBranch: baseTrackingBranch,
HeadRemote: headRemote,
isPushEnabled: isPushEnabled,
forkHeadRepo: forkHeadRepo,
RepoContext: repoContext,
Client: client,
GitClient: gitClient,
}, nil
// The baseRefs are always going to be the same from now on. If I could make this immutable I would!
baseRefs := baseRefs{
baseRepo: baseRepo,
baseBranchName: baseBranch,
}
// If we were able to determine a head repo, then let's check that the remote tracking ref matches the SHA of
// HEAD. If it does, then we don't need to push, otherwise we'll need to ask the user to tell us where to push.
if headRepo, present := defaultPRHead.Repo.Value(); present {
// We may not find a remote because the git branch config may have a URL rather than a remote name.
// Ideally, we would return a sentinel error from RemoteForRepo that we could compare to, but the
// refactor that introduced this code was already large enough.
headRemote, _ := resolvedRemotes.RemoteForRepo(headRepo)
if headRemote != nil {
resolvedRefs, _ := opts.GitClient.ShowRefs(
context.Background(),
[]string{
"HEAD",
fmt.Sprintf("refs/remotes/%s/%s", headRemote.Name, defaultPRHead.BranchName),
},
)
// Two refs returned means we can compare HEAD to the remote tracking branch.
// If we had a matching ref, then we can skip pushing.
refsMatch := len(resolvedRefs) == 2 && resolvedRefs[0].Hash == resolvedRefs[1].Hash
if refsMatch {
qualifiedHeadRef := shared.NewQualifiedHeadRefWithoutOwner(defaultPRHead.BranchName)
if headRepo.RepoOwner() != baseRepo.RepoOwner() {
qualifiedHeadRef = shared.NewQualifiedHeadRef(headRepo.RepoOwner(), defaultPRHead.BranchName)
}
return newCreateContext(skipPushRefs{
qualifiedHeadRef: qualifiedHeadRef,
baseRefs: baseRefs,
}), nil
}
}
}
// If we didn't determine that the git indicated repo had the correct ref, we'll take a look at the other
// remotes and see whether any of them have the same SHA as HEAD. Now, at this point, you might be asking yourself:
// "Why didn't we collect all the SHAs with a single ShowRefs command above, for use in both cases?"
// ...
// That's because the code below has a bug that I've ported from the old code, in order to preserve the existing
// behaviour, and to limit the scope of an already large refactor. The intention of the original code was to loop
// over all the returned refs. However, as it turns out, our implementation of ShowRefs doesn't do that correctly.
// Since it provides the --verify flag, git will return the SHAs for refs up until it hits a ref that doesn't exist,
// at which point it bails out.
//
// Imagine you have a remotes "upstream" and "origin", and you have pushed your branch "feature" to "origin". Since
// the order of remotes is always guaranteed "upstream", "github", "origin", and then everything else unstably sorted,
// we will never get a SHA for origin, as refs/remotes/upstream/feature doesn't exist.
//
// Furthermore, when you really think about it, this code is a bit eager. What happens if you have the same SHA on
// remotes "origin" and "colleague", this will always offer origin. If it were "colleague-a" and "colleague-b", no
// order would be guaranteed between different invocations of pr create, because the order of remotes after "origin"
// is unstable sorted.
//
// All that said, this has been the behaviour for a long, long time, and I do not want to make other behavioural changes
// in what is mostly a refactor.
refsToLookup := []string{"HEAD"}
for _, remote := range remotes {
refsToLookup = append(refsToLookup, fmt.Sprintf("refs/remotes/%s/%s", remote.Name, currentBranch))
}
// Ignoring the error in this case is allowed because we may get refs and an error (see: --verify flag above).
// Ideally there would be a typed error to allow us to distinguish between an execution error and some refs
// not existing. However, this is too much to take on in an already large refactor.
refs, _ := opts.GitClient.ShowRefs(context.Background(), refsToLookup)
if len(refs) > 1 {
headRef := refs[0]
var firstMatchingRef o.Option[git.RemoteTrackingRef]
// Loop over all the refs, trying to find one that matches the SHA of HEAD.
for _, r := range refs[1:] {
if r.Hash == headRef.Hash {
remoteTrackingRef, err := git.ParseRemoteTrackingRef(r.Name)
if err != nil {
return nil, err
}
firstMatchingRef = o.Some(remoteTrackingRef)
break
}
}
// If we found a matching ref, then we don't need to push.
if ref, present := firstMatchingRef.Value(); present {
remote, err := remotes.FindByName(ref.Remote)
if err != nil {
return nil, err
}
qualifiedHeadRef := shared.NewQualifiedHeadRefWithoutOwner(ref.Branch)
if baseRepo.RepoOwner() != remote.RepoOwner() {
qualifiedHeadRef = shared.NewQualifiedHeadRef(remote.RepoOwner(), ref.Branch)
}
return newCreateContext(skipPushRefs{
qualifiedHeadRef: qualifiedHeadRef,
baseRefs: baseRefs,
}), nil
}
}
// If we haven't got a repo by now, and we can't prompt then it's game over.
if !opts.IO.CanPrompt() {
fmt.Fprintln(opts.IO.ErrOut, "aborted: you must first push the current branch to a remote, or use the --head flag")
return nil, cmdutil.SilentError
}
// Otherwise, hooray, prompting!
// First, we're going to look at our remotes and decide whether there are any repos we can push to.
pushableRepos, err := resolvedRemotes.HeadRepos()
if err != nil {
return nil, err
}
// If we couldn't find any pushable repos, then find forks of the base repo.
if len(pushableRepos) == 0 {
pushableRepos, err = api.RepoFindForks(client, baseRepo, 3)
if err != nil {
return nil, err
}
}
currentLogin, err := api.CurrentLoginName(client, baseRepo.RepoHost())
if err != nil {
return nil, err
}
hasOwnFork := false
var pushOptions []string
for _, r := range pushableRepos {
pushOptions = append(pushOptions, ghrepo.FullName(r))
if r.RepoOwner() == currentLogin {
hasOwnFork = true
}
}
if !hasOwnFork {
pushOptions = append(pushOptions, fmt.Sprintf("Create a fork of %s", ghrepo.FullName(baseRepo)))
}
pushOptions = append(pushOptions, "Skip pushing the branch")
pushOptions = append(pushOptions, "Cancel")
selectedOption, err := opts.Prompter.Select(fmt.Sprintf("Where should we push the '%s' branch?", currentBranch), "", pushOptions)
if err != nil {
return nil, err
}
if selectedOption < len(pushableRepos) {
// A repository has been selected to push to.
return newCreateContext(pushableRefs{
headRepo: pushableRepos[selectedOption],
headBranchName: currentBranch,
baseRefs: baseRefs,
}), nil
} else if pushOptions[selectedOption] == "Skip pushing the branch" {
// We're going to skip pushing the branch altogether, meaning, use whatever SHA is already pushed.
// It's not exactly clear what repo the user expects to use here for the HEAD, and maybe we should
// make that clear in the UX somehow, but in the old implementation as far as I can tell, this
// always meant "use the base repo".
return newCreateContext(skipPushRefs{
qualifiedHeadRef: shared.NewQualifiedHeadRefWithoutOwner(currentBranch),
baseRefs: baseRefs,
}), nil
} else if pushOptions[selectedOption] == "Cancel" {
return nil, cmdutil.CancelError
} else {
// A fork should be created.
return newCreateContext(forkableRefs{
qualifiedHeadRef: shared.NewQualifiedHeadRef(currentLogin, currentBranch),
baseRefs: baseRefs,
}), nil
}
}
func getRemotes(opts *CreateOptions) (ghContext.Remotes, error) {
@ -789,8 +973,8 @@ func submitPR(opts CreateOptions, ctx CreateContext, state shared.IssueMetadataS
"title": state.Title,
"body": state.Body,
"draft": state.Draft,
"baseRefName": ctx.BaseBranch,
"headRefName": ctx.PrRefs.GetPRHeadLabel(),
"baseRefName": ctx.PRRefs.BaseRef(),
"headRefName": ctx.PRRefs.QualifiedHeadRef(),
"maintainerCanModify": opts.MaintainerCanModify,
}
@ -798,7 +982,7 @@ func submitPR(opts CreateOptions, ctx CreateContext, state shared.IssueMetadataS
return errors.New("pull request title must not be blank")
}
err := shared.AddMetadataToIssueParams(client, ctx.PrRefs.BaseRepo, params, &state)
err := shared.AddMetadataToIssueParams(client, ctx.PRRefs.BaseRepo(), params, &state)
if err != nil {
return err
}
@ -812,9 +996,7 @@ func submitPR(opts CreateOptions, ctx CreateContext, state shared.IssueMetadataS
}
opts.IO.StartProgressIndicator()
// At this point, ctx.PrRefs.BaseRepo is guaranteed to be an *api.Repository
// because of https://github.com/cli/cli/blob/d29db2d44199ad4a987ea866f3f4ff601b1c90a0/pkg/cmd/pr/create/create.go#L578-L592
pr, err := api.CreatePullRequest(client, ctx.PrRefs.BaseRepo.(*api.Repository), params)
pr, err := api.CreatePullRequest(client, ctx.PRRefs.BaseRepo(), params)
opts.IO.StopProgressIndicator()
if pr != nil {
fmt.Fprintln(opts.IO.Out, pr.URL)
@ -910,38 +1092,43 @@ func previewPR(opts CreateOptions, openURL string) error {
}
func handlePush(opts CreateOptions, ctx CreateContext) error {
didForkRepo := false
headRepo := ctx.PrRefs.HeadRepo
headRemote := ctx.HeadRemote
client := ctx.Client
gitClient := ctx.GitClient
var err error
// if a head repository could not be determined so far, automatically create
// one by forking the base repository
if ctx.forkHeadRepo && ctx.isPushEnabled {
refs := ctx.PRRefs
forkableRefs, requiresFork := refs.(forkableRefs)
if requiresFork {
opts.IO.StartProgressIndicator()
headRepo, err = api.ForkRepo(client, ctx.PrRefs.BaseRepo, "", "", false)
forkedRepo, err := api.ForkRepo(ctx.Client, forkableRefs.BaseRepo(), "", "", false)
opts.IO.StopProgressIndicator()
if err != nil {
return fmt.Errorf("error forking repo: %w", err)
}
didForkRepo = true
refs = pushableRefs{
headRepo: forkedRepo,
headBranchName: forkableRefs.qualifiedHeadRef.BranchName(),
baseRefs: baseRefs{
baseRepo: forkableRefs.baseRepo,
baseBranchName: forkableRefs.baseBranchName,
},
}
}
if headRemote == nil && headRepo != nil {
headRemote, _ = ctx.RepoContext.RemoteForRepo(headRepo)
// We may have upcast to pushableRefs on fork, or we may have been passed an instance
// already. But if we haven't, then there's nothing more to do.
pushableRefs, ok := refs.(pushableRefs)
if !ok {
return nil
}
// There are two cases when an existing remote for the head repo will be
// missing:
// missing (and an error will be returned):
// 1. the head repo was just created by auto-forking;
// 2. an existing fork was discovered by querying the API.
// In either case, we want to add the head repo as a new git remote so we
// can push to it. We will try to add the head repo as the "origin" remote
// and fallback to the "fork" remote if it is unavailable. Also, if the
// base repo is the "origin" remote we will rename it "upstream".
if headRemote == nil && ctx.isPushEnabled {
headRemote, _ := ctx.ResolvedRemotes.RemoteForRepo(pushableRefs.HeadRepo())
if headRemote == nil {
cfg, err := opts.Config()
if err != nil {
return err
@ -952,8 +1139,8 @@ func handlePush(opts CreateOptions, ctx CreateContext) error {
return err
}
cloneProtocol := cfg.GitProtocol(headRepo.RepoHost()).Value
headRepoURL := ghrepo.FormatRemoteURL(headRepo, cloneProtocol)
cloneProtocol := cfg.GitProtocol(pushableRefs.HeadRepo().RepoHost()).Value
headRepoURL := ghrepo.FormatRemoteURL(pushableRefs.HeadRepo(), cloneProtocol)
gitClient := ctx.GitClient
origin, _ := remotes.FindByName("origin")
upstreamName := "upstream"
@ -964,7 +1151,7 @@ func handlePush(opts CreateOptions, ctx CreateContext) error {
remoteName = "fork"
}
if origin != nil && upstream == nil && ghrepo.IsSame(origin, ctx.PrRefs.BaseRepo) {
if origin != nil && upstream == nil && ghrepo.IsSame(origin, pushableRefs.BaseRepo()) {
renameCmd, err := gitClient.Command(context.Background(), "remote", "rename", "origin", upstreamName)
if err != nil {
return err
@ -973,7 +1160,7 @@ func handlePush(opts CreateOptions, ctx CreateContext) error {
return fmt.Errorf("error renaming origin remote: %w", err)
}
remoteName = "origin"
fmt.Fprintf(opts.IO.ErrOut, "Changed %s remote to %q\n", ghrepo.FullName(ctx.PrRefs.BaseRepo), upstreamName)
fmt.Fprintf(opts.IO.ErrOut, "Changed %s remote to %q\n", ghrepo.FullName(pushableRefs.BaseRepo()), upstreamName)
}
gitRemote, err := gitClient.AddRemote(context.Background(), remoteName, headRepoURL, []string{})
@ -981,10 +1168,10 @@ func handlePush(opts CreateOptions, ctx CreateContext) error {
return fmt.Errorf("error adding remote: %w", err)
}
fmt.Fprintf(opts.IO.ErrOut, "Added %s as remote %q\n", ghrepo.FullName(headRepo), remoteName)
fmt.Fprintf(opts.IO.ErrOut, "Added %s as remote %q\n", ghrepo.FullName(pushableRefs.HeadRepo()), remoteName)
// Only mark `upstream` remote as default if `gh pr create` created the remote.
if didForkRepo {
if requiresFork {
err := gitClient.SetRemoteResolution(context.Background(), upstreamName, "base")
if err != nil {
return fmt.Errorf("error setting upstream as default: %w", err)
@ -992,52 +1179,45 @@ func handlePush(opts CreateOptions, ctx CreateContext) error {
if opts.IO.IsStdoutTTY() {
cs := opts.IO.ColorScheme()
fmt.Fprintf(opts.IO.ErrOut, "%s Repository %s set as the default repository. To learn more about the default repository, run: gh repo set-default --help\n", cs.WarningIcon(), cs.Bold(ghrepo.FullName(headRepo)))
fmt.Fprintf(opts.IO.ErrOut, "%s Repository %s set as the default repository. To learn more about the default repository, run: gh repo set-default --help\n", cs.WarningIcon(), cs.Bold(ghrepo.FullName(pushableRefs.HeadRepo())))
}
}
headRemote = &ghContext.Remote{
Remote: gitRemote,
Repo: headRepo,
Repo: pushableRefs.HeadRepo(),
}
}
// automatically push the branch if it hasn't been pushed anywhere yet
if ctx.isPushEnabled {
pushBranch := func() error {
w := NewRegexpWriter(opts.IO.ErrOut, gitPushRegexp, "")
defer w.Flush()
ref := fmt.Sprintf("HEAD:refs/heads/%s", ctx.PrRefs.BranchName)
bo := backoff.NewConstantBackOff(2 * time.Second)
ctx := context.Background()
return backoff.Retry(func() error {
if err := gitClient.Push(ctx, headRemote.Name, ref, git.WithStderr(w)); err != nil {
// Only retry if we have forked the repo else the push should succeed the first time.
if didForkRepo {
fmt.Fprintf(opts.IO.ErrOut, "waiting 2 seconds before retrying...\n")
return err
}
return backoff.Permanent(err)
pushBranch := func() error {
w := NewRegexpWriter(opts.IO.ErrOut, gitPushRegexp, "")
defer w.Flush()
ref := fmt.Sprintf("HEAD:refs/heads/%s", ctx.PRRefs.UnqualifiedHeadRef())
bo := backoff.NewConstantBackOff(2 * time.Second)
root := context.Background()
return backoff.Retry(func() error {
if err := ctx.GitClient.Push(root, headRemote.Name, ref, git.WithStderr(w)); err != nil {
// Only retry if we have forked the repo else the push should succeed the first time.
if requiresFork {
fmt.Fprintf(opts.IO.ErrOut, "waiting 2 seconds before retrying...\n")
return err
}
return nil
}, backoff.WithContext(backoff.WithMaxRetries(bo, 3), ctx))
}
err := pushBranch()
if err != nil {
return err
}
return backoff.Permanent(err)
}
return nil
}, backoff.WithContext(backoff.WithMaxRetries(bo, 3), root))
}
return nil
return pushBranch()
}
func generateCompareURL(ctx CreateContext, state shared.IssueMetadataState) (string, error) {
u := ghrepo.GenerateRepoURL(
ctx.PrRefs.BaseRepo,
ctx.PRRefs.BaseRepo(),
"compare/%s...%s?expand=1",
url.PathEscape(ctx.BaseBranch), url.PathEscape(ctx.PrRefs.GetPRHeadLabel()))
url, err := shared.WithPrAndIssueQueryParams(ctx.Client, ctx.PrRefs.BaseRepo, u, state)
url.PathEscape(ctx.PRRefs.BaseRef()), url.PathEscape(ctx.PRRefs.QualifiedHeadRef()))
url, err := shared.WithPrAndIssueQueryParams(ctx.Client, ctx.PRRefs.BaseRepo(), u, state)
if err != nil {
return "", err
}

View file

@ -2,7 +2,6 @@ package create
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"os"
@ -332,19 +331,18 @@ func TestNewCmdCreate(t *testing.T) {
func Test_createRun(t *testing.T) {
tests := []struct {
name string
setup func(*CreateOptions, *testing.T) func()
cmdStubs func(*run.CommandStubber)
promptStubs func(*prompter.PrompterMock)
httpStubs func(*httpmock.Registry, *testing.T)
expectedOutputs []string
expectedOut string
expectedErrOut string
expectedBrowse string
wantErr string
tty bool
customBranchConfig bool
customPushDestination bool
name string
setup func(*CreateOptions, *testing.T) func()
cmdStubs func(*run.CommandStubber)
promptStubs func(*prompter.PrompterMock)
httpStubs func(*httpmock.Registry, *testing.T)
expectedOutputs []string
expectedOut string
expectedErrOut string
expectedBrowse string
wantErr string
tty bool
customBranchConfig bool
}{
{
name: "nontty web",
@ -608,7 +606,7 @@ func Test_createRun(t *testing.T) {
`),
},
{
name: "survey",
name: "select a specific branch to push to on prompt",
tty: true,
setup: func(opts *CreateOptions, t *testing.T) func() {
opts.TitleProvided = true
@ -637,6 +635,9 @@ func Test_createRun(t *testing.T) {
}))
},
cmdStubs: func(cs *run.CommandStubber) {
cs.Register("git rev-parse --symbolic-full-name feature@{push}", 0, "refs/remotes/origin/feature")
cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 1, "")
cs.Register("git show-ref --verify -- HEAD refs/remotes/origin/feature", 1, "")
cs.Register(`git push --set-upstream origin HEAD:refs/heads/feature`, 0, "")
},
promptStubs: func(pm *prompter.PrompterMock) {
@ -651,6 +652,52 @@ func Test_createRun(t *testing.T) {
expectedOut: "https://github.com/OWNER/REPO/pull/12\n",
expectedErrOut: "\nCreating pull request for feature into master in OWNER/REPO\n\n",
},
{
name: "skip pushing to branch on prompt",
tty: true,
setup: func(opts *CreateOptions, t *testing.T) func() {
opts.TitleProvided = true
opts.BodyProvided = true
opts.Title = "my title"
opts.Body = "my body"
return func() {}
},
httpStubs: func(reg *httpmock.Registry, t *testing.T) {
reg.StubRepoResponse("OWNER", "REPO")
reg.Register(
httpmock.GraphQL(`query UserCurrent\b`),
httpmock.StringResponse(`{"data": {"viewer": {"login": "OWNER"} } }`))
reg.Register(
httpmock.GraphQL(`mutation PullRequestCreate\b`),
httpmock.GraphQLMutation(`
{ "data": { "createPullRequest": { "pullRequest": {
"URL": "https://github.com/OWNER/REPO/pull/12"
} } } }`, func(input map[string]interface{}) {
assert.Equal(t, "REPOID", input["repositoryId"].(string))
assert.Equal(t, "my title", input["title"].(string))
assert.Equal(t, "my body", input["body"].(string))
assert.Equal(t, "master", input["baseRefName"].(string))
assert.Equal(t, "feature", input["headRefName"].(string))
assert.Equal(t, false, input["draft"].(bool))
}))
},
cmdStubs: func(cs *run.CommandStubber) {
cs.Register("git rev-parse --symbolic-full-name feature@{push}", 0, "refs/remotes/origin/feature")
cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 1, "")
cs.Register("git show-ref --verify -- HEAD refs/remotes/origin/feature", 1, "")
},
promptStubs: func(pm *prompter.PrompterMock) {
pm.SelectFunc = func(p, _ string, opts []string) (int, error) {
if p == "Where should we push the 'feature' branch?" {
return prompter.IndexFor(opts, "Skip pushing the branch")
} else {
return -1, prompter.NoSuchPromptErr(p)
}
}
},
expectedOut: "https://github.com/OWNER/REPO/pull/12\n",
expectedErrOut: "\nCreating pull request for feature into master in OWNER/REPO\n\n",
},
{
name: "project v2",
tty: true,
@ -699,6 +746,9 @@ func Test_createRun(t *testing.T) {
}))
},
cmdStubs: func(cs *run.CommandStubber) {
cs.Register("git rev-parse --symbolic-full-name feature@{push}", 0, "refs/remotes/origin/feature")
cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 1, "")
cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 1, "")
cs.Register(`git push --set-upstream origin HEAD:refs/heads/feature`, 0, "")
},
promptStubs: func(pm *prompter.PrompterMock) {
@ -744,6 +794,9 @@ func Test_createRun(t *testing.T) {
}))
},
cmdStubs: func(cs *run.CommandStubber) {
cs.Register("git rev-parse --symbolic-full-name feature@{push}", 0, "refs/remotes/origin/feature")
cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 1, "")
cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 1, "")
cs.Register(`git push --set-upstream origin HEAD:refs/heads/feature`, 0, "")
},
promptStubs: func(pm *prompter.PrompterMock) {
@ -791,12 +844,11 @@ func Test_createRun(t *testing.T) {
assert.Equal(t, "monalisa:feature", input["headRefName"].(string))
}))
},
customPushDestination: true,
cmdStubs: func(cs *run.CommandStubber) {
cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 0, "")
cs.Register("git rev-parse --abbrev-ref feature@{push}", 0, "")
cs.Register("git config remote.pushDefault", 0, "")
cs.Register("git config push.default", 0, "")
cs.Register("git rev-parse --symbolic-full-name feature@{push}", 1, "")
cs.Register("git config remote.pushDefault", 1, "")
cs.Register("git config push.default", 1, "")
cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 1, "")
cs.Register("git remote rename origin upstream", 0, "")
cs.Register(`git remote add origin https://github.com/monalisa/REPO.git`, 0, "")
cs.Register(`git push --set-upstream origin HEAD:refs/heads/feature`, 0, "")
@ -854,15 +906,11 @@ func Test_createRun(t *testing.T) {
assert.Equal(t, "monalisa:feature", input["headRefName"].(string))
}))
},
customPushDestination: true,
cmdStubs: func(cs *run.CommandStubber) {
cs.Register("git show-ref --verify", 0, heredoc.Doc(`
cs.Register("git rev-parse --symbolic-full-name feature@{push}", 0, "refs/remotes/origin/feature")
cs.Register("git show-ref --verify -- HEAD refs/remotes/origin/feature", 0, heredoc.Doc(`
deadbeef HEAD
deadb00f refs/remotes/upstream/feature
deadbeef refs/remotes/origin/feature`))
cs.Register("git rev-parse --abbrev-ref feature@{push}", 0, "origin/feature")
cs.Register("git config remote.pushDefault", 0, "")
cs.Register("git config push.default", 0, "")
},
expectedOut: "https://github.com/OWNER/REPO/pull/12\n",
expectedErrOut: "\nCreating pull request for monalisa:feature into master in OWNER/REPO\n\n",
@ -890,20 +938,17 @@ func Test_createRun(t *testing.T) {
assert.Equal(t, "my-feat2", input["headRefName"].(string))
}))
},
customBranchConfig: true,
customPushDestination: true,
customBranchConfig: true,
cmdStubs: func(cs *run.CommandStubber) {
cs.Register(`git config --get-regexp \^branch\\\.feature\\\.`, 0, heredoc.Doc(`
branch.feature.remote origin
branch.feature.merge refs/heads/my-feat2
`)) // determineTrackingBranch
cs.Register("git show-ref --verify", 0, heredoc.Doc(`
`))
cs.Register("git rev-parse --symbolic-full-name feature@{push}", 0, "refs/remotes/origin/my-feat2")
cs.Register("git show-ref --verify -- HEAD refs/remotes/origin/my-feat2", 0, heredoc.Doc(`
deadbeef HEAD
deadbeef refs/remotes/origin/my-feat2
`)) // determineTrackingBranch
cs.Register("git rev-parse --abbrev-ref feature@{push}", 0, "origin/my-feat2")
cs.Register("git config remote.pushDefault", 0, "")
cs.Register("git config push.default", 0, "")
`))
},
expectedOut: "https://github.com/OWNER/REPO/pull/12\n",
expectedErrOut: "\nCreating pull request for my-feat2 into master in OWNER/REPO\n\n",
@ -1084,6 +1129,9 @@ func Test_createRun(t *testing.T) {
},
cmdStubs: func(cs *run.CommandStubber) {
cs.Register(`git( .+)? log( .+)? origin/master\.\.\.feature`, 0, "")
cs.Register("git rev-parse --symbolic-full-name feature@{push}", 0, "refs/remotes/origin/feature")
cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 1, "")
cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 1, "")
cs.Register(`git push --set-upstream origin HEAD:refs/heads/feature`, 0, "")
},
promptStubs: func(pm *prompter.PrompterMock) {
@ -1115,6 +1163,9 @@ func Test_createRun(t *testing.T) {
},
cmdStubs: func(cs *run.CommandStubber) {
cs.Register(`git( .+)? log( .+)? origin/master\.\.\.feature`, 0, "")
cs.Register("git rev-parse --symbolic-full-name feature@{push}", 0, "refs/remotes/origin/feature")
cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 1, "")
cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 1, "")
cs.Register(`git push --set-upstream origin HEAD:refs/heads/feature`, 0, "")
},
promptStubs: func(pm *prompter.PrompterMock) {
@ -1279,37 +1330,6 @@ func Test_createRun(t *testing.T) {
},
wantErr: "cannot open in browser: maximum URL length exceeded",
},
{
name: "no local git repo",
setup: func(opts *CreateOptions, t *testing.T) func() {
opts.Title = "My PR"
opts.TitleProvided = true
opts.Body = ""
opts.BodyProvided = true
opts.HeadBranch = "feature"
opts.RepoOverride = "OWNER/REPO"
opts.Remotes = func() (context.Remotes, error) {
return nil, errors.New("not a git repository")
}
return func() {}
},
httpStubs: func(reg *httpmock.Registry, t *testing.T) {
reg.Register(
httpmock.GraphQL(`mutation PullRequestCreate\b`),
httpmock.StringResponse(`
{ "data": { "createPullRequest": { "pullRequest": {
"URL": "https://github.com/OWNER/REPO/pull/12"
} } } }
`))
},
customPushDestination: true,
cmdStubs: func(cs *run.CommandStubber) {
cs.Register("git rev-parse --abbrev-ref feature@{push}", 1, "fatal: not a git repository (or any of the parent directories): .git")
cs.Register("git config remote.pushDefault", 1, "")
cs.Register("git config push.default", 1, "")
},
expectedOut: "https://github.com/OWNER/REPO/pull/12\n",
},
{
name: "single commit title and body are used",
tty: true,
@ -1528,20 +1548,16 @@ func Test_createRun(t *testing.T) {
assert.Equal(t, "monalisa:task1", input["headRefName"].(string))
}))
},
customBranchConfig: true,
customPushDestination: true,
customBranchConfig: true,
cmdStubs: func(cs *run.CommandStubber) {
cs.Register(`git config --get-regexp \^branch\\\.task1\\\.\(remote\|merge\|pushremote\|gh-merge-base\)\$`, 0, heredoc.Doc(`
branch.task1.remote origin
branch.task1.merge refs/heads/task1
branch.task1.gh-merge-base feature/feat2`)) // ReadBranchConfig
cs.Register(`git show-ref --verify`, 0, heredoc.Doc(`
cs.Register("git rev-parse --symbolic-full-name task1@{push}", 0, "refs/remotes/origin/task1")
cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/task1`, 0, heredoc.Doc(`
deadbeef HEAD
deadb00f refs/remotes/upstream/feature/feat2
deadbeef refs/remotes/origin/task1`)) // determineTrackingBranch
cs.Register("git rev-parse --abbrev-ref task1@{push}", 0, "origin/task1")
cs.Register("git config remote.pushDefault", 0, "")
cs.Register("git config push.default", 0, "")
deadbeef refs/remotes/origin/task1`))
},
expectedOut: "https://github.com/OWNER/REPO/pull/12\n",
expectedErrOut: "\nCreating pull request for monalisa:task1 into feature/feat2 in OWNER/REPO\n\n",
@ -1571,12 +1587,6 @@ func Test_createRun(t *testing.T) {
opts.HeadBranch = "otherowner:feature"
return func() {}
},
customPushDestination: true,
cmdStubs: func(cs *run.CommandStubber) {
cs.Register("git rev-parse --abbrev-ref feature@{push}", 0, "origin/feature")
cs.Register("git config remote.pushDefault", 0, "")
cs.Register("git config push.default", 0, "")
},
expectedOut: "https://github.com/OWNER/REPO/pull/12\n",
},
}
@ -1598,16 +1608,7 @@ func Test_createRun(t *testing.T) {
cs, cmdTeardown := run.Stub()
defer cmdTeardown(t)
cs.Register(`git status --porcelain`, 0, "")
// TODO this could be values in the test struct with a helper
// function to invoke the appropriate command stubs based on
// those values.
if !tt.customPushDestination {
cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 0, "")
cs.Register("git rev-parse --abbrev-ref feature@{push}", 0, "origin/feature")
cs.Register("git config remote.pushDefault", 0, "")
cs.Register("git config push.default", 0, "")
}
if !tt.customBranchConfig {
cs.Register(`git config --get-regexp \^branch\\\..+\\\.\(remote\|merge\|pushremote\|gh-merge-base\)\$`, 0, "")
}
@ -1658,6 +1659,10 @@ func Test_createRun(t *testing.T) {
}
defer cleanSetup()
if opts.HeadBranch == "" {
cs.Register(`git status --porcelain`, 0, "")
}
err := createRun(&opts)
output := &test.CmdOut{
OutBuf: stdout,
@ -1681,6 +1686,168 @@ func Test_createRun(t *testing.T) {
}
}
func TestRemoteGuessing(t *testing.T) {
// Given git config does not provide the necessary info to determine a remote
cs, cmdTeardown := run.Stub()
defer cmdTeardown(t)
cs.Register(`git status --porcelain`, 0, "")
cs.Register(`git config --get-regexp \^branch\\\..+\\\.\(remote\|merge\|pushremote\|gh-merge-base\)\$`, 0, "")
cs.Register(`git rev-parse --symbolic-full-name feature@{push}`, 1, "")
cs.Register("git config remote.pushDefault", 1, "")
cs.Register("git config push.default", 1, "")
// And Given there is a remote on a SHA that matches the current HEAD
cs.Register(`git show-ref --verify -- HEAD refs/remotes/upstream/feature refs/remotes/origin/feature`, 0, heredoc.Doc(`
deadbeef HEAD
deadb00f refs/remotes/upstream/feature
deadbeef refs/remotes/origin/feature`))
// When the command is run
reg := &httpmock.Registry{}
reg.StubRepoInfoResponse("OWNER", "REPO", "master")
defer reg.Verify(t)
reg.Register(
httpmock.GraphQL(`mutation PullRequestCreate\b`),
httpmock.GraphQLMutation(`
{ "data": { "createPullRequest": { "pullRequest": {
"URL": "https://github.com/OWNER/REPO/pull/12"
} } } }`, func(input map[string]interface{}) {
assert.Equal(t, "REPOID", input["repositoryId"].(string))
assert.Equal(t, "master", input["baseRefName"].(string))
assert.Equal(t, "OTHEROWNER:feature", input["headRefName"].(string))
}))
ios, _, _, _ := iostreams.Test()
opts := CreateOptions{
HttpClient: func() (*http.Client, error) {
return &http.Client{Transport: reg}, nil
},
Config: func() (gh.Config, error) {
return config.NewBlankConfig(), nil
},
Browser: &browser.Stub{},
IO: ios,
Prompter: &prompter.PrompterMock{},
GitClient: &git.Client{
GhPath: "some/path/gh",
GitPath: "some/path/git",
},
Finder: shared.NewMockFinder("feature", nil, nil),
Remotes: func() (context.Remotes, error) {
return context.Remotes{
{
Remote: &git.Remote{
Name: "upstream",
Resolved: "base",
},
Repo: ghrepo.New("OWNER", "REPO"),
},
{
Remote: &git.Remote{
Name: "origin",
},
Repo: ghrepo.New("OTHEROWNER", "REPO-FORK"),
},
}, nil
},
Branch: func() (string, error) {
return "feature", nil
},
TitleProvided: true,
BodyProvided: true,
Title: "my title",
Body: "my body",
}
require.NoError(t, createRun(&opts))
// Then guessed remote is used for the PR head,
// which annoyingly, is asserted above on the line:
// assert.Equal(t, "OTHEROWNER:feature", input["headRefName"].(string))
//
// This is because OTHEROWNER relates to the "origin" remote, which has a
// SHA that matches the HEAD ref in the `git show-ref` output.
}
func TestNoRepoCanBeDetermined(t *testing.T) {
// Given no head repo can be determined from git config
cs, cmdTeardown := run.Stub()
defer cmdTeardown(t)
cs.Register(`git status --porcelain`, 0, "")
cs.Register(`git config --get-regexp \^branch\\\..+\\\.\(remote\|merge\|pushremote\|gh-merge-base\)\$`, 0, "")
cs.Register(`git rev-parse --symbolic-full-name feature@{push}`, 1, "")
cs.Register("git config remote.pushDefault", 1, "")
cs.Register("git config push.default", 1, "")
// And Given there is no remote on the correct SHA
cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 0, heredoc.Doc(`
deadbeef HEAD
deadb00f refs/remotes/origin/feature`))
// When the command is run with no TTY
reg := &httpmock.Registry{}
reg.StubRepoInfoResponse("OWNER", "REPO", "master")
defer reg.Verify(t)
ios, _, _, stderr := iostreams.Test()
opts := CreateOptions{
HttpClient: func() (*http.Client, error) {
return &http.Client{Transport: reg}, nil
},
Config: func() (gh.Config, error) {
return config.NewBlankConfig(), nil
},
Browser: &browser.Stub{},
IO: ios,
Prompter: &prompter.PrompterMock{},
GitClient: &git.Client{
GhPath: "some/path/gh",
GitPath: "some/path/git",
},
Finder: shared.NewMockFinder("feature", nil, nil),
Remotes: func() (context.Remotes, error) {
return context.Remotes{
{
Remote: &git.Remote{
Name: "origin",
Resolved: "base",
},
Repo: ghrepo.New("OWNER", "REPO"),
},
}, nil
},
Branch: func() (string, error) {
return "feature", nil
},
TitleProvided: true,
BodyProvided: true,
Title: "my title",
Body: "my body",
}
// When we run the command
err := createRun(&opts)
// Then create fails
require.Equal(t, cmdutil.SilentError, err)
assert.Equal(t, "aborted: you must first push the current branch to a remote, or use the --head flag\n", stderr.String())
}
func mustParseQualifiedHeadRef(ref string) shared.QualifiedHeadRef {
parsed, err := shared.ParseQualifiedHeadRef(ref)
if err != nil {
panic(err)
}
return parsed
}
func Test_generateCompareURL(t *testing.T) {
tests := []struct {
name string
@ -1692,12 +1859,13 @@ func Test_generateCompareURL(t *testing.T) {
{
name: "basic",
ctx: CreateContext{
PrRefs: shared.PullRequestRefs{
BaseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"),
HeadRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"),
BranchName: "feature",
PRRefs: &skipPushRefs{
qualifiedHeadRef: shared.NewQualifiedHeadRefWithoutOwner("feature"),
baseRefs: baseRefs{
baseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"),
baseBranchName: "main",
},
},
BaseBranch: "main",
},
want: "https://github.com/OWNER/REPO/compare/main...feature?body=&expand=1",
wantErr: false,
@ -1705,12 +1873,13 @@ func Test_generateCompareURL(t *testing.T) {
{
name: "with labels",
ctx: CreateContext{
PrRefs: shared.PullRequestRefs{
BaseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"),
HeadRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"),
BranchName: "b",
PRRefs: &skipPushRefs{
qualifiedHeadRef: shared.NewQualifiedHeadRefWithoutOwner("b"),
baseRefs: baseRefs{
baseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"),
baseBranchName: "a",
},
},
BaseBranch: "a",
},
state: shared.IssueMetadataState{
Labels: []string{"one", "two three"},
@ -1721,12 +1890,13 @@ func Test_generateCompareURL(t *testing.T) {
{
name: "'/'s in branch names/labels are percent-encoded",
ctx: CreateContext{
PrRefs: shared.PullRequestRefs{
BaseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "UPSTREAMOWNER"}}, "github.com"),
HeadRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "ORIGINOWNER"}}, "github.com"),
BranchName: "feature",
PRRefs: &skipPushRefs{
qualifiedHeadRef: mustParseQualifiedHeadRef("ORIGINOWNER:feature"),
baseRefs: baseRefs{
baseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "UPSTREAMOWNER"}}, "github.com"),
baseBranchName: "main/trunk",
},
},
BaseBranch: "main/trunk",
},
want: "https://github.com/UPSTREAMOWNER/REPO/compare/main%2Ftrunk...ORIGINOWNER:feature?body=&expand=1",
wantErr: false,
@ -1734,18 +1904,19 @@ func Test_generateCompareURL(t *testing.T) {
{
name: "Any of !'(),; but none of $&+=@ and : in branch names/labels are percent-encoded ",
/*
- Technically, per section 3.3 of RFC 3986, none of !$&'()*+,;= (sub-delims) and :[]@ (part of gen-delims) in path segments are optionally percent-encoded, but url.PathEscape percent-encodes !'(),; anyway
- !$&'()+,;=@ is a valid Git branch nameessentially RFC 3986 sub-delims without * and gen-delims without :/?#[]
- : is GitHub separator between a fork name and a branch name
- See https://github.com/golang/go/issues/27559.
- Technically, per section 3.3 of RFC 3986, none of !$&'()*+,;= (sub-delims) and :[]@ (part of gen-delims) in path segments are optionally percent-encoded, but url.PathEscape percent-encodes !'(),; anyway
- !$&'()+,;=@ is a valid Git branch nameessentially RFC 3986 sub-delims without * and gen-delims without :/?#[]
- : is GitHub separator between a fork name and a branch name
- See https://github.com/golang/go/issues/27559.
*/
ctx: CreateContext{
PrRefs: shared.PullRequestRefs{
BaseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "UPSTREAMOWNER"}}, "github.com"),
HeadRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "ORIGINOWNER"}}, "github.com"),
BranchName: "!$&'()+,;=@",
PRRefs: &skipPushRefs{
qualifiedHeadRef: mustParseQualifiedHeadRef("ORIGINOWNER:!$&'()+,;=@"),
baseRefs: baseRefs{
baseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "UPSTREAMOWNER"}}, "github.com"),
baseBranchName: "main/trunk",
},
},
BaseBranch: "main/trunk",
},
want: "https://github.com/UPSTREAMOWNER/REPO/compare/main%2Ftrunk...ORIGINOWNER:%21$&%27%28%29+%2C%3B=@?body=&expand=1",
wantErr: false,
@ -1753,12 +1924,13 @@ func Test_generateCompareURL(t *testing.T) {
{
name: "with template",
ctx: CreateContext{
PrRefs: shared.PullRequestRefs{
BaseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"),
HeadRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"),
BranchName: "feature",
PRRefs: &skipPushRefs{
qualifiedHeadRef: shared.NewQualifiedHeadRefWithoutOwner("feature"),
baseRefs: baseRefs{
baseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"),
baseBranchName: "main",
},
},
BaseBranch: "main",
},
state: shared.IssueMetadataState{
Template: "story.md",

View file

@ -0,0 +1,394 @@
package shared
import (
"context"
"fmt"
"net/url"
"strings"
ghContext "github.com/cli/cli/v2/context"
"github.com/cli/cli/v2/git"
"github.com/cli/cli/v2/internal/ghrepo"
o "github.com/cli/cli/v2/pkg/option"
)
// QualifiedHeadRef represents a git branch with an optional owner, used
// for the head of a pull request. For example, within a single repository,
// we would expect a PR to have a head ref of no owner, and a branch name.
// However, for cross-repository pull requests, we would expect a head ref
// with an owner and a branch name. In string form this is represented as
// <owner>:<branch>. The GitHub API is able to interpret this format in order
// to discover the correct fork repository.
//
// In other parts of the code, you may see this refered to as a HeadLabel.
type QualifiedHeadRef struct {
owner o.Option[string]
branchName string
}
// NewQualifiedHeadRef creates a QualifiedHeadRef. If the empty string is provided
// for the owner, it will be treated as None.
func NewQualifiedHeadRef(owner string, branchName string) QualifiedHeadRef {
return QualifiedHeadRef{
owner: o.SomeIfNonZero(owner),
branchName: branchName,
}
}
func NewQualifiedHeadRefWithoutOwner(branchName string) QualifiedHeadRef {
return QualifiedHeadRef{
owner: o.None[string](),
branchName: branchName,
}
}
// ParseQualifiedHeadRef takes strings of the form <owner>:<branch> or <branch>
// and returns a QualifiedHeadRef. If the form <owner>:<branch> is used,
// the owner is set to the value of <owner>, and the branch name is set to
// the value of <branch>. If the form <branch> is used, the owner is set to
// None, and the branch name is set to the value of <branch>.
//
// This does no further error checking about the validity of a ref, so
// it is not safe to assume the ref is truly a valid ref, e.g. "my~bad:ref?"
// is going to result in a nonsense result.
func ParseQualifiedHeadRef(ref string) (QualifiedHeadRef, error) {
if !strings.Contains(ref, ":") {
return NewQualifiedHeadRefWithoutOwner(ref), nil
}
parts := strings.Split(ref, ":")
if len(parts) != 2 {
return QualifiedHeadRef{}, fmt.Errorf("invalid qualified head ref format '%s'", ref)
}
return NewQualifiedHeadRef(parts[0], parts[1]), nil
}
// A QualifiedHeadRef without an owner returns <branch>, while a QualifiedHeadRef
// with an owner returns <owner>:<branch>.
func (r QualifiedHeadRef) String() string {
if owner, present := r.owner.Value(); present {
return fmt.Sprintf("%s:%s", owner, r.branchName)
}
return r.branchName
}
func (r QualifiedHeadRef) BranchName() string {
return r.branchName
}
// PRFindRefs represents the necessary data to find a pull request from the API.
type PRFindRefs struct {
qualifiedHeadRef QualifiedHeadRef
baseRepo ghrepo.Interface
// baseBranchName is an optional branch name, because it is not required for
// finding a pull request, only for disambiguation if multiple pull requests
// contain the same head ref.
baseBranchName o.Option[string]
}
// QualifiedHeadRef returns a stringified form of the head ref, varying depending
// on whether the head ref is in the same repository as the base ref. If they are
// the same repository, we return the branch name only. If they are different repositories,
// we return the owner and branch name in the form <owner>:<branch>.
func (r PRFindRefs) QualifiedHeadRef() string {
return r.qualifiedHeadRef.String()
}
func (r PRFindRefs) UnqualifiedHeadRef() string {
return r.qualifiedHeadRef.BranchName()
}
// Matches checks whether the provided baseBranchName and headRef match the refs.
// It is used to determine whether Pull Requests returned from the API
func (r PRFindRefs) Matches(baseBranchName, qualifiedHeadRef string) bool {
headMatches := qualifiedHeadRef == r.QualifiedHeadRef()
baseMatches := r.baseBranchName.IsNone() || baseBranchName == r.baseBranchName.Unwrap()
return headMatches && baseMatches
}
func (r PRFindRefs) BaseRepo() ghrepo.Interface {
return r.baseRepo
}
type RemoteNameToRepoFn func(remoteName string) (ghrepo.Interface, error)
// PullRequestFindRefsResolver interrogates git configuration to try and determine
// a head repository and a remote branch name, from a local branch name.
type PullRequestFindRefsResolver struct {
GitConfigClient GitConfigClient
RemoteNameToRepoFn RemoteNameToRepoFn
}
func NewPullRequestFindRefsResolver(gitConfigClient GitConfigClient, remotesFn func() (ghContext.Remotes, error)) PullRequestFindRefsResolver {
return PullRequestFindRefsResolver{
GitConfigClient: gitConfigClient,
RemoteNameToRepoFn: newRemoteNameToRepoFn(remotesFn),
}
}
// ResolvePullRequests takes a base repository, a base branch name and a local branch name and uses the git configuration to
// determine the head repository and remote branch name. If we were unable to determine this from git, we default the head
// repository to the base repository.
func (r *PullRequestFindRefsResolver) ResolvePullRequestRefs(baseRepo ghrepo.Interface, baseBranchName, localBranchName string) (PRFindRefs, error) {
if baseRepo == nil {
return PRFindRefs{}, fmt.Errorf("find pull request ref resolution cannot be performed without a base repository")
}
if localBranchName == "" {
return PRFindRefs{}, fmt.Errorf("find pull request ref resolution cannot be performed without a local branch name")
}
headPRRef, err := TryDetermineDefaultPRHead(r.GitConfigClient, remoteToRepoResolver{r.RemoteNameToRepoFn}, localBranchName)
if err != nil {
return PRFindRefs{}, err
}
// If the headRepo was resolved, we can just convert the response
// to refs and return it.
if headRepo, present := headPRRef.Repo.Value(); present {
qualifiedHeadRef := NewQualifiedHeadRefWithoutOwner(headPRRef.BranchName)
if !ghrepo.IsSame(headRepo, baseRepo) {
qualifiedHeadRef = NewQualifiedHeadRef(headRepo.RepoOwner(), headPRRef.BranchName)
}
return PRFindRefs{
qualifiedHeadRef: qualifiedHeadRef,
baseRepo: baseRepo,
baseBranchName: o.SomeIfNonZero(baseBranchName),
}, nil
}
// If we didn't find a head repo, default to the base repo
return PRFindRefs{
qualifiedHeadRef: NewQualifiedHeadRefWithoutOwner(headPRRef.BranchName),
baseRepo: baseRepo,
baseBranchName: o.SomeIfNonZero(baseBranchName),
}, nil
}
// DefaultPRHead is a neighbour to defaultPushTarget, but instead of holding
// basic git remote information, it holds a resolved repository in `gh` terms.
//
// Since we may not be able to determine a default remote for a branch, this
// is also true of the resolved repository.
type DefaultPRHead struct {
Repo o.Option[ghrepo.Interface]
BranchName string
}
// TryDetermineDefaultPRHead is a thin wrapper around determineDefaultPushTarget, which attempts to convert
// a present remote into a resolved repository. If the remote is not present, we indicate that to the caller
// by returning a None value for the repo.
func TryDetermineDefaultPRHead(gitClient GitConfigClient, remoteToRepo remoteToRepoResolver, branch string) (DefaultPRHead, error) {
pushTarget, err := tryDetermineDefaultPushTarget(gitClient, branch)
if err != nil {
return DefaultPRHead{}, err
}
// If we have no remote, let the caller decide what to do by indicating that with a None.
if pushTarget.remote.IsNone() {
return DefaultPRHead{
Repo: o.None[ghrepo.Interface](),
BranchName: pushTarget.branchName,
}, nil
}
repo, err := remoteToRepo.resolve(pushTarget.remote.Unwrap())
if err != nil {
return DefaultPRHead{}, err
}
return DefaultPRHead{
Repo: o.Some(repo),
BranchName: pushTarget.branchName,
}, nil
}
// remote represents the value of the remote key in a branch's git configuration.
// This value may be a name or a URL, both of which are strings, but are unfortunately
// parsed by ReadBranchConfig into separate fields, allowing for illegal states to be
// created by accident. This is an attempt to indicate that they are mutally exclusive.
type remote interface{ sealedRemote() }
type remoteName struct{ name string }
func (rn remoteName) sealedRemote() {}
type remoteURL struct{ url *url.URL }
func (ru remoteURL) sealedRemote() {}
// newRemoteNameToRepoFn takes a function that returns a list of remotes and
// returns a function that takes a remote name and returns the corresponding
// repository. It is a convenience function to call sites having to duplicate
// the same logic.
func newRemoteNameToRepoFn(remotesFn func() (ghContext.Remotes, error)) RemoteNameToRepoFn {
return func(remoteName string) (ghrepo.Interface, error) {
remotes, err := remotesFn()
if err != nil {
return nil, err
}
repo, err := remotes.FindByName(remoteName)
if err != nil {
return nil, err
}
return repo, nil
}
}
// remoteToRepoResolver provides a utility method to resolve a remote (either name or URL)
// to a repo (ghrepo.Interface).
type remoteToRepoResolver struct {
remoteNameToRepo RemoteNameToRepoFn
}
func NewRemoteToRepoResolver(remotesFn func() (ghContext.Remotes, error)) remoteToRepoResolver {
return remoteToRepoResolver{
remoteNameToRepo: newRemoteNameToRepoFn(remotesFn),
}
}
// resolve takes a remote and returns a repository representing it.
func (r remoteToRepoResolver) resolve(remote remote) (ghrepo.Interface, error) {
switch v := remote.(type) {
case remoteName:
repo, err := r.remoteNameToRepo(v.name)
if err != nil {
return nil, fmt.Errorf("could not resolve remote %q: %w", v.name, err)
}
return repo, nil
case remoteURL:
repo, err := ghrepo.FromURL(v.url)
if err != nil {
return nil, fmt.Errorf("could not parse remote URL %q: %w", v.url, err)
}
return repo, nil
default:
return nil, fmt.Errorf("unsupported remote type %T, value: %v", v, remote)
}
}
// A defaultPushTarget represents the remote name or URL and a branch name
// that we would expect a branch to be pushed to if `git push` were run with
// no further arguments. This is the most likely place for the head of the PR
// to be, but it's not guaranteed. The user may have pushed to another branch
// directly via `git push <remote> <local>:<remote>` and not set up tracking information.
// A branch name is always present.
//
// It's possible that we're unable to determine a remote, if the user had pushed directly
// to a URL for example `git push <url> <branch>`, which is why it is optional. When present,
// the remote may either be a name or a URL.
type defaultPushTarget struct {
remote o.Option[remote]
branchName string
}
// newDefaultPushTarget is a thin wrapper over defaultPushTarget to help with
// generic type inference, to reduce verbosity in repeating the parametric type.
func newDefaultPushTarget(remote remote, branchName string) defaultPushTarget {
return defaultPushTarget{
remote: o.Some(remote),
branchName: branchName,
}
}
// tryDetermineDefaultPushTarget uses git configuration to make a best guess about where a branch
// is pushed to, and where it would be pushed to if the user ran `git push` with no additional
// arguments.
//
// Firstly, it attempts to resolve the @{push} ref, which is the most reliable method, as this
// is what git uses to determine the remote tracking branch
//
// If this fails, we go through a series of steps to determine the remote:
//
// 1. check branch configuration for `branch.<name>.pushRemote = <name> | <url>`
// 2. check remote configuration for `remote.pushDefault = <name>`
// 3. check branch configuration for `branch.<name>.remote = <name> | <url>`
//
// If none of these are set, we indicate that we were unable to determine the
// remote by returning a None value for the remote.
//
// The branch name is always set. The default configuration for push.default (current) indicates
// that a git push should use the same remote branch name as the local branch name. If push.default
// is set to upstream or tracking (deprecated form of upstream), then we use the branch name from the merge ref.
func tryDetermineDefaultPushTarget(gitClient GitConfigClient, localBranchName string) (defaultPushTarget, error) {
// If @{push} resolves, then we have the remote tracking branch already, no problem.
if pushRevisionRef, err := gitClient.PushRevision(context.Background(), localBranchName); err == nil {
return newDefaultPushTarget(remoteName{pushRevisionRef.Remote}, pushRevisionRef.Branch), nil
}
// But it doesn't always resolve, so we can suppress the error and move on to other means
// of determination. We'll first look at branch and remote configuration to make a determination.
branchConfig, err := gitClient.ReadBranchConfig(context.Background(), localBranchName)
if err != nil {
return defaultPushTarget{}, err
}
pushDefault, err := gitClient.PushDefault(context.Background())
if err != nil {
return defaultPushTarget{}, err
}
// We assume the PR's branch name is the same as whatever was provided, unless the user has specified
// push.default = upstream or tracking, then we use the branch name from the merge ref.
remoteBranch := localBranchName
if pushDefault == git.PushDefaultUpstream || pushDefault == git.PushDefaultTracking {
remoteBranch = strings.TrimPrefix(branchConfig.MergeRef, "refs/heads/")
if remoteBranch == "" {
return defaultPushTarget{}, fmt.Errorf("could not determine remote branch name")
}
}
// To get the remote, we look to the git config. It comes from one of the following, in order of precedence:
// 1. branch.<name>.pushRemote (which may be a name or a URL)
// 2. remote.pushDefault (which is a remote name)
// 3. branch.<name>.remote (which may be a name or a URL)
if branchConfig.PushRemoteName != "" {
return newDefaultPushTarget(
remoteName{branchConfig.PushRemoteName},
remoteBranch,
), nil
}
if branchConfig.PushRemoteURL != nil {
return newDefaultPushTarget(
remoteURL{branchConfig.PushRemoteURL},
remoteBranch,
), nil
}
remotePushDefault, err := gitClient.RemotePushDefault(context.Background())
if err != nil {
return defaultPushTarget{}, err
}
if remotePushDefault != "" {
return newDefaultPushTarget(
remoteName{remotePushDefault},
remoteBranch,
), nil
}
if branchConfig.RemoteName != "" {
return newDefaultPushTarget(
remoteName{branchConfig.RemoteName},
remoteBranch,
), nil
}
if branchConfig.RemoteURL != nil {
return newDefaultPushTarget(
remoteURL{branchConfig.RemoteURL},
remoteBranch,
), nil
}
// If we couldn't find the remote, we'll indicate that to the caller via None.
return defaultPushTarget{
remote: o.None[remote](),
branchName: remoteBranch,
}, nil
}

View file

@ -0,0 +1,508 @@
package shared
import (
"errors"
"net/url"
"testing"
ghContext "github.com/cli/cli/v2/context"
"github.com/cli/cli/v2/git"
"github.com/cli/cli/v2/internal/ghrepo"
o "github.com/cli/cli/v2/pkg/option"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestQualifiedHeadRef(t *testing.T) {
t.Parallel()
testCases := []struct {
behavior string
ref string
expectedString string
expectedBranchName string
expectedError error
}{
{
behavior: "when a branch is provided, the parsed qualified head ref only has a branch",
ref: "feature-branch",
expectedString: "feature-branch",
expectedBranchName: "feature-branch",
},
{
behavior: "when an owner and branch are provided, the parsed qualified head ref has both",
ref: "owner:feature-branch",
expectedString: "owner:feature-branch",
expectedBranchName: "feature-branch",
},
{
behavior: "when the structure cannot be interpreted correctly, an error is returned",
ref: "owner:feature-branch:extra",
expectedError: errors.New("invalid qualified head ref format 'owner:feature-branch:extra'"),
},
}
for _, tc := range testCases {
t.Run(tc.behavior, func(t *testing.T) {
t.Parallel()
qualifiedHeadRef, err := ParseQualifiedHeadRef(tc.ref)
if tc.expectedError != nil {
require.Equal(t, tc.expectedError, err)
return
}
require.NoError(t, err)
assert.Equal(t, tc.expectedString, qualifiedHeadRef.String())
assert.Equal(t, tc.expectedBranchName, qualifiedHeadRef.BranchName())
})
}
}
func TestPRFindRefs(t *testing.T) {
t.Parallel()
t.Run("qualified head ref with owner", func(t *testing.T) {
t.Parallel()
refs := PRFindRefs{
qualifiedHeadRef: mustParseQualifiedHeadRef("forkowner:feature-branch"),
}
require.Equal(t, "forkowner:feature-branch", refs.QualifiedHeadRef())
require.Equal(t, "feature-branch", refs.UnqualifiedHeadRef())
})
t.Run("qualified head ref without owner", func(t *testing.T) {
t.Parallel()
refs := PRFindRefs{
qualifiedHeadRef: mustParseQualifiedHeadRef("feature-branch"),
}
require.Equal(t, "feature-branch", refs.QualifiedHeadRef())
require.Equal(t, "feature-branch", refs.UnqualifiedHeadRef())
})
t.Run("base repo", func(t *testing.T) {
t.Parallel()
refs := PRFindRefs{
baseRepo: ghrepo.New("owner", "repo"),
}
require.True(t, ghrepo.IsSame(refs.BaseRepo(), ghrepo.New("owner", "repo")), "expected repos to be the same")
})
t.Run("matches", func(t *testing.T) {
t.Parallel()
testCases := []struct {
behavior string
refs PRFindRefs
baseBranchName string
qualifiedHeadRef string
expectedMatch bool
}{
{
behavior: "when qualified head refs don't match, returns false",
refs: PRFindRefs{
qualifiedHeadRef: mustParseQualifiedHeadRef("owner:feature-branch"),
},
baseBranchName: "feature-branch",
qualifiedHeadRef: "feature-branch",
expectedMatch: false,
},
{
behavior: "when base branches don't match, returns false",
refs: PRFindRefs{
qualifiedHeadRef: mustParseQualifiedHeadRef("feature-branch"),
baseBranchName: o.Some("not-main"),
},
baseBranchName: "main",
qualifiedHeadRef: "feature-branch",
expectedMatch: false,
},
{
behavior: "when head refs match and there is no base branch, returns true",
refs: PRFindRefs{
qualifiedHeadRef: mustParseQualifiedHeadRef("feature-branch"),
baseBranchName: o.None[string](),
},
baseBranchName: "main",
qualifiedHeadRef: "feature-branch",
expectedMatch: true,
},
{
behavior: "when head refs match and base branches match, returns true",
refs: PRFindRefs{
qualifiedHeadRef: mustParseQualifiedHeadRef("feature-branch"),
baseBranchName: o.Some("main"),
},
baseBranchName: "main",
qualifiedHeadRef: "feature-branch",
expectedMatch: true,
},
}
for _, tc := range testCases {
t.Run(tc.behavior, func(t *testing.T) {
t.Parallel()
require.Equal(t, tc.expectedMatch, tc.refs.Matches(tc.baseBranchName, tc.qualifiedHeadRef))
})
}
})
}
func TestPullRequestResolution(t *testing.T) {
t.Parallel()
baseRepo := ghrepo.New("owner", "repo")
baseRemote := ghContext.Remote{
Remote: &git.Remote{
Name: "upstream",
},
Repo: ghrepo.New("owner", "repo"),
}
forkRemote := ghContext.Remote{
Remote: &git.Remote{
Name: "origin",
},
Repo: ghrepo.New("otherowner", "repo-fork"),
}
t.Run("when the base repo is nil, returns an error", func(t *testing.T) {
t.Parallel()
resolver := NewPullRequestFindRefsResolver(stubGitConfigClient{}, dummyRemotesFn)
_, err := resolver.ResolvePullRequestRefs(nil, "", "")
require.Error(t, err)
})
t.Run("when the local branch name is empty, returns an error", func(t *testing.T) {
t.Parallel()
resolver := NewPullRequestFindRefsResolver(stubGitConfigClient{}, dummyRemotesFn)
_, err := resolver.ResolvePullRequestRefs(baseRepo, "", "")
require.Error(t, err)
})
t.Run("when the default pr head has a repo, it is used for the refs", func(t *testing.T) {
t.Parallel()
// Push revision is the first thing checked for resolution,
// so nothing else needs to be stubbed.
repoResolvedFromPushRevisionClient := stubGitConfigClient{
pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{
Remote: "origin",
Branch: "feature-branch",
}, nil),
}
resolver := NewPullRequestFindRefsResolver(
repoResolvedFromPushRevisionClient,
stubRemotes(ghContext.Remotes{&baseRemote, &forkRemote}, nil),
)
refs, err := resolver.ResolvePullRequestRefs(baseRepo, "main", "feature-branch")
require.NoError(t, err)
expectedRefs := PRFindRefs{
qualifiedHeadRef: QualifiedHeadRef{
owner: o.Some("otherowner"),
branchName: "feature-branch",
},
baseRepo: baseRepo,
baseBranchName: o.Some("main"),
}
require.Equal(t, expectedRefs, refs)
})
t.Run("when the default pr head does not have a repo, we use the base repo for the head", func(t *testing.T) {
t.Parallel()
// All the values stubbed here result in being unable to resolve a default repo.
noRepoResolutionStubClient := stubGitConfigClient{
pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("test error")),
readBranchConfigFn: stubBranchConfig(git.BranchConfig{}, nil),
pushDefaultFn: stubPushDefault("", nil),
remotePushDefaultFn: stubRemotePushDefault("", nil),
}
resolver := NewPullRequestFindRefsResolver(
noRepoResolutionStubClient,
stubRemotes(ghContext.Remotes{&baseRemote, &forkRemote}, nil),
)
refs, err := resolver.ResolvePullRequestRefs(baseRepo, "main", "feature-branch")
require.NoError(t, err)
expectedRefs := PRFindRefs{
qualifiedHeadRef: QualifiedHeadRef{
owner: o.None[string](),
branchName: "feature-branch",
},
baseRepo: baseRepo,
baseBranchName: o.Some("main"),
}
require.Equal(t, expectedRefs, refs)
})
}
func TestTryDetermineDefaultPRHead(t *testing.T) {
t.Parallel()
baseRepo := ghrepo.New("owner", "repo")
baseRemote := ghContext.Remote{
Remote: &git.Remote{
Name: "upstream",
},
Repo: baseRepo,
}
forkRepo := ghrepo.New("otherowner", "repo-fork")
forkRemote := ghContext.Remote{
Remote: &git.Remote{
Name: "origin",
},
Repo: forkRepo,
}
forkRepoURL, err := url.Parse("https://github.com/otherowner/repo-fork.git")
require.NoError(t, err)
t.Run("when the push revision is set, use that", func(t *testing.T) {
t.Parallel()
repoResolvedFromPushRevisionClient := stubGitConfigClient{
pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{
Remote: "origin",
Branch: "remote-feature-branch",
}, nil),
}
defaultPRHead, err := TryDetermineDefaultPRHead(
repoResolvedFromPushRevisionClient,
stubRemoteToRepoResolver(ghContext.Remotes{&baseRemote, &forkRemote}, nil),
"feature-branch",
)
require.NoError(t, err)
require.True(t, ghrepo.IsSame(defaultPRHead.Repo.Unwrap(), forkRepo), "expected repos to be the same")
require.Equal(t, "remote-feature-branch", defaultPRHead.BranchName)
})
t.Run("when the branch config push remote is set to a name, use that", func(t *testing.T) {
t.Parallel()
repoResolvedFromPushRemoteClient := stubGitConfigClient{
pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("no push revision")),
readBranchConfigFn: stubBranchConfig(git.BranchConfig{
PushRemoteName: "origin",
}, nil),
pushDefaultFn: stubPushDefault(git.PushDefaultCurrent, nil),
}
defaultPRHead, err := TryDetermineDefaultPRHead(
repoResolvedFromPushRemoteClient,
stubRemoteToRepoResolver(ghContext.Remotes{&baseRemote, &forkRemote}, nil),
"feature-branch",
)
require.NoError(t, err)
require.True(t, ghrepo.IsSame(defaultPRHead.Repo.Unwrap(), forkRepo), "expected repos to be the same")
require.Equal(t, "feature-branch", defaultPRHead.BranchName)
})
t.Run("when the branch config push remote is set to a URL, use that", func(t *testing.T) {
t.Parallel()
repoResolvedFromPushRemoteClient := stubGitConfigClient{
pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("no push revision")),
readBranchConfigFn: stubBranchConfig(git.BranchConfig{
PushRemoteURL: forkRepoURL,
}, nil),
pushDefaultFn: stubPushDefault(git.PushDefaultCurrent, nil),
}
defaultPRHead, err := TryDetermineDefaultPRHead(
repoResolvedFromPushRemoteClient,
dummyRemoteToRepoResolver(),
"feature-branch",
)
require.NoError(t, err)
require.True(t, ghrepo.IsSame(defaultPRHead.Repo.Unwrap(), forkRepo), "expected repos to be the same")
require.Equal(t, "feature-branch", defaultPRHead.BranchName)
})
t.Run("when a remote push default is set, use that", func(t *testing.T) {
t.Parallel()
repoResolvedFromPushRemoteClient := stubGitConfigClient{
pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("no push revision")),
readBranchConfigFn: stubBranchConfig(git.BranchConfig{}, nil),
pushDefaultFn: stubPushDefault(git.PushDefaultCurrent, nil),
remotePushDefaultFn: stubRemotePushDefault("origin", nil),
}
defaultPRHead, err := TryDetermineDefaultPRHead(
repoResolvedFromPushRemoteClient,
stubRemoteToRepoResolver(ghContext.Remotes{&baseRemote, &forkRemote}, nil),
"feature-branch",
)
require.NoError(t, err)
require.True(t, ghrepo.IsSame(defaultPRHead.Repo.Unwrap(), forkRepo), "expected repos to be the same")
require.Equal(t, "feature-branch", defaultPRHead.BranchName)
})
t.Run("when the branch config remote is set to a name, use that", func(t *testing.T) {
t.Parallel()
repoResolvedFromPushRemoteClient := stubGitConfigClient{
pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("no push revision")),
readBranchConfigFn: stubBranchConfig(git.BranchConfig{
RemoteName: "origin",
}, nil),
pushDefaultFn: stubPushDefault(git.PushDefaultCurrent, nil),
remotePushDefaultFn: stubRemotePushDefault("", nil),
}
defaultPRHead, err := TryDetermineDefaultPRHead(
repoResolvedFromPushRemoteClient,
stubRemoteToRepoResolver(ghContext.Remotes{&baseRemote, &forkRemote}, nil),
"feature-branch",
)
require.NoError(t, err)
require.True(t, ghrepo.IsSame(defaultPRHead.Repo.Unwrap(), forkRepo), "expected repos to be the same")
require.Equal(t, "feature-branch", defaultPRHead.BranchName)
})
t.Run("when the branch config remote is set to a URL, use that", func(t *testing.T) {
t.Parallel()
repoResolvedFromPushRemoteClient := stubGitConfigClient{
pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("no push revision")),
readBranchConfigFn: stubBranchConfig(git.BranchConfig{
RemoteURL: forkRepoURL,
}, nil),
pushDefaultFn: stubPushDefault(git.PushDefaultCurrent, nil),
remotePushDefaultFn: stubRemotePushDefault("", nil),
}
defaultPRHead, err := TryDetermineDefaultPRHead(
repoResolvedFromPushRemoteClient,
dummyRemoteToRepoResolver(),
"feature-branch",
)
require.NoError(t, err)
require.True(t, ghrepo.IsSame(defaultPRHead.Repo.Unwrap(), forkRepo), "expected repos to be the same")
require.Equal(t, "feature-branch", defaultPRHead.BranchName)
})
t.Run("when git didn't provide the necessary information, return none for the remote", func(t *testing.T) {
t.Parallel()
// All the values stubbed here result in being unable to resolve a default repo.
noRepoResolutionStubClient := stubGitConfigClient{
pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("test error")),
readBranchConfigFn: stubBranchConfig(git.BranchConfig{}, nil),
pushDefaultFn: stubPushDefault("", nil),
remotePushDefaultFn: stubRemotePushDefault("", nil),
}
defaultPRHead, err := TryDetermineDefaultPRHead(
noRepoResolutionStubClient,
stubRemoteToRepoResolver(ghContext.Remotes{&baseRemote, &forkRemote}, nil),
"feature-branch",
)
require.NoError(t, err)
require.True(t, defaultPRHead.Repo.IsNone(), "expected repo to be none")
require.Equal(t, "feature-branch", defaultPRHead.BranchName)
})
t.Run("when the push default is tracking or upstream, use the merge ref", func(t *testing.T) {
t.Parallel()
testCases := []struct {
pushDefault git.PushDefault
}{
{pushDefault: git.PushDefaultTracking},
{pushDefault: git.PushDefaultUpstream},
}
for _, tc := range testCases {
t.Run(string(tc.pushDefault), func(t *testing.T) {
t.Parallel()
repoResolvedFromPushRemoteClient := stubGitConfigClient{
pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("test error")),
readBranchConfigFn: stubBranchConfig(git.BranchConfig{
PushRemoteName: "origin",
MergeRef: "main",
}, nil),
pushDefaultFn: stubPushDefault(tc.pushDefault, nil),
}
defaultPRHead, err := TryDetermineDefaultPRHead(
repoResolvedFromPushRemoteClient,
stubRemoteToRepoResolver(ghContext.Remotes{&baseRemote, &forkRemote}, nil),
"feature-branch",
)
require.NoError(t, err)
require.True(t, ghrepo.IsSame(defaultPRHead.Repo.Unwrap(), forkRepo), "expected repos to be the same")
require.Equal(t, "main", defaultPRHead.BranchName)
})
}
t.Run("but if the merge ref is empty, error", func(t *testing.T) {
t.Parallel()
repoResolvedFromPushRemoteClient := stubGitConfigClient{
pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("test error")),
readBranchConfigFn: stubBranchConfig(git.BranchConfig{
PushRemoteName: "origin",
MergeRef: "", // intentionally empty
}, nil),
pushDefaultFn: stubPushDefault(git.PushDefaultUpstream, nil),
}
_, err := TryDetermineDefaultPRHead(
repoResolvedFromPushRemoteClient,
stubRemoteToRepoResolver(ghContext.Remotes{&baseRemote, &forkRemote}, nil),
"feature-branch",
)
require.Error(t, err)
})
})
}
func dummyRemotesFn() (ghContext.Remotes, error) {
panic("remotes fn not implemented")
}
func dummyRemoteToRepoResolver() remoteToRepoResolver {
return NewRemoteToRepoResolver(dummyRemotesFn)
}
func stubRemoteToRepoResolver(remotes ghContext.Remotes, err error) remoteToRepoResolver {
return NewRemoteToRepoResolver(func() (ghContext.Remotes, error) {
return remotes, err
})
}
func mustParseQualifiedHeadRef(ref string) QualifiedHeadRef {
parsed, err := ParseQualifiedHeadRef(ref)
if err != nil {
panic(err)
}
return parsed
}

View file

@ -13,11 +13,12 @@ import (
"time"
"github.com/cli/cli/v2/api"
remotes "github.com/cli/cli/v2/context"
ghContext "github.com/cli/cli/v2/context"
"github.com/cli/cli/v2/git"
fd "github.com/cli/cli/v2/internal/featuredetection"
"github.com/cli/cli/v2/internal/ghrepo"
"github.com/cli/cli/v2/pkg/cmdutil"
o "github.com/cli/cli/v2/pkg/option"
"github.com/cli/cli/v2/pkg/set"
"github.com/shurcooL/githubv4"
"golang.org/x/sync/errgroup"
@ -32,16 +33,20 @@ type progressIndicator interface {
StopProgressIndicator()
}
type GitConfigClient interface {
ReadBranchConfig(ctx context.Context, branchName string) (git.BranchConfig, error)
PushDefault(ctx context.Context) (git.PushDefault, error)
RemotePushDefault(ctx context.Context) (string, error)
PushRevision(ctx context.Context, branchName string) (git.RemoteTrackingRef, error)
}
type finder struct {
baseRepoFn func() (ghrepo.Interface, error)
branchFn func() (string, error)
remotesFn func() (remotes.Remotes, error)
httpClient func() (*http.Client, error)
pushDefault func() (string, error)
remotePushDefault func() (string, error)
parsePushRevision func(string) (string, error)
branchConfig func(string) (git.BranchConfig, error)
progress progressIndicator
baseRepoFn func() (ghrepo.Interface, error)
branchFn func() (string, error)
httpClient func() (*http.Client, error)
remotesFn func() (ghContext.Remotes, error)
gitConfigClient GitConfigClient
progress progressIndicator
baseRefRepo ghrepo.Interface
prNumber int
@ -56,23 +61,12 @@ func NewFinder(factory *cmdutil.Factory) PRFinder {
}
return &finder{
baseRepoFn: factory.BaseRepo,
branchFn: factory.Branch,
remotesFn: factory.Remotes,
httpClient: factory.HttpClient,
pushDefault: func() (string, error) {
return factory.GitClient.PushDefault(context.Background())
},
remotePushDefault: func() (string, error) {
return factory.GitClient.RemotePushDefault(context.Background())
},
parsePushRevision: func(branch string) (string, error) {
return factory.GitClient.ParsePushRevision(context.Background(), branch)
},
progress: factory.IOStreams,
branchConfig: func(s string) (git.BranchConfig, error) {
return factory.GitClient.ReadBranchConfig(context.Background(), s)
},
baseRepoFn: factory.BaseRepo,
branchFn: factory.Branch,
httpClient: factory.HttpClient,
gitConfigClient: factory.GitClient,
remotesFn: factory.Remotes,
progress: factory.IOStreams,
}
}
@ -97,32 +91,6 @@ type FindOptions struct {
States []string
}
// TODO: Does this also need the BaseBranchName?
// PR's are represented by the following:
// headRef -----PR-----> baseRef
//
// A ref is described as "remoteName/branchName", so
// headRepoName/headBranchName -----PR-----> baseRepoName/baseBranchName
type PullRequestRefs struct {
BranchName string
HeadRepo ghrepo.Interface
BaseRepo ghrepo.Interface
}
func (s *PullRequestRefs) HasHead() bool {
return s.HeadRepo != nil && s.BranchName != ""
}
// GetPRHeadLabel returns the string that the GitHub API uses to identify the PR. This is
// either just the branch name or, if the PR is originating from a fork, the fork owner
// and the branch name, like <user>:<branch>.
func (s *PullRequestRefs) GetPRHeadLabel() string {
if ghrepo.IsSame(s.HeadRepo, s.BaseRepo) {
return s.BranchName
}
return fmt.Sprintf("%s:%s", s.HeadRepo.RepoOwner(), s.BranchName)
}
func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, error) {
// If we have a URL, we don't need git stuff
if len(opts.Fields) == 0 {
@ -142,7 +110,7 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err
f.baseRefRepo = repo
}
var prRefs PullRequestRefs
var prRefs PRFindRefs
if opts.Selector == "" {
// You must be in a git repo for this case to work
currentBranchName, err := f.branchFn()
@ -152,7 +120,7 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err
f.branchName = currentBranchName
// Get the branch config for the current branchName
branchConfig, err := f.branchConfig(f.branchName)
branchConfig, err := f.gitConfigClient.ReadBranchConfig(context.Background(), f.branchName)
if err != nil {
return nil, nil, err
}
@ -166,30 +134,19 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err
// Determine the PullRequestRefs from config
if f.prNumber == 0 {
rems, err := f.remotesFn()
if err != nil {
return nil, nil, err
}
// Suppressing these errors as we have other means of computing the PullRequestRefs when these fail.
parsedPushRevision, _ := f.parsePushRevision(f.branchName)
pushDefault, err := f.pushDefault()
if err != nil {
return nil, nil, err
}
remotePushDefault, err := f.remotePushDefault()
if err != nil {
return nil, nil, err
}
prRefs, err = ParsePRRefs(f.branchName, branchConfig, parsedPushRevision, pushDefault, remotePushDefault, f.baseRefRepo, rems)
prRefsResolver := NewPullRequestFindRefsResolver(
// We requested the branch config already, so let's cache that
CachedBranchConfigGitConfigClient{
CachedBranchConfig: branchConfig,
GitConfigClient: f.gitConfigClient,
},
f.remotesFn,
)
prRefs, err = prRefsResolver.ResolvePullRequestRefs(f.baseRefRepo, opts.BaseBranch, f.branchName)
if err != nil {
return nil, nil, err
}
}
} else if f.prNumber == 0 {
// You gave me a selector but I couldn't find a PR number (it wasn't a URL)
@ -204,11 +161,17 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err
f.prNumber = prNumber
} else {
f.branchName = opts.Selector
// We don't expect an error here because parsedPushRevision is empty
prRefs, err = ParsePRRefs(f.branchName, git.BranchConfig{}, "", "", "", f.baseRefRepo, remotes.Remotes{})
qualifiedHeadRef, err := ParseQualifiedHeadRef(f.branchName)
if err != nil {
return nil, nil, err
}
prRefs = PRFindRefs{
qualifiedHeadRef: qualifiedHeadRef,
baseRepo: f.baseRefRepo,
baseBranchName: o.SomeIfNonZero(opts.BaseBranch),
}
}
}
@ -259,7 +222,7 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err
return pr, f.baseRefRepo, err
}
} else {
pr, err = findForBranch(httpClient, f.baseRefRepo, opts.BaseBranch, prRefs.GetPRHeadLabel(), opts.States, fields.ToSlice())
pr, err = findForRefs(httpClient, prRefs, opts.States, fields.ToSlice())
if err != nil {
return pr, f.baseRefRepo, err
}
@ -321,72 +284,6 @@ func (f *finder) parseURL(prURL string) (ghrepo.Interface, int, error) {
return repo, prNumber, nil
}
func ParsePRRefs(currentBranchName string, branchConfig git.BranchConfig, parsedPushRevision string, pushDefault string, remotePushDefault string, baseRefRepo ghrepo.Interface, rems remotes.Remotes) (PullRequestRefs, error) {
prRefs := PullRequestRefs{
BaseRepo: baseRefRepo,
}
// If @{push} resolves, then we have all the information we need to determine the head repo
// and branch name. It is of the form <remote>/<branch>.
if parsedPushRevision != "" {
for _, r := range rems {
// Find the remote who's name matches the push <remote> prefix
if strings.HasPrefix(parsedPushRevision, r.Name+"/") {
prRefs.BranchName = strings.TrimPrefix(parsedPushRevision, r.Name+"/")
prRefs.HeadRepo = r.Repo
return prRefs, nil
}
}
remoteNames := make([]string, len(rems))
for i, r := range rems {
remoteNames[i] = r.Name
}
return PullRequestRefs{}, fmt.Errorf("no remote for %q found in %q", parsedPushRevision, strings.Join(remoteNames, ", "))
}
// We assume the PR's branch name is the same as whatever f.BranchFn() returned earlier
// unless the user has specified push.default = upstream or tracking, then we use the
// branch name from the merge ref.
prRefs.BranchName = currentBranchName
if pushDefault == "upstream" || pushDefault == "tracking" {
prRefs.BranchName = strings.TrimPrefix(branchConfig.MergeRef, "refs/heads/")
}
// To get the HeadRepo, we look to the git config. The HeadRepo comes from one of the following, in order of precedence:
// 1. branch.<name>.pushRemote
// 2. remote.pushDefault
// 3. branch.<name>.remote
if branchConfig.PushRemoteName != "" {
if r, err := rems.FindByName(branchConfig.PushRemoteName); err == nil {
prRefs.HeadRepo = r.Repo
}
} else if branchConfig.PushRemoteURL != nil {
if r, err := ghrepo.FromURL(branchConfig.PushRemoteURL); err == nil {
prRefs.HeadRepo = r
}
} else if remotePushDefault != "" {
if r, err := rems.FindByName(remotePushDefault); err == nil {
prRefs.HeadRepo = r.Repo
}
} else if branchConfig.RemoteName != "" {
if r, err := rems.FindByName(branchConfig.RemoteName); err == nil {
prRefs.HeadRepo = r.Repo
}
} else if branchConfig.RemoteURL != nil {
if r, err := ghrepo.FromURL(branchConfig.RemoteURL); err == nil {
prRefs.HeadRepo = r
}
}
// The PR merges from a branch in the same repo as the base branch (usually the default branch)
if prRefs.HeadRepo == nil {
prRefs.HeadRepo = baseRefRepo
}
return prRefs, nil
}
func findByNumber(httpClient *http.Client, repo ghrepo.Interface, number int, fields []string) (*api.PullRequest, error) {
type response struct {
Repository struct {
@ -417,7 +314,7 @@ func findByNumber(httpClient *http.Client, repo ghrepo.Interface, number int, fi
return &resp.Repository.PullRequest, nil
}
func findForBranch(httpClient *http.Client, repo ghrepo.Interface, baseBranch, headBranchWithOwnerIfFork string, stateFilters, fields []string) (*api.PullRequest, error) {
func findForRefs(httpClient *http.Client, prRefs PRFindRefs, stateFilters, fields []string) (*api.PullRequest, error) {
type response struct {
Repository struct {
PullRequests struct {
@ -444,21 +341,16 @@ func findForBranch(httpClient *http.Client, repo ghrepo.Interface, baseBranch, h
}
}`, api.PullRequestGraphQL(fieldSet.ToSlice()))
branchWithoutOwner := headBranchWithOwnerIfFork
if idx := strings.Index(headBranchWithOwnerIfFork, ":"); idx >= 0 {
branchWithoutOwner = headBranchWithOwnerIfFork[idx+1:]
}
variables := map[string]interface{}{
"owner": repo.RepoOwner(),
"repo": repo.RepoName(),
"headRefName": branchWithoutOwner,
"owner": prRefs.BaseRepo().RepoOwner(),
"repo": prRefs.BaseRepo().RepoName(),
"headRefName": prRefs.UnqualifiedHeadRef(),
"states": stateFilters,
}
var resp response
client := api.NewClientFromHTTP(httpClient)
err := client.GraphQL(repo.RepoHost(), query, variables, &resp)
err := client.GraphQL(prRefs.BaseRepo().RepoHost(), query, variables, &resp)
if err != nil {
return nil, err
}
@ -469,17 +361,15 @@ func findForBranch(httpClient *http.Client, repo ghrepo.Interface, baseBranch, h
})
for _, pr := range prs {
headBranchMatches := pr.HeadLabel() == headBranchWithOwnerIfFork
baseBranchEmptyOrMatches := baseBranch == "" || pr.BaseRefName == baseBranch
// When the head is the default branch, it doesn't really make sense to show merged or closed PRs.
// https://github.com/cli/cli/issues/4263
isNotClosedOrMergedWhenHeadIsDefault := pr.State == "OPEN" || resp.Repository.DefaultBranchRef.Name != headBranchWithOwnerIfFork
if headBranchMatches && baseBranchEmptyOrMatches && isNotClosedOrMergedWhenHeadIsDefault {
isNotClosedOrMergedWhenHeadIsDefault := pr.State == "OPEN" || resp.Repository.DefaultBranchRef.Name != prRefs.QualifiedHeadRef()
if prRefs.Matches(pr.BaseRefName, pr.HeadLabel()) && isNotClosedOrMergedWhenHeadIsDefault {
return &pr, nil
}
}
return nil, &NotFoundError{fmt.Errorf("no pull requests found for branch %q", headBranchWithOwnerIfFork)}
return nil, &NotFoundError{fmt.Errorf("no pull requests found for branch %q", prRefs.QualifiedHeadRef())}
}
func preloadPrReviews(httpClient *http.Client, repo ghrepo.Interface, pr *api.PullRequest) error {

View file

@ -1,46 +1,41 @@
package shared
import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
"testing"
"github.com/cli/cli/v2/context"
ghContext "github.com/cli/cli/v2/context"
"github.com/cli/cli/v2/git"
"github.com/cli/cli/v2/internal/ghrepo"
"github.com/cli/cli/v2/pkg/httpmock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type args struct {
baseRepoFn func() (ghrepo.Interface, error)
branchFn func() (string, error)
branchConfig func(string) (git.BranchConfig, error)
pushDefault func() (string, error)
remotePushDefault func() (string, error)
parsePushRevision func(string) (string, error)
selector string
fields []string
baseBranch string
baseRepoFn func() (ghrepo.Interface, error)
branchFn func() (string, error)
gitConfigClient stubGitConfigClient
selector string
fields []string
baseBranch string
}
func TestFind(t *testing.T) {
// TODO: Abstract these out meaningfully for reuse in parsePRRefs tests
originOwnerUrl, err := url.Parse("https://github.com/ORIGINOWNER/REPO.git")
if err != nil {
t.Fatal(err)
}
remoteOrigin := context.Remote{
remoteOrigin := ghContext.Remote{
Remote: &git.Remote{
Name: "origin",
FetchURL: originOwnerUrl,
},
Repo: ghrepo.New("ORIGINOWNER", "REPO"),
}
remoteOther := context.Remote{
remoteOther := ghContext.Remote{
Remote: &git.Remote{
Name: "other",
FetchURL: originOwnerUrl,
@ -52,7 +47,7 @@ func TestFind(t *testing.T) {
if err != nil {
t.Fatal(err)
}
remoteUpstream := context.Remote{
remoteUpstream := ghContext.Remote{
Remote: &git.Remote{
Name: "upstream",
FetchURL: upstreamOwnerUrl,
@ -77,7 +72,6 @@ func TestFind(t *testing.T) {
branchFn: func() (string, error) {
return "blueberries", nil
},
branchConfig: stubBranchConfig(git.BranchConfig{}, nil),
},
httpStub: func(r *httpmock.Registry) {
r.Register(
@ -99,12 +93,14 @@ func TestFind(t *testing.T) {
branchFn: func() (string, error) {
return "blueberries", nil
},
branchConfig: stubBranchConfig(git.BranchConfig{
PushRemoteName: remoteOrigin.Remote.Name,
}, nil),
pushDefault: stubPushDefault("simple", nil),
remotePushDefault: stubRemotePushDefault("", nil),
parsePushRevision: stubParsedPushRevision("", nil),
gitConfigClient: stubGitConfigClient{
readBranchConfigFn: stubBranchConfig(git.BranchConfig{
PushRemoteName: remoteOrigin.Remote.Name,
}, nil),
pushDefaultFn: stubPushDefault(git.PushDefaultSimple, nil),
remotePushDefaultFn: stubRemotePushDefault("", nil),
pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("testErr")),
},
},
httpStub: func(r *httpmock.Registry) {
r.Register(
@ -134,9 +130,11 @@ func TestFind(t *testing.T) {
branchFn: func() (string, error) {
return "blueberries", nil
},
branchConfig: stubBranchConfig(git.BranchConfig{}, nil),
pushDefault: stubPushDefault("simple", nil),
remotePushDefault: stubRemotePushDefault("", nil),
gitConfigClient: stubGitConfigClient{
readBranchConfigFn: stubBranchConfig(git.BranchConfig{}, nil),
pushDefaultFn: stubPushDefault(git.PushDefaultSimple, nil),
remotePushDefaultFn: stubRemotePushDefault("", nil),
},
},
wantErr: true,
},
@ -157,9 +155,11 @@ func TestFind(t *testing.T) {
branchFn: func() (string, error) {
return "blueberries", nil
},
branchConfig: stubBranchConfig(git.BranchConfig{}, nil),
pushDefault: stubPushDefault("simple", nil),
remotePushDefault: stubRemotePushDefault("", nil),
gitConfigClient: stubGitConfigClient{
readBranchConfigFn: stubBranchConfig(git.BranchConfig{}, nil),
pushDefaultFn: stubPushDefault(git.PushDefaultSimple, nil),
remotePushDefaultFn: stubRemotePushDefault("", nil),
},
},
httpStub: nil,
wantPR: 13,
@ -174,9 +174,11 @@ func TestFind(t *testing.T) {
branchFn: func() (string, error) {
return "blueberries", nil
},
branchConfig: stubBranchConfig(git.BranchConfig{}, nil),
pushDefault: stubPushDefault("simple", nil),
remotePushDefault: stubRemotePushDefault("", nil),
gitConfigClient: stubGitConfigClient{
readBranchConfigFn: stubBranchConfig(git.BranchConfig{}, nil),
pushDefaultFn: stubPushDefault(git.PushDefaultSimple, nil),
remotePushDefaultFn: stubRemotePushDefault("", nil),
},
},
httpStub: func(r *httpmock.Registry) {
r.Register(
@ -197,9 +199,11 @@ func TestFind(t *testing.T) {
branchFn: func() (string, error) {
return "blueberries", nil
},
branchConfig: stubBranchConfig(git.BranchConfig{}, nil),
pushDefault: stubPushDefault("simple", nil),
remotePushDefault: stubRemotePushDefault("", nil),
gitConfigClient: stubGitConfigClient{
readBranchConfigFn: stubBranchConfig(git.BranchConfig{}, nil),
pushDefaultFn: stubPushDefault(git.PushDefaultSimple, nil),
remotePushDefaultFn: stubRemotePushDefault("", nil),
},
},
httpStub: func(r *httpmock.Registry) {
r.Register(
@ -223,15 +227,17 @@ func TestFind(t *testing.T) {
ExitCode: 128,
}
},
branchConfig: stubBranchConfig(git.BranchConfig{}, &git.GitError{
Stderr: "fatal: branchConfig error",
ExitCode: 128,
}),
pushDefault: stubPushDefault("simple", nil),
remotePushDefault: stubRemotePushDefault("", &git.GitError{
Stderr: "fatal: remotePushDefault error",
ExitCode: 128,
}),
gitConfigClient: stubGitConfigClient{
readBranchConfigFn: stubBranchConfig(git.BranchConfig{}, &git.GitError{
Stderr: "fatal: branchConfig error",
ExitCode: 128,
}),
pushDefaultFn: stubPushDefault(git.PushDefaultSimple, nil),
remotePushDefaultFn: stubRemotePushDefault("", &git.GitError{
Stderr: "fatal: remotePushDefault error",
ExitCode: 128,
}),
},
},
httpStub: func(r *httpmock.Registry) {
r.Register(
@ -252,10 +258,12 @@ func TestFind(t *testing.T) {
branchFn: func() (string, error) {
return "blueberries", nil
},
branchConfig: stubBranchConfig(git.BranchConfig{}, nil),
pushDefault: stubPushDefault("simple", nil),
parsePushRevision: stubParsedPushRevision("", nil),
remotePushDefault: stubRemotePushDefault("", nil),
gitConfigClient: stubGitConfigClient{
readBranchConfigFn: stubBranchConfig(git.BranchConfig{}, nil),
pushDefaultFn: stubPushDefault(git.PushDefaultSimple, nil),
pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("testErr")),
remotePushDefaultFn: stubRemotePushDefault("", nil),
},
},
httpStub: func(r *httpmock.Registry) {
r.Register(
@ -296,10 +304,12 @@ func TestFind(t *testing.T) {
branchFn: func() (string, error) {
return "blueberries", nil
},
branchConfig: stubBranchConfig(git.BranchConfig{}, nil),
pushDefault: stubPushDefault("simple", nil),
remotePushDefault: stubRemotePushDefault("", nil),
parsePushRevision: stubParsedPushRevision("", nil),
gitConfigClient: stubGitConfigClient{
readBranchConfigFn: stubBranchConfig(git.BranchConfig{}, nil),
pushDefaultFn: stubPushDefault(git.PushDefaultSimple, nil),
remotePushDefaultFn: stubRemotePushDefault("", nil),
pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("testErr")),
},
},
httpStub: func(r *httpmock.Registry) {
r.Register(
@ -339,10 +349,12 @@ func TestFind(t *testing.T) {
branchFn: func() (string, error) {
return "blueberries", nil
},
branchConfig: stubBranchConfig(git.BranchConfig{}, nil),
pushDefault: stubPushDefault("simple", nil),
remotePushDefault: stubRemotePushDefault("", nil),
parsePushRevision: stubParsedPushRevision("", nil),
gitConfigClient: stubGitConfigClient{
readBranchConfigFn: stubBranchConfig(git.BranchConfig{}, nil),
pushDefaultFn: stubPushDefault(git.PushDefaultSimple, nil),
remotePushDefaultFn: stubRemotePushDefault("", nil),
pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("testErr")),
},
},
httpStub: func(r *httpmock.Registry) {
r.Register(
@ -374,10 +386,12 @@ func TestFind(t *testing.T) {
branchFn: func() (string, error) {
return "blueberries", nil
},
branchConfig: stubBranchConfig(git.BranchConfig{}, nil),
pushDefault: stubPushDefault("simple", nil),
remotePushDefault: stubRemotePushDefault("", nil),
parsePushRevision: stubParsedPushRevision("", nil),
gitConfigClient: stubGitConfigClient{
readBranchConfigFn: stubBranchConfig(git.BranchConfig{}, nil),
pushDefaultFn: stubPushDefault(git.PushDefaultSimple, nil),
remotePushDefaultFn: stubRemotePushDefault("", nil),
pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("testErr")),
},
},
httpStub: func(r *httpmock.Registry) {
r.Register(
@ -423,13 +437,15 @@ func TestFind(t *testing.T) {
branchFn: func() (string, error) {
return "blueberries", nil
},
branchConfig: stubBranchConfig(git.BranchConfig{
MergeRef: "refs/heads/blue-upstream-berries",
PushRemoteName: "upstream",
}, nil),
pushDefault: stubPushDefault("upstream", nil),
remotePushDefault: stubRemotePushDefault("", nil),
parsePushRevision: stubParsedPushRevision("", nil),
gitConfigClient: stubGitConfigClient{
readBranchConfigFn: stubBranchConfig(git.BranchConfig{
MergeRef: "refs/heads/blue-upstream-berries",
PushRemoteName: "upstream",
}, nil),
pushDefaultFn: stubPushDefault("upstream", nil),
remotePushDefaultFn: stubRemotePushDefault("", nil),
pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("testErr")),
},
},
httpStub: func(r *httpmock.Registry) {
r.Register(
@ -463,13 +479,15 @@ func TestFind(t *testing.T) {
branchFn: func() (string, error) {
return "blueberries", nil
},
branchConfig: stubBranchConfig(git.BranchConfig{
MergeRef: "refs/heads/blue-upstream-berries",
PushRemoteURL: remoteUpstream.Remote.FetchURL,
}, nil),
pushDefault: stubPushDefault("upstream", nil),
remotePushDefault: stubRemotePushDefault("", nil),
parsePushRevision: stubParsedPushRevision("", nil),
gitConfigClient: stubGitConfigClient{
readBranchConfigFn: stubBranchConfig(git.BranchConfig{
MergeRef: "refs/heads/blue-upstream-berries",
PushRemoteURL: remoteUpstream.Remote.FetchURL,
}, nil),
pushDefaultFn: stubPushDefault("upstream", nil),
remotePushDefaultFn: stubRemotePushDefault("", nil),
pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("testErr")),
},
},
httpStub: func(r *httpmock.Registry) {
r.Register(
@ -499,10 +517,12 @@ func TestFind(t *testing.T) {
branchFn: func() (string, error) {
return "blueberries", nil
},
branchConfig: stubBranchConfig(git.BranchConfig{}, nil),
pushDefault: stubPushDefault("simple", nil),
remotePushDefault: stubRemotePushDefault("", nil),
parsePushRevision: stubParsedPushRevision("other/blueberries", nil),
gitConfigClient: stubGitConfigClient{
readBranchConfigFn: stubBranchConfig(git.BranchConfig{}, nil),
pushDefaultFn: stubPushDefault(git.PushDefaultSimple, nil),
remotePushDefaultFn: stubRemotePushDefault("", nil),
pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{Remote: "other", Branch: "blueberries"}, nil),
},
},
httpStub: func(r *httpmock.Registry) {
r.Register(
@ -534,9 +554,11 @@ func TestFind(t *testing.T) {
branchFn: func() (string, error) {
return "blueberries", nil
},
branchConfig: stubBranchConfig(git.BranchConfig{
MergeRef: "refs/pull/13/head",
}, nil),
gitConfigClient: stubGitConfigClient{
readBranchConfigFn: stubBranchConfig(git.BranchConfig{
MergeRef: "refs/pull/13/head",
}, nil),
},
},
httpStub: func(r *httpmock.Registry) {
r.Register(
@ -559,11 +581,13 @@ func TestFind(t *testing.T) {
branchFn: func() (string, error) {
return "blueberries", nil
},
branchConfig: stubBranchConfig(git.BranchConfig{
MergeRef: "refs/pull/13/head",
}, nil),
pushDefault: stubPushDefault("simple", nil),
remotePushDefault: stubRemotePushDefault("", nil),
gitConfigClient: stubGitConfigClient{
readBranchConfigFn: stubBranchConfig(git.BranchConfig{
MergeRef: "refs/pull/13/head",
}, nil),
pushDefaultFn: stubPushDefault(git.PushDefaultSimple, nil),
remotePushDefaultFn: stubRemotePushDefault("", nil),
},
},
httpStub: func(r *httpmock.Registry) {
r.Register(
@ -575,32 +599,32 @@ func TestFind(t *testing.T) {
r.Register(
httpmock.GraphQL(`query PullRequestProjectItems\b`),
httpmock.GraphQLQuery(`{
"data": {
"repository": {
"pullRequest": {
"projectItems": {
"nodes": [
{
"id": "PVTI_lADOB-vozM4AVk16zgK6U50",
"project": {
"id": "PVT_kwDOB-vozM4AVk16",
"title": "Test Project"
},
"status": {
"optionId": "47fc9ee4",
"name": "In Progress"
}
}
],
"pageInfo": {
"hasNextPage": false,
"endCursor": "MQ"
}
}
}
}
}
}`,
"data": {
"repository": {
"pullRequest": {
"projectItems": {
"nodes": [
{
"id": "PVTI_lADOB-vozM4AVk16zgK6U50",
"project": {
"id": "PVT_kwDOB-vozM4AVk16",
"title": "Test Project"
},
"status": {
"optionId": "47fc9ee4",
"name": "In Progress"
}
}
],
"pageInfo": {
"hasNextPage": false,
"endCursor": "MQ"
}
}
}
}
}
}`,
func(query string, inputs map[string]interface{}) {
require.Equal(t, float64(13), inputs["number"])
require.Equal(t, "OWNER", inputs["owner"])
@ -624,13 +648,10 @@ func TestFind(t *testing.T) {
httpClient: func() (*http.Client, error) {
return &http.Client{Transport: reg}, nil
},
baseRepoFn: tt.args.baseRepoFn,
branchFn: tt.args.branchFn,
branchConfig: tt.args.branchConfig,
pushDefault: tt.args.pushDefault,
remotePushDefault: tt.args.remotePushDefault,
parsePushRevision: tt.args.parsePushRevision,
remotesFn: stubRemotes(context.Remotes{
baseRepoFn: tt.args.baseRepoFn,
branchFn: tt.args.branchFn,
gitConfigClient: tt.args.gitConfigClient,
remotesFn: stubRemotes(ghContext.Remotes{
&remoteOrigin,
&remoteOther,
&remoteUpstream,
@ -667,366 +688,14 @@ func TestFind(t *testing.T) {
}
}
func TestParsePRRefs(t *testing.T) {
originOwnerUrl, err := url.Parse("https://github.com/ORIGINOWNER/REPO.git")
if err != nil {
t.Fatal(err)
}
remoteOrigin := context.Remote{
Remote: &git.Remote{
Name: "origin",
FetchURL: originOwnerUrl,
},
Repo: ghrepo.New("ORIGINOWNER", "REPO"),
}
remoteOther := context.Remote{
Remote: &git.Remote{
Name: "other",
FetchURL: originOwnerUrl,
},
Repo: ghrepo.New("ORIGINOWNER", "REPO"),
}
upstreamOwnerUrl, err := url.Parse("https://github.com/UPSTREAMOWNER/REPO.git")
if err != nil {
t.Fatal(err)
}
remoteUpstream := context.Remote{
Remote: &git.Remote{
Name: "upstream",
FetchURL: upstreamOwnerUrl,
},
Repo: ghrepo.New("UPSTREAMOWNER", "REPO"),
}
tests := []struct {
name string
branchConfig git.BranchConfig
pushDefault string
parsedPushRevision string
remotePushDefault string
currentBranchName string
baseRefRepo ghrepo.Interface
rems context.Remotes
wantPRRefs PullRequestRefs
wantErr error
}{
{
name: "When the branch is called 'blueberries' with an empty branch config, it returns the correct PullRequestRefs",
branchConfig: git.BranchConfig{},
currentBranchName: "blueberries",
baseRefRepo: remoteOrigin.Repo,
wantPRRefs: PullRequestRefs{
BranchName: "blueberries",
HeadRepo: remoteOrigin.Repo,
BaseRepo: remoteOrigin.Repo,
},
wantErr: nil,
},
{
name: "When the branch is called 'otherBranch' with an empty branch config, it returns the correct PullRequestRefs",
branchConfig: git.BranchConfig{},
currentBranchName: "otherBranch",
baseRefRepo: remoteOrigin.Repo,
wantPRRefs: PullRequestRefs{
BranchName: "otherBranch",
HeadRepo: remoteOrigin.Repo,
BaseRepo: remoteOrigin.Repo,
},
wantErr: nil,
},
{
name: "When the branch name doesn't match the branch name in BranchConfig.Push, it returns the BranchConfig.Push branch name",
parsedPushRevision: "origin/pushBranch",
currentBranchName: "blueberries",
baseRefRepo: remoteOrigin.Repo,
rems: context.Remotes{
&remoteOrigin,
},
wantPRRefs: PullRequestRefs{
BranchName: "pushBranch",
HeadRepo: remoteOrigin.Repo,
BaseRepo: remoteOrigin.Repo,
},
wantErr: nil,
},
{
name: "When the push revision doesn't match a remote, it returns an error",
parsedPushRevision: "origin/differentPushBranch",
currentBranchName: "blueberries",
baseRefRepo: remoteOrigin.Repo,
rems: context.Remotes{
&remoteUpstream,
&remoteOther,
},
wantPRRefs: PullRequestRefs{},
wantErr: fmt.Errorf("no remote for %q found in %q", "origin/differentPushBranch", "upstream, other"),
},
{
name: "When the branch name doesn't match a different branch name in BranchConfig.Push and the remote isn't 'origin', it returns the BranchConfig.Push branch name",
parsedPushRevision: "other/pushBranch",
currentBranchName: "blueberries",
baseRefRepo: remoteOrigin.Repo,
rems: context.Remotes{
&remoteOther,
},
wantPRRefs: PullRequestRefs{
BranchName: "pushBranch",
HeadRepo: remoteOther.Repo,
BaseRepo: remoteOrigin.Repo,
},
wantErr: nil,
},
{
name: "When the push remote is the same as the baseRepo, it returns the baseRepo as the PullRequestRefs HeadRepo",
branchConfig: git.BranchConfig{
PushRemoteName: remoteOrigin.Remote.Name,
},
currentBranchName: "blueberries",
baseRefRepo: remoteOrigin.Repo,
rems: context.Remotes{
&remoteOrigin,
&remoteUpstream,
},
wantPRRefs: PullRequestRefs{
BranchName: "blueberries",
HeadRepo: remoteOrigin.Repo,
BaseRepo: remoteOrigin.Repo,
},
wantErr: nil,
},
{
name: "When the push remote is different from the baseRepo, it returns the push remote repo as the PullRequestRefs HeadRepo",
branchConfig: git.BranchConfig{
PushRemoteName: remoteOrigin.Remote.Name,
},
currentBranchName: "blueberries",
baseRefRepo: remoteUpstream.Repo,
rems: context.Remotes{
&remoteOrigin,
&remoteUpstream,
},
wantPRRefs: PullRequestRefs{
BranchName: "blueberries",
HeadRepo: remoteOrigin.Repo,
BaseRepo: remoteUpstream.Repo,
},
wantErr: nil,
},
{
name: "When the push remote defined by a URL and the baseRepo is different from the push remote, it returns the push remote repo as the PullRequestRefs HeadRepo",
branchConfig: git.BranchConfig{
PushRemoteURL: remoteOrigin.Remote.FetchURL,
},
currentBranchName: "blueberries",
baseRefRepo: remoteUpstream.Repo,
rems: context.Remotes{
&remoteOrigin,
&remoteUpstream,
},
wantPRRefs: PullRequestRefs{
BranchName: "blueberries",
HeadRepo: remoteOrigin.Repo,
BaseRepo: remoteUpstream.Repo,
},
wantErr: nil,
},
{
name: "When the push remote and merge ref are configured to a different repo and push.default = upstream, it should return the branch name from the other repo",
branchConfig: git.BranchConfig{
PushRemoteName: remoteUpstream.Remote.Name,
MergeRef: "refs/heads/blue-upstream-berries",
},
pushDefault: "upstream",
currentBranchName: "blueberries",
baseRefRepo: remoteOrigin.Repo,
rems: context.Remotes{
&remoteOrigin,
&remoteUpstream,
},
wantPRRefs: PullRequestRefs{
BranchName: "blue-upstream-berries",
HeadRepo: remoteUpstream.Repo,
BaseRepo: remoteOrigin.Repo,
},
wantErr: nil,
},
{
name: "When the push remote and merge ref are configured to a different repo and push.default = tracking, it should return the branch name from the other repo",
branchConfig: git.BranchConfig{
PushRemoteName: remoteUpstream.Remote.Name,
MergeRef: "refs/heads/blue-upstream-berries",
},
pushDefault: "tracking",
currentBranchName: "blueberries",
baseRefRepo: remoteOrigin.Repo,
rems: context.Remotes{
&remoteOrigin,
&remoteUpstream,
},
wantPRRefs: PullRequestRefs{
BranchName: "blue-upstream-berries",
HeadRepo: remoteUpstream.Repo,
BaseRepo: remoteOrigin.Repo,
},
wantErr: nil,
},
{
name: "When remote.pushDefault is set, it returns the correct PullRequestRefs",
branchConfig: git.BranchConfig{},
remotePushDefault: remoteUpstream.Remote.Name,
currentBranchName: "blueberries",
baseRefRepo: remoteOrigin.Repo,
rems: context.Remotes{
&remoteOrigin,
&remoteUpstream,
},
wantPRRefs: PullRequestRefs{
BranchName: "blueberries",
HeadRepo: remoteUpstream.Repo,
BaseRepo: remoteOrigin.Repo,
},
wantErr: nil,
},
{
name: "When the remote name is set on the branch, it returns the correct PullRequestRefs",
branchConfig: git.BranchConfig{
RemoteName: remoteUpstream.Remote.Name,
},
currentBranchName: "blueberries",
baseRefRepo: remoteOrigin.Repo,
rems: context.Remotes{
&remoteOrigin,
&remoteUpstream,
},
wantPRRefs: PullRequestRefs{
BranchName: "blueberries",
HeadRepo: remoteUpstream.Repo,
BaseRepo: remoteOrigin.Repo,
},
wantErr: nil,
},
{
name: "When the remote URL is set on the branch, it returns the correct PullRequestRefs",
branchConfig: git.BranchConfig{
RemoteURL: remoteUpstream.Remote.FetchURL,
},
currentBranchName: "blueberries",
baseRefRepo: remoteOrigin.Repo,
rems: context.Remotes{
&remoteOrigin,
&remoteUpstream,
},
wantPRRefs: PullRequestRefs{
BranchName: "blueberries",
HeadRepo: remoteUpstream.Repo,
BaseRepo: remoteOrigin.Repo,
},
wantErr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
prRefs, err := ParsePRRefs(tt.currentBranchName, tt.branchConfig, tt.parsedPushRevision, tt.pushDefault, tt.remotePushDefault, tt.baseRefRepo, tt.rems)
if tt.wantErr != nil {
require.Equal(t, tt.wantErr, err)
} else {
require.NoError(t, err)
}
require.Equal(t, tt.wantPRRefs, prRefs)
})
}
}
func TestPRRefs_GetPRHeadLabel(t *testing.T) {
originRepo := ghrepo.New("ORIGINOWNER", "REPO")
upstreamRepo := ghrepo.New("UPSTREAMOWNER", "REPO")
tests := []struct {
name string
prRefs PullRequestRefs
want string
}{
{
name: "When the HeadRepo and BaseRepo match, it returns the branch name",
prRefs: PullRequestRefs{
BranchName: "blueberries",
HeadRepo: originRepo,
BaseRepo: originRepo,
},
want: "blueberries",
},
{
name: "When the HeadRepo and BaseRepo do not match, it returns the prepended HeadRepo owner to the branch name",
prRefs: PullRequestRefs{
BranchName: "blueberries",
HeadRepo: originRepo,
BaseRepo: upstreamRepo,
},
want: "ORIGINOWNER:blueberries",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.want, tt.prRefs.GetPRHeadLabel())
})
}
}
func TestPullRequestRefs_HasHead(t *testing.T) {
tests := []struct {
name string
prRefs PullRequestRefs
want bool
}{
{
name: "HeadRepo is nil and BranchName is empty, return false",
prRefs: PullRequestRefs{
HeadRepo: nil,
BranchName: "",
},
want: false,
},
{
name: "HeadRepo is not nil and BranchName is empty, return false",
prRefs: PullRequestRefs{
HeadRepo: ghrepo.New("ORIGINOWNER", "REPO"),
BranchName: "",
},
want: false,
},
{
name: "HeadRepo is nil and BranchName is not empty, return false",
prRefs: PullRequestRefs{
HeadRepo: nil,
BranchName: "feature-branch",
},
want: false,
},
{
name: "HeadRepo is not nil and BranchName is not empty, return true",
prRefs: PullRequestRefs{
HeadRepo: ghrepo.New("ORIGINOWNER", "REPO"),
BranchName: "feature-branch",
},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
require.Equal(t, tt.want, tt.prRefs.HasHead())
})
}
}
func stubBranchConfig(branchConfig git.BranchConfig, err error) func(string) (git.BranchConfig, error) {
return func(branch string) (git.BranchConfig, error) {
func stubBranchConfig(branchConfig git.BranchConfig, err error) func(context.Context, string) (git.BranchConfig, error) {
return func(_ context.Context, branch string) (git.BranchConfig, error) {
return branchConfig, err
}
}
func stubRemotes(remotes context.Remotes, err error) func() (context.Remotes, error) {
return func() (context.Remotes, error) {
func stubRemotes(remotes ghContext.Remotes, err error) func() (ghContext.Remotes, error) {
return func() (ghContext.Remotes, error) {
return remotes, err
}
}
@ -1037,20 +706,55 @@ func stubBaseRepoFn(baseRepo ghrepo.Interface, err error) func() (ghrepo.Interfa
}
}
func stubPushDefault(pushDefault string, err error) func() (string, error) {
return func() (string, error) {
func stubPushDefault(pushDefault git.PushDefault, err error) func(context.Context) (git.PushDefault, error) {
return func(_ context.Context) (git.PushDefault, error) {
return pushDefault, err
}
}
func stubRemotePushDefault(remotePushDefault string, err error) func() (string, error) {
return func() (string, error) {
func stubRemotePushDefault(remotePushDefault string, err error) func(context.Context) (string, error) {
return func(_ context.Context) (string, error) {
return remotePushDefault, err
}
}
func stubParsedPushRevision(parsedPushRevision string, err error) func(string) (string, error) {
return func(_ string) (string, error) {
func stubPushRevision(parsedPushRevision git.RemoteTrackingRef, err error) func(context.Context, string) (git.RemoteTrackingRef, error) {
return func(_ context.Context, _ string) (git.RemoteTrackingRef, error) {
return parsedPushRevision, err
}
}
type stubGitConfigClient struct {
readBranchConfigFn func(ctx context.Context, branchName string) (git.BranchConfig, error)
pushDefaultFn func(ctx context.Context) (git.PushDefault, error)
remotePushDefaultFn func(ctx context.Context) (string, error)
pushRevisionFn func(ctx context.Context, branchName string) (git.RemoteTrackingRef, error)
}
func (s stubGitConfigClient) ReadBranchConfig(ctx context.Context, branchName string) (git.BranchConfig, error) {
if s.readBranchConfigFn == nil {
panic("unexpected call to ReadBranchConfig")
}
return s.readBranchConfigFn(ctx, branchName)
}
func (s stubGitConfigClient) PushDefault(ctx context.Context) (git.PushDefault, error) {
if s.pushDefaultFn == nil {
panic("unexpected call to PushDefault")
}
return s.pushDefaultFn(ctx)
}
func (s stubGitConfigClient) RemotePushDefault(ctx context.Context) (string, error) {
if s.remotePushDefaultFn == nil {
panic("unexpected call to RemotePushDefault")
}
return s.remotePushDefaultFn(ctx)
}
func (s stubGitConfigClient) PushRevision(ctx context.Context, branchName string) (git.RemoteTrackingRef, error) {
if s.pushRevisionFn == nil {
panic("unexpected call to PushRevision")
}
return s.pushRevisionFn(ctx, branchName)
}

View file

@ -0,0 +1,18 @@
package shared
import (
"context"
"github.com/cli/cli/v2/git"
)
var _ GitConfigClient = &CachedBranchConfigGitConfigClient{}
type CachedBranchConfigGitConfigClient struct {
CachedBranchConfig git.BranchConfig
GitConfigClient
}
func (c CachedBranchConfigGitConfigClient) ReadBranchConfig(ctx context.Context, branchName string) (git.BranchConfig, error) {
return c.CachedBranchConfig, nil
}

View file

@ -102,43 +102,34 @@ func statusRun(opts *StatusOptions) error {
return fmt.Errorf("could not query for pull request for current branch: %w", err)
}
branchConfig, err := opts.GitClient.ReadBranchConfig(ctx, currentBranchName)
if err != nil {
return err
}
// Determine if the branch is configured to merge to a special PR ref
prHeadRE := regexp.MustCompile(`^refs/pull/(\d+)/head$`)
if m := prHeadRE.FindStringSubmatch(branchConfig.MergeRef); m != nil {
currentPRNumber, _ = strconv.Atoi(m[1])
}
if currentPRNumber == 0 {
remotes, err := opts.Remotes()
if !errors.Is(err, git.ErrNotOnAnyBranch) {
branchConfig, err := opts.GitClient.ReadBranchConfig(ctx, currentBranchName)
if err != nil {
return err
}
// Suppressing these errors as we have other means of computing the PullRequestRefs when these fail.
parsedPushRevision, _ := opts.GitClient.ParsePushRevision(ctx, currentBranchName)
remotePushDefault, err := opts.GitClient.RemotePushDefault(ctx)
if err != nil {
return err
// Determine if the branch is configured to merge to a special PR ref
prHeadRE := regexp.MustCompile(`^refs/pull/(\d+)/head$`)
if m := prHeadRE.FindStringSubmatch(branchConfig.MergeRef); m != nil {
currentPRNumber, _ = strconv.Atoi(m[1])
}
pushDefault, err := opts.GitClient.PushDefault(ctx)
if err != nil {
return err
}
if currentPRNumber == 0 {
prRefsResolver := shared.NewPullRequestFindRefsResolver(
// We requested the branch config already, so let's cache that
shared.CachedBranchConfigGitConfigClient{
CachedBranchConfig: branchConfig,
GitConfigClient: opts.GitClient,
},
opts.Remotes,
)
prRefs, err := shared.ParsePRRefs(currentBranchName, branchConfig, parsedPushRevision, pushDefault, remotePushDefault, baseRefRepo, remotes)
if err != nil {
return err
}
currentHeadRefBranchName = prRefs.BranchName
}
prRefs, err := prRefsResolver.ResolvePullRequestRefs(baseRefRepo, "", currentBranchName)
if err != nil {
return err
}
if err != nil {
return fmt.Errorf("could not query for pull request for current branch: %w", err)
currentHeadRefBranchName = prRefs.QualifiedHeadRef()
}
}
}

View file

@ -98,10 +98,10 @@ func TestPRStatus(t *testing.T) {
// stub successful git commands
rs, cleanup := run.Stub()
defer cleanup(t)
rs.Register(`git rev-parse --symbolic-full-name blueberries@{push}`, 128, "")
rs.Register(`git config --get-regexp \^branch\\.`, 0, "")
rs.Register(`git config remote.pushDefault`, 0, "")
rs.Register(`git rev-parse --abbrev-ref blueberries@{push}`, 0, "")
rs.Register(`git config push.default`, 0, "")
rs.Register(`git config push.default`, 1, "")
output, err := runCommand(http, "blueberries", true, "")
if err != nil {
@ -133,8 +133,8 @@ func TestPRStatus_reviewsAndChecks(t *testing.T) {
defer cleanup(t)
rs.Register(`git config --get-regexp \^branch\\.`, 0, "")
rs.Register(`git config remote.pushDefault`, 0, "")
rs.Register(`git rev-parse --abbrev-ref blueberries@{push}`, 0, "")
rs.Register(`git config push.default`, 0, "")
rs.Register(`git rev-parse --symbolic-full-name blueberries@{push}`, 128, "")
rs.Register(`git config push.default`, 1, "")
output, err := runCommand(http, "blueberries", true, "")
if err != nil {
@ -166,8 +166,8 @@ func TestPRStatus_reviewsAndChecksWithStatesByCount(t *testing.T) {
defer cleanup(t)
rs.Register(`git config --get-regexp \^branch\\.`, 0, "")
rs.Register(`git config remote.pushDefault`, 0, "")
rs.Register(`git rev-parse --abbrev-ref blueberries@{push}`, 0, "")
rs.Register(`git config push.default`, 0, "")
rs.Register(`git rev-parse --symbolic-full-name blueberries@{push}`, 128, "")
rs.Register(`git config push.default`, 1, "")
output, err := runCommandWithDetector(http, "blueberries", true, "", &fd.EnabledDetectorMock{})
if err != nil {
@ -198,8 +198,8 @@ func TestPRStatus_currentBranch_showTheMostRecentPR(t *testing.T) {
defer cleanup(t)
rs.Register(`git config --get-regexp \^branch\\.`, 0, "")
rs.Register(`git config remote.pushDefault`, 0, "")
rs.Register(`git rev-parse --abbrev-ref blueberries@{push}`, 0, "")
rs.Register(`git config push.default`, 0, "")
rs.Register(`git rev-parse --symbolic-full-name blueberries@{push}`, 128, "")
rs.Register(`git config push.default`, 1, "")
output, err := runCommand(http, "blueberries", true, "")
if err != nil {
@ -234,8 +234,8 @@ func TestPRStatus_currentBranch_defaultBranch(t *testing.T) {
defer cleanup(t)
rs.Register(`git config --get-regexp \^branch\\.`, 0, "")
rs.Register(`git config remote.pushDefault`, 0, "")
rs.Register(`git rev-parse --abbrev-ref blueberries@{push}`, 0, "")
rs.Register(`git config push.default`, 0, "")
rs.Register(`git rev-parse --symbolic-full-name blueberries@{push}`, 128, "")
rs.Register(`git config push.default`, 1, "")
output, err := runCommand(http, "blueberries", true, "")
if err != nil {
@ -276,8 +276,8 @@ func TestPRStatus_currentBranch_Closed(t *testing.T) {
defer cleanup(t)
rs.Register(`git config --get-regexp \^branch\\.`, 0, "")
rs.Register(`git config remote.pushDefault`, 0, "")
rs.Register(`git rev-parse --abbrev-ref blueberries@{push}`, 0, "")
rs.Register(`git config push.default`, 0, "")
rs.Register(`git rev-parse --symbolic-full-name blueberries@{push}`, 128, "")
rs.Register(`git config push.default`, 1, "")
output, err := runCommand(http, "blueberries", true, "")
if err != nil {
@ -301,8 +301,8 @@ func TestPRStatus_currentBranch_Closed_defaultBranch(t *testing.T) {
defer cleanup(t)
rs.Register(`git config --get-regexp \^branch\\.`, 0, "")
rs.Register(`git config remote.pushDefault`, 0, "")
rs.Register(`git rev-parse --abbrev-ref blueberries@{push}`, 0, "")
rs.Register(`git config push.default`, 0, "")
rs.Register(`git rev-parse --symbolic-full-name blueberries@{push}`, 128, "")
rs.Register(`git config push.default`, 1, "")
output, err := runCommand(http, "blueberries", true, "")
if err != nil {
@ -326,8 +326,8 @@ func TestPRStatus_currentBranch_Merged(t *testing.T) {
defer cleanup(t)
rs.Register(`git config --get-regexp \^branch\\.`, 0, "")
rs.Register(`git config remote.pushDefault`, 0, "")
rs.Register(`git rev-parse --abbrev-ref blueberries@{push}`, 0, "")
rs.Register(`git config push.default`, 0, "")
rs.Register(`git rev-parse --symbolic-full-name blueberries@{push}`, 128, "")
rs.Register(`git config push.default`, 1, "")
output, err := runCommand(http, "blueberries", true, "")
if err != nil {
@ -351,8 +351,8 @@ func TestPRStatus_currentBranch_Merged_defaultBranch(t *testing.T) {
defer cleanup(t)
rs.Register(`git config --get-regexp \^branch\\.`, 0, "")
rs.Register(`git config remote.pushDefault`, 0, "")
rs.Register(`git rev-parse --abbrev-ref blueberries@{push}`, 0, "")
rs.Register(`git config push.default`, 0, "")
rs.Register(`git rev-parse --symbolic-full-name blueberries@{push}`, 128, "")
rs.Register(`git config push.default`, 1, "")
output, err := runCommand(http, "blueberries", true, "")
if err != nil {
@ -376,8 +376,8 @@ func TestPRStatus_blankSlate(t *testing.T) {
defer cleanup(t)
rs.Register(`git config --get-regexp \^branch\\.`, 0, "")
rs.Register(`git config remote.pushDefault`, 0, "")
rs.Register(`git rev-parse --abbrev-ref blueberries@{push}`, 0, "")
rs.Register(`git config push.default`, 0, "")
rs.Register(`git rev-parse --symbolic-full-name blueberries@{push}`, 128, "")
rs.Register(`git config push.default`, 1, "")
output, err := runCommand(http, "blueberries", true, "")
if err != nil {
@ -432,14 +432,6 @@ func TestPRStatus_detachedHead(t *testing.T) {
defer http.Verify(t)
http.Register(httpmock.GraphQL(`query PullRequestStatus\b`), httpmock.StringResponse(`{"data": {}}`))
// stub successful git command
rs, cleanup := run.Stub()
defer cleanup(t)
rs.Register(`git config --get-regexp \^branch\\.`, 0, "")
rs.Register(`git config remote.pushDefault`, 0, "")
rs.Register(`git rev-parse --abbrev-ref @{push}`, 0, "")
rs.Register(`git config push.default`, 0, "")
output, err := runCommand(http, "", true, "")
if err != nil {
t.Errorf("error running command `pr status`: %v", err)

View file

@ -3,6 +3,8 @@ package httpmock
import (
"fmt"
"net/http"
"runtime/debug"
"strings"
"sync"
"testing"
@ -23,6 +25,7 @@ type Registry struct {
func (r *Registry) Register(m Matcher, resp Responder) {
r.stubs = append(r.stubs, &Stub{
Stack: string(debug.Stack()),
Matcher: m,
Responder: resp,
})
@ -46,17 +49,24 @@ type Testing interface {
}
func (r *Registry) Verify(t Testing) {
n := 0
var unmatchedStubStacks []string
for _, s := range r.stubs {
if !s.matched && !s.exclude {
n++
unmatchedStubStacks = append(unmatchedStubStacks, s.Stack)
}
}
if n > 0 {
if len(unmatchedStubStacks) > 0 {
t.Helper()
// NOTE: stubs offer no useful reflection, so we can't print details
stacks := strings.Builder{}
for i, stack := range unmatchedStubStacks {
stacks.WriteString(fmt.Sprintf("Stub %d:\n", i+1))
stacks.WriteString(fmt.Sprintf("\t%s", stack))
if stack != unmatchedStubStacks[len(unmatchedStubStacks)-1] {
stacks.WriteString("\n")
}
}
// about dead stubs and what they were trying to match
t.Errorf("%d unmatched HTTP stubs", n)
t.Errorf("%d HTTP stubs unmatched, stacks:\n%s", len(unmatchedStubStacks), stacks.String())
}
}
@ -84,7 +94,7 @@ func (r *Registry) RoundTrip(req *http.Request) (*http.Response, error) {
if stub == nil {
r.mu.Unlock()
return nil, fmt.Errorf("no registered stubs matched %v", req)
return nil, fmt.Errorf("no registered HTTP stubs matched %v", req)
}
r.Requests = append(r.Requests, req)

View file

@ -15,6 +15,7 @@ type Matcher func(req *http.Request) bool
type Responder func(req *http.Request) (*http.Response, error)
type Stub struct {
Stack string
matched bool
Matcher Matcher
Responder Responder

View file

@ -46,6 +46,15 @@ func None[T any]() Option[T] {
return Option[T]{}
}
func SomeIfNonZero[T comparable](value T) Option[T] {
// value is a zero value then return a None
var zero T
if value == zero {
return None[T]()
}
return Some(value)
}
// String implements the [fmt.Stringer] interface.
func (o Option[T]) String() string {
if o.present {