diff --git a/pkg/cmd/factory/default.go b/pkg/cmd/factory/default.go index 52837b252..b48840387 100644 --- a/pkg/cmd/factory/default.go +++ b/pkg/cmd/factory/default.go @@ -33,15 +33,16 @@ func New(appVersion string) *cmdutil.Factory { ExecutableName: "gh", } - f.IOStreams = ioStreams(f) // Depends on Config - f.HttpClient = httpClientFunc(f, appVersion) // Depends on Config, IOStreams, and appVersion - f.GitClient = newGitClient(f) // Depends on IOStreams, and Executable - f.Remotes = remotesFunc(f) // Depends on Config, and GitClient - f.BaseRepo = BaseRepoFunc(f) // Depends on Remotes - f.Prompter = newPrompter(f) // Depends on Config and IOStreams - f.Browser = newBrowser(f) // Depends on Config, and IOStreams - f.ExtensionManager = extensionManager(f) // Depends on Config, HttpClient, and IOStreams - f.Branch = branchFunc(f) // Depends on GitClient + f.IOStreams = ioStreams(f) // Depends on Config + f.HttpClient = httpClientFunc(f, appVersion) // Depends on Config, IOStreams, and appVersion + f.PlainHttpClient = plainHttpClientFunc(f, appVersion) // Depends on IOStreams, and appVersion + f.GitClient = newGitClient(f) // Depends on IOStreams, and Executable + f.Remotes = remotesFunc(f) // Depends on Config, and GitClient + f.BaseRepo = BaseRepoFunc(f) // Depends on Remotes + f.Prompter = newPrompter(f) // Depends on Config and IOStreams + f.Browser = newBrowser(f) // Depends on Config, and IOStreams + f.ExtensionManager = extensionManager(f) // Depends on Config, HttpClient, and IOStreams + f.Branch = branchFunc(f) // Depends on GitClient return f } @@ -207,6 +208,24 @@ func httpClientFunc(f *cmdutil.Factory, appVersion string) func() (*http.Client, } } +func plainHttpClientFunc(f *cmdutil.Factory, appVersion string) func() (*http.Client, error) { + return func() (*http.Client, error) { + io := f.IOStreams + opts := api.HTTPClientOptions{ + Log: io.ErrOut, + LogColorize: io.ColorEnabled(), + AppVersion: appVersion, + // This is required to prevent automatic setting of auth and other headers. + SkipDefaultHeaders: true, + } + client, err := api.NewHTTPClient(opts) + if err != nil { + return nil, err + } + return client, nil + } +} + func newGitClient(f *cmdutil.Factory) *git.Client { io := f.IOStreams ghPath := f.Executable() diff --git a/pkg/cmd/factory/default_test.go b/pkg/cmd/factory/default_test.go index d7bfe39fd..54ffb8d59 100644 --- a/pkg/cmd/factory/default_test.go +++ b/pkg/cmd/factory/default_test.go @@ -710,6 +710,36 @@ func TestSSOURL(t *testing.T) { } } +func TestPlainHttpClient(t *testing.T) { + var receivedHeaders *http.Header + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHeaders = &r.Header + w.WriteHeader(http.StatusNoContent) + })) + defer ts.Close() + + f := New("1") + f.Config = func() (gh.Config, error) { + return config.NewBlankConfig(), nil + } + ios, _, _, _ := iostreams.Test() + f.IOStreams = ios + client, err := plainHttpClientFunc(f, "v1.2.3")() + require.NoError(t, err) + + req, err := http.NewRequest("GET", ts.URL, nil) + require.NoError(t, err) + res, err := client.Do(req) + require.NoError(t, err) + + assert.Equal(t, 204, res.StatusCode) + assert.Equal(t, []string{"GitHub CLI v1.2.3"}, receivedHeaders.Values("User-Agent")) + assert.Nil(t, receivedHeaders.Values("Authorization")) + assert.Nil(t, receivedHeaders.Values("Content-Type")) + assert.Nil(t, receivedHeaders.Values("Accept")) + assert.Nil(t, receivedHeaders.Values("Time-Zone")) +} + func TestNewGitClient(t *testing.T) { tests := []struct { name string diff --git a/pkg/cmdutil/factory.go b/pkg/cmdutil/factory.go index 07ffbee64..b90960b31 100644 --- a/pkg/cmdutil/factory.go +++ b/pkg/cmdutil/factory.go @@ -30,7 +30,11 @@ type Factory struct { Branch func() (string, error) Config func() (gh.Config, error) HttpClient func() (*http.Client, error) - Remotes func() (context.Remotes, error) + // PlainHttpClient is a special HTTP client that does not automatically set + // auth and other headers. This is meant to be used in situations where the + // client needs to specify the headers itself (e.g. during login). + PlainHttpClient func() (*http.Client, error) + Remotes func() (context.Remotes, error) } // Executable is the path to the currently invoked binary