Refactor code to live inside default pkg

This commit is contained in:
Sam Coe 2022-04-19 13:31:59 +02:00
parent 9060b44e6d
commit 22a5d2abf8
No known key found for this signature in database
GPG key ID: 8E322C20F811D086
8 changed files with 589 additions and 268 deletions

View file

@ -4,6 +4,7 @@ package context
import (
"errors"
"sort"
"strings"
"github.com/AlecAivazis/survey/v2"
"github.com/cli/cli/v2/api"
@ -59,93 +60,6 @@ type ResolvedRemotes struct {
apiClient *api.Client
}
func GetBaseRepo(remotes Remotes) (ghrepo.Interface, error) {
for _, r := range remotes {
if r.Resolved == "base" {
return r, nil
} else if r.Resolved != "" {
repo, err := ghrepo.FromFullName(r.Resolved)
if err != nil {
return nil, err
}
return ghrepo.NewWithHost(repo.RepoOwner(), repo.RepoName(), r.RepoHost()), nil
}
}
return nil, errors.New("a default repo has not been set, use `gh repo default` to set a default repo")
}
func (r *ResolvedRemotes) SetBaseRepo(io *iostreams.IOStreams) error {
resolution := "base"
if !io.CanPrompt() {
return git.SetRemoteResolution(r.remotes[0].Name, resolution)
}
// from here on, consult the API
if r.network == nil {
err := resolveNetwork(r)
if err != nil {
return err
}
}
var repoNames []string
repoMap := map[string]*api.Repository{}
add := func(r *api.Repository) {
fn := ghrepo.FullName(r)
if _, ok := repoMap[fn]; !ok {
repoMap[fn] = r
repoNames = append(repoNames, fn)
}
}
for _, repo := range r.network.Repositories {
if repo == nil {
continue
}
if repo.Parent != nil {
add(repo.Parent)
}
add(repo)
}
if len(repoNames) == 0 {
return git.SetRemoteResolution(r.remotes[0].Name, resolution)
}
baseName := repoNames[0]
if len(repoNames) > 1 {
err := prompt.SurveyAskOne(&survey.Select{
Message: "Which should be the base repository (used for e.g. querying issues) for this directory?",
Options: repoNames,
}, &baseName)
if err != nil {
return err
}
}
// determine corresponding git remote
selectedRepo := repoMap[baseName]
remote, _ := r.RemoteForRepo(selectedRepo)
if remote == nil {
remote = r.remotes[0]
resolution = ghrepo.FullName(selectedRepo)
}
// cache the result to git config
return git.SetRemoteResolution(remote.Name, resolution)
}
func RemoveBaseRepo(remotes Remotes) error {
for _, remote := range remotes {
if remote.Resolved == "base" {
if err := git.UnsetRemoteResolution(remote.Remote.Name); err != nil {
return err
}
}
}
return nil
}
func (r *ResolvedRemotes) BaseRepo(io *iostreams.IOStreams) (ghrepo.Interface, error) {
if r.baseOverride != nil {
return r.baseOverride, nil
@ -169,36 +83,18 @@ func (r *ResolvedRemotes) BaseRepo(io *iostreams.IOStreams) (ghrepo.Interface, e
return r.remotes[0], nil
}
// from here on, consult the API
if r.network == nil {
err := resolveNetwork(r)
if err != nil {
return nil, err
}
repos, err := r.NetworkRepos()
if err != nil {
return nil, err
}
if len(repos) == 0 {
return r.remotes[0], nil
}
var repoNames []string
repoMap := map[string]*api.Repository{}
add := func(r *api.Repository) {
fn := ghrepo.FullName(r)
if _, ok := repoMap[fn]; !ok {
repoMap[fn] = r
repoNames = append(repoNames, fn)
}
}
for _, repo := range r.network.Repositories {
if repo == nil {
continue
}
if repo.Parent != nil {
add(repo.Parent)
}
add(repo)
}
if len(repoNames) == 0 {
return r.remotes[0], nil
for _, r := range repos {
repoNames = append(repoNames, ghrepo.FullName(r))
}
baseName := repoNames[0]
@ -216,7 +112,8 @@ func (r *ResolvedRemotes) BaseRepo(io *iostreams.IOStreams) (ghrepo.Interface, e
}
// determine corresponding git remote
selectedRepo := repoMap[baseName]
owner, repo, _ := strings.Cut(baseName, "/")
selectedRepo := ghrepo.New(owner, repo)
resolution := "base"
remote, _ := r.RemoteForRepo(selectedRepo)
if remote == nil {
@ -225,7 +122,7 @@ func (r *ResolvedRemotes) BaseRepo(io *iostreams.IOStreams) (ghrepo.Interface, e
}
// cache the result to git config
err := git.SetRemoteResolution(remote.Name, resolution)
err = git.SetRemoteResolution(remote.Name, resolution)
return selectedRepo, err
}
@ -246,6 +143,38 @@ func (r *ResolvedRemotes) HeadRepos() ([]*api.Repository, error) {
return results, nil
}
func (r *ResolvedRemotes) NetworkRepos() ([]*api.Repository, error) {
if r.network == nil {
err := resolveNetwork(r)
if err != nil {
return nil, err
}
}
var repos []*api.Repository
repoMap := map[string]bool{}
add := func(r *api.Repository) {
fn := ghrepo.FullName(r)
if _, ok := repoMap[fn]; !ok {
repoMap[fn] = true
repos = append(repos, r)
}
}
for _, repo := range r.network.Repositories {
if repo == nil {
continue
}
if repo.Parent != nil {
add(repo.Parent)
}
add(repo)
}
return repos, nil
}
// RemoteForRepo finds the git remote that points to a repository
func (r *ResolvedRemotes) RemoteForRepo(repo ghrepo.Interface) (*Remote, error) {
for _, remote := range r.remotes {

View file

@ -21,7 +21,7 @@ func (r Remotes) FindByName(names ...string) (*Remote, error) {
}
}
}
return nil, fmt.Errorf("no GitHub remotes found")
return nil, fmt.Errorf("no matching remote found")
}
// FindByRepo returns the first Remote that points to a specific GitHub repository
@ -34,6 +34,29 @@ func (r Remotes) FindByRepo(owner, name string) (*Remote, error) {
return nil, fmt.Errorf("no matching remote found")
}
// Filter remotes by given hostnames, maintains original order
func (r Remotes) FilterByHosts(hosts []string) Remotes {
filtered := make(Remotes, 0)
for _, rr := range r {
for _, host := range hosts {
if strings.EqualFold(rr.RepoHost(), host) {
filtered = append(filtered, rr)
break
}
}
}
return filtered
}
func (r Remotes) ResolvedRemote() (*Remote, error) {
for _, rr := range r {
if rr.Resolved != "" {
return rr, nil
}
}
return nil, fmt.Errorf("no resolved remote found")
}
func remoteNameSortScore(name string) int {
switch strings.ToLower(name) {
case "upstream":
@ -54,20 +77,6 @@ func (r Remotes) Less(i, j int) bool {
return remoteNameSortScore(r[i].Name) > remoteNameSortScore(r[j].Name)
}
// Filter remotes by given hostnames, maintains original order
func (r Remotes) FilterByHosts(hosts []string) Remotes {
filtered := make(Remotes, 0)
for _, rr := range r {
for _, host := range hosts {
if strings.EqualFold(rr.RepoHost(), host) {
filtered = append(filtered, rr)
break
}
}
}
return filtered
}
// Remote represents a git remote mapped to a GitHub repository
type Remote struct {
*git.Remote

View file

@ -7,9 +7,3 @@
[user]
name = Mona the Cat
email = monalisa@github.com
[remote "origin"]
url = git@github.com:monathecat/cli.git
fetch = +refs/heads/*:refs/remotes/origin/*
[remote "upstream"]
url = git@github.com:cli/cli.git
fetch = +refs/heads/trunk:refs/remotes/upstream/trunk

View file

@ -394,7 +394,6 @@ func ToplevelDir() (string, error) {
}
output, err := run.PrepareCmd(showCmd).Output()
return firstLine(output), err
}
// ToplevelDirFromPath returns the top-level given path of the current repository
@ -439,3 +438,16 @@ func getBranchShortName(output []byte) string {
branch := firstLine(output)
return strings.TrimPrefix(branch, "refs/heads/")
}
func IsGitDirectory() bool {
showCmd, err := GitCommand("rev-parse", "--is-inside-work-tree")
if err != nil {
return false
}
output, err := run.PrepareCmd(showCmd).Output()
if err != nil {
return false
}
out := firstLine(output)
return out == "true"
}

View file

@ -169,7 +169,7 @@ func SetRemoteResolution(name, resolution string) error {
}
func UnsetRemoteResolution(name string) error {
unsetCmd, err := GitCommand("config", "--unset-all", fmt.Sprintf("remote.%s.gh-resolved", name))
unsetCmd, err := GitCommand("config", "--unset", fmt.Sprintf("remote.%s.gh-resolved", name))
if err != nil {
return err
}

View file

@ -1,15 +1,21 @@
package base
import (
"errors"
"fmt"
"net/http"
"sort"
"strings"
"github.com/AlecAivazis/survey/v2"
"github.com/MakeNowJust/heredoc"
"github.com/cli/cli/v2/api"
"github.com/cli/cli/v2/context"
"github.com/cli/cli/v2/git"
"github.com/cli/cli/v2/internal/ghrepo"
"github.com/cli/cli/v2/pkg/cmdutil"
"github.com/cli/cli/v2/pkg/iostreams"
"github.com/cli/cli/v2/pkg/prompt"
"github.com/spf13/cobra"
)
@ -18,10 +24,11 @@ type DefaultOptions struct {
Remotes func() (context.Remotes, error)
HttpClient func() (*http.Client, error)
ViewFlag bool
Repo ghrepo.Interface
ViewMode bool
}
func NewCmdDefault(f *cmdutil.Factory) *cobra.Command {
func NewCmdDefault(f *cmdutil.Factory, runF func(*DefaultOptions) error) *cobra.Command {
opts := &DefaultOptions{
IO: f.IOStreams,
HttpClient: f.HttpClient,
@ -29,51 +36,60 @@ func NewCmdDefault(f *cmdutil.Factory) *cobra.Command {
}
cmd := &cobra.Command{
Use: "default",
Short: "Configure the default repository used for various commands",
Long: heredoc.Doc(`
The default repository is used to determine which remote
repository gh should automatically point to.
`),
Example: heredoc.Doc(`
$ gh repo default
#=> prompts remote options
Use: "default [<repository>]",
Short: "Configure default repository",
Long: heredoc.Docf(`
Set default repository for current directory.
$ gh repo default -v
#=> cli/cli
`),
Annotations: map[string]string{
"help:environment": heredoc.Doc(`
To manually configure a remote for gh to use, modify your local repo's git config
; Ex: setting gh to use the upstream remote
[remote "upstream"]
gh-resolved = base
...
`),
},
Args: cobra.NoArgs,
The default repository is used as the target
repository for various commands such as %[1]spr%[1]s, %[1]sissue%[1]s,
and %[1]srepo%[1]s.
`, "`"),
Args: cobra.MaximumNArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
return runDefault(opts)
if len(args) > 0 {
var err error
opts.Repo, err = ghrepo.FromFullName(args[0])
if err != nil {
return err
}
}
if !opts.IO.CanPrompt() && opts.Repo == nil {
return cmdutil.FlagErrorf("repository required when not running interactively")
}
if !git.IsGitDirectory() {
return errors.New("must be run from inside a git repository")
}
if runF != nil {
return runF(opts)
}
return defaultRun(opts)
},
}
cmd.Flags().BoolVarP(&opts.ViewFlag, "view", "v", false, "view the default repository used for various commands")
cmd.Flags().BoolVarP(&opts.ViewMode, "view", "v", false, "view the current default repository")
return cmd
}
func runDefault(opts *DefaultOptions) error {
func defaultRun(opts *DefaultOptions) error {
remotes, err := opts.Remotes()
if err != nil {
return err
}
if opts.ViewFlag {
baseRepo, err := context.GetBaseRepo(remotes)
if err != nil {
return err
currentDefaultRepo, _ := remotes.ResolvedRemote()
if opts.ViewMode {
if currentDefaultRepo == nil {
fmt.Fprintln(opts.IO.Out, "no default repo has been set; use `gh repo default` to select one")
} else {
fmt.Fprintln(opts.IO.Out, displayRemoteRepoName(currentDefaultRepo))
}
fmt.Fprintln(opts.IO.Out, ghrepo.FullName(baseRepo))
return nil
}
@ -82,12 +98,107 @@ func runDefault(opts *DefaultOptions) error {
return err
}
apiClient := api.NewClientFromHTTP(httpClient)
repoContext, err := context.ResolveRemotesToRepos(remotes, apiClient, "")
resolvedRemotes, err := context.ResolveRemotesToRepos(remotes, apiClient, "")
if err != nil {
return err
}
if err = context.RemoveBaseRepo(remotes); err != nil {
knownRepos, err := resolvedRemotes.NetworkRepos()
if err != nil {
return err
}
return repoContext.SetBaseRepo(opts.IO)
if len(knownRepos) == 0 {
return errors.New("none of the git remotes correspond to a valid remote repository")
}
var selectedRepo ghrepo.Interface
if opts.Repo != nil {
for _, knownRepo := range knownRepos {
if ghrepo.IsSame(opts.Repo, knownRepo) {
selectedRepo = opts.Repo
break
}
}
if selectedRepo == nil {
return fmt.Errorf("%s does not correspond to any git remotes", ghrepo.FullName(opts.Repo))
}
}
if selectedRepo == nil {
if len(knownRepos) == 1 {
selectedRepo = knownRepos[0]
} else {
var repoNames []string
var selectedName string
current := ""
if currentDefaultRepo != nil {
current = ghrepo.FullName(currentDefaultRepo)
}
for _, knownRepo := range knownRepos {
repoNames = append(repoNames, ghrepo.FullName(knownRepo))
}
err := prompt.SurveyAskOne(&survey.Select{
Message: "Which should be the default repository (used for e.g. querying issues) for this directory?",
Options: repoNames,
Default: current,
}, &selectedName)
if err != nil {
return err
}
owner, repo, _ := strings.Cut(selectedName, "/")
selectedRepo = ghrepo.New(owner, repo)
}
}
resolution := "base"
selectedRemote, _ := resolvedRemotes.RemoteForRepo(selectedRepo)
if selectedRemote == nil {
sort.Stable(remotes)
selectedRemote = remotes[0]
resolution = ghrepo.FullName(selectedRepo)
}
if currentDefaultRepo != nil {
if err := unsetDefaultRepo(currentDefaultRepo); err != nil {
return err
}
}
err = setDefaultRepo(selectedRemote, resolution)
if err != nil {
return err
}
if opts.IO.IsStdoutTTY() {
cs := opts.IO.ColorScheme()
fmt.Fprintf(opts.IO.Out, "%s Set %s as the default repository for the current directory\n", cs.SuccessIcon(), ghrepo.FullName(selectedRepo))
}
return nil
}
func displayRemoteRepoName(remote *context.Remote) string {
if remote.Resolved == "" || remote.Resolved == "base" {
return ghrepo.FullName(remote)
}
repo, err := ghrepo.FromFullName(remote.Resolved)
if err != nil {
return ghrepo.FullName(remote)
}
return ghrepo.FullName(repo)
}
func setDefaultRepo(remote *context.Remote, resolution string) error {
return git.SetRemoteResolution(remote.Name, resolution)
}
func unsetDefaultRepo(remote *context.Remote) error {
return git.UnsetRemoteResolution(remote.Name)
}

View file

@ -1,129 +1,395 @@
package base
import (
"errors"
"fmt"
"bytes"
"net/http"
"os"
"testing"
"github.com/cli/cli/v2/context"
"github.com/cli/cli/v2/git"
"github.com/cli/cli/v2/internal/ghrepo"
"github.com/cli/cli/v2/internal/run"
"github.com/cli/cli/v2/pkg/cmdutil"
"github.com/cli/cli/v2/pkg/httpmock"
"github.com/cli/cli/v2/pkg/iostreams"
"github.com/cli/cli/v2/pkg/prompt"
"github.com/google/shlex"
"github.com/stretchr/testify/assert"
)
func Test_defaultRun(t *testing.T) {
setGitDir(t, "../../../../git/fixtures/simple.git")
func TestNewCmdDefault(t *testing.T) {
tests := []struct {
name string
opts DefaultOptions
wantedErr error
wantedStdOut string
wantedResolvedName string
name string
gitStubs func(*run.CommandStubber)
input string
output DefaultOptions
wantErr bool
errMsg string
}{
{
name: "Base repo set with view option",
opts: DefaultOptions{
Remotes: func() (context.Remotes, error) {
return []*context.Remote{
{
Remote: &git.Remote{
Name: "origin",
Resolved: "base",
},
Repo: ghrepo.New("hubot", "Spoon-Knife"),
},
}, nil
},
ViewFlag: true,
name: "no argument",
gitStubs: func(cs *run.CommandStubber) {
cs.Register(`git rev-parse --is-inside-work-tree`, 0, "true")
},
wantedStdOut: "hubot/Spoon-Knife",
input: "",
output: DefaultOptions{},
},
{
name: "Base repo not set with view option",
opts: DefaultOptions{
Remotes: func() (context.Remotes, error) {
return []*context.Remote{
{
Remote: &git.Remote{
Name: "origin",
},
},
}, nil
},
ViewFlag: true,
name: "repo argument",
gitStubs: func(cs *run.CommandStubber) {
cs.Register(`git rev-parse --is-inside-work-tree`, 0, "true")
},
wantedErr: errors.New("a default repo has not been set, use `gh repo default` to set a default repo"),
input: "cli/cli",
output: DefaultOptions{Repo: ghrepo.New("cli", "cli")},
},
{
name: "Base repo not set, assign non-interactively",
opts: DefaultOptions{
Remotes: func() (context.Remotes, error) {
return []*context.Remote{
{
Remote: &git.Remote{
Name: "origin",
},
},
{
Remote: &git.Remote{
Name: "upstream",
},
},
}, nil
},
name: "invalid repo argument",
gitStubs: func(cs *run.CommandStubber) {},
input: "some_invalid_format",
wantErr: true,
errMsg: `expected the "[HOST/]OWNER/REPO" format, got "some_invalid_format"`,
},
{
name: "view flag",
gitStubs: func(cs *run.CommandStubber) {
cs.Register(`git rev-parse --is-inside-work-tree`, 0, "true")
},
wantedResolvedName: "upstream",
input: "--view",
output: DefaultOptions{ViewMode: true},
},
{
name: "run from non-git directory",
gitStubs: func(cs *run.CommandStubber) {
cs.Register(`git rev-parse --is-inside-work-tree`, 1, "")
},
input: "",
wantErr: true,
errMsg: "must be run from inside a git repository",
},
}
for _, tt := range tests {
io, _, _, _ := iostreams.Test()
io.SetStdoutTTY(true)
io.SetStdinTTY(true)
io.SetStderrTTY(true)
f := &cmdutil.Factory{
IOStreams: io,
}
var gotOpts *DefaultOptions
cmd := NewCmdDefault(f, func(opts *DefaultOptions) error {
gotOpts = opts
return nil
})
cmd.SetIn(&bytes.Buffer{})
cmd.SetOut(&bytes.Buffer{})
cmd.SetErr(&bytes.Buffer{})
t.Run(tt.name, func(t *testing.T) {
io, _, stdout, stderr := iostreams.Test()
argv, err := shlex.Split(tt.input)
assert.NoError(t, err)
opts := tt.opts
opts.IO = io
opts.HttpClient = func() (*http.Client, error) { return nil, nil }
cmd.SetArgs(argv)
err := runDefault(&opts)
if tt.wantedErr != nil {
assert.EqualError(t, tt.wantedErr, err.Error())
} else {
assert.NoError(t, err)
if opts.ViewFlag {
assert.Equal(t, fmt.Sprintf("%s\n", tt.wantedStdOut), stdout.String())
} else {
assert.Equal(t, "", stdout.String())
assert.Equal(t, "", stderr.String())
}
}
if tt.wantedResolvedName != "" {
resolvedAmount := 0
remotes, err := git.Remotes()
if err != nil {
panic(err)
}
for _, r := range remotes {
if r.Resolved == "base" {
assert.Equal(t, r.Name, tt.wantedResolvedName)
resolvedAmount++
}
}
assert.Equal(t, 1, resolvedAmount)
cs, teardown := run.Stub()
defer teardown(t)
tt.gitStubs(cs)
_, err = cmd.ExecuteC()
if tt.wantErr {
assert.EqualError(t, err, tt.errMsg)
return
}
assert.NoError(t, err)
assert.Equal(t, tt.output.Repo, gotOpts.Repo)
assert.Equal(t, tt.output.ViewMode, gotOpts.ViewMode)
})
}
}
func setGitDir(t *testing.T, dir string) {
old_GIT_DIR := os.Getenv("GIT_DIR")
os.Setenv("GIT_DIR", dir)
t.Cleanup(func() {
if err := git.UnsetRemoteResolution("upstream"); err != nil {
panic(err)
func TestDefaultRun(t *testing.T) {
repo1, _ := ghrepo.FromFullName("OWNER/REPO")
repo2, _ := ghrepo.FromFullName("OWNER2/REPO2")
repo3, _ := ghrepo.FromFullName("OWNER3/REPO3")
tests := []struct {
name string
tty bool
opts DefaultOptions
remotes []*context.Remote
httpStubs func(*httpmock.Registry)
gitStubs func(*run.CommandStubber)
askStubs func(*prompt.AskStubber)
wantStdout string
wantErr bool
errMsg string
}{
{
name: "view mode no current default",
opts: DefaultOptions{ViewMode: true},
remotes: []*context.Remote{
{
Remote: &git.Remote{Name: "origin"},
Repo: repo1,
},
},
wantStdout: "no default repo has been set; use `gh repo default` to select one\n",
},
{
name: "view mode with base resolved current default",
opts: DefaultOptions{ViewMode: true},
remotes: []*context.Remote{
{
Remote: &git.Remote{Name: "origin", Resolved: "base"},
Repo: repo1,
},
},
wantStdout: "OWNER/REPO\n",
},
{
name: "view mode with non-base resolved current default",
opts: DefaultOptions{ViewMode: true},
remotes: []*context.Remote{
{
Remote: &git.Remote{Name: "origin", Resolved: "PARENT/REPO"},
Repo: repo1,
},
},
wantStdout: "PARENT/REPO\n",
},
{
name: "tty non-interactive mode no current default",
tty: true,
opts: DefaultOptions{Repo: repo2},
remotes: []*context.Remote{
{
Remote: &git.Remote{Name: "origin"},
Repo: repo1,
},
{
Remote: &git.Remote{Name: "upstream"},
Repo: repo2,
},
},
httpStubs: func(reg *httpmock.Registry) {
reg.Register(
httpmock.GraphQL(`query RepositoryNetwork\b`),
httpmock.StringResponse(`{"data":{"repo_000":{"name":"REPO2","owner":{"login":"OWNER2"}}}}`),
)
},
gitStubs: func(cs *run.CommandStubber) {
cs.Register(`git config --add remote.upstream.gh-resolved base`, 0, "")
},
wantStdout: "✓ Set OWNER2/REPO2 as the default repository for the current directory\n",
},
{
name: "tty non-interactive mode set non-base default",
tty: true,
opts: DefaultOptions{Repo: repo2},
remotes: []*context.Remote{
{
Remote: &git.Remote{Name: "origin"},
Repo: repo1,
},
{
Remote: &git.Remote{Name: "upstream"},
Repo: repo3,
},
},
httpStubs: func(reg *httpmock.Registry) {
reg.Register(
httpmock.GraphQL(`query RepositoryNetwork\b`),
httpmock.StringResponse(`{"data":{"repo_000":{"name":"REPO","owner":{"login":"OWNER"},"parent":{"name":"REPO2","owner":{"login":"OWNER2"}}}}}`),
)
},
gitStubs: func(cs *run.CommandStubber) {
cs.Register(`git config --add remote.upstream.gh-resolved OWNER2/REPO2`, 0, "")
},
wantStdout: "✓ Set OWNER2/REPO2 as the default repository for the current directory\n",
},
{
name: "non-tty non-interactive mode no current default",
opts: DefaultOptions{Repo: repo2},
remotes: []*context.Remote{
{
Remote: &git.Remote{Name: "origin"},
Repo: repo1,
},
{
Remote: &git.Remote{Name: "upstream"},
Repo: repo2,
},
},
httpStubs: func(reg *httpmock.Registry) {
reg.Register(
httpmock.GraphQL(`query RepositoryNetwork\b`),
httpmock.StringResponse(`{"data":{"repo_000":{"name":"REPO2","owner":{"login":"OWNER2"}}}}`),
)
},
gitStubs: func(cs *run.CommandStubber) {
cs.Register(`git config --add remote.upstream.gh-resolved base`, 0, "")
},
wantStdout: "",
},
{
name: "non-interactive mode with current default",
tty: true,
opts: DefaultOptions{Repo: repo2},
remotes: []*context.Remote{
{
Remote: &git.Remote{Name: "origin", Resolved: "base"},
Repo: repo1,
},
{
Remote: &git.Remote{Name: "upstream"},
Repo: repo2,
},
},
httpStubs: func(reg *httpmock.Registry) {
reg.Register(
httpmock.GraphQL(`query RepositoryNetwork\b`),
httpmock.StringResponse(`{"data":{"repo_000":{"name":"REPO2","owner":{"login":"OWNER2"}}}}`),
)
},
gitStubs: func(cs *run.CommandStubber) {
cs.Register(`git config --unset remote.origin.gh-resolved`, 0, "")
cs.Register(`git config --add remote.upstream.gh-resolved base`, 0, "")
},
wantStdout: "✓ Set OWNER2/REPO2 as the default repository for the current directory\n",
},
{
name: "non-interactive mode no known hosts",
opts: DefaultOptions{Repo: repo2},
remotes: []*context.Remote{
{
Remote: &git.Remote{Name: "origin"},
Repo: repo1,
},
},
httpStubs: func(reg *httpmock.Registry) {
reg.Register(
httpmock.GraphQL(`query RepositoryNetwork\b`),
httpmock.StringResponse(`{"data":{}}`),
)
},
wantErr: true,
errMsg: "none of the git remotes correspond to a valid remote repository",
},
{
name: "non-interactive mode no matching remotes",
opts: DefaultOptions{Repo: repo2},
remotes: []*context.Remote{
{
Remote: &git.Remote{Name: "origin"},
Repo: repo1,
},
},
httpStubs: func(reg *httpmock.Registry) {
reg.Register(
httpmock.GraphQL(`query RepositoryNetwork\b`),
httpmock.StringResponse(`{"data":{"repo_000":{"name":"REPO","owner":{"login":"OWNER"}}}}`),
)
},
wantErr: true,
errMsg: "OWNER2/REPO2 does not correspond to any git remotes",
},
{
name: "interactive mode",
tty: true,
opts: DefaultOptions{},
remotes: []*context.Remote{
{
Remote: &git.Remote{Name: "origin"},
Repo: repo1,
},
{
Remote: &git.Remote{Name: "upstream"},
Repo: repo2,
},
},
httpStubs: func(reg *httpmock.Registry) {
reg.Register(
httpmock.GraphQL(`query RepositoryNetwork\b`),
httpmock.StringResponse(`{"data":{"repo_000":{"name":"REPO","owner":{"login":"OWNER"}},"repo_001":{"name":"REPO2","owner":{"login":"OWNER2"}}}}`),
)
},
gitStubs: func(cs *run.CommandStubber) {
cs.Register(`git config --add remote.upstream.gh-resolved base`, 0, "")
},
askStubs: func(as *prompt.AskStubber) {
as.StubPrompt("Which should be the default repository (used for e.g. querying issues) for this directory?").
AssertOptions([]string{"OWNER/REPO", "OWNER2/REPO2"}).
AnswerWith("OWNER2/REPO2")
},
wantStdout: "✓ Set OWNER2/REPO2 as the default repository for the current directory\n",
},
{
name: "interactive mode only one known host",
tty: true,
opts: DefaultOptions{},
remotes: []*context.Remote{
{
Remote: &git.Remote{Name: "origin"},
Repo: repo1,
},
{
Remote: &git.Remote{Name: "upstream"},
Repo: repo2,
},
},
httpStubs: func(reg *httpmock.Registry) {
reg.Register(
httpmock.GraphQL(`query RepositoryNetwork\b`),
httpmock.StringResponse(`{"data":{"repo_000":{"name":"REPO2","owner":{"login":"OWNER2"}}}}`),
)
},
gitStubs: func(cs *run.CommandStubber) {
cs.Register(`git config --add remote.upstream.gh-resolved base`, 0, "")
},
wantStdout: "✓ Set OWNER2/REPO2 as the default repository for the current directory\n",
},
}
for _, tt := range tests {
reg := &httpmock.Registry{}
if tt.httpStubs != nil {
tt.httpStubs(reg)
}
os.Setenv("GIT_DIR", old_GIT_DIR)
})
tt.opts.HttpClient = func() (*http.Client, error) {
return &http.Client{Transport: reg}, nil
}
io, _, stdout, _ := iostreams.Test()
io.SetStdinTTY(tt.tty)
io.SetStdoutTTY(tt.tty)
io.SetStderrTTY(tt.tty)
tt.opts.IO = io
tt.opts.Remotes = func() (context.Remotes, error) {
return tt.remotes, nil
}
as := prompt.NewAskStubber(t)
if tt.askStubs != nil {
tt.askStubs(as)
}
t.Run(tt.name, func(t *testing.T) {
cs, teardown := run.Stub()
defer teardown(t)
if tt.gitStubs != nil {
tt.gitStubs(cs)
}
defer reg.Verify(t)
err := defaultRun(&tt.opts)
if tt.wantErr {
assert.EqualError(t, err, tt.errMsg)
return
}
assert.NoError(t, err)
assert.Equal(t, tt.wantStdout, stdout.String())
})
}
}

View file

@ -53,7 +53,7 @@ func NewCmdRepo(f *cmdutil.Factory) *cobra.Command {
cmd.AddCommand(repoRenameCmd.NewCmdRename(f, nil))
cmd.AddCommand(repoDeleteCmd.NewCmdDelete(f, nil))
cmd.AddCommand(repoArchiveCmd.NewCmdArchive(f, nil))
cmd.AddCommand(repoDefaultCmd.NewCmdDefault(f))
cmd.AddCommand(repoDefaultCmd.NewCmdDefault(f, nil))
return cmd
}