From a9dbda69135631511eebfa8173f80e93e726c7b2 Mon Sep 17 00:00:00 2001 From: William Martin Date: Wed, 2 Apr 2025 13:21:10 +0200 Subject: [PATCH] Rework ref usage when finding and creating PRs --- .../pr/pr-checkout-with-url-from-fork.txtar | 1 + .../pr-create-guesses-remote-from-sha.txtar | 46 ++ ...pr-create-respects-branch-pushremote.txtar | 2 +- .../pr-create-respects-push-destination.txtar | 4 +- ...r-create-respects-remote-pushdefault.txtar | 2 +- ...te-respects-user-colon-branch-syntax.txtar | 4 +- .../pr-create-without-upstream-config.txtar | 10 +- .../pr/pr-status-respects-cross-org.txtar | 46 ++ .../testdata/pr/pr-view-same-org-fork.txtar | 3 +- ...ew-status-respects-branch-pushremote.txtar | 3 +- ...w-status-respects-remote-pushdefault.txtar | 3 +- acceptance/testdata/repo/repo-fork-sync.txtar | 4 +- ...secret-require-remote-disambiguation.txtar | 4 +- git/client.go | 82 ++- git/client_test.go | 147 +++- internal/run/stub.go | 2 +- pkg/cmd/pr/create/create.go | 668 +++++++++++------- pkg/cmd/pr/create/create_test.go | 408 +++++++---- pkg/cmd/pr/shared/find_refs_resolution.go | 394 +++++++++++ .../pr/shared/find_refs_resolution_test.go | 508 +++++++++++++ pkg/cmd/pr/shared/finder.go | 208 ++---- pkg/cmd/pr/shared/finder_test.go | 666 +++++------------ pkg/cmd/pr/shared/git_cached_config_client.go | 18 + pkg/cmd/pr/status/status.go | 51 +- pkg/cmd/pr/status/status_test.go | 48 +- pkg/httpmock/registry.go | 22 +- pkg/httpmock/stub.go | 1 + pkg/option/option.go | 9 + 28 files changed, 2254 insertions(+), 1110 deletions(-) create mode 100644 acceptance/testdata/pr/pr-create-guesses-remote-from-sha.txtar create mode 100644 acceptance/testdata/pr/pr-status-respects-cross-org.txtar create mode 100644 pkg/cmd/pr/shared/find_refs_resolution.go create mode 100644 pkg/cmd/pr/shared/find_refs_resolution_test.go create mode 100644 pkg/cmd/pr/shared/git_cached_config_client.go diff --git a/acceptance/testdata/pr/pr-checkout-with-url-from-fork.txtar b/acceptance/testdata/pr/pr-checkout-with-url-from-fork.txtar index 9a0494f4b..637422a5a 100644 --- a/acceptance/testdata/pr/pr-checkout-with-url-from-fork.txtar +++ b/acceptance/testdata/pr/pr-checkout-with-url-from-fork.txtar @@ -12,6 +12,7 @@ defer gh repo delete --yes ${ORG}/${REPO} # Create a fork exec gh repo fork ${ORG}/${REPO} --org ${ORG} --fork-name ${REPO}-fork +sleep 5 # Defer fork cleanup defer gh repo delete --yes ${ORG}/${REPO}-fork diff --git a/acceptance/testdata/pr/pr-create-guesses-remote-from-sha.txtar b/acceptance/testdata/pr/pr-create-guesses-remote-from-sha.txtar new file mode 100644 index 000000000..52579b501 --- /dev/null +++ b/acceptance/testdata/pr/pr-create-guesses-remote-from-sha.txtar @@ -0,0 +1,46 @@ +env REPO=${SCRIPT_NAME}-${RANDOM_STRING} +env FORK=${REPO}-fork + +# Use gh as a credential helper +exec gh auth setup-git + +# Get the current username for the fork owner +exec gh api user --jq .login +stdout2env USER + +# Create a repository with a file so it has a default branch +exec gh repo create ${ORG}/${REPO} --add-readme --private + +# Defer repo cleanup +defer gh repo delete --yes ${ORG}/${REPO} + +# Create a user fork of repository. This will be owned by USER. +exec gh repo fork ${ORG}/${REPO} --fork-name ${FORK} +sleep 5 + +# Defer repo cleanup of fork +defer gh repo delete --yes ${USER}/${FORK} + +# Retrieve fork repository information +exec gh repo view ${USER}/${FORK} --json id --jq '.id' +stdout2env FORK_ID + +exec gh repo clone ${USER}/${FORK} +cd ${FORK} + +# Prepare a branch to commit +exec git checkout -b feature-branch +exec git commit --allow-empty -m 'Upstream Commit' +exec git push upstream feature-branch + +# Prepare an additional commit +exec git commit --allow-empty -m 'Fork Commit' +exec git push origin feature-branch + +# Create the PR +exec gh pr create --title 'Feature Title' --body 'Feature Body' +stdout https://${GH_HOST}/${ORG}/${REPO}/pull/1 + +# Check the PR is indeed created +exec gh pr view ${USER}:feature-branch --json headRefName,headRepository,baseRefName,isCrossRepository +stdout {"baseRefName":"main","headRefName":"feature-branch","headRepository":{"id":"${FORK_ID}","name":"${FORK}"},"isCrossRepository":true} diff --git a/acceptance/testdata/pr/pr-create-respects-branch-pushremote.txtar b/acceptance/testdata/pr/pr-create-respects-branch-pushremote.txtar index 189caaf9e..e0d0c099c 100644 --- a/acceptance/testdata/pr/pr-create-respects-branch-pushremote.txtar +++ b/acceptance/testdata/pr/pr-create-respects-branch-pushremote.txtar @@ -19,12 +19,12 @@ defer gh repo delete --yes ${ORG}/${REPO} # Create a user fork of repository. This will be owned by USER. exec gh repo fork ${ORG}/${REPO} --fork-name ${FORK} +sleep 5 # Defer repo cleanup of fork defer gh repo delete --yes ${USER}/${FORK} # Retrieve fork repository information -sleep 5 exec gh repo view ${USER}/${FORK} --json id --jq '.id' stdout2env FORK_ID diff --git a/acceptance/testdata/pr/pr-create-respects-push-destination.txtar b/acceptance/testdata/pr/pr-create-respects-push-destination.txtar index 142a2ec35..51708405d 100644 --- a/acceptance/testdata/pr/pr-create-respects-push-destination.txtar +++ b/acceptance/testdata/pr/pr-create-respects-push-destination.txtar @@ -19,12 +19,12 @@ defer gh repo delete --yes ${ORG}/${REPO} # Create a user fork of repository. This will be owned by USER. exec gh repo fork ${ORG}/${REPO} --fork-name ${FORK} +sleep 5 # Defer repo cleanup of fork defer gh repo delete --yes ${USER}/${FORK} # Retrieve fork repository information -sleep 5 exec gh repo view ${USER}/${FORK} --json id --jq '.id' stdout2env FORK_ID @@ -50,4 +50,4 @@ stdout https://${GH_HOST}/${ORG}/${REPO}/pull/1 # Assert that the PR was created with the correct head repository and refs exec gh pr view --json headRefName,headRepository,baseRefName,isCrossRepository -stdout {"baseRefName":"main","headRefName":"feature-branch","headRepository":{"id":"${FORK_ID}","name":"${FORK}"},"isCrossRepository":true} \ No newline at end of file +stdout {"baseRefName":"main","headRefName":"feature-branch","headRepository":{"id":"${FORK_ID}","name":"${FORK}"},"isCrossRepository":true} diff --git a/acceptance/testdata/pr/pr-create-respects-remote-pushdefault.txtar b/acceptance/testdata/pr/pr-create-respects-remote-pushdefault.txtar index 2b4b28809..ff92f1e2d 100644 --- a/acceptance/testdata/pr/pr-create-respects-remote-pushdefault.txtar +++ b/acceptance/testdata/pr/pr-create-respects-remote-pushdefault.txtar @@ -19,12 +19,12 @@ defer gh repo delete --yes ${ORG}/${REPO} # Create a user fork of repository. This will be owned by USER. exec gh repo fork ${ORG}/${REPO} --fork-name ${FORK} +sleep 5 # Defer repo cleanup of fork defer gh repo delete --yes ${USER}/${FORK} # Retrieve fork repository information -sleep 5 exec gh repo view ${USER}/${FORK} --json id --jq '.id' stdout2env FORK_ID diff --git a/acceptance/testdata/pr/pr-create-respects-user-colon-branch-syntax.txtar b/acceptance/testdata/pr/pr-create-respects-user-colon-branch-syntax.txtar index 097775cbd..a59171d58 100644 --- a/acceptance/testdata/pr/pr-create-respects-user-colon-branch-syntax.txtar +++ b/acceptance/testdata/pr/pr-create-respects-user-colon-branch-syntax.txtar @@ -19,16 +19,16 @@ defer gh repo delete --yes ${ORG}/${REPO} # Create a user fork of repository. This will be owned by USER. exec gh repo fork ${ORG}/${REPO} --fork-name ${FORK} +sleep 5 # Defer repo cleanup of fork defer gh repo delete --yes ${USER}/${FORK} # Retrieve fork repository information -sleep 5 exec gh repo view ${USER}/${FORK} --json id --jq '.id' stdout2env FORK_ID -# Clone the repo +# Clone the fork exec gh repo clone ${USER}/${FORK} cd ${FORK} diff --git a/acceptance/testdata/pr/pr-create-without-upstream-config.txtar b/acceptance/testdata/pr/pr-create-without-upstream-config.txtar index 00f3535a7..e5a40af72 100644 --- a/acceptance/testdata/pr/pr-create-without-upstream-config.txtar +++ b/acceptance/testdata/pr/pr-create-without-upstream-config.txtar @@ -1,20 +1,22 @@ # This test is the same as pr-create-basic, except that the git push doesn't include the -u argument # This causes a git config read to fail during gh pr create, but it should not be fatal +env REPO=${SCRIPT_NAME}-${RANDOM_STRING} + # Use gh as a credential helper exec gh auth setup-git # Create a repository with a file so it has a default branch -exec gh repo create $ORG/$SCRIPT_NAME-$RANDOM_STRING --add-readme --private +exec gh repo create ${ORG}/${REPO} --add-readme --private # Defer repo cleanup -defer gh repo delete --yes $ORG/$SCRIPT_NAME-$RANDOM_STRING +defer gh repo delete --yes ${ORG}/${REPO} # Clone the repo -exec gh repo clone $ORG/$SCRIPT_NAME-$RANDOM_STRING +exec gh repo clone ${ORG}/${REPO} # Prepare a branch to PR -cd $SCRIPT_NAME-$RANDOM_STRING +cd ${REPO} exec git checkout -b feature-branch exec git commit --allow-empty -m 'Empty Commit' exec git push origin feature-branch diff --git a/acceptance/testdata/pr/pr-status-respects-cross-org.txtar b/acceptance/testdata/pr/pr-status-respects-cross-org.txtar new file mode 100644 index 000000000..4505be923 --- /dev/null +++ b/acceptance/testdata/pr/pr-status-respects-cross-org.txtar @@ -0,0 +1,46 @@ +skip 'it creates a fork owned by the user running the test' + +# Setup environment variables used for testscript +env REPO=${SCRIPT_NAME}-${RANDOM_STRING} +env FORK=${REPO}-fork + +# Use gh as a credential helper +exec gh auth setup-git + +# Get the current username for the fork owner +exec gh api user --jq .login +stdout2env USER + +# Create a repository to act as upstream with a file so it has a default branch +exec gh repo create ${ORG}/${REPO} --add-readme --private + +# Defer repo cleanup of upstream +defer gh repo delete --yes ${ORG}/${REPO} + +# Create a user fork of repository. This will be owned by USER. +exec gh repo fork ${ORG}/${REPO} --fork-name ${FORK} +sleep 5 + +# Defer repo cleanup of fork +defer gh repo delete --yes ${USER}/${FORK} + +# Retrieve fork repository information +exec gh repo view ${USER}/${FORK} --json id --jq '.id' +stdout2env FORK_ID + +# Clone the repo +exec gh repo clone ${USER}/${FORK} +cd ${FORK} + +# Prepare a branch where changes are pulled from the upstream default branch but pushed to fork +exec git checkout -b feature-branch +exec git commit --allow-empty -m 'Empty Commit' +exec git push -u origin feature-branch + +# Create the PR spanning upstream and fork repositories +exec gh pr create --title 'Feature Title' --body 'Feature Body' +stdout https://${GH_HOST}/${ORG}/${REPO}/pull/1 + +# Assert that the PR was created with the correct head repository and refs +exec gh pr status +! stdout 'There is no pull request associated with' diff --git a/acceptance/testdata/pr/pr-view-same-org-fork.txtar b/acceptance/testdata/pr/pr-view-same-org-fork.txtar index ca58918a9..eed524dec 100644 --- a/acceptance/testdata/pr/pr-view-same-org-fork.txtar +++ b/acceptance/testdata/pr/pr-view-same-org-fork.txtar @@ -15,10 +15,11 @@ stdout2env REPO_ID # Create a fork in the same org exec gh repo fork ${ORG}/${REPO} --org ${ORG} --fork-name ${FORK} +sleep 5 # Defer repo cleanup of fork defer gh repo delete --yes ${ORG}/${FORK} -sleep 1 + exec gh repo view ${ORG}/${FORK} --json id --jq '.id' stdout2env FORK_ID diff --git a/acceptance/testdata/pr/pr-view-status-respects-branch-pushremote.txtar b/acceptance/testdata/pr/pr-view-status-respects-branch-pushremote.txtar index f0bb0e6e7..4e1e5e64a 100644 --- a/acceptance/testdata/pr/pr-view-status-respects-branch-pushremote.txtar +++ b/acceptance/testdata/pr/pr-view-status-respects-branch-pushremote.txtar @@ -15,10 +15,11 @@ stdout2env REPO_ID # Create a user fork of repository as opposed to private organization fork exec gh repo fork ${ORG}/${REPO} --org ${ORG} --fork-name ${FORK} +sleep 5 # Defer repo cleanup of fork defer gh repo delete --yes ${ORG}/${FORK} -sleep 5 + exec gh repo view ${ORG}/${FORK} --json id --jq '.id' stdout2env FORK_ID diff --git a/acceptance/testdata/pr/pr-view-status-respects-remote-pushdefault.txtar b/acceptance/testdata/pr/pr-view-status-respects-remote-pushdefault.txtar index a3d376b80..6c0743a6f 100644 --- a/acceptance/testdata/pr/pr-view-status-respects-remote-pushdefault.txtar +++ b/acceptance/testdata/pr/pr-view-status-respects-remote-pushdefault.txtar @@ -15,10 +15,11 @@ stdout2env REPO_ID # Create a user fork of repository as opposed to private organization fork exec gh repo fork ${ORG}/${REPO} --org ${ORG} --fork-name ${FORK} +sleep 5 # Defer repo cleanup of fork defer gh repo delete --yes ${ORG}/${FORK} -sleep 5 + exec gh repo view ${ORG}/${FORK} --json id --jq '.id' stdout2env FORK_ID diff --git a/acceptance/testdata/repo/repo-fork-sync.txtar b/acceptance/testdata/repo/repo-fork-sync.txtar index 6ed7b94e1..04c4c5845 100644 --- a/acceptance/testdata/repo/repo-fork-sync.txtar +++ b/acceptance/testdata/repo/repo-fork-sync.txtar @@ -9,13 +9,11 @@ defer gh repo delete --yes $ORG/$SCRIPT_NAME-$RANDOM_STRING # Fork and clone the repo exec gh repo fork $ORG/$SCRIPT_NAME-$RANDOM_STRING --org $ORG --fork-name $SCRIPT_NAME-$RANDOM_STRING-fork --clone +sleep 5 # Defer fork cleanup defer gh repo delete $ORG/$SCRIPT_NAME-$RANDOM_STRING-fork --yes -# Sleep so that the BE has time to sync -sleep 5 - # Check that the repo was forked exec gh repo view $ORG/$SCRIPT_NAME-$RANDOM_STRING-fork --json='isFork' --jq='.isFork' stdout 'true' diff --git a/acceptance/testdata/secret/secret-require-remote-disambiguation.txtar b/acceptance/testdata/secret/secret-require-remote-disambiguation.txtar index 02dec06a0..f3fa4a47a 100644 --- a/acceptance/testdata/secret/secret-require-remote-disambiguation.txtar +++ b/acceptance/testdata/secret/secret-require-remote-disambiguation.txtar @@ -12,13 +12,11 @@ defer gh repo delete --yes ${ORG}/${REPO} # Create a fork exec gh repo fork ${ORG}/${REPO} --org ${ORG} --fork-name ${REPO}-fork +sleep 5 # Defer fork cleanup defer gh repo delete --yes ${ORG}/${REPO}-fork -# Sleep to allow the fork to be created before cloning -sleep 2 - # Clone and move into the fork repo exec gh repo clone ${ORG}/${REPO}-fork cd ${REPO}-fork diff --git a/git/client.go b/git/client.go index 11a2e2e20..fe2819cf0 100644 --- a/git/client.go +++ b/git/client.go @@ -381,7 +381,6 @@ func (c *Client) lookupCommit(ctx context.Context, sha, format string) ([]byte, // Downstream consumers of ReadBranchConfig should consider the behavior they desire if this errors, // as an empty config is not necessarily breaking. func (c *Client) ReadBranchConfig(ctx context.Context, branch string) (BranchConfig, error) { - prefix := regexp.QuoteMeta(fmt.Sprintf("branch.%s.", branch)) args := []string{"config", "--get-regexp", fmt.Sprintf("^%s(remote|merge|pushremote|%s)$", prefix, MergeBaseConfig)} cmd, err := c.Command(ctx, args...) @@ -441,18 +440,50 @@ func (c *Client) SetBranchConfig(ctx context.Context, branch, name, value string return err } +// PushDefault defines the action git push should take if no refspec is given. +// See: https://git-scm.com/docs/git-config#Documentation/git-config.txt-pushdefault +type PushDefault string + +const ( + PushDefaultNothing PushDefault = "nothing" + PushDefaultCurrent PushDefault = "current" + PushDefaultUpstream PushDefault = "upstream" + PushDefaultTracking PushDefault = "tracking" + PushDefaultSimple PushDefault = "simple" + PushDefaultMatching PushDefault = "matching" +) + +func ParsePushDefault(s string) (PushDefault, error) { + validPushDefaults := map[string]struct{}{ + string(PushDefaultNothing): {}, + string(PushDefaultCurrent): {}, + string(PushDefaultUpstream): {}, + string(PushDefaultTracking): {}, + string(PushDefaultSimple): {}, + string(PushDefaultMatching): {}, + } + + if _, ok := validPushDefaults[s]; ok { + return PushDefault(s), nil + } + + return "", fmt.Errorf("unknown push.default value: %s", s) +} + // PushDefault returns the value of push.default in the config. If the value // is not set, it returns "simple" (the default git value). See // https://git-scm.com/docs/git-config#Documentation/git-config.txt-pushdefault -func (c *Client) PushDefault(ctx context.Context) (string, error) { +func (c *Client) PushDefault(ctx context.Context) (PushDefault, error) { pushDefault, err := c.Config(ctx, "push.default") if err == nil { - return pushDefault, nil + return ParsePushDefault(pushDefault) } + // If there is an error that the config key is not set, return the default value + // that git uses since 2.0. var gitError *GitError if ok := errors.As(err, &gitError); ok && gitError.ExitCode == 1 { - return "simple", nil + return PushDefaultSimple, nil } return "", err } @@ -473,13 +504,48 @@ func (c *Client) RemotePushDefault(ctx context.Context) (string, error) { return "", err } -// ParsePushRevision gets the value of the @{push} revision syntax +// RemoteTrackingRef is the structured form of the string "refs/remotes//". +// For example, the @{push} revision syntax could report "refs/remotes/origin/main" which would +// be parsed into RemoteTrackingRef{Remote: "origin", Branch: "main"}. +type RemoteTrackingRef struct { + Remote string + Branch string +} + +func (r RemoteTrackingRef) String() string { + return fmt.Sprintf("refs/remotes/%s/%s", r.Remote, r.Branch) +} + +// ParseRemoteTrackingRef parses a string of the form "refs/remotes//" into +// a RemoteTrackingBranch struct. If the string does not match this format, an error is returned. +func ParseRemoteTrackingRef(s string) (RemoteTrackingRef, error) { + parts := strings.Split(s, "/") + if len(parts) != 4 || parts[0] != "refs" || parts[1] != "remotes" { + return RemoteTrackingRef{}, fmt.Errorf("remote tracking branch must have format refs/remotes// but was: %s", s) + } + + return RemoteTrackingRef{ + Remote: parts[2], + Branch: parts[3], + }, nil +} + +// PushRevision gets the value of the @{push} revision syntax // An error here doesn't necessarily mean something is broken, but may mean that the @{push} // revision syntax couldn't be resolved, such as in non-centralized workflows with // push.default = simple. Downstream consumers should consider how to handle this error. -func (c *Client) ParsePushRevision(ctx context.Context, branch string) (string, error) { - revParseOut, err := c.revParse(ctx, "--abbrev-ref", branch+"@{push}") - return firstLine(revParseOut), err +func (c *Client) PushRevision(ctx context.Context, branch string) (RemoteTrackingRef, error) { + revParseOut, err := c.revParse(ctx, "--symbolic-full-name", branch+"@{push}") + if err != nil { + return RemoteTrackingRef{}, err + } + + ref, err := ParseRemoteTrackingRef(firstLine(revParseOut)) + if err != nil { + return RemoteTrackingRef{}, fmt.Errorf("could not parse push revision: %v", err) + } + + return ref, nil } func (c *Client) DeleteLocalTag(ctx context.Context, tag string) error { diff --git a/git/client_test.go b/git/client_test.go index 9fa076199..3d7560228 100644 --- a/git/client_test.go +++ b/git/client_test.go @@ -952,7 +952,7 @@ func TestClientPushDefault(t *testing.T) { tests := []struct { name string commandResult commandResult - wantPushDefault string + wantPushDefault PushDefault wantError *GitError }{ { @@ -961,7 +961,7 @@ func TestClientPushDefault(t *testing.T) { ExitStatus: 1, Stderr: "error: key does not contain a section: remote.pushDefault", }, - wantPushDefault: "simple", + wantPushDefault: PushDefaultSimple, wantError: nil, }, { @@ -970,7 +970,7 @@ func TestClientPushDefault(t *testing.T) { ExitStatus: 0, Stdout: "current", }, - wantPushDefault: "current", + wantPushDefault: PushDefaultCurrent, wantError: nil, }, { @@ -1077,17 +1077,17 @@ func TestClientParsePushRevision(t *testing.T) { name string branch string commandResult commandResult - wantParsedPushRevision string - wantError *GitError + wantParsedPushRevision RemoteTrackingRef + wantError error }{ { - name: "@{push} resolves to origin/branchName", + name: "@{push} resolves to refs/remotes/origin/branchName", branch: "branchName", commandResult: commandResult{ ExitStatus: 0, - Stdout: "origin/branchName", + Stdout: "refs/remotes/origin/branchName", }, - wantParsedPushRevision: "origin/branchName", + wantParsedPushRevision: RemoteTrackingRef{Remote: "origin", Branch: "branchName"}, }, { name: "@{push} doesn't resolve", @@ -1095,16 +1095,25 @@ func TestClientParsePushRevision(t *testing.T) { ExitStatus: 128, Stderr: "fatal: git error", }, - wantParsedPushRevision: "", + wantParsedPushRevision: RemoteTrackingRef{}, wantError: &GitError{ ExitCode: 128, Stderr: "fatal: git error", }, }, + { + name: "@{push} resolves to something surprising", + commandResult: commandResult{ + ExitStatus: 0, + Stdout: "not/a/valid/remote/ref", + }, + wantParsedPushRevision: RemoteTrackingRef{}, + wantError: fmt.Errorf("could not parse push revision: remote tracking branch must have format refs/remotes// but was: not/a/valid/remote/ref"), + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - cmd := fmt.Sprintf("path/to/git rev-parse --abbrev-ref %s@{push}", tt.branch) + cmd := fmt.Sprintf("path/to/git rev-parse --symbolic-full-name %s@{push}", tt.branch) cmdCtx := createMockedCommandContext(t, mockedCommands{ args(cmd): tt.commandResult, }) @@ -1112,20 +1121,91 @@ func TestClientParsePushRevision(t *testing.T) { GitPath: "path/to/git", commandContext: cmdCtx, } - pushDefault, err := client.ParsePushRevision(context.Background(), tt.branch) + trackingRef, err := client.PushRevision(context.Background(), tt.branch) if tt.wantError != nil { - var gitError *GitError - require.ErrorAs(t, err, &gitError) - assert.Equal(t, tt.wantError.ExitCode, gitError.ExitCode) - assert.Equal(t, tt.wantError.Stderr, gitError.Stderr) + var wantErrorAsGit *GitError + if errors.As(err, &wantErrorAsGit) { + var gitError *GitError + require.ErrorAs(t, err, &gitError) + assert.Equal(t, wantErrorAsGit.ExitCode, gitError.ExitCode) + assert.Equal(t, wantErrorAsGit.Stderr, gitError.Stderr) + } else { + assert.Equal(t, err, tt.wantError) + } } else { require.NoError(t, err) } - assert.Equal(t, tt.wantParsedPushRevision, pushDefault) + assert.Equal(t, tt.wantParsedPushRevision, trackingRef) }) } } +func TestRemoteTrackingRef(t *testing.T) { + t.Run("parsing", func(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + remoteTrackingRef string + wantRemoteTrackingRef RemoteTrackingRef + wantError error + }{ + { + name: "valid remote tracking ref", + remoteTrackingRef: "refs/remotes/origin/branchName", + wantRemoteTrackingRef: RemoteTrackingRef{ + Remote: "origin", + Branch: "branchName", + }, + }, + { + name: "incorrect parts", + remoteTrackingRef: "refs/remotes/origin", + wantRemoteTrackingRef: RemoteTrackingRef{}, + wantError: fmt.Errorf("remote tracking branch must have format refs/remotes// but was: refs/remotes/origin"), + }, + { + name: "incorrect prefix type", + remoteTrackingRef: "invalid/remotes/origin/branchName", + wantRemoteTrackingRef: RemoteTrackingRef{}, + wantError: fmt.Errorf("remote tracking branch must have format refs/remotes// but was: invalid/remotes/origin/branchName"), + }, + { + name: "incorrect ref type", + remoteTrackingRef: "refs/invalid/origin/branchName", + wantRemoteTrackingRef: RemoteTrackingRef{}, + wantError: fmt.Errorf("remote tracking branch must have format refs/remotes// but was: refs/invalid/origin/branchName"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + trackingRef, err := ParseRemoteTrackingRef(tt.remoteTrackingRef) + if tt.wantError != nil { + require.Equal(t, tt.wantError, err) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.wantRemoteTrackingRef, trackingRef) + }) + } + }) + + t.Run("stringifying", func(t *testing.T) { + t.Parallel() + + remoteTrackingRef := RemoteTrackingRef{ + Remote: "origin", + Branch: "branchName", + } + + require.Equal(t, "refs/remotes/origin/branchName", remoteTrackingRef.String()) + }) +} + func TestClientDeleteLocalTag(t *testing.T) { tests := []struct { name string @@ -1992,6 +2072,41 @@ func TestCredentialPatternFromHost(t *testing.T) { } } +func TestPushDefault(t *testing.T) { + t.Run("it parses valid values correctly", func(t *testing.T) { + t.Parallel() + + tests := []struct { + value string + expectedPushDefault PushDefault + }{ + {"nothing", PushDefaultNothing}, + {"current", PushDefaultCurrent}, + {"upstream", PushDefaultUpstream}, + {"tracking", PushDefaultTracking}, + {"simple", PushDefaultSimple}, + {"matching", PushDefaultMatching}, + } + + for _, test := range tests { + t.Run(test.value, func(t *testing.T) { + t.Parallel() + + pushDefault, err := ParsePushDefault(test.value) + require.NoError(t, err) + assert.Equal(t, test.expectedPushDefault, pushDefault) + }) + } + }) + + t.Run("it returns an error for invalid values", func(t *testing.T) { + t.Parallel() + + _, err := ParsePushDefault("invalid") + require.Error(t, err) + }) +} + func createCommandContext(t *testing.T, exitStatus int, stdout, stderr string) (*exec.Cmd, commandCtx) { cmd := exec.CommandContext(context.Background(), os.Args[0], "-test.run=TestHelperProcess", "--") cmd.Env = []string{ diff --git a/internal/run/stub.go b/internal/run/stub.go index 5cd3c6de5..507fd61d6 100644 --- a/internal/run/stub.go +++ b/internal/run/stub.go @@ -46,7 +46,7 @@ func Stub() (*CommandStubber, func(T)) { return } t.Helper() - t.Errorf("unmatched stubs (%d): %s", len(unmatched), strings.Join(unmatched, ", ")) + t.Errorf("unmatched exec stubs (%d): %s", len(unmatched), strings.Join(unmatched, ", ")) } } diff --git a/pkg/cmd/pr/create/create.go b/pkg/cmd/pr/create/create.go index 8ea0b48db..eda7a3ce7 100644 --- a/pkg/cmd/pr/create/create.go +++ b/pkg/cmd/pr/create/create.go @@ -25,6 +25,7 @@ import ( "github.com/cli/cli/v2/pkg/cmdutil" "github.com/cli/cli/v2/pkg/iostreams" "github.com/cli/cli/v2/pkg/markdown" + o "github.com/cli/cli/v2/pkg/option" "github.com/spf13/cobra" ) @@ -72,16 +73,107 @@ type CreateOptions struct { DryRun bool } +// creationRefs is an interface that provides the necessary information for creating a pull request in the API. +// Upcasting to concrete implementations can provide further context on other operations (forking and pushing). +type creationRefs interface { + // QualifiedHeadRef returns a stringified form of the head ref, varying depending + // on whether the head ref is in the same repository as the base ref. If they are + // the same repository, we return the branch name only. If they are different repositories, + // we return the owner and branch name in the form :. + QualifiedHeadRef() string + // UnqualifiedHeadRef returns a head ref in the form of the branch name only. + UnqualifiedHeadRef() string + //BaseRef returns the base branch name. + BaseRef() string + + // While the only thing really required from an api.Repository is the repository ID, changing that + // would require changing the API function signatures, and the refactor that introduced this refs + // type is already large enough. + BaseRepo() *api.Repository +} + +type baseRefs struct { + baseRepo *api.Repository + baseBranchName string +} + +func (r baseRefs) BaseRef() string { + return r.baseBranchName +} + +func (r baseRefs) BaseRepo() *api.Repository { + return r.baseRepo +} + +// skipPushRefs indicate to handlePush that no pushing is required. +type skipPushRefs struct { + baseRefs + + qualifiedHeadRef shared.QualifiedHeadRef +} + +func (r skipPushRefs) QualifiedHeadRef() string { + return r.qualifiedHeadRef.String() +} + +func (r skipPushRefs) UnqualifiedHeadRef() string { + return r.qualifiedHeadRef.BranchName() +} + +// pushableRefs indicate to handlePush that pushing is required, +// and provide further information (HeadRepo) on where that push +// should go. +type pushableRefs struct { + baseRefs + + headRepo ghrepo.Interface + headBranchName string +} + +func (r pushableRefs) QualifiedHeadRef() string { + if ghrepo.IsSame(r.headRepo, r.baseRepo) { + return r.headBranchName + } + return fmt.Sprintf("%s:%s", r.headRepo.RepoOwner(), r.headBranchName) +} + +func (r pushableRefs) UnqualifiedHeadRef() string { + return r.headBranchName +} + +func (r pushableRefs) HeadRepo() ghrepo.Interface { + return r.headRepo +} + +// forkableRefs indicate to handlePush that forking is required before +// pushing. The expectation is that after forking, this is converted to +// pushableRefs. We could go very OOP and have a Fork method on this +// struct that returns a pushableRefs but then we'd need to embed an API client +// and it just seems nice that it is a simple bag of data. +type forkableRefs struct { + baseRefs + + qualifiedHeadRef shared.QualifiedHeadRef +} + +func (r forkableRefs) QualifiedHeadRef() string { + return r.qualifiedHeadRef.String() +} + +func (r forkableRefs) UnqualifiedHeadRef() string { + return r.qualifiedHeadRef.BranchName() +} + +// CreateContext stores contextual data about the creation process and is for building up enough +// data to create a pull request. type CreateContext struct { - // This struct stores contextual data about the creation process and is for building up enough - // data to create a pull request - RepoContext *ghContext.ResolvedRemotes - PrRefs shared.PullRequestRefs + ResolvedRemotes *ghContext.ResolvedRemotes + PRRefs creationRefs + // BaseTrackingBranch is perhaps a slightly leaky abstraction in the presence + // of PRRefs, but a huge amount of refactoring was done to introduce that struct, + // and this is a small price to pay for the convenience of not having to do a lot + // more design. BaseTrackingBranch string - BaseBranch string // Currently not supported by shared.PullRequestRefs struct - HeadRemote *ghContext.Remote - isPushEnabled bool - forkHeadRepo bool Client *api.Client GitClient *git.Client } @@ -312,8 +404,8 @@ func createRun(opts *CreateOptions) error { } existingPR, _, err := opts.Finder.Find(shared.FindOptions{ - Selector: ctx.PrRefs.GetPRHeadLabel(), - BaseBranch: ctx.BaseBranch, + Selector: ctx.PRRefs.QualifiedHeadRef(), + BaseBranch: ctx.PRRefs.BaseRef(), States: []string{"OPEN"}, Fields: []string{"url"}, }) @@ -323,7 +415,7 @@ func createRun(opts *CreateOptions) error { } if err == nil { return fmt.Errorf("a pull request for branch %q into branch %q already exists:\n%s", - ctx.PrRefs.GetPRHeadLabel(), ctx.BaseBranch, existingPR.URL) + ctx.PRRefs.QualifiedHeadRef(), ctx.PRRefs.BaseRef(), existingPR.URL) } message := "\nCreating pull request for %s into %s in %s\n\n" @@ -338,9 +430,9 @@ func createRun(opts *CreateOptions) error { if opts.IO.CanPrompt() { fmt.Fprintf(opts.IO.ErrOut, message, - cs.Cyan(ctx.PrRefs.GetPRHeadLabel()), - cs.Cyan(ctx.BaseBranch), - ghrepo.FullName(ctx.PrRefs.BaseRepo)) + cs.Cyan(ctx.PRRefs.QualifiedHeadRef()), + cs.Cyan(ctx.PRRefs.BaseRef()), + ghrepo.FullName(ctx.PRRefs.BaseRepo())) } if !opts.EditorMode && (opts.FillVerbose || opts.Autofill || opts.FillFirst || (opts.TitleProvided && opts.BodyProvided)) { @@ -363,7 +455,7 @@ func createRun(opts *CreateOptions) error { action = shared.SubmitDraftAction } - tpl := shared.NewTemplateManager(client.HTTP(), ctx.PrRefs.BaseRepo, opts.Prompter, opts.RootDirOverride, opts.RepoOverride == "", true) + tpl := shared.NewTemplateManager(client.HTTP(), ctx.PRRefs.BaseRepo(), opts.Prompter, opts.RootDirOverride, opts.RepoOverride == "", true) if opts.EditorMode { if opts.Template != "" { @@ -431,7 +523,7 @@ func createRun(opts *CreateOptions) error { } allowPreview := !state.HasMetadata() && shared.ValidURL(openURL) && !opts.DryRun - allowMetadata := ctx.PrRefs.BaseRepo.(*api.Repository).ViewerCanTriage() + allowMetadata := ctx.PRRefs.BaseRepo().ViewerCanTriage() action, err = shared.ConfirmPRSubmission(opts.Prompter, allowPreview, allowMetadata, state.Draft) if err != nil { return fmt.Errorf("unable to confirm: %w", err) @@ -441,10 +533,10 @@ func createRun(opts *CreateOptions) error { fetcher := &shared.MetadataFetcher{ IO: opts.IO, APIClient: client, - Repo: ctx.PrRefs.BaseRepo, + Repo: ctx.PRRefs.BaseRepo(), State: state, } - err = shared.MetadataSurvey(opts.Prompter, opts.IO, ctx.PrRefs.BaseRepo, fetcher, state) + err = shared.MetadataSurvey(opts.Prompter, opts.IO, ctx.PRRefs.BaseRepo(), fetcher, state) if err != nil { return err } @@ -487,11 +579,7 @@ func createRun(opts *CreateOptions) error { var regexPattern = regexp.MustCompile(`(?m)^`) func initDefaultTitleBody(ctx CreateContext, state *shared.IssueMetadataState, useFirstCommit bool, addBody bool) error { - baseRef := ctx.BaseTrackingBranch - headRef := ctx.PrRefs.BranchName - gitClient := ctx.GitClient - - commits, err := gitClient.Commits(context.Background(), baseRef, headRef) + commits, err := ctx.GitClient.Commits(context.Background(), ctx.BaseTrackingBranch, ctx.PRRefs.UnqualifiedHeadRef()) if err != nil { return err } @@ -500,7 +588,7 @@ func initDefaultTitleBody(ctx CreateContext, state *shared.IssueMetadataState, u state.Title = commits[len(commits)-1].Title state.Body = commits[len(commits)-1].Body } else { - state.Title = humanize(headRef) + state.Title = humanize(ctx.PRRefs.UnqualifiedHeadRef()) var body strings.Builder for i := len(commits) - 1; i >= 0; i-- { fmt.Fprintf(&body, "- **%s**\n", commits[i].Title) @@ -526,7 +614,7 @@ func NewIssueState(ctx CreateContext, opts CreateOptions) (*shared.IssueMetadata milestoneTitles = []string{opts.Milestone} } - meReplacer := shared.NewMeReplacer(ctx.Client, ctx.PrRefs.BaseRepo.RepoHost()) + meReplacer := shared.NewMeReplacer(ctx.Client, ctx.PRRefs.BaseRepo().RepoHost()) assignees, err := meReplacer.ReplaceSlice(opts.Assignees) if err != nil { return nil, err @@ -553,7 +641,6 @@ func NewIssueState(ctx CreateContext, opts CreateOptions) (*shared.IssueMetadata } func NewCreateContext(opts *CreateOptions) (*CreateContext, error) { - ctx := context.Background() httpClient, err := opts.HttpClient() if err != nil { return nil, err @@ -565,25 +652,19 @@ func NewCreateContext(opts *CreateOptions) (*CreateContext, error) { return nil, err } - gitClient := opts.GitClient - if ucc, err := gitClient.UncommittedChangeCount(ctx); err == nil && ucc > 0 { - fmt.Fprintf(opts.IO.ErrOut, "Warning: %s\n", text.Pluralize(ucc, "uncommitted change")) - } - - // Resolve base repo - repoContext, err := ghContext.ResolveRemotesToRepos(remotes, client, opts.RepoOverride) + resolvedRemotes, err := ghContext.ResolveRemotesToRepos(remotes, client, opts.RepoOverride) if err != nil { return nil, err } - var targetBaseRepo *api.Repository - if br, err := repoContext.BaseRepo(opts.IO); err == nil { + var baseRepo *api.Repository + if br, err := resolvedRemotes.BaseRepo(opts.IO); err == nil { if r, ok := br.(*api.Repository); ok { - targetBaseRepo = r + baseRepo = r } else { // TODO: if RepoNetwork is going to be requested anyway in `repoContext.HeadRepos()`, // consider piggybacking on that result instead of performing a separate lookup - targetBaseRepo, err = api.GitHubRepo(client, br) + baseRepo, err = api.GitHubRepo(client, br) if err != nil { return nil, err } @@ -592,181 +673,284 @@ func NewCreateContext(opts *CreateOptions) (*CreateContext, error) { return nil, err } - // Resolve target head branch name from either - // --head or the current branch. - var targetHeadBranch string - var targetHeadRepoOwner string + // This closure provides an easy way to instantiate a CreateContext with everything other than + // the refs. This probably indicates that CreateContext could do with some rework, but the refactor + // to introduce PRRefs is already large enough. + var newCreateContext = func(refs creationRefs) *CreateContext { + baseTrackingBranch := refs.BaseRef() - promptForHeadRepo := true + // The baseTrackingBranch is used later for a command like: + // `git commit upstream/main feature` in order to create a PR message showing the commits + // between these two refs. I'm not really sure what is expected to happen if we don't have a remote, + // which seems like it would be possible with a command `gh pr create --repo owner/repo-that-is-not-a-remote`. + // In that case, we might just have a mess? In any case, this is what the old code did, so I don't want to change + // it as part of an already large refactor. + baseRemote, _ := resolvedRemotes.RemoteForRepo(baseRepo) + if baseRemote != nil { + baseTrackingBranch = fmt.Sprintf("%s/%s", baseRemote.Name, baseTrackingBranch) + } + return &CreateContext{ + ResolvedRemotes: resolvedRemotes, + Client: client, + GitClient: opts.GitClient, + PRRefs: refs, + BaseTrackingBranch: baseTrackingBranch, + } + } + + // If the user provided a head branch we're going to use that without any interrogation + // of git. The value can take the form of or :. In the former case, the + // PR base and head repos are the same. In the latter case we don't know the head repo + // (though we could look it up in the API) but fortunately we don't need to because the API + // will resolve this for us when we create the pull request. This is possible because + // users can only have a single fork in their namespace, and organizations don't work at all with this ref format. + // + // Note that providing the head branch in this way indicates that we shouldn't push the branch, + // and we indicate that via the returned type as well. if opts.HeadBranch != "" { - promptForHeadRepo = false - targetHeadBranch = opts.HeadBranch - // If the --head provided contains a colon, that means - // this is : syntax. - if idx := strings.IndexRune(opts.HeadBranch, ':'); idx >= 0 { - targetHeadRepoOwner = opts.HeadBranch[:idx] - targetHeadBranch = opts.HeadBranch[idx+1:] - } - } else { - // Use the current branch as the target local head branch when - // --head is not provided. - targetHeadBranch, err = opts.Branch() - if err != nil { - return nil, fmt.Errorf("could not determine the current branch: %w", err) - } - } - - targetHeadBranchConfig, err := gitClient.ReadBranchConfig(ctx, targetHeadBranch) - if err != nil { - return nil, err - } - - // See if we can determine if this branch has been push previously with - // Git configurations and @{push} revision syntax. - remotePushDefault, err := opts.GitClient.RemotePushDefault(ctx) - if err != nil { - return nil, err - } - // Suppressing these errors as we have other means of computing the PullRequestRefs when these fail. - parsedPushRevision, _ := gitClient.ParsePushRevision(ctx, targetHeadBranch) - pushDefault, err := gitClient.PushDefault(ctx) - if err != nil { - return nil, err - } - - prRefs, err := shared.ParsePRRefs(targetHeadBranch, targetHeadBranchConfig, parsedPushRevision, pushDefault, remotePushDefault, targetBaseRepo, remotes) - if err != nil { - return nil, err - } - - // If the --head provided contains : syntax, we need to use - // the provided owner instead of the owner of the base repository. - if targetHeadRepoOwner != "" { - prRefs.HeadRepo = ghrepo.New(targetHeadRepoOwner, prRefs.HeadRepo.RepoName()) - } - - var headRemote *ghContext.Remote - - // We received the head repository and branch from ParsePRRefs, or inferred - // it from --head input, but we need to check if it's up-to-date with - // our local branch state. - // If it is, we can use it as the head repo for the PR - // and avoid prompting the user. - // Errors raised here should not cause command to fail, - // prompt user for head repo if an error is raised or no remote found. - if prRefs.HasHead() { - // Check if the head branch is up-to-date with the local branch - headRemote, err := remotes.FindByRepo(prRefs.HeadRepo.RepoOwner(), prRefs.HeadRepo.RepoName()) - if headRemote != nil && err == nil { - headRefName := fmt.Sprintf("refs/remotes/%s/%s", headRemote, prRefs.BranchName) - refsForLookup := []string{"HEAD", headRefName} - resolvedRefs, err := gitClient.ShowRefs(ctx, refsForLookup) - - // If there is more than one resolved ref, then remote head ref was resolved. - if err == nil && len(resolvedRefs) > 1 { - headRef := resolvedRefs[0] - for _, r := range resolvedRefs[1:] { - // If the head ref is the same as the remote head ref, - // then the remote head is current and we can use it. - if r.Hash == headRef.Hash { - promptForHeadRepo = false - break - } - } - } - } - } - - var forkHeadRepo bool - var isPushEnabled bool - - if promptForHeadRepo && opts.IO.CanPrompt() { - isPushEnabled = true - // Since we could not determine a head ref, prompt the user for the head repository to push - // using a list of repositories obtained from the API - pushableRepos, err := repoContext.HeadRepos() + qualifiedHeadRef, err := shared.ParseQualifiedHeadRef(opts.HeadBranch) if err != nil { return nil, err } - if len(pushableRepos) == 0 { - pushableRepos, err = api.RepoFindForks(client, prRefs.BaseRepo, 3) - if err != nil { - return nil, err - } - } - - currentLogin, err := api.CurrentLoginName(client, prRefs.BaseRepo.RepoHost()) + branchConfig, err := opts.GitClient.ReadBranchConfig(context.Background(), qualifiedHeadRef.BranchName()) if err != nil { return nil, err } - hasOwnFork := false - var pushOptions []string - for _, r := range pushableRepos { - pushOptions = append(pushOptions, ghrepo.FullName(r)) - if r.RepoOwner() == currentLogin { - hasOwnFork = true - } + baseBranch := opts.BaseBranch + if baseBranch == "" { + baseBranch = branchConfig.MergeBase + } + if baseBranch == "" { + baseBranch = baseRepo.DefaultBranchRef.Name } - if !hasOwnFork { - pushOptions = append(pushOptions, "Create a fork of "+ghrepo.FullName(prRefs.BaseRepo)) - } - pushOptions = append(pushOptions, "Skip pushing the branch") - pushOptions = append(pushOptions, "Cancel") - - selectedOption, err := opts.Prompter.Select(fmt.Sprintf("Where should we push the '%s' branch?", prRefs.BranchName), "", pushOptions) - if err != nil { - return nil, err - } - - if selectedOption < len(pushableRepos) { - prRefs.HeadRepo = pushableRepos[selectedOption] - } else if pushOptions[selectedOption] == "Skip pushing the branch" { - isPushEnabled = false - } else if pushOptions[selectedOption] == "Cancel" { - return nil, cmdutil.CancelError - } else { - // "Create a fork of ..." - forkHeadRepo = true - prRefs.HeadRepo = ghrepo.New(currentLogin, prRefs.HeadRepo.RepoName()) - } + return newCreateContext(skipPushRefs{ + qualifiedHeadRef: qualifiedHeadRef, + baseRefs: baseRefs{ + baseRepo: baseRepo, + baseBranchName: baseBranch, + }, + }), nil } - if prRefs.HeadRepo == nil && isPushEnabled && !opts.IO.CanPrompt() { - fmt.Fprintf(opts.IO.ErrOut, "aborted: you must first push the current branch to a remote, or use the --head flag") - return nil, cmdutil.SilentError + if ucc, err := opts.GitClient.UncommittedChangeCount(context.Background()); err == nil && ucc > 0 { + fmt.Fprintf(opts.IO.ErrOut, "Warning: %s\n", text.Pluralize(ucc, "uncommitted change")) + } + + // If the user didn't provide a head branch then we're gettin' real. We're going to interrogate git + // and try to create refs that are pushable. + currentBranch, err := opts.Branch() + if err != nil { + return nil, fmt.Errorf("could not determine the current branch: %w", err) + } + + branchConfig, err := opts.GitClient.ReadBranchConfig(context.Background(), currentBranch) + if err != nil { + return nil, err } baseBranch := opts.BaseBranch if baseBranch == "" { - baseBranch = targetHeadBranchConfig.MergeBase + baseBranch = branchConfig.MergeBase } if baseBranch == "" { - baseBranch = targetBaseRepo.DefaultBranchRef.Name - } - if prRefs.BranchName == baseBranch && prRefs.HeadRepo != nil && ghrepo.IsSame(prRefs.BaseRepo, prRefs.HeadRepo) { - return nil, fmt.Errorf("must be on a branch named differently than %q", baseBranch) + baseBranch = baseRepo.DefaultBranchRef.Name } - baseTrackingBranch := baseBranch - if baseRemote, err := remotes.FindByRepo(prRefs.BaseRepo.RepoOwner(), prRefs.BaseRepo.RepoName()); err == nil { - baseTrackingBranch = fmt.Sprintf("%s/%s", baseRemote.Name, baseBranch) + // First we check with the git information we have to see if we can figure out the default + // head repo and remote branch name. + defaultPRHead, err := shared.TryDetermineDefaultPRHead( + // We requested the branch config already, so let's cache that + shared.CachedBranchConfigGitConfigClient{ + CachedBranchConfig: branchConfig, + GitConfigClient: opts.GitClient, + }, + shared.NewRemoteToRepoResolver(opts.Remotes), + currentBranch, + ) + if err != nil { + return nil, err } - return &CreateContext{ - PrRefs: prRefs, - BaseBranch: baseBranch, // Currently not supported by shared.PullRequestRefs struct - BaseTrackingBranch: baseTrackingBranch, - HeadRemote: headRemote, - isPushEnabled: isPushEnabled, - forkHeadRepo: forkHeadRepo, - RepoContext: repoContext, - Client: client, - GitClient: gitClient, - }, nil + // The baseRefs are always going to be the same from now on. If I could make this immutable I would! + baseRefs := baseRefs{ + baseRepo: baseRepo, + baseBranchName: baseBranch, + } + + // If we were able to determine a head repo, then let's check that the remote tracking ref matches the SHA of + // HEAD. If it does, then we don't need to push, otherwise we'll need to ask the user to tell us where to push. + if headRepo, present := defaultPRHead.Repo.Value(); present { + // We may not find a remote because the git branch config may have a URL rather than a remote name. + // Ideally, we would return a sentinel error from RemoteForRepo that we could compare to, but the + // refactor that introduced this code was already large enough. + headRemote, _ := resolvedRemotes.RemoteForRepo(headRepo) + if headRemote != nil { + resolvedRefs, _ := opts.GitClient.ShowRefs( + context.Background(), + []string{ + "HEAD", + fmt.Sprintf("refs/remotes/%s/%s", headRemote.Name, defaultPRHead.BranchName), + }, + ) + + // Two refs returned means we can compare HEAD to the remote tracking branch. + // If we had a matching ref, then we can skip pushing. + refsMatch := len(resolvedRefs) == 2 && resolvedRefs[0].Hash == resolvedRefs[1].Hash + if refsMatch { + qualifiedHeadRef := shared.NewQualifiedHeadRefWithoutOwner(defaultPRHead.BranchName) + if headRepo.RepoOwner() != baseRepo.RepoOwner() { + qualifiedHeadRef = shared.NewQualifiedHeadRef(headRepo.RepoOwner(), defaultPRHead.BranchName) + } + + return newCreateContext(skipPushRefs{ + qualifiedHeadRef: qualifiedHeadRef, + baseRefs: baseRefs, + }), nil + } + } + } + + // If we didn't determine that the git indicated repo had the correct ref, we'll take a look at the other + // remotes and see whether any of them have the same SHA as HEAD. Now, at this point, you might be asking yourself: + // "Why didn't we collect all the SHAs with a single ShowRefs command above, for use in both cases?" + // ... + // That's because the code below has a bug that I've ported from the old code, in order to preserve the existing + // behaviour, and to limit the scope of an already large refactor. The intention of the original code was to loop + // over all the returned refs. However, as it turns out, our implementation of ShowRefs doesn't do that correctly. + // Since it provides the --verify flag, git will return the SHAs for refs up until it hits a ref that doesn't exist, + // at which point it bails out. + // + // Imagine you have a remotes "upstream" and "origin", and you have pushed your branch "feature" to "origin". Since + // the order of remotes is always guaranteed "upstream", "github", "origin", and then everything else unstably sorted, + // we will never get a SHA for origin, as refs/remotes/upstream/feature doesn't exist. + // + // Furthermore, when you really think about it, this code is a bit eager. What happens if you have the same SHA on + // remotes "origin" and "colleague", this will always offer origin. If it were "colleague-a" and "colleague-b", no + // order would be guaranteed between different invocations of pr create, because the order of remotes after "origin" + // is unstable sorted. + // + // All that said, this has been the behaviour for a long, long time, and I do not want to make other behavioural changes + // in what is mostly a refactor. + refsToLookup := []string{"HEAD"} + for _, remote := range remotes { + refsToLookup = append(refsToLookup, fmt.Sprintf("refs/remotes/%s/%s", remote.Name, currentBranch)) + } + + // Ignoring the error in this case is allowed because we may get refs and an error (see: --verify flag above). + // Ideally there would be a typed error to allow us to distinguish between an execution error and some refs + // not existing. However, this is too much to take on in an already large refactor. + refs, _ := opts.GitClient.ShowRefs(context.Background(), refsToLookup) + if len(refs) > 1 { + headRef := refs[0] + var firstMatchingRef o.Option[git.RemoteTrackingRef] + // Loop over all the refs, trying to find one that matches the SHA of HEAD. + for _, r := range refs[1:] { + if r.Hash == headRef.Hash { + remoteTrackingRef, err := git.ParseRemoteTrackingRef(r.Name) + if err != nil { + return nil, err + } + + firstMatchingRef = o.Some(remoteTrackingRef) + break + } + } + + // If we found a matching ref, then we don't need to push. + if ref, present := firstMatchingRef.Value(); present { + remote, err := remotes.FindByName(ref.Remote) + if err != nil { + return nil, err + } + + qualifiedHeadRef := shared.NewQualifiedHeadRefWithoutOwner(ref.Branch) + if baseRepo.RepoOwner() != remote.RepoOwner() { + qualifiedHeadRef = shared.NewQualifiedHeadRef(remote.RepoOwner(), ref.Branch) + } + + return newCreateContext(skipPushRefs{ + qualifiedHeadRef: qualifiedHeadRef, + baseRefs: baseRefs, + }), nil + } + } + + // If we haven't got a repo by now, and we can't prompt then it's game over. + if !opts.IO.CanPrompt() { + fmt.Fprintln(opts.IO.ErrOut, "aborted: you must first push the current branch to a remote, or use the --head flag") + return nil, cmdutil.SilentError + } + + // Otherwise, hooray, prompting! + + // First, we're going to look at our remotes and decide whether there are any repos we can push to. + pushableRepos, err := resolvedRemotes.HeadRepos() + if err != nil { + return nil, err + } + + // If we couldn't find any pushable repos, then find forks of the base repo. + if len(pushableRepos) == 0 { + pushableRepos, err = api.RepoFindForks(client, baseRepo, 3) + if err != nil { + return nil, err + } + } + + currentLogin, err := api.CurrentLoginName(client, baseRepo.RepoHost()) + if err != nil { + return nil, err + } + + hasOwnFork := false + var pushOptions []string + for _, r := range pushableRepos { + pushOptions = append(pushOptions, ghrepo.FullName(r)) + if r.RepoOwner() == currentLogin { + hasOwnFork = true + } + } + + if !hasOwnFork { + pushOptions = append(pushOptions, fmt.Sprintf("Create a fork of %s", ghrepo.FullName(baseRepo))) + } + pushOptions = append(pushOptions, "Skip pushing the branch") + pushOptions = append(pushOptions, "Cancel") + + selectedOption, err := opts.Prompter.Select(fmt.Sprintf("Where should we push the '%s' branch?", currentBranch), "", pushOptions) + if err != nil { + return nil, err + } + + if selectedOption < len(pushableRepos) { + // A repository has been selected to push to. + return newCreateContext(pushableRefs{ + headRepo: pushableRepos[selectedOption], + headBranchName: currentBranch, + baseRefs: baseRefs, + }), nil + } else if pushOptions[selectedOption] == "Skip pushing the branch" { + // We're going to skip pushing the branch altogether, meaning, use whatever SHA is already pushed. + // It's not exactly clear what repo the user expects to use here for the HEAD, and maybe we should + // make that clear in the UX somehow, but in the old implementation as far as I can tell, this + // always meant "use the base repo". + return newCreateContext(skipPushRefs{ + qualifiedHeadRef: shared.NewQualifiedHeadRefWithoutOwner(currentBranch), + baseRefs: baseRefs, + }), nil + } else if pushOptions[selectedOption] == "Cancel" { + return nil, cmdutil.CancelError + } else { + // A fork should be created. + return newCreateContext(forkableRefs{ + qualifiedHeadRef: shared.NewQualifiedHeadRef(currentLogin, currentBranch), + baseRefs: baseRefs, + }), nil + } } func getRemotes(opts *CreateOptions) (ghContext.Remotes, error) { @@ -789,8 +973,8 @@ func submitPR(opts CreateOptions, ctx CreateContext, state shared.IssueMetadataS "title": state.Title, "body": state.Body, "draft": state.Draft, - "baseRefName": ctx.BaseBranch, - "headRefName": ctx.PrRefs.GetPRHeadLabel(), + "baseRefName": ctx.PRRefs.BaseRef(), + "headRefName": ctx.PRRefs.QualifiedHeadRef(), "maintainerCanModify": opts.MaintainerCanModify, } @@ -798,7 +982,7 @@ func submitPR(opts CreateOptions, ctx CreateContext, state shared.IssueMetadataS return errors.New("pull request title must not be blank") } - err := shared.AddMetadataToIssueParams(client, ctx.PrRefs.BaseRepo, params, &state) + err := shared.AddMetadataToIssueParams(client, ctx.PRRefs.BaseRepo(), params, &state) if err != nil { return err } @@ -812,9 +996,7 @@ func submitPR(opts CreateOptions, ctx CreateContext, state shared.IssueMetadataS } opts.IO.StartProgressIndicator() - // At this point, ctx.PrRefs.BaseRepo is guaranteed to be an *api.Repository - // because of https://github.com/cli/cli/blob/d29db2d44199ad4a987ea866f3f4ff601b1c90a0/pkg/cmd/pr/create/create.go#L578-L592 - pr, err := api.CreatePullRequest(client, ctx.PrRefs.BaseRepo.(*api.Repository), params) + pr, err := api.CreatePullRequest(client, ctx.PRRefs.BaseRepo(), params) opts.IO.StopProgressIndicator() if pr != nil { fmt.Fprintln(opts.IO.Out, pr.URL) @@ -910,38 +1092,43 @@ func previewPR(opts CreateOptions, openURL string) error { } func handlePush(opts CreateOptions, ctx CreateContext) error { - didForkRepo := false - headRepo := ctx.PrRefs.HeadRepo - headRemote := ctx.HeadRemote - client := ctx.Client - gitClient := ctx.GitClient - - var err error - // if a head repository could not be determined so far, automatically create - // one by forking the base repository - if ctx.forkHeadRepo && ctx.isPushEnabled { + refs := ctx.PRRefs + forkableRefs, requiresFork := refs.(forkableRefs) + if requiresFork { opts.IO.StartProgressIndicator() - headRepo, err = api.ForkRepo(client, ctx.PrRefs.BaseRepo, "", "", false) + forkedRepo, err := api.ForkRepo(ctx.Client, forkableRefs.BaseRepo(), "", "", false) opts.IO.StopProgressIndicator() if err != nil { return fmt.Errorf("error forking repo: %w", err) } - didForkRepo = true + + refs = pushableRefs{ + headRepo: forkedRepo, + headBranchName: forkableRefs.qualifiedHeadRef.BranchName(), + baseRefs: baseRefs{ + baseRepo: forkableRefs.baseRepo, + baseBranchName: forkableRefs.baseBranchName, + }, + } } - if headRemote == nil && headRepo != nil { - headRemote, _ = ctx.RepoContext.RemoteForRepo(headRepo) + // We may have upcast to pushableRefs on fork, or we may have been passed an instance + // already. But if we haven't, then there's nothing more to do. + pushableRefs, ok := refs.(pushableRefs) + if !ok { + return nil } // There are two cases when an existing remote for the head repo will be - // missing: + // missing (and an error will be returned): // 1. the head repo was just created by auto-forking; // 2. an existing fork was discovered by querying the API. // In either case, we want to add the head repo as a new git remote so we // can push to it. We will try to add the head repo as the "origin" remote // and fallback to the "fork" remote if it is unavailable. Also, if the // base repo is the "origin" remote we will rename it "upstream". - if headRemote == nil && ctx.isPushEnabled { + headRemote, _ := ctx.ResolvedRemotes.RemoteForRepo(pushableRefs.HeadRepo()) + if headRemote == nil { cfg, err := opts.Config() if err != nil { return err @@ -952,8 +1139,8 @@ func handlePush(opts CreateOptions, ctx CreateContext) error { return err } - cloneProtocol := cfg.GitProtocol(headRepo.RepoHost()).Value - headRepoURL := ghrepo.FormatRemoteURL(headRepo, cloneProtocol) + cloneProtocol := cfg.GitProtocol(pushableRefs.HeadRepo().RepoHost()).Value + headRepoURL := ghrepo.FormatRemoteURL(pushableRefs.HeadRepo(), cloneProtocol) gitClient := ctx.GitClient origin, _ := remotes.FindByName("origin") upstreamName := "upstream" @@ -964,7 +1151,7 @@ func handlePush(opts CreateOptions, ctx CreateContext) error { remoteName = "fork" } - if origin != nil && upstream == nil && ghrepo.IsSame(origin, ctx.PrRefs.BaseRepo) { + if origin != nil && upstream == nil && ghrepo.IsSame(origin, pushableRefs.BaseRepo()) { renameCmd, err := gitClient.Command(context.Background(), "remote", "rename", "origin", upstreamName) if err != nil { return err @@ -973,7 +1160,7 @@ func handlePush(opts CreateOptions, ctx CreateContext) error { return fmt.Errorf("error renaming origin remote: %w", err) } remoteName = "origin" - fmt.Fprintf(opts.IO.ErrOut, "Changed %s remote to %q\n", ghrepo.FullName(ctx.PrRefs.BaseRepo), upstreamName) + fmt.Fprintf(opts.IO.ErrOut, "Changed %s remote to %q\n", ghrepo.FullName(pushableRefs.BaseRepo()), upstreamName) } gitRemote, err := gitClient.AddRemote(context.Background(), remoteName, headRepoURL, []string{}) @@ -981,10 +1168,10 @@ func handlePush(opts CreateOptions, ctx CreateContext) error { return fmt.Errorf("error adding remote: %w", err) } - fmt.Fprintf(opts.IO.ErrOut, "Added %s as remote %q\n", ghrepo.FullName(headRepo), remoteName) + fmt.Fprintf(opts.IO.ErrOut, "Added %s as remote %q\n", ghrepo.FullName(pushableRefs.HeadRepo()), remoteName) // Only mark `upstream` remote as default if `gh pr create` created the remote. - if didForkRepo { + if requiresFork { err := gitClient.SetRemoteResolution(context.Background(), upstreamName, "base") if err != nil { return fmt.Errorf("error setting upstream as default: %w", err) @@ -992,52 +1179,45 @@ func handlePush(opts CreateOptions, ctx CreateContext) error { if opts.IO.IsStdoutTTY() { cs := opts.IO.ColorScheme() - fmt.Fprintf(opts.IO.ErrOut, "%s Repository %s set as the default repository. To learn more about the default repository, run: gh repo set-default --help\n", cs.WarningIcon(), cs.Bold(ghrepo.FullName(headRepo))) + fmt.Fprintf(opts.IO.ErrOut, "%s Repository %s set as the default repository. To learn more about the default repository, run: gh repo set-default --help\n", cs.WarningIcon(), cs.Bold(ghrepo.FullName(pushableRefs.HeadRepo()))) } } headRemote = &ghContext.Remote{ Remote: gitRemote, - Repo: headRepo, + Repo: pushableRefs.HeadRepo(), } } // automatically push the branch if it hasn't been pushed anywhere yet - if ctx.isPushEnabled { - pushBranch := func() error { - w := NewRegexpWriter(opts.IO.ErrOut, gitPushRegexp, "") - defer w.Flush() - ref := fmt.Sprintf("HEAD:refs/heads/%s", ctx.PrRefs.BranchName) - bo := backoff.NewConstantBackOff(2 * time.Second) - ctx := context.Background() - return backoff.Retry(func() error { - if err := gitClient.Push(ctx, headRemote.Name, ref, git.WithStderr(w)); err != nil { - // Only retry if we have forked the repo else the push should succeed the first time. - if didForkRepo { - fmt.Fprintf(opts.IO.ErrOut, "waiting 2 seconds before retrying...\n") - return err - } - return backoff.Permanent(err) + pushBranch := func() error { + w := NewRegexpWriter(opts.IO.ErrOut, gitPushRegexp, "") + defer w.Flush() + ref := fmt.Sprintf("HEAD:refs/heads/%s", ctx.PRRefs.UnqualifiedHeadRef()) + bo := backoff.NewConstantBackOff(2 * time.Second) + root := context.Background() + return backoff.Retry(func() error { + if err := ctx.GitClient.Push(root, headRemote.Name, ref, git.WithStderr(w)); err != nil { + // Only retry if we have forked the repo else the push should succeed the first time. + if requiresFork { + fmt.Fprintf(opts.IO.ErrOut, "waiting 2 seconds before retrying...\n") + return err } - return nil - }, backoff.WithContext(backoff.WithMaxRetries(bo, 3), ctx)) - } - - err := pushBranch() - if err != nil { - return err - } + return backoff.Permanent(err) + } + return nil + }, backoff.WithContext(backoff.WithMaxRetries(bo, 3), root)) } - return nil + return pushBranch() } func generateCompareURL(ctx CreateContext, state shared.IssueMetadataState) (string, error) { u := ghrepo.GenerateRepoURL( - ctx.PrRefs.BaseRepo, + ctx.PRRefs.BaseRepo(), "compare/%s...%s?expand=1", - url.PathEscape(ctx.BaseBranch), url.PathEscape(ctx.PrRefs.GetPRHeadLabel())) - url, err := shared.WithPrAndIssueQueryParams(ctx.Client, ctx.PrRefs.BaseRepo, u, state) + url.PathEscape(ctx.PRRefs.BaseRef()), url.PathEscape(ctx.PRRefs.QualifiedHeadRef())) + url, err := shared.WithPrAndIssueQueryParams(ctx.Client, ctx.PRRefs.BaseRepo(), u, state) if err != nil { return "", err } diff --git a/pkg/cmd/pr/create/create_test.go b/pkg/cmd/pr/create/create_test.go index 51cbfa724..2a88b5eee 100644 --- a/pkg/cmd/pr/create/create_test.go +++ b/pkg/cmd/pr/create/create_test.go @@ -2,7 +2,6 @@ package create import ( "encoding/json" - "errors" "fmt" "net/http" "os" @@ -332,19 +331,18 @@ func TestNewCmdCreate(t *testing.T) { func Test_createRun(t *testing.T) { tests := []struct { - name string - setup func(*CreateOptions, *testing.T) func() - cmdStubs func(*run.CommandStubber) - promptStubs func(*prompter.PrompterMock) - httpStubs func(*httpmock.Registry, *testing.T) - expectedOutputs []string - expectedOut string - expectedErrOut string - expectedBrowse string - wantErr string - tty bool - customBranchConfig bool - customPushDestination bool + name string + setup func(*CreateOptions, *testing.T) func() + cmdStubs func(*run.CommandStubber) + promptStubs func(*prompter.PrompterMock) + httpStubs func(*httpmock.Registry, *testing.T) + expectedOutputs []string + expectedOut string + expectedErrOut string + expectedBrowse string + wantErr string + tty bool + customBranchConfig bool }{ { name: "nontty web", @@ -608,7 +606,7 @@ func Test_createRun(t *testing.T) { `), }, { - name: "survey", + name: "select a specific branch to push to on prompt", tty: true, setup: func(opts *CreateOptions, t *testing.T) func() { opts.TitleProvided = true @@ -637,6 +635,9 @@ func Test_createRun(t *testing.T) { })) }, cmdStubs: func(cs *run.CommandStubber) { + cs.Register("git rev-parse --symbolic-full-name feature@{push}", 0, "refs/remotes/origin/feature") + cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 1, "") + cs.Register("git show-ref --verify -- HEAD refs/remotes/origin/feature", 1, "") cs.Register(`git push --set-upstream origin HEAD:refs/heads/feature`, 0, "") }, promptStubs: func(pm *prompter.PrompterMock) { @@ -651,6 +652,52 @@ func Test_createRun(t *testing.T) { expectedOut: "https://github.com/OWNER/REPO/pull/12\n", expectedErrOut: "\nCreating pull request for feature into master in OWNER/REPO\n\n", }, + { + name: "skip pushing to branch on prompt", + tty: true, + setup: func(opts *CreateOptions, t *testing.T) func() { + opts.TitleProvided = true + opts.BodyProvided = true + opts.Title = "my title" + opts.Body = "my body" + return func() {} + }, + httpStubs: func(reg *httpmock.Registry, t *testing.T) { + reg.StubRepoResponse("OWNER", "REPO") + reg.Register( + httpmock.GraphQL(`query UserCurrent\b`), + httpmock.StringResponse(`{"data": {"viewer": {"login": "OWNER"} } }`)) + reg.Register( + httpmock.GraphQL(`mutation PullRequestCreate\b`), + httpmock.GraphQLMutation(` + { "data": { "createPullRequest": { "pullRequest": { + "URL": "https://github.com/OWNER/REPO/pull/12" + } } } }`, func(input map[string]interface{}) { + assert.Equal(t, "REPOID", input["repositoryId"].(string)) + assert.Equal(t, "my title", input["title"].(string)) + assert.Equal(t, "my body", input["body"].(string)) + assert.Equal(t, "master", input["baseRefName"].(string)) + assert.Equal(t, "feature", input["headRefName"].(string)) + assert.Equal(t, false, input["draft"].(bool)) + })) + }, + cmdStubs: func(cs *run.CommandStubber) { + cs.Register("git rev-parse --symbolic-full-name feature@{push}", 0, "refs/remotes/origin/feature") + cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 1, "") + cs.Register("git show-ref --verify -- HEAD refs/remotes/origin/feature", 1, "") + }, + promptStubs: func(pm *prompter.PrompterMock) { + pm.SelectFunc = func(p, _ string, opts []string) (int, error) { + if p == "Where should we push the 'feature' branch?" { + return prompter.IndexFor(opts, "Skip pushing the branch") + } else { + return -1, prompter.NoSuchPromptErr(p) + } + } + }, + expectedOut: "https://github.com/OWNER/REPO/pull/12\n", + expectedErrOut: "\nCreating pull request for feature into master in OWNER/REPO\n\n", + }, { name: "project v2", tty: true, @@ -699,6 +746,9 @@ func Test_createRun(t *testing.T) { })) }, cmdStubs: func(cs *run.CommandStubber) { + cs.Register("git rev-parse --symbolic-full-name feature@{push}", 0, "refs/remotes/origin/feature") + cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 1, "") + cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 1, "") cs.Register(`git push --set-upstream origin HEAD:refs/heads/feature`, 0, "") }, promptStubs: func(pm *prompter.PrompterMock) { @@ -744,6 +794,9 @@ func Test_createRun(t *testing.T) { })) }, cmdStubs: func(cs *run.CommandStubber) { + cs.Register("git rev-parse --symbolic-full-name feature@{push}", 0, "refs/remotes/origin/feature") + cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 1, "") + cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 1, "") cs.Register(`git push --set-upstream origin HEAD:refs/heads/feature`, 0, "") }, promptStubs: func(pm *prompter.PrompterMock) { @@ -791,12 +844,11 @@ func Test_createRun(t *testing.T) { assert.Equal(t, "monalisa:feature", input["headRefName"].(string)) })) }, - customPushDestination: true, cmdStubs: func(cs *run.CommandStubber) { - cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 0, "") - cs.Register("git rev-parse --abbrev-ref feature@{push}", 0, "") - cs.Register("git config remote.pushDefault", 0, "") - cs.Register("git config push.default", 0, "") + cs.Register("git rev-parse --symbolic-full-name feature@{push}", 1, "") + cs.Register("git config remote.pushDefault", 1, "") + cs.Register("git config push.default", 1, "") + cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 1, "") cs.Register("git remote rename origin upstream", 0, "") cs.Register(`git remote add origin https://github.com/monalisa/REPO.git`, 0, "") cs.Register(`git push --set-upstream origin HEAD:refs/heads/feature`, 0, "") @@ -854,15 +906,11 @@ func Test_createRun(t *testing.T) { assert.Equal(t, "monalisa:feature", input["headRefName"].(string)) })) }, - customPushDestination: true, cmdStubs: func(cs *run.CommandStubber) { - cs.Register("git show-ref --verify", 0, heredoc.Doc(` + cs.Register("git rev-parse --symbolic-full-name feature@{push}", 0, "refs/remotes/origin/feature") + cs.Register("git show-ref --verify -- HEAD refs/remotes/origin/feature", 0, heredoc.Doc(` deadbeef HEAD - deadb00f refs/remotes/upstream/feature deadbeef refs/remotes/origin/feature`)) - cs.Register("git rev-parse --abbrev-ref feature@{push}", 0, "origin/feature") - cs.Register("git config remote.pushDefault", 0, "") - cs.Register("git config push.default", 0, "") }, expectedOut: "https://github.com/OWNER/REPO/pull/12\n", expectedErrOut: "\nCreating pull request for monalisa:feature into master in OWNER/REPO\n\n", @@ -890,20 +938,17 @@ func Test_createRun(t *testing.T) { assert.Equal(t, "my-feat2", input["headRefName"].(string)) })) }, - customBranchConfig: true, - customPushDestination: true, + customBranchConfig: true, cmdStubs: func(cs *run.CommandStubber) { cs.Register(`git config --get-regexp \^branch\\\.feature\\\.`, 0, heredoc.Doc(` branch.feature.remote origin branch.feature.merge refs/heads/my-feat2 - `)) // determineTrackingBranch - cs.Register("git show-ref --verify", 0, heredoc.Doc(` + `)) + cs.Register("git rev-parse --symbolic-full-name feature@{push}", 0, "refs/remotes/origin/my-feat2") + cs.Register("git show-ref --verify -- HEAD refs/remotes/origin/my-feat2", 0, heredoc.Doc(` deadbeef HEAD deadbeef refs/remotes/origin/my-feat2 - `)) // determineTrackingBranch - cs.Register("git rev-parse --abbrev-ref feature@{push}", 0, "origin/my-feat2") - cs.Register("git config remote.pushDefault", 0, "") - cs.Register("git config push.default", 0, "") + `)) }, expectedOut: "https://github.com/OWNER/REPO/pull/12\n", expectedErrOut: "\nCreating pull request for my-feat2 into master in OWNER/REPO\n\n", @@ -1084,6 +1129,9 @@ func Test_createRun(t *testing.T) { }, cmdStubs: func(cs *run.CommandStubber) { cs.Register(`git( .+)? log( .+)? origin/master\.\.\.feature`, 0, "") + cs.Register("git rev-parse --symbolic-full-name feature@{push}", 0, "refs/remotes/origin/feature") + cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 1, "") + cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 1, "") cs.Register(`git push --set-upstream origin HEAD:refs/heads/feature`, 0, "") }, promptStubs: func(pm *prompter.PrompterMock) { @@ -1115,6 +1163,9 @@ func Test_createRun(t *testing.T) { }, cmdStubs: func(cs *run.CommandStubber) { cs.Register(`git( .+)? log( .+)? origin/master\.\.\.feature`, 0, "") + cs.Register("git rev-parse --symbolic-full-name feature@{push}", 0, "refs/remotes/origin/feature") + cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 1, "") + cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 1, "") cs.Register(`git push --set-upstream origin HEAD:refs/heads/feature`, 0, "") }, promptStubs: func(pm *prompter.PrompterMock) { @@ -1279,37 +1330,6 @@ func Test_createRun(t *testing.T) { }, wantErr: "cannot open in browser: maximum URL length exceeded", }, - { - name: "no local git repo", - setup: func(opts *CreateOptions, t *testing.T) func() { - opts.Title = "My PR" - opts.TitleProvided = true - opts.Body = "" - opts.BodyProvided = true - opts.HeadBranch = "feature" - opts.RepoOverride = "OWNER/REPO" - opts.Remotes = func() (context.Remotes, error) { - return nil, errors.New("not a git repository") - } - return func() {} - }, - httpStubs: func(reg *httpmock.Registry, t *testing.T) { - reg.Register( - httpmock.GraphQL(`mutation PullRequestCreate\b`), - httpmock.StringResponse(` - { "data": { "createPullRequest": { "pullRequest": { - "URL": "https://github.com/OWNER/REPO/pull/12" - } } } } - `)) - }, - customPushDestination: true, - cmdStubs: func(cs *run.CommandStubber) { - cs.Register("git rev-parse --abbrev-ref feature@{push}", 1, "fatal: not a git repository (or any of the parent directories): .git") - cs.Register("git config remote.pushDefault", 1, "") - cs.Register("git config push.default", 1, "") - }, - expectedOut: "https://github.com/OWNER/REPO/pull/12\n", - }, { name: "single commit title and body are used", tty: true, @@ -1528,20 +1548,16 @@ func Test_createRun(t *testing.T) { assert.Equal(t, "monalisa:task1", input["headRefName"].(string)) })) }, - customBranchConfig: true, - customPushDestination: true, + customBranchConfig: true, cmdStubs: func(cs *run.CommandStubber) { cs.Register(`git config --get-regexp \^branch\\\.task1\\\.\(remote\|merge\|pushremote\|gh-merge-base\)\$`, 0, heredoc.Doc(` branch.task1.remote origin branch.task1.merge refs/heads/task1 branch.task1.gh-merge-base feature/feat2`)) // ReadBranchConfig - cs.Register(`git show-ref --verify`, 0, heredoc.Doc(` + cs.Register("git rev-parse --symbolic-full-name task1@{push}", 0, "refs/remotes/origin/task1") + cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/task1`, 0, heredoc.Doc(` deadbeef HEAD - deadb00f refs/remotes/upstream/feature/feat2 - deadbeef refs/remotes/origin/task1`)) // determineTrackingBranch - cs.Register("git rev-parse --abbrev-ref task1@{push}", 0, "origin/task1") - cs.Register("git config remote.pushDefault", 0, "") - cs.Register("git config push.default", 0, "") + deadbeef refs/remotes/origin/task1`)) }, expectedOut: "https://github.com/OWNER/REPO/pull/12\n", expectedErrOut: "\nCreating pull request for monalisa:task1 into feature/feat2 in OWNER/REPO\n\n", @@ -1571,12 +1587,6 @@ func Test_createRun(t *testing.T) { opts.HeadBranch = "otherowner:feature" return func() {} }, - customPushDestination: true, - cmdStubs: func(cs *run.CommandStubber) { - cs.Register("git rev-parse --abbrev-ref feature@{push}", 0, "origin/feature") - cs.Register("git config remote.pushDefault", 0, "") - cs.Register("git config push.default", 0, "") - }, expectedOut: "https://github.com/OWNER/REPO/pull/12\n", }, } @@ -1598,16 +1608,7 @@ func Test_createRun(t *testing.T) { cs, cmdTeardown := run.Stub() defer cmdTeardown(t) - cs.Register(`git status --porcelain`, 0, "") - // TODO this could be values in the test struct with a helper - // function to invoke the appropriate command stubs based on - // those values. - if !tt.customPushDestination { - cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 0, "") - cs.Register("git rev-parse --abbrev-ref feature@{push}", 0, "origin/feature") - cs.Register("git config remote.pushDefault", 0, "") - cs.Register("git config push.default", 0, "") - } + if !tt.customBranchConfig { cs.Register(`git config --get-regexp \^branch\\\..+\\\.\(remote\|merge\|pushremote\|gh-merge-base\)\$`, 0, "") } @@ -1658,6 +1659,10 @@ func Test_createRun(t *testing.T) { } defer cleanSetup() + if opts.HeadBranch == "" { + cs.Register(`git status --porcelain`, 0, "") + } + err := createRun(&opts) output := &test.CmdOut{ OutBuf: stdout, @@ -1681,6 +1686,168 @@ func Test_createRun(t *testing.T) { } } +func TestRemoteGuessing(t *testing.T) { + // Given git config does not provide the necessary info to determine a remote + cs, cmdTeardown := run.Stub() + defer cmdTeardown(t) + + cs.Register(`git status --porcelain`, 0, "") + cs.Register(`git config --get-regexp \^branch\\\..+\\\.\(remote\|merge\|pushremote\|gh-merge-base\)\$`, 0, "") + cs.Register(`git rev-parse --symbolic-full-name feature@{push}`, 1, "") + cs.Register("git config remote.pushDefault", 1, "") + cs.Register("git config push.default", 1, "") + + // And Given there is a remote on a SHA that matches the current HEAD + cs.Register(`git show-ref --verify -- HEAD refs/remotes/upstream/feature refs/remotes/origin/feature`, 0, heredoc.Doc(` + deadbeef HEAD + deadb00f refs/remotes/upstream/feature + deadbeef refs/remotes/origin/feature`)) + + // When the command is run + reg := &httpmock.Registry{} + reg.StubRepoInfoResponse("OWNER", "REPO", "master") + defer reg.Verify(t) + + reg.Register( + httpmock.GraphQL(`mutation PullRequestCreate\b`), + httpmock.GraphQLMutation(` + { "data": { "createPullRequest": { "pullRequest": { + "URL": "https://github.com/OWNER/REPO/pull/12" + } } } }`, func(input map[string]interface{}) { + assert.Equal(t, "REPOID", input["repositoryId"].(string)) + assert.Equal(t, "master", input["baseRefName"].(string)) + assert.Equal(t, "OTHEROWNER:feature", input["headRefName"].(string)) + })) + + ios, _, _, _ := iostreams.Test() + + opts := CreateOptions{ + HttpClient: func() (*http.Client, error) { + return &http.Client{Transport: reg}, nil + }, + Config: func() (gh.Config, error) { + return config.NewBlankConfig(), nil + }, + Browser: &browser.Stub{}, + IO: ios, + Prompter: &prompter.PrompterMock{}, + GitClient: &git.Client{ + GhPath: "some/path/gh", + GitPath: "some/path/git", + }, + Finder: shared.NewMockFinder("feature", nil, nil), + Remotes: func() (context.Remotes, error) { + return context.Remotes{ + { + Remote: &git.Remote{ + Name: "upstream", + Resolved: "base", + }, + Repo: ghrepo.New("OWNER", "REPO"), + }, + { + Remote: &git.Remote{ + Name: "origin", + }, + Repo: ghrepo.New("OTHEROWNER", "REPO-FORK"), + }, + }, nil + }, + Branch: func() (string, error) { + return "feature", nil + }, + + TitleProvided: true, + BodyProvided: true, + Title: "my title", + Body: "my body", + } + + require.NoError(t, createRun(&opts)) + + // Then guessed remote is used for the PR head, + // which annoyingly, is asserted above on the line: + // assert.Equal(t, "OTHEROWNER:feature", input["headRefName"].(string)) + // + // This is because OTHEROWNER relates to the "origin" remote, which has a + // SHA that matches the HEAD ref in the `git show-ref` output. +} + +func TestNoRepoCanBeDetermined(t *testing.T) { + // Given no head repo can be determined from git config + cs, cmdTeardown := run.Stub() + defer cmdTeardown(t) + + cs.Register(`git status --porcelain`, 0, "") + cs.Register(`git config --get-regexp \^branch\\\..+\\\.\(remote\|merge\|pushremote\|gh-merge-base\)\$`, 0, "") + cs.Register(`git rev-parse --symbolic-full-name feature@{push}`, 1, "") + cs.Register("git config remote.pushDefault", 1, "") + cs.Register("git config push.default", 1, "") + + // And Given there is no remote on the correct SHA + cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 0, heredoc.Doc(` + deadbeef HEAD + deadb00f refs/remotes/origin/feature`)) + + // When the command is run with no TTY + reg := &httpmock.Registry{} + reg.StubRepoInfoResponse("OWNER", "REPO", "master") + defer reg.Verify(t) + + ios, _, _, stderr := iostreams.Test() + + opts := CreateOptions{ + HttpClient: func() (*http.Client, error) { + return &http.Client{Transport: reg}, nil + }, + Config: func() (gh.Config, error) { + return config.NewBlankConfig(), nil + }, + Browser: &browser.Stub{}, + IO: ios, + Prompter: &prompter.PrompterMock{}, + GitClient: &git.Client{ + GhPath: "some/path/gh", + GitPath: "some/path/git", + }, + Finder: shared.NewMockFinder("feature", nil, nil), + Remotes: func() (context.Remotes, error) { + return context.Remotes{ + { + Remote: &git.Remote{ + Name: "origin", + Resolved: "base", + }, + Repo: ghrepo.New("OWNER", "REPO"), + }, + }, nil + }, + Branch: func() (string, error) { + return "feature", nil + }, + + TitleProvided: true, + BodyProvided: true, + Title: "my title", + Body: "my body", + } + + // When we run the command + err := createRun(&opts) + + // Then create fails + require.Equal(t, cmdutil.SilentError, err) + assert.Equal(t, "aborted: you must first push the current branch to a remote, or use the --head flag\n", stderr.String()) +} + +func mustParseQualifiedHeadRef(ref string) shared.QualifiedHeadRef { + parsed, err := shared.ParseQualifiedHeadRef(ref) + if err != nil { + panic(err) + } + return parsed +} + func Test_generateCompareURL(t *testing.T) { tests := []struct { name string @@ -1692,12 +1859,13 @@ func Test_generateCompareURL(t *testing.T) { { name: "basic", ctx: CreateContext{ - PrRefs: shared.PullRequestRefs{ - BaseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"), - HeadRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"), - BranchName: "feature", + PRRefs: &skipPushRefs{ + qualifiedHeadRef: shared.NewQualifiedHeadRefWithoutOwner("feature"), + baseRefs: baseRefs{ + baseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"), + baseBranchName: "main", + }, }, - BaseBranch: "main", }, want: "https://github.com/OWNER/REPO/compare/main...feature?body=&expand=1", wantErr: false, @@ -1705,12 +1873,13 @@ func Test_generateCompareURL(t *testing.T) { { name: "with labels", ctx: CreateContext{ - PrRefs: shared.PullRequestRefs{ - BaseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"), - HeadRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"), - BranchName: "b", + PRRefs: &skipPushRefs{ + qualifiedHeadRef: shared.NewQualifiedHeadRefWithoutOwner("b"), + baseRefs: baseRefs{ + baseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"), + baseBranchName: "a", + }, }, - BaseBranch: "a", }, state: shared.IssueMetadataState{ Labels: []string{"one", "two three"}, @@ -1721,12 +1890,13 @@ func Test_generateCompareURL(t *testing.T) { { name: "'/'s in branch names/labels are percent-encoded", ctx: CreateContext{ - PrRefs: shared.PullRequestRefs{ - BaseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "UPSTREAMOWNER"}}, "github.com"), - HeadRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "ORIGINOWNER"}}, "github.com"), - BranchName: "feature", + PRRefs: &skipPushRefs{ + qualifiedHeadRef: mustParseQualifiedHeadRef("ORIGINOWNER:feature"), + baseRefs: baseRefs{ + baseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "UPSTREAMOWNER"}}, "github.com"), + baseBranchName: "main/trunk", + }, }, - BaseBranch: "main/trunk", }, want: "https://github.com/UPSTREAMOWNER/REPO/compare/main%2Ftrunk...ORIGINOWNER:feature?body=&expand=1", wantErr: false, @@ -1734,18 +1904,19 @@ func Test_generateCompareURL(t *testing.T) { { name: "Any of !'(),; but none of $&+=@ and : in branch names/labels are percent-encoded ", /* - - Technically, per section 3.3 of RFC 3986, none of !$&'()*+,;= (sub-delims) and :[]@ (part of gen-delims) in path segments are optionally percent-encoded, but url.PathEscape percent-encodes !'(),; anyway - - !$&'()+,;=@ is a valid Git branch name—essentially RFC 3986 sub-delims without * and gen-delims without :/?#[] - - : is GitHub separator between a fork name and a branch name - - See https://github.com/golang/go/issues/27559. + - Technically, per section 3.3 of RFC 3986, none of !$&'()*+,;= (sub-delims) and :[]@ (part of gen-delims) in path segments are optionally percent-encoded, but url.PathEscape percent-encodes !'(),; anyway + - !$&'()+,;=@ is a valid Git branch name—essentially RFC 3986 sub-delims without * and gen-delims without :/?#[] + - : is GitHub separator between a fork name and a branch name + - See https://github.com/golang/go/issues/27559. */ ctx: CreateContext{ - PrRefs: shared.PullRequestRefs{ - BaseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "UPSTREAMOWNER"}}, "github.com"), - HeadRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "ORIGINOWNER"}}, "github.com"), - BranchName: "!$&'()+,;=@", + PRRefs: &skipPushRefs{ + qualifiedHeadRef: mustParseQualifiedHeadRef("ORIGINOWNER:!$&'()+,;=@"), + baseRefs: baseRefs{ + baseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "UPSTREAMOWNER"}}, "github.com"), + baseBranchName: "main/trunk", + }, }, - BaseBranch: "main/trunk", }, want: "https://github.com/UPSTREAMOWNER/REPO/compare/main%2Ftrunk...ORIGINOWNER:%21$&%27%28%29+%2C%3B=@?body=&expand=1", wantErr: false, @@ -1753,12 +1924,13 @@ func Test_generateCompareURL(t *testing.T) { { name: "with template", ctx: CreateContext{ - PrRefs: shared.PullRequestRefs{ - BaseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"), - HeadRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"), - BranchName: "feature", + PRRefs: &skipPushRefs{ + qualifiedHeadRef: shared.NewQualifiedHeadRefWithoutOwner("feature"), + baseRefs: baseRefs{ + baseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"), + baseBranchName: "main", + }, }, - BaseBranch: "main", }, state: shared.IssueMetadataState{ Template: "story.md", diff --git a/pkg/cmd/pr/shared/find_refs_resolution.go b/pkg/cmd/pr/shared/find_refs_resolution.go new file mode 100644 index 000000000..833075af8 --- /dev/null +++ b/pkg/cmd/pr/shared/find_refs_resolution.go @@ -0,0 +1,394 @@ +package shared + +import ( + "context" + "fmt" + "net/url" + "strings" + + ghContext "github.com/cli/cli/v2/context" + "github.com/cli/cli/v2/git" + + "github.com/cli/cli/v2/internal/ghrepo" + o "github.com/cli/cli/v2/pkg/option" +) + +// QualifiedHeadRef represents a git branch with an optional owner, used +// for the head of a pull request. For example, within a single repository, +// we would expect a PR to have a head ref of no owner, and a branch name. +// However, for cross-repository pull requests, we would expect a head ref +// with an owner and a branch name. In string form this is represented as +// :. The GitHub API is able to interpret this format in order +// to discover the correct fork repository. +// +// In other parts of the code, you may see this refered to as a HeadLabel. +type QualifiedHeadRef struct { + owner o.Option[string] + branchName string +} + +// NewQualifiedHeadRef creates a QualifiedHeadRef. If the empty string is provided +// for the owner, it will be treated as None. +func NewQualifiedHeadRef(owner string, branchName string) QualifiedHeadRef { + return QualifiedHeadRef{ + owner: o.SomeIfNonZero(owner), + branchName: branchName, + } +} + +func NewQualifiedHeadRefWithoutOwner(branchName string) QualifiedHeadRef { + return QualifiedHeadRef{ + owner: o.None[string](), + branchName: branchName, + } +} + +// ParseQualifiedHeadRef takes strings of the form : or +// and returns a QualifiedHeadRef. If the form : is used, +// the owner is set to the value of , and the branch name is set to +// the value of . If the form is used, the owner is set to +// None, and the branch name is set to the value of . +// +// This does no further error checking about the validity of a ref, so +// it is not safe to assume the ref is truly a valid ref, e.g. "my~bad:ref?" +// is going to result in a nonsense result. +func ParseQualifiedHeadRef(ref string) (QualifiedHeadRef, error) { + if !strings.Contains(ref, ":") { + return NewQualifiedHeadRefWithoutOwner(ref), nil + } + + parts := strings.Split(ref, ":") + if len(parts) != 2 { + return QualifiedHeadRef{}, fmt.Errorf("invalid qualified head ref format '%s'", ref) + } + + return NewQualifiedHeadRef(parts[0], parts[1]), nil +} + +// A QualifiedHeadRef without an owner returns , while a QualifiedHeadRef +// with an owner returns :. +func (r QualifiedHeadRef) String() string { + if owner, present := r.owner.Value(); present { + return fmt.Sprintf("%s:%s", owner, r.branchName) + } + return r.branchName +} + +func (r QualifiedHeadRef) BranchName() string { + return r.branchName +} + +// PRFindRefs represents the necessary data to find a pull request from the API. +type PRFindRefs struct { + qualifiedHeadRef QualifiedHeadRef + + baseRepo ghrepo.Interface + // baseBranchName is an optional branch name, because it is not required for + // finding a pull request, only for disambiguation if multiple pull requests + // contain the same head ref. + baseBranchName o.Option[string] +} + +// QualifiedHeadRef returns a stringified form of the head ref, varying depending +// on whether the head ref is in the same repository as the base ref. If they are +// the same repository, we return the branch name only. If they are different repositories, +// we return the owner and branch name in the form :. +func (r PRFindRefs) QualifiedHeadRef() string { + return r.qualifiedHeadRef.String() +} + +func (r PRFindRefs) UnqualifiedHeadRef() string { + return r.qualifiedHeadRef.BranchName() +} + +// Matches checks whether the provided baseBranchName and headRef match the refs. +// It is used to determine whether Pull Requests returned from the API +func (r PRFindRefs) Matches(baseBranchName, qualifiedHeadRef string) bool { + headMatches := qualifiedHeadRef == r.QualifiedHeadRef() + baseMatches := r.baseBranchName.IsNone() || baseBranchName == r.baseBranchName.Unwrap() + return headMatches && baseMatches +} + +func (r PRFindRefs) BaseRepo() ghrepo.Interface { + return r.baseRepo +} + +type RemoteNameToRepoFn func(remoteName string) (ghrepo.Interface, error) + +// PullRequestFindRefsResolver interrogates git configuration to try and determine +// a head repository and a remote branch name, from a local branch name. +type PullRequestFindRefsResolver struct { + GitConfigClient GitConfigClient + RemoteNameToRepoFn RemoteNameToRepoFn +} + +func NewPullRequestFindRefsResolver(gitConfigClient GitConfigClient, remotesFn func() (ghContext.Remotes, error)) PullRequestFindRefsResolver { + return PullRequestFindRefsResolver{ + GitConfigClient: gitConfigClient, + RemoteNameToRepoFn: newRemoteNameToRepoFn(remotesFn), + } +} + +// ResolvePullRequests takes a base repository, a base branch name and a local branch name and uses the git configuration to +// determine the head repository and remote branch name. If we were unable to determine this from git, we default the head +// repository to the base repository. +func (r *PullRequestFindRefsResolver) ResolvePullRequestRefs(baseRepo ghrepo.Interface, baseBranchName, localBranchName string) (PRFindRefs, error) { + if baseRepo == nil { + return PRFindRefs{}, fmt.Errorf("find pull request ref resolution cannot be performed without a base repository") + } + + if localBranchName == "" { + return PRFindRefs{}, fmt.Errorf("find pull request ref resolution cannot be performed without a local branch name") + } + + headPRRef, err := TryDetermineDefaultPRHead(r.GitConfigClient, remoteToRepoResolver{r.RemoteNameToRepoFn}, localBranchName) + if err != nil { + return PRFindRefs{}, err + } + + // If the headRepo was resolved, we can just convert the response + // to refs and return it. + if headRepo, present := headPRRef.Repo.Value(); present { + qualifiedHeadRef := NewQualifiedHeadRefWithoutOwner(headPRRef.BranchName) + if !ghrepo.IsSame(headRepo, baseRepo) { + qualifiedHeadRef = NewQualifiedHeadRef(headRepo.RepoOwner(), headPRRef.BranchName) + } + + return PRFindRefs{ + qualifiedHeadRef: qualifiedHeadRef, + baseRepo: baseRepo, + baseBranchName: o.SomeIfNonZero(baseBranchName), + }, nil + } + + // If we didn't find a head repo, default to the base repo + return PRFindRefs{ + qualifiedHeadRef: NewQualifiedHeadRefWithoutOwner(headPRRef.BranchName), + baseRepo: baseRepo, + baseBranchName: o.SomeIfNonZero(baseBranchName), + }, nil +} + +// DefaultPRHead is a neighbour to defaultPushTarget, but instead of holding +// basic git remote information, it holds a resolved repository in `gh` terms. +// +// Since we may not be able to determine a default remote for a branch, this +// is also true of the resolved repository. +type DefaultPRHead struct { + Repo o.Option[ghrepo.Interface] + BranchName string +} + +// TryDetermineDefaultPRHead is a thin wrapper around determineDefaultPushTarget, which attempts to convert +// a present remote into a resolved repository. If the remote is not present, we indicate that to the caller +// by returning a None value for the repo. +func TryDetermineDefaultPRHead(gitClient GitConfigClient, remoteToRepo remoteToRepoResolver, branch string) (DefaultPRHead, error) { + pushTarget, err := tryDetermineDefaultPushTarget(gitClient, branch) + if err != nil { + return DefaultPRHead{}, err + } + + // If we have no remote, let the caller decide what to do by indicating that with a None. + if pushTarget.remote.IsNone() { + return DefaultPRHead{ + Repo: o.None[ghrepo.Interface](), + BranchName: pushTarget.branchName, + }, nil + } + + repo, err := remoteToRepo.resolve(pushTarget.remote.Unwrap()) + if err != nil { + return DefaultPRHead{}, err + } + + return DefaultPRHead{ + Repo: o.Some(repo), + BranchName: pushTarget.branchName, + }, nil +} + +// remote represents the value of the remote key in a branch's git configuration. +// This value may be a name or a URL, both of which are strings, but are unfortunately +// parsed by ReadBranchConfig into separate fields, allowing for illegal states to be +// created by accident. This is an attempt to indicate that they are mutally exclusive. +type remote interface{ sealedRemote() } + +type remoteName struct{ name string } + +func (rn remoteName) sealedRemote() {} + +type remoteURL struct{ url *url.URL } + +func (ru remoteURL) sealedRemote() {} + +// newRemoteNameToRepoFn takes a function that returns a list of remotes and +// returns a function that takes a remote name and returns the corresponding +// repository. It is a convenience function to call sites having to duplicate +// the same logic. +func newRemoteNameToRepoFn(remotesFn func() (ghContext.Remotes, error)) RemoteNameToRepoFn { + return func(remoteName string) (ghrepo.Interface, error) { + remotes, err := remotesFn() + if err != nil { + return nil, err + } + repo, err := remotes.FindByName(remoteName) + if err != nil { + return nil, err + } + return repo, nil + } +} + +// remoteToRepoResolver provides a utility method to resolve a remote (either name or URL) +// to a repo (ghrepo.Interface). +type remoteToRepoResolver struct { + remoteNameToRepo RemoteNameToRepoFn +} + +func NewRemoteToRepoResolver(remotesFn func() (ghContext.Remotes, error)) remoteToRepoResolver { + return remoteToRepoResolver{ + remoteNameToRepo: newRemoteNameToRepoFn(remotesFn), + } +} + +// resolve takes a remote and returns a repository representing it. +func (r remoteToRepoResolver) resolve(remote remote) (ghrepo.Interface, error) { + switch v := remote.(type) { + case remoteName: + repo, err := r.remoteNameToRepo(v.name) + if err != nil { + return nil, fmt.Errorf("could not resolve remote %q: %w", v.name, err) + } + return repo, nil + case remoteURL: + repo, err := ghrepo.FromURL(v.url) + if err != nil { + return nil, fmt.Errorf("could not parse remote URL %q: %w", v.url, err) + } + return repo, nil + default: + return nil, fmt.Errorf("unsupported remote type %T, value: %v", v, remote) + } +} + +// A defaultPushTarget represents the remote name or URL and a branch name +// that we would expect a branch to be pushed to if `git push` were run with +// no further arguments. This is the most likely place for the head of the PR +// to be, but it's not guaranteed. The user may have pushed to another branch +// directly via `git push :` and not set up tracking information. +// A branch name is always present. +// +// It's possible that we're unable to determine a remote, if the user had pushed directly +// to a URL for example `git push `, which is why it is optional. When present, +// the remote may either be a name or a URL. +type defaultPushTarget struct { + remote o.Option[remote] + branchName string +} + +// newDefaultPushTarget is a thin wrapper over defaultPushTarget to help with +// generic type inference, to reduce verbosity in repeating the parametric type. +func newDefaultPushTarget(remote remote, branchName string) defaultPushTarget { + return defaultPushTarget{ + remote: o.Some(remote), + branchName: branchName, + } +} + +// tryDetermineDefaultPushTarget uses git configuration to make a best guess about where a branch +// is pushed to, and where it would be pushed to if the user ran `git push` with no additional +// arguments. +// +// Firstly, it attempts to resolve the @{push} ref, which is the most reliable method, as this +// is what git uses to determine the remote tracking branch +// +// If this fails, we go through a series of steps to determine the remote: +// +// 1. check branch configuration for `branch..pushRemote = | ` +// 2. check remote configuration for `remote.pushDefault = ` +// 3. check branch configuration for `branch..remote = | ` +// +// If none of these are set, we indicate that we were unable to determine the +// remote by returning a None value for the remote. +// +// The branch name is always set. The default configuration for push.default (current) indicates +// that a git push should use the same remote branch name as the local branch name. If push.default +// is set to upstream or tracking (deprecated form of upstream), then we use the branch name from the merge ref. +func tryDetermineDefaultPushTarget(gitClient GitConfigClient, localBranchName string) (defaultPushTarget, error) { + // If @{push} resolves, then we have the remote tracking branch already, no problem. + if pushRevisionRef, err := gitClient.PushRevision(context.Background(), localBranchName); err == nil { + return newDefaultPushTarget(remoteName{pushRevisionRef.Remote}, pushRevisionRef.Branch), nil + } + + // But it doesn't always resolve, so we can suppress the error and move on to other means + // of determination. We'll first look at branch and remote configuration to make a determination. + branchConfig, err := gitClient.ReadBranchConfig(context.Background(), localBranchName) + if err != nil { + return defaultPushTarget{}, err + } + + pushDefault, err := gitClient.PushDefault(context.Background()) + if err != nil { + return defaultPushTarget{}, err + } + + // We assume the PR's branch name is the same as whatever was provided, unless the user has specified + // push.default = upstream or tracking, then we use the branch name from the merge ref. + remoteBranch := localBranchName + if pushDefault == git.PushDefaultUpstream || pushDefault == git.PushDefaultTracking { + remoteBranch = strings.TrimPrefix(branchConfig.MergeRef, "refs/heads/") + if remoteBranch == "" { + return defaultPushTarget{}, fmt.Errorf("could not determine remote branch name") + } + } + + // To get the remote, we look to the git config. It comes from one of the following, in order of precedence: + // 1. branch..pushRemote (which may be a name or a URL) + // 2. remote.pushDefault (which is a remote name) + // 3. branch..remote (which may be a name or a URL) + if branchConfig.PushRemoteName != "" { + return newDefaultPushTarget( + remoteName{branchConfig.PushRemoteName}, + remoteBranch, + ), nil + } + + if branchConfig.PushRemoteURL != nil { + return newDefaultPushTarget( + remoteURL{branchConfig.PushRemoteURL}, + remoteBranch, + ), nil + } + + remotePushDefault, err := gitClient.RemotePushDefault(context.Background()) + if err != nil { + return defaultPushTarget{}, err + } + + if remotePushDefault != "" { + return newDefaultPushTarget( + remoteName{remotePushDefault}, + remoteBranch, + ), nil + } + + if branchConfig.RemoteName != "" { + return newDefaultPushTarget( + remoteName{branchConfig.RemoteName}, + remoteBranch, + ), nil + } + + if branchConfig.RemoteURL != nil { + return newDefaultPushTarget( + remoteURL{branchConfig.RemoteURL}, + remoteBranch, + ), nil + } + + // If we couldn't find the remote, we'll indicate that to the caller via None. + return defaultPushTarget{ + remote: o.None[remote](), + branchName: remoteBranch, + }, nil +} diff --git a/pkg/cmd/pr/shared/find_refs_resolution_test.go b/pkg/cmd/pr/shared/find_refs_resolution_test.go new file mode 100644 index 000000000..8cbb62146 --- /dev/null +++ b/pkg/cmd/pr/shared/find_refs_resolution_test.go @@ -0,0 +1,508 @@ +package shared + +import ( + "errors" + "net/url" + "testing" + + ghContext "github.com/cli/cli/v2/context" + "github.com/cli/cli/v2/git" + "github.com/cli/cli/v2/internal/ghrepo" + o "github.com/cli/cli/v2/pkg/option" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestQualifiedHeadRef(t *testing.T) { + t.Parallel() + + testCases := []struct { + behavior string + ref string + expectedString string + expectedBranchName string + expectedError error + }{ + { + behavior: "when a branch is provided, the parsed qualified head ref only has a branch", + ref: "feature-branch", + expectedString: "feature-branch", + expectedBranchName: "feature-branch", + }, + { + behavior: "when an owner and branch are provided, the parsed qualified head ref has both", + ref: "owner:feature-branch", + expectedString: "owner:feature-branch", + expectedBranchName: "feature-branch", + }, + { + behavior: "when the structure cannot be interpreted correctly, an error is returned", + ref: "owner:feature-branch:extra", + expectedError: errors.New("invalid qualified head ref format 'owner:feature-branch:extra'"), + }, + } + + for _, tc := range testCases { + t.Run(tc.behavior, func(t *testing.T) { + t.Parallel() + + qualifiedHeadRef, err := ParseQualifiedHeadRef(tc.ref) + if tc.expectedError != nil { + require.Equal(t, tc.expectedError, err) + return + } + + require.NoError(t, err) + assert.Equal(t, tc.expectedString, qualifiedHeadRef.String()) + assert.Equal(t, tc.expectedBranchName, qualifiedHeadRef.BranchName()) + }) + } +} + +func TestPRFindRefs(t *testing.T) { + t.Parallel() + + t.Run("qualified head ref with owner", func(t *testing.T) { + t.Parallel() + + refs := PRFindRefs{ + qualifiedHeadRef: mustParseQualifiedHeadRef("forkowner:feature-branch"), + } + + require.Equal(t, "forkowner:feature-branch", refs.QualifiedHeadRef()) + require.Equal(t, "feature-branch", refs.UnqualifiedHeadRef()) + }) + + t.Run("qualified head ref without owner", func(t *testing.T) { + t.Parallel() + + refs := PRFindRefs{ + qualifiedHeadRef: mustParseQualifiedHeadRef("feature-branch"), + } + + require.Equal(t, "feature-branch", refs.QualifiedHeadRef()) + require.Equal(t, "feature-branch", refs.UnqualifiedHeadRef()) + }) + + t.Run("base repo", func(t *testing.T) { + t.Parallel() + + refs := PRFindRefs{ + baseRepo: ghrepo.New("owner", "repo"), + } + + require.True(t, ghrepo.IsSame(refs.BaseRepo(), ghrepo.New("owner", "repo")), "expected repos to be the same") + }) + + t.Run("matches", func(t *testing.T) { + t.Parallel() + + testCases := []struct { + behavior string + refs PRFindRefs + baseBranchName string + qualifiedHeadRef string + expectedMatch bool + }{ + { + behavior: "when qualified head refs don't match, returns false", + refs: PRFindRefs{ + qualifiedHeadRef: mustParseQualifiedHeadRef("owner:feature-branch"), + }, + baseBranchName: "feature-branch", + qualifiedHeadRef: "feature-branch", + expectedMatch: false, + }, + { + behavior: "when base branches don't match, returns false", + refs: PRFindRefs{ + qualifiedHeadRef: mustParseQualifiedHeadRef("feature-branch"), + baseBranchName: o.Some("not-main"), + }, + baseBranchName: "main", + qualifiedHeadRef: "feature-branch", + expectedMatch: false, + }, + { + behavior: "when head refs match and there is no base branch, returns true", + refs: PRFindRefs{ + qualifiedHeadRef: mustParseQualifiedHeadRef("feature-branch"), + baseBranchName: o.None[string](), + }, + baseBranchName: "main", + qualifiedHeadRef: "feature-branch", + expectedMatch: true, + }, + { + behavior: "when head refs match and base branches match, returns true", + refs: PRFindRefs{ + qualifiedHeadRef: mustParseQualifiedHeadRef("feature-branch"), + baseBranchName: o.Some("main"), + }, + baseBranchName: "main", + qualifiedHeadRef: "feature-branch", + expectedMatch: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.behavior, func(t *testing.T) { + t.Parallel() + + require.Equal(t, tc.expectedMatch, tc.refs.Matches(tc.baseBranchName, tc.qualifiedHeadRef)) + }) + } + }) +} + +func TestPullRequestResolution(t *testing.T) { + t.Parallel() + + baseRepo := ghrepo.New("owner", "repo") + baseRemote := ghContext.Remote{ + Remote: &git.Remote{ + Name: "upstream", + }, + Repo: ghrepo.New("owner", "repo"), + } + + forkRemote := ghContext.Remote{ + Remote: &git.Remote{ + Name: "origin", + }, + Repo: ghrepo.New("otherowner", "repo-fork"), + } + + t.Run("when the base repo is nil, returns an error", func(t *testing.T) { + t.Parallel() + + resolver := NewPullRequestFindRefsResolver(stubGitConfigClient{}, dummyRemotesFn) + _, err := resolver.ResolvePullRequestRefs(nil, "", "") + require.Error(t, err) + }) + + t.Run("when the local branch name is empty, returns an error", func(t *testing.T) { + t.Parallel() + + resolver := NewPullRequestFindRefsResolver(stubGitConfigClient{}, dummyRemotesFn) + _, err := resolver.ResolvePullRequestRefs(baseRepo, "", "") + require.Error(t, err) + }) + + t.Run("when the default pr head has a repo, it is used for the refs", func(t *testing.T) { + t.Parallel() + + // Push revision is the first thing checked for resolution, + // so nothing else needs to be stubbed. + repoResolvedFromPushRevisionClient := stubGitConfigClient{ + pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{ + Remote: "origin", + Branch: "feature-branch", + }, nil), + } + + resolver := NewPullRequestFindRefsResolver( + repoResolvedFromPushRevisionClient, + stubRemotes(ghContext.Remotes{&baseRemote, &forkRemote}, nil), + ) + + refs, err := resolver.ResolvePullRequestRefs(baseRepo, "main", "feature-branch") + require.NoError(t, err) + + expectedRefs := PRFindRefs{ + qualifiedHeadRef: QualifiedHeadRef{ + owner: o.Some("otherowner"), + branchName: "feature-branch", + }, + baseRepo: baseRepo, + baseBranchName: o.Some("main"), + } + + require.Equal(t, expectedRefs, refs) + }) + + t.Run("when the default pr head does not have a repo, we use the base repo for the head", func(t *testing.T) { + t.Parallel() + + // All the values stubbed here result in being unable to resolve a default repo. + noRepoResolutionStubClient := stubGitConfigClient{ + pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("test error")), + readBranchConfigFn: stubBranchConfig(git.BranchConfig{}, nil), + pushDefaultFn: stubPushDefault("", nil), + remotePushDefaultFn: stubRemotePushDefault("", nil), + } + + resolver := NewPullRequestFindRefsResolver( + noRepoResolutionStubClient, + stubRemotes(ghContext.Remotes{&baseRemote, &forkRemote}, nil), + ) + + refs, err := resolver.ResolvePullRequestRefs(baseRepo, "main", "feature-branch") + require.NoError(t, err) + + expectedRefs := PRFindRefs{ + qualifiedHeadRef: QualifiedHeadRef{ + owner: o.None[string](), + branchName: "feature-branch", + }, + baseRepo: baseRepo, + baseBranchName: o.Some("main"), + } + require.Equal(t, expectedRefs, refs) + }) +} + +func TestTryDetermineDefaultPRHead(t *testing.T) { + t.Parallel() + + baseRepo := ghrepo.New("owner", "repo") + baseRemote := ghContext.Remote{ + Remote: &git.Remote{ + Name: "upstream", + }, + Repo: baseRepo, + } + + forkRepo := ghrepo.New("otherowner", "repo-fork") + forkRemote := ghContext.Remote{ + Remote: &git.Remote{ + Name: "origin", + }, + Repo: forkRepo, + } + forkRepoURL, err := url.Parse("https://github.com/otherowner/repo-fork.git") + require.NoError(t, err) + + t.Run("when the push revision is set, use that", func(t *testing.T) { + t.Parallel() + + repoResolvedFromPushRevisionClient := stubGitConfigClient{ + pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{ + Remote: "origin", + Branch: "remote-feature-branch", + }, nil), + } + + defaultPRHead, err := TryDetermineDefaultPRHead( + repoResolvedFromPushRevisionClient, + stubRemoteToRepoResolver(ghContext.Remotes{&baseRemote, &forkRemote}, nil), + "feature-branch", + ) + require.NoError(t, err) + + require.True(t, ghrepo.IsSame(defaultPRHead.Repo.Unwrap(), forkRepo), "expected repos to be the same") + require.Equal(t, "remote-feature-branch", defaultPRHead.BranchName) + }) + + t.Run("when the branch config push remote is set to a name, use that", func(t *testing.T) { + t.Parallel() + + repoResolvedFromPushRemoteClient := stubGitConfigClient{ + pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("no push revision")), + readBranchConfigFn: stubBranchConfig(git.BranchConfig{ + PushRemoteName: "origin", + }, nil), + pushDefaultFn: stubPushDefault(git.PushDefaultCurrent, nil), + } + + defaultPRHead, err := TryDetermineDefaultPRHead( + repoResolvedFromPushRemoteClient, + stubRemoteToRepoResolver(ghContext.Remotes{&baseRemote, &forkRemote}, nil), + "feature-branch", + ) + require.NoError(t, err) + + require.True(t, ghrepo.IsSame(defaultPRHead.Repo.Unwrap(), forkRepo), "expected repos to be the same") + require.Equal(t, "feature-branch", defaultPRHead.BranchName) + }) + + t.Run("when the branch config push remote is set to a URL, use that", func(t *testing.T) { + t.Parallel() + + repoResolvedFromPushRemoteClient := stubGitConfigClient{ + pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("no push revision")), + readBranchConfigFn: stubBranchConfig(git.BranchConfig{ + PushRemoteURL: forkRepoURL, + }, nil), + pushDefaultFn: stubPushDefault(git.PushDefaultCurrent, nil), + } + + defaultPRHead, err := TryDetermineDefaultPRHead( + repoResolvedFromPushRemoteClient, + dummyRemoteToRepoResolver(), + "feature-branch", + ) + require.NoError(t, err) + + require.True(t, ghrepo.IsSame(defaultPRHead.Repo.Unwrap(), forkRepo), "expected repos to be the same") + require.Equal(t, "feature-branch", defaultPRHead.BranchName) + }) + + t.Run("when a remote push default is set, use that", func(t *testing.T) { + t.Parallel() + + repoResolvedFromPushRemoteClient := stubGitConfigClient{ + pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("no push revision")), + readBranchConfigFn: stubBranchConfig(git.BranchConfig{}, nil), + pushDefaultFn: stubPushDefault(git.PushDefaultCurrent, nil), + remotePushDefaultFn: stubRemotePushDefault("origin", nil), + } + + defaultPRHead, err := TryDetermineDefaultPRHead( + repoResolvedFromPushRemoteClient, + stubRemoteToRepoResolver(ghContext.Remotes{&baseRemote, &forkRemote}, nil), + "feature-branch", + ) + require.NoError(t, err) + + require.True(t, ghrepo.IsSame(defaultPRHead.Repo.Unwrap(), forkRepo), "expected repos to be the same") + require.Equal(t, "feature-branch", defaultPRHead.BranchName) + }) + + t.Run("when the branch config remote is set to a name, use that", func(t *testing.T) { + t.Parallel() + + repoResolvedFromPushRemoteClient := stubGitConfigClient{ + pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("no push revision")), + readBranchConfigFn: stubBranchConfig(git.BranchConfig{ + RemoteName: "origin", + }, nil), + pushDefaultFn: stubPushDefault(git.PushDefaultCurrent, nil), + remotePushDefaultFn: stubRemotePushDefault("", nil), + } + + defaultPRHead, err := TryDetermineDefaultPRHead( + repoResolvedFromPushRemoteClient, + stubRemoteToRepoResolver(ghContext.Remotes{&baseRemote, &forkRemote}, nil), + "feature-branch", + ) + require.NoError(t, err) + + require.True(t, ghrepo.IsSame(defaultPRHead.Repo.Unwrap(), forkRepo), "expected repos to be the same") + require.Equal(t, "feature-branch", defaultPRHead.BranchName) + }) + + t.Run("when the branch config remote is set to a URL, use that", func(t *testing.T) { + t.Parallel() + + repoResolvedFromPushRemoteClient := stubGitConfigClient{ + pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("no push revision")), + readBranchConfigFn: stubBranchConfig(git.BranchConfig{ + RemoteURL: forkRepoURL, + }, nil), + pushDefaultFn: stubPushDefault(git.PushDefaultCurrent, nil), + remotePushDefaultFn: stubRemotePushDefault("", nil), + } + + defaultPRHead, err := TryDetermineDefaultPRHead( + repoResolvedFromPushRemoteClient, + dummyRemoteToRepoResolver(), + "feature-branch", + ) + require.NoError(t, err) + + require.True(t, ghrepo.IsSame(defaultPRHead.Repo.Unwrap(), forkRepo), "expected repos to be the same") + require.Equal(t, "feature-branch", defaultPRHead.BranchName) + }) + + t.Run("when git didn't provide the necessary information, return none for the remote", func(t *testing.T) { + t.Parallel() + + // All the values stubbed here result in being unable to resolve a default repo. + noRepoResolutionStubClient := stubGitConfigClient{ + pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("test error")), + readBranchConfigFn: stubBranchConfig(git.BranchConfig{}, nil), + pushDefaultFn: stubPushDefault("", nil), + remotePushDefaultFn: stubRemotePushDefault("", nil), + } + + defaultPRHead, err := TryDetermineDefaultPRHead( + noRepoResolutionStubClient, + stubRemoteToRepoResolver(ghContext.Remotes{&baseRemote, &forkRemote}, nil), + "feature-branch", + ) + require.NoError(t, err) + + require.True(t, defaultPRHead.Repo.IsNone(), "expected repo to be none") + require.Equal(t, "feature-branch", defaultPRHead.BranchName) + }) + + t.Run("when the push default is tracking or upstream, use the merge ref", func(t *testing.T) { + t.Parallel() + + testCases := []struct { + pushDefault git.PushDefault + }{ + {pushDefault: git.PushDefaultTracking}, + {pushDefault: git.PushDefaultUpstream}, + } + + for _, tc := range testCases { + t.Run(string(tc.pushDefault), func(t *testing.T) { + t.Parallel() + + repoResolvedFromPushRemoteClient := stubGitConfigClient{ + pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("test error")), + readBranchConfigFn: stubBranchConfig(git.BranchConfig{ + PushRemoteName: "origin", + MergeRef: "main", + }, nil), + pushDefaultFn: stubPushDefault(tc.pushDefault, nil), + } + + defaultPRHead, err := TryDetermineDefaultPRHead( + repoResolvedFromPushRemoteClient, + stubRemoteToRepoResolver(ghContext.Remotes{&baseRemote, &forkRemote}, nil), + "feature-branch", + ) + require.NoError(t, err) + + require.True(t, ghrepo.IsSame(defaultPRHead.Repo.Unwrap(), forkRepo), "expected repos to be the same") + require.Equal(t, "main", defaultPRHead.BranchName) + }) + } + + t.Run("but if the merge ref is empty, error", func(t *testing.T) { + t.Parallel() + + repoResolvedFromPushRemoteClient := stubGitConfigClient{ + pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("test error")), + readBranchConfigFn: stubBranchConfig(git.BranchConfig{ + PushRemoteName: "origin", + MergeRef: "", // intentionally empty + }, nil), + pushDefaultFn: stubPushDefault(git.PushDefaultUpstream, nil), + } + + _, err := TryDetermineDefaultPRHead( + repoResolvedFromPushRemoteClient, + stubRemoteToRepoResolver(ghContext.Remotes{&baseRemote, &forkRemote}, nil), + "feature-branch", + ) + require.Error(t, err) + }) + }) + +} + +func dummyRemotesFn() (ghContext.Remotes, error) { + panic("remotes fn not implemented") +} + +func dummyRemoteToRepoResolver() remoteToRepoResolver { + return NewRemoteToRepoResolver(dummyRemotesFn) +} + +func stubRemoteToRepoResolver(remotes ghContext.Remotes, err error) remoteToRepoResolver { + return NewRemoteToRepoResolver(func() (ghContext.Remotes, error) { + return remotes, err + }) +} + +func mustParseQualifiedHeadRef(ref string) QualifiedHeadRef { + parsed, err := ParseQualifiedHeadRef(ref) + if err != nil { + panic(err) + } + return parsed +} diff --git a/pkg/cmd/pr/shared/finder.go b/pkg/cmd/pr/shared/finder.go index 7fed231cb..6d36ef816 100644 --- a/pkg/cmd/pr/shared/finder.go +++ b/pkg/cmd/pr/shared/finder.go @@ -13,11 +13,12 @@ import ( "time" "github.com/cli/cli/v2/api" - remotes "github.com/cli/cli/v2/context" + ghContext "github.com/cli/cli/v2/context" "github.com/cli/cli/v2/git" fd "github.com/cli/cli/v2/internal/featuredetection" "github.com/cli/cli/v2/internal/ghrepo" "github.com/cli/cli/v2/pkg/cmdutil" + o "github.com/cli/cli/v2/pkg/option" "github.com/cli/cli/v2/pkg/set" "github.com/shurcooL/githubv4" "golang.org/x/sync/errgroup" @@ -32,16 +33,20 @@ type progressIndicator interface { StopProgressIndicator() } +type GitConfigClient interface { + ReadBranchConfig(ctx context.Context, branchName string) (git.BranchConfig, error) + PushDefault(ctx context.Context) (git.PushDefault, error) + RemotePushDefault(ctx context.Context) (string, error) + PushRevision(ctx context.Context, branchName string) (git.RemoteTrackingRef, error) +} + type finder struct { - baseRepoFn func() (ghrepo.Interface, error) - branchFn func() (string, error) - remotesFn func() (remotes.Remotes, error) - httpClient func() (*http.Client, error) - pushDefault func() (string, error) - remotePushDefault func() (string, error) - parsePushRevision func(string) (string, error) - branchConfig func(string) (git.BranchConfig, error) - progress progressIndicator + baseRepoFn func() (ghrepo.Interface, error) + branchFn func() (string, error) + httpClient func() (*http.Client, error) + remotesFn func() (ghContext.Remotes, error) + gitConfigClient GitConfigClient + progress progressIndicator baseRefRepo ghrepo.Interface prNumber int @@ -56,23 +61,12 @@ func NewFinder(factory *cmdutil.Factory) PRFinder { } return &finder{ - baseRepoFn: factory.BaseRepo, - branchFn: factory.Branch, - remotesFn: factory.Remotes, - httpClient: factory.HttpClient, - pushDefault: func() (string, error) { - return factory.GitClient.PushDefault(context.Background()) - }, - remotePushDefault: func() (string, error) { - return factory.GitClient.RemotePushDefault(context.Background()) - }, - parsePushRevision: func(branch string) (string, error) { - return factory.GitClient.ParsePushRevision(context.Background(), branch) - }, - progress: factory.IOStreams, - branchConfig: func(s string) (git.BranchConfig, error) { - return factory.GitClient.ReadBranchConfig(context.Background(), s) - }, + baseRepoFn: factory.BaseRepo, + branchFn: factory.Branch, + httpClient: factory.HttpClient, + gitConfigClient: factory.GitClient, + remotesFn: factory.Remotes, + progress: factory.IOStreams, } } @@ -97,32 +91,6 @@ type FindOptions struct { States []string } -// TODO: Does this also need the BaseBranchName? -// PR's are represented by the following: -// headRef -----PR-----> baseRef -// -// A ref is described as "remoteName/branchName", so -// headRepoName/headBranchName -----PR-----> baseRepoName/baseBranchName -type PullRequestRefs struct { - BranchName string - HeadRepo ghrepo.Interface - BaseRepo ghrepo.Interface -} - -func (s *PullRequestRefs) HasHead() bool { - return s.HeadRepo != nil && s.BranchName != "" -} - -// GetPRHeadLabel returns the string that the GitHub API uses to identify the PR. This is -// either just the branch name or, if the PR is originating from a fork, the fork owner -// and the branch name, like :. -func (s *PullRequestRefs) GetPRHeadLabel() string { - if ghrepo.IsSame(s.HeadRepo, s.BaseRepo) { - return s.BranchName - } - return fmt.Sprintf("%s:%s", s.HeadRepo.RepoOwner(), s.BranchName) -} - func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, error) { // If we have a URL, we don't need git stuff if len(opts.Fields) == 0 { @@ -142,7 +110,7 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err f.baseRefRepo = repo } - var prRefs PullRequestRefs + var prRefs PRFindRefs if opts.Selector == "" { // You must be in a git repo for this case to work currentBranchName, err := f.branchFn() @@ -152,7 +120,7 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err f.branchName = currentBranchName // Get the branch config for the current branchName - branchConfig, err := f.branchConfig(f.branchName) + branchConfig, err := f.gitConfigClient.ReadBranchConfig(context.Background(), f.branchName) if err != nil { return nil, nil, err } @@ -166,30 +134,19 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err // Determine the PullRequestRefs from config if f.prNumber == 0 { - rems, err := f.remotesFn() - if err != nil { - return nil, nil, err - } - - // Suppressing these errors as we have other means of computing the PullRequestRefs when these fail. - parsedPushRevision, _ := f.parsePushRevision(f.branchName) - - pushDefault, err := f.pushDefault() - if err != nil { - return nil, nil, err - } - - remotePushDefault, err := f.remotePushDefault() - if err != nil { - return nil, nil, err - } - - prRefs, err = ParsePRRefs(f.branchName, branchConfig, parsedPushRevision, pushDefault, remotePushDefault, f.baseRefRepo, rems) + prRefsResolver := NewPullRequestFindRefsResolver( + // We requested the branch config already, so let's cache that + CachedBranchConfigGitConfigClient{ + CachedBranchConfig: branchConfig, + GitConfigClient: f.gitConfigClient, + }, + f.remotesFn, + ) + prRefs, err = prRefsResolver.ResolvePullRequestRefs(f.baseRefRepo, opts.BaseBranch, f.branchName) if err != nil { return nil, nil, err } } - } else if f.prNumber == 0 { // You gave me a selector but I couldn't find a PR number (it wasn't a URL) @@ -204,11 +161,17 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err f.prNumber = prNumber } else { f.branchName = opts.Selector - // We don't expect an error here because parsedPushRevision is empty - prRefs, err = ParsePRRefs(f.branchName, git.BranchConfig{}, "", "", "", f.baseRefRepo, remotes.Remotes{}) + + qualifiedHeadRef, err := ParseQualifiedHeadRef(f.branchName) if err != nil { return nil, nil, err } + + prRefs = PRFindRefs{ + qualifiedHeadRef: qualifiedHeadRef, + baseRepo: f.baseRefRepo, + baseBranchName: o.SomeIfNonZero(opts.BaseBranch), + } } } @@ -259,7 +222,7 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err return pr, f.baseRefRepo, err } } else { - pr, err = findForBranch(httpClient, f.baseRefRepo, opts.BaseBranch, prRefs.GetPRHeadLabel(), opts.States, fields.ToSlice()) + pr, err = findForRefs(httpClient, prRefs, opts.States, fields.ToSlice()) if err != nil { return pr, f.baseRefRepo, err } @@ -321,72 +284,6 @@ func (f *finder) parseURL(prURL string) (ghrepo.Interface, int, error) { return repo, prNumber, nil } -func ParsePRRefs(currentBranchName string, branchConfig git.BranchConfig, parsedPushRevision string, pushDefault string, remotePushDefault string, baseRefRepo ghrepo.Interface, rems remotes.Remotes) (PullRequestRefs, error) { - prRefs := PullRequestRefs{ - BaseRepo: baseRefRepo, - } - - // If @{push} resolves, then we have all the information we need to determine the head repo - // and branch name. It is of the form /. - if parsedPushRevision != "" { - for _, r := range rems { - // Find the remote who's name matches the push prefix - if strings.HasPrefix(parsedPushRevision, r.Name+"/") { - prRefs.BranchName = strings.TrimPrefix(parsedPushRevision, r.Name+"/") - prRefs.HeadRepo = r.Repo - return prRefs, nil - } - } - - remoteNames := make([]string, len(rems)) - for i, r := range rems { - remoteNames[i] = r.Name - } - return PullRequestRefs{}, fmt.Errorf("no remote for %q found in %q", parsedPushRevision, strings.Join(remoteNames, ", ")) - } - - // We assume the PR's branch name is the same as whatever f.BranchFn() returned earlier - // unless the user has specified push.default = upstream or tracking, then we use the - // branch name from the merge ref. - prRefs.BranchName = currentBranchName - if pushDefault == "upstream" || pushDefault == "tracking" { - prRefs.BranchName = strings.TrimPrefix(branchConfig.MergeRef, "refs/heads/") - } - - // To get the HeadRepo, we look to the git config. The HeadRepo comes from one of the following, in order of precedence: - // 1. branch..pushRemote - // 2. remote.pushDefault - // 3. branch..remote - if branchConfig.PushRemoteName != "" { - if r, err := rems.FindByName(branchConfig.PushRemoteName); err == nil { - prRefs.HeadRepo = r.Repo - } - } else if branchConfig.PushRemoteURL != nil { - if r, err := ghrepo.FromURL(branchConfig.PushRemoteURL); err == nil { - prRefs.HeadRepo = r - } - } else if remotePushDefault != "" { - if r, err := rems.FindByName(remotePushDefault); err == nil { - prRefs.HeadRepo = r.Repo - } - } else if branchConfig.RemoteName != "" { - if r, err := rems.FindByName(branchConfig.RemoteName); err == nil { - prRefs.HeadRepo = r.Repo - } - } else if branchConfig.RemoteURL != nil { - if r, err := ghrepo.FromURL(branchConfig.RemoteURL); err == nil { - prRefs.HeadRepo = r - } - } - - // The PR merges from a branch in the same repo as the base branch (usually the default branch) - if prRefs.HeadRepo == nil { - prRefs.HeadRepo = baseRefRepo - } - - return prRefs, nil -} - func findByNumber(httpClient *http.Client, repo ghrepo.Interface, number int, fields []string) (*api.PullRequest, error) { type response struct { Repository struct { @@ -417,7 +314,7 @@ func findByNumber(httpClient *http.Client, repo ghrepo.Interface, number int, fi return &resp.Repository.PullRequest, nil } -func findForBranch(httpClient *http.Client, repo ghrepo.Interface, baseBranch, headBranchWithOwnerIfFork string, stateFilters, fields []string) (*api.PullRequest, error) { +func findForRefs(httpClient *http.Client, prRefs PRFindRefs, stateFilters, fields []string) (*api.PullRequest, error) { type response struct { Repository struct { PullRequests struct { @@ -444,21 +341,16 @@ func findForBranch(httpClient *http.Client, repo ghrepo.Interface, baseBranch, h } }`, api.PullRequestGraphQL(fieldSet.ToSlice())) - branchWithoutOwner := headBranchWithOwnerIfFork - if idx := strings.Index(headBranchWithOwnerIfFork, ":"); idx >= 0 { - branchWithoutOwner = headBranchWithOwnerIfFork[idx+1:] - } - variables := map[string]interface{}{ - "owner": repo.RepoOwner(), - "repo": repo.RepoName(), - "headRefName": branchWithoutOwner, + "owner": prRefs.BaseRepo().RepoOwner(), + "repo": prRefs.BaseRepo().RepoName(), + "headRefName": prRefs.UnqualifiedHeadRef(), "states": stateFilters, } var resp response client := api.NewClientFromHTTP(httpClient) - err := client.GraphQL(repo.RepoHost(), query, variables, &resp) + err := client.GraphQL(prRefs.BaseRepo().RepoHost(), query, variables, &resp) if err != nil { return nil, err } @@ -469,17 +361,15 @@ func findForBranch(httpClient *http.Client, repo ghrepo.Interface, baseBranch, h }) for _, pr := range prs { - headBranchMatches := pr.HeadLabel() == headBranchWithOwnerIfFork - baseBranchEmptyOrMatches := baseBranch == "" || pr.BaseRefName == baseBranch // When the head is the default branch, it doesn't really make sense to show merged or closed PRs. // https://github.com/cli/cli/issues/4263 - isNotClosedOrMergedWhenHeadIsDefault := pr.State == "OPEN" || resp.Repository.DefaultBranchRef.Name != headBranchWithOwnerIfFork - if headBranchMatches && baseBranchEmptyOrMatches && isNotClosedOrMergedWhenHeadIsDefault { + isNotClosedOrMergedWhenHeadIsDefault := pr.State == "OPEN" || resp.Repository.DefaultBranchRef.Name != prRefs.QualifiedHeadRef() + if prRefs.Matches(pr.BaseRefName, pr.HeadLabel()) && isNotClosedOrMergedWhenHeadIsDefault { return &pr, nil } } - return nil, &NotFoundError{fmt.Errorf("no pull requests found for branch %q", headBranchWithOwnerIfFork)} + return nil, &NotFoundError{fmt.Errorf("no pull requests found for branch %q", prRefs.QualifiedHeadRef())} } func preloadPrReviews(httpClient *http.Client, repo ghrepo.Interface, pr *api.PullRequest) error { diff --git a/pkg/cmd/pr/shared/finder_test.go b/pkg/cmd/pr/shared/finder_test.go index 3349197e2..e1aae16b1 100644 --- a/pkg/cmd/pr/shared/finder_test.go +++ b/pkg/cmd/pr/shared/finder_test.go @@ -1,46 +1,41 @@ package shared import ( + "context" "errors" - "fmt" "net/http" "net/url" "testing" - "github.com/cli/cli/v2/context" + ghContext "github.com/cli/cli/v2/context" "github.com/cli/cli/v2/git" "github.com/cli/cli/v2/internal/ghrepo" "github.com/cli/cli/v2/pkg/httpmock" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) type args struct { - baseRepoFn func() (ghrepo.Interface, error) - branchFn func() (string, error) - branchConfig func(string) (git.BranchConfig, error) - pushDefault func() (string, error) - remotePushDefault func() (string, error) - parsePushRevision func(string) (string, error) - selector string - fields []string - baseBranch string + baseRepoFn func() (ghrepo.Interface, error) + branchFn func() (string, error) + gitConfigClient stubGitConfigClient + selector string + fields []string + baseBranch string } func TestFind(t *testing.T) { - // TODO: Abstract these out meaningfully for reuse in parsePRRefs tests originOwnerUrl, err := url.Parse("https://github.com/ORIGINOWNER/REPO.git") if err != nil { t.Fatal(err) } - remoteOrigin := context.Remote{ + remoteOrigin := ghContext.Remote{ Remote: &git.Remote{ Name: "origin", FetchURL: originOwnerUrl, }, Repo: ghrepo.New("ORIGINOWNER", "REPO"), } - remoteOther := context.Remote{ + remoteOther := ghContext.Remote{ Remote: &git.Remote{ Name: "other", FetchURL: originOwnerUrl, @@ -52,7 +47,7 @@ func TestFind(t *testing.T) { if err != nil { t.Fatal(err) } - remoteUpstream := context.Remote{ + remoteUpstream := ghContext.Remote{ Remote: &git.Remote{ Name: "upstream", FetchURL: upstreamOwnerUrl, @@ -77,7 +72,6 @@ func TestFind(t *testing.T) { branchFn: func() (string, error) { return "blueberries", nil }, - branchConfig: stubBranchConfig(git.BranchConfig{}, nil), }, httpStub: func(r *httpmock.Registry) { r.Register( @@ -99,12 +93,14 @@ func TestFind(t *testing.T) { branchFn: func() (string, error) { return "blueberries", nil }, - branchConfig: stubBranchConfig(git.BranchConfig{ - PushRemoteName: remoteOrigin.Remote.Name, - }, nil), - pushDefault: stubPushDefault("simple", nil), - remotePushDefault: stubRemotePushDefault("", nil), - parsePushRevision: stubParsedPushRevision("", nil), + gitConfigClient: stubGitConfigClient{ + readBranchConfigFn: stubBranchConfig(git.BranchConfig{ + PushRemoteName: remoteOrigin.Remote.Name, + }, nil), + pushDefaultFn: stubPushDefault(git.PushDefaultSimple, nil), + remotePushDefaultFn: stubRemotePushDefault("", nil), + pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("testErr")), + }, }, httpStub: func(r *httpmock.Registry) { r.Register( @@ -134,9 +130,11 @@ func TestFind(t *testing.T) { branchFn: func() (string, error) { return "blueberries", nil }, - branchConfig: stubBranchConfig(git.BranchConfig{}, nil), - pushDefault: stubPushDefault("simple", nil), - remotePushDefault: stubRemotePushDefault("", nil), + gitConfigClient: stubGitConfigClient{ + readBranchConfigFn: stubBranchConfig(git.BranchConfig{}, nil), + pushDefaultFn: stubPushDefault(git.PushDefaultSimple, nil), + remotePushDefaultFn: stubRemotePushDefault("", nil), + }, }, wantErr: true, }, @@ -157,9 +155,11 @@ func TestFind(t *testing.T) { branchFn: func() (string, error) { return "blueberries", nil }, - branchConfig: stubBranchConfig(git.BranchConfig{}, nil), - pushDefault: stubPushDefault("simple", nil), - remotePushDefault: stubRemotePushDefault("", nil), + gitConfigClient: stubGitConfigClient{ + readBranchConfigFn: stubBranchConfig(git.BranchConfig{}, nil), + pushDefaultFn: stubPushDefault(git.PushDefaultSimple, nil), + remotePushDefaultFn: stubRemotePushDefault("", nil), + }, }, httpStub: nil, wantPR: 13, @@ -174,9 +174,11 @@ func TestFind(t *testing.T) { branchFn: func() (string, error) { return "blueberries", nil }, - branchConfig: stubBranchConfig(git.BranchConfig{}, nil), - pushDefault: stubPushDefault("simple", nil), - remotePushDefault: stubRemotePushDefault("", nil), + gitConfigClient: stubGitConfigClient{ + readBranchConfigFn: stubBranchConfig(git.BranchConfig{}, nil), + pushDefaultFn: stubPushDefault(git.PushDefaultSimple, nil), + remotePushDefaultFn: stubRemotePushDefault("", nil), + }, }, httpStub: func(r *httpmock.Registry) { r.Register( @@ -197,9 +199,11 @@ func TestFind(t *testing.T) { branchFn: func() (string, error) { return "blueberries", nil }, - branchConfig: stubBranchConfig(git.BranchConfig{}, nil), - pushDefault: stubPushDefault("simple", nil), - remotePushDefault: stubRemotePushDefault("", nil), + gitConfigClient: stubGitConfigClient{ + readBranchConfigFn: stubBranchConfig(git.BranchConfig{}, nil), + pushDefaultFn: stubPushDefault(git.PushDefaultSimple, nil), + remotePushDefaultFn: stubRemotePushDefault("", nil), + }, }, httpStub: func(r *httpmock.Registry) { r.Register( @@ -223,15 +227,17 @@ func TestFind(t *testing.T) { ExitCode: 128, } }, - branchConfig: stubBranchConfig(git.BranchConfig{}, &git.GitError{ - Stderr: "fatal: branchConfig error", - ExitCode: 128, - }), - pushDefault: stubPushDefault("simple", nil), - remotePushDefault: stubRemotePushDefault("", &git.GitError{ - Stderr: "fatal: remotePushDefault error", - ExitCode: 128, - }), + gitConfigClient: stubGitConfigClient{ + readBranchConfigFn: stubBranchConfig(git.BranchConfig{}, &git.GitError{ + Stderr: "fatal: branchConfig error", + ExitCode: 128, + }), + pushDefaultFn: stubPushDefault(git.PushDefaultSimple, nil), + remotePushDefaultFn: stubRemotePushDefault("", &git.GitError{ + Stderr: "fatal: remotePushDefault error", + ExitCode: 128, + }), + }, }, httpStub: func(r *httpmock.Registry) { r.Register( @@ -252,10 +258,12 @@ func TestFind(t *testing.T) { branchFn: func() (string, error) { return "blueberries", nil }, - branchConfig: stubBranchConfig(git.BranchConfig{}, nil), - pushDefault: stubPushDefault("simple", nil), - parsePushRevision: stubParsedPushRevision("", nil), - remotePushDefault: stubRemotePushDefault("", nil), + gitConfigClient: stubGitConfigClient{ + readBranchConfigFn: stubBranchConfig(git.BranchConfig{}, nil), + pushDefaultFn: stubPushDefault(git.PushDefaultSimple, nil), + pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("testErr")), + remotePushDefaultFn: stubRemotePushDefault("", nil), + }, }, httpStub: func(r *httpmock.Registry) { r.Register( @@ -296,10 +304,12 @@ func TestFind(t *testing.T) { branchFn: func() (string, error) { return "blueberries", nil }, - branchConfig: stubBranchConfig(git.BranchConfig{}, nil), - pushDefault: stubPushDefault("simple", nil), - remotePushDefault: stubRemotePushDefault("", nil), - parsePushRevision: stubParsedPushRevision("", nil), + gitConfigClient: stubGitConfigClient{ + readBranchConfigFn: stubBranchConfig(git.BranchConfig{}, nil), + pushDefaultFn: stubPushDefault(git.PushDefaultSimple, nil), + remotePushDefaultFn: stubRemotePushDefault("", nil), + pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("testErr")), + }, }, httpStub: func(r *httpmock.Registry) { r.Register( @@ -339,10 +349,12 @@ func TestFind(t *testing.T) { branchFn: func() (string, error) { return "blueberries", nil }, - branchConfig: stubBranchConfig(git.BranchConfig{}, nil), - pushDefault: stubPushDefault("simple", nil), - remotePushDefault: stubRemotePushDefault("", nil), - parsePushRevision: stubParsedPushRevision("", nil), + gitConfigClient: stubGitConfigClient{ + readBranchConfigFn: stubBranchConfig(git.BranchConfig{}, nil), + pushDefaultFn: stubPushDefault(git.PushDefaultSimple, nil), + remotePushDefaultFn: stubRemotePushDefault("", nil), + pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("testErr")), + }, }, httpStub: func(r *httpmock.Registry) { r.Register( @@ -374,10 +386,12 @@ func TestFind(t *testing.T) { branchFn: func() (string, error) { return "blueberries", nil }, - branchConfig: stubBranchConfig(git.BranchConfig{}, nil), - pushDefault: stubPushDefault("simple", nil), - remotePushDefault: stubRemotePushDefault("", nil), - parsePushRevision: stubParsedPushRevision("", nil), + gitConfigClient: stubGitConfigClient{ + readBranchConfigFn: stubBranchConfig(git.BranchConfig{}, nil), + pushDefaultFn: stubPushDefault(git.PushDefaultSimple, nil), + remotePushDefaultFn: stubRemotePushDefault("", nil), + pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("testErr")), + }, }, httpStub: func(r *httpmock.Registry) { r.Register( @@ -423,13 +437,15 @@ func TestFind(t *testing.T) { branchFn: func() (string, error) { return "blueberries", nil }, - branchConfig: stubBranchConfig(git.BranchConfig{ - MergeRef: "refs/heads/blue-upstream-berries", - PushRemoteName: "upstream", - }, nil), - pushDefault: stubPushDefault("upstream", nil), - remotePushDefault: stubRemotePushDefault("", nil), - parsePushRevision: stubParsedPushRevision("", nil), + gitConfigClient: stubGitConfigClient{ + readBranchConfigFn: stubBranchConfig(git.BranchConfig{ + MergeRef: "refs/heads/blue-upstream-berries", + PushRemoteName: "upstream", + }, nil), + pushDefaultFn: stubPushDefault("upstream", nil), + remotePushDefaultFn: stubRemotePushDefault("", nil), + pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("testErr")), + }, }, httpStub: func(r *httpmock.Registry) { r.Register( @@ -463,13 +479,15 @@ func TestFind(t *testing.T) { branchFn: func() (string, error) { return "blueberries", nil }, - branchConfig: stubBranchConfig(git.BranchConfig{ - MergeRef: "refs/heads/blue-upstream-berries", - PushRemoteURL: remoteUpstream.Remote.FetchURL, - }, nil), - pushDefault: stubPushDefault("upstream", nil), - remotePushDefault: stubRemotePushDefault("", nil), - parsePushRevision: stubParsedPushRevision("", nil), + gitConfigClient: stubGitConfigClient{ + readBranchConfigFn: stubBranchConfig(git.BranchConfig{ + MergeRef: "refs/heads/blue-upstream-berries", + PushRemoteURL: remoteUpstream.Remote.FetchURL, + }, nil), + pushDefaultFn: stubPushDefault("upstream", nil), + remotePushDefaultFn: stubRemotePushDefault("", nil), + pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("testErr")), + }, }, httpStub: func(r *httpmock.Registry) { r.Register( @@ -499,10 +517,12 @@ func TestFind(t *testing.T) { branchFn: func() (string, error) { return "blueberries", nil }, - branchConfig: stubBranchConfig(git.BranchConfig{}, nil), - pushDefault: stubPushDefault("simple", nil), - remotePushDefault: stubRemotePushDefault("", nil), - parsePushRevision: stubParsedPushRevision("other/blueberries", nil), + gitConfigClient: stubGitConfigClient{ + readBranchConfigFn: stubBranchConfig(git.BranchConfig{}, nil), + pushDefaultFn: stubPushDefault(git.PushDefaultSimple, nil), + remotePushDefaultFn: stubRemotePushDefault("", nil), + pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{Remote: "other", Branch: "blueberries"}, nil), + }, }, httpStub: func(r *httpmock.Registry) { r.Register( @@ -534,9 +554,11 @@ func TestFind(t *testing.T) { branchFn: func() (string, error) { return "blueberries", nil }, - branchConfig: stubBranchConfig(git.BranchConfig{ - MergeRef: "refs/pull/13/head", - }, nil), + gitConfigClient: stubGitConfigClient{ + readBranchConfigFn: stubBranchConfig(git.BranchConfig{ + MergeRef: "refs/pull/13/head", + }, nil), + }, }, httpStub: func(r *httpmock.Registry) { r.Register( @@ -559,11 +581,13 @@ func TestFind(t *testing.T) { branchFn: func() (string, error) { return "blueberries", nil }, - branchConfig: stubBranchConfig(git.BranchConfig{ - MergeRef: "refs/pull/13/head", - }, nil), - pushDefault: stubPushDefault("simple", nil), - remotePushDefault: stubRemotePushDefault("", nil), + gitConfigClient: stubGitConfigClient{ + readBranchConfigFn: stubBranchConfig(git.BranchConfig{ + MergeRef: "refs/pull/13/head", + }, nil), + pushDefaultFn: stubPushDefault(git.PushDefaultSimple, nil), + remotePushDefaultFn: stubRemotePushDefault("", nil), + }, }, httpStub: func(r *httpmock.Registry) { r.Register( @@ -575,32 +599,32 @@ func TestFind(t *testing.T) { r.Register( httpmock.GraphQL(`query PullRequestProjectItems\b`), httpmock.GraphQLQuery(`{ - "data": { - "repository": { - "pullRequest": { - "projectItems": { - "nodes": [ - { - "id": "PVTI_lADOB-vozM4AVk16zgK6U50", - "project": { - "id": "PVT_kwDOB-vozM4AVk16", - "title": "Test Project" - }, - "status": { - "optionId": "47fc9ee4", - "name": "In Progress" - } - } - ], - "pageInfo": { - "hasNextPage": false, - "endCursor": "MQ" - } - } - } - } - } - }`, + "data": { + "repository": { + "pullRequest": { + "projectItems": { + "nodes": [ + { + "id": "PVTI_lADOB-vozM4AVk16zgK6U50", + "project": { + "id": "PVT_kwDOB-vozM4AVk16", + "title": "Test Project" + }, + "status": { + "optionId": "47fc9ee4", + "name": "In Progress" + } + } + ], + "pageInfo": { + "hasNextPage": false, + "endCursor": "MQ" + } + } + } + } + } + }`, func(query string, inputs map[string]interface{}) { require.Equal(t, float64(13), inputs["number"]) require.Equal(t, "OWNER", inputs["owner"]) @@ -624,13 +648,10 @@ func TestFind(t *testing.T) { httpClient: func() (*http.Client, error) { return &http.Client{Transport: reg}, nil }, - baseRepoFn: tt.args.baseRepoFn, - branchFn: tt.args.branchFn, - branchConfig: tt.args.branchConfig, - pushDefault: tt.args.pushDefault, - remotePushDefault: tt.args.remotePushDefault, - parsePushRevision: tt.args.parsePushRevision, - remotesFn: stubRemotes(context.Remotes{ + baseRepoFn: tt.args.baseRepoFn, + branchFn: tt.args.branchFn, + gitConfigClient: tt.args.gitConfigClient, + remotesFn: stubRemotes(ghContext.Remotes{ &remoteOrigin, &remoteOther, &remoteUpstream, @@ -667,366 +688,14 @@ func TestFind(t *testing.T) { } } -func TestParsePRRefs(t *testing.T) { - originOwnerUrl, err := url.Parse("https://github.com/ORIGINOWNER/REPO.git") - if err != nil { - t.Fatal(err) - } - remoteOrigin := context.Remote{ - Remote: &git.Remote{ - Name: "origin", - FetchURL: originOwnerUrl, - }, - Repo: ghrepo.New("ORIGINOWNER", "REPO"), - } - remoteOther := context.Remote{ - Remote: &git.Remote{ - Name: "other", - FetchURL: originOwnerUrl, - }, - Repo: ghrepo.New("ORIGINOWNER", "REPO"), - } - - upstreamOwnerUrl, err := url.Parse("https://github.com/UPSTREAMOWNER/REPO.git") - if err != nil { - t.Fatal(err) - } - remoteUpstream := context.Remote{ - Remote: &git.Remote{ - Name: "upstream", - FetchURL: upstreamOwnerUrl, - }, - Repo: ghrepo.New("UPSTREAMOWNER", "REPO"), - } - - tests := []struct { - name string - branchConfig git.BranchConfig - pushDefault string - parsedPushRevision string - remotePushDefault string - currentBranchName string - baseRefRepo ghrepo.Interface - rems context.Remotes - wantPRRefs PullRequestRefs - wantErr error - }{ - { - name: "When the branch is called 'blueberries' with an empty branch config, it returns the correct PullRequestRefs", - branchConfig: git.BranchConfig{}, - currentBranchName: "blueberries", - baseRefRepo: remoteOrigin.Repo, - wantPRRefs: PullRequestRefs{ - BranchName: "blueberries", - HeadRepo: remoteOrigin.Repo, - BaseRepo: remoteOrigin.Repo, - }, - wantErr: nil, - }, - { - name: "When the branch is called 'otherBranch' with an empty branch config, it returns the correct PullRequestRefs", - branchConfig: git.BranchConfig{}, - currentBranchName: "otherBranch", - baseRefRepo: remoteOrigin.Repo, - wantPRRefs: PullRequestRefs{ - BranchName: "otherBranch", - HeadRepo: remoteOrigin.Repo, - BaseRepo: remoteOrigin.Repo, - }, - wantErr: nil, - }, - { - name: "When the branch name doesn't match the branch name in BranchConfig.Push, it returns the BranchConfig.Push branch name", - parsedPushRevision: "origin/pushBranch", - currentBranchName: "blueberries", - baseRefRepo: remoteOrigin.Repo, - rems: context.Remotes{ - &remoteOrigin, - }, - wantPRRefs: PullRequestRefs{ - BranchName: "pushBranch", - HeadRepo: remoteOrigin.Repo, - BaseRepo: remoteOrigin.Repo, - }, - wantErr: nil, - }, - { - name: "When the push revision doesn't match a remote, it returns an error", - parsedPushRevision: "origin/differentPushBranch", - currentBranchName: "blueberries", - baseRefRepo: remoteOrigin.Repo, - rems: context.Remotes{ - &remoteUpstream, - &remoteOther, - }, - wantPRRefs: PullRequestRefs{}, - wantErr: fmt.Errorf("no remote for %q found in %q", "origin/differentPushBranch", "upstream, other"), - }, - { - name: "When the branch name doesn't match a different branch name in BranchConfig.Push and the remote isn't 'origin', it returns the BranchConfig.Push branch name", - parsedPushRevision: "other/pushBranch", - currentBranchName: "blueberries", - baseRefRepo: remoteOrigin.Repo, - rems: context.Remotes{ - &remoteOther, - }, - wantPRRefs: PullRequestRefs{ - BranchName: "pushBranch", - HeadRepo: remoteOther.Repo, - BaseRepo: remoteOrigin.Repo, - }, - wantErr: nil, - }, - { - name: "When the push remote is the same as the baseRepo, it returns the baseRepo as the PullRequestRefs HeadRepo", - branchConfig: git.BranchConfig{ - PushRemoteName: remoteOrigin.Remote.Name, - }, - currentBranchName: "blueberries", - baseRefRepo: remoteOrigin.Repo, - rems: context.Remotes{ - &remoteOrigin, - &remoteUpstream, - }, - wantPRRefs: PullRequestRefs{ - BranchName: "blueberries", - HeadRepo: remoteOrigin.Repo, - BaseRepo: remoteOrigin.Repo, - }, - wantErr: nil, - }, - { - name: "When the push remote is different from the baseRepo, it returns the push remote repo as the PullRequestRefs HeadRepo", - branchConfig: git.BranchConfig{ - PushRemoteName: remoteOrigin.Remote.Name, - }, - currentBranchName: "blueberries", - baseRefRepo: remoteUpstream.Repo, - rems: context.Remotes{ - &remoteOrigin, - &remoteUpstream, - }, - wantPRRefs: PullRequestRefs{ - BranchName: "blueberries", - HeadRepo: remoteOrigin.Repo, - BaseRepo: remoteUpstream.Repo, - }, - wantErr: nil, - }, - { - name: "When the push remote defined by a URL and the baseRepo is different from the push remote, it returns the push remote repo as the PullRequestRefs HeadRepo", - branchConfig: git.BranchConfig{ - PushRemoteURL: remoteOrigin.Remote.FetchURL, - }, - currentBranchName: "blueberries", - baseRefRepo: remoteUpstream.Repo, - rems: context.Remotes{ - &remoteOrigin, - &remoteUpstream, - }, - wantPRRefs: PullRequestRefs{ - BranchName: "blueberries", - HeadRepo: remoteOrigin.Repo, - BaseRepo: remoteUpstream.Repo, - }, - wantErr: nil, - }, - { - name: "When the push remote and merge ref are configured to a different repo and push.default = upstream, it should return the branch name from the other repo", - branchConfig: git.BranchConfig{ - PushRemoteName: remoteUpstream.Remote.Name, - MergeRef: "refs/heads/blue-upstream-berries", - }, - pushDefault: "upstream", - currentBranchName: "blueberries", - baseRefRepo: remoteOrigin.Repo, - rems: context.Remotes{ - &remoteOrigin, - &remoteUpstream, - }, - wantPRRefs: PullRequestRefs{ - BranchName: "blue-upstream-berries", - HeadRepo: remoteUpstream.Repo, - BaseRepo: remoteOrigin.Repo, - }, - wantErr: nil, - }, - { - name: "When the push remote and merge ref are configured to a different repo and push.default = tracking, it should return the branch name from the other repo", - branchConfig: git.BranchConfig{ - PushRemoteName: remoteUpstream.Remote.Name, - MergeRef: "refs/heads/blue-upstream-berries", - }, - pushDefault: "tracking", - currentBranchName: "blueberries", - baseRefRepo: remoteOrigin.Repo, - rems: context.Remotes{ - &remoteOrigin, - &remoteUpstream, - }, - wantPRRefs: PullRequestRefs{ - BranchName: "blue-upstream-berries", - HeadRepo: remoteUpstream.Repo, - BaseRepo: remoteOrigin.Repo, - }, - wantErr: nil, - }, - { - name: "When remote.pushDefault is set, it returns the correct PullRequestRefs", - branchConfig: git.BranchConfig{}, - remotePushDefault: remoteUpstream.Remote.Name, - currentBranchName: "blueberries", - baseRefRepo: remoteOrigin.Repo, - rems: context.Remotes{ - &remoteOrigin, - &remoteUpstream, - }, - wantPRRefs: PullRequestRefs{ - BranchName: "blueberries", - HeadRepo: remoteUpstream.Repo, - BaseRepo: remoteOrigin.Repo, - }, - wantErr: nil, - }, - { - name: "When the remote name is set on the branch, it returns the correct PullRequestRefs", - branchConfig: git.BranchConfig{ - RemoteName: remoteUpstream.Remote.Name, - }, - currentBranchName: "blueberries", - baseRefRepo: remoteOrigin.Repo, - rems: context.Remotes{ - &remoteOrigin, - &remoteUpstream, - }, - wantPRRefs: PullRequestRefs{ - BranchName: "blueberries", - HeadRepo: remoteUpstream.Repo, - BaseRepo: remoteOrigin.Repo, - }, - wantErr: nil, - }, - { - name: "When the remote URL is set on the branch, it returns the correct PullRequestRefs", - branchConfig: git.BranchConfig{ - RemoteURL: remoteUpstream.Remote.FetchURL, - }, - currentBranchName: "blueberries", - baseRefRepo: remoteOrigin.Repo, - rems: context.Remotes{ - &remoteOrigin, - &remoteUpstream, - }, - wantPRRefs: PullRequestRefs{ - BranchName: "blueberries", - HeadRepo: remoteUpstream.Repo, - BaseRepo: remoteOrigin.Repo, - }, - wantErr: nil, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - prRefs, err := ParsePRRefs(tt.currentBranchName, tt.branchConfig, tt.parsedPushRevision, tt.pushDefault, tt.remotePushDefault, tt.baseRefRepo, tt.rems) - if tt.wantErr != nil { - require.Equal(t, tt.wantErr, err) - } else { - require.NoError(t, err) - } - require.Equal(t, tt.wantPRRefs, prRefs) - }) - } -} - -func TestPRRefs_GetPRHeadLabel(t *testing.T) { - originRepo := ghrepo.New("ORIGINOWNER", "REPO") - upstreamRepo := ghrepo.New("UPSTREAMOWNER", "REPO") - tests := []struct { - name string - prRefs PullRequestRefs - want string - }{ - { - name: "When the HeadRepo and BaseRepo match, it returns the branch name", - prRefs: PullRequestRefs{ - BranchName: "blueberries", - HeadRepo: originRepo, - BaseRepo: originRepo, - }, - want: "blueberries", - }, - { - name: "When the HeadRepo and BaseRepo do not match, it returns the prepended HeadRepo owner to the branch name", - prRefs: PullRequestRefs{ - BranchName: "blueberries", - HeadRepo: originRepo, - BaseRepo: upstreamRepo, - }, - want: "ORIGINOWNER:blueberries", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - assert.Equal(t, tt.want, tt.prRefs.GetPRHeadLabel()) - }) - } -} - -func TestPullRequestRefs_HasHead(t *testing.T) { - tests := []struct { - name string - prRefs PullRequestRefs - want bool - }{ - { - name: "HeadRepo is nil and BranchName is empty, return false", - prRefs: PullRequestRefs{ - HeadRepo: nil, - BranchName: "", - }, - want: false, - }, - { - name: "HeadRepo is not nil and BranchName is empty, return false", - prRefs: PullRequestRefs{ - HeadRepo: ghrepo.New("ORIGINOWNER", "REPO"), - BranchName: "", - }, - want: false, - }, - { - name: "HeadRepo is nil and BranchName is not empty, return false", - prRefs: PullRequestRefs{ - HeadRepo: nil, - BranchName: "feature-branch", - }, - want: false, - }, - { - name: "HeadRepo is not nil and BranchName is not empty, return true", - prRefs: PullRequestRefs{ - HeadRepo: ghrepo.New("ORIGINOWNER", "REPO"), - BranchName: "feature-branch", - }, - want: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - require.Equal(t, tt.want, tt.prRefs.HasHead()) - }) - } -} - -func stubBranchConfig(branchConfig git.BranchConfig, err error) func(string) (git.BranchConfig, error) { - return func(branch string) (git.BranchConfig, error) { +func stubBranchConfig(branchConfig git.BranchConfig, err error) func(context.Context, string) (git.BranchConfig, error) { + return func(_ context.Context, branch string) (git.BranchConfig, error) { return branchConfig, err } } -func stubRemotes(remotes context.Remotes, err error) func() (context.Remotes, error) { - return func() (context.Remotes, error) { +func stubRemotes(remotes ghContext.Remotes, err error) func() (ghContext.Remotes, error) { + return func() (ghContext.Remotes, error) { return remotes, err } } @@ -1037,20 +706,55 @@ func stubBaseRepoFn(baseRepo ghrepo.Interface, err error) func() (ghrepo.Interfa } } -func stubPushDefault(pushDefault string, err error) func() (string, error) { - return func() (string, error) { +func stubPushDefault(pushDefault git.PushDefault, err error) func(context.Context) (git.PushDefault, error) { + return func(_ context.Context) (git.PushDefault, error) { return pushDefault, err } } -func stubRemotePushDefault(remotePushDefault string, err error) func() (string, error) { - return func() (string, error) { +func stubRemotePushDefault(remotePushDefault string, err error) func(context.Context) (string, error) { + return func(_ context.Context) (string, error) { return remotePushDefault, err } } -func stubParsedPushRevision(parsedPushRevision string, err error) func(string) (string, error) { - return func(_ string) (string, error) { +func stubPushRevision(parsedPushRevision git.RemoteTrackingRef, err error) func(context.Context, string) (git.RemoteTrackingRef, error) { + return func(_ context.Context, _ string) (git.RemoteTrackingRef, error) { return parsedPushRevision, err } } + +type stubGitConfigClient struct { + readBranchConfigFn func(ctx context.Context, branchName string) (git.BranchConfig, error) + pushDefaultFn func(ctx context.Context) (git.PushDefault, error) + remotePushDefaultFn func(ctx context.Context) (string, error) + pushRevisionFn func(ctx context.Context, branchName string) (git.RemoteTrackingRef, error) +} + +func (s stubGitConfigClient) ReadBranchConfig(ctx context.Context, branchName string) (git.BranchConfig, error) { + if s.readBranchConfigFn == nil { + panic("unexpected call to ReadBranchConfig") + } + return s.readBranchConfigFn(ctx, branchName) +} + +func (s stubGitConfigClient) PushDefault(ctx context.Context) (git.PushDefault, error) { + if s.pushDefaultFn == nil { + panic("unexpected call to PushDefault") + } + return s.pushDefaultFn(ctx) +} + +func (s stubGitConfigClient) RemotePushDefault(ctx context.Context) (string, error) { + if s.remotePushDefaultFn == nil { + panic("unexpected call to RemotePushDefault") + } + return s.remotePushDefaultFn(ctx) +} + +func (s stubGitConfigClient) PushRevision(ctx context.Context, branchName string) (git.RemoteTrackingRef, error) { + if s.pushRevisionFn == nil { + panic("unexpected call to PushRevision") + } + return s.pushRevisionFn(ctx, branchName) +} diff --git a/pkg/cmd/pr/shared/git_cached_config_client.go b/pkg/cmd/pr/shared/git_cached_config_client.go new file mode 100644 index 000000000..aea25abee --- /dev/null +++ b/pkg/cmd/pr/shared/git_cached_config_client.go @@ -0,0 +1,18 @@ +package shared + +import ( + "context" + + "github.com/cli/cli/v2/git" +) + +var _ GitConfigClient = &CachedBranchConfigGitConfigClient{} + +type CachedBranchConfigGitConfigClient struct { + CachedBranchConfig git.BranchConfig + GitConfigClient +} + +func (c CachedBranchConfigGitConfigClient) ReadBranchConfig(ctx context.Context, branchName string) (git.BranchConfig, error) { + return c.CachedBranchConfig, nil +} diff --git a/pkg/cmd/pr/status/status.go b/pkg/cmd/pr/status/status.go index eb120e5a7..60202594f 100644 --- a/pkg/cmd/pr/status/status.go +++ b/pkg/cmd/pr/status/status.go @@ -102,43 +102,34 @@ func statusRun(opts *StatusOptions) error { return fmt.Errorf("could not query for pull request for current branch: %w", err) } - branchConfig, err := opts.GitClient.ReadBranchConfig(ctx, currentBranchName) - if err != nil { - return err - } - // Determine if the branch is configured to merge to a special PR ref - prHeadRE := regexp.MustCompile(`^refs/pull/(\d+)/head$`) - if m := prHeadRE.FindStringSubmatch(branchConfig.MergeRef); m != nil { - currentPRNumber, _ = strconv.Atoi(m[1]) - } - - if currentPRNumber == 0 { - remotes, err := opts.Remotes() + if !errors.Is(err, git.ErrNotOnAnyBranch) { + branchConfig, err := opts.GitClient.ReadBranchConfig(ctx, currentBranchName) if err != nil { return err } - // Suppressing these errors as we have other means of computing the PullRequestRefs when these fail. - parsedPushRevision, _ := opts.GitClient.ParsePushRevision(ctx, currentBranchName) - - remotePushDefault, err := opts.GitClient.RemotePushDefault(ctx) - if err != nil { - return err + // Determine if the branch is configured to merge to a special PR ref + prHeadRE := regexp.MustCompile(`^refs/pull/(\d+)/head$`) + if m := prHeadRE.FindStringSubmatch(branchConfig.MergeRef); m != nil { + currentPRNumber, _ = strconv.Atoi(m[1]) } - pushDefault, err := opts.GitClient.PushDefault(ctx) - if err != nil { - return err - } + if currentPRNumber == 0 { + prRefsResolver := shared.NewPullRequestFindRefsResolver( + // We requested the branch config already, so let's cache that + shared.CachedBranchConfigGitConfigClient{ + CachedBranchConfig: branchConfig, + GitConfigClient: opts.GitClient, + }, + opts.Remotes, + ) - prRefs, err := shared.ParsePRRefs(currentBranchName, branchConfig, parsedPushRevision, pushDefault, remotePushDefault, baseRefRepo, remotes) - if err != nil { - return err - } - currentHeadRefBranchName = prRefs.BranchName - } + prRefs, err := prRefsResolver.ResolvePullRequestRefs(baseRefRepo, "", currentBranchName) + if err != nil { + return err + } - if err != nil { - return fmt.Errorf("could not query for pull request for current branch: %w", err) + currentHeadRefBranchName = prRefs.QualifiedHeadRef() + } } } diff --git a/pkg/cmd/pr/status/status_test.go b/pkg/cmd/pr/status/status_test.go index c55604c28..41c01e915 100644 --- a/pkg/cmd/pr/status/status_test.go +++ b/pkg/cmd/pr/status/status_test.go @@ -98,10 +98,10 @@ func TestPRStatus(t *testing.T) { // stub successful git commands rs, cleanup := run.Stub() defer cleanup(t) + rs.Register(`git rev-parse --symbolic-full-name blueberries@{push}`, 128, "") rs.Register(`git config --get-regexp \^branch\\.`, 0, "") rs.Register(`git config remote.pushDefault`, 0, "") - rs.Register(`git rev-parse --abbrev-ref blueberries@{push}`, 0, "") - rs.Register(`git config push.default`, 0, "") + rs.Register(`git config push.default`, 1, "") output, err := runCommand(http, "blueberries", true, "") if err != nil { @@ -133,8 +133,8 @@ func TestPRStatus_reviewsAndChecks(t *testing.T) { defer cleanup(t) rs.Register(`git config --get-regexp \^branch\\.`, 0, "") rs.Register(`git config remote.pushDefault`, 0, "") - rs.Register(`git rev-parse --abbrev-ref blueberries@{push}`, 0, "") - rs.Register(`git config push.default`, 0, "") + rs.Register(`git rev-parse --symbolic-full-name blueberries@{push}`, 128, "") + rs.Register(`git config push.default`, 1, "") output, err := runCommand(http, "blueberries", true, "") if err != nil { @@ -166,8 +166,8 @@ func TestPRStatus_reviewsAndChecksWithStatesByCount(t *testing.T) { defer cleanup(t) rs.Register(`git config --get-regexp \^branch\\.`, 0, "") rs.Register(`git config remote.pushDefault`, 0, "") - rs.Register(`git rev-parse --abbrev-ref blueberries@{push}`, 0, "") - rs.Register(`git config push.default`, 0, "") + rs.Register(`git rev-parse --symbolic-full-name blueberries@{push}`, 128, "") + rs.Register(`git config push.default`, 1, "") output, err := runCommandWithDetector(http, "blueberries", true, "", &fd.EnabledDetectorMock{}) if err != nil { @@ -198,8 +198,8 @@ func TestPRStatus_currentBranch_showTheMostRecentPR(t *testing.T) { defer cleanup(t) rs.Register(`git config --get-regexp \^branch\\.`, 0, "") rs.Register(`git config remote.pushDefault`, 0, "") - rs.Register(`git rev-parse --abbrev-ref blueberries@{push}`, 0, "") - rs.Register(`git config push.default`, 0, "") + rs.Register(`git rev-parse --symbolic-full-name blueberries@{push}`, 128, "") + rs.Register(`git config push.default`, 1, "") output, err := runCommand(http, "blueberries", true, "") if err != nil { @@ -234,8 +234,8 @@ func TestPRStatus_currentBranch_defaultBranch(t *testing.T) { defer cleanup(t) rs.Register(`git config --get-regexp \^branch\\.`, 0, "") rs.Register(`git config remote.pushDefault`, 0, "") - rs.Register(`git rev-parse --abbrev-ref blueberries@{push}`, 0, "") - rs.Register(`git config push.default`, 0, "") + rs.Register(`git rev-parse --symbolic-full-name blueberries@{push}`, 128, "") + rs.Register(`git config push.default`, 1, "") output, err := runCommand(http, "blueberries", true, "") if err != nil { @@ -276,8 +276,8 @@ func TestPRStatus_currentBranch_Closed(t *testing.T) { defer cleanup(t) rs.Register(`git config --get-regexp \^branch\\.`, 0, "") rs.Register(`git config remote.pushDefault`, 0, "") - rs.Register(`git rev-parse --abbrev-ref blueberries@{push}`, 0, "") - rs.Register(`git config push.default`, 0, "") + rs.Register(`git rev-parse --symbolic-full-name blueberries@{push}`, 128, "") + rs.Register(`git config push.default`, 1, "") output, err := runCommand(http, "blueberries", true, "") if err != nil { @@ -301,8 +301,8 @@ func TestPRStatus_currentBranch_Closed_defaultBranch(t *testing.T) { defer cleanup(t) rs.Register(`git config --get-regexp \^branch\\.`, 0, "") rs.Register(`git config remote.pushDefault`, 0, "") - rs.Register(`git rev-parse --abbrev-ref blueberries@{push}`, 0, "") - rs.Register(`git config push.default`, 0, "") + rs.Register(`git rev-parse --symbolic-full-name blueberries@{push}`, 128, "") + rs.Register(`git config push.default`, 1, "") output, err := runCommand(http, "blueberries", true, "") if err != nil { @@ -326,8 +326,8 @@ func TestPRStatus_currentBranch_Merged(t *testing.T) { defer cleanup(t) rs.Register(`git config --get-regexp \^branch\\.`, 0, "") rs.Register(`git config remote.pushDefault`, 0, "") - rs.Register(`git rev-parse --abbrev-ref blueberries@{push}`, 0, "") - rs.Register(`git config push.default`, 0, "") + rs.Register(`git rev-parse --symbolic-full-name blueberries@{push}`, 128, "") + rs.Register(`git config push.default`, 1, "") output, err := runCommand(http, "blueberries", true, "") if err != nil { @@ -351,8 +351,8 @@ func TestPRStatus_currentBranch_Merged_defaultBranch(t *testing.T) { defer cleanup(t) rs.Register(`git config --get-regexp \^branch\\.`, 0, "") rs.Register(`git config remote.pushDefault`, 0, "") - rs.Register(`git rev-parse --abbrev-ref blueberries@{push}`, 0, "") - rs.Register(`git config push.default`, 0, "") + rs.Register(`git rev-parse --symbolic-full-name blueberries@{push}`, 128, "") + rs.Register(`git config push.default`, 1, "") output, err := runCommand(http, "blueberries", true, "") if err != nil { @@ -376,8 +376,8 @@ func TestPRStatus_blankSlate(t *testing.T) { defer cleanup(t) rs.Register(`git config --get-regexp \^branch\\.`, 0, "") rs.Register(`git config remote.pushDefault`, 0, "") - rs.Register(`git rev-parse --abbrev-ref blueberries@{push}`, 0, "") - rs.Register(`git config push.default`, 0, "") + rs.Register(`git rev-parse --symbolic-full-name blueberries@{push}`, 128, "") + rs.Register(`git config push.default`, 1, "") output, err := runCommand(http, "blueberries", true, "") if err != nil { @@ -432,14 +432,6 @@ func TestPRStatus_detachedHead(t *testing.T) { defer http.Verify(t) http.Register(httpmock.GraphQL(`query PullRequestStatus\b`), httpmock.StringResponse(`{"data": {}}`)) - // stub successful git command - rs, cleanup := run.Stub() - defer cleanup(t) - rs.Register(`git config --get-regexp \^branch\\.`, 0, "") - rs.Register(`git config remote.pushDefault`, 0, "") - rs.Register(`git rev-parse --abbrev-ref @{push}`, 0, "") - rs.Register(`git config push.default`, 0, "") - output, err := runCommand(http, "", true, "") if err != nil { t.Errorf("error running command `pr status`: %v", err) diff --git a/pkg/httpmock/registry.go b/pkg/httpmock/registry.go index 387d0fc95..51aa5a898 100644 --- a/pkg/httpmock/registry.go +++ b/pkg/httpmock/registry.go @@ -3,6 +3,8 @@ package httpmock import ( "fmt" "net/http" + "runtime/debug" + "strings" "sync" "testing" @@ -23,6 +25,7 @@ type Registry struct { func (r *Registry) Register(m Matcher, resp Responder) { r.stubs = append(r.stubs, &Stub{ + Stack: string(debug.Stack()), Matcher: m, Responder: resp, }) @@ -46,17 +49,24 @@ type Testing interface { } func (r *Registry) Verify(t Testing) { - n := 0 + var unmatchedStubStacks []string for _, s := range r.stubs { if !s.matched && !s.exclude { - n++ + unmatchedStubStacks = append(unmatchedStubStacks, s.Stack) } } - if n > 0 { + if len(unmatchedStubStacks) > 0 { t.Helper() - // NOTE: stubs offer no useful reflection, so we can't print details + stacks := strings.Builder{} + for i, stack := range unmatchedStubStacks { + stacks.WriteString(fmt.Sprintf("Stub %d:\n", i+1)) + stacks.WriteString(fmt.Sprintf("\t%s", stack)) + if stack != unmatchedStubStacks[len(unmatchedStubStacks)-1] { + stacks.WriteString("\n") + } + } // about dead stubs and what they were trying to match - t.Errorf("%d unmatched HTTP stubs", n) + t.Errorf("%d HTTP stubs unmatched, stacks:\n%s", len(unmatchedStubStacks), stacks.String()) } } @@ -84,7 +94,7 @@ func (r *Registry) RoundTrip(req *http.Request) (*http.Response, error) { if stub == nil { r.mu.Unlock() - return nil, fmt.Errorf("no registered stubs matched %v", req) + return nil, fmt.Errorf("no registered HTTP stubs matched %v", req) } r.Requests = append(r.Requests, req) diff --git a/pkg/httpmock/stub.go b/pkg/httpmock/stub.go index 4e61d12f4..745c12417 100644 --- a/pkg/httpmock/stub.go +++ b/pkg/httpmock/stub.go @@ -15,6 +15,7 @@ type Matcher func(req *http.Request) bool type Responder func(req *http.Request) (*http.Response, error) type Stub struct { + Stack string matched bool Matcher Matcher Responder Responder diff --git a/pkg/option/option.go b/pkg/option/option.go index 8d3b70f3f..caf26dd0b 100644 --- a/pkg/option/option.go +++ b/pkg/option/option.go @@ -46,6 +46,15 @@ func None[T any]() Option[T] { return Option[T]{} } +func SomeIfNonZero[T comparable](value T) Option[T] { + // value is a zero value then return a None + var zero T + if value == zero { + return None[T]() + } + return Some(value) +} + // String implements the [fmt.Stringer] interface. func (o Option[T]) String() string { if o.present {