Merge pull request #3779 from jgold-stripe/unix

Add ability to dial API via unix socket
This commit is contained in:
Nate Smith 2021-06-29 09:46:33 -07:00 committed by GitHub
commit 554250bc4e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 73 additions and 5 deletions

View file

@ -50,6 +50,11 @@ var configOptions = []ConfigOption{
Description: "the terminal pager program to send standard output to",
DefaultValue: "",
},
{
Key: "http_unix_socket",
Description: "the path to a unix socket through which to make HTTP connection",
DefaultValue: "",
},
}
func ConfigOptions() []ConfigOption {
@ -179,6 +184,15 @@ func NewBlankRoot() *yaml.Node {
},
},
},
{
HeadComment: "The path to a unix socket through which send HTTP connections. If blank, HTTP traffic will be handled by net/http.DefaultTransport.",
Kind: yaml.ScalarNode,
Value: "http_unix_socket",
},
{
Kind: yaml.ScalarNode,
Value: "",
},
},
},
},

View file

@ -50,6 +50,8 @@ func Test_defaultConfig(t *testing.T) {
# Aliases allow you to create nicknames for gh commands
aliases:
co: pr checkout
# The path to a unix socket through which send HTTP connections. If blank, HTTP traffic will be handled by net/http.DefaultTransport.
http_unix_socket:
`)
assert.Equal(t, expected, mainBuf.String())
assert.Equal(t, "", hostsBuf.String())
@ -81,6 +83,9 @@ func Test_ValidateValue(t *testing.T) {
err = ValidateValue("got", "123")
assert.NoError(t, err)
err = ValidateValue("http_unix_socket", "really_anything/is/allowed/and/net.Dial\\(...\\)/will/ultimately/validate")
assert.NoError(t, err)
}
func Test_ValidateKey(t *testing.T) {
@ -98,4 +103,7 @@ func Test_ValidateKey(t *testing.T) {
err = ValidateKey("pager")
assert.NoError(t, err)
err = ValidateKey("http_unix_socket")
assert.NoError(t, err)
}

View file

@ -0,0 +1,21 @@
// package httpunix provides an http.RoundTripper which dials a server via a unix socket.
package httpunix
import (
"net"
"net/http"
)
// NewRoundTripper returns an http.RoundTripper which sends requests via a unix
// socket at socketPath.
func NewRoundTripper(socketPath string) http.RoundTripper {
dial := func(network, addr string) (net.Conn, error) {
return net.Dial("unix", socketPath)
}
return &http.Transport{
Dial: dial,
DialTLS: dial,
DisableKeepAlives: true,
}
}

View file

@ -85,7 +85,7 @@ func httpClientFunc(f *cmdutil.Factory, appVersion string) func() (*http.Client,
if err != nil {
return nil, err
}
return NewHTTPClient(io, cfg, appVersion, true), nil
return NewHTTPClient(io, cfg, appVersion, true)
}
}

View file

@ -9,6 +9,7 @@ import (
"github.com/cli/cli/api"
"github.com/cli/cli/internal/ghinstance"
"github.com/cli/cli/internal/httpunix"
"github.com/cli/cli/pkg/iostreams"
)
@ -57,8 +58,31 @@ type configGetter interface {
}
// generic authenticated HTTP client for commands
func NewHTTPClient(io *iostreams.IOStreams, cfg configGetter, appVersion string, setAccept bool) *http.Client {
func NewHTTPClient(io *iostreams.IOStreams, cfg configGetter, appVersion string, setAccept bool) (*http.Client, error) {
var opts []api.ClientOption
// We need to check and potentially add the unix socket roundtripper option
// before adding any other options, since if we are going to use the unix
// socket transport, it needs to form the base of the transport chain
// represented by invocations of opts...
//
// Another approach might be to change the signature of api.NewHTTPClient to
// take an explicit base http.RoundTripper as its first parameter (it
// currently defaults internally to http.DefaultTransport), or add another
// variant like api.NewHTTPClientWithBaseRoundTripper. But, the only caller
// which would use that non-default behavior is right here, and it doesn't
// seem worth the cognitive overhead everywhere else just to serve this one
// use case.
unixSocket, err := cfg.Get("", "http_unix_socket")
if err != nil {
return nil, err
}
if unixSocket != "" {
opts = append(opts, api.ClientOption(func(http.RoundTripper) http.RoundTripper {
return httpunix.NewRoundTripper(unixSocket)
}))
}
if verbose := os.Getenv("DEBUG"); verbose != "" {
logTraffic := strings.Contains(verbose, "api")
opts = append(opts, api.VerboseLog(io.ErrOut, logTraffic, io.IsStderrTTY()))
@ -98,7 +122,7 @@ func NewHTTPClient(io *iostreams.IOStreams, cfg configGetter, appVersion string,
)
}
return api.NewHTTPClient(opts...)
return api.NewHTTPClient(opts...), nil
}
func getHost(r *http.Request) string {

View file

@ -135,7 +135,8 @@ func TestNewHTTPClient(t *testing.T) {
})
io, _, _, stderr := iostreams.Test()
client := NewHTTPClient(io, tt.args.config, tt.args.appVersion, tt.args.setAccept)
client, err := NewHTTPClient(io, tt.args.config, tt.args.appVersion, tt.args.setAccept)
require.NoError(t, err)
req, err := http.NewRequest("GET", ts.URL, nil)
req.Host = tt.host

View file

@ -118,6 +118,6 @@ func bareHTTPClient(f *cmdutil.Factory, version string) func() (*http.Client, er
if err != nil {
return nil, err
}
return factory.NewHTTPClient(f.IOStreams, cfg, version, false), nil
return factory.NewHTTPClient(f.IOStreams, cfg, version, false)
}
}