From fa3e25bb4db25ad03a26a9e20e5d08d2992e10d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Thu, 14 May 2020 11:42:03 +0200 Subject: [PATCH] Serialize GraphQL parameters under `variables` --- api/client.go | 68 +++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 63 insertions(+), 5 deletions(-) diff --git a/api/client.go b/api/client.go index a0007c465..0e620365f 100644 --- a/api/client.go +++ b/api/client.go @@ -7,6 +7,7 @@ import ( "io" "io/ioutil" "net/http" + "net/url" "regexp" "strings" @@ -212,15 +213,23 @@ func (c Client) DirectRequest(method string, p string, params interface{}, heade url := "https://api.github.com/" + p var body io.Reader var bodyIsJSON bool + isGraphQL := p == "graphql" switch pp := params.(type) { case map[string]interface{}: - b, err := json.Marshal(pp) - if err != nil { - return nil, fmt.Errorf("error serializing parameters: %w", err) + if strings.EqualFold(method, "GET") { + url = addQuery(url, pp) + } else { + if isGraphQL { + pp = groupGraphQLVariables(pp) + } + b, err := json.Marshal(pp) + if err != nil { + return nil, fmt.Errorf("error serializing parameters: %w", err) + } + body = bytes.NewBuffer(b) + bodyIsJSON = true } - body = bytes.NewBuffer(b) - bodyIsJSON = true case io.Reader: body = pp default: @@ -246,6 +255,55 @@ func (c Client) DirectRequest(method string, p string, params interface{}, heade return c.http.Do(req) } +func groupGraphQLVariables(params map[string]interface{}) map[string]interface{} { + topLevel := make(map[string]interface{}) + variables := make(map[string]interface{}) + + for key, val := range params { + switch key { + case "query": + topLevel[key] = val + default: + variables[key] = val + } + } + + if len(variables) > 0 { + topLevel["variables"] = variables + } + return topLevel +} + +func addQuery(path string, params map[string]interface{}) string { + if len(params) == 0 { + return path + } + + query := url.Values{} + for key, value := range params { + switch v := value.(type) { + case string: + query.Add(key, v) + case []byte: + query.Add(key, string(v)) + case nil: + query.Add(key, "") + case int: + query.Add(key, fmt.Sprintf("%d", v)) + case bool: + query.Add(key, fmt.Sprintf("%v", v)) + default: + panic(fmt.Sprintf("unknown type %v", v)) + } + } + + sep := "?" + if strings.ContainsRune(path, '?') { + sep = "&" + } + return path + sep + query.Encode() +} + func handleResponse(resp *http.Response, data interface{}) error { success := resp.StatusCode >= 200 && resp.StatusCode < 300