diff --git a/internal/config/config_type.go b/internal/config/config_type.go index 9098a1c08..3714698bd 100644 --- a/internal/config/config_type.go +++ b/internal/config/config_type.go @@ -50,6 +50,11 @@ var configOptions = []ConfigOption{ Description: "the terminal pager program to send standard output to", DefaultValue: "", }, + { + Key: "http_unix_socket", + Description: "the path to a unix socket through which to make HTTP connection", + DefaultValue: "", + }, } func ConfigOptions() []ConfigOption { @@ -179,6 +184,15 @@ func NewBlankRoot() *yaml.Node { }, }, }, + { + HeadComment: "The path to a unix socket through which send HTTP connections. If blank, HTTP traffic will be handled by net/http.DefaultTransport.", + Kind: yaml.ScalarNode, + Value: "http_unix_socket", + }, + { + Kind: yaml.ScalarNode, + Value: "", + }, }, }, }, diff --git a/internal/config/config_type_test.go b/internal/config/config_type_test.go index fca819f46..fe5e3f239 100644 --- a/internal/config/config_type_test.go +++ b/internal/config/config_type_test.go @@ -50,6 +50,8 @@ func Test_defaultConfig(t *testing.T) { # Aliases allow you to create nicknames for gh commands aliases: co: pr checkout + # The path to a unix socket through which send HTTP connections. If blank, HTTP traffic will be handled by net/http.DefaultTransport. + http_unix_socket: `) assert.Equal(t, expected, mainBuf.String()) assert.Equal(t, "", hostsBuf.String()) @@ -81,6 +83,9 @@ func Test_ValidateValue(t *testing.T) { err = ValidateValue("got", "123") assert.NoError(t, err) + + err = ValidateValue("http_unix_socket", "really_anything/is/allowed/and/net.Dial\\(...\\)/will/ultimately/validate") + assert.NoError(t, err) } func Test_ValidateKey(t *testing.T) { @@ -98,4 +103,7 @@ func Test_ValidateKey(t *testing.T) { err = ValidateKey("pager") assert.NoError(t, err) + + err = ValidateKey("http_unix_socket") + assert.NoError(t, err) } diff --git a/internal/httpunix/transport.go b/internal/httpunix/transport.go new file mode 100644 index 000000000..2326a5f91 --- /dev/null +++ b/internal/httpunix/transport.go @@ -0,0 +1,21 @@ +// package httpunix provides an http.RoundTripper which dials a server via a unix socket. +package httpunix + +import ( + "net" + "net/http" +) + +// NewRoundTripper returns an http.RoundTripper which sends requests via a unix +// socket at socketPath. +func NewRoundTripper(socketPath string) http.RoundTripper { + dial := func(network, addr string) (net.Conn, error) { + return net.Dial("unix", socketPath) + } + + return &http.Transport{ + Dial: dial, + DialTLS: dial, + DisableKeepAlives: true, + } +} diff --git a/pkg/cmd/factory/default.go b/pkg/cmd/factory/default.go index b7fb9dd13..32187689b 100644 --- a/pkg/cmd/factory/default.go +++ b/pkg/cmd/factory/default.go @@ -85,7 +85,7 @@ func httpClientFunc(f *cmdutil.Factory, appVersion string) func() (*http.Client, if err != nil { return nil, err } - return NewHTTPClient(io, cfg, appVersion, true), nil + return NewHTTPClient(io, cfg, appVersion, true) } } diff --git a/pkg/cmd/factory/http.go b/pkg/cmd/factory/http.go index 23534a0d3..0c0abd7b2 100644 --- a/pkg/cmd/factory/http.go +++ b/pkg/cmd/factory/http.go @@ -9,6 +9,7 @@ import ( "github.com/cli/cli/api" "github.com/cli/cli/internal/ghinstance" + "github.com/cli/cli/internal/httpunix" "github.com/cli/cli/pkg/iostreams" ) @@ -57,8 +58,31 @@ type configGetter interface { } // generic authenticated HTTP client for commands -func NewHTTPClient(io *iostreams.IOStreams, cfg configGetter, appVersion string, setAccept bool) *http.Client { +func NewHTTPClient(io *iostreams.IOStreams, cfg configGetter, appVersion string, setAccept bool) (*http.Client, error) { var opts []api.ClientOption + + // We need to check and potentially add the unix socket roundtripper option + // before adding any other options, since if we are going to use the unix + // socket transport, it needs to form the base of the transport chain + // represented by invocations of opts... + // + // Another approach might be to change the signature of api.NewHTTPClient to + // take an explicit base http.RoundTripper as its first parameter (it + // currently defaults internally to http.DefaultTransport), or add another + // variant like api.NewHTTPClientWithBaseRoundTripper. But, the only caller + // which would use that non-default behavior is right here, and it doesn't + // seem worth the cognitive overhead everywhere else just to serve this one + // use case. + unixSocket, err := cfg.Get("", "http_unix_socket") + if err != nil { + return nil, err + } + if unixSocket != "" { + opts = append(opts, api.ClientOption(func(http.RoundTripper) http.RoundTripper { + return httpunix.NewRoundTripper(unixSocket) + })) + } + if verbose := os.Getenv("DEBUG"); verbose != "" { logTraffic := strings.Contains(verbose, "api") opts = append(opts, api.VerboseLog(io.ErrOut, logTraffic, io.IsStderrTTY())) @@ -98,7 +122,7 @@ func NewHTTPClient(io *iostreams.IOStreams, cfg configGetter, appVersion string, ) } - return api.NewHTTPClient(opts...) + return api.NewHTTPClient(opts...), nil } func getHost(r *http.Request) string { diff --git a/pkg/cmd/factory/http_test.go b/pkg/cmd/factory/http_test.go index 6172d846c..ae4096594 100644 --- a/pkg/cmd/factory/http_test.go +++ b/pkg/cmd/factory/http_test.go @@ -135,7 +135,8 @@ func TestNewHTTPClient(t *testing.T) { }) io, _, _, stderr := iostreams.Test() - client := NewHTTPClient(io, tt.args.config, tt.args.appVersion, tt.args.setAccept) + client, err := NewHTTPClient(io, tt.args.config, tt.args.appVersion, tt.args.setAccept) + require.NoError(t, err) req, err := http.NewRequest("GET", ts.URL, nil) req.Host = tt.host diff --git a/pkg/cmd/root/root.go b/pkg/cmd/root/root.go index fe1877be5..158bb112d 100644 --- a/pkg/cmd/root/root.go +++ b/pkg/cmd/root/root.go @@ -118,6 +118,6 @@ func bareHTTPClient(f *cmdutil.Factory, version string) func() (*http.Client, er if err != nil { return nil, err } - return factory.NewHTTPClient(f.IOStreams, cfg, version, false), nil + return factory.NewHTTPClient(f.IOStreams, cfg, version, false) } }