diff --git a/api/client.go b/api/client.go index b8c6857fb..1c8c1eaa2 100644 --- a/api/client.go +++ b/api/client.go @@ -35,7 +35,11 @@ func NewClient(opts ...ClientOption) *Client { func AddHeader(name, value string) ClientOption { return func(tr http.RoundTripper) http.RoundTripper { return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) { - req.Header.Add(name, value) + // prevent the token from leaking to non-GitHub hosts + // TODO: GHE support + if !strings.EqualFold(name, "Authorization") || strings.HasSuffix(req.URL.Hostname(), ".github.com") { + req.Header.Add(name, value) + } return tr.RoundTrip(req) }} } @@ -45,7 +49,11 @@ func AddHeader(name, value string) ClientOption { func AddHeaderFunc(name string, value func() string) ClientOption { return func(tr http.RoundTripper) http.RoundTripper { return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) { - req.Header.Add(name, value()) + // prevent the token from leaking to non-GitHub hosts + // TODO: GHE support + if !strings.EqualFold(name, "Authorization") || strings.HasSuffix(req.URL.Hostname(), ".github.com") { + req.Header.Add(name, value()) + } return tr.RoundTrip(req) }} } diff --git a/pkg/cmd/api/http.go b/pkg/cmd/api/http.go index 812b95af4..4db21b286 100644 --- a/pkg/cmd/api/http.go +++ b/pkg/cmd/api/http.go @@ -11,8 +11,14 @@ import ( ) func httpRequest(client *http.Client, method string, p string, params interface{}, headers []string) (*http.Response, error) { + var requestURL string // TODO: GHE support - url := "https://api.github.com/" + p + if strings.Contains(p, "://") { + requestURL = p + } else { + requestURL = "https://api.github.com/" + p + } + var body io.Reader var bodyIsJSON bool isGraphQL := p == "graphql" @@ -20,7 +26,7 @@ func httpRequest(client *http.Client, method string, p string, params interface{ switch pp := params.(type) { case map[string]interface{}: if strings.EqualFold(method, "GET") { - url = addQuery(url, pp) + requestURL = addQuery(requestURL, pp) } else { for key, value := range pp { switch vv := value.(type) { @@ -46,7 +52,7 @@ func httpRequest(client *http.Client, method string, p string, params interface{ return nil, fmt.Errorf("unrecognized parameters type: %v", params) } - req, err := http.NewRequest(method, url, body) + req, err := http.NewRequest(method, requestURL, body) if err != nil { return nil, err }