diff --git a/api/cache.go b/api/cache.go deleted file mode 100644 index 59cf3f344..000000000 --- a/api/cache.go +++ /dev/null @@ -1,178 +0,0 @@ -package api - -import ( - "bufio" - "bytes" - "crypto/sha256" - "errors" - "fmt" - "io" - "net/http" - "os" - "path/filepath" - "strings" - "sync" - "time" -) - -func NewCachedClient(httpClient *http.Client, cacheTTL time.Duration) *http.Client { - cacheDir := filepath.Join(os.TempDir(), "gh-cli-cache") - return &http.Client{ - Transport: CacheResponse(cacheTTL, cacheDir)(httpClient.Transport), - } -} - -func isCacheableRequest(req *http.Request) bool { - if strings.EqualFold(req.Method, "GET") || strings.EqualFold(req.Method, "HEAD") { - return true - } - - if strings.EqualFold(req.Method, "POST") && (req.URL.Path == "/graphql" || req.URL.Path == "/api/graphql") { - return true - } - - return false -} - -func isCacheableResponse(res *http.Response) bool { - return res.StatusCode < 500 && res.StatusCode != 403 -} - -// CacheResponse produces a RoundTripper that caches HTTP responses to disk for a specified amount of time -func CacheResponse(ttl time.Duration, dir string) ClientOption { - fs := fileStorage{ - dir: dir, - ttl: ttl, - mu: &sync.RWMutex{}, - } - - return func(tr http.RoundTripper) http.RoundTripper { - return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) { - if !isCacheableRequest(req) { - return tr.RoundTrip(req) - } - - key, keyErr := cacheKey(req) - if keyErr == nil { - if res, err := fs.read(key); err == nil { - res.Request = req - return res, nil - } - } - - res, err := tr.RoundTrip(req) - if err == nil && keyErr == nil && isCacheableResponse(res) { - _ = fs.store(key, res) - } - return res, err - }} - } -} - -func copyStream(r io.ReadCloser) (io.ReadCloser, io.ReadCloser) { - b := &bytes.Buffer{} - nr := io.TeeReader(r, b) - return io.NopCloser(b), &readCloser{ - Reader: nr, - Closer: r, - } -} - -type readCloser struct { - io.Reader - io.Closer -} - -func cacheKey(req *http.Request) (string, error) { - h := sha256.New() - fmt.Fprintf(h, "%s:", req.Method) - fmt.Fprintf(h, "%s:", req.URL.String()) - fmt.Fprintf(h, "%s:", req.Header.Get("Accept")) - fmt.Fprintf(h, "%s:", req.Header.Get("Authorization")) - - if req.Body != nil { - var bodyCopy io.ReadCloser - req.Body, bodyCopy = copyStream(req.Body) - defer bodyCopy.Close() - if _, err := io.Copy(h, bodyCopy); err != nil { - return "", err - } - } - - digest := h.Sum(nil) - return fmt.Sprintf("%x", digest), nil -} - -type fileStorage struct { - dir string - ttl time.Duration - mu *sync.RWMutex -} - -func (fs *fileStorage) filePath(key string) string { - if len(key) >= 6 { - return filepath.Join(fs.dir, key[0:2], key[2:4], key[4:]) - } - return filepath.Join(fs.dir, key) -} - -func (fs *fileStorage) read(key string) (*http.Response, error) { - cacheFile := fs.filePath(key) - - fs.mu.RLock() - defer fs.mu.RUnlock() - - f, err := os.Open(cacheFile) - if err != nil { - return nil, err - } - defer f.Close() - - stat, err := f.Stat() - if err != nil { - return nil, err - } - - age := time.Since(stat.ModTime()) - if age > fs.ttl { - return nil, errors.New("cache expired") - } - - body := &bytes.Buffer{} - _, err = io.Copy(body, f) - if err != nil { - return nil, err - } - - res, err := http.ReadResponse(bufio.NewReader(body), nil) - return res, err -} - -func (fs *fileStorage) store(key string, res *http.Response) error { - cacheFile := fs.filePath(key) - - fs.mu.Lock() - defer fs.mu.Unlock() - - err := os.MkdirAll(filepath.Dir(cacheFile), 0755) - if err != nil { - return err - } - - f, err := os.OpenFile(cacheFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) - if err != nil { - return err - } - defer f.Close() - - var origBody io.ReadCloser - if res.Body != nil { - origBody, res.Body = copyStream(res.Body) - defer res.Body.Close() - } - err = res.Write(f) - if origBody != nil { - res.Body = origBody - } - return err -} diff --git a/api/cache_test.go b/api/cache_test.go deleted file mode 100644 index d02fae917..000000000 --- a/api/cache_test.go +++ /dev/null @@ -1,88 +0,0 @@ -package api - -import ( - "bytes" - "fmt" - "io" - "net/http" - "path/filepath" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func Test_CacheResponse(t *testing.T) { - counter := 0 - fakeHTTP := funcTripper{ - roundTrip: func(req *http.Request) (*http.Response, error) { - counter += 1 - body := fmt.Sprintf("%d: %s %s", counter, req.Method, req.URL.String()) - status := 200 - if req.URL.Path == "/error" { - status = 500 - } - return &http.Response{ - StatusCode: status, - Body: io.NopCloser(bytes.NewBufferString(body)), - }, nil - }, - } - - cacheDir := filepath.Join(t.TempDir(), "gh-cli-cache") - httpClient := NewHTTPClient(ReplaceTripper(fakeHTTP), CacheResponse(time.Minute, cacheDir)) - - do := func(method, url string, body io.Reader) (string, error) { - req, err := http.NewRequest(method, url, body) - if err != nil { - return "", err - } - res, err := httpClient.Do(req) - if err != nil { - return "", err - } - defer res.Body.Close() - resBody, err := io.ReadAll(res.Body) - if err != nil { - err = fmt.Errorf("ReadAll: %w", err) - } - return string(resBody), err - } - - var res string - var err error - - res, err = do("GET", "http://example.com/path", nil) - require.NoError(t, err) - assert.Equal(t, "1: GET http://example.com/path", res) - res, err = do("GET", "http://example.com/path", nil) - require.NoError(t, err) - assert.Equal(t, "1: GET http://example.com/path", res) - - res, err = do("GET", "http://example.com/path2", nil) - require.NoError(t, err) - assert.Equal(t, "2: GET http://example.com/path2", res) - - res, err = do("POST", "http://example.com/path2", nil) - require.NoError(t, err) - assert.Equal(t, "3: POST http://example.com/path2", res) - - res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello`)) - require.NoError(t, err) - assert.Equal(t, "4: POST http://example.com/graphql", res) - res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello`)) - require.NoError(t, err) - assert.Equal(t, "4: POST http://example.com/graphql", res) - - res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello2`)) - require.NoError(t, err) - assert.Equal(t, "5: POST http://example.com/graphql", res) - - res, err = do("GET", "http://example.com/error", nil) - require.NoError(t, err) - assert.Equal(t, "6: GET http://example.com/error", res) - res, err = do("GET", "http://example.com/error", nil) - require.NoError(t, err) - assert.Equal(t, "7: GET http://example.com/error", res) -} diff --git a/api/client.go b/api/client.go index 3a2aa83ca..f37b5306b 100644 --- a/api/client.go +++ b/api/client.go @@ -1,127 +1,26 @@ package api import ( - "bytes" "encoding/json" + "errors" "fmt" "io" "net/http" - "net/url" "regexp" "strings" "github.com/cli/cli/v2/internal/ghinstance" - graphql "github.com/cli/shurcooL-graphql" - "github.com/henvic/httpretty" + "github.com/cli/go-gh" + ghAPI "github.com/cli/go-gh/pkg/api" ) -// ClientOption represents an argument to NewClient -type ClientOption = func(http.RoundTripper) http.RoundTripper +var linkRE = regexp.MustCompile(`<([^>]+)>;\s*rel="([^"]+)"`) -// NewHTTPClient initializes an http.Client -func NewHTTPClient(opts ...ClientOption) *http.Client { - tr := http.DefaultTransport - for _, opt := range opts { - tr = opt(tr) - } - return &http.Client{Transport: tr} -} - -// NewClient initializes a Client -func NewClient(opts ...ClientOption) *Client { - client := &Client{http: NewHTTPClient(opts...)} - return client -} - -// NewClientFromHTTP takes in an http.Client instance func NewClientFromHTTP(httpClient *http.Client) *Client { client := &Client{http: httpClient} return client } -// AddHeader turns a RoundTripper into one that adds a request header -func AddHeader(name, value string) ClientOption { - return func(tr http.RoundTripper) http.RoundTripper { - return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) { - if req.Header.Get(name) == "" { - req.Header.Add(name, value) - } - return tr.RoundTrip(req) - }} - } -} - -// AddHeaderFunc is an AddHeader that gets the string value from a function -func AddHeaderFunc(name string, getValue func(*http.Request) (string, error)) ClientOption { - return func(tr http.RoundTripper) http.RoundTripper { - return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) { - if req.Header.Get(name) != "" { - return tr.RoundTrip(req) - } - value, err := getValue(req) - if err != nil { - return nil, err - } - if value != "" { - req.Header.Add(name, value) - } - return tr.RoundTrip(req) - }} - } -} - -// VerboseLog enables request/response logging within a RoundTripper -func VerboseLog(out io.Writer, logTraffic bool, colorize bool) ClientOption { - logger := &httpretty.Logger{ - Time: true, - TLS: false, - Colors: colorize, - RequestHeader: logTraffic, - RequestBody: logTraffic, - ResponseHeader: logTraffic, - ResponseBody: logTraffic, - Formatters: []httpretty.Formatter{&httpretty.JSONFormatter{}}, - MaxResponseBody: 10000, - } - logger.SetOutput(out) - logger.SetBodyFilter(func(h http.Header) (skip bool, err error) { - return !inspectableMIMEType(h.Get("Content-Type")), nil - }) - return logger.RoundTripper -} - -// ReplaceTripper substitutes the underlying RoundTripper with a custom one -func ReplaceTripper(tr http.RoundTripper) ClientOption { - return func(http.RoundTripper) http.RoundTripper { - return tr - } -} - -// ExtractHeader extracts a named header from any response received by this client and, if non-blank, saves -// it to dest. -func ExtractHeader(name string, dest *string) ClientOption { - return func(tr http.RoundTripper) http.RoundTripper { - return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) { - res, err := tr.RoundTrip(req) - if err == nil { - if value := res.Header.Get(name); value != "" { - *dest = value - } - } - return res, err - }} - } -} - -type funcTripper struct { - roundTrip func(*http.Request) (*http.Response, error) -} - -func (tr funcTripper) RoundTrip(req *http.Request) (*http.Response, error) { - return tr.roundTrip(req) -} - -// Client facilitates making HTTP requests to the GitHub API type Client struct { http *http.Client } @@ -130,103 +29,168 @@ func (c *Client) HTTP() *http.Client { return c.http } -type graphQLResponse struct { - Data interface{} - Errors []GraphQLError -} - -// GraphQLError is a single error returned in a GraphQL response type GraphQLError struct { - Type string - Message string - Path []interface{} // mixed strings and numbers + ghAPI.GQLError } -func (ge GraphQLError) PathString() string { - var res strings.Builder - for i, v := range ge.Path { - if i > 0 { - res.WriteRune('.') - } - fmt.Fprintf(&res, "%v", v) - } - return res.String() -} - -// GraphQLErrorResponse contains errors returned in a GraphQL response -type GraphQLErrorResponse struct { - Errors []GraphQLError -} - -func (gr GraphQLErrorResponse) Error() string { - errorMessages := make([]string, 0, len(gr.Errors)) - for _, e := range gr.Errors { - msg := e.Message - if p := e.PathString(); p != "" { - msg = fmt.Sprintf("%s (%s)", msg, p) - } - errorMessages = append(errorMessages, msg) - } - return fmt.Sprintf("GraphQL: %s", strings.Join(errorMessages, ", ")) -} - -// Match checks if this error is only about a specific type on a specific path. If the path argument ends -// with a ".", it will match all its subpaths as well. -func (gr GraphQLErrorResponse) Match(expectType, expectPath string) bool { - for _, e := range gr.Errors { - if e.Type != expectType || !matchPath(e.PathString(), expectPath) { - return false - } - } - return true -} - -func matchPath(p, expect string) bool { - if strings.HasSuffix(expect, ".") { - return strings.HasPrefix(p, expect) || p == strings.TrimSuffix(expect, ".") - } - return p == expect -} - -// HTTPError is an error returned by a failed API call type HTTPError struct { - StatusCode int - RequestURL *url.URL - Message string - Errors []HTTPErrorItem - + ghAPI.HTTPError scopesSuggestion string } -type HTTPErrorItem struct { - Message string - Resource string - Field string - Code string -} - -func (err HTTPError) Error() string { - if msgs := strings.SplitN(err.Message, "\n", 2); len(msgs) > 1 { - return fmt.Sprintf("HTTP %d: %s (%s)\n%s", err.StatusCode, msgs[0], err.RequestURL, msgs[1]) - } else if err.Message != "" { - return fmt.Sprintf("HTTP %d: %s (%s)", err.StatusCode, err.Message, err.RequestURL) - } - return fmt.Sprintf("HTTP %d (%s)", err.StatusCode, err.RequestURL) -} - func (err HTTPError) ScopesSuggestion() string { return err.scopesSuggestion } +// GraphQL performs a GraphQL request and parses the response. If there are errors in the response, +// GraphQLError will be returned, but the data will also be parsed into the receiver. +func (c Client) GraphQL(hostname string, query string, variables map[string]interface{}, data interface{}) error { + // AuthToken is being handled by Transport, so let go-gh know that it does not need to resolve it. + opts := ghAPI.ClientOptions{Host: hostname, AuthToken: "none", Transport: c.http.Transport} + opts.Headers = map[string]string{"GraphQL-Features": "merge_queue"} + gqlClient, err := gh.GQLClient(&opts) + if err != nil { + return err + } + return handleResponse(gqlClient.Do(query, variables, data)) +} + +// GraphQL performs a GraphQL mutation and parses the response. If there are errors in the response, +// GraphQLError will be returned, but the data will also be parsed into the receiver. +func (c Client) Mutate(hostname, name string, mutation interface{}, variables map[string]interface{}) error { + // AuthToken is being handled by Transport, so let go-gh know that it does not need to resolve it. + opts := ghAPI.ClientOptions{Host: hostname, AuthToken: "none", Transport: c.http.Transport} + gqlClient, err := gh.GQLClient(&opts) + if err != nil { + return err + } + return handleResponse(gqlClient.Mutate(name, mutation, variables)) +} + +// GraphQL performs a GraphQL query and parses the response. If there are errors in the response, +// GraphQLError will be returned, but the data will also be parsed into the receiver. +func (c Client) Query(hostname, name string, query interface{}, variables map[string]interface{}) error { + // AuthToken is being handled by Transport, so let go-gh know that it does not need to resolve it. + opts := ghAPI.ClientOptions{Host: hostname, AuthToken: "none", Transport: c.http.Transport} + gqlClient, err := gh.GQLClient(&opts) + if err != nil { + return err + } + return handleResponse(gqlClient.Query(name, query, variables)) +} + +// REST performs a REST request and parses the response. +func (c Client) REST(hostname string, method string, p string, body io.Reader, data interface{}) error { + // AuthToken is being handled by Transport, so let go-gh know that it does not need to resolve it. + opts := ghAPI.ClientOptions{Host: hostname, AuthToken: "none", Transport: c.http.Transport} + restClient, err := gh.RESTClient(&opts) + if err != nil { + return err + } + return handleResponse(restClient.Do(method, p, body, data)) +} + +func (c Client) RESTWithNext(hostname string, method string, p string, body io.Reader, data interface{}) (string, error) { + // AuthToken is being handled by Transport, so let go-gh know that it does not need to resolve it. + opts := ghAPI.ClientOptions{Host: hostname, AuthToken: "none", Transport: c.http.Transport} + restClient, err := gh.RESTClient(&opts) + if err != nil { + return "", err + } + + resp, err := restClient.Request(method, p, body) + if err != nil { + return "", err + } + defer resp.Body.Close() + + success := resp.StatusCode >= 200 && resp.StatusCode < 300 + if !success { + return "", HandleHTTPError(resp) + } + + if resp.StatusCode == http.StatusNoContent { + return "", nil + } + + b, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + + err = json.Unmarshal(b, &data) + if err != nil { + return "", err + } + + var next string + for _, m := range linkRE.FindAllStringSubmatch(resp.Header.Get("Link"), -1) { + if len(m) > 2 && m[2] == "next" { + next = m[1] + } + } + + return next, nil +} + +// HandleHTTPError parses a http.Response into a HTTPError. +func HandleHTTPError(resp *http.Response) error { + return handleResponse(ghAPI.HandleHTTPError(resp)) +} + +// handleResponse takes a ghAPI.HTTPError or ghAPI.GQLError and converts it into an +// HTTPError or GraphQLError respectively. +func handleResponse(err error) error { + if err == nil { + return nil + } + + var restErr ghAPI.HTTPError + if errors.As(err, &restErr) { + return HTTPError{ + HTTPError: restErr, + scopesSuggestion: generateScopesSuggestion(restErr.StatusCode, + restErr.Headers.Get("X-Accepted-Oauth-Scopes"), + restErr.Headers.Get("X-Oauth-Scopes"), + restErr.RequestURL.Hostname()), + } + } + + var gqlErr ghAPI.GQLError + if errors.As(err, &gqlErr) { + return GraphQLError{ + GQLError: gqlErr, + } + } + + return err +} + // ScopesSuggestion is an error messaging utility that prints the suggestion to request additional OAuth // scopes in case a server response indicates that there are missing scopes. func ScopesSuggestion(resp *http.Response) string { - if resp.StatusCode < 400 || resp.StatusCode > 499 || resp.StatusCode == 422 { + return generateScopesSuggestion(resp.StatusCode, + resp.Header.Get("X-Accepted-Oauth-Scopes"), + resp.Header.Get("X-Oauth-Scopes"), + resp.Request.URL.Hostname()) +} + +// EndpointNeedsScopes adds additional OAuth scopes to an HTTP response as if they were returned from the +// server endpoint. This improves HTTP 4xx error messaging for endpoints that don't explicitly list the +// OAuth scopes they need. +func EndpointNeedsScopes(resp *http.Response, s string) *http.Response { + if resp.StatusCode >= 400 && resp.StatusCode < 500 { + oldScopes := resp.Header.Get("X-Accepted-Oauth-Scopes") + resp.Header.Set("X-Accepted-Oauth-Scopes", fmt.Sprintf("%s, %s", oldScopes, s)) + } + return resp +} + +func generateScopesSuggestion(statusCode int, endpointNeedsScopes, tokenHasScopes, hostname string) string { + if statusCode < 400 || statusCode > 499 || statusCode == 422 { return "" } - endpointNeedsScopes := resp.Header.Get("X-Accepted-Oauth-Scopes") - tokenHasScopes := resp.Header.Get("X-Oauth-Scopes") if tokenHasScopes == "" { return "" } @@ -267,206 +231,9 @@ func ScopesSuggestion(resp *http.Response) string { return fmt.Sprintf( "This API operation needs the %[1]q scope. To request it, run: gh auth refresh -h %[2]s -s %[1]s", s, - ghinstance.NormalizeHostname(resp.Request.URL.Hostname()), + ghinstance.NormalizeHostname(hostname), ) } return "" } - -// EndpointNeedsScopes adds additional OAuth scopes to an HTTP response as if they were returned from the -// server endpoint. This improves HTTP 4xx error messaging for endpoints that don't explicitly list the -// OAuth scopes they need. -func EndpointNeedsScopes(resp *http.Response, s string) *http.Response { - if resp.StatusCode >= 400 && resp.StatusCode < 500 { - oldScopes := resp.Header.Get("X-Accepted-Oauth-Scopes") - resp.Header.Set("X-Accepted-Oauth-Scopes", fmt.Sprintf("%s, %s", oldScopes, s)) - } - return resp -} - -// GraphQL performs a GraphQL request and parses the response. If there are errors in the response, -// *GraphQLErrorResponse will be returned, but the data will also be parsed into the receiver. -func (c Client) GraphQL(hostname string, query string, variables map[string]interface{}, data interface{}) error { - reqBody, err := json.Marshal(map[string]interface{}{"query": query, "variables": variables}) - if err != nil { - return err - } - - req, err := http.NewRequest("POST", ghinstance.GraphQLEndpoint(hostname), bytes.NewBuffer(reqBody)) - if err != nil { - return err - } - - req.Header.Set("Content-Type", "application/json; charset=utf-8") - req.Header.Set("GraphQL-Features", "merge_queue") - - resp, err := c.http.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - - return handleResponse(resp, data) -} - -func graphQLClient(h *http.Client, hostname string) *graphql.Client { - return graphql.NewClient(ghinstance.GraphQLEndpoint(hostname), h) -} - -// REST performs a REST request and parses the response. -func (c Client) REST(hostname string, method string, p string, body io.Reader, data interface{}) error { - _, err := c.RESTWithNext(hostname, method, p, body, data) - return err -} - -func (c Client) RESTWithNext(hostname string, method string, p string, body io.Reader, data interface{}) (string, error) { - req, err := http.NewRequest(method, restURL(hostname, p), body) - if err != nil { - return "", err - } - - req.Header.Set("Content-Type", "application/json; charset=utf-8") - - resp, err := c.http.Do(req) - if err != nil { - return "", err - } - defer resp.Body.Close() - - success := resp.StatusCode >= 200 && resp.StatusCode < 300 - if !success { - return "", HandleHTTPError(resp) - } - - if resp.StatusCode == http.StatusNoContent { - return "", nil - } - - b, err := io.ReadAll(resp.Body) - if err != nil { - return "", err - } - - err = json.Unmarshal(b, &data) - if err != nil { - return "", err - } - - var next string - for _, m := range linkRE.FindAllStringSubmatch(resp.Header.Get("Link"), -1) { - if len(m) > 2 && m[2] == "next" { - next = m[1] - } - } - - return next, nil -} - -var linkRE = regexp.MustCompile(`<([^>]+)>;\s*rel="([^"]+)"`) - -func restURL(hostname string, pathOrURL string) string { - if strings.HasPrefix(pathOrURL, "https://") || strings.HasPrefix(pathOrURL, "http://") { - return pathOrURL - } - return ghinstance.RESTPrefix(hostname) + pathOrURL -} - -func handleResponse(resp *http.Response, data interface{}) error { - success := resp.StatusCode >= 200 && resp.StatusCode < 300 - - if !success { - return HandleHTTPError(resp) - } - - body, err := io.ReadAll(resp.Body) - if err != nil { - return err - } - - gr := &graphQLResponse{Data: data} - err = json.Unmarshal(body, &gr) - if err != nil { - return err - } - - if len(gr.Errors) > 0 { - return &GraphQLErrorResponse{Errors: gr.Errors} - } - return nil -} - -func HandleHTTPError(resp *http.Response) error { - httpError := HTTPError{ - StatusCode: resp.StatusCode, - RequestURL: resp.Request.URL, - scopesSuggestion: ScopesSuggestion(resp), - } - - if !jsonTypeRE.MatchString(resp.Header.Get("Content-Type")) { - httpError.Message = resp.Status - return httpError - } - - body, err := io.ReadAll(resp.Body) - if err != nil { - httpError.Message = err.Error() - return httpError - } - - var parsedBody struct { - Message string `json:"message"` - Errors []json.RawMessage - } - if err := json.Unmarshal(body, &parsedBody); err != nil { - return httpError - } - - var messages []string - if parsedBody.Message != "" { - messages = append(messages, parsedBody.Message) - } - for _, raw := range parsedBody.Errors { - switch raw[0] { - case '"': - var errString string - _ = json.Unmarshal(raw, &errString) - messages = append(messages, errString) - httpError.Errors = append(httpError.Errors, HTTPErrorItem{Message: errString}) - case '{': - var errInfo HTTPErrorItem - _ = json.Unmarshal(raw, &errInfo) - msg := errInfo.Message - if errInfo.Code != "" && errInfo.Code != "custom" { - msg = fmt.Sprintf("%s.%s %s", errInfo.Resource, errInfo.Field, errorCodeToMessage(errInfo.Code)) - } - if msg != "" { - messages = append(messages, msg) - } - httpError.Errors = append(httpError.Errors, errInfo) - } - } - httpError.Message = strings.Join(messages, "\n") - - return httpError -} - -func errorCodeToMessage(code string) string { - // https://docs.github.com/en/rest/overview/resources-in-the-rest-api#client-errors - switch code { - case "missing", "missing_field": - return "is missing" - case "invalid", "unprocessable": - return "is invalid" - case "already_exists": - return "already exists" - default: - return code - } -} - -var jsonTypeRE = regexp.MustCompile(`[/+]json($|;)`) - -func inspectableMIMEType(t string) bool { - return strings.HasPrefix(t, "text/") || jsonTypeRE.MatchString(t) -} diff --git a/api/client_test.go b/api/client_test.go index 666667d8b..53a750e02 100644 --- a/api/client_test.go +++ b/api/client_test.go @@ -11,12 +11,15 @@ import ( "github.com/stretchr/testify/assert" ) +func newTestClient(reg *httpmock.Registry) *Client { + client := &http.Client{} + httpmock.ReplaceTripper(client, reg) + return NewClientFromHTTP(client) +} + func TestGraphQL(t *testing.T) { http := &httpmock.Registry{} - client := NewClient( - ReplaceTripper(http), - AddHeader("Authorization", "token OTOKEN"), - ) + client := newTestClient(http) vars := map[string]interface{}{"name": "Mona"} response := struct { @@ -37,16 +40,15 @@ func TestGraphQL(t *testing.T) { req := http.Requests[0] reqBody, _ := io.ReadAll(req.Body) assert.Equal(t, `{"query":"QUERY","variables":{"name":"Mona"}}`, string(reqBody)) - assert.Equal(t, "token OTOKEN", req.Header.Get("Authorization")) } func TestGraphQLError(t *testing.T) { - http := &httpmock.Registry{} - client := NewClient(ReplaceTripper(http)) + reg := &httpmock.Registry{} + client := newTestClient(reg) response := struct{}{} - http.Register( + reg.Register( httpmock.GraphQL(""), httpmock.StringResponse(` { "errors": [ @@ -73,10 +75,7 @@ func TestGraphQLError(t *testing.T) { func TestRESTGetDelete(t *testing.T) { http := &httpmock.Registry{} - - client := NewClient( - ReplaceTripper(http), - ) + client := newTestClient(http) http.Register( httpmock.REST("DELETE", "applications/CLIENTID/grant"), @@ -90,7 +89,7 @@ func TestRESTGetDelete(t *testing.T) { func TestRESTWithFullURL(t *testing.T) { http := &httpmock.Registry{} - client := NewClient(ReplaceTripper(http)) + client := newTestClient(http) http.Register( httpmock.REST("GET", "api/v3/user/repos"), @@ -110,7 +109,7 @@ func TestRESTWithFullURL(t *testing.T) { func TestRESTError(t *testing.T) { fakehttp := &httpmock.Registry{} - client := NewClient(ReplaceTripper(fakehttp)) + client := newTestClient(fakehttp) fakehttp.Register(httpmock.MatchAny, func(req *http.Request) (*http.Response, error) { return &http.Response{ @@ -134,7 +133,6 @@ func TestRESTError(t *testing.T) { } if httpErr.Error() != "HTTP 422: OH NO (https://api.github.com/repos/branch)" { t.Errorf("got %q", httpErr.Error()) - } } diff --git a/api/http_client.go b/api/http_client.go new file mode 100644 index 000000000..daeb0f3da --- /dev/null +++ b/api/http_client.go @@ -0,0 +1,114 @@ +package api + +import ( + "fmt" + "io" + "net/http" + "time" + + "github.com/cli/cli/v2/internal/ghinstance" + "github.com/cli/cli/v2/utils" + "github.com/cli/go-gh" + ghAPI "github.com/cli/go-gh/pkg/api" +) + +type configGetter interface { + Get(string, string) (string, error) +} + +type HTTPClientOptions struct { + AppVersion string + CacheTTL time.Duration + Config configGetter + EnableCache bool + Log io.Writer + SkipAcceptHeaders bool +} + +func NewHTTPClient(opts HTTPClientOptions) (*http.Client, error) { + // Provide invalid host, and token values so gh.HTTPClient will not automatically resolve them. + // The real host and token are inserted at request time. + clientOpts := ghAPI.ClientOptions{Host: "none", AuthToken: "none"} + + if debugEnabled, _ := utils.IsDebugEnabled(); debugEnabled { + clientOpts.Log = opts.Log + } + + headers := map[string]string{ + "User-Agent": fmt.Sprintf("GitHub CLI %s", opts.AppVersion), + } + if opts.SkipAcceptHeaders { + headers["Accept"] = "" + } + clientOpts.Headers = headers + + if opts.EnableCache { + clientOpts.EnableCache = opts.EnableCache + clientOpts.CacheTTL = opts.CacheTTL + } + + client, err := gh.HTTPClient(&clientOpts) + if err != nil { + return nil, err + } + + client.Transport = AddAuthTokenHeader(client.Transport, opts.Config) + + return client, nil +} + +func NewCachedHTTPClient(httpClient *http.Client, ttl time.Duration) *http.Client { + httpClient.Transport = AddCacheTTLHeader(httpClient.Transport, ttl) + return httpClient +} + +// AddCacheTTLHeader adds an header to the request telling the cache that the request +// should be cached for a specified amount of time. +func AddCacheTTLHeader(rt http.RoundTripper, ttl time.Duration) http.RoundTripper { + return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) { + req.Header.Set("X-GH-CACHE-TTL", ttl.String()) + return rt.RoundTrip(req) + }} +} + +// AddAuthToken adds an authentication token header for the host specified by the request. +func AddAuthTokenHeader(rt http.RoundTripper, cfg configGetter) http.RoundTripper { + return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) { + hostname := ghinstance.NormalizeHostname(getHost(req)) + if token, err := cfg.Get(hostname, "oauth_token"); err == nil && token != "" { + req.Header.Set("Authorization", fmt.Sprintf("token %s", token)) + } + return rt.RoundTrip(req) + }} +} + +// ExtractHeader extracts a named header from any response received by this client and, +// if non-blank, saves it to dest. +func ExtractHeader(name string, dest *string) func(http.RoundTripper) http.RoundTripper { + return func(tr http.RoundTripper) http.RoundTripper { + return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) { + res, err := tr.RoundTrip(req) + if err == nil { + if value := res.Header.Get(name); value != "" { + *dest = value + } + } + return res, err + }} + } +} + +type funcTripper struct { + roundTrip func(*http.Request) (*http.Response, error) +} + +func (tr funcTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return tr.roundTrip(req) +} + +func getHost(r *http.Request) string { + if r.Host != "" { + return r.Host + } + return r.URL.Hostname() +} diff --git a/pkg/cmd/factory/http_test.go b/api/http_client_test.go similarity index 86% rename from pkg/cmd/factory/http_test.go rename to api/http_client_test.go index b5911b957..06021e2f1 100644 --- a/pkg/cmd/factory/http_test.go +++ b/api/http_client_test.go @@ -1,4 +1,4 @@ -package factory +package api import ( "fmt" @@ -27,10 +27,8 @@ func TestNewHTTPClient(t *testing.T) { setGhDebug bool envGhDebug string host string - sso string wantHeader map[string]string wantStderr string - wantSSO string }{ { name: "github.com with Accept header", @@ -99,6 +97,8 @@ func TestNewHTTPClient(t *testing.T) { > Host: github.com > Accept: application/vnd.github.merge-info-preview+json, application/vnd.github.nebula-preview > Authorization: token ████████████████████ + > Content-Type: application/json; charset=utf-8 + > Time-Zone: > User-Agent: GitHub CLI v1.2.3 < HTTP/1.1 204 No Content @@ -129,6 +129,8 @@ func TestNewHTTPClient(t *testing.T) { > Host: github.com > Accept: application/vnd.github.merge-info-preview+json, application/vnd.github.nebula-preview > Authorization: token ████████████████████ + > Content-Type: application/json; charset=utf-8 + > Time-Zone: > User-Agent: GitHub CLI v1.2.3 < HTTP/1.1 204 No Content @@ -148,29 +150,15 @@ func TestNewHTTPClient(t *testing.T) { wantHeader: map[string]string{ "authorization": "token GHETOKEN", "user-agent": "GitHub CLI v1.2.3", - "accept": "application/vnd.github.merge-info-preview+json, application/vnd.github.nebula-preview, application/vnd.github.antiope-preview, application/vnd.github.shadow-cat-preview", + "accept": "application/vnd.github.merge-info-preview+json, application/vnd.github.nebula-preview", }, wantStderr: "", }, - { - name: "SSO challenge in response header", - args: args{ - config: tinyConfig{}, - appVersion: "v1.2.3", - }, - host: "github.com", - sso: "required; url=https://github.com/login/sso?return_to=xyz¶m=123abc; another", - wantStderr: "", - wantSSO: "https://github.com/login/sso?return_to=xyz¶m=123abc", - }, } var gotReq *http.Request ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { gotReq = r - if sso := r.URL.Query().Get("sso"); sso != "" { - w.Header().Set("X-GitHub-SSO", sso) - } w.WriteHeader(http.StatusNoContent) })) defer ts.Close() @@ -191,19 +179,20 @@ func TestNewHTTPClient(t *testing.T) { }) ios, _, _, stderr := iostreams.Test() - client, err := NewHTTPClient(ios, tt.args.config, tt.args.appVersion, tt.args.setAccept) + client, err := NewHTTPClient(HTTPClientOptions{ + AppVersion: tt.args.appVersion, + Config: tt.args.config, + Log: ios.ErrOut, + SkipAcceptHeaders: !tt.args.setAccept, + }) require.NoError(t, err) req, err := http.NewRequest("GET", ts.URL, nil) - if tt.sso != "" { - q := req.URL.Query() - q.Set("sso", tt.sso) - req.URL.RawQuery = q.Encode() - } req.Host = tt.host require.NoError(t, err) res, err := client.Do(req) + require.NoError(t, err) for name, value := range tt.wantHeader { @@ -212,7 +201,6 @@ func TestNewHTTPClient(t *testing.T) { assert.Equal(t, 204, res.StatusCode) assert.Equal(t, tt.wantStderr, normalizeVerboseLog(stderr.String())) - assert.Equal(t, tt.wantSSO, SSOURL()) }) } } @@ -227,11 +215,13 @@ var requestAtRE = regexp.MustCompile(`(?m)^\* Request at .+`) var dateRE = regexp.MustCompile(`(?m)^< Date: .+`) var hostWithPortRE = regexp.MustCompile(`127\.0\.0\.1:\d+`) var durationRE = regexp.MustCompile(`(?m)^\* Request took .+`) +var timezoneRE = regexp.MustCompile(`(?m)^> Time-Zone: .+`) func normalizeVerboseLog(t string) string { t = requestAtRE.ReplaceAllString(t, "* Request at