Serialize GraphQL parameters under variables

This commit is contained in:
Mislav Marohnić 2020-05-14 11:42:03 +02:00
parent 1609afe993
commit fa3e25bb4d

View file

@ -7,6 +7,7 @@ import (
"io"
"io/ioutil"
"net/http"
"net/url"
"regexp"
"strings"
@ -212,15 +213,23 @@ func (c Client) DirectRequest(method string, p string, params interface{}, heade
url := "https://api.github.com/" + p
var body io.Reader
var bodyIsJSON bool
isGraphQL := p == "graphql"
switch pp := params.(type) {
case map[string]interface{}:
b, err := json.Marshal(pp)
if err != nil {
return nil, fmt.Errorf("error serializing parameters: %w", err)
if strings.EqualFold(method, "GET") {
url = addQuery(url, pp)
} else {
if isGraphQL {
pp = groupGraphQLVariables(pp)
}
b, err := json.Marshal(pp)
if err != nil {
return nil, fmt.Errorf("error serializing parameters: %w", err)
}
body = bytes.NewBuffer(b)
bodyIsJSON = true
}
body = bytes.NewBuffer(b)
bodyIsJSON = true
case io.Reader:
body = pp
default:
@ -246,6 +255,55 @@ func (c Client) DirectRequest(method string, p string, params interface{}, heade
return c.http.Do(req)
}
func groupGraphQLVariables(params map[string]interface{}) map[string]interface{} {
topLevel := make(map[string]interface{})
variables := make(map[string]interface{})
for key, val := range params {
switch key {
case "query":
topLevel[key] = val
default:
variables[key] = val
}
}
if len(variables) > 0 {
topLevel["variables"] = variables
}
return topLevel
}
func addQuery(path string, params map[string]interface{}) string {
if len(params) == 0 {
return path
}
query := url.Values{}
for key, value := range params {
switch v := value.(type) {
case string:
query.Add(key, v)
case []byte:
query.Add(key, string(v))
case nil:
query.Add(key, "")
case int:
query.Add(key, fmt.Sprintf("%d", v))
case bool:
query.Add(key, fmt.Sprintf("%v", v))
default:
panic(fmt.Sprintf("unknown type %v", v))
}
}
sep := "?"
if strings.ContainsRune(path, '?') {
sep = "&"
}
return path + sep + query.Encode()
}
func handleResponse(resp *http.Response, data interface{}) error {
success := resp.StatusCode >= 200 && resp.StatusCode < 300