Cleanup factory/default and add tests
This commit is contained in:
parent
e380d68ed2
commit
53fac59ef9
4 changed files with 389 additions and 80 deletions
|
|
@ -6,6 +6,8 @@ import (
|
|||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/cli/cli/api"
|
||||
"github.com/cli/cli/context"
|
||||
"github.com/cli/cli/git"
|
||||
"github.com/cli/cli/internal/config"
|
||||
"github.com/cli/cli/internal/ghrepo"
|
||||
|
|
@ -14,11 +16,93 @@ import (
|
|||
)
|
||||
|
||||
func New(appVersion string) *cmdutil.Factory {
|
||||
io := iostreams.System()
|
||||
f := &cmdutil.Factory{
|
||||
IOStreams: iostreams.System(), // No factory dependencies
|
||||
Config: configFunc(), // No factory dependencies
|
||||
Branch: branchFunc(), // No factory dependencies
|
||||
Executable: executable(), // No factory dependencies
|
||||
}
|
||||
|
||||
f.HttpClient = httpClientFunc(f, appVersion) // Depends on Config, IOStreams, and appVersion
|
||||
f.Remotes = remotesFunc(f) // Depends on Config
|
||||
f.BaseRepo = BaseRepoFunc(f) // Depends on Remotes
|
||||
f.Browser = browser(f) // Depends on IOStreams
|
||||
|
||||
return f
|
||||
}
|
||||
|
||||
func BaseRepoFunc(f *cmdutil.Factory) func() (ghrepo.Interface, error) {
|
||||
return func() (ghrepo.Interface, error) {
|
||||
remotes, err := f.Remotes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return remotes[0], nil
|
||||
}
|
||||
}
|
||||
|
||||
func SmartBaseRepoFunc(f *cmdutil.Factory) func() (ghrepo.Interface, error) {
|
||||
return func() (ghrepo.Interface, error) {
|
||||
httpClient, err := f.HttpClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
apiClient := api.NewClientFromHTTP(httpClient)
|
||||
|
||||
remotes, err := f.Remotes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
repoContext, err := context.ResolveRemotesToRepos(remotes, apiClient, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
baseRepo, err := repoContext.BaseRepo(f.IOStreams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return baseRepo, nil
|
||||
}
|
||||
}
|
||||
|
||||
func remotesFunc(f *cmdutil.Factory) func() (context.Remotes, error) {
|
||||
rr := &remoteResolver{
|
||||
readRemotes: git.Remotes,
|
||||
getConfig: f.Config,
|
||||
}
|
||||
return rr.Resolver()
|
||||
}
|
||||
|
||||
func httpClientFunc(f *cmdutil.Factory, appVersion string) func() (*http.Client, error) {
|
||||
return func() (*http.Client, error) {
|
||||
io := f.IOStreams
|
||||
cfg, err := f.Config()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewHTTPClient(io, cfg, appVersion, true), nil
|
||||
}
|
||||
}
|
||||
|
||||
func browser(f *cmdutil.Factory) cmdutil.Browser {
|
||||
io := f.IOStreams
|
||||
return cmdutil.NewBrowser(os.Getenv("BROWSER"), io.Out, io.ErrOut)
|
||||
}
|
||||
|
||||
func executable() string {
|
||||
gh := "gh"
|
||||
if exe, err := os.Executable(); err == nil {
|
||||
gh = exe
|
||||
}
|
||||
return gh
|
||||
}
|
||||
|
||||
func configFunc() func() (config.Config, error) {
|
||||
var cachedConfig config.Config
|
||||
var configError error
|
||||
configFunc := func() (config.Config, error) {
|
||||
return func() (config.Config, error) {
|
||||
if cachedConfig != nil || configError != nil {
|
||||
return cachedConfig, configError
|
||||
}
|
||||
|
|
@ -30,45 +114,14 @@ func New(appVersion string) *cmdutil.Factory {
|
|||
cachedConfig = config.InheritEnv(cachedConfig)
|
||||
return cachedConfig, configError
|
||||
}
|
||||
}
|
||||
|
||||
rr := &remoteResolver{
|
||||
readRemotes: git.Remotes,
|
||||
getConfig: configFunc,
|
||||
}
|
||||
remotesFunc := rr.Resolver()
|
||||
|
||||
ghExecutable := "gh"
|
||||
if exe, err := os.Executable(); err == nil {
|
||||
ghExecutable = exe
|
||||
}
|
||||
|
||||
return &cmdutil.Factory{
|
||||
IOStreams: io,
|
||||
Config: configFunc,
|
||||
Remotes: remotesFunc,
|
||||
HttpClient: func() (*http.Client, error) {
|
||||
cfg, err := configFunc()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewHTTPClient(io, cfg, appVersion, true), nil
|
||||
},
|
||||
BaseRepo: func() (ghrepo.Interface, error) {
|
||||
remotes, err := remotesFunc()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return remotes[0], nil
|
||||
},
|
||||
Branch: func() (string, error) {
|
||||
currentBranch, err := git.CurrentBranch()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("could not determine current branch: %w", err)
|
||||
}
|
||||
return currentBranch, nil
|
||||
},
|
||||
Executable: ghExecutable,
|
||||
Browser: cmdutil.NewBrowser(os.Getenv("BROWSER"), io.Out, io.ErrOut),
|
||||
func branchFunc() func() (string, error) {
|
||||
return func() (string, error) {
|
||||
currentBranch, err := git.CurrentBranch()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("could not determine current branch: %w", err)
|
||||
}
|
||||
return currentBranch, nil
|
||||
}
|
||||
}
|
||||
|
|
|
|||
281
pkg/cmd/factory/default_test.go
Normal file
281
pkg/cmd/factory/default_test.go
Normal file
|
|
@ -0,0 +1,281 @@
|
|||
package factory
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/MakeNowJust/heredoc"
|
||||
"github.com/cli/cli/git"
|
||||
"github.com/cli/cli/internal/config"
|
||||
"github.com/cli/cli/pkg/cmdutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_BaseRepo(t *testing.T) {
|
||||
orig_GH_HOST := os.Getenv("GH_HOST")
|
||||
t.Cleanup(func() {
|
||||
os.Setenv("GH_HOST", orig_GH_HOST)
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
remotes git.RemoteSet
|
||||
config config.Config
|
||||
override string
|
||||
wantsErr bool
|
||||
wantsName string
|
||||
wantsOwner string
|
||||
wantsHost string
|
||||
}{
|
||||
{
|
||||
name: "matching remote",
|
||||
remotes: git.RemoteSet{
|
||||
git.NewRemote("origin", "https://nonsense.com/owner/repo.git"),
|
||||
},
|
||||
config: defaultConfig(),
|
||||
wantsName: "repo",
|
||||
wantsOwner: "owner",
|
||||
wantsHost: "nonsense.com",
|
||||
},
|
||||
{
|
||||
name: "no matching remote",
|
||||
remotes: git.RemoteSet{
|
||||
git.NewRemote("origin", "https://test.com/owner/repo.git"),
|
||||
},
|
||||
config: defaultConfig(),
|
||||
wantsErr: true,
|
||||
},
|
||||
{
|
||||
name: "override with matching remote",
|
||||
remotes: git.RemoteSet{
|
||||
git.NewRemote("origin", "https://test.com/owner/repo.git"),
|
||||
},
|
||||
config: defaultConfig(),
|
||||
override: "test.com",
|
||||
wantsName: "repo",
|
||||
wantsOwner: "owner",
|
||||
wantsHost: "test.com",
|
||||
},
|
||||
{
|
||||
name: "override with no matching remote",
|
||||
remotes: git.RemoteSet{
|
||||
git.NewRemote("origin", "https://nonsense.com/owner/repo.git"),
|
||||
},
|
||||
config: defaultConfig(),
|
||||
override: "test.com",
|
||||
wantsErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.override != "" {
|
||||
os.Setenv("GH_HOST", tt.override)
|
||||
} else {
|
||||
os.Unsetenv("GH_HOST")
|
||||
}
|
||||
f := New("1")
|
||||
rr := &remoteResolver{
|
||||
readRemotes: func() (git.RemoteSet, error) {
|
||||
return tt.remotes, nil
|
||||
},
|
||||
getConfig: func() (config.Config, error) {
|
||||
return tt.config, nil
|
||||
},
|
||||
}
|
||||
f.Remotes = rr.Resolver()
|
||||
f.BaseRepo = BaseRepoFunc(f)
|
||||
repo, err := f.BaseRepo()
|
||||
if tt.wantsErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.wantsName, repo.RepoName())
|
||||
assert.Equal(t, tt.wantsOwner, repo.RepoOwner())
|
||||
assert.Equal(t, tt.wantsHost, repo.RepoHost())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SmartBaseRepo(t *testing.T) {
|
||||
pu, _ := url.Parse("https://test.com/newowner/newrepo.git")
|
||||
orig_GH_HOST := os.Getenv("GH_HOST")
|
||||
t.Cleanup(func() {
|
||||
os.Setenv("GH_HOST", orig_GH_HOST)
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
remotes git.RemoteSet
|
||||
config config.Config
|
||||
override string
|
||||
wantsErr bool
|
||||
wantsName string
|
||||
wantsOwner string
|
||||
wantsHost string
|
||||
}{
|
||||
{
|
||||
name: "override with matching remote",
|
||||
remotes: git.RemoteSet{
|
||||
git.NewRemote("origin", "https://test.com/owner/repo.git"),
|
||||
},
|
||||
config: defaultConfig(),
|
||||
override: "test.com",
|
||||
wantsName: "repo",
|
||||
wantsOwner: "owner",
|
||||
wantsHost: "test.com",
|
||||
},
|
||||
{
|
||||
name: "override with matching remote and base resolution",
|
||||
remotes: git.RemoteSet{
|
||||
&git.Remote{Name: "origin",
|
||||
Resolved: "base",
|
||||
FetchURL: pu,
|
||||
PushURL: pu},
|
||||
},
|
||||
config: defaultConfig(),
|
||||
override: "test.com",
|
||||
wantsName: "newrepo",
|
||||
wantsOwner: "newowner",
|
||||
wantsHost: "test.com",
|
||||
},
|
||||
{
|
||||
name: "override with matching remote and nonbase resolution",
|
||||
remotes: git.RemoteSet{
|
||||
&git.Remote{Name: "origin",
|
||||
Resolved: "johnny/test",
|
||||
FetchURL: pu,
|
||||
PushURL: pu},
|
||||
},
|
||||
config: defaultConfig(),
|
||||
override: "test.com",
|
||||
wantsName: "test",
|
||||
wantsOwner: "johnny",
|
||||
wantsHost: "test.com",
|
||||
},
|
||||
{
|
||||
name: "override with no matching remote",
|
||||
remotes: git.RemoteSet{
|
||||
git.NewRemote("origin", "https://example.com/owner/repo.git"),
|
||||
},
|
||||
config: defaultConfig(),
|
||||
override: "test.com",
|
||||
wantsErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.override != "" {
|
||||
os.Setenv("GH_HOST", tt.override)
|
||||
} else {
|
||||
os.Unsetenv("GH_HOST")
|
||||
}
|
||||
f := New("1")
|
||||
rr := &remoteResolver{
|
||||
readRemotes: func() (git.RemoteSet, error) {
|
||||
return tt.remotes, nil
|
||||
},
|
||||
getConfig: func() (config.Config, error) {
|
||||
return tt.config, nil
|
||||
},
|
||||
}
|
||||
f.Remotes = rr.Resolver()
|
||||
f.BaseRepo = SmartBaseRepoFunc(f)
|
||||
repo, err := f.BaseRepo()
|
||||
if tt.wantsErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.wantsName, repo.RepoName())
|
||||
assert.Equal(t, tt.wantsOwner, repo.RepoOwner())
|
||||
assert.Equal(t, tt.wantsHost, repo.RepoHost())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Defined in pkg/cmdutil/repo_override.go but test it along with other BaseRepo functions
|
||||
func Test_OverrideBaseRepo(t *testing.T) {
|
||||
orig_GH_HOST := os.Getenv("GH_REPO")
|
||||
t.Cleanup(func() {
|
||||
os.Setenv("GH_REPO", orig_GH_HOST)
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
remotes git.RemoteSet
|
||||
config config.Config
|
||||
envOverride string
|
||||
argOverride string
|
||||
wantsErr bool
|
||||
wantsName string
|
||||
wantsOwner string
|
||||
wantsHost string
|
||||
}{
|
||||
{
|
||||
name: "override from argument",
|
||||
argOverride: "override/test",
|
||||
wantsHost: "github.com",
|
||||
wantsOwner: "override",
|
||||
wantsName: "test",
|
||||
},
|
||||
{
|
||||
name: "override from environment",
|
||||
envOverride: "somehost.com/override/test",
|
||||
wantsHost: "somehost.com",
|
||||
wantsOwner: "override",
|
||||
wantsName: "test",
|
||||
},
|
||||
{
|
||||
name: "no override",
|
||||
remotes: git.RemoteSet{
|
||||
git.NewRemote("origin", "https://nonsense.com/owner/repo.git"),
|
||||
},
|
||||
config: defaultConfig(),
|
||||
wantsHost: "nonsense.com",
|
||||
wantsOwner: "owner",
|
||||
wantsName: "repo",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.envOverride != "" {
|
||||
os.Setenv("GH_REPO", tt.envOverride)
|
||||
} else {
|
||||
os.Unsetenv("GH_REPO")
|
||||
}
|
||||
f := New("1")
|
||||
rr := &remoteResolver{
|
||||
readRemotes: func() (git.RemoteSet, error) {
|
||||
return tt.remotes, nil
|
||||
},
|
||||
getConfig: func() (config.Config, error) {
|
||||
return tt.config, nil
|
||||
},
|
||||
}
|
||||
f.Remotes = rr.Resolver()
|
||||
f.BaseRepo = cmdutil.OverrideBaseRepoFunc(f, tt.argOverride)
|
||||
repo, err := f.BaseRepo()
|
||||
if tt.wantsErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.wantsName, repo.RepoName())
|
||||
assert.Equal(t, tt.wantsOwner, repo.RepoOwner())
|
||||
assert.Equal(t, tt.wantsHost, repo.RepoHost())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func defaultConfig() config.Config {
|
||||
return config.InheritEnv(config.NewFromString(heredoc.Doc(`
|
||||
hosts:
|
||||
nonsense.com:
|
||||
oauth_token: BLAH
|
||||
`)))
|
||||
}
|
||||
|
|
@ -4,9 +4,6 @@ import (
|
|||
"net/http"
|
||||
|
||||
"github.com/MakeNowJust/heredoc"
|
||||
"github.com/cli/cli/api"
|
||||
"github.com/cli/cli/context"
|
||||
"github.com/cli/cli/internal/ghrepo"
|
||||
actionsCmd "github.com/cli/cli/pkg/cmd/actions"
|
||||
aliasCmd "github.com/cli/cli/pkg/cmd/alias"
|
||||
apiCmd "github.com/cli/cli/pkg/cmd/api"
|
||||
|
|
@ -93,7 +90,7 @@ func NewCmdRoot(f *cmdutil.Factory, version, buildDate string) *cobra.Command {
|
|||
|
||||
// below here at the commands that require the "intelligent" BaseRepo resolver
|
||||
repoResolvingCmdFactory := *f
|
||||
repoResolvingCmdFactory.BaseRepo = resolvedBaseRepo(f)
|
||||
repoResolvingCmdFactory.BaseRepo = factory.SmartBaseRepoFunc(f)
|
||||
|
||||
cmd.AddCommand(prCmd.NewCmdPR(&repoResolvingCmdFactory))
|
||||
cmd.AddCommand(issueCmd.NewCmdIssue(&repoResolvingCmdFactory))
|
||||
|
|
@ -126,29 +123,3 @@ func bareHTTPClient(f *cmdutil.Factory, version string) func() (*http.Client, er
|
|||
return factory.NewHTTPClient(f.IOStreams, cfg, version, false), nil
|
||||
}
|
||||
}
|
||||
|
||||
func resolvedBaseRepo(f *cmdutil.Factory) func() (ghrepo.Interface, error) {
|
||||
return func() (ghrepo.Interface, error) {
|
||||
httpClient, err := f.HttpClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
apiClient := api.NewClientFromHTTP(httpClient)
|
||||
|
||||
remotes, err := f.Remotes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
repoContext, err := context.ResolveRemotesToRepos(remotes, apiClient, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
baseRepo, err := repoContext.BaseRepo(f.IOStreams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return baseRepo, nil
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,14 +12,18 @@ func EnableRepoOverride(cmd *cobra.Command, f *Factory) {
|
|||
|
||||
cmd.PersistentPreRun = func(cmd *cobra.Command, args []string) {
|
||||
repoOverride, _ := cmd.Flags().GetString("repo")
|
||||
if repoFromEnv := os.Getenv("GH_REPO"); repoOverride == "" && repoFromEnv != "" {
|
||||
repoOverride = repoFromEnv
|
||||
}
|
||||
if repoOverride != "" {
|
||||
// NOTE: this mutates the factory
|
||||
f.BaseRepo = func() (ghrepo.Interface, error) {
|
||||
return ghrepo.FromFullName(repoOverride)
|
||||
}
|
||||
}
|
||||
f.BaseRepo = OverrideBaseRepoFunc(f, repoOverride)
|
||||
}
|
||||
}
|
||||
|
||||
func OverrideBaseRepoFunc(f *Factory, override string) func() (ghrepo.Interface, error) {
|
||||
if override == "" {
|
||||
override = os.Getenv("GH_REPO")
|
||||
}
|
||||
if override != "" {
|
||||
return func() (ghrepo.Interface, error) {
|
||||
return ghrepo.FromFullName(override)
|
||||
}
|
||||
}
|
||||
return f.BaseRepo
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue