diff --git a/pkg/cmd/factory/remote_resolver.go b/pkg/cmd/factory/remote_resolver.go index b008ade7e..b5cad393e 100644 --- a/pkg/cmd/factory/remote_resolver.go +++ b/pkg/cmd/factory/remote_resolver.go @@ -21,25 +21,24 @@ type remoteResolver struct { readRemotes func() (git.RemoteSet, error) getConfig func() (gh.Config, error) urlTranslator context.Translator + cachedRemotes context.Remotes + remotesError error } func (rr *remoteResolver) Resolver() func() (context.Remotes, error) { - var cachedRemotes context.Remotes - var remotesError error - return func() (context.Remotes, error) { - if cachedRemotes != nil || remotesError != nil { - return cachedRemotes, remotesError + if rr.cachedRemotes != nil || rr.remotesError != nil { + return rr.cachedRemotes, rr.remotesError } gitRemotes, err := rr.readRemotes() if err != nil { - remotesError = err + rr.remotesError = err return nil, err } if len(gitRemotes) == 0 { - remotesError = errors.New("no git remotes found") - return nil, remotesError + rr.remotesError = errors.New("no git remotes found") + return nil, rr.remotesError } sshTranslate := rr.urlTranslator @@ -68,30 +67,31 @@ func (rr *remoteResolver) Resolver() func() (context.Remotes, error) { // Sort remotes sort.Sort(resolvedRemotes) - // Filter remotes by hosts - // Note that this is not caching correctly: https://github.com/cli/cli/issues/10103 - cachedRemotes := resolvedRemotes.FilterByHosts(hosts) + rr.cachedRemotes = resolvedRemotes.FilterByHosts(hosts) // Filter again by default host if one is set // For config file default host fallback to cachedRemotes if none match // For environment default host (GH_HOST) do not fallback to cachedRemotes if none match if src != "default" { - filteredRemotes := cachedRemotes.FilterByHosts([]string{defaultHost}) + filteredRemotes := rr.cachedRemotes.FilterByHosts([]string{defaultHost}) if isHostEnv(src) || len(filteredRemotes) > 0 { - cachedRemotes = filteredRemotes + rr.cachedRemotes = filteredRemotes } } - if len(cachedRemotes) == 0 { + if len(rr.cachedRemotes) == 0 { if isHostEnv(src) { - return nil, fmt.Errorf("none of the git remotes configured for this repository correspond to the %s environment variable. Try adding a matching remote or unsetting the variable.", src) + rr.remotesError = fmt.Errorf("none of the git remotes configured for this repository correspond to the %s environment variable. Try adding a matching remote or unsetting the variable", src) + return nil, rr.remotesError } else if cfg.Authentication().HasEnvToken() { - return nil, errors.New("set the GH_HOST environment variable to specify which GitHub host to use") + rr.remotesError = errors.New("set the GH_HOST environment variable to specify which GitHub host to use") + return nil, rr.remotesError } - return nil, errors.New("none of the git remotes configured for this repository point to a known GitHub host. To tell gh about a new GitHub host, please use `gh auth login`") + rr.remotesError = errors.New("none of the git remotes configured for this repository point to a known GitHub host. To tell gh about a new GitHub host, please use `gh auth login`") + return nil, rr.remotesError } - return cachedRemotes, nil + return rr.cachedRemotes, nil } } diff --git a/pkg/cmd/factory/remote_resolver_test.go b/pkg/cmd/factory/remote_resolver_test.go index 8d537826e..6250be11d 100644 --- a/pkg/cmd/factory/remote_resolver_test.go +++ b/pkg/cmd/factory/remote_resolver_test.go @@ -1,6 +1,7 @@ package factory import ( + "errors" "net/url" "testing" @@ -9,6 +10,7 @@ import ( "github.com/cli/cli/v2/internal/gh" ghmock "github.com/cli/cli/v2/internal/gh/mock" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) type identityTranslator struct{} @@ -288,3 +290,103 @@ func Test_remoteResolver(t *testing.T) { }) } } + +func Test_remoteResolver_Caching(t *testing.T) { + t.Run("cache remotes", func(t *testing.T) { + var readRemotesCalled bool + + rr := &remoteResolver{ + readRemotes: func() (git.RemoteSet, error) { + if readRemotesCalled { + return git.RemoteSet{}, errors.New("readRemotes should only be called once") + } + + readRemotesCalled = true + return git.RemoteSet{ + git.NewRemote("origin", "https://github.com/owner/repo.git"), + }, nil + }, + getConfig: func() (gh.Config, error) { + cfg := &ghmock.ConfigMock{} + cfg.AuthenticationFunc = func() gh.AuthConfig { + authCfg := &config.AuthConfig{} + authCfg.SetHosts([]string{"github.com"}) + authCfg.SetDefaultHost("github.com", "default") + return authCfg + } + return cfg, nil + }, + urlTranslator: identityTranslator{}, + } + + resolver := rr.Resolver() + + remotes, err := resolver() + require.NoError(t, err) + names := []string{} + for _, remote := range remotes { + names = append(names, remote.Name) + } + require.Equal(t, []string{"origin"}, names) + + require.Equal(t, readRemotesCalled, true) + + cachedRemotes, err := resolver() + require.NoError(t, err) + cachedNames := []string{} + for _, remote := range cachedRemotes { + cachedNames = append(cachedNames, remote.Name) + } + require.Equal(t, []string{"origin"}, cachedNames) + }) + + t.Run("cache error", func(t *testing.T) { + var readRemotesCalled bool + + rr := &remoteResolver{ + readRemotes: func() (git.RemoteSet, error) { + if readRemotesCalled { + return git.RemoteSet{ + git.NewRemote("origin", "https://github.com/owner/repo.git"), + }, nil + } + + readRemotesCalled = true + return git.RemoteSet{}, errors.New("error to be cached") + }, + getConfig: func() (gh.Config, error) { + cfg := &ghmock.ConfigMock{} + cfg.AuthenticationFunc = func() gh.AuthConfig { + authCfg := &config.AuthConfig{} + authCfg.SetHosts([]string{"github.com"}) + authCfg.SetDefaultHost("github.com", "default") + return authCfg + } + return cfg, nil + }, + urlTranslator: identityTranslator{}, + } + + resolver := rr.Resolver() + + remotes, err := resolver() + require.Error(t, err) + require.Equal(t, err.Error(), "error to be cached") + names := []string{} + for _, remote := range remotes { + names = append(names, remote.Name) + } + require.Equal(t, []string{}, names) + + require.Equal(t, readRemotesCalled, true) + + cachedRemotes, err := resolver() + require.Error(t, err) + require.Equal(t, err.Error(), "error to be cached") + cachedNames := []string{} + for _, remote := range cachedRemotes { + cachedNames = append(cachedNames, remote.Name) + } + require.Equal(t, []string{}, cachedNames) + }) +}