cli/github/http.go
2019-10-07 16:22:53 +02:00

490 lines
11 KiB
Go

package github
import (
"bytes"
"context"
"crypto/md5"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/url"
"os"
"path"
"path/filepath"
"regexp"
"sort"
"strconv"
"strings"
"time"
"github.com/github/gh-cli/ui"
"github.com/github/gh-cli/utils"
)
const apiPayloadVersion = "application/vnd.github.v3+json;charset=utf-8"
const patchMediaType = "application/vnd.github.v3.patch;charset=utf-8"
const textMediaType = "text/plain;charset=utf-8"
const checksType = "application/vnd.github.antiope-preview+json;charset=utf-8"
const draftsType = "application/vnd.github.shadow-cat-preview+json;charset=utf-8"
const cacheVersion = 2
var inspectHeaders = []string{
"Authorization",
"X-GitHub-OTP",
"Location",
"Link",
"Accept",
}
type verboseTransport struct {
Transport *http.Transport
Verbose bool
OverrideURL *url.URL
Out io.Writer
Colorized bool
}
func (t *verboseTransport) RoundTrip(req *http.Request) (resp *http.Response, err error) {
if t.Verbose {
t.dumpRequest(req)
}
if t.OverrideURL != nil {
port := "80"
if s := strings.Split(req.URL.Host, ":"); len(s) > 1 {
port = s[1]
}
req = cloneRequest(req)
req.Header.Set("X-Original-Scheme", req.URL.Scheme)
req.Header.Set("X-Original-Port", port)
req.Host = req.URL.Host
req.URL.Scheme = t.OverrideURL.Scheme
req.URL.Host = t.OverrideURL.Host
}
resp, err = t.Transport.RoundTrip(req)
if err == nil && t.Verbose {
t.dumpResponse(resp)
}
return
}
func (t *verboseTransport) dumpRequest(req *http.Request) {
info := fmt.Sprintf("> %s %s://%s%s", req.Method, req.URL.Scheme, req.URL.Host, req.URL.RequestURI())
t.verbosePrintln(info)
t.dumpHeaders(req.Header, ">")
body := t.dumpBody(req.Body)
if body != nil {
// reset body since it's been read
req.Body = body
}
}
func (t *verboseTransport) dumpResponse(resp *http.Response) {
info := fmt.Sprintf("< HTTP %d", resp.StatusCode)
t.verbosePrintln(info)
t.dumpHeaders(resp.Header, "<")
body := t.dumpBody(resp.Body)
if body != nil {
// reset body since it's been read
resp.Body = body
}
}
func (t *verboseTransport) dumpHeaders(header http.Header, indent string) {
for _, listed := range inspectHeaders {
for name, vv := range header {
if !strings.EqualFold(name, listed) {
continue
}
for _, v := range vv {
if v != "" {
r := regexp.MustCompile("(?i)^(basic|token) (.+)")
if r.MatchString(v) {
v = r.ReplaceAllString(v, "$1 [REDACTED]")
}
info := fmt.Sprintf("%s %s: %s", indent, name, v)
t.verbosePrintln(info)
}
}
}
}
}
func (t *verboseTransport) dumpBody(body io.ReadCloser) io.ReadCloser {
if body == nil {
return nil
}
defer body.Close()
buf := new(bytes.Buffer)
_, err := io.Copy(buf, body)
utils.Check(err)
if buf.Len() > 0 {
t.verbosePrintln(buf.String())
}
return ioutil.NopCloser(buf)
}
func (t *verboseTransport) verbosePrintln(msg string) {
if t.Colorized {
msg = fmt.Sprintf("\033[36m%s\033[0m", msg)
}
fmt.Fprintln(t.Out, msg)
}
func newHttpClient(testHost string, verbose bool, unixSocket string) *http.Client {
var testURL *url.URL
if testHost != "" {
testURL, _ = url.Parse(testHost)
}
var httpTransport *http.Transport
if unixSocket != "" {
dialFunc := func(network, addr string) (net.Conn, error) {
return net.Dial("unix", unixSocket)
}
dialContext := func(_ context.Context, _, _ string) (net.Conn, error) {
return net.Dial("unix", unixSocket)
}
httpTransport = &http.Transport{
DialContext: dialContext,
DialTLS: dialFunc,
ResponseHeaderTimeout: 30 * time.Second,
ExpectContinueTimeout: 10 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
}
} else {
httpTransport = &http.Transport{
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
TLSHandshakeTimeout: 10 * time.Second,
}
}
tr := &verboseTransport{
Transport: httpTransport,
Verbose: verbose,
OverrideURL: testURL,
Out: ui.Stderr,
Colorized: ui.IsTerminal(os.Stderr),
}
return &http.Client{
Transport: tr,
}
}
func cloneRequest(req *http.Request) *http.Request {
dup := new(http.Request)
*dup = *req
dup.URL, _ = url.Parse(req.URL.String())
dup.Header = make(http.Header)
for k, s := range req.Header {
dup.Header[k] = s
}
return dup
}
type simpleClient struct {
httpClient *http.Client
rootUrl *url.URL
PrepareRequest func(*http.Request)
CacheTTL int
}
func (c *simpleClient) performRequest(method, path string, body io.Reader, configure func(*http.Request)) (*simpleResponse, error) {
url, err := url.Parse(path)
if err == nil {
url = c.rootUrl.ResolveReference(url)
return c.performRequestUrl(method, url, body, configure)
} else {
return nil, err
}
}
func (c *simpleClient) performRequestUrl(method string, url *url.URL, body io.Reader, configure func(*http.Request)) (res *simpleResponse, err error) {
req, err := http.NewRequest(method, url.String(), body)
if err != nil {
return
}
if c.PrepareRequest != nil {
c.PrepareRequest(req)
}
req.Header.Set("User-Agent", UserAgent)
req.Header.Set("Accept", apiPayloadVersion)
if configure != nil {
configure(req)
}
key := cacheKey(req)
if cachedResponse := c.cacheRead(key, req); cachedResponse != nil {
res = &simpleResponse{cachedResponse}
return
}
httpResponse, err := c.httpClient.Do(req)
if err != nil {
return
}
c.cacheWrite(key, httpResponse)
res = &simpleResponse{httpResponse}
return
}
func isGraphQL(req *http.Request) bool {
return req.URL.Path == "/graphql"
}
func canCache(req *http.Request) bool {
return strings.EqualFold(req.Method, "GET") || isGraphQL(req)
}
func (c *simpleClient) cacheRead(key string, req *http.Request) (res *http.Response) {
if c.CacheTTL > 0 && canCache(req) {
f := cacheFile(key)
cacheInfo, err := os.Stat(f)
if err != nil {
return
}
if time.Since(cacheInfo.ModTime()).Seconds() > float64(c.CacheTTL) {
return
}
cf, err := os.Open(f)
if err != nil {
return
}
defer cf.Close()
cb, err := ioutil.ReadAll(cf)
if err != nil {
return
}
parts := strings.SplitN(string(cb), "\r\n\r\n", 2)
if len(parts) < 2 {
return
}
res = &http.Response{
Body: ioutil.NopCloser(bytes.NewBufferString(parts[1])),
Header: http.Header{},
}
headerLines := strings.Split(parts[0], "\r\n")
if len(headerLines) < 1 {
return
}
if proto := strings.SplitN(headerLines[0], " ", 3); len(proto) >= 3 {
res.Proto = proto[0]
res.Status = fmt.Sprintf("%s %s", proto[1], proto[2])
if code, _ := strconv.Atoi(proto[1]); code > 0 {
res.StatusCode = code
}
}
for _, line := range headerLines[1:] {
kv := strings.SplitN(line, ":", 2)
if len(kv) >= 2 {
res.Header.Add(kv[0], strings.TrimLeft(kv[1], " "))
}
}
}
return
}
func (c *simpleClient) cacheWrite(key string, res *http.Response) {
if c.CacheTTL > 0 && canCache(res.Request) && res.StatusCode < 500 && res.StatusCode != 403 {
bodyCopy := &bytes.Buffer{}
bodyReplacement := readCloserCallback{
Reader: io.TeeReader(res.Body, bodyCopy),
Closer: res.Body,
Callback: func() {
f := cacheFile(key)
err := os.MkdirAll(filepath.Dir(f), 0771)
if err != nil {
return
}
cf, err := os.OpenFile(f, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return
}
defer cf.Close()
fmt.Fprintf(cf, "%s %s\r\n", res.Proto, res.Status)
res.Header.Write(cf)
fmt.Fprintf(cf, "\r\n")
io.Copy(cf, bodyCopy)
},
}
res.Body = &bodyReplacement
}
}
type readCloserCallback struct {
Callback func()
Closer io.Closer
io.Reader
}
func (rc *readCloserCallback) Close() error {
err := rc.Closer.Close()
if err == nil {
rc.Callback()
}
return err
}
func cacheKey(req *http.Request) string {
path := strings.Replace(req.URL.EscapedPath(), "/", "-", -1)
if len(path) > 1 {
path = strings.TrimPrefix(path, "-")
}
host := req.Host
if host == "" {
host = req.URL.Host
}
hash := md5.New()
fmt.Fprintf(hash, "%d:", cacheVersion)
io.WriteString(hash, req.Header.Get("Accept"))
io.WriteString(hash, req.Header.Get("Authorization"))
queryParts := strings.Split(req.URL.RawQuery, "&")
sort.Strings(queryParts)
for _, q := range queryParts {
fmt.Fprintf(hash, "%s&", q)
}
if isGraphQL(req) && req.Body != nil {
if b, err := ioutil.ReadAll(req.Body); err == nil {
req.Body = ioutil.NopCloser(bytes.NewBuffer(b))
hash.Write(b)
}
}
return fmt.Sprintf("%s/%s_%x", host, path, hash.Sum(nil))
}
func cacheFile(key string) string {
return path.Join(os.TempDir(), "hub", "api", key)
}
func (c *simpleClient) jsonRequest(method, path string, body interface{}, configure func(*http.Request)) (*simpleResponse, error) {
json, err := json.Marshal(body)
if err != nil {
return nil, err
}
buf := bytes.NewBuffer(json)
return c.performRequest(method, path, buf, func(req *http.Request) {
req.Header.Set("Content-Type", "application/json; charset=utf-8")
if configure != nil {
configure(req)
}
})
}
func (c *simpleClient) Get(path string) (*simpleResponse, error) {
return c.performRequest("GET", path, nil, nil)
}
func (c *simpleClient) GetFile(path string, mimeType string) (*simpleResponse, error) {
return c.performRequest("GET", path, nil, func(req *http.Request) {
req.Header.Set("Accept", mimeType)
})
}
func (c *simpleClient) PostJSON(path string, payload interface{}) (*simpleResponse, error) {
return c.jsonRequest("POST", path, payload, nil)
}
func (c *simpleClient) PostJSONPreview(path string, payload interface{}, mimeType string) (*simpleResponse, error) {
return c.jsonRequest("POST", path, payload, func(req *http.Request) {
req.Header.Set("Accept", mimeType)
})
}
func (c *simpleClient) PatchJSON(path string, payload interface{}) (*simpleResponse, error) {
return c.jsonRequest("PATCH", path, payload, nil)
}
type simpleResponse struct {
*http.Response
}
type errorInfo struct {
Message string `json:"message"`
Errors []fieldError `json:"errors"`
Response *http.Response
}
type errorInfoSimple struct {
Message string `json:"message"`
Errors []string `json:"errors"`
}
type fieldError struct {
Resource string `json:"resource"`
Message string `json:"message"`
Code string `json:"code"`
Field string `json:"field"`
}
func (e *errorInfo) Error() string {
return e.Message
}
func (res *simpleResponse) Unmarshal(dest interface{}) (err error) {
defer res.Body.Close()
body, err := ioutil.ReadAll(res.Body)
if err != nil {
return
}
return json.Unmarshal(body, dest)
}
func (res *simpleResponse) ErrorInfo() (msg *errorInfo, err error) {
defer res.Body.Close()
body, err := ioutil.ReadAll(res.Body)
if err != nil {
return
}
msg = &errorInfo{}
err = json.Unmarshal(body, msg)
if err != nil {
msgSimple := &errorInfoSimple{}
if err = json.Unmarshal(body, msgSimple); err == nil {
msg.Message = msgSimple.Message
for _, errMsg := range msgSimple.Errors {
msg.Errors = append(msg.Errors, fieldError{
Code: "custom",
Message: errMsg,
})
}
}
}
if err == nil {
msg.Response = res.Response
}
return
}
func (res *simpleResponse) Link(name string) string {
linkVal := res.Header.Get("Link")
re := regexp.MustCompile(`<([^>]+)>; rel="([^"]+)"`)
for _, match := range re.FindAllStringSubmatch(linkVal, -1) {
if match[2] == name {
return match[1]
}
}
return ""
}