diff --git a/pkg/cmd/factory/default.go b/pkg/cmd/factory/default.go index e9023c852..a5a71f579 100644 --- a/pkg/cmd/factory/default.go +++ b/pkg/cmd/factory/default.go @@ -6,6 +6,8 @@ import ( "net/http" "os" + "github.com/cli/cli/api" + "github.com/cli/cli/context" "github.com/cli/cli/git" "github.com/cli/cli/internal/config" "github.com/cli/cli/internal/ghrepo" @@ -14,11 +16,93 @@ import ( ) func New(appVersion string) *cmdutil.Factory { - io := iostreams.System() + f := &cmdutil.Factory{ + IOStreams: iostreams.System(), // No factory dependencies + Config: configFunc(), // No factory dependencies + Branch: branchFunc(), // No factory dependencies + Executable: executable(), // No factory dependencies + } + f.HttpClient = httpClientFunc(f, appVersion) // Depends on Config, IOStreams, and appVersion + f.Remotes = remotesFunc(f) // Depends on Config + f.BaseRepo = BaseRepoFunc(f) // Depends on Remotes + f.Browser = browser(f) // Depends on IOStreams + + return f +} + +func BaseRepoFunc(f *cmdutil.Factory) func() (ghrepo.Interface, error) { + return func() (ghrepo.Interface, error) { + remotes, err := f.Remotes() + if err != nil { + return nil, err + } + return remotes[0], nil + } +} + +func SmartBaseRepoFunc(f *cmdutil.Factory) func() (ghrepo.Interface, error) { + return func() (ghrepo.Interface, error) { + httpClient, err := f.HttpClient() + if err != nil { + return nil, err + } + + apiClient := api.NewClientFromHTTP(httpClient) + + remotes, err := f.Remotes() + if err != nil { + return nil, err + } + repoContext, err := context.ResolveRemotesToRepos(remotes, apiClient, "") + if err != nil { + return nil, err + } + baseRepo, err := repoContext.BaseRepo(f.IOStreams) + if err != nil { + return nil, err + } + + return baseRepo, nil + } +} + +func remotesFunc(f *cmdutil.Factory) func() (context.Remotes, error) { + rr := &remoteResolver{ + readRemotes: git.Remotes, + getConfig: f.Config, + } + return rr.Resolver() +} + +func httpClientFunc(f *cmdutil.Factory, appVersion string) func() (*http.Client, error) { + return func() (*http.Client, error) { + io := f.IOStreams + cfg, err := f.Config() + if err != nil { + return nil, err + } + return NewHTTPClient(io, cfg, appVersion, true), nil + } +} + +func browser(f *cmdutil.Factory) cmdutil.Browser { + io := f.IOStreams + return cmdutil.NewBrowser(os.Getenv("BROWSER"), io.Out, io.ErrOut) +} + +func executable() string { + gh := "gh" + if exe, err := os.Executable(); err == nil { + gh = exe + } + return gh +} + +func configFunc() func() (config.Config, error) { var cachedConfig config.Config var configError error - configFunc := func() (config.Config, error) { + return func() (config.Config, error) { if cachedConfig != nil || configError != nil { return cachedConfig, configError } @@ -30,45 +114,14 @@ func New(appVersion string) *cmdutil.Factory { cachedConfig = config.InheritEnv(cachedConfig) return cachedConfig, configError } +} - rr := &remoteResolver{ - readRemotes: git.Remotes, - getConfig: configFunc, - } - remotesFunc := rr.Resolver() - - ghExecutable := "gh" - if exe, err := os.Executable(); err == nil { - ghExecutable = exe - } - - return &cmdutil.Factory{ - IOStreams: io, - Config: configFunc, - Remotes: remotesFunc, - HttpClient: func() (*http.Client, error) { - cfg, err := configFunc() - if err != nil { - return nil, err - } - - return NewHTTPClient(io, cfg, appVersion, true), nil - }, - BaseRepo: func() (ghrepo.Interface, error) { - remotes, err := remotesFunc() - if err != nil { - return nil, err - } - return remotes[0], nil - }, - Branch: func() (string, error) { - currentBranch, err := git.CurrentBranch() - if err != nil { - return "", fmt.Errorf("could not determine current branch: %w", err) - } - return currentBranch, nil - }, - Executable: ghExecutable, - Browser: cmdutil.NewBrowser(os.Getenv("BROWSER"), io.Out, io.ErrOut), +func branchFunc() func() (string, error) { + return func() (string, error) { + currentBranch, err := git.CurrentBranch() + if err != nil { + return "", fmt.Errorf("could not determine current branch: %w", err) + } + return currentBranch, nil } } diff --git a/pkg/cmd/factory/default_test.go b/pkg/cmd/factory/default_test.go new file mode 100644 index 000000000..1f60b6b23 --- /dev/null +++ b/pkg/cmd/factory/default_test.go @@ -0,0 +1,281 @@ +package factory + +import ( + "net/url" + "os" + "testing" + + "github.com/MakeNowJust/heredoc" + "github.com/cli/cli/git" + "github.com/cli/cli/internal/config" + "github.com/cli/cli/pkg/cmdutil" + "github.com/stretchr/testify/assert" +) + +func Test_BaseRepo(t *testing.T) { + orig_GH_HOST := os.Getenv("GH_HOST") + t.Cleanup(func() { + os.Setenv("GH_HOST", orig_GH_HOST) + }) + + tests := []struct { + name string + remotes git.RemoteSet + config config.Config + override string + wantsErr bool + wantsName string + wantsOwner string + wantsHost string + }{ + { + name: "matching remote", + remotes: git.RemoteSet{ + git.NewRemote("origin", "https://nonsense.com/owner/repo.git"), + }, + config: defaultConfig(), + wantsName: "repo", + wantsOwner: "owner", + wantsHost: "nonsense.com", + }, + { + name: "no matching remote", + remotes: git.RemoteSet{ + git.NewRemote("origin", "https://test.com/owner/repo.git"), + }, + config: defaultConfig(), + wantsErr: true, + }, + { + name: "override with matching remote", + remotes: git.RemoteSet{ + git.NewRemote("origin", "https://test.com/owner/repo.git"), + }, + config: defaultConfig(), + override: "test.com", + wantsName: "repo", + wantsOwner: "owner", + wantsHost: "test.com", + }, + { + name: "override with no matching remote", + remotes: git.RemoteSet{ + git.NewRemote("origin", "https://nonsense.com/owner/repo.git"), + }, + config: defaultConfig(), + override: "test.com", + wantsErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.override != "" { + os.Setenv("GH_HOST", tt.override) + } else { + os.Unsetenv("GH_HOST") + } + f := New("1") + rr := &remoteResolver{ + readRemotes: func() (git.RemoteSet, error) { + return tt.remotes, nil + }, + getConfig: func() (config.Config, error) { + return tt.config, nil + }, + } + f.Remotes = rr.Resolver() + f.BaseRepo = BaseRepoFunc(f) + repo, err := f.BaseRepo() + if tt.wantsErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + assert.Equal(t, tt.wantsName, repo.RepoName()) + assert.Equal(t, tt.wantsOwner, repo.RepoOwner()) + assert.Equal(t, tt.wantsHost, repo.RepoHost()) + }) + } +} + +func Test_SmartBaseRepo(t *testing.T) { + pu, _ := url.Parse("https://test.com/newowner/newrepo.git") + orig_GH_HOST := os.Getenv("GH_HOST") + t.Cleanup(func() { + os.Setenv("GH_HOST", orig_GH_HOST) + }) + + tests := []struct { + name string + remotes git.RemoteSet + config config.Config + override string + wantsErr bool + wantsName string + wantsOwner string + wantsHost string + }{ + { + name: "override with matching remote", + remotes: git.RemoteSet{ + git.NewRemote("origin", "https://test.com/owner/repo.git"), + }, + config: defaultConfig(), + override: "test.com", + wantsName: "repo", + wantsOwner: "owner", + wantsHost: "test.com", + }, + { + name: "override with matching remote and base resolution", + remotes: git.RemoteSet{ + &git.Remote{Name: "origin", + Resolved: "base", + FetchURL: pu, + PushURL: pu}, + }, + config: defaultConfig(), + override: "test.com", + wantsName: "newrepo", + wantsOwner: "newowner", + wantsHost: "test.com", + }, + { + name: "override with matching remote and nonbase resolution", + remotes: git.RemoteSet{ + &git.Remote{Name: "origin", + Resolved: "johnny/test", + FetchURL: pu, + PushURL: pu}, + }, + config: defaultConfig(), + override: "test.com", + wantsName: "test", + wantsOwner: "johnny", + wantsHost: "test.com", + }, + { + name: "override with no matching remote", + remotes: git.RemoteSet{ + git.NewRemote("origin", "https://example.com/owner/repo.git"), + }, + config: defaultConfig(), + override: "test.com", + wantsErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.override != "" { + os.Setenv("GH_HOST", tt.override) + } else { + os.Unsetenv("GH_HOST") + } + f := New("1") + rr := &remoteResolver{ + readRemotes: func() (git.RemoteSet, error) { + return tt.remotes, nil + }, + getConfig: func() (config.Config, error) { + return tt.config, nil + }, + } + f.Remotes = rr.Resolver() + f.BaseRepo = SmartBaseRepoFunc(f) + repo, err := f.BaseRepo() + if tt.wantsErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + assert.Equal(t, tt.wantsName, repo.RepoName()) + assert.Equal(t, tt.wantsOwner, repo.RepoOwner()) + assert.Equal(t, tt.wantsHost, repo.RepoHost()) + }) + } +} + +// Defined in pkg/cmdutil/repo_override.go but test it along with other BaseRepo functions +func Test_OverrideBaseRepo(t *testing.T) { + orig_GH_HOST := os.Getenv("GH_REPO") + t.Cleanup(func() { + os.Setenv("GH_REPO", orig_GH_HOST) + }) + + tests := []struct { + name string + remotes git.RemoteSet + config config.Config + envOverride string + argOverride string + wantsErr bool + wantsName string + wantsOwner string + wantsHost string + }{ + { + name: "override from argument", + argOverride: "override/test", + wantsHost: "github.com", + wantsOwner: "override", + wantsName: "test", + }, + { + name: "override from environment", + envOverride: "somehost.com/override/test", + wantsHost: "somehost.com", + wantsOwner: "override", + wantsName: "test", + }, + { + name: "no override", + remotes: git.RemoteSet{ + git.NewRemote("origin", "https://nonsense.com/owner/repo.git"), + }, + config: defaultConfig(), + wantsHost: "nonsense.com", + wantsOwner: "owner", + wantsName: "repo", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.envOverride != "" { + os.Setenv("GH_REPO", tt.envOverride) + } else { + os.Unsetenv("GH_REPO") + } + f := New("1") + rr := &remoteResolver{ + readRemotes: func() (git.RemoteSet, error) { + return tt.remotes, nil + }, + getConfig: func() (config.Config, error) { + return tt.config, nil + }, + } + f.Remotes = rr.Resolver() + f.BaseRepo = cmdutil.OverrideBaseRepoFunc(f, tt.argOverride) + repo, err := f.BaseRepo() + if tt.wantsErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + assert.Equal(t, tt.wantsName, repo.RepoName()) + assert.Equal(t, tt.wantsOwner, repo.RepoOwner()) + assert.Equal(t, tt.wantsHost, repo.RepoHost()) + }) + } +} + +func defaultConfig() config.Config { + return config.InheritEnv(config.NewFromString(heredoc.Doc(` + hosts: + nonsense.com: + oauth_token: BLAH + `))) +} diff --git a/pkg/cmd/root/root.go b/pkg/cmd/root/root.go index cc971a0d4..7876eb1f9 100644 --- a/pkg/cmd/root/root.go +++ b/pkg/cmd/root/root.go @@ -4,9 +4,6 @@ import ( "net/http" "github.com/MakeNowJust/heredoc" - "github.com/cli/cli/api" - "github.com/cli/cli/context" - "github.com/cli/cli/internal/ghrepo" actionsCmd "github.com/cli/cli/pkg/cmd/actions" aliasCmd "github.com/cli/cli/pkg/cmd/alias" apiCmd "github.com/cli/cli/pkg/cmd/api" @@ -93,7 +90,7 @@ func NewCmdRoot(f *cmdutil.Factory, version, buildDate string) *cobra.Command { // below here at the commands that require the "intelligent" BaseRepo resolver repoResolvingCmdFactory := *f - repoResolvingCmdFactory.BaseRepo = resolvedBaseRepo(f) + repoResolvingCmdFactory.BaseRepo = factory.SmartBaseRepoFunc(f) cmd.AddCommand(prCmd.NewCmdPR(&repoResolvingCmdFactory)) cmd.AddCommand(issueCmd.NewCmdIssue(&repoResolvingCmdFactory)) @@ -126,29 +123,3 @@ func bareHTTPClient(f *cmdutil.Factory, version string) func() (*http.Client, er return factory.NewHTTPClient(f.IOStreams, cfg, version, false), nil } } - -func resolvedBaseRepo(f *cmdutil.Factory) func() (ghrepo.Interface, error) { - return func() (ghrepo.Interface, error) { - httpClient, err := f.HttpClient() - if err != nil { - return nil, err - } - - apiClient := api.NewClientFromHTTP(httpClient) - - remotes, err := f.Remotes() - if err != nil { - return nil, err - } - repoContext, err := context.ResolveRemotesToRepos(remotes, apiClient, "") - if err != nil { - return nil, err - } - baseRepo, err := repoContext.BaseRepo(f.IOStreams) - if err != nil { - return nil, err - } - - return baseRepo, nil - } -} diff --git a/pkg/cmdutil/repo_override.go b/pkg/cmdutil/repo_override.go index 8b3d36489..c5d996c70 100644 --- a/pkg/cmdutil/repo_override.go +++ b/pkg/cmdutil/repo_override.go @@ -12,14 +12,18 @@ func EnableRepoOverride(cmd *cobra.Command, f *Factory) { cmd.PersistentPreRun = func(cmd *cobra.Command, args []string) { repoOverride, _ := cmd.Flags().GetString("repo") - if repoFromEnv := os.Getenv("GH_REPO"); repoOverride == "" && repoFromEnv != "" { - repoOverride = repoFromEnv - } - if repoOverride != "" { - // NOTE: this mutates the factory - f.BaseRepo = func() (ghrepo.Interface, error) { - return ghrepo.FromFullName(repoOverride) - } - } + f.BaseRepo = OverrideBaseRepoFunc(f, repoOverride) } } + +func OverrideBaseRepoFunc(f *Factory, override string) func() (ghrepo.Interface, error) { + if override == "" { + override = os.Getenv("GH_REPO") + } + if override != "" { + return func() (ghrepo.Interface, error) { + return ghrepo.FromFullName(override) + } + } + return f.BaseRepo +}