Refactor finder to work with URL selectors
This commit is contained in:
parent
29e57ee8e7
commit
058cb2c220
4 changed files with 163 additions and 47 deletions
|
|
@ -6,6 +6,7 @@ import (
|
|||
"errors"
|
||||
"io"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/cli/cli/v2/internal/run"
|
||||
)
|
||||
|
|
@ -48,6 +49,10 @@ func (gc *Command) Output() ([]byte, error) {
|
|||
ge.Stderr = string(exitError.Stderr)
|
||||
ge.ExitCode = exitError.ExitCode()
|
||||
}
|
||||
|
||||
if strings.Contains(ge.Stderr, "fatal: not a git repository") {
|
||||
ge.err = ErrNoGitRepository
|
||||
}
|
||||
err = &ge
|
||||
}
|
||||
return out, err
|
||||
|
|
|
|||
98
git/command_test.go
Normal file
98
git/command_test.go
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
package git
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOutput(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
exitCode int
|
||||
stdout string
|
||||
stderr string
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "successful command",
|
||||
stdout: "hello world",
|
||||
stderr: "",
|
||||
exitCode: 0,
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "not a repo failure",
|
||||
stdout: "",
|
||||
stderr: "fatal: not a git repository (or any of the parent directories): .git",
|
||||
exitCode: 128,
|
||||
wantErr: ErrNoGitRepository,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
cmd := Command{
|
||||
&exec.Cmd{
|
||||
Path: createMockExecutable(t, tt.stdout, tt.stderr, tt.exitCode),
|
||||
},
|
||||
}
|
||||
|
||||
out, err := cmd.Output()
|
||||
if tt.wantErr != nil {
|
||||
require.Error(t, err)
|
||||
var gitError *GitError
|
||||
require.ErrorAs(t, err, &gitError)
|
||||
assert.Equal(t, tt.wantErr, gitError.err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tt.stdout, string(out))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createMockExecutable(t *testing.T, stdout string, stderr string, exitCode int) string {
|
||||
tmpDir := t.TempDir()
|
||||
sourcePath := filepath.Join(tmpDir, "main.go")
|
||||
binaryPath := filepath.Join(tmpDir, "mockexec")
|
||||
if runtime.GOOS == "windows" {
|
||||
binaryPath += ".exe"
|
||||
}
|
||||
|
||||
// Create Go source
|
||||
source := buildCommandSourceCode(exitCode, stdout, stderr)
|
||||
|
||||
// Write source file
|
||||
if err := os.WriteFile(sourcePath, []byte(source), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Compile
|
||||
cmd := exec.Command("go", "build", "-o", binaryPath, sourcePath)
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
t.Fatalf("failed to compile: %v\n%s", err, out)
|
||||
}
|
||||
return binaryPath
|
||||
|
||||
}
|
||||
|
||||
func buildCommandSourceCode(exitCode int, stdout, stderr string) string {
|
||||
return fmt.Sprintf(`package main
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
func main() {
|
||||
fmt.Printf(%q)
|
||||
fmt.Fprintf(os.Stderr, %q)
|
||||
os.Exit(%d)
|
||||
}`, stdout, stderr, exitCode)
|
||||
}
|
||||
|
|
@ -7,6 +7,7 @@ import (
|
|||
|
||||
// ErrNotOnAnyBranch indicates that the user is in detached HEAD state.
|
||||
var ErrNotOnAnyBranch = errors.New("git: not on any branch")
|
||||
var ErrNoGitRepository = errors.New("git: not in a git repository")
|
||||
|
||||
type NotInstalled struct {
|
||||
message string
|
||||
|
|
|
|||
|
|
@ -120,6 +120,8 @@ func (s *PullRequestRefs) GetPRHeadLabel() string {
|
|||
}
|
||||
|
||||
func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, error) {
|
||||
// If we have a URL, we don't need git stuff
|
||||
|
||||
if len(opts.Fields) == 0 {
|
||||
return nil, nil, errors.New("Find error: no fields specified")
|
||||
}
|
||||
|
|
@ -137,37 +139,70 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err
|
|||
f.baseRefRepo = repo
|
||||
}
|
||||
|
||||
if f.prNumber == 0 && opts.Selector != "" {
|
||||
// If opts.Selector is a valid number then assume it is the
|
||||
// PR number unless opts.BaseBranch is specified. This is a
|
||||
// special case for PR create command which will always want
|
||||
// to assume that a numerical selector is a branch name rather
|
||||
// than PR number.
|
||||
prNumber, err := strconv.Atoi(strings.TrimPrefix(opts.Selector, "#"))
|
||||
if opts.BaseBranch == "" && err == nil {
|
||||
f.prNumber = prNumber
|
||||
} else {
|
||||
f.branchName = opts.Selector
|
||||
}
|
||||
} else {
|
||||
var prRefs PullRequestRefs
|
||||
if opts.Selector == "" {
|
||||
// You must be in a git repo for this case to work
|
||||
currentBranchName, err := f.branchFn()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
f.branchName = currentBranchName
|
||||
}
|
||||
|
||||
// Get the branch config for the current branchName
|
||||
branchConfig, err := f.branchConfig(f.branchName)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
// Get the branch config for the current branchName
|
||||
branchConfig, err := f.branchConfig(f.branchName)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Determine if the branch is configured to merge to a special PR ref
|
||||
prHeadRE := regexp.MustCompile(`^refs/pull/(\d+)/head$`)
|
||||
if m := prHeadRE.FindStringSubmatch(branchConfig.MergeRef); m != nil {
|
||||
prNumber, _ := strconv.Atoi(m[1])
|
||||
f.prNumber = prNumber
|
||||
// Determine if the branch is configured to merge to a special PR ref
|
||||
prHeadRE := regexp.MustCompile(`^refs/pull/(\d+)/head$`)
|
||||
if m := prHeadRE.FindStringSubmatch(branchConfig.MergeRef); m != nil {
|
||||
prNumber, _ := strconv.Atoi(m[1])
|
||||
f.prNumber = prNumber
|
||||
}
|
||||
|
||||
// Determine the PullRequestRefs from config
|
||||
if f.prNumber == 0 {
|
||||
rems, err := f.remotesFn()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Suppressing these errors as we have other means of computing the PullRequestRefs when these fail.
|
||||
parsedPushRevision, _ := f.parsePushRevision(f.branchName)
|
||||
|
||||
pushDefault, err := f.pushDefault()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
remotePushDefault, err := f.remotePushDefault()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
prRefs, err = ParsePRRefs(f.branchName, branchConfig, parsedPushRevision, pushDefault, remotePushDefault, f.baseRefRepo, rems)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
} else if f.prNumber == 0 {
|
||||
// You gave me a selector but I couldn't find a PR number (it wasn't a URL)
|
||||
|
||||
// Try to get a PR number from the selector
|
||||
prNumber, err := strconv.Atoi(strings.TrimPrefix(opts.Selector, "#"))
|
||||
// If opts.Selector is a valid number then assume it is the
|
||||
// PR number unless opts.BaseBranch is specified. This is a
|
||||
// special case for PR create command which will always want
|
||||
// to assume that a numerical selector is a branch name rather
|
||||
// than PR number.
|
||||
if opts.BaseBranch == "" && err == nil {
|
||||
f.prNumber = prNumber
|
||||
} else {
|
||||
f.branchName = opts.Selector
|
||||
prRefs, _ = ParsePRRefs(f.branchName, git.BranchConfig{}, "", "", "", f.baseRefRepo, remotes.Remotes{})
|
||||
}
|
||||
}
|
||||
|
||||
// Set up HTTP client
|
||||
|
|
@ -217,29 +252,6 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err
|
|||
return pr, f.baseRefRepo, err
|
||||
}
|
||||
} else {
|
||||
rems, err := f.remotesFn()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
pushDefault, err := f.pushDefault()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Suppressing these errors as we have other means of computing the PullRequestRefs when these fail.
|
||||
parsedPushRevision, _ := f.parsePushRevision(f.branchName)
|
||||
|
||||
remotePushDefault, err := f.remotePushDefault()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
prRefs, err := ParsePRRefs(f.branchName, branchConfig, parsedPushRevision, pushDefault, remotePushDefault, f.baseRefRepo, rems)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
pr, err = findForBranch(httpClient, f.baseRefRepo, opts.BaseBranch, prRefs.GetPRHeadLabel(), opts.States, fields.ToSlice())
|
||||
if err != nil {
|
||||
return pr, f.baseRefRepo, err
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue