diff --git a/acceptance/testdata/telemetry/no-telemetry-for-ghes-user.txtar b/acceptance/testdata/telemetry/no-telemetry-for-ghes-user.txtar new file mode 100644 index 000000000..0fe6f4bb2 --- /dev/null +++ b/acceptance/testdata/telemetry/no-telemetry-for-ghes-user.txtar @@ -0,0 +1,8 @@ +# GHES users should not get telemetry even when telemetry is enabled +env GH_PRIVATE_ENABLE_TELEMETRY=1 +env GH_TELEMETRY=log +env GH_TELEMETRY_SAMPLE_RATE=100 +env GH_ENTERPRISE_TOKEN=fake-enterprise-token + +exec gh version +! stderr 'Telemetry payload:' diff --git a/api/http_client.go b/api/http_client.go index 532f79c7f..be7a6b8a7 100644 --- a/api/http_client.go +++ b/api/http_client.go @@ -7,6 +7,7 @@ import ( "strings" "time" + "github.com/cli/cli/v2/internal/gh/ghtelemetry" "github.com/cli/cli/v2/utils" ghAPI "github.com/cli/go-gh/v2/pkg/api" ghauth "github.com/cli/go-gh/v2/pkg/auth" @@ -26,6 +27,7 @@ type HTTPClientOptions struct { LogColorize bool LogVerboseHTTP bool SkipDefaultHeaders bool + TelemetryDisabler ghtelemetry.Disabler } func NewHTTPClient(opts HTTPClientOptions) (*http.Client, error) { @@ -74,6 +76,13 @@ func NewHTTPClient(opts HTTPClientOptions) (*http.Client, error) { client.Transport = AddAuthTokenHeader(client.Transport, opts.Config) } + if opts.TelemetryDisabler != nil { + client.Transport = telemetryDisablerTransport{ + wrappedTransport: client.Transport, + telemetryDisabler: opts.TelemetryDisabler, + } + } + return client, nil } @@ -147,3 +156,15 @@ func getHost(r *http.Request) string { } return r.URL.Host } + +type telemetryDisablerTransport struct { + wrappedTransport http.RoundTripper + telemetryDisabler ghtelemetry.Disabler +} + +func (t telemetryDisablerTransport) RoundTrip(req *http.Request) (*http.Response, error) { + if ghauth.IsEnterprise(getHost(req)) { + t.telemetryDisabler.Disable() + } + return t.wrappedTransport.RoundTrip(req) +} diff --git a/api/http_client_test.go b/api/http_client_test.go index 1c81b4aa7..198c08491 100644 --- a/api/http_client_test.go +++ b/api/http_client_test.go @@ -315,6 +315,80 @@ func TestHTTPClientSanitizeControlCharactersC1(t *testing.T) { assert.Equal(t, "monalisa¡", issue.Author.Login) } +func TestNewHTTPClientTelemetryDisabler(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + })) + defer ts.Close() + + tests := []struct { + name string + host string + wantDisabled bool + }{ + { + name: "enterprise host triggers disable", + host: "ghes.example.com", + wantDisabled: true, + }, + { + name: "github.com does not trigger disable", + host: "github.com", + wantDisabled: false, + }, + { + name: "tenancy host does not trigger disable", + host: "my-company.ghe.com", + wantDisabled: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + disabler := &fakeTelemetryDisabler{} + client, err := NewHTTPClient(HTTPClientOptions{ + TelemetryDisabler: disabler, + }) + require.NoError(t, err) + + req, err := http.NewRequest("GET", ts.URL, nil) + require.NoError(t, err) + req.Host = tt.host + + res, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, 204, res.StatusCode) + assert.Equal(t, tt.wantDisabled, disabler.disabled, "Disable() called") + }) + } +} + +func TestNewHTTPClientWithoutTelemetryDisabler(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + })) + defer ts.Close() + + client, err := NewHTTPClient(HTTPClientOptions{}) + require.NoError(t, err) + + req, err := http.NewRequest("GET", ts.URL, nil) + require.NoError(t, err) + req.Host = "ghes.example.com" + + res, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, 204, res.StatusCode) +} + +type fakeTelemetryDisabler struct { + disabled bool +} + +func (f *fakeTelemetryDisabler) Disable() { + f.disabled = true +} + type tinyConfig map[string]string func (c tinyConfig) ActiveToken(host string) (string, string) { diff --git a/internal/gh/ghtelemetry/telemetry.go b/internal/gh/ghtelemetry/telemetry.go index c9256361b..197b955b4 100644 --- a/internal/gh/ghtelemetry/telemetry.go +++ b/internal/gh/ghtelemetry/telemetry.go @@ -10,8 +10,13 @@ type Event struct { Measures Measures } +type Disabler interface { + Disable() +} + type EventRecorder interface { Record(event Event) + Disabler } type CommandRecorder interface { diff --git a/internal/ghcmd/cmd.go b/internal/ghcmd/cmd.go index 9112d4283..eab842c5a 100644 --- a/internal/ghcmd/cmd.go +++ b/internal/ghcmd/cmd.go @@ -9,6 +9,7 @@ import ( "os" "os/exec" "path/filepath" + "slices" "strconv" "strings" "time" @@ -20,6 +21,7 @@ import ( "github.com/cli/cli/v2/internal/build" "github.com/cli/cli/v2/internal/config" "github.com/cli/cli/v2/internal/config/migration" + "github.com/cli/cli/v2/internal/gh" "github.com/cli/cli/v2/internal/gh/ghtelemetry" "github.com/cli/cli/v2/internal/telemetry" "github.com/cli/cli/v2/internal/update" @@ -28,6 +30,8 @@ import ( "github.com/cli/cli/v2/pkg/cmdutil" "github.com/cli/cli/v2/pkg/iostreams" "github.com/cli/cli/v2/utils" + ghauth "github.com/cli/go-gh/v2/pkg/auth" + xcolor "github.com/cli/go-gh/v2/pkg/x/color" "github.com/cli/safeexec" "github.com/mgutz/ansi" "github.com/spf13/cobra" @@ -48,33 +52,34 @@ func Main() exitCode { buildVersion := build.Version hasDebug, _ := utils.IsDebugEnabled() - cmdFactory := factory.New(buildVersion, string(agents.Detect())) - stderr := cmdFactory.IOStreams.ErrOut - - cfg, err := cmdFactory.Config() + cfg, err := config.NewConfig() if err != nil { - fmt.Fprintf(stderr, "failed to load config: %s\n", err) + fmt.Fprintf(os.Stderr, "failed to load config: %s\n", err) return exitError } + ioStreams := newIOStreams(cfg) + stderr := ioStreams.ErrOut + + ghExecutablePath := executablePath("gh") + additionalCommonDimensions := ghtelemetry.Dimensions{ "version": strings.TrimPrefix(buildVersion, "v"), - "is_tty": strconv.FormatBool(cmdFactory.IOStreams.IsStdoutTTY()), + "is_tty": strconv.FormatBool(ioStreams.IsStdoutTTY()), "agent": string(agents.Detect()), } var telemetryService ghtelemetry.Service - if os.Getenv("GH_PRIVATE_ENABLE_TELEMETRY") == "" { + if os.Getenv("GH_PRIVATE_ENABLE_TELEMETRY") == "" || mightBeGHESUser(cfg) { telemetryService = &telemetry.NoOpService{} } else { - telemetryState := telemetry.ParseTelemetryState(cfg.Telemetry().Value) switch telemetryState { case telemetry.Disabled: telemetryService = &telemetry.NoOpService{} case telemetry.Logged: telemetryService = telemetry.NewService( - telemetry.LogFlusher(cmdFactory.IOStreams.ErrOut, cmdFactory.IOStreams.ColorEnabled()), + telemetry.LogFlusher(ioStreams.ErrOut, ioStreams.ColorEnabled()), telemetry.WithAdditionalCommonDimensions(additionalCommonDimensions), ) case telemetry.Enabled: @@ -84,7 +89,7 @@ func Main() exitCode { } additionalCommonDimensions["sample_rate"] = strconv.Itoa(sampleRate) telemetryService = telemetry.NewService( - telemetry.GitHubFlusher(cmdFactory.Executable()), + telemetry.GitHubFlusher(ghExecutablePath), telemetry.WithAdditionalCommonDimensions(additionalCommonDimensions), telemetry.WithSampleRate(sampleRate), ) @@ -95,6 +100,8 @@ func Main() exitCode { } defer telemetryService.Flush() + cmdFactory := factory.New(buildVersion, string(agents.Detect()), cfg, ioStreams, ghExecutablePath, telemetryService) + var m migration.MultiAccount if err := cfg.Migrate(m); err != nil { fmt.Fprintln(stderr, err) @@ -211,7 +218,7 @@ func Main() exitCode { updateCancel() // if the update checker hasn't completed by now, abort it newRelease := <-updateMessageChan if newRelease != nil { - isHomebrew := isUnderHomebrew(cmdFactory.Executable()) + isHomebrew := isUnderHomebrew(cmdFactory.ExecutablePath) if isHomebrew && isRecentRelease(newRelease.PublishedAt) { // do not notify Homebrew users before the version bump had a chance to get merged into homebrew-core return exitOK @@ -289,3 +296,148 @@ func isUnderHomebrew(ghBinary string) bool { brewBinPrefix := filepath.Join(strings.TrimSpace(string(brewPrefixBytes)), "bin") + string(filepath.Separator) return strings.HasPrefix(ghBinary, brewBinPrefix) } + +func newIOStreams(cfg gh.Config) *iostreams.IOStreams { + io := iostreams.System() + + if _, ghPromptDisabled := os.LookupEnv("GH_PROMPT_DISABLED"); ghPromptDisabled { + io.SetNeverPrompt(true) + } else if prompt := cfg.Prompt(""); prompt.Value == "disabled" { + io.SetNeverPrompt(true) + } + + falseyValues := []string{"false", "0", "no", ""} + + accessiblePrompterValue, accessiblePrompterIsSet := os.LookupEnv("GH_ACCESSIBLE_PROMPTER") + if accessiblePrompterIsSet { + if !slices.Contains(falseyValues, accessiblePrompterValue) { + io.SetAccessiblePrompterEnabled(true) + } + } else if prompt := cfg.AccessiblePrompter(""); prompt.Value == "enabled" { + io.SetAccessiblePrompterEnabled(true) + } + + experimentalPrompterValue, experimentalPrompterIsSet := os.LookupEnv("GH_EXPERIMENTAL_PROMPTER") + if experimentalPrompterIsSet { + if !slices.Contains(falseyValues, experimentalPrompterValue) { + io.SetExperimentalPrompterEnabled(true) + } + } + + ghSpinnerDisabledValue, ghSpinnerDisabledIsSet := os.LookupEnv("GH_SPINNER_DISABLED") + if ghSpinnerDisabledIsSet { + if !slices.Contains(falseyValues, ghSpinnerDisabledValue) { + io.SetSpinnerDisabled(true) + } + } else if spinnerDisabled := cfg.Spinner(""); spinnerDisabled.Value == "disabled" { + io.SetSpinnerDisabled(true) + } + + // Pager precedence + // 1. GH_PAGER + // 2. pager from config + // 3. PAGER + if ghPager, ghPagerExists := os.LookupEnv("GH_PAGER"); ghPagerExists { + io.SetPager(ghPager) + } else if pager := cfg.Pager(""); pager.Value != "" { + io.SetPager(pager.Value) + } + + if ghColorLabels, ghColorLabelsExists := os.LookupEnv("GH_COLOR_LABELS"); ghColorLabelsExists { + switch ghColorLabels { + case "", "0", "false", "no": + io.SetColorLabels(false) + default: + io.SetColorLabels(true) + } + } else if prompt := cfg.ColorLabels(""); prompt.Value == "enabled" { + io.SetColorLabels(true) + } + + io.SetAccessibleColorsEnabled(xcolor.IsAccessibleColorsEnabled()) + + return io +} + +// Executable is the path to the currently invoked binary +func executablePath(executableName string) string { + ghPath := os.Getenv("GH_PATH") + if ghPath != "" { + return ghPath + } + + if strings.ContainsRune(executableName, os.PathSeparator) { + return executableName + } + + return executable(executableName) +} + +// Finds the location of the executable for the current process as it's found in PATH, respecting symlinks. +// If the process couldn't determine its location, return fallbackName. If the executable wasn't found in +// PATH, return the absolute location to the program. +// +// The idea is that the result of this function is callable in the future and refers to the same +// installation of gh, even across upgrades. This is needed primarily for Homebrew, which installs software +// under a location such as `/usr/local/Cellar/gh/1.13.1/bin/gh` and symlinks it from `/usr/local/bin/gh`. +// When the version is upgraded, Homebrew will often delete older versions, but keep the symlink. Because of +// this, we want to refer to the `gh` binary as `/usr/local/bin/gh` and not as its internal Homebrew +// location. +// +// None of this would be needed if we could just refer to GitHub CLI as `gh`, i.e. without using an absolute +// path. However, for some reason Homebrew does not include `/usr/local/bin` in PATH when it invokes git +// commands to update its taps. If `gh` (no path) is being used as git credential helper, as set up by `gh +// auth login`, running `brew update` will print out authentication errors as git is unable to locate +// Homebrew-installed `gh` +func executable(fallback string) string { + exe, err := os.Executable() + if err != nil { + return fallback + } + + base := filepath.Base(exe) + path := os.Getenv("PATH") + for _, dir := range filepath.SplitList(path) { + p, err := filepath.Abs(filepath.Join(dir, base)) + if err != nil { + continue + } + f, err := os.Lstat(p) + if err != nil { + continue + } + + if p == exe { + return p + } else if f.Mode()&os.ModeSymlink != 0 { + realP, err := filepath.EvalSymlinks(p) + if err != nil { + continue + } + realExe, err := filepath.EvalSymlinks(exe) + if err != nil { + continue + } + if realP == realExe { + return p + } + } + } + + return exe +} + +func mightBeGHESUser(cfg gh.Config) bool { + if os.Getenv("GH_ENTERPRISE_TOKEN") != "" || os.Getenv("GITHUB_ENTERPRISE_TOKEN") != "" { + return true + } + + if host := os.Getenv("GH_HOST"); host != "" && ghauth.IsEnterprise(host) { + return true + } + + // If any targeted host is Enterprise, then the user is likely a GHES user. + return slices.ContainsFunc(cfg.Authentication().Hosts(), func(host string) bool { + return ghauth.IsEnterprise(host) + }) +} diff --git a/internal/ghcmd/cmd_test.go b/internal/ghcmd/cmd_test.go index 08bbceb85..65bcc0f28 100644 --- a/internal/ghcmd/cmd_test.go +++ b/internal/ghcmd/cmd_test.go @@ -7,8 +7,11 @@ import ( "net" "testing" + "github.com/cli/cli/v2/internal/config" + "github.com/cli/cli/v2/internal/gh" "github.com/cli/cli/v2/pkg/cmdutil" "github.com/spf13/cobra" + "github.com/stretchr/testify/assert" ) func Test_printError(t *testing.T) { @@ -76,3 +79,435 @@ check your internet connection or https://githubstatus.com }) } } + +func Test_newIOStreams_pager(t *testing.T) { + tests := []struct { + name string + env map[string]string + config gh.Config + wantPager string + }{ + { + name: "GH_PAGER and PAGER set", + env: map[string]string{ + "GH_PAGER": "GH_PAGER", + "PAGER": "PAGER", + }, + wantPager: "GH_PAGER", + }, + { + name: "GH_PAGER and config pager set", + env: map[string]string{ + "GH_PAGER": "GH_PAGER", + }, + config: pagerConfig(), + wantPager: "GH_PAGER", + }, + { + name: "config pager and PAGER set", + env: map[string]string{ + "PAGER": "PAGER", + }, + config: pagerConfig(), + wantPager: "CONFIG_PAGER", + }, + { + name: "only PAGER set", + env: map[string]string{ + "PAGER": "PAGER", + }, + wantPager: "PAGER", + }, + { + name: "GH_PAGER set to blank string", + env: map[string]string{ + "GH_PAGER": "", + "PAGER": "PAGER", + }, + wantPager: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.env != nil { + for k, v := range tt.env { + t.Setenv(k, v) + } + } + var cfg gh.Config + if tt.config != nil { + cfg = tt.config + } else { + cfg = config.NewBlankConfig() + } + io := newIOStreams(cfg) + assert.Equal(t, tt.wantPager, io.GetPager()) + }) + } +} + +func Test_newIOStreams_prompt(t *testing.T) { + tests := []struct { + name string + config gh.Config + promptDisabled bool + env map[string]string + }{ + { + name: "default config", + promptDisabled: false, + }, + { + name: "config with prompt disabled", + config: disablePromptConfig(), + promptDisabled: true, + }, + { + name: "prompt disabled via GH_PROMPT_DISABLED env var", + env: map[string]string{"GH_PROMPT_DISABLED": "1"}, + promptDisabled: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.env != nil { + for k, v := range tt.env { + t.Setenv(k, v) + } + } + var cfg gh.Config + if tt.config != nil { + cfg = tt.config + } else { + cfg = config.NewBlankConfig() + } + io := newIOStreams(cfg) + assert.Equal(t, tt.promptDisabled, io.GetNeverPrompt()) + }) + } +} + +func Test_newIOStreams_spinnerDisabled(t *testing.T) { + tests := []struct { + name string + config gh.Config + spinnerDisabled bool + env map[string]string + }{ + { + name: "default config", + spinnerDisabled: false, + }, + { + name: "config with spinner disabled", + config: disableSpinnersConfig(), + spinnerDisabled: true, + }, + { + name: "config with spinner enabled", + config: enableSpinnersConfig(), + spinnerDisabled: false, + }, + { + name: "spinner disabled via GH_SPINNER_DISABLED env var = 0", + env: map[string]string{"GH_SPINNER_DISABLED": "0"}, + spinnerDisabled: false, + }, + { + name: "spinner disabled via GH_SPINNER_DISABLED env var = false", + env: map[string]string{"GH_SPINNER_DISABLED": "false"}, + spinnerDisabled: false, + }, + { + name: "spinner disabled via GH_SPINNER_DISABLED env var = no", + env: map[string]string{"GH_SPINNER_DISABLED": "no"}, + spinnerDisabled: false, + }, + { + name: "spinner enabled via GH_SPINNER_DISABLED env var = 1", + env: map[string]string{"GH_SPINNER_DISABLED": "1"}, + spinnerDisabled: true, + }, + { + name: "spinner enabled via GH_SPINNER_DISABLED env var = true", + env: map[string]string{"GH_SPINNER_DISABLED": "true"}, + spinnerDisabled: true, + }, + { + name: "config enabled but env disabled, respects env", + config: enableSpinnersConfig(), + env: map[string]string{"GH_SPINNER_DISABLED": "true"}, + spinnerDisabled: true, + }, + { + name: "config disabled but env enabled, respects env", + config: disableSpinnersConfig(), + env: map[string]string{"GH_SPINNER_DISABLED": "false"}, + spinnerDisabled: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + for k, v := range tt.env { + t.Setenv(k, v) + } + var cfg gh.Config + if tt.config != nil { + cfg = tt.config + } else { + cfg = config.NewBlankConfig() + } + io := newIOStreams(cfg) + assert.Equal(t, tt.spinnerDisabled, io.GetSpinnerDisabled()) + }) + } +} + +func Test_newIOStreams_accessiblePrompterEnabled(t *testing.T) { + tests := []struct { + name string + config gh.Config + accessiblePrompterEnabled bool + env map[string]string + }{ + { + name: "default config", + accessiblePrompterEnabled: false, + }, + { + name: "config with accessible prompter enabled", + config: enableAccessiblePrompterConfig(), + accessiblePrompterEnabled: true, + }, + { + name: "config with accessible prompter disabled", + config: disableAccessiblePrompterConfig(), + accessiblePrompterEnabled: false, + }, + { + name: "accessible prompter enabled via GH_ACCESSIBLE_PROMPTER env var = 1", + env: map[string]string{"GH_ACCESSIBLE_PROMPTER": "1"}, + accessiblePrompterEnabled: true, + }, + { + name: "accessible prompter enabled via GH_ACCESSIBLE_PROMPTER env var = true", + env: map[string]string{"GH_ACCESSIBLE_PROMPTER": "true"}, + accessiblePrompterEnabled: true, + }, + { + name: "accessible prompter disabled via GH_ACCESSIBLE_PROMPTER env var = 0", + env: map[string]string{"GH_ACCESSIBLE_PROMPTER": "0"}, + accessiblePrompterEnabled: false, + }, + { + name: "config disabled but env enabled, respects env", + config: disableAccessiblePrompterConfig(), + env: map[string]string{"GH_ACCESSIBLE_PROMPTER": "true"}, + accessiblePrompterEnabled: true, + }, + { + name: "config enabled but env disabled, respects env", + config: enableAccessiblePrompterConfig(), + env: map[string]string{"GH_ACCESSIBLE_PROMPTER": "false"}, + accessiblePrompterEnabled: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + for k, v := range tt.env { + t.Setenv(k, v) + } + var cfg gh.Config + if tt.config != nil { + cfg = tt.config + } else { + cfg = config.NewBlankConfig() + } + io := newIOStreams(cfg) + assert.Equal(t, tt.accessiblePrompterEnabled, io.AccessiblePrompterEnabled()) + }) + } +} + +func Test_newIOStreams_colorLabels(t *testing.T) { + tests := []struct { + name string + config gh.Config + colorLabelsEnabled bool + env map[string]string + }{ + { + name: "default config", + colorLabelsEnabled: false, + }, + { + name: "config with colorLabels enabled", + config: enableColorLabelsConfig(), + colorLabelsEnabled: true, + }, + { + name: "config with colorLabels disabled", + config: disableColorLabelsConfig(), + colorLabelsEnabled: false, + }, + { + name: "colorLabels enabled via `1` in GH_COLOR_LABELS env var", + env: map[string]string{"GH_COLOR_LABELS": "1"}, + colorLabelsEnabled: true, + }, + { + name: "colorLabels enabled via `true` in GH_COLOR_LABELS env var", + env: map[string]string{"GH_COLOR_LABELS": "true"}, + colorLabelsEnabled: true, + }, + { + name: "colorLabels enabled via `yes` in GH_COLOR_LABELS env var", + env: map[string]string{"GH_COLOR_LABELS": "yes"}, + colorLabelsEnabled: true, + }, + { + name: "colorLabels disable via empty string in GH_COLOR_LABELS env var", + env: map[string]string{"GH_COLOR_LABELS": ""}, + colorLabelsEnabled: false, + }, + { + name: "colorLabels disabled via `0` in GH_COLOR_LABELS env var", + env: map[string]string{"GH_COLOR_LABELS": "0"}, + colorLabelsEnabled: false, + }, + { + name: "colorLabels disabled via `false` in GH_COLOR_LABELS env var", + env: map[string]string{"GH_COLOR_LABELS": "false"}, + colorLabelsEnabled: false, + }, + { + name: "colorLabels disabled via `no` in GH_COLOR_LABELS env var", + env: map[string]string{"GH_COLOR_LABELS": "no"}, + colorLabelsEnabled: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.env != nil { + for k, v := range tt.env { + t.Setenv(k, v) + } + } + var cfg gh.Config + if tt.config != nil { + cfg = tt.config + } else { + cfg = config.NewBlankConfig() + } + io := newIOStreams(cfg) + assert.Equal(t, tt.colorLabelsEnabled, io.ColorLabels()) + }) + } +} + +func Test_mightBeGHESUser(t *testing.T) { + tests := []struct { + name string + env map[string]string + config gh.Config + want bool + }{ + { + name: "GH_ENTERPRISE_TOKEN set", + env: map[string]string{"GH_ENTERPRISE_TOKEN": "some-token"}, + config: config.NewBlankConfig(), + want: true, + }, + { + name: "GITHUB_ENTERPRISE_TOKEN set", + env: map[string]string{"GITHUB_ENTERPRISE_TOKEN": "some-token"}, + config: config.NewBlankConfig(), + want: true, + }, + { + name: "no env vars, config has enterprise host", + config: config.NewFromString("hosts:\n ghes.example.com:\n oauth_token: abc123\n"), + want: true, + }, + { + name: "no env vars, config has only github.com", + config: config.NewFromString("hosts:\n github.com:\n oauth_token: abc123\n"), + want: false, + }, + { + name: "no env vars, config has no hosts", + config: config.NewBlankConfig(), + want: false, + }, + { + name: "no env vars, config has github.com and enterprise host", + config: config.NewFromString("hosts:\n github.com:\n oauth_token: abc123\n ghes.example.com:\n oauth_token: def456\n"), + want: true, + }, + { + name: "no env vars, config has tenancy host", + config: config.NewFromString("hosts:\n my-company.ghe.com:\n oauth_token: abc123\n"), + want: false, + }, + { + name: "GH_HOST set to enterprise host", + env: map[string]string{"GH_HOST": "ghes.example.com"}, + config: config.NewBlankConfig(), + want: true, + }, + { + name: "GH_HOST set to github.com", + env: map[string]string{"GH_HOST": "github.com"}, + config: config.NewBlankConfig(), + want: false, + }, + { + name: "GH_HOST set to tenancy host", + env: map[string]string{"GH_HOST": "my-company.ghe.com"}, + config: config.NewBlankConfig(), + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + for k, v := range tt.env { + t.Setenv(k, v) + } + got := mightBeGHESUser(tt.config) + assert.Equal(t, tt.want, got) + }) + } +} + +func pagerConfig() gh.Config { + return config.NewFromString("pager: CONFIG_PAGER") +} + +func disablePromptConfig() gh.Config { + return config.NewFromString("prompt: disabled") +} + +func enableAccessiblePrompterConfig() gh.Config { + return config.NewFromString("accessible_prompter: enabled") +} + +func disableAccessiblePrompterConfig() gh.Config { + return config.NewFromString("accessible_prompter: disabled") +} + +func disableSpinnersConfig() gh.Config { + return config.NewFromString("spinner: disabled") +} + +func enableSpinnersConfig() gh.Config { + return config.NewFromString("spinner: enabled") +} + +func disableColorLabelsConfig() gh.Config { + return config.NewFromString("color_labels: disabled") +} + +func enableColorLabelsConfig() gh.Config { + return config.NewFromString("color_labels: enabled") +} diff --git a/pkg/cmdutil/factory_test.go b/internal/ghcmd/executable_test.go similarity index 94% rename from pkg/cmdutil/factory_test.go rename to internal/ghcmd/executable_test.go index 0103a04f1..f0374429b 100644 --- a/pkg/cmdutil/factory_test.go +++ b/internal/ghcmd/executable_test.go @@ -1,10 +1,12 @@ -package cmdutil +package ghcmd import ( "os" "path/filepath" "strings" "testing" + + "github.com/stretchr/testify/require" ) func Test_executable(t *testing.T) { @@ -113,11 +115,8 @@ func Test_executable_relative(t *testing.T) { } } -func Test_Executable_override(t *testing.T) { +func TestExecutablePath(t *testing.T) { override := strings.Join([]string{"C:", "cygwin64", "home", "gh.exe"}, string(os.PathSeparator)) t.Setenv("GH_PATH", override) - f := Factory{} - if got := f.Executable(); got != override { - t.Errorf("executable() = %q, want %q", got, override) - } + require.Equal(t, override, executablePath("gh")) } diff --git a/internal/telemetry/fake.go b/internal/telemetry/fake.go index ee38262d9..1dc45ab26 100644 --- a/internal/telemetry/fake.go +++ b/internal/telemetry/fake.go @@ -10,4 +10,6 @@ func (r *EventRecorderSpy) Record(event ghtelemetry.Event) { r.Events = append(r.Events, event) } +func (r *EventRecorderSpy) Disable() {} + func (r *EventRecorderSpy) Flush() {} diff --git a/internal/telemetry/telemetry.go b/internal/telemetry/telemetry.go index f8698706a..b046ec77d 100644 --- a/internal/telemetry/telemetry.go +++ b/internal/telemetry/telemetry.go @@ -248,6 +248,15 @@ type service struct { sampleBucket byte events []recordedEvent + + disabled bool +} + +func (s *service) Disable() { + s.mu.Lock() + defer s.mu.Unlock() + + s.disabled = true } func (s *service) Record(event ghtelemetry.Event) { @@ -269,6 +278,10 @@ func (s *service) Flush() { s.mu.Lock() defer s.mu.Unlock() + if s.disabled { + return + } + if s.previouslyCalled { return } @@ -379,6 +392,8 @@ type NoOpService struct{} func (s *NoOpService) Record(event ghtelemetry.Event) {} +func (s *NoOpService) Disable() {} + func (s *NoOpService) SetSampleRate(rate int) {} func (s *NoOpService) Flush() {} diff --git a/internal/telemetry/telemetry_test.go b/internal/telemetry/telemetry_test.go index 0142d4d16..207d611ee 100644 --- a/internal/telemetry/telemetry_test.go +++ b/internal/telemetry/telemetry_test.go @@ -598,10 +598,54 @@ func TestWithAdditionalCommonDimensions(t *testing.T) { assert.NotEmpty(t, captured.Events[0].Dimensions["architecture"]) } +func TestServiceDisable(t *testing.T) { + t.Run("prevents flush from sending events", func(t *testing.T) { + t.Cleanup(stubDeviceID("test-device")) + + called := false + svc := newService(func(SendTelemetryPayload) { called = true }, nil) + + svc.Record(ghtelemetry.Event{Type: "test"}) + svc.Disable() + svc.Flush() + + assert.False(t, called, "flusher should not be called after Disable()") + }) + + t.Run("prevents flush even with multiple recorded events", func(t *testing.T) { + t.Cleanup(stubDeviceID("test-device")) + + called := false + svc := newService(func(SendTelemetryPayload) { called = true }, nil) + + svc.Record(ghtelemetry.Event{Type: "event1"}) + svc.Record(ghtelemetry.Event{Type: "event2"}) + svc.Record(ghtelemetry.Event{Type: "event3"}) + svc.Disable() + svc.Flush() + + assert.False(t, called, "flusher should not be called after Disable()") + }) + + t.Run("can be called before any events are recorded", func(t *testing.T) { + t.Cleanup(stubDeviceID("test-device")) + + called := false + svc := newService(func(SendTelemetryPayload) { called = true }, nil) + + svc.Disable() + svc.Record(ghtelemetry.Event{Type: "test"}) + svc.Flush() + + assert.False(t, called, "flusher should not be called when disabled before recording") + }) +} + func TestNoOpService(t *testing.T) { svc := &NoOpService{} // All methods should be safe to call without panicking svc.Record(ghtelemetry.Event{Type: "test"}) + svc.Disable() svc.SetSampleRate(50) svc.Flush() } diff --git a/pkg/cmd/api/api.go b/pkg/cmd/api/api.go index fb641457f..4a87e0f8c 100644 --- a/pkg/cmd/api/api.go +++ b/pkg/cmd/api/api.go @@ -223,7 +223,7 @@ func NewCmdApi(f *cmdutil.Factory, runF func(*ApiOptions) error) *cobra.Command }, Args: cobra.ExactArgs(1), PreRun: func(c *cobra.Command, args []string) { - opts.BaseRepo = cmdutil.OverrideBaseRepoFunc(f, "") + opts.BaseRepo = cmdutil.OverrideBaseRepoFunc(f.BaseRepo, "") }, RunE: func(c *cobra.Command, args []string) error { opts.RequestPath = args[0] diff --git a/pkg/cmd/attestation/verify/verify_integration_test.go b/pkg/cmd/attestation/verify/verify_integration_test.go index ec64cefa7..10a1e5216 100644 --- a/pkg/cmd/attestation/verify/verify_integration_test.go +++ b/pkg/cmd/attestation/verify/verify_integration_test.go @@ -1,17 +1,19 @@ -//go:build integration - package verify import ( "net/http" "testing" + "github.com/cli/cli/v2/internal/config" + "github.com/cli/cli/v2/internal/gh" + "github.com/cli/cli/v2/internal/telemetry" "github.com/cli/cli/v2/pkg/cmd/attestation/api" "github.com/cli/cli/v2/pkg/cmd/attestation/artifact/oci" "github.com/cli/cli/v2/pkg/cmd/attestation/io" "github.com/cli/cli/v2/pkg/cmd/attestation/test" "github.com/cli/cli/v2/pkg/cmd/attestation/verification" "github.com/cli/cli/v2/pkg/cmd/factory" + "github.com/cli/cli/v2/pkg/iostreams" o "github.com/cli/cli/v2/pkg/option" "github.com/cli/go-gh/v2/pkg/auth" "github.com/stretchr/testify/require" @@ -26,12 +28,15 @@ func TestVerifyIntegration(t *testing.T) { TUFMetadataDir: o.Some(t.TempDir()), } - cmdFactory := factory.New("test", "") - - hc, err := cmdFactory.HttpClient() - if err != nil { - t.Fatal(err) - } + ios, _, _, _ := iostreams.Test() + hc, err := factory.HttpClientFunc( + &config.AuthConfig{}, + ios, + "test", + "", + &telemetry.NoOpService{}, + )() + require.NoError(t, err) host, _ := auth.DefaultHost() @@ -143,12 +148,15 @@ func TestVerifyIntegrationCustomIssuer(t *testing.T) { TUFMetadataDir: o.Some(t.TempDir()), } - cmdFactory := factory.New("test", "") - - hc, err := cmdFactory.HttpClient() - if err != nil { - t.Fatal(err) - } + ios, _, _, _ := iostreams.Test() + hc, err := factory.HttpClientFunc( + &config.AuthConfig{}, + ios, + "test", + "", + &telemetry.NoOpService{}, + )() + require.NoError(t, err) host, _ := auth.DefaultHost() @@ -217,12 +225,16 @@ func TestVerifyIntegrationReusableWorkflow(t *testing.T) { TUFMetadataDir: o.Some(t.TempDir()), } - cmdFactory := factory.New("test", "") - - hc, err := cmdFactory.HttpClient() - if err != nil { - t.Fatal(err) - } + cfg := config.NewBlankConfig() + ios, _, _, _ := iostreams.Test() + hc, err := factory.HttpClientFunc( + cfg.Authentication(), + ios, + "test", + "", + &telemetry.NoOpService{}, + )() + require.NoError(t, err) host, _ := auth.DefaultHost() @@ -310,22 +322,28 @@ func TestVerifyIntegrationReusableWorkflowSignerWorkflow(t *testing.T) { TUFMetadataDir: o.Some(t.TempDir()), } - cmdFactory := factory.New("test", "") - - hc, err := cmdFactory.HttpClient() - if err != nil { - t.Fatal(err) - } + cfg := config.NewBlankConfig() + ios, _, _, _ := iostreams.Test() + hc, err := factory.HttpClientFunc( + cfg.Authentication(), + ios, + "test", + "", + &telemetry.NoOpService{}, + )() + require.NoError(t, err) host, _ := auth.DefaultHost() sigstoreVerifier, err := verification.NewLiveSigstoreVerifier(sigstoreConfig) require.NoError(t, err) baseOpts := Options{ - APIClient: api.NewLiveClient(hc, host, logger), - ArtifactPath: artifactPath, - BundlePath: bundlePath, - Config: cmdFactory.Config, + APIClient: api.NewLiveClient(hc, host, logger), + ArtifactPath: artifactPath, + BundlePath: bundlePath, + Config: func() (gh.Config, error) { + return cfg, nil + }, DigestAlgorithm: "sha256", Logger: logger, OCIClient: oci.NewLiveClient(), diff --git a/pkg/cmd/auth/login/login.go b/pkg/cmd/auth/login/login.go index 88bc09f63..24d30c562 100644 --- a/pkg/cmd/auth/login/login.go +++ b/pkg/cmd/auth/login/login.go @@ -138,7 +138,7 @@ func NewCmdLogin(f *cmdutil.Factory, runF func(*LoginOptions) error) *cobra.Comm opts.Hostname, _ = ghauth.DefaultHost() } - opts.MainExecutable = f.Executable() + opts.MainExecutable = f.ExecutablePath if runF != nil { return runF(opts) } diff --git a/pkg/cmd/auth/refresh/refresh.go b/pkg/cmd/auth/refresh/refresh.go index c025df465..842902502 100644 --- a/pkg/cmd/auth/refresh/refresh.go +++ b/pkg/cmd/auth/refresh/refresh.go @@ -101,7 +101,7 @@ func NewCmdRefresh(f *cmdutil.Factory, runF func(*RefreshOptions) error) *cobra. return cmdutil.FlagErrorf("--hostname required when not running interactively") } - opts.MainExecutable = f.Executable() + opts.MainExecutable = f.ExecutablePath if runF != nil { return runF(opts) } diff --git a/pkg/cmd/auth/setupgit/setupgit.go b/pkg/cmd/auth/setupgit/setupgit.go index 0ff7b6903..a146a579f 100644 --- a/pkg/cmd/auth/setupgit/setupgit.go +++ b/pkg/cmd/auth/setupgit/setupgit.go @@ -53,7 +53,7 @@ func NewCmdSetupGit(f *cmdutil.Factory, runF func(*SetupGitOptions) error) *cobr `), RunE: func(cmd *cobra.Command, args []string) error { opts.CredentialsHelperConfig = &gitcredentials.HelperConfig{ - SelfExecutablePath: f.Executable(), + SelfExecutablePath: f.ExecutablePath, GitClient: f.GitClient, } if opts.Hostname == "" && opts.Force { diff --git a/pkg/cmd/codespace/root.go b/pkg/cmd/codespace/root.go index d1675a8f7..5d3bff3d6 100644 --- a/pkg/cmd/codespace/root.go +++ b/pkg/cmd/codespace/root.go @@ -7,6 +7,14 @@ import ( "github.com/spf13/cobra" ) +type ghExecutable struct { + executablePath string +} + +func (e *ghExecutable) Executable() string { + return e.executablePath +} + func NewCmdCodespace(f *cmdutil.Factory) *cobra.Command { root := &cobra.Command{ Use: "codespace", @@ -17,7 +25,7 @@ func NewCmdCodespace(f *cmdutil.Factory) *cobra.Command { app := NewApp( f.IOStreams, - f, + &ghExecutable{executablePath: f.ExecutablePath}, codespacesAPI.New(f), f.Browser, f.Remotes, diff --git a/pkg/cmd/factory/default.go b/pkg/cmd/factory/default.go index 7afc6baa7..bf203bd43 100644 --- a/pkg/cmd/factory/default.go +++ b/pkg/cmd/factory/default.go @@ -4,46 +4,46 @@ import ( "context" "fmt" "net/http" - "os" "regexp" - "slices" "time" "github.com/cli/cli/v2/api" ghContext "github.com/cli/cli/v2/context" "github.com/cli/cli/v2/git" "github.com/cli/cli/v2/internal/browser" - "github.com/cli/cli/v2/internal/config" "github.com/cli/cli/v2/internal/gh" + "github.com/cli/cli/v2/internal/gh/ghtelemetry" "github.com/cli/cli/v2/internal/ghrepo" "github.com/cli/cli/v2/internal/prompter" "github.com/cli/cli/v2/pkg/cmd/extension" "github.com/cli/cli/v2/pkg/cmdutil" "github.com/cli/cli/v2/pkg/iostreams" - xcolor "github.com/cli/go-gh/v2/pkg/x/color" ) var ssoHeader string var ssoURLRE = regexp.MustCompile(`\burl=([^;]+)`) -func New(appVersion string, invokingAgent string) *cmdutil.Factory { +func New(appVersion string, invokingAgent string, cfg gh.Config, ios *iostreams.IOStreams, executablePath string, telemetryDisabler ghtelemetry.Disabler) *cmdutil.Factory { f := &cmdutil.Factory{ - AppVersion: appVersion, - InvokingAgent: invokingAgent, - Config: configFunc(), // No factory dependencies - ExecutableName: "gh", + AppVersion: appVersion, + InvokingAgent: invokingAgent, + Cfg: cfg, + Config: func() (gh.Config, error) { + return cfg, nil + }, // No factory dependencies + ExecutablePath: executablePath, } - f.IOStreams = ioStreams(f) // Depends on Config - f.HttpClient = httpClientFunc(f, appVersion, invokingAgent) // Depends on Config, IOStreams, appVersion, and invokingAgent - f.PlainHttpClient = plainHttpClientFunc(f, appVersion, invokingAgent) // Depends on IOStreams, appVersion, and invokingAgent - f.GitClient = newGitClient(f) // Depends on IOStreams, and Executable - f.Remotes = remotesFunc(f) // Depends on Config, and GitClient - f.BaseRepo = BaseRepoFunc(f) // Depends on Remotes - f.Prompter = newPrompter(f) // Depends on Config and IOStreams - f.Browser = newBrowser(f) // Depends on Config, and IOStreams - f.ExtensionManager = extensionManager(f) // Depends on Config, HttpClient, and IOStreams - f.Branch = branchFunc(f) // Depends on GitClient + f.IOStreams = ios + f.HttpClient = HttpClientFunc(cfg.Authentication(), ios, appVersion, invokingAgent, telemetryDisabler) + f.PlainHttpClient = plainHttpClientFunc(ios, appVersion, invokingAgent, telemetryDisabler) + f.GitClient = newGitClient(f) // Depends on IOStreams, and Executable + f.Remotes = remotesFunc(f) // Depends on Config, and GitClient + f.BaseRepo = BaseRepoFunc(f.Remotes) + f.Prompter = newPrompter(f) // Depends on Config and IOStreams + f.Browser = newBrowser(f) // Depends on Config, and IOStreams + f.ExtensionManager = extensionManager(f) // Depends on Config, HttpClient, and IOStreams + f.Branch = branchFunc(f) // Depends on GitClient return f } @@ -73,9 +73,9 @@ func New(appVersion string, invokingAgent string) *cmdutil.Factory { // origin https://github.com/cli/cli-fork.git (push) // // With this resolution function, the upstream will always be chosen (assuming we have authenticated with github.com). -func BaseRepoFunc(f *cmdutil.Factory) func() (ghrepo.Interface, error) { +func BaseRepoFunc(remotesFunc func() (ghContext.Remotes, error)) func() (ghrepo.Interface, error) { return func() (ghrepo.Interface, error) { - remotes, err := f.Remotes() + remotes, err := remotesFunc() if err != nil { return nil, err } @@ -187,19 +187,15 @@ func remotesFunc(f *cmdutil.Factory) func() (ghContext.Remotes, error) { return rr.Resolver() } -func httpClientFunc(f *cmdutil.Factory, appVersion string, invokingAgent string) func() (*http.Client, error) { +func HttpClientFunc(authCfg gh.AuthConfig, ios *iostreams.IOStreams, appVersion string, invokingAgent string, telemetryDisabler ghtelemetry.Disabler) func() (*http.Client, error) { return func() (*http.Client, error) { - io := f.IOStreams - cfg, err := f.Config() - if err != nil { - return nil, err - } opts := api.HTTPClientOptions{ - Config: cfg.Authentication(), - Log: io.ErrOut, - LogColorize: io.ColorEnabled(), - AppVersion: appVersion, - InvokingAgent: invokingAgent, + Config: authCfg, + Log: ios.ErrOut, + LogColorize: ios.ColorEnabled(), + AppVersion: appVersion, + InvokingAgent: invokingAgent, + TelemetryDisabler: telemetryDisabler, } client, err := api.NewHTTPClient(opts) if err != nil { @@ -210,16 +206,16 @@ func httpClientFunc(f *cmdutil.Factory, appVersion string, invokingAgent string) } } -func plainHttpClientFunc(f *cmdutil.Factory, appVersion string, invokingAgent string) func() (*http.Client, error) { +func plainHttpClientFunc(ios *iostreams.IOStreams, appVersion string, invokingAgent string, telemetryDisabler ghtelemetry.Disabler) func() (*http.Client, error) { return func() (*http.Client, error) { - io := f.IOStreams opts := api.HTTPClientOptions{ - Log: io.ErrOut, - LogColorize: io.ColorEnabled(), + Log: ios.ErrOut, + LogColorize: ios.ColorEnabled(), AppVersion: appVersion, InvokingAgent: invokingAgent, // This is required to prevent automatic setting of auth and other headers. SkipDefaultHeaders: true, + TelemetryDisabler: telemetryDisabler, } client, err := api.NewHTTPClient(opts) if err != nil { @@ -231,9 +227,8 @@ func plainHttpClientFunc(f *cmdutil.Factory, appVersion string, invokingAgent st func newGitClient(f *cmdutil.Factory) *git.Client { io := f.IOStreams - ghPath := f.Executable() client := &git.Client{ - GhPath: ghPath, + GhPath: f.ExecutablePath, Stderr: io.ErrOut, Stdin: io.In, Stdout: io.Out, @@ -252,18 +247,6 @@ func newPrompter(f *cmdutil.Factory) prompter.Prompter { return prompter.New(editor, io) } -func configFunc() func() (gh.Config, error) { - var cachedConfig gh.Config - var configError error - return func() (gh.Config, error) { - if cachedConfig != nil || configError != nil { - return cachedConfig, configError - } - cachedConfig, configError = config.NewConfig() - return cachedConfig, configError - } -} - func branchFunc(f *cmdutil.Factory) func() (string, error) { return func() (string, error) { currentBranch, err := f.GitClient.CurrentBranch(context.Background()) @@ -293,72 +276,6 @@ func extensionManager(f *cmdutil.Factory) *extension.Manager { return em } -func ioStreams(f *cmdutil.Factory) *iostreams.IOStreams { - io := iostreams.System() - cfg, err := f.Config() - if err != nil { - return io - } - - if _, ghPromptDisabled := os.LookupEnv("GH_PROMPT_DISABLED"); ghPromptDisabled { - io.SetNeverPrompt(true) - } else if prompt := cfg.Prompt(""); prompt.Value == "disabled" { - io.SetNeverPrompt(true) - } - - falseyValues := []string{"false", "0", "no", ""} - - accessiblePrompterValue, accessiblePrompterIsSet := os.LookupEnv("GH_ACCESSIBLE_PROMPTER") - if accessiblePrompterIsSet { - if !slices.Contains(falseyValues, accessiblePrompterValue) { - io.SetAccessiblePrompterEnabled(true) - } - } else if prompt := cfg.AccessiblePrompter(""); prompt.Value == "enabled" { - io.SetAccessiblePrompterEnabled(true) - } - - experimentalPrompterValue, experimentalPrompterIsSet := os.LookupEnv("GH_EXPERIMENTAL_PROMPTER") - if experimentalPrompterIsSet { - if !slices.Contains(falseyValues, experimentalPrompterValue) { - io.SetExperimentalPrompterEnabled(true) - } - } - - ghSpinnerDisabledValue, ghSpinnerDisabledIsSet := os.LookupEnv("GH_SPINNER_DISABLED") - if ghSpinnerDisabledIsSet { - if !slices.Contains(falseyValues, ghSpinnerDisabledValue) { - io.SetSpinnerDisabled(true) - } - } else if spinnerDisabled := cfg.Spinner(""); spinnerDisabled.Value == "disabled" { - io.SetSpinnerDisabled(true) - } - - // Pager precedence - // 1. GH_PAGER - // 2. pager from config - // 3. PAGER - if ghPager, ghPagerExists := os.LookupEnv("GH_PAGER"); ghPagerExists { - io.SetPager(ghPager) - } else if pager := cfg.Pager(""); pager.Value != "" { - io.SetPager(pager.Value) - } - - if ghColorLabels, ghColorLabelsExists := os.LookupEnv("GH_COLOR_LABELS"); ghColorLabelsExists { - switch ghColorLabels { - case "", "0", "false", "no": - io.SetColorLabels(false) - default: - io.SetColorLabels(true) - } - } else if prompt := cfg.ColorLabels(""); prompt.Value == "enabled" { - io.SetColorLabels(true) - } - - io.SetAccessibleColorsEnabled(xcolor.IsAccessibleColorsEnabled()) - - return io -} - // SSOURL returns the URL of a SAML SSO challenge received by the server for clients that use ExtractHeader // to extract the value of the "X-GitHub-SSO" response header. func SSOURL() string { diff --git a/pkg/cmd/factory/default_test.go b/pkg/cmd/factory/default_test.go index 7d84caa8f..6f376e48c 100644 --- a/pkg/cmd/factory/default_test.go +++ b/pkg/cmd/factory/default_test.go @@ -11,6 +11,7 @@ import ( "github.com/cli/cli/v2/internal/config" "github.com/cli/cli/v2/internal/gh" ghmock "github.com/cli/cli/v2/internal/gh/mock" + "github.com/cli/cli/v2/internal/telemetry" "github.com/cli/cli/v2/pkg/cmdutil" "github.com/cli/cli/v2/pkg/httpmock" "github.com/cli/cli/v2/pkg/iostreams" @@ -66,7 +67,6 @@ func Test_BaseRepo(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - f := New("1", "") rr := &remoteResolver{ readRemotes: func() (git.RemoteSet, error) { return tt.remotes, nil @@ -90,8 +90,10 @@ func Test_BaseRepo(t *testing.T) { return cfg, nil }, } - f.Remotes = rr.Resolver() - f.BaseRepo = BaseRepoFunc(f) + remotes := rr.Resolver() + f := &cmdutil.Factory{ + BaseRepo: BaseRepoFunc(remotes), + } repo, err := f.BaseRepo() if tt.wantsErr { assert.Error(t, err) @@ -204,7 +206,7 @@ func Test_SmartBaseRepo(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - f := New("1", "") + f := &cmdutil.Factory{} rr := &remoteResolver{ readRemotes: func() (git.RemoteSet, error) { return tt.remotes, nil @@ -297,7 +299,6 @@ func Test_OverrideBaseRepo(t *testing.T) { if tt.envOverride != "" { t.Setenv("GH_REPO", tt.envOverride) } - f := New("1", "") rr := &remoteResolver{ readRemotes: func() (git.RemoteSet, error) { return tt.remotes, nil @@ -306,8 +307,10 @@ func Test_OverrideBaseRepo(t *testing.T) { return tt.config, nil }, } - f.Remotes = rr.Resolver() - f.BaseRepo = cmdutil.OverrideBaseRepoFunc(f, tt.argOverride) + remotes := rr.Resolver() + f := &cmdutil.Factory{ + BaseRepo: cmdutil.OverrideBaseRepoFunc(BaseRepoFunc(remotes), tt.argOverride), + } repo, err := f.BaseRepo() if tt.wantsErr { assert.Error(t, err) @@ -321,341 +324,6 @@ func Test_OverrideBaseRepo(t *testing.T) { } } -func Test_ioStreams_pager(t *testing.T) { - tests := []struct { - name string - env map[string]string - config gh.Config - wantPager string - }{ - { - name: "GH_PAGER and PAGER set", - env: map[string]string{ - "GH_PAGER": "GH_PAGER", - "PAGER": "PAGER", - }, - wantPager: "GH_PAGER", - }, - { - name: "GH_PAGER and config pager set", - env: map[string]string{ - "GH_PAGER": "GH_PAGER", - }, - config: pagerConfig(), - wantPager: "GH_PAGER", - }, - { - name: "config pager and PAGER set", - env: map[string]string{ - "PAGER": "PAGER", - }, - config: pagerConfig(), - wantPager: "CONFIG_PAGER", - }, - { - name: "only PAGER set", - env: map[string]string{ - "PAGER": "PAGER", - }, - wantPager: "PAGER", - }, - { - name: "GH_PAGER set to blank string", - env: map[string]string{ - "GH_PAGER": "", - "PAGER": "PAGER", - }, - wantPager: "", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if tt.env != nil { - for k, v := range tt.env { - t.Setenv(k, v) - } - } - f := New("1", "") - f.Config = func() (gh.Config, error) { - if tt.config == nil { - return config.NewBlankConfig(), nil - } else { - return tt.config, nil - } - } - io := ioStreams(f) - assert.Equal(t, tt.wantPager, io.GetPager()) - }) - } -} - -func Test_ioStreams_prompt(t *testing.T) { - tests := []struct { - name string - config gh.Config - promptDisabled bool - env map[string]string - }{ - { - name: "default config", - promptDisabled: false, - }, - { - name: "config with prompt disabled", - config: disablePromptConfig(), - promptDisabled: true, - }, - { - name: "prompt disabled via GH_PROMPT_DISABLED env var", - env: map[string]string{"GH_PROMPT_DISABLED": "1"}, - promptDisabled: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if tt.env != nil { - for k, v := range tt.env { - t.Setenv(k, v) - } - } - f := New("1", "") - f.Config = func() (gh.Config, error) { - if tt.config == nil { - return config.NewBlankConfig(), nil - } else { - return tt.config, nil - } - } - io := ioStreams(f) - assert.Equal(t, tt.promptDisabled, io.GetNeverPrompt()) - }) - } -} - -func Test_ioStreams_spinnerDisabled(t *testing.T) { - tests := []struct { - name string - config gh.Config - spinnerDisabled bool - env map[string]string - }{ - { - name: "default config", - spinnerDisabled: false, - }, - { - name: "config with spinner disabled", - config: disableSpinnersConfig(), - spinnerDisabled: true, - }, - { - name: "config with spinner enabled", - config: enableSpinnersConfig(), - spinnerDisabled: false, - }, - { - name: "spinner disabled via GH_SPINNER_DISABLED env var = 0", - env: map[string]string{"GH_SPINNER_DISABLED": "0"}, - spinnerDisabled: false, - }, - { - name: "spinner disabled via GH_SPINNER_DISABLED env var = false", - env: map[string]string{"GH_SPINNER_DISABLED": "false"}, - spinnerDisabled: false, - }, - { - name: "spinner disabled via GH_SPINNER_DISABLED env var = no", - env: map[string]string{"GH_SPINNER_DISABLED": "no"}, - spinnerDisabled: false, - }, - { - name: "spinner enabled via GH_SPINNER_DISABLED env var = 1", - env: map[string]string{"GH_SPINNER_DISABLED": "1"}, - spinnerDisabled: true, - }, - { - name: "spinner enabled via GH_SPINNER_DISABLED env var = true", - env: map[string]string{"GH_SPINNER_DISABLED": "true"}, - spinnerDisabled: true, - }, - { - name: "config enabled but env disabled, respects env", - config: enableSpinnersConfig(), - env: map[string]string{"GH_SPINNER_DISABLED": "true"}, - spinnerDisabled: true, - }, - { - name: "config disabled but env enabled, respects env", - config: disableSpinnersConfig(), - env: map[string]string{"GH_SPINNER_DISABLED": "false"}, - spinnerDisabled: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - for k, v := range tt.env { - t.Setenv(k, v) - } - f := New("1", "") - f.Config = func() (gh.Config, error) { - if tt.config == nil { - return config.NewBlankConfig(), nil - } else { - return tt.config, nil - } - } - io := ioStreams(f) - assert.Equal(t, tt.spinnerDisabled, io.GetSpinnerDisabled()) - }) - } -} - -func Test_ioStreams_accessiblePrompterEnabled(t *testing.T) { - tests := []struct { - name string - config gh.Config - accessiblePrompterEnabled bool - env map[string]string - }{ - { - name: "default config", - accessiblePrompterEnabled: false, - }, - { - name: "config with accessible prompter enabled", - config: enableAccessiblePrompterConfig(), - accessiblePrompterEnabled: true, - }, - { - name: "config with accessible prompter disabled", - config: disableAccessiblePrompterConfig(), - accessiblePrompterEnabled: false, - }, - { - name: "accessible prompter enabled via GH_ACCESSIBLE_PROMPTER env var = 1", - env: map[string]string{"GH_ACCESSIBLE_PROMPTER": "1"}, - accessiblePrompterEnabled: true, - }, - { - name: "accessible prompter enabled via GH_ACCESSIBLE_PROMPTER env var = true", - env: map[string]string{"GH_ACCESSIBLE_PROMPTER": "true"}, - accessiblePrompterEnabled: true, - }, - { - name: "accessible prompter disabled via GH_ACCESSIBLE_PROMPTER env var = 0", - env: map[string]string{"GH_ACCESSIBLE_PROMPTER": "0"}, - accessiblePrompterEnabled: false, - }, - { - name: "config disabled but env enabled, respects env", - config: disableAccessiblePrompterConfig(), - env: map[string]string{"GH_ACCESSIBLE_PROMPTER": "true"}, - accessiblePrompterEnabled: true, - }, - { - name: "config enabled but env disabled, respects env", - config: enableAccessiblePrompterConfig(), - env: map[string]string{"GH_ACCESSIBLE_PROMPTER": "false"}, - accessiblePrompterEnabled: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - for k, v := range tt.env { - t.Setenv(k, v) - } - f := New("1", "") - f.Config = func() (gh.Config, error) { - if tt.config == nil { - return config.NewBlankConfig(), nil - } else { - return tt.config, nil - } - } - io := ioStreams(f) - assert.Equal(t, tt.accessiblePrompterEnabled, io.AccessiblePrompterEnabled()) - }) - } -} - -func Test_ioStreams_colorLabels(t *testing.T) { - tests := []struct { - name string - config gh.Config - colorLabelsEnabled bool - env map[string]string - }{ - { - name: "default config", - colorLabelsEnabled: false, - }, - { - name: "config with colorLabels enabled", - config: enableColorLabelsConfig(), - colorLabelsEnabled: true, - }, - { - name: "config with colorLabels disabled", - config: disableColorLabelsConfig(), - colorLabelsEnabled: false, - }, - { - name: "colorLabels enabled via `1` in GH_COLOR_LABELS env var", - env: map[string]string{"GH_COLOR_LABELS": "1"}, - colorLabelsEnabled: true, - }, - { - name: "colorLabels enabled via `true` in GH_COLOR_LABELS env var", - env: map[string]string{"GH_COLOR_LABELS": "true"}, - colorLabelsEnabled: true, - }, - { - name: "colorLabels enabled via `yes` in GH_COLOR_LABELS env var", - env: map[string]string{"GH_COLOR_LABELS": "yes"}, - colorLabelsEnabled: true, - }, - { - name: "colorLabels disable via empty string in GH_COLOR_LABELS env var", - env: map[string]string{"GH_COLOR_LABELS": ""}, - colorLabelsEnabled: false, - }, - { - name: "colorLabels disabled via `0` in GH_COLOR_LABELS env var", - env: map[string]string{"GH_COLOR_LABELS": "0"}, - colorLabelsEnabled: false, - }, - { - name: "colorLabels disabled via `false` in GH_COLOR_LABELS env var", - env: map[string]string{"GH_COLOR_LABELS": "false"}, - colorLabelsEnabled: false, - }, - { - name: "colorLabels disabled via `no` in GH_COLOR_LABELS env var", - env: map[string]string{"GH_COLOR_LABELS": "no"}, - colorLabelsEnabled: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if tt.env != nil { - for k, v := range tt.env { - t.Setenv(k, v) - } - } - f := New("1", "") - f.Config = func() (gh.Config, error) { - if tt.config == nil { - return config.NewBlankConfig(), nil - } else { - return tt.config, nil - } - } - io := ioStreams(f) - assert.Equal(t, tt.colorLabelsEnabled, io.ColorLabels()) - }) - } -} - func TestSSOURL(t *testing.T) { tests := []struct { name string @@ -683,13 +351,9 @@ func TestSSOURL(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - f := New("1", "") - f.Config = func() (gh.Config, error) { - return config.NewBlankConfig(), nil - } + cfg := config.NewBlankConfig() ios, _, _, stderr := iostreams.Test() - f.IOStreams = ios - client, err := httpClientFunc(f, "v1.2.3", "")() + client, err := HttpClientFunc(cfg.Authentication(), ios, "v1.2.3", "", &telemetry.NoOpService{})() require.NoError(t, err) req, err := http.NewRequest("GET", ts.URL, nil) if tt.sso != "" { @@ -718,13 +382,8 @@ func TestPlainHttpClient(t *testing.T) { })) defer ts.Close() - f := New("1", "") - f.Config = func() (gh.Config, error) { - return config.NewBlankConfig(), nil - } ios, _, _, _ := iostreams.Test() - f.IOStreams = ios - client, err := plainHttpClientFunc(f, "v1.2.3", "")() + client, err := plainHttpClientFunc(ios, "v1.2.3", "", &telemetry.NoOpService{})() require.NoError(t, err) req, err := http.NewRequest("GET", ts.URL, nil) @@ -759,7 +418,7 @@ func TestNewGitClient(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - f := New("1", "") + f := &cmdutil.Factory{} f.Config = func() (gh.Config, error) { if tt.config == nil { return config.NewBlankConfig(), nil @@ -767,7 +426,7 @@ func TestNewGitClient(t *testing.T) { return tt.config, nil } } - f.ExecutableName = tt.executable + f.ExecutablePath = tt.executable ios, _, _, _ := iostreams.Test() f.IOStreams = ios c := newGitClient(f) @@ -784,35 +443,3 @@ func defaultConfig() *ghmock.ConfigMock { cfg.Set("nonsense.com", "oauth_token", "BLAH") return cfg } - -func pagerConfig() gh.Config { - return config.NewFromString("pager: CONFIG_PAGER") -} - -func disablePromptConfig() gh.Config { - return config.NewFromString("prompt: disabled") -} - -func enableAccessiblePrompterConfig() gh.Config { - return config.NewFromString("accessible_prompter: enabled") -} - -func disableAccessiblePrompterConfig() gh.Config { - return config.NewFromString("accessible_prompter: disabled") -} - -func disableSpinnersConfig() gh.Config { - return config.NewFromString("spinner: disabled") -} - -func enableSpinnersConfig() gh.Config { - return config.NewFromString("spinner: enabled") -} - -func disableColorLabelsConfig() gh.Config { - return config.NewFromString("color_labels: disabled") -} - -func enableColorLabelsConfig() gh.Config { - return config.NewFromString("color_labels: enabled") -} diff --git a/pkg/cmd/search/shared/shared_test.go b/pkg/cmd/search/shared/shared_test.go index c66e0908f..bd8060943 100644 --- a/pkg/cmd/search/shared/shared_test.go +++ b/pkg/cmd/search/shared/shared_test.go @@ -2,22 +2,27 @@ package shared import ( "fmt" + "net/http" "testing" "time" "github.com/cli/cli/v2/internal/browser" "github.com/cli/cli/v2/internal/config" "github.com/cli/cli/v2/internal/gh" - "github.com/cli/cli/v2/pkg/cmd/factory" + "github.com/cli/cli/v2/pkg/cmdutil" "github.com/cli/cli/v2/pkg/iostreams" "github.com/cli/cli/v2/pkg/search" "github.com/stretchr/testify/assert" ) func TestSearcher(t *testing.T) { - f := factory.New("1", "") - f.Config = func() (gh.Config, error) { - return config.NewBlankConfig(), nil + f := &cmdutil.Factory{ + Config: func() (gh.Config, error) { + return config.NewBlankConfig(), nil + }, + HttpClient: func() (*http.Client, error) { + return &http.Client{}, nil + }, } _, err := Searcher(f) assert.NoError(t, err) diff --git a/pkg/cmd/skills/preview/preview.go b/pkg/cmd/skills/preview/preview.go index e39886ecd..4bdfdd416 100644 --- a/pkg/cmd/skills/preview/preview.go +++ b/pkg/cmd/skills/preview/preview.go @@ -22,11 +22,11 @@ import ( ) type PreviewOptions struct { - IO *iostreams.IOStreams - HttpClient func() (*http.Client, error) - Prompter prompter.Prompter - Executable func() string - RenderFile func(string, string) string + IO *iostreams.IOStreams + HttpClient func() (*http.Client, error) + Prompter prompter.Prompter + ExecutablePath string + RenderFile func(string, string) string RepoArg string SkillName string @@ -38,10 +38,10 @@ type PreviewOptions struct { // NewCmdPreview creates the "skills preview" command. func NewCmdPreview(f *cmdutil.Factory, runF func(*PreviewOptions) error) *cobra.Command { opts := &PreviewOptions{ - IO: f.IOStreams, - HttpClient: f.HttpClient, - Prompter: f.Prompter, - Executable: f.Executable, + IO: f.IOStreams, + HttpClient: f.HttpClient, + Prompter: f.Prompter, + ExecutablePath: f.ExecutablePath, } opts.RenderFile = func(filePath, content string) string { return renderMarkdownPreview(opts.IO, filePath, content) diff --git a/pkg/cmd/skills/search/search.go b/pkg/cmd/skills/search/search.go index 05511484e..2542b9d90 100644 --- a/pkg/cmd/skills/search/search.go +++ b/pkg/cmd/skills/search/search.go @@ -47,12 +47,12 @@ var SkillSearchFields = []string{ } type SearchOptions struct { - IO *iostreams.IOStreams - HttpClient func() (*http.Client, error) - Config func() (gh.Config, error) - Prompter prompter.Prompter - Executable string // path to the current gh binary for install subprocess - Exporter cmdutil.Exporter + IO *iostreams.IOStreams + HttpClient func() (*http.Client, error) + Config func() (gh.Config, error) + Prompter prompter.Prompter + ExecutablePath string // path to the current gh binary for install subprocess + Exporter cmdutil.Exporter // User inputs Query string @@ -64,11 +64,11 @@ type SearchOptions struct { // NewCmdSearch creates the "skills search" command. func NewCmdSearch(f *cmdutil.Factory, runF func(*SearchOptions) error) *cobra.Command { opts := &SearchOptions{ - IO: f.IOStreams, - HttpClient: f.HttpClient, - Config: f.Config, - Prompter: f.Prompter, - Executable: f.Executable(), + IO: f.IOStreams, + HttpClient: f.HttpClient, + Config: f.Config, + Prompter: f.Prompter, + ExecutablePath: f.ExecutablePath, } cmd := &cobra.Command{ @@ -585,7 +585,7 @@ func promptInstall(opts *SearchOptions, skills []skillResult) error { } //nolint:gosec // arguments are from user-selected search results, not arbitrary input - cmd := exec.Command(opts.Executable, "skills", "install", s.Repo, installArg, + cmd := exec.Command(opts.ExecutablePath, "skills", "install", s.Repo, installArg, "--agent", host.ID, "--scope", scope) cmd.Stdin = os.Stdin cmd.Stdout = opts.IO.Out diff --git a/pkg/cmdutil/factory.go b/pkg/cmdutil/factory.go index f746ec897..9ea8ce4a6 100644 --- a/pkg/cmdutil/factory.go +++ b/pkg/cmdutil/factory.go @@ -2,9 +2,6 @@ package cmdutil import ( "net/http" - "os" - "path/filepath" - "strings" "github.com/cli/cli/v2/context" "github.com/cli/cli/v2/git" @@ -18,7 +15,7 @@ import ( type Factory struct { AppVersion string - ExecutableName string + ExecutablePath string InvokingAgent string Browser browser.Browser @@ -27,8 +24,11 @@ type Factory struct { IOStreams *iostreams.IOStreams Prompter prompter.Prompter - BaseRepo func() (ghrepo.Interface, error) - Branch func() (string, error) + BaseRepo func() (ghrepo.Interface, error) + Branch func() (string, error) + Cfg gh.Config + // TODO: Config should be removed in favour of cfg being passed to the right place, + // but this is going to be very invasive and shouldn't be done as part of a feature change. Config func() (gh.Config, error) HttpClient func() (*http.Client, error) // PlainHttpClient is a special HTTP client that does not automatically set @@ -37,69 +37,3 @@ type Factory struct { PlainHttpClient func() (*http.Client, error) Remotes func() (context.Remotes, error) } - -// Executable is the path to the currently invoked binary -func (f *Factory) Executable() string { - ghPath := os.Getenv("GH_PATH") - if ghPath != "" { - return ghPath - } - if !strings.ContainsRune(f.ExecutableName, os.PathSeparator) { - f.ExecutableName = executable(f.ExecutableName) - } - return f.ExecutableName -} - -// Finds the location of the executable for the current process as it's found in PATH, respecting symlinks. -// If the process couldn't determine its location, return fallbackName. If the executable wasn't found in -// PATH, return the absolute location to the program. -// -// The idea is that the result of this function is callable in the future and refers to the same -// installation of gh, even across upgrades. This is needed primarily for Homebrew, which installs software -// under a location such as `/usr/local/Cellar/gh/1.13.1/bin/gh` and symlinks it from `/usr/local/bin/gh`. -// When the version is upgraded, Homebrew will often delete older versions, but keep the symlink. Because of -// this, we want to refer to the `gh` binary as `/usr/local/bin/gh` and not as its internal Homebrew -// location. -// -// None of this would be needed if we could just refer to GitHub CLI as `gh`, i.e. without using an absolute -// path. However, for some reason Homebrew does not include `/usr/local/bin` in PATH when it invokes git -// commands to update its taps. If `gh` (no path) is being used as git credential helper, as set up by `gh -// auth login`, running `brew update` will print out authentication errors as git is unable to locate -// Homebrew-installed `gh`. -func executable(fallbackName string) string { - exe, err := os.Executable() - if err != nil { - return fallbackName - } - - base := filepath.Base(exe) - path := os.Getenv("PATH") - for _, dir := range filepath.SplitList(path) { - p, err := filepath.Abs(filepath.Join(dir, base)) - if err != nil { - continue - } - f, err := os.Lstat(p) - if err != nil { - continue - } - - if p == exe { - return p - } else if f.Mode()&os.ModeSymlink != 0 { - realP, err := filepath.EvalSymlinks(p) - if err != nil { - continue - } - realExe, err := filepath.EvalSymlinks(exe) - if err != nil { - continue - } - if realP == realExe { - return p - } - } - } - - return exe -} diff --git a/pkg/cmdutil/repo_override.go b/pkg/cmdutil/repo_override.go index 791dd919a..b037859e7 100644 --- a/pkg/cmdutil/repo_override.go +++ b/pkg/cmdutil/repo_override.go @@ -52,12 +52,12 @@ func EnableRepoOverride(cmd *cobra.Command, f *Factory) { return err } repoOverride, _ := cmd.Flags().GetString("repo") - f.BaseRepo = OverrideBaseRepoFunc(f, repoOverride) + f.BaseRepo = OverrideBaseRepoFunc(f.BaseRepo, repoOverride) return nil } } -func OverrideBaseRepoFunc(f *Factory, override string) func() (ghrepo.Interface, error) { +func OverrideBaseRepoFunc(baseRepoFunc func() (ghrepo.Interface, error), override string) func() (ghrepo.Interface, error) { if override == "" { override = os.Getenv("GH_REPO") } @@ -66,5 +66,5 @@ func OverrideBaseRepoFunc(f *Factory, override string) func() (ghrepo.Interface, return ghrepo.FromFullName(override) } } - return f.BaseRepo + return baseRepoFunc }