From 9d70d62520aea42688518a24651a896f9aa25f44 Mon Sep 17 00:00:00 2001 From: Sam Coe Date: Thu, 14 Jul 2022 14:13:34 +0200 Subject: [PATCH] Set blank headers so they are not automatically resolved by go-gh (#5935) --- api/client.go | 47 +++++++++++++++++++++++++++++++++++----------- api/client_test.go | 35 ++++++++++++++++++++++++++++++++++ api/http_client.go | 18 ++++++++++++------ 3 files changed, 83 insertions(+), 17 deletions(-) diff --git a/api/client.go b/api/client.go index f37b5306b..54c25c0ff 100644 --- a/api/client.go +++ b/api/client.go @@ -14,6 +14,17 @@ import ( ghAPI "github.com/cli/go-gh/pkg/api" ) +const ( + accept = "Accept" + authorization = "Authorization" + cacheTTL = "X-GH-CACHE-TTL" + contentType = "Content-Type" + graphqlFeatures = "GraphQL-Features" + mergeQueue = "merge_queue" + timeZone = "Time-Zone" + userAgent = "User-Agent" +) + var linkRE = regexp.MustCompile(`<([^>]+)>;\s*rel="([^"]+)"`) func NewClientFromHTTP(httpClient *http.Client) *Client { @@ -45,9 +56,8 @@ func (err HTTPError) ScopesSuggestion() string { // GraphQL performs a GraphQL request and parses the response. If there are errors in the response, // GraphQLError will be returned, but the data will also be parsed into the receiver. func (c Client) GraphQL(hostname string, query string, variables map[string]interface{}, data interface{}) error { - // AuthToken is being handled by Transport, so let go-gh know that it does not need to resolve it. - opts := ghAPI.ClientOptions{Host: hostname, AuthToken: "none", Transport: c.http.Transport} - opts.Headers = map[string]string{"GraphQL-Features": "merge_queue"} + opts := clientOptions(hostname, c.http.Transport) + opts.Headers[graphqlFeatures] = mergeQueue gqlClient, err := gh.GQLClient(&opts) if err != nil { return err @@ -58,8 +68,7 @@ func (c Client) GraphQL(hostname string, query string, variables map[string]inte // GraphQL performs a GraphQL mutation and parses the response. If there are errors in the response, // GraphQLError will be returned, but the data will also be parsed into the receiver. func (c Client) Mutate(hostname, name string, mutation interface{}, variables map[string]interface{}) error { - // AuthToken is being handled by Transport, so let go-gh know that it does not need to resolve it. - opts := ghAPI.ClientOptions{Host: hostname, AuthToken: "none", Transport: c.http.Transport} + opts := clientOptions(hostname, c.http.Transport) gqlClient, err := gh.GQLClient(&opts) if err != nil { return err @@ -70,8 +79,7 @@ func (c Client) Mutate(hostname, name string, mutation interface{}, variables ma // GraphQL performs a GraphQL query and parses the response. If there are errors in the response, // GraphQLError will be returned, but the data will also be parsed into the receiver. func (c Client) Query(hostname, name string, query interface{}, variables map[string]interface{}) error { - // AuthToken is being handled by Transport, so let go-gh know that it does not need to resolve it. - opts := ghAPI.ClientOptions{Host: hostname, AuthToken: "none", Transport: c.http.Transport} + opts := clientOptions(hostname, c.http.Transport) gqlClient, err := gh.GQLClient(&opts) if err != nil { return err @@ -81,8 +89,7 @@ func (c Client) Query(hostname, name string, query interface{}, variables map[st // REST performs a REST request and parses the response. func (c Client) REST(hostname string, method string, p string, body io.Reader, data interface{}) error { - // AuthToken is being handled by Transport, so let go-gh know that it does not need to resolve it. - opts := ghAPI.ClientOptions{Host: hostname, AuthToken: "none", Transport: c.http.Transport} + opts := clientOptions(hostname, c.http.Transport) restClient, err := gh.RESTClient(&opts) if err != nil { return err @@ -91,8 +98,7 @@ func (c Client) REST(hostname string, method string, p string, body io.Reader, d } func (c Client) RESTWithNext(hostname string, method string, p string, body io.Reader, data interface{}) (string, error) { - // AuthToken is being handled by Transport, so let go-gh know that it does not need to resolve it. - opts := ghAPI.ClientOptions{Host: hostname, AuthToken: "none", Transport: c.http.Transport} + opts := clientOptions(hostname, c.http.Transport) restClient, err := gh.RESTClient(&opts) if err != nil { return "", err @@ -237,3 +243,22 @@ func generateScopesSuggestion(statusCode int, endpointNeedsScopes, tokenHasScope return "" } + +func clientOptions(hostname string, transport http.RoundTripper) ghAPI.ClientOptions { + // AuthToken, and Headers are being handled by transport, + // so let go-gh know that it does not need to resolve them. + opts := ghAPI.ClientOptions{ + AuthToken: "none", + // Blank values for Accept, Authorization, Content-Type, Time-Zone, and User-Agent headers. + Headers: map[string]string{ + accept: "", + authorization: "", + contentType: "", + timeZone: "", + userAgent: "", + }, + Host: hostname, + Transport: transport, + } + return opts +} diff --git a/api/client_test.go b/api/client_test.go index 53a750e02..0f36e3f69 100644 --- a/api/client_test.go +++ b/api/client_test.go @@ -5,9 +5,11 @@ import ( "errors" "io" "net/http" + "net/http/httptest" "testing" "github.com/cli/cli/v2/pkg/httpmock" + "github.com/cli/cli/v2/pkg/iostreams" "github.com/stretchr/testify/assert" ) @@ -221,3 +223,36 @@ func TestHTTPError_ScopesSuggestion(t *testing.T) { }) } } + +func TestHTTPHeaders(t *testing.T) { + var gotReq *http.Request + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotReq = r + w.WriteHeader(http.StatusNoContent) + })) + defer ts.Close() + + ios, _, _, stderr := iostreams.Test() + httpClient, err := NewHTTPClient(HTTPClientOptions{ + AppVersion: "v1.2.3", + Config: tinyConfig{ts.URL[7:] + ":oauth_token": "MYTOKEN"}, + Log: ios.ErrOut, + SkipAcceptHeaders: false, + }) + assert.NoError(t, err) + client := NewClientFromHTTP(httpClient) + + err = client.REST(ts.URL, "GET", ts.URL+"/user/repos", nil, nil) + assert.NoError(t, err) + + wantHeader := map[string]string{ + "Accept": "application/vnd.github.merge-info-preview+json, application/vnd.github.nebula-preview", + "Authorization": "token MYTOKEN", + "Content-Type": "application/json; charset=utf-8", + "User-Agent": "GitHub CLI v1.2.3", + } + for name, value := range wantHeader { + assert.Equal(t, value, gotReq.Header.Get(name), name) + } + assert.Equal(t, "", stderr.String()) +} diff --git a/api/http_client.go b/api/http_client.go index bd70c1015..9ea6ca1fb 100644 --- a/api/http_client.go +++ b/api/http_client.go @@ -35,10 +35,10 @@ func NewHTTPClient(opts HTTPClientOptions) (*http.Client, error) { } headers := map[string]string{ - "User-Agent": fmt.Sprintf("GitHub CLI %s", opts.AppVersion), + userAgent: fmt.Sprintf("GitHub CLI %s", opts.AppVersion), } if opts.SkipAcceptHeaders { - headers["Accept"] = "" + headers[accept] = "" } clientOpts.Headers = headers @@ -68,7 +68,10 @@ func NewCachedHTTPClient(httpClient *http.Client, ttl time.Duration) *http.Clien // should be cached for a specified amount of time. func AddCacheTTLHeader(rt http.RoundTripper, ttl time.Duration) http.RoundTripper { return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) { - req.Header.Set("X-GH-CACHE-TTL", ttl.String()) + // If the header is already set in the request, don't overwrite it. + if req.Header.Get(cacheTTL) == "" { + req.Header.Set(cacheTTL, ttl.String()) + } return rt.RoundTrip(req) }} } @@ -76,9 +79,12 @@ func AddCacheTTLHeader(rt http.RoundTripper, ttl time.Duration) http.RoundTrippe // AddAuthToken adds an authentication token header for the host specified by the request. func AddAuthTokenHeader(rt http.RoundTripper, cfg tokenGetter) http.RoundTripper { return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) { - hostname := ghinstance.NormalizeHostname(getHost(req)) - if token, _ := cfg.AuthToken(hostname); token != "" { - req.Header.Set("Authorization", fmt.Sprintf("token %s", token)) + // If the header is already set in the request, don't overwrite it. + if req.Header.Get(authorization) == "" { + hostname := ghinstance.NormalizeHostname(getHost(req)) + if token, _ := cfg.AuthToken(hostname); token != "" { + req.Header.Set(authorization, fmt.Sprintf("token %s", token)) + } } return rt.RoundTrip(req) }}