diff --git a/acceptance/testdata/pr/pr-view-status-respects-branch-pushremote.txtar b/acceptance/testdata/pr/pr-view-status-respects-branch-pushremote.txtar new file mode 100644 index 000000000..ef80cd8ba --- /dev/null +++ b/acceptance/testdata/pr/pr-view-status-respects-branch-pushremote.txtar @@ -0,0 +1,45 @@ +# Setup environment variables used for testscript +env REPO=${SCRIPT_NAME}-${RANDOM_STRING} +env FORK=${REPO}-fork + +# Use gh as a credential helper +exec gh auth setup-git + +# Create a repository to act as upstream with a file so it has a default branch +exec gh repo create ${ORG}/${REPO} --add-readme --private + +# Defer repo cleanup of upstream +defer gh repo delete --yes ${ORG}/${REPO} +exec gh repo view ${ORG}/${REPO} --json id --jq '.id' +stdout2env REPO_ID + +# Create a user fork of repository as opposed to private organization fork +exec gh repo fork ${ORG}/${REPO} --org ${ORG} --fork-name ${FORK} + +# Defer repo cleanup of fork +defer gh repo delete --yes ${ORG}/${FORK} +sleep 5 +exec gh repo view ${ORG}/${FORK} --json id --jq '.id' +stdout2env FORK_ID + +# Clone the repo +exec gh repo clone ${ORG}/${FORK} +cd ${FORK} + +# Prepare a branch where changes are pulled from the upstream default branch but pushed to fork +exec git checkout -b feature-branch upstream/main +exec git config branch.feature-branch.pushRemote origin +exec git commit --allow-empty -m 'Empty Commit' +exec git push + +# Create the PR spanning upstream and fork repositories, gh pr create does not support headRepositoryId needed for private forks +exec gh api graphql -F repositoryId="${REPO_ID}" -F headRepositoryId="${FORK_ID}" -F query='mutation CreatePullRequest($headRepositoryId: ID!, $repositoryId: ID!) { createPullRequest(input:{ baseRefName: "main", body: "Feature Body", draft: false, headRefName: "feature-branch", headRepositoryId: $headRepositoryId, repositoryId: $repositoryId, title:"Feature Title" }){ pullRequest{ id url } } }' + +# View the PR +exec gh pr view +stdout 'Feature Title' + +# Check the PR status +env PR_STATUS_BRANCH=#1 Feature Title [${ORG}:feature-branch] +exec gh pr status +stdout $PR_STATUS_BRANCH diff --git a/acceptance/testdata/pr/pr-view-status-respects-push-destination.txtar b/acceptance/testdata/pr/pr-view-status-respects-push-destination.txtar new file mode 100644 index 000000000..ff9db4037 --- /dev/null +++ b/acceptance/testdata/pr/pr-view-status-respects-push-destination.txtar @@ -0,0 +1,38 @@ +# Setup environment variables used for testscript +env REPO=${SCRIPT_NAME}-${RANDOM_STRING} + +# Use gh as a credential helper +exec gh auth setup-git + +# Create a repository with a file so it has a default branch +exec gh repo create ${ORG}/${REPO} --add-readme --private + +# Defer repo cleanup +defer gh repo delete --yes ${ORG}/${REPO} + +# Clone the repo +exec gh repo clone ${ORG}/${REPO} +cd ${REPO} + +# Configure default push behavior so local and remote branches will be the same +exec git config push.default current + +# Prepare a branch where changes are pulled from the default branch instead of remote branch of same name +exec git checkout -b feature-branch +exec git branch --set-upstream-to origin/main +exec git rev-parse --abbrev-ref feature-branch@{upstream} +stdout origin/main + +# Create the PR +exec git commit --allow-empty -m 'Empty Commit' +exec git push +exec gh pr create -B main -H feature-branch --title 'Feature Title' --body 'Feature Body' + +# View the PR +exec gh pr view +stdout 'Feature Title' + +# Check the PR status +env PR_STATUS_BRANCH=#1 Feature Title [feature-branch] +exec gh pr status +stdout $PR_STATUS_BRANCH diff --git a/acceptance/testdata/pr/pr-view-status-respects-remote-pushdefault.txtar b/acceptance/testdata/pr/pr-view-status-respects-remote-pushdefault.txtar new file mode 100644 index 000000000..8bfac2837 --- /dev/null +++ b/acceptance/testdata/pr/pr-view-status-respects-remote-pushdefault.txtar @@ -0,0 +1,45 @@ +# Setup environment variables used for testscript +env REPO=${SCRIPT_NAME}-${RANDOM_STRING} +env FORK=${REPO}-fork + +# Use gh as a credential helper +exec gh auth setup-git + +# Create a repository to act as upstream with a file so it has a default branch +exec gh repo create ${ORG}/${REPO} --add-readme --private + +# Defer repo cleanup of upstream +defer gh repo delete --yes ${ORG}/${REPO} +exec gh repo view ${ORG}/${REPO} --json id --jq '.id' +stdout2env REPO_ID + +# Create a user fork of repository as opposed to private organization fork +exec gh repo fork ${ORG}/${REPO} --org ${ORG} --fork-name ${FORK} + +# Defer repo cleanup of fork +defer gh repo delete --yes ${ORG}/${FORK} +sleep 5 +exec gh repo view ${ORG}/${FORK} --json id --jq '.id' +stdout2env FORK_ID + +# Clone the repo +exec gh repo clone ${ORG}/${FORK} +cd ${FORK} + +# Prepare a branch where changes are pulled from the upstream default branch but pushed to fork +exec git checkout -b feature-branch upstream/main +exec git config remote.pushDefault origin +exec git commit --allow-empty -m 'Empty Commit' +exec git push + +# Create the PR spanning upstream and fork repositories, gh pr create does not support headRepositoryId needed for private forks +exec gh api graphql -F repositoryId="${REPO_ID}" -F headRepositoryId="${FORK_ID}" -F query='mutation CreatePullRequest($headRepositoryId: ID!, $repositoryId: ID!) { createPullRequest(input:{ baseRefName: "main", body: "Feature Body", draft: false, headRefName: "feature-branch", headRepositoryId: $headRepositoryId, repositoryId: $repositoryId, title:"Feature Title" }){ pullRequest{ id url } } }' + +# View the PR +exec gh pr view +stdout 'Feature Title' + +# Check the PR status +env PR_STATUS_BRANCH=#1 Feature Title [${ORG}:feature-branch] +exec gh pr status +stdout $PR_STATUS_BRANCH diff --git a/acceptance/testdata/pr/pr-view-status-respects-simple-pushdefault.txtar b/acceptance/testdata/pr/pr-view-status-respects-simple-pushdefault.txtar new file mode 100644 index 000000000..114f401ec --- /dev/null +++ b/acceptance/testdata/pr/pr-view-status-respects-simple-pushdefault.txtar @@ -0,0 +1,35 @@ +# Setup environment variables used for testscript +env REPO=${SCRIPT_NAME}-${RANDOM_STRING} + +# Use gh as a credential helper +exec gh auth setup-git + +# Create a repository with a file so it has a default branch +exec gh repo create ${ORG}/${REPO} --add-readme --private + +# Defer repo cleanup +defer gh repo delete --yes ${ORG}/${REPO} + +# Clone the repo +exec gh repo clone ${ORG}/${REPO} +cd ${REPO} + +# Configure default push behavior so local and remote branches have to be the same +exec git config push.default simple + +# Prepare a branch where changes are pulled from the default branch instead of remote branch of same name +exec git checkout -b feature-branch origin/main + +# Create the PR +exec git commit --allow-empty -m 'Empty Commit' +exec git push origin feature-branch +exec gh pr create -H feature-branch --title 'Feature Title' --body 'Feature Body' + +# View the PR +exec gh pr view +stdout 'Feature Title' + +# Check the PR status +env PR_STATUS_BRANCH=#1 Feature Title [feature-branch] +exec gh pr status +stdout $PR_STATUS_BRANCH diff --git a/git/client.go b/git/client.go index 19688ca51..11a2e2e20 100644 --- a/git/client.go +++ b/git/client.go @@ -376,20 +376,20 @@ func (c *Client) lookupCommit(ctx context.Context, sha, format string) ([]byte, return out, nil } -// ReadBranchConfig parses the `branch.BRANCH.(remote|merge|gh-merge-base)` part of git config. +// ReadBranchConfig parses the `branch.BRANCH.(remote|merge|pushremote|gh-merge-base)` part of git config. // If no branch config is found or there is an error in the command, it returns an empty BranchConfig. // Downstream consumers of ReadBranchConfig should consider the behavior they desire if this errors, // as an empty config is not necessarily breaking. func (c *Client) ReadBranchConfig(ctx context.Context, branch string) (BranchConfig, error) { prefix := regexp.QuoteMeta(fmt.Sprintf("branch.%s.", branch)) - args := []string{"config", "--get-regexp", fmt.Sprintf("^%s(remote|merge|%s)$", prefix, MergeBaseConfig)} + args := []string{"config", "--get-regexp", fmt.Sprintf("^%s(remote|merge|pushremote|%s)$", prefix, MergeBaseConfig)} cmd, err := c.Command(ctx, args...) if err != nil { return BranchConfig{}, err } - out, err := cmd.Output() + branchCfgOut, err := cmd.Output() if err != nil { // This is the error we expect if the git command does not run successfully. // If the ExitCode is 1, then we just didn't find any config for the branch. @@ -400,13 +400,14 @@ func (c *Client) ReadBranchConfig(ctx context.Context, branch string) (BranchCon return BranchConfig{}, nil } - return parseBranchConfig(outputLines(out)), nil + return parseBranchConfig(outputLines(branchCfgOut)), nil } -func parseBranchConfig(configLines []string) BranchConfig { +func parseBranchConfig(branchConfigLines []string) BranchConfig { var cfg BranchConfig - for _, line := range configLines { + // Read the config lines for the specific branch + for _, line := range branchConfigLines { parts := strings.SplitN(line, " ", 2) if len(parts) < 2 { continue @@ -414,21 +415,16 @@ func parseBranchConfig(configLines []string) BranchConfig { keys := strings.Split(parts[0], ".") switch keys[len(keys)-1] { case "remote": - if strings.Contains(parts[1], ":") { - u, err := ParseURL(parts[1]) - if err != nil { - continue - } - cfg.RemoteURL = u - } else if !isFilesystemPath(parts[1]) { - cfg.RemoteName = parts[1] - } + cfg.RemoteURL, cfg.RemoteName = parseRemoteURLOrName(parts[1]) + case "pushremote": + cfg.PushRemoteURL, cfg.PushRemoteName = parseRemoteURLOrName(parts[1]) case "merge": cfg.MergeRef = parts[1] case MergeBaseConfig: cfg.MergeBase = parts[1] } } + return cfg } @@ -445,6 +441,47 @@ func (c *Client) SetBranchConfig(ctx context.Context, branch, name, value string return err } +// PushDefault returns the value of push.default in the config. If the value +// is not set, it returns "simple" (the default git value). See +// https://git-scm.com/docs/git-config#Documentation/git-config.txt-pushdefault +func (c *Client) PushDefault(ctx context.Context) (string, error) { + pushDefault, err := c.Config(ctx, "push.default") + if err == nil { + return pushDefault, nil + } + + var gitError *GitError + if ok := errors.As(err, &gitError); ok && gitError.ExitCode == 1 { + return "simple", nil + } + return "", err +} + +// RemotePushDefault returns the value of remote.pushDefault in the config. If +// the value is not set, it returns an empty string. +func (c *Client) RemotePushDefault(ctx context.Context) (string, error) { + remotePushDefault, err := c.Config(ctx, "remote.pushDefault") + if err == nil { + return remotePushDefault, nil + } + + var gitError *GitError + if ok := errors.As(err, &gitError); ok && gitError.ExitCode == 1 { + return "", nil + } + + return "", err +} + +// ParsePushRevision gets the value of the @{push} revision syntax +// An error here doesn't necessarily mean something is broken, but may mean that the @{push} +// revision syntax couldn't be resolved, such as in non-centralized workflows with +// push.default = simple. Downstream consumers should consider how to handle this error. +func (c *Client) ParsePushRevision(ctx context.Context, branch string) (string, error) { + revParseOut, err := c.revParse(ctx, "--abbrev-ref", branch+"@{push}") + return firstLine(revParseOut), err +} + func (c *Client) DeleteLocalTag(ctx context.Context, tag string) error { args := []string{"tag", "-d", tag} cmd, err := c.Command(ctx, args...) @@ -790,6 +827,17 @@ func parseRemotes(remotesStr []string) RemoteSet { return remotes } +func parseRemoteURLOrName(value string) (*url.URL, string) { + if strings.Contains(value, ":") { + if u, err := ParseURL(value); err == nil { + return u, "" + } + } else if !isFilesystemPath(value) { + return nil, value + } + return nil, "" +} + func populateResolvedRemotes(remotes RemoteSet, resolved []string) { for _, l := range resolved { parts := strings.SplitN(l, " ", 2) diff --git a/git/client_test.go b/git/client_test.go index fe3047db9..9fa076199 100644 --- a/git/client_test.go +++ b/git/client_test.go @@ -14,6 +14,7 @@ import ( "strings" "testing" + "github.com/MakeNowJust/heredoc" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -728,56 +729,76 @@ func TestClientCommitBody(t *testing.T) { func TestClientReadBranchConfig(t *testing.T) { tests := []struct { name string - cmdExitStatus int - cmdStdout string - cmdStderr string + cmds mockedCommands branch string wantBranchConfig BranchConfig wantError *GitError }{ { - name: "read branch config", - cmdExitStatus: 0, - cmdStdout: "branch.trunk.remote origin\nbranch.trunk.merge refs/heads/trunk\nbranch.trunk.gh-merge-base trunk", - branch: "trunk", - wantBranchConfig: BranchConfig{RemoteName: "origin", MergeRef: "refs/heads/trunk", MergeBase: "trunk"}, - wantError: nil, - }, - { - name: "git config runs successfully but returns no output (Exit Code 1)", - cmdExitStatus: 1, - cmdStdout: "", - cmdStderr: "", + name: "when the git config has no (remote|merge|pushremote|gh-merge-base) keys, it should return an empty BranchConfig and no error", + cmds: mockedCommands{ + `path/to/git config --get-regexp ^branch\.trunk\.(remote|merge|pushremote|gh-merge-base)$`: { + ExitStatus: 1, + }, + }, branch: "trunk", wantBranchConfig: BranchConfig{}, wantError: nil, }, { - name: "output error (Exit Code > 1)", - cmdExitStatus: 2, - cmdStdout: "", - cmdStderr: "git error message", + name: "when the git fails to read the config, it should return an empty BranchConfig and the error", + cmds: mockedCommands{ + `path/to/git config --get-regexp ^branch\.trunk\.(remote|merge|pushremote|gh-merge-base)$`: { + ExitStatus: 2, + Stderr: "git error", + }, + }, branch: "trunk", wantBranchConfig: BranchConfig{}, - wantError: &GitError{}, + wantError: &GitError{ + ExitCode: 2, + Stderr: "git error", + }, + }, + { + name: "when the config is read, it should return the correct BranchConfig", + cmds: mockedCommands{ + `path/to/git config --get-regexp ^branch\.trunk\.(remote|merge|pushremote|gh-merge-base)$`: { + Stdout: heredoc.Doc(` + branch.trunk.remote upstream + branch.trunk.merge refs/heads/trunk + branch.trunk.pushremote origin + branch.trunk.gh-merge-base gh-merge-base + `), + }, + }, + branch: "trunk", + wantBranchConfig: BranchConfig{ + RemoteName: "upstream", + PushRemoteName: "origin", + MergeRef: "refs/heads/trunk", + MergeBase: "gh-merge-base", + }, + wantError: nil, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - cmd, cmdCtx := createCommandContext(t, tt.cmdExitStatus, tt.cmdStdout, tt.cmdStderr) + cmdCtx := createMockedCommandContext(t, tt.cmds) client := Client{ GitPath: "path/to/git", commandContext: cmdCtx, } branchConfig, err := client.ReadBranchConfig(context.Background(), tt.branch) - wantCmdArgs := fmt.Sprintf("path/to/git config --get-regexp ^branch\\.%s\\.(remote|merge|gh-merge-base)$", tt.branch) - assert.Equal(t, wantCmdArgs, strings.Join(cmd.Args[3:], " ")) - assert.Equal(t, tt.wantBranchConfig, branchConfig) if tt.wantError != nil { - assert.ErrorAs(t, err, &tt.wantError) + var gitError *GitError + require.ErrorAs(t, err, &gitError) + assert.Equal(t, tt.wantError.ExitCode, gitError.ExitCode) + assert.Equal(t, tt.wantError.Stderr, gitError.Stderr) } else { - assert.NoError(t, err) + require.NoError(t, err) } + assert.Equal(t, tt.wantBranchConfig, branchConfig) }) } } @@ -810,44 +831,297 @@ func Test_parseBranchConfig(t *testing.T) { }, }, { - name: "remote, merge ref, and merge base all specified", - configLines: []string{ - "branch.trunk.remote origin", - "branch.trunk.merge refs/heads/trunk", - "branch.trunk.gh-merge-base gh-merge-base", - }, + name: "pushremote", + configLines: []string{"branch.trunk.pushremote pushremote"}, wantBranchConfig: BranchConfig{ - RemoteName: "origin", - MergeRef: "refs/heads/trunk", - MergeBase: "gh-merge-base", + PushRemoteName: "pushremote", }, }, { - name: "remote URL", + name: "remote and pushremote are specified by name", configLines: []string{ - "branch.Frederick888/main.remote git@github.com:Frederick888/playground.git", - "branch.Frederick888/main.merge refs/heads/main", + "branch.trunk.remote upstream", + "branch.trunk.pushremote origin", + }, + wantBranchConfig: BranchConfig{ + RemoteName: "upstream", + PushRemoteName: "origin", + }, + }, + { + name: "remote and pushremote are specified by url", + configLines: []string{ + "branch.trunk.remote git@github.com:UPSTREAMOWNER/REPO.git", + "branch.trunk.pushremote git@github.com:ORIGINOWNER/REPO.git", }, wantBranchConfig: BranchConfig{ - MergeRef: "refs/heads/main", RemoteURL: &url.URL{ Scheme: "ssh", User: url.User("git"), Host: "github.com", - Path: "/Frederick888/playground.git", + Path: "/UPSTREAMOWNER/REPO.git", }, + PushRemoteURL: &url.URL{ + Scheme: "ssh", + User: url.User("git"), + Host: "github.com", + Path: "/ORIGINOWNER/REPO.git", + }, + }, + }, + { + name: "remote, pushremote, gh-merge-base, and merge ref all specified", + configLines: []string{ + "branch.trunk.remote remote", + "branch.trunk.pushremote pushremote", + "branch.trunk.gh-merge-base gh-merge-base", + "branch.trunk.merge refs/heads/trunk", + }, + wantBranchConfig: BranchConfig{ + RemoteName: "remote", + PushRemoteName: "pushremote", + MergeBase: "gh-merge-base", + MergeRef: "refs/heads/trunk", }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { branchConfig := parseBranchConfig(tt.configLines) - assert.Equal(t, tt.wantBranchConfig.RemoteName, branchConfig.RemoteName) - assert.Equal(t, tt.wantBranchConfig.MergeRef, branchConfig.MergeRef) - assert.Equal(t, tt.wantBranchConfig.MergeBase, branchConfig.MergeBase) + assert.Equalf(t, tt.wantBranchConfig.RemoteName, branchConfig.RemoteName, "unexpected RemoteName") + assert.Equalf(t, tt.wantBranchConfig.MergeRef, branchConfig.MergeRef, "unexpected MergeRef") + assert.Equalf(t, tt.wantBranchConfig.MergeBase, branchConfig.MergeBase, "unexpected MergeBase") + assert.Equalf(t, tt.wantBranchConfig.PushRemoteName, branchConfig.PushRemoteName, "unexpected PushRemoteName") if tt.wantBranchConfig.RemoteURL != nil { - assert.Equal(t, tt.wantBranchConfig.RemoteURL.String(), branchConfig.RemoteURL.String()) + assert.Equalf(t, tt.wantBranchConfig.RemoteURL.String(), branchConfig.RemoteURL.String(), "unexpected RemoteURL") } + if tt.wantBranchConfig.PushRemoteURL != nil { + assert.Equalf(t, tt.wantBranchConfig.PushRemoteURL.String(), branchConfig.PushRemoteURL.String(), "unexpected PushRemoteURL") + } + }) + } +} + +func Test_parseRemoteURLOrName(t *testing.T) { + tests := []struct { + name string + value string + wantRemoteURL *url.URL + wantRemoteName string + }{ + { + name: "empty value", + value: "", + wantRemoteURL: nil, + wantRemoteName: "", + }, + { + name: "remote URL", + value: "git@github.com:foo/bar.git", + wantRemoteURL: &url.URL{ + Scheme: "ssh", + User: url.User("git"), + Host: "github.com", + Path: "/foo/bar.git", + }, + wantRemoteName: "", + }, + { + name: "remote name", + value: "origin", + wantRemoteURL: nil, + wantRemoteName: "origin", + }, + { + name: "remote name is from filesystem", + value: "./path/to/repo", + wantRemoteURL: nil, + wantRemoteName: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + remoteURL, remoteName := parseRemoteURLOrName(tt.value) + assert.Equal(t, tt.wantRemoteURL, remoteURL) + assert.Equal(t, tt.wantRemoteName, remoteName) + }) + } +} + +func TestClientPushDefault(t *testing.T) { + tests := []struct { + name string + commandResult commandResult + wantPushDefault string + wantError *GitError + }{ + { + name: "push default is not set", + commandResult: commandResult{ + ExitStatus: 1, + Stderr: "error: key does not contain a section: remote.pushDefault", + }, + wantPushDefault: "simple", + wantError: nil, + }, + { + name: "push default is set to current", + commandResult: commandResult{ + ExitStatus: 0, + Stdout: "current", + }, + wantPushDefault: "current", + wantError: nil, + }, + { + name: "push default errors", + commandResult: commandResult{ + ExitStatus: 128, + Stderr: "fatal: git error", + }, + wantPushDefault: "", + wantError: &GitError{ + ExitCode: 128, + Stderr: "fatal: git error", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmdCtx := createMockedCommandContext(t, mockedCommands{ + `path/to/git config push.default`: tt.commandResult, + }, + ) + client := Client{ + GitPath: "path/to/git", + commandContext: cmdCtx, + } + pushDefault, err := client.PushDefault(context.Background()) + if tt.wantError != nil { + var gitError *GitError + require.ErrorAs(t, err, &gitError) + assert.Equal(t, tt.wantError.ExitCode, gitError.ExitCode) + assert.Equal(t, tt.wantError.Stderr, gitError.Stderr) + } else { + require.NoError(t, err) + } + assert.Equal(t, tt.wantPushDefault, pushDefault) + }) + } +} + +func TestClientRemotePushDefault(t *testing.T) { + tests := []struct { + name string + commandResult commandResult + wantRemotePushDefault string + wantError *GitError + }{ + { + name: "remote.pushDefault is not set", + commandResult: commandResult{ + ExitStatus: 1, + Stderr: "error: key does not contain a section: remote.pushDefault", + }, + wantRemotePushDefault: "", + wantError: nil, + }, + { + name: "remote.pushDefault is set to origin", + commandResult: commandResult{ + ExitStatus: 0, + Stdout: "origin", + }, + wantRemotePushDefault: "origin", + wantError: nil, + }, + { + name: "remote.pushDefault errors", + commandResult: commandResult{ + ExitStatus: 128, + Stderr: "fatal: git error", + }, + wantRemotePushDefault: "", + wantError: &GitError{ + ExitCode: 128, + Stderr: "fatal: git error", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmdCtx := createMockedCommandContext(t, mockedCommands{ + `path/to/git config remote.pushDefault`: tt.commandResult, + }, + ) + client := Client{ + GitPath: "path/to/git", + commandContext: cmdCtx, + } + pushDefault, err := client.RemotePushDefault(context.Background()) + if tt.wantError != nil { + var gitError *GitError + require.ErrorAs(t, err, &gitError) + assert.Equal(t, tt.wantError.ExitCode, gitError.ExitCode) + assert.Equal(t, tt.wantError.Stderr, gitError.Stderr) + } else { + require.NoError(t, err) + } + assert.Equal(t, tt.wantRemotePushDefault, pushDefault) + }) + } +} + +func TestClientParsePushRevision(t *testing.T) { + tests := []struct { + name string + branch string + commandResult commandResult + wantParsedPushRevision string + wantError *GitError + }{ + { + name: "@{push} resolves to origin/branchName", + branch: "branchName", + commandResult: commandResult{ + ExitStatus: 0, + Stdout: "origin/branchName", + }, + wantParsedPushRevision: "origin/branchName", + }, + { + name: "@{push} doesn't resolve", + commandResult: commandResult{ + ExitStatus: 128, + Stderr: "fatal: git error", + }, + wantParsedPushRevision: "", + wantError: &GitError{ + ExitCode: 128, + Stderr: "fatal: git error", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := fmt.Sprintf("path/to/git rev-parse --abbrev-ref %s@{push}", tt.branch) + cmdCtx := createMockedCommandContext(t, mockedCommands{ + args(cmd): tt.commandResult, + }) + client := Client{ + GitPath: "path/to/git", + commandContext: cmdCtx, + } + pushDefault, err := client.ParsePushRevision(context.Background(), tt.branch) + if tt.wantError != nil { + var gitError *GitError + require.ErrorAs(t, err, &gitError) + assert.Equal(t, tt.wantError.ExitCode, gitError.ExitCode) + assert.Equal(t, tt.wantError.Stderr, gitError.Stderr) + } else { + require.NoError(t, err) + } + assert.Equal(t, tt.wantParsedPushRevision, pushDefault) }) } } @@ -1601,13 +1875,17 @@ func TestCommandMocking(t *testing.T) { jsonVar, ok := os.LookupEnv("GH_HELPER_PROCESS_RICH_COMMANDS") if !ok { fmt.Fprint(os.Stderr, "missing GH_HELPER_PROCESS_RICH_COMMANDS") - os.Exit(1) + // Exit 1 is used for empty key values in the git config. This is non-breaking in those use cases, + // so this is returning a non-zero exit code to avoid suppressing this error for those use cases. + os.Exit(16) } var commands mockedCommands if err := json.Unmarshal([]byte(jsonVar), &commands); err != nil { fmt.Fprint(os.Stderr, "failed to unmarshal GH_HELPER_PROCESS_RICH_COMMANDS") - os.Exit(1) + // Exit 1 is used for empty key values in the git config. This is non-breaking in those use cases, + // so this is returning a non-zero exit code to avoid suppressing this error for those use cases. + os.Exit(16) } // The discarded args are those for the go test binary itself, e.g. `-test.run=TestHelperProcessRich` @@ -1616,7 +1894,9 @@ func TestCommandMocking(t *testing.T) { commandResult, ok := commands[args(strings.Join(realArgs, " "))] if !ok { fmt.Fprintf(os.Stderr, "unexpected command: %s\n", strings.Join(realArgs, " ")) - os.Exit(1) + // Exit 1 is used for empty key values in the git config. This is non-breaking in those use cases, + // so this is returning a non-zero exit code to avoid suppressing this error for those use cases. + os.Exit(16) } if commandResult.Stdout != "" { diff --git a/git/objects.go b/git/objects.go index c09683042..9db528b8c 100644 --- a/git/objects.go +++ b/git/objects.go @@ -60,10 +60,14 @@ type Commit struct { Body string } +// These are the keys we read from the git branch. config. type BranchConfig struct { - RemoteName string - RemoteURL *url.URL + RemoteName string // .remote if string + RemoteURL *url.URL // .remote if url + MergeRef string // .merge + PushRemoteName string // .pushremote if string + PushRemoteURL *url.URL // .pushremote if url + // MergeBase is the optional base branch to target in a new PR if `--base` is not specified. MergeBase string - MergeRef string } diff --git a/pkg/cmd/pr/create/create.go b/pkg/cmd/pr/create/create.go index 124bd4b07..b2abe0938 100644 --- a/pkg/cmd/pr/create/create.go +++ b/pkg/cmd/pr/create/create.go @@ -518,6 +518,7 @@ func initDefaultTitleBody(ctx CreateContext, state *shared.IssueMetadataState, u return nil } +// TODO: Replace with the finder's PullRequestRefs struct // trackingRef represents a ref for a remote tracking branch. type trackingRef struct { remoteName string @@ -685,7 +686,9 @@ func NewCreateContext(opts *CreateOptions) (*CreateContext, error) { return nil, err } if isPushEnabled { - // determine whether the head branch is already pushed to a remote + // TODO: This doesn't respect the @{push} revision resolution or triagular workflows assembled with + // remote.pushDefault, or branch..pushremote config settings. The finder's ParsePRRefs + // may be able to replace this function entirely. if trackingRef, found := tryDetermineTrackingRef(gitClient, remotes, headBranch, headBranchConfig); found { isPushEnabled = false if r, err := remotes.FindByName(trackingRef.remoteName); err == nil { diff --git a/pkg/cmd/pr/create/create_test.go b/pkg/cmd/pr/create/create_test.go index 992439df1..6df3f9880 100644 --- a/pkg/cmd/pr/create/create_test.go +++ b/pkg/cmd/pr/create/create_test.go @@ -1515,7 +1515,7 @@ func Test_createRun(t *testing.T) { }, customBranchConfig: true, cmdStubs: func(cs *run.CommandStubber) { - cs.Register(`git config --get-regexp \^branch\\\.task1\\\.\(remote\|merge\|gh-merge-base\)\$`, 0, heredoc.Doc(` + cs.Register(`git config --get-regexp \^branch\\\.task1\\\.\(remote\|merge\|pushremote\|gh-merge-base\)\$`, 0, heredoc.Doc(` branch.task1.remote origin branch.task1.merge refs/heads/task1 branch.task1.gh-merge-base feature/feat2`)) // ReadBranchConfig @@ -1549,7 +1549,7 @@ func Test_createRun(t *testing.T) { defer cmdTeardown(t) cs.Register(`git status --porcelain`, 0, "") if !tt.customBranchConfig { - cs.Register(`git config --get-regexp \^branch\\\..+\\\.\(remote\|merge\|gh-merge-base\)\$`, 0, "") + cs.Register(`git config --get-regexp \^branch\\\..+\\\.\(remote\|merge\|pushremote\|gh-merge-base\)\$`, 0, "") } if tt.cmdStubs != nil { diff --git a/pkg/cmd/pr/shared/finder.go b/pkg/cmd/pr/shared/finder.go index 3f036c0cd..e4f89502c 100644 --- a/pkg/cmd/pr/shared/finder.go +++ b/pkg/cmd/pr/shared/finder.go @@ -33,16 +33,19 @@ type progressIndicator interface { } type finder struct { - baseRepoFn func() (ghrepo.Interface, error) - branchFn func() (string, error) - remotesFn func() (remotes.Remotes, error) - httpClient func() (*http.Client, error) - branchConfig func(string) (git.BranchConfig, error) - progress progressIndicator + baseRepoFn func() (ghrepo.Interface, error) + branchFn func() (string, error) + remotesFn func() (remotes.Remotes, error) + httpClient func() (*http.Client, error) + pushDefault func() (string, error) + remotePushDefault func() (string, error) + parsePushRevision func(string) (string, error) + branchConfig func(string) (git.BranchConfig, error) + progress progressIndicator - repo ghrepo.Interface - prNumber int - branchName string + baseRefRepo ghrepo.Interface + prNumber int + branchName string } func NewFinder(factory *cmdutil.Factory) PRFinder { @@ -57,7 +60,16 @@ func NewFinder(factory *cmdutil.Factory) PRFinder { branchFn: factory.Branch, remotesFn: factory.Remotes, httpClient: factory.HttpClient, - progress: factory.IOStreams, + pushDefault: func() (string, error) { + return factory.GitClient.PushDefault(context.Background()) + }, + remotePushDefault: func() (string, error) { + return factory.GitClient.RemotePushDefault(context.Background()) + }, + parsePushRevision: func(branch string) (string, error) { + return factory.GitClient.ParsePushRevision(context.Background(), branch) + }, + progress: factory.IOStreams, branchConfig: func(s string) (git.BranchConfig, error) { return factory.GitClient.ReadBranchConfig(context.Background(), s) }, @@ -85,6 +97,28 @@ type FindOptions struct { States []string } +// TODO: Does this also need the BaseBranchName? +// PR's are represented by the following: +// baseRef -----PR-----> headRef +// +// A ref is described as "remoteName/branchName", so +// baseRepoName/baseBranchName -----PR-----> headRepoName/headBranchName +type PullRequestRefs struct { + BranchName string + HeadRepo ghrepo.Interface + BaseRepo ghrepo.Interface +} + +// GetPRHeadLabel returns the string that the GitHub API uses to identify the PR. This is +// either just the branch name or, if the PR is originating from a fork, the fork owner +// and the branch name, like :. +func (s *PullRequestRefs) GetPRHeadLabel() string { + if ghrepo.IsSame(s.HeadRepo, s.BaseRepo) { + return s.BranchName + } + return fmt.Sprintf("%s:%s", s.HeadRepo.RepoOwner(), s.BranchName) +} + func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, error) { if len(opts.Fields) == 0 { return nil, nil, errors.New("Find error: no fields specified") @@ -92,26 +126,18 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err if repo, prNumber, err := f.parseURL(opts.Selector); err == nil { f.prNumber = prNumber - f.repo = repo + f.baseRefRepo = repo } - if f.repo == nil { + if f.baseRefRepo == nil { repo, err := f.baseRepoFn() if err != nil { return nil, nil, err } - f.repo = repo + f.baseRefRepo = repo } - if opts.Selector == "" { - if branch, prNumber, err := f.parseCurrentBranch(); err != nil { - return nil, nil, err - } else if prNumber > 0 { - f.prNumber = prNumber - } else { - f.branchName = branch - } - } else if f.prNumber == 0 { + if f.prNumber == 0 && opts.Selector != "" { // If opts.Selector is a valid number then assume it is the // PR number unless opts.BaseBranch is specified. This is a // special case for PR create command which will always want @@ -123,8 +149,28 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err } else { f.branchName = opts.Selector } + } else { + currentBranchName, err := f.branchFn() + if err != nil { + return nil, nil, err + } + f.branchName = currentBranchName } + // Get the branch config for the current branchName + branchConfig, err := f.branchConfig(f.branchName) + if err != nil { + return nil, nil, err + } + + // Determine if the branch is configured to merge to a special PR ref + prHeadRE := regexp.MustCompile(`^refs/pull/(\d+)/head$`) + if m := prHeadRE.FindStringSubmatch(branchConfig.MergeRef); m != nil { + prNumber, _ := strconv.Atoi(m[1]) + f.prNumber = prNumber + } + + // Set up HTTP client httpClient, err := f.httpClient() if err != nil { return nil, nil, err @@ -143,7 +189,7 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err if fields.Contains("isInMergeQueue") || fields.Contains("isMergeQueueEnabled") { cachedClient := api.NewCachedHTTPClient(httpClient, time.Hour*24) - detector := fd.NewDetector(cachedClient, f.repo.RepoHost()) + detector := fd.NewDetector(cachedClient, f.baseRefRepo.RepoHost()) prFeatures, err := detector.PullRequestFeatures() if err != nil { return nil, nil, err @@ -164,36 +210,62 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err if f.prNumber > 0 { if numberFieldOnly { // avoid hitting the API if we already have all the information - return &api.PullRequest{Number: f.prNumber}, f.repo, nil + return &api.PullRequest{Number: f.prNumber}, f.baseRefRepo, nil + } + pr, err = findByNumber(httpClient, f.baseRefRepo, f.prNumber, fields.ToSlice()) + if err != nil { + return pr, f.baseRefRepo, err } - pr, err = findByNumber(httpClient, f.repo, f.prNumber, fields.ToSlice()) } else { - pr, err = findForBranch(httpClient, f.repo, opts.BaseBranch, f.branchName, opts.States, fields.ToSlice()) - } - if err != nil { - return pr, f.repo, err + rems, err := f.remotesFn() + if err != nil { + return nil, nil, err + } + + pushDefault, err := f.pushDefault() + if err != nil { + return nil, nil, err + } + + // Suppressing these errors as we have other means of computing the PullRequestRefs when these fail. + parsedPushRevision, _ := f.parsePushRevision(f.branchName) + + remotePushDefault, err := f.remotePushDefault() + if err != nil { + return nil, nil, err + } + + prRefs, err := ParsePRRefs(f.branchName, branchConfig, parsedPushRevision, pushDefault, remotePushDefault, f.baseRefRepo, rems) + if err != nil { + return nil, nil, err + } + + pr, err = findForBranch(httpClient, f.baseRefRepo, opts.BaseBranch, prRefs.GetPRHeadLabel(), opts.States, fields.ToSlice()) + if err != nil { + return pr, f.baseRefRepo, err + } } g, _ := errgroup.WithContext(context.Background()) if fields.Contains("reviews") { g.Go(func() error { - return preloadPrReviews(httpClient, f.repo, pr) + return preloadPrReviews(httpClient, f.baseRefRepo, pr) }) } if fields.Contains("comments") { g.Go(func() error { - return preloadPrComments(httpClient, f.repo, pr) + return preloadPrComments(httpClient, f.baseRefRepo, pr) }) } if fields.Contains("statusCheckRollup") { g.Go(func() error { - return preloadPrChecks(httpClient, f.repo, pr) + return preloadPrChecks(httpClient, f.baseRefRepo, pr) }) } if getProjectItems { g.Go(func() error { apiClient := api.NewClientFromHTTP(httpClient) - err := api.ProjectsV2ItemsForPullRequest(apiClient, f.repo, pr) + err := api.ProjectsV2ItemsForPullRequest(apiClient, f.baseRefRepo, pr) if err != nil && !api.ProjectsV2IgnorableError(err) { return err } @@ -201,7 +273,7 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err }) } - return pr, f.repo, g.Wait() + return pr, f.baseRefRepo, g.Wait() } var pullURLRE = regexp.MustCompile(`^/([^/]+)/([^/]+)/pull/(\d+)`) @@ -230,59 +302,70 @@ func (f *finder) parseURL(prURL string) (ghrepo.Interface, int, error) { return repo, prNumber, nil } -var prHeadRE = regexp.MustCompile(`^refs/pull/(\d+)/head$`) - -func (f *finder) parseCurrentBranch() (string, int, error) { - prHeadRef, err := f.branchFn() - if err != nil { - return "", 0, err +func ParsePRRefs(currentBranchName string, branchConfig git.BranchConfig, parsedPushRevision string, pushDefault string, remotePushDefault string, baseRefRepo ghrepo.Interface, rems remotes.Remotes) (PullRequestRefs, error) { + prRefs := PullRequestRefs{ + BaseRepo: baseRefRepo, } - branchConfig, err := f.branchConfig(prHeadRef) - if err != nil { - return "", 0, err + // If @{push} resolves, then we have all the information we need to determine the head repo + // and branch name. It is of the form /. + if parsedPushRevision != "" { + for _, r := range rems { + // Find the remote who's name matches the push prefix + if strings.HasPrefix(parsedPushRevision, r.Name+"/") { + prRefs.BranchName = strings.TrimPrefix(parsedPushRevision, r.Name+"/") + prRefs.HeadRepo = r.Repo + return prRefs, nil + } + } + + remoteNames := make([]string, len(rems)) + for i, r := range rems { + remoteNames[i] = r.Name + } + return PullRequestRefs{}, fmt.Errorf("no remote for %q found in %q", parsedPushRevision, strings.Join(remoteNames, ", ")) } - // the branch is configured to merge a special PR head ref - if m := prHeadRE.FindStringSubmatch(branchConfig.MergeRef); m != nil { - prNumber, _ := strconv.Atoi(m[1]) - return "", prNumber, nil + // We assume the PR's branch name is the same as whatever f.BranchFn() returned earlier + // unless the user has specified push.default = upstream or tracking, then we use the + // branch name from the merge ref. + prRefs.BranchName = currentBranchName + if pushDefault == "upstream" || pushDefault == "tracking" { + prRefs.BranchName = strings.TrimPrefix(branchConfig.MergeRef, "refs/heads/") } - var gitRemoteRepo ghrepo.Interface - if branchConfig.RemoteURL != nil { - // the branch merges from a remote specified by URL - if r, err := ghrepo.FromURL(branchConfig.RemoteURL); err == nil { - gitRemoteRepo = r + // To get the HeadRepo, we look to the git config. The HeadRepo comes from one of the following, in order of precedence: + // 1. branch..pushRemote + // 2. remote.pushDefault + // 3. branch..remote + if branchConfig.PushRemoteName != "" { + if r, err := rems.FindByName(branchConfig.PushRemoteName); err == nil { + prRefs.HeadRepo = r.Repo + } + } else if branchConfig.PushRemoteURL != nil { + if r, err := ghrepo.FromURL(branchConfig.PushRemoteURL); err == nil { + prRefs.HeadRepo = r + } + } else if remotePushDefault != "" { + if r, err := rems.FindByName(remotePushDefault); err == nil { + prRefs.HeadRepo = r.Repo } } else if branchConfig.RemoteName != "" { - // the branch merges from a remote specified by name - rem, _ := f.remotesFn() - if r, err := rem.FindByName(branchConfig.RemoteName); err == nil { - gitRemoteRepo = r + if r, err := rems.FindByName(branchConfig.RemoteName); err == nil { + prRefs.HeadRepo = r.Repo + } + } else if branchConfig.RemoteURL != nil { + if r, err := ghrepo.FromURL(branchConfig.RemoteURL); err == nil { + prRefs.HeadRepo = r } } - if gitRemoteRepo != nil { - if strings.HasPrefix(branchConfig.MergeRef, "refs/heads/") { - prHeadRef = strings.TrimPrefix(branchConfig.MergeRef, "refs/heads/") - } - // prepend `OWNER:` if this branch is pushed to a fork - // This is determined by: - // - The repo having a different owner - // - The repo having the same owner but a different name (private org fork) - // I suspect that the implementation of the second case may be broken in the face - // of a repo rename, where the remote hasn't been updated locally. This is a - // frequent issue in commands that use SmartBaseRepoFunc. It's not any worse than not - // supporting this case at all though. - sameOwner := strings.EqualFold(gitRemoteRepo.RepoOwner(), f.repo.RepoOwner()) - sameOwnerDifferentRepoName := sameOwner && !strings.EqualFold(gitRemoteRepo.RepoName(), f.repo.RepoName()) - if !sameOwner || sameOwnerDifferentRepoName { - prHeadRef = fmt.Sprintf("%s:%s", gitRemoteRepo.RepoOwner(), prHeadRef) - } + // The PR merges from a branch in the same repo as the base branch (usually the default branch) + if prRefs.HeadRepo == nil { + prRefs.HeadRepo = baseRefRepo } - return prHeadRef, 0, nil + return prRefs, nil } func findByNumber(httpClient *http.Client, repo ghrepo.Interface, number int, fields []string) (*api.PullRequest, error) { @@ -315,7 +398,7 @@ func findByNumber(httpClient *http.Client, repo ghrepo.Interface, number int, fi return &resp.Repository.PullRequest, nil } -func findForBranch(httpClient *http.Client, repo ghrepo.Interface, baseBranch, headBranch string, stateFilters, fields []string) (*api.PullRequest, error) { +func findForBranch(httpClient *http.Client, repo ghrepo.Interface, baseBranch, headBranchWithOwnerIfFork string, stateFilters, fields []string) (*api.PullRequest, error) { type response struct { Repository struct { PullRequests struct { @@ -342,9 +425,9 @@ func findForBranch(httpClient *http.Client, repo ghrepo.Interface, baseBranch, h } }`, api.PullRequestGraphQL(fieldSet.ToSlice())) - branchWithoutOwner := headBranch - if idx := strings.Index(headBranch, ":"); idx >= 0 { - branchWithoutOwner = headBranch[idx+1:] + branchWithoutOwner := headBranchWithOwnerIfFork + if idx := strings.Index(headBranchWithOwnerIfFork, ":"); idx >= 0 { + branchWithoutOwner = headBranchWithOwnerIfFork[idx+1:] } variables := map[string]interface{}{ @@ -367,18 +450,17 @@ func findForBranch(httpClient *http.Client, repo ghrepo.Interface, baseBranch, h }) for _, pr := range prs { - headBranchMatches := pr.HeadLabel() == headBranch + headBranchMatches := pr.HeadLabel() == headBranchWithOwnerIfFork baseBranchEmptyOrMatches := baseBranch == "" || pr.BaseRefName == baseBranch // When the head is the default branch, it doesn't really make sense to show merged or closed PRs. // https://github.com/cli/cli/issues/4263 - isNotClosedOrMergedWhenHeadIsDefault := pr.State == "OPEN" || resp.Repository.DefaultBranchRef.Name != headBranch - + isNotClosedOrMergedWhenHeadIsDefault := pr.State == "OPEN" || resp.Repository.DefaultBranchRef.Name != headBranchWithOwnerIfFork if headBranchMatches && baseBranchEmptyOrMatches && isNotClosedOrMergedWhenHeadIsDefault { return &pr, nil } } - return nil, &NotFoundError{fmt.Errorf("no pull requests found for branch %q", headBranch)} + return nil, &NotFoundError{fmt.Errorf("no pull requests found for branch %q", headBranchWithOwnerIfFork)} } func preloadPrReviews(httpClient *http.Client, repo ghrepo.Interface, pr *api.PullRequest) error { diff --git a/pkg/cmd/pr/shared/finder_test.go b/pkg/cmd/pr/shared/finder_test.go index dd96e684a..694e0c20d 100644 --- a/pkg/cmd/pr/shared/finder_test.go +++ b/pkg/cmd/pr/shared/finder_test.go @@ -2,6 +2,7 @@ package shared import ( "errors" + "fmt" "net/http" "net/url" "testing" @@ -15,16 +16,50 @@ import ( ) type args struct { - baseRepoFn func() (ghrepo.Interface, error) - branchFn func() (string, error) - branchConfig func(string) (git.BranchConfig, error) - remotesFn func() (context.Remotes, error) - selector string - fields []string - baseBranch string + baseRepoFn func() (ghrepo.Interface, error) + branchFn func() (string, error) + branchConfig func(string) (git.BranchConfig, error) + pushDefault func() (string, error) + remotePushDefault func() (string, error) + parsePushRevision func(string) (string, error) + selector string + fields []string + baseBranch string } func TestFind(t *testing.T) { + // TODO: Abstract these out meaningfully for reuse in parsePRRefs tests + originOwnerUrl, err := url.Parse("https://github.com/ORIGINOWNER/REPO.git") + if err != nil { + t.Fatal(err) + } + remoteOrigin := context.Remote{ + Remote: &git.Remote{ + Name: "origin", + FetchURL: originOwnerUrl, + }, + Repo: ghrepo.New("ORIGINOWNER", "REPO"), + } + remoteOther := context.Remote{ + Remote: &git.Remote{ + Name: "other", + FetchURL: originOwnerUrl, + }, + Repo: ghrepo.New("ORIGINOWNER", "OTHER-REPO"), + } + + upstreamOwnerUrl, err := url.Parse("https://github.com/UPSTREAMOWNER/REPO.git") + if err != nil { + t.Fatal(err) + } + remoteUpstream := context.Remote{ + Remote: &git.Remote{ + Name: "upstream", + FetchURL: upstreamOwnerUrl, + }, + Repo: ghrepo.New("UPSTREAMOWNER", "REPO"), + } + tests := []struct { name string args args @@ -36,11 +71,13 @@ func TestFind(t *testing.T) { { name: "number argument", args: args{ - selector: "13", - fields: []string{"id", "number"}, - baseRepoFn: func() (ghrepo.Interface, error) { - return ghrepo.FromFullName("OWNER/REPO") + selector: "13", + fields: []string{"id", "number"}, + baseRepoFn: stubBaseRepoFn(ghrepo.New("ORIGINOWNER", "REPO"), nil), + branchFn: func() (string, error) { + return "blueberries", nil }, + branchConfig: stubBranchConfig(git.BranchConfig{}, nil), }, httpStub: func(r *httpmock.Registry) { r.Register( @@ -50,7 +87,7 @@ func TestFind(t *testing.T) { }}}`)) }, wantPR: 13, - wantRepo: "https://github.com/OWNER/REPO", + wantRepo: "https://github.com/ORIGINOWNER/REPO", }, { name: "number argument with base branch", @@ -58,9 +95,16 @@ func TestFind(t *testing.T) { selector: "13", baseBranch: "main", fields: []string{"id", "number"}, - baseRepoFn: func() (ghrepo.Interface, error) { - return ghrepo.FromFullName("OWNER/REPO") + baseRepoFn: stubBaseRepoFn(ghrepo.New("ORIGINOWNER", "REPO"), nil), + branchFn: func() (string, error) { + return "blueberries", nil }, + branchConfig: stubBranchConfig(git.BranchConfig{ + PushRemoteName: remoteOrigin.Remote.Name, + }, nil), + pushDefault: stubPushDefault("simple", nil), + remotePushDefault: stubRemotePushDefault("", nil), + parsePushRevision: stubParsedPushRevision("", nil), }, httpStub: func(r *httpmock.Registry) { r.Register( @@ -73,22 +117,26 @@ func TestFind(t *testing.T) { "baseRefName": "main", "headRefName": "13", "isCrossRepository": false, - "headRepositoryOwner": {"login":"OWNER"} + "headRepositoryOwner": {"login":"ORIGINOWNER"} } ]} }}}`)) }, wantPR: 123, - wantRepo: "https://github.com/OWNER/REPO", + wantRepo: "https://github.com/ORIGINOWNER/REPO", }, { name: "baseRepo is error", args: args{ - selector: "13", - fields: []string{"id", "number"}, - baseRepoFn: func() (ghrepo.Interface, error) { - return nil, errors.New("baseRepoErr") + selector: "13", + fields: []string{"id", "number"}, + baseRepoFn: stubBaseRepoFn(nil, errors.New("baseRepoErr")), + branchFn: func() (string, error) { + return "blueberries", nil }, + branchConfig: stubBranchConfig(git.BranchConfig{}, nil), + pushDefault: stubPushDefault("simple", nil), + remotePushDefault: stubRemotePushDefault("", nil), }, wantErr: true, }, @@ -103,24 +151,32 @@ func TestFind(t *testing.T) { { name: "number only", args: args{ - selector: "13", - fields: []string{"number"}, - baseRepoFn: func() (ghrepo.Interface, error) { - return ghrepo.FromFullName("OWNER/REPO") + selector: "13", + fields: []string{"number"}, + baseRepoFn: stubBaseRepoFn(ghrepo.New("ORIGINOWNER", "REPO"), nil), + branchFn: func() (string, error) { + return "blueberries", nil }, + branchConfig: stubBranchConfig(git.BranchConfig{}, nil), + pushDefault: stubPushDefault("simple", nil), + remotePushDefault: stubRemotePushDefault("", nil), }, httpStub: nil, wantPR: 13, - wantRepo: "https://github.com/OWNER/REPO", + wantRepo: "https://github.com/ORIGINOWNER/REPO", }, { name: "number with hash argument", args: args{ - selector: "#13", - fields: []string{"id", "number"}, - baseRepoFn: func() (ghrepo.Interface, error) { - return ghrepo.FromFullName("OWNER/REPO") + selector: "#13", + fields: []string{"id", "number"}, + baseRepoFn: stubBaseRepoFn(ghrepo.New("ORIGINOWNER", "REPO"), nil), + branchFn: func() (string, error) { + return "blueberries", nil }, + branchConfig: stubBranchConfig(git.BranchConfig{}, nil), + pushDefault: stubPushDefault("simple", nil), + remotePushDefault: stubRemotePushDefault("", nil), }, httpStub: func(r *httpmock.Registry) { r.Register( @@ -130,14 +186,20 @@ func TestFind(t *testing.T) { }}}`)) }, wantPR: 13, - wantRepo: "https://github.com/OWNER/REPO", + wantRepo: "https://github.com/ORIGINOWNER/REPO", }, { - name: "URL argument", + name: "PR URL argument", args: args{ selector: "https://example.org/OWNER/REPO/pull/13/files", fields: []string{"id", "number"}, baseRepoFn: nil, + branchFn: func() (string, error) { + return "blueberries", nil + }, + branchConfig: stubBranchConfig(git.BranchConfig{}, nil), + pushDefault: stubPushDefault("simple", nil), + remotePushDefault: stubRemotePushDefault("", nil), }, httpStub: func(r *httpmock.Registry) { r.Register( @@ -150,13 +212,18 @@ func TestFind(t *testing.T) { wantRepo: "https://example.org/OWNER/REPO", }, { - name: "branch argument", + name: "when provided branch argument with an open and closed PR for that branch name, it returns the open PR", args: args{ - selector: "blueberries", - fields: []string{"id", "number"}, - baseRepoFn: func() (ghrepo.Interface, error) { - return ghrepo.FromFullName("OWNER/REPO") + selector: "blueberries", + fields: []string{"id", "number"}, + baseRepoFn: stubBaseRepoFn(ghrepo.New("ORIGINOWNER", "REPO"), nil), + branchFn: func() (string, error) { + return "blueberries", nil }, + branchConfig: stubBranchConfig(git.BranchConfig{}, nil), + pushDefault: stubPushDefault("simple", nil), + parsePushRevision: stubParsedPushRevision("", nil), + remotePushDefault: stubRemotePushDefault("", nil), }, httpStub: func(r *httpmock.Registry) { r.Register( @@ -169,7 +236,7 @@ func TestFind(t *testing.T) { "baseRefName": "main", "headRefName": "blueberries", "isCrossRepository": false, - "headRepositoryOwner": {"login":"OWNER"} + "headRepositoryOwner": {"login":"ORIGINOWNER"} }, { "number": 13, @@ -177,13 +244,13 @@ func TestFind(t *testing.T) { "baseRefName": "main", "headRefName": "blueberries", "isCrossRepository": false, - "headRepositoryOwner": {"login":"OWNER"} + "headRepositoryOwner": {"login":"ORIGINOWNER"} } ]} }}}`)) }, wantPR: 13, - wantRepo: "https://github.com/OWNER/REPO", + wantRepo: "https://github.com/ORIGINOWNER/REPO", }, { name: "branch argument with base branch", @@ -194,6 +261,13 @@ func TestFind(t *testing.T) { baseRepoFn: func() (ghrepo.Interface, error) { return ghrepo.FromFullName("OWNER/REPO") }, + branchFn: func() (string, error) { + return "blueberries", nil + }, + branchConfig: stubBranchConfig(git.BranchConfig{}, nil), + pushDefault: stubPushDefault("simple", nil), + remotePushDefault: stubRemotePushDefault("", nil), + parsePushRevision: stubParsedPushRevision("", nil), }, httpStub: func(r *httpmock.Registry) { r.Register( @@ -233,7 +307,10 @@ func TestFind(t *testing.T) { branchFn: func() (string, error) { return "blueberries", nil }, - branchConfig: stubBranchConfig(git.BranchConfig{}, nil), + branchConfig: stubBranchConfig(git.BranchConfig{}, nil), + pushDefault: stubPushDefault("simple", nil), + remotePushDefault: stubRemotePushDefault("", nil), + parsePushRevision: stubParsedPushRevision("", nil), }, httpStub: func(r *httpmock.Registry) { r.Register( @@ -265,7 +342,10 @@ func TestFind(t *testing.T) { branchFn: func() (string, error) { return "blueberries", nil }, - branchConfig: stubBranchConfig(git.BranchConfig{}, nil), + branchConfig: stubBranchConfig(git.BranchConfig{}, nil), + pushDefault: stubPushDefault("simple", nil), + remotePushDefault: stubRemotePushDefault("", nil), + parsePushRevision: stubParsedPushRevision("", nil), }, httpStub: func(r *httpmock.Registry) { r.Register( @@ -303,26 +383,21 @@ func TestFind(t *testing.T) { wantErr: true, }, { - name: "current branch with upstream configuration", + name: "when the current branch is configured to push to and pull from 'upstream' and push.default = upstream but the repo push/pulls from 'origin', it finds the PR associated with the upstream repo and returns origin as the base repo", args: args{ - selector: "", - fields: []string{"id", "number"}, - baseRepoFn: func() (ghrepo.Interface, error) { - return ghrepo.FromFullName("OWNER/REPO") - }, + selector: "", + fields: []string{"id", "number"}, + baseRepoFn: stubBaseRepoFn(ghrepo.New("ORIGINOWNER", "REPO"), nil), branchFn: func() (string, error) { return "blueberries", nil }, branchConfig: stubBranchConfig(git.BranchConfig{ - MergeRef: "refs/heads/blue-upstream-berries", - RemoteName: "origin", + MergeRef: "refs/heads/blue-upstream-berries", + PushRemoteName: "upstream", }, nil), - remotesFn: func() (context.Remotes, error) { - return context.Remotes{{ - Remote: &git.Remote{Name: "origin"}, - Repo: ghrepo.New("UPSTREAMOWNER", "REPO"), - }}, nil - }, + pushDefault: stubPushDefault("upstream", nil), + remotePushDefault: stubRemotePushDefault("", nil), + parsePushRevision: stubParsedPushRevision("", nil), }, httpStub: func(r *httpmock.Registry) { r.Register( @@ -341,27 +416,28 @@ func TestFind(t *testing.T) { }}}`)) }, wantPR: 13, - wantRepo: "https://github.com/OWNER/REPO", + wantRepo: "https://github.com/ORIGINOWNER/REPO", }, { - name: "current branch with upstream RemoteURL configuration", + // The current BRANCH is configured to push to and pull from a URL (upstream, in this example) + // which is different from what the REPO is configured to push to and pull from (origin, in this example) + // and push.default = upstream. It should find the PR associated with the upstream repo and return + // origin as the base repo + name: "when push.default = upstream and the current branch is configured to push/pull from a different remote than the repo", args: args{ - selector: "", - fields: []string{"id", "number"}, - baseRepoFn: func() (ghrepo.Interface, error) { - return ghrepo.FromFullName("OWNER/REPO") - }, + selector: "", + fields: []string{"id", "number"}, + baseRepoFn: stubBaseRepoFn(ghrepo.New("ORIGINOWNER", "REPO"), nil), branchFn: func() (string, error) { return "blueberries", nil }, - branchConfig: func(branch string) (git.BranchConfig, error) { - u, _ := url.Parse("https://github.com/UPSTREAMOWNER/REPO") - return stubBranchConfig(git.BranchConfig{ - MergeRef: "refs/heads/blue-upstream-berries", - RemoteURL: u, - }, nil)(branch) - }, - remotesFn: nil, + branchConfig: stubBranchConfig(git.BranchConfig{ + MergeRef: "refs/heads/blue-upstream-berries", + PushRemoteURL: remoteUpstream.Remote.FetchURL, + }, nil), + pushDefault: stubPushDefault("upstream", nil), + remotePushDefault: stubRemotePushDefault("", nil), + parsePushRevision: stubParsedPushRevision("", nil), }, httpStub: func(r *httpmock.Registry) { r.Register( @@ -380,28 +456,21 @@ func TestFind(t *testing.T) { }}}`)) }, wantPR: 13, - wantRepo: "https://github.com/OWNER/REPO", + wantRepo: "https://github.com/ORIGINOWNER/REPO", }, { name: "current branch with upstream and fork in same org", args: args{ - selector: "", - fields: []string{"id", "number"}, - baseRepoFn: func() (ghrepo.Interface, error) { - return ghrepo.FromFullName("OWNER/REPO") - }, + selector: "", + fields: []string{"id", "number"}, + baseRepoFn: stubBaseRepoFn(ghrepo.New("ORIGINOWNER", "REPO"), nil), branchFn: func() (string, error) { return "blueberries", nil }, - branchConfig: stubBranchConfig(git.BranchConfig{ - RemoteName: "origin", - }, nil), - remotesFn: func() (context.Remotes, error) { - return context.Remotes{{ - Remote: &git.Remote{Name: "origin"}, - Repo: ghrepo.New("OWNER", "REPO-FORK"), - }}, nil - }, + branchConfig: stubBranchConfig(git.BranchConfig{}, nil), + pushDefault: stubPushDefault("simple", nil), + remotePushDefault: stubRemotePushDefault("", nil), + parsePushRevision: stubParsedPushRevision("other/blueberries", nil), }, httpStub: func(r *httpmock.Registry) { r.Register( @@ -414,13 +483,13 @@ func TestFind(t *testing.T) { "baseRefName": "main", "headRefName": "blueberries", "isCrossRepository": true, - "headRepositoryOwner": {"login":"OWNER"} + "headRepositoryOwner": {"login":"ORIGINOWNER"} } ]} }}}`)) }, wantPR: 13, - wantRepo: "https://github.com/OWNER/REPO", + wantRepo: "https://github.com/ORIGINOWNER/REPO", }, { name: "current branch made by pr checkout", @@ -461,6 +530,8 @@ func TestFind(t *testing.T) { branchConfig: stubBranchConfig(git.BranchConfig{ MergeRef: "refs/pull/13/head", }, nil), + pushDefault: stubPushDefault("simple", nil), + remotePushDefault: stubRemotePushDefault("", nil), }, httpStub: func(r *httpmock.Registry) { r.Register( @@ -521,10 +592,17 @@ func TestFind(t *testing.T) { httpClient: func() (*http.Client, error) { return &http.Client{Transport: reg}, nil }, - baseRepoFn: tt.args.baseRepoFn, - branchFn: tt.args.branchFn, - branchConfig: tt.args.branchConfig, - remotesFn: tt.args.remotesFn, + baseRepoFn: tt.args.baseRepoFn, + branchFn: tt.args.branchFn, + branchConfig: tt.args.branchConfig, + pushDefault: tt.args.pushDefault, + remotePushDefault: tt.args.remotePushDefault, + parsePushRevision: tt.args.parsePushRevision, + remotesFn: stubRemotes(context.Remotes{ + &remoteOrigin, + &remoteOther, + &remoteUpstream, + }, nil), } pr, repo, err := f.Find(FindOptions{ @@ -557,42 +635,307 @@ func TestFind(t *testing.T) { } } -func Test_parseCurrentBranch(t *testing.T) { +func TestParsePRRefs(t *testing.T) { + originOwnerUrl, err := url.Parse("https://github.com/ORIGINOWNER/REPO.git") + if err != nil { + t.Fatal(err) + } + remoteOrigin := context.Remote{ + Remote: &git.Remote{ + Name: "origin", + FetchURL: originOwnerUrl, + }, + Repo: ghrepo.New("ORIGINOWNER", "REPO"), + } + remoteOther := context.Remote{ + Remote: &git.Remote{ + Name: "other", + FetchURL: originOwnerUrl, + }, + Repo: ghrepo.New("ORIGINOWNER", "REPO"), + } + + upstreamOwnerUrl, err := url.Parse("https://github.com/UPSTREAMOWNER/REPO.git") + if err != nil { + t.Fatal(err) + } + remoteUpstream := context.Remote{ + Remote: &git.Remote{ + Name: "upstream", + FetchURL: upstreamOwnerUrl, + }, + Repo: ghrepo.New("UPSTREAMOWNER", "REPO"), + } + tests := []struct { - name string - args args - wantSelector string - wantPR int - wantError error + name string + branchConfig git.BranchConfig + pushDefault string + parsedPushRevision string + remotePushDefault string + currentBranchName string + baseRefRepo ghrepo.Interface + rems context.Remotes + wantPRRefs PullRequestRefs + wantErr error }{ { - name: "failed branch config", - args: args{ - branchConfig: stubBranchConfig(git.BranchConfig{}, errors.New("branchConfigErr")), - branchFn: func() (string, error) { - return "blueberries", nil - }, + name: "When the branch is called 'blueberries' with an empty branch config, it returns the correct PullRequestRefs", + branchConfig: git.BranchConfig{}, + currentBranchName: "blueberries", + baseRefRepo: remoteOrigin.Repo, + wantPRRefs: PullRequestRefs{ + BranchName: "blueberries", + HeadRepo: remoteOrigin.Repo, + BaseRepo: remoteOrigin.Repo, }, - wantSelector: "", - wantPR: 0, - wantError: errors.New("branchConfigErr"), + wantErr: nil, + }, + { + name: "When the branch is called 'otherBranch' with an empty branch config, it returns the correct PullRequestRefs", + branchConfig: git.BranchConfig{}, + currentBranchName: "otherBranch", + baseRefRepo: remoteOrigin.Repo, + wantPRRefs: PullRequestRefs{ + BranchName: "otherBranch", + HeadRepo: remoteOrigin.Repo, + BaseRepo: remoteOrigin.Repo, + }, + wantErr: nil, + }, + { + name: "When the branch name doesn't match the branch name in BranchConfig.Push, it returns the BranchConfig.Push branch name", + parsedPushRevision: "origin/pushBranch", + currentBranchName: "blueberries", + baseRefRepo: remoteOrigin.Repo, + rems: context.Remotes{ + &remoteOrigin, + }, + wantPRRefs: PullRequestRefs{ + BranchName: "pushBranch", + HeadRepo: remoteOrigin.Repo, + BaseRepo: remoteOrigin.Repo, + }, + wantErr: nil, + }, + { + name: "When the push revision doesn't match a remote, it returns an error", + parsedPushRevision: "origin/differentPushBranch", + currentBranchName: "blueberries", + baseRefRepo: remoteOrigin.Repo, + rems: context.Remotes{ + &remoteUpstream, + &remoteOther, + }, + wantPRRefs: PullRequestRefs{}, + wantErr: fmt.Errorf("no remote for %q found in %q", "origin/differentPushBranch", "upstream, other"), + }, + { + name: "When the branch name doesn't match a different branch name in BranchConfig.Push and the remote isn't 'origin', it returns the BranchConfig.Push branch name", + parsedPushRevision: "other/pushBranch", + currentBranchName: "blueberries", + baseRefRepo: remoteOrigin.Repo, + rems: context.Remotes{ + &remoteOther, + }, + wantPRRefs: PullRequestRefs{ + BranchName: "pushBranch", + HeadRepo: remoteOther.Repo, + BaseRepo: remoteOrigin.Repo, + }, + wantErr: nil, + }, + { + name: "When the push remote is the same as the baseRepo, it returns the baseRepo as the PullRequestRefs HeadRepo", + branchConfig: git.BranchConfig{ + PushRemoteName: remoteOrigin.Remote.Name, + }, + currentBranchName: "blueberries", + baseRefRepo: remoteOrigin.Repo, + rems: context.Remotes{ + &remoteOrigin, + &remoteUpstream, + }, + wantPRRefs: PullRequestRefs{ + BranchName: "blueberries", + HeadRepo: remoteOrigin.Repo, + BaseRepo: remoteOrigin.Repo, + }, + wantErr: nil, + }, + { + name: "When the push remote is different from the baseRepo, it returns the push remote repo as the PullRequestRefs HeadRepo", + branchConfig: git.BranchConfig{ + PushRemoteName: remoteOrigin.Remote.Name, + }, + currentBranchName: "blueberries", + baseRefRepo: remoteUpstream.Repo, + rems: context.Remotes{ + &remoteOrigin, + &remoteUpstream, + }, + wantPRRefs: PullRequestRefs{ + BranchName: "blueberries", + HeadRepo: remoteOrigin.Repo, + BaseRepo: remoteUpstream.Repo, + }, + wantErr: nil, + }, + { + name: "When the push remote defined by a URL and the baseRepo is different from the push remote, it returns the push remote repo as the PullRequestRefs HeadRepo", + branchConfig: git.BranchConfig{ + PushRemoteURL: remoteOrigin.Remote.FetchURL, + }, + currentBranchName: "blueberries", + baseRefRepo: remoteUpstream.Repo, + rems: context.Remotes{ + &remoteOrigin, + &remoteUpstream, + }, + wantPRRefs: PullRequestRefs{ + BranchName: "blueberries", + HeadRepo: remoteOrigin.Repo, + BaseRepo: remoteUpstream.Repo, + }, + wantErr: nil, + }, + { + name: "When the push remote and merge ref are configured to a different repo and push.default = upstream, it should return the branch name from the other repo", + branchConfig: git.BranchConfig{ + PushRemoteName: remoteUpstream.Remote.Name, + MergeRef: "refs/heads/blue-upstream-berries", + }, + pushDefault: "upstream", + currentBranchName: "blueberries", + baseRefRepo: remoteOrigin.Repo, + rems: context.Remotes{ + &remoteOrigin, + &remoteUpstream, + }, + wantPRRefs: PullRequestRefs{ + BranchName: "blue-upstream-berries", + HeadRepo: remoteUpstream.Repo, + BaseRepo: remoteOrigin.Repo, + }, + wantErr: nil, + }, + { + name: "When the push remote and merge ref are configured to a different repo and push.default = tracking, it should return the branch name from the other repo", + branchConfig: git.BranchConfig{ + PushRemoteName: remoteUpstream.Remote.Name, + MergeRef: "refs/heads/blue-upstream-berries", + }, + pushDefault: "tracking", + currentBranchName: "blueberries", + baseRefRepo: remoteOrigin.Repo, + rems: context.Remotes{ + &remoteOrigin, + &remoteUpstream, + }, + wantPRRefs: PullRequestRefs{ + BranchName: "blue-upstream-berries", + HeadRepo: remoteUpstream.Repo, + BaseRepo: remoteOrigin.Repo, + }, + wantErr: nil, + }, + { + name: "When remote.pushDefault is set, it returns the correct PullRequestRefs", + branchConfig: git.BranchConfig{}, + remotePushDefault: remoteUpstream.Remote.Name, + currentBranchName: "blueberries", + baseRefRepo: remoteOrigin.Repo, + rems: context.Remotes{ + &remoteOrigin, + &remoteUpstream, + }, + wantPRRefs: PullRequestRefs{ + BranchName: "blueberries", + HeadRepo: remoteUpstream.Repo, + BaseRepo: remoteOrigin.Repo, + }, + wantErr: nil, + }, + { + name: "When the remote name is set on the branch, it returns the correct PullRequestRefs", + branchConfig: git.BranchConfig{ + RemoteName: remoteUpstream.Remote.Name, + }, + currentBranchName: "blueberries", + baseRefRepo: remoteOrigin.Repo, + rems: context.Remotes{ + &remoteOrigin, + &remoteUpstream, + }, + wantPRRefs: PullRequestRefs{ + BranchName: "blueberries", + HeadRepo: remoteUpstream.Repo, + BaseRepo: remoteOrigin.Repo, + }, + wantErr: nil, + }, + { + name: "When the remote URL is set on the branch, it returns the correct PullRequestRefs", + branchConfig: git.BranchConfig{ + RemoteURL: remoteUpstream.Remote.FetchURL, + }, + currentBranchName: "blueberries", + baseRefRepo: remoteOrigin.Repo, + rems: context.Remotes{ + &remoteOrigin, + &remoteUpstream, + }, + wantPRRefs: PullRequestRefs{ + BranchName: "blueberries", + HeadRepo: remoteUpstream.Repo, + BaseRepo: remoteOrigin.Repo, + }, + wantErr: nil, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - f := finder{ - httpClient: func() (*http.Client, error) { - return &http.Client{}, nil - }, - baseRepoFn: tt.args.baseRepoFn, - branchFn: tt.args.branchFn, - branchConfig: tt.args.branchConfig, - remotesFn: tt.args.remotesFn, + prRefs, err := ParsePRRefs(tt.currentBranchName, tt.branchConfig, tt.parsedPushRevision, tt.pushDefault, tt.remotePushDefault, tt.baseRefRepo, tt.rems) + if tt.wantErr != nil { + require.Equal(t, tt.wantErr, err) + } else { + require.NoError(t, err) } - selector, pr, err := f.parseCurrentBranch() - assert.Equal(t, tt.wantSelector, selector) - assert.Equal(t, tt.wantPR, pr) - assert.Equal(t, tt.wantError, err) + require.Equal(t, tt.wantPRRefs, prRefs) + }) + } +} + +func TestPRRefs_GetPRHeadLabel(t *testing.T) { + originRepo := ghrepo.New("ORIGINOWNER", "REPO") + upstreamRepo := ghrepo.New("UPSTREAMOWNER", "REPO") + tests := []struct { + name string + prRefs PullRequestRefs + want string + }{ + { + name: "When the HeadRepo and BaseRepo match, it returns the branch name", + prRefs: PullRequestRefs{ + BranchName: "blueberries", + HeadRepo: originRepo, + BaseRepo: originRepo, + }, + want: "blueberries", + }, + { + name: "When the HeadRepo and BaseRepo do not match, it returns the prepended HeadRepo owner to the branch name", + prRefs: PullRequestRefs{ + BranchName: "blueberries", + HeadRepo: originRepo, + BaseRepo: upstreamRepo, + }, + want: "ORIGINOWNER:blueberries", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, tt.prRefs.GetPRHeadLabel()) }) } } @@ -602,3 +945,33 @@ func stubBranchConfig(branchConfig git.BranchConfig, err error) func(string) (gi return branchConfig, err } } + +func stubRemotes(remotes context.Remotes, err error) func() (context.Remotes, error) { + return func() (context.Remotes, error) { + return remotes, err + } +} + +func stubBaseRepoFn(baseRepo ghrepo.Interface, err error) func() (ghrepo.Interface, error) { + return func() (ghrepo.Interface, error) { + return baseRepo, err + } +} + +func stubPushDefault(pushDefault string, err error) func() (string, error) { + return func() (string, error) { + return pushDefault, err + } +} + +func stubRemotePushDefault(remotePushDefault string, err error) func() (string, error) { + return func() (string, error) { + return remotePushDefault, err + } +} + +func stubParsedPushRevision(parsedPushRevision string, err error) func(string) (string, error) { + return func(_ string) (string, error) { + return parsedPushRevision, err + } +} diff --git a/pkg/cmd/pr/status/status.go b/pkg/cmd/pr/status/status.go index a97a59e44..3877fb1cf 100644 --- a/pkg/cmd/pr/status/status.go +++ b/pkg/cmd/pr/status/status.go @@ -7,7 +7,6 @@ import ( "net/http" "regexp" "strconv" - "strings" "time" "github.com/cli/cli/v2/api" @@ -78,27 +77,56 @@ func statusRun(opts *StatusOptions) error { return err } - baseRepo, err := opts.BaseRepo() + baseRefRepo, err := opts.BaseRepo() if err != nil { return err } - var currentBranch string + var currentBranchName string var currentPRNumber int - var currentPRHeadRef string + var currentHeadRefBranchName string if !opts.HasRepoOverride { - currentBranch, err = opts.Branch() + currentBranchName, err = opts.Branch() if err != nil && !errors.Is(err, git.ErrNotOnAnyBranch) { return fmt.Errorf("could not query for pull request for current branch: %w", err) } - remotes, _ := opts.Remotes() - branchConfig, err := opts.GitClient.ReadBranchConfig(ctx, currentBranch) + branchConfig, err := opts.GitClient.ReadBranchConfig(ctx, currentBranchName) if err != nil { return err } - currentPRNumber, currentPRHeadRef, err = prSelectorForCurrentBranch(branchConfig, baseRepo, currentBranch, remotes) + // Determine if the branch is configured to merge to a special PR ref + prHeadRE := regexp.MustCompile(`^refs/pull/(\d+)/head$`) + if m := prHeadRE.FindStringSubmatch(branchConfig.MergeRef); m != nil { + currentPRNumber, _ = strconv.Atoi(m[1]) + } + + if currentPRNumber == 0 { + remotes, err := opts.Remotes() + if err != nil { + return err + } + // Suppressing these errors as we have other means of computing the PullRequestRefs when these fail. + parsedPushRevision, _ := opts.GitClient.ParsePushRevision(ctx, currentBranchName) + + remotePushDefault, err := opts.GitClient.RemotePushDefault(ctx) + if err != nil { + return err + } + + pushDefault, err := opts.GitClient.PushDefault(ctx) + if err != nil { + return err + } + + prRefs, err := shared.ParsePRRefs(currentBranchName, branchConfig, parsedPushRevision, pushDefault, remotePushDefault, baseRefRepo, remotes) + if err != nil { + return err + } + currentHeadRefBranchName = prRefs.BranchName + } + if err != nil { return fmt.Errorf("could not query for pull request for current branch: %w", err) } @@ -107,7 +135,7 @@ func statusRun(opts *StatusOptions) error { options := requestOptions{ Username: "@me", CurrentPR: currentPRNumber, - HeadRef: currentPRHeadRef, + HeadRef: currentHeadRefBranchName, ConflictStatus: opts.ConflictStatus, } if opts.Exporter != nil { @@ -116,7 +144,7 @@ func statusRun(opts *StatusOptions) error { if opts.Detector == nil { cachedClient := api.NewCachedHTTPClient(httpClient, time.Hour*24) - opts.Detector = fd.NewDetector(cachedClient, baseRepo.RepoHost()) + opts.Detector = fd.NewDetector(cachedClient, baseRefRepo.RepoHost()) } prFeatures, err := opts.Detector.PullRequestFeatures() if err != nil { @@ -124,7 +152,7 @@ func statusRun(opts *StatusOptions) error { } options.CheckRunAndStatusContextCountsSupported = prFeatures.CheckRunAndStatusContextCounts - prPayload, err := pullRequestStatus(httpClient, baseRepo, options) + prPayload, err := pullRequestStatus(httpClient, baseRefRepo, options) if err != nil { return err } @@ -151,21 +179,21 @@ func statusRun(opts *StatusOptions) error { cs := opts.IO.ColorScheme() fmt.Fprintln(out, "") - fmt.Fprintf(out, "Relevant pull requests in %s\n", ghrepo.FullName(baseRepo)) + fmt.Fprintf(out, "Relevant pull requests in %s\n", ghrepo.FullName(baseRefRepo)) fmt.Fprintln(out, "") if !opts.HasRepoOverride { shared.PrintHeader(opts.IO, "Current branch") currentPR := prPayload.CurrentPR - if currentPR != nil && currentPR.State != "OPEN" && prPayload.DefaultBranch == currentBranch { + if currentPR != nil && currentPR.State != "OPEN" && prPayload.DefaultBranch == currentBranchName { currentPR = nil } if currentPR != nil { printPrs(opts.IO, 1, *currentPR) - } else if currentPRHeadRef == "" { + } else if currentHeadRefBranchName == "" { shared.PrintMessage(opts.IO, " There is no current branch") } else { - shared.PrintMessage(opts.IO, fmt.Sprintf(" There is no pull request associated with %s", cs.Cyan("["+currentPRHeadRef+"]"))) + shared.PrintMessage(opts.IO, fmt.Sprintf(" There is no pull request associated with %s", cs.Cyan("["+currentHeadRefBranchName+"]"))) } fmt.Fprintln(out) } @@ -189,55 +217,6 @@ func statusRun(opts *StatusOptions) error { return nil } -func prSelectorForCurrentBranch(branchConfig git.BranchConfig, baseRepo ghrepo.Interface, prHeadRef string, rem ghContext.Remotes) (int, string, error) { - // the branch is configured to merge a special PR head ref - prHeadRE := regexp.MustCompile(`^refs/pull/(\d+)/head$`) - if m := prHeadRE.FindStringSubmatch(branchConfig.MergeRef); m != nil { - prNumber, err := strconv.Atoi(m[1]) - if err != nil { - return 0, "", err - } - return prNumber, prHeadRef, nil - } - - var branchOwner string - if branchConfig.RemoteURL != nil { - // the branch merges from a remote specified by URL - r, err := ghrepo.FromURL(branchConfig.RemoteURL) - if err != nil { - // TODO: We aren't returning the error because we discovered that it was shadowed - // before refactoring to its current return pattern. Thus, we aren't confident - // that returning the error won't break existing behavior. - return 0, prHeadRef, nil - } - branchOwner = r.RepoOwner() - } else if branchConfig.RemoteName != "" { - // the branch merges from a remote specified by name - r, err := rem.FindByName(branchConfig.RemoteName) - if err != nil { - // TODO: We aren't returning the error because we discovered that it was shadowed - // before refactoring to its current return pattern. Thus, we aren't confident - // that returning the error won't break existing behavior. - return 0, prHeadRef, nil - } - branchOwner = r.RepoOwner() - } - - if branchOwner != "" { - selector := prHeadRef - if strings.HasPrefix(branchConfig.MergeRef, "refs/heads/") { - selector = strings.TrimPrefix(branchConfig.MergeRef, "refs/heads/") - } - // prepend `OWNER:` if this branch is pushed to a fork - if !strings.EqualFold(branchOwner, baseRepo.RepoOwner()) { - selector = fmt.Sprintf("%s:%s", branchOwner, selector) - } - return 0, selector, nil - } - - return 0, prHeadRef, nil -} - func totalApprovals(pr *api.PullRequest) int { approvals := 0 for _, review := range pr.LatestReviews.Nodes { diff --git a/pkg/cmd/pr/status/status_test.go b/pkg/cmd/pr/status/status_test.go index 43cce5f4c..c55604c28 100644 --- a/pkg/cmd/pr/status/status_test.go +++ b/pkg/cmd/pr/status/status_test.go @@ -4,7 +4,6 @@ import ( "bytes" "io" "net/http" - "net/url" "regexp" "strings" "testing" @@ -96,10 +95,13 @@ func TestPRStatus(t *testing.T) { defer http.Verify(t) http.Register(httpmock.GraphQL(`query PullRequestStatus\b`), httpmock.FileResponse("./fixtures/prStatus.json")) - // stub successful git command + // stub successful git commands rs, cleanup := run.Stub() defer cleanup(t) rs.Register(`git config --get-regexp \^branch\\.`, 0, "") + rs.Register(`git config remote.pushDefault`, 0, "") + rs.Register(`git rev-parse --abbrev-ref blueberries@{push}`, 0, "") + rs.Register(`git config push.default`, 0, "") output, err := runCommand(http, "blueberries", true, "") if err != nil { @@ -130,6 +132,9 @@ func TestPRStatus_reviewsAndChecks(t *testing.T) { rs, cleanup := run.Stub() defer cleanup(t) rs.Register(`git config --get-regexp \^branch\\.`, 0, "") + rs.Register(`git config remote.pushDefault`, 0, "") + rs.Register(`git rev-parse --abbrev-ref blueberries@{push}`, 0, "") + rs.Register(`git config push.default`, 0, "") output, err := runCommand(http, "blueberries", true, "") if err != nil { @@ -160,6 +165,9 @@ func TestPRStatus_reviewsAndChecksWithStatesByCount(t *testing.T) { rs, cleanup := run.Stub() defer cleanup(t) rs.Register(`git config --get-regexp \^branch\\.`, 0, "") + rs.Register(`git config remote.pushDefault`, 0, "") + rs.Register(`git rev-parse --abbrev-ref blueberries@{push}`, 0, "") + rs.Register(`git config push.default`, 0, "") output, err := runCommandWithDetector(http, "blueberries", true, "", &fd.EnabledDetectorMock{}) if err != nil { @@ -189,6 +197,9 @@ func TestPRStatus_currentBranch_showTheMostRecentPR(t *testing.T) { rs, cleanup := run.Stub() defer cleanup(t) rs.Register(`git config --get-regexp \^branch\\.`, 0, "") + rs.Register(`git config remote.pushDefault`, 0, "") + rs.Register(`git rev-parse --abbrev-ref blueberries@{push}`, 0, "") + rs.Register(`git config push.default`, 0, "") output, err := runCommand(http, "blueberries", true, "") if err != nil { @@ -222,6 +233,9 @@ func TestPRStatus_currentBranch_defaultBranch(t *testing.T) { rs, cleanup := run.Stub() defer cleanup(t) rs.Register(`git config --get-regexp \^branch\\.`, 0, "") + rs.Register(`git config remote.pushDefault`, 0, "") + rs.Register(`git rev-parse --abbrev-ref blueberries@{push}`, 0, "") + rs.Register(`git config push.default`, 0, "") output, err := runCommand(http, "blueberries", true, "") if err != nil { @@ -261,6 +275,9 @@ func TestPRStatus_currentBranch_Closed(t *testing.T) { rs, cleanup := run.Stub() defer cleanup(t) rs.Register(`git config --get-regexp \^branch\\.`, 0, "") + rs.Register(`git config remote.pushDefault`, 0, "") + rs.Register(`git rev-parse --abbrev-ref blueberries@{push}`, 0, "") + rs.Register(`git config push.default`, 0, "") output, err := runCommand(http, "blueberries", true, "") if err != nil { @@ -283,6 +300,9 @@ func TestPRStatus_currentBranch_Closed_defaultBranch(t *testing.T) { rs, cleanup := run.Stub() defer cleanup(t) rs.Register(`git config --get-regexp \^branch\\.`, 0, "") + rs.Register(`git config remote.pushDefault`, 0, "") + rs.Register(`git rev-parse --abbrev-ref blueberries@{push}`, 0, "") + rs.Register(`git config push.default`, 0, "") output, err := runCommand(http, "blueberries", true, "") if err != nil { @@ -305,6 +325,9 @@ func TestPRStatus_currentBranch_Merged(t *testing.T) { rs, cleanup := run.Stub() defer cleanup(t) rs.Register(`git config --get-regexp \^branch\\.`, 0, "") + rs.Register(`git config remote.pushDefault`, 0, "") + rs.Register(`git rev-parse --abbrev-ref blueberries@{push}`, 0, "") + rs.Register(`git config push.default`, 0, "") output, err := runCommand(http, "blueberries", true, "") if err != nil { @@ -327,6 +350,9 @@ func TestPRStatus_currentBranch_Merged_defaultBranch(t *testing.T) { rs, cleanup := run.Stub() defer cleanup(t) rs.Register(`git config --get-regexp \^branch\\.`, 0, "") + rs.Register(`git config remote.pushDefault`, 0, "") + rs.Register(`git rev-parse --abbrev-ref blueberries@{push}`, 0, "") + rs.Register(`git config push.default`, 0, "") output, err := runCommand(http, "blueberries", true, "") if err != nil { @@ -349,6 +375,9 @@ func TestPRStatus_blankSlate(t *testing.T) { rs, cleanup := run.Stub() defer cleanup(t) rs.Register(`git config --get-regexp \^branch\\.`, 0, "") + rs.Register(`git config remote.pushDefault`, 0, "") + rs.Register(`git rev-parse --abbrev-ref blueberries@{push}`, 0, "") + rs.Register(`git config push.default`, 0, "") output, err := runCommand(http, "blueberries", true, "") if err != nil { @@ -407,6 +436,9 @@ func TestPRStatus_detachedHead(t *testing.T) { rs, cleanup := run.Stub() defer cleanup(t) rs.Register(`git config --get-regexp \^branch\\.`, 0, "") + rs.Register(`git config remote.pushDefault`, 0, "") + rs.Register(`git rev-parse --abbrev-ref @{push}`, 0, "") + rs.Register(`git config push.default`, 0, "") output, err := runCommand(http, "", true, "") if err != nil { @@ -434,216 +466,8 @@ Requesting a code review from you func TestPRStatus_error_ReadBranchConfig(t *testing.T) { rs, cleanup := run.Stub() defer cleanup(t) - rs.Register(`git config --get-regexp \^branch\\.`, 1, "") - + // We only need the one stub because this fails early + rs.Register(`git config --get-regexp \^branch\\.`, 2, "") _, err := runCommand(initFakeHTTP(), "blueberries", true, "") assert.Error(t, err) } - -func Test_prSelectorForCurrentBranch(t *testing.T) { - tests := []struct { - name string - branchConfig git.BranchConfig - baseRepo ghrepo.Interface - prHeadRef string - remotes context.Remotes - wantPrNumber int - wantSelector string - wantError error - }{ - { - name: "Empty branch config", - branchConfig: git.BranchConfig{}, - prHeadRef: "monalisa/main", - wantPrNumber: 0, - wantSelector: "monalisa/main", - wantError: nil, - }, - { - name: "The branch is configured to merge a special PR head ref", - branchConfig: git.BranchConfig{ - MergeRef: "refs/pull/42/head", - }, - prHeadRef: "monalisa/main", - wantPrNumber: 42, - wantSelector: "monalisa/main", - wantError: nil, - }, - { - name: "Branch merges from a remote specified by URL", - branchConfig: git.BranchConfig{ - RemoteURL: &url.URL{ - Scheme: "ssh", - User: url.User("git"), - Host: "github.com", - Path: "monalisa/playground.git", - }, - }, - baseRepo: ghrepo.NewWithHost("monalisa", "playground", "github.com"), - prHeadRef: "monalisa/main", - remotes: context.Remotes{ - &context.Remote{ - Remote: &git.Remote{Name: "origin"}, - Repo: ghrepo.NewWithHost("monalisa", "playground", "github.com"), - }, - }, - wantPrNumber: 0, - wantSelector: "monalisa/main", - wantError: nil, - }, - { - name: "Branch merges from a remote specified by name", - branchConfig: git.BranchConfig{ - RemoteName: "upstream", - }, - baseRepo: ghrepo.NewWithHost("monalisa", "playground", "github.com"), - prHeadRef: "monalisa/main", - remotes: context.Remotes{ - &context.Remote{ - Remote: &git.Remote{Name: "origin"}, - Repo: ghrepo.NewWithHost("forkName", "playground", "github.com"), - }, - &context.Remote{ - Remote: &git.Remote{Name: "upstream"}, - Repo: ghrepo.NewWithHost("monalisa", "playground", "github.com"), - }, - }, - wantPrNumber: 0, - wantSelector: "monalisa/main", - wantError: nil, - }, - { - name: "Branch is a fork and merges from a remote specified by URL", - branchConfig: git.BranchConfig{ - RemoteURL: &url.URL{ - Scheme: "ssh", - User: url.User("git"), - Host: "github.com", - Path: "forkName/playground.git", - }, - MergeRef: "refs/heads/main", - }, - baseRepo: ghrepo.NewWithHost("monalisa", "playground", "github.com"), - prHeadRef: "monalisa/main", - remotes: context.Remotes{ - &context.Remote{ - Remote: &git.Remote{Name: "origin"}, - Repo: ghrepo.NewWithHost("forkName", "playground", "github.com"), - }, - }, - wantPrNumber: 0, - wantSelector: "forkName:main", - wantError: nil, - }, - { - name: "Branch is a fork and merges from a remote specified by name", - branchConfig: git.BranchConfig{ - RemoteName: "origin", - }, - baseRepo: ghrepo.NewWithHost("monalisa", "playground", "github.com"), - prHeadRef: "monalisa/main", - remotes: context.Remotes{ - &context.Remote{ - Remote: &git.Remote{Name: "origin"}, - Repo: ghrepo.NewWithHost("forkName", "playground", "github.com"), - }, - &context.Remote{ - Remote: &git.Remote{Name: "upstream"}, - Repo: ghrepo.NewWithHost("monalisa", "playground", "github.com"), - }, - }, - wantPrNumber: 0, - wantSelector: "forkName:monalisa/main", - wantError: nil, - }, - { - name: "Branch specifies a mergeRef and merges from a remote specified by name", - branchConfig: git.BranchConfig{ - RemoteName: "upstream", - MergeRef: "refs/heads/main", - }, - baseRepo: ghrepo.NewWithHost("monalisa", "playground", "github.com"), - prHeadRef: "monalisa/main", - remotes: context.Remotes{ - &context.Remote{ - Remote: &git.Remote{Name: "origin"}, - Repo: ghrepo.NewWithHost("forkName", "playground", "github.com"), - }, - &context.Remote{ - Remote: &git.Remote{Name: "upstream"}, - Repo: ghrepo.NewWithHost("monalisa", "playground", "github.com"), - }, - }, - wantPrNumber: 0, - wantSelector: "main", - wantError: nil, - }, - { - name: "Branch is a fork, specifies a mergeRef, and merges from a remote specified by name", - branchConfig: git.BranchConfig{ - RemoteName: "origin", - MergeRef: "refs/heads/main", - }, - baseRepo: ghrepo.NewWithHost("monalisa", "playground", "github.com"), - prHeadRef: "monalisa/main", - remotes: context.Remotes{ - &context.Remote{ - Remote: &git.Remote{Name: "origin"}, - Repo: ghrepo.NewWithHost("forkName", "playground", "github.com"), - }, - &context.Remote{ - Remote: &git.Remote{Name: "upstream"}, - Repo: ghrepo.NewWithHost("monalisa", "playground", "github.com"), - }, - }, - wantPrNumber: 0, - wantSelector: "forkName:main", - wantError: nil, - }, - { - name: "Remote URL errors", - branchConfig: git.BranchConfig{ - RemoteURL: &url.URL{ - Scheme: "ssh", - User: url.User("git"), - Host: "github.com", - Path: "/\\invalid?Path/", - }, - }, - prHeadRef: "monalisa/main", - wantPrNumber: 0, - wantSelector: "monalisa/main", - wantError: nil, - }, - { - name: "Remote Name errors", - branchConfig: git.BranchConfig{ - RemoteName: "nonexistentRemote", - }, - prHeadRef: "monalisa/main", - remotes: context.Remotes{ - &context.Remote{ - Remote: &git.Remote{Name: "origin"}, - Repo: ghrepo.NewWithHost("forkName", "playground", "github.com"), - }, - &context.Remote{ - Remote: &git.Remote{Name: "upstream"}, - Repo: ghrepo.NewWithHost("monalisa", "playground", "github.com"), - }, - }, - wantPrNumber: 0, - wantSelector: "monalisa/main", - wantError: nil, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - - prNum, headRef, err := prSelectorForCurrentBranch(tt.branchConfig, tt.baseRepo, tt.prHeadRef, tt.remotes) - assert.Equal(t, tt.wantPrNumber, prNum) - assert.Equal(t, tt.wantSelector, headRef) - assert.Equal(t, tt.wantError, err) - }) - } -}