diff --git a/pkg/cmd/api/api.go b/pkg/cmd/api/api.go index 7e4106f9d..02eac439b 100644 --- a/pkg/cmd/api/api.go +++ b/pkg/cmd/api/api.go @@ -10,10 +10,14 @@ import ( "strings" "github.com/cli/cli/context" + "github.com/cli/cli/pkg/cmdutil" + "github.com/cli/cli/pkg/iostreams" "github.com/spf13/cobra" ) type ApiOptions struct { + IO *iostreams.IOStreams + RequestMethod string RequestMethodPassed bool RequestPath string @@ -54,9 +58,12 @@ on the format of the value: opts.RequestPath = args[0] opts.RequestMethodPassed = c.Flags().Changed("method") + // TODO: pass in via caller + opts.IO = iostreams.System() + opts.HttpClient = func() (*http.Client, error) { ctx := context.New() - token, err := ctx.AuthLogin() + token, err := ctx.AuthToken() if err != nil { return nil, err } @@ -108,12 +115,16 @@ func apiRun(opts *ApiOptions) error { } defer resp.Body.Close() - // TODO: make stdout configurable for tests - _, err = io.Copy(os.Stdout, resp.Body) + _, err = io.Copy(opts.IO.Out, resp.Body) if err != nil { return err } + // TODO: detect GraphQL errors + if resp.StatusCode > 299 { + return cmdutil.SilentError + } + return nil } @@ -131,7 +142,7 @@ func parseFields(opts *ApiOptions) (map[string]interface{}, error) { if err != nil { return params, err } - value, err := magicFieldValue(strValue) + value, err := magicFieldValue(strValue, opts.IO.In) if err != nil { return params, fmt.Errorf("error parsing %q value: %w", key, err) } @@ -148,12 +159,12 @@ func parseField(f string) (string, string, error) { return f[0:idx], f[idx+1:], nil } -func magicFieldValue(v string) (interface{}, error) { +func magicFieldValue(v string, stdin io.ReadCloser) (interface{}, error) { if strings.HasPrefix(v, "@") { - return readUserFile(v[1:]) + return readUserFile(v[1:], stdin) } - if n, err := strconv.Atoi(v); err != nil { + if n, err := strconv.Atoi(v); err == nil { return n, nil } @@ -169,18 +180,17 @@ func magicFieldValue(v string) (interface{}, error) { } } -func readUserFile(fn string) ([]byte, error) { +func readUserFile(fn string, stdin io.ReadCloser) ([]byte, error) { var r io.ReadCloser if fn == "-" { - // TODO: make stdin configurable for tests - r = os.Stdin + r = stdin } else { var err error r, err = os.Open(fn) if err != nil { return nil, err } - defer r.Close() } + defer r.Close() return ioutil.ReadAll(r) } diff --git a/pkg/cmd/api/api_test.go b/pkg/cmd/api/api_test.go new file mode 100644 index 000000000..ad8211761 --- /dev/null +++ b/pkg/cmd/api/api_test.go @@ -0,0 +1,115 @@ +package api + +import ( + "bytes" + "fmt" + "io/ioutil" + "net/http" + "reflect" + "testing" + + "github.com/cli/cli/pkg/cmdutil" + "github.com/cli/cli/pkg/iostreams" +) + +func Test_apiRun(t *testing.T) { + tests := []struct { + name string + httpResponse *http.Response + err error + stdout string + stderr string + }{ + { + name: "success", + httpResponse: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(bytes.NewBufferString(`bam!`)), + }, + err: nil, + stdout: `bam!`, + stderr: ``, + }, + { + name: "failure", + httpResponse: &http.Response{ + StatusCode: 502, + Body: ioutil.NopCloser(bytes.NewBufferString(`gateway timeout`)), + }, + err: cmdutil.SilentError, + stdout: `gateway timeout`, + stderr: ``, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + io, _, stdout, stderr := iostreams.Test() + + opts := ApiOptions{ + IO: io, + HttpClient: func() (*http.Client, error) { + var tr roundTripper = func(req *http.Request) (*http.Response, error) { + resp := tt.httpResponse + resp.Request = req + return resp, nil + } + return &http.Client{Transport: tr}, nil + }, + + RawFields: []string{}, + MagicFields: []string{}, + } + + err := apiRun(&opts) + if err != tt.err { + t.Errorf("expected error %v, got %v", tt.err, err) + } + + if stdout.String() != tt.stdout { + t.Errorf("expected output %q, got %q", tt.stdout, stdout.String()) + } + if stderr.String() != tt.stderr { + t.Errorf("expected error output %q, got %q", tt.stderr, stderr.String()) + } + }) + } +} + +func Test_parseFields(t *testing.T) { + io, stdin, _, _ := iostreams.Test() + fmt.Fprint(stdin, "pasted contents") + + opts := ApiOptions{ + IO: io, + RawFields: []string{ + "robot=Hubot", + "destroyer=false", + "helper=true", + "location=@work", + }, + MagicFields: []string{ + "input=@-", + "enabled=true", + "victories=123", + }, + } + + params, err := parseFields(&opts) + if err != nil { + t.Fatalf("parseFields error: %v", err) + } + + expect := map[string]interface{}{ + "robot": "Hubot", + "destroyer": "false", + "helper": "true", + "location": "@work", + "input": []byte("pasted contents"), + "enabled": true, + "victories": 123, + } + if !reflect.DeepEqual(params, expect) { + t.Errorf("expected %v, got %v", expect, params) + } +} diff --git a/pkg/cmd/api/http.go b/pkg/cmd/api/http.go index 4f9064f18..5f9ebc3ca 100644 --- a/pkg/cmd/api/http.go +++ b/pkg/cmd/api/http.go @@ -34,6 +34,8 @@ func httpRequest(client *http.Client, method string, p string, params interface{ } case io.Reader: body = pp + case nil: + body = nil default: return nil, fmt.Errorf("unrecognized parameters type: %v", params) } diff --git a/pkg/cmd/api/http_test.go b/pkg/cmd/api/http_test.go new file mode 100644 index 000000000..772a90409 --- /dev/null +++ b/pkg/cmd/api/http_test.go @@ -0,0 +1,306 @@ +package api + +import ( + "bytes" + "io/ioutil" + "net/http" + "reflect" + "testing" +) + +func Test_groupGraphQLVariables(t *testing.T) { + tests := []struct { + name string + args map[string]interface{} + want map[string]interface{} + }{ + { + name: "empty", + args: map[string]interface{}{}, + want: map[string]interface{}{}, + }, + { + name: "query only", + args: map[string]interface{}{ + "query": "QUERY", + }, + want: map[string]interface{}{ + "query": "QUERY", + }, + }, + { + name: "variables only", + args: map[string]interface{}{ + "name": "hubot", + }, + want: map[string]interface{}{ + "variables": map[string]interface{}{ + "name": "hubot", + }, + }, + }, + { + name: "query + variables", + args: map[string]interface{}{ + "query": "QUERY", + "name": "hubot", + "power": 9001, + }, + want: map[string]interface{}{ + "query": "QUERY", + "variables": map[string]interface{}{ + "name": "hubot", + "power": 9001, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := groupGraphQLVariables(tt.args); !reflect.DeepEqual(got, tt.want) { + t.Errorf("groupGraphQLVariables() = %v, want %v", got, tt.want) + } + }) + } +} + +type roundTripper func(*http.Request) (*http.Response, error) + +func (f roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +func Test_httpRequest(t *testing.T) { + var tr roundTripper = func(req *http.Request) (*http.Response, error) { + return &http.Response{Request: req}, nil + } + httpClient := http.Client{Transport: tr} + + type args struct { + client *http.Client + method string + p string + params interface{} + headers []string + } + type expects struct { + method string + u string + body string + headers string + } + tests := []struct { + name string + args args + want expects + wantErr bool + }{ + { + name: "simple GET", + args: args{ + client: &httpClient, + method: "GET", + p: "repos/octocat/spoon-knife", + params: nil, + headers: []string{}, + }, + wantErr: false, + want: expects{ + method: "GET", + u: "https://api.github.com/repos/octocat/spoon-knife", + body: "", + headers: "", + }, + }, + { + name: "GET with params", + args: args{ + client: &httpClient, + method: "GET", + p: "repos/octocat/spoon-knife", + params: map[string]interface{}{ + "a": "b", + }, + headers: []string{}, + }, + wantErr: false, + want: expects{ + method: "GET", + u: "https://api.github.com/repos/octocat/spoon-knife?a=b", + body: "", + headers: "", + }, + }, + { + name: "POST with params", + args: args{ + client: &httpClient, + method: "POST", + p: "repos", + params: map[string]interface{}{ + "a": "b", + }, + headers: []string{}, + }, + wantErr: false, + want: expects{ + method: "POST", + u: "https://api.github.com/repos", + body: `{"a":"b"}`, + headers: "Content-Type: application/json; charset=utf-8\r\n", + }, + }, + { + name: "POST GraphQL", + args: args{ + client: &httpClient, + method: "POST", + p: "graphql", + params: map[string]interface{}{ + "a": "b", + }, + headers: []string{}, + }, + wantErr: false, + want: expects{ + method: "POST", + u: "https://api.github.com/graphql", + body: `{"variables":{"a":"b"}}`, + headers: "Content-Type: application/json; charset=utf-8\r\n", + }, + }, + { + name: "POST with body and type", + args: args{ + client: &httpClient, + method: "POST", + p: "repos", + params: bytes.NewBufferString("CUSTOM"), + headers: []string{ + "content-type: text/plain", + "accept: application/json", + }, + }, + wantErr: false, + want: expects{ + method: "POST", + u: "https://api.github.com/repos", + body: `CUSTOM`, + headers: "Accept: application/json\r\nContent-Type: text/plain\r\n", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := httpRequest(tt.args.client, tt.args.method, tt.args.p, tt.args.params, tt.args.headers) + if (err != nil) != tt.wantErr { + t.Errorf("httpRequest() error = %v, wantErr %v", err, tt.wantErr) + return + } + req := got.Request + if req.Method != tt.want.method { + t.Errorf("Request.Method = %q, want %q", req.Method, tt.want.method) + } + if req.URL.String() != tt.want.u { + t.Errorf("Request.URL = %q, want %q", req.URL.String(), tt.want.u) + } + + if tt.want.body != "" { + bb, err := ioutil.ReadAll(req.Body) + if err != nil { + t.Errorf("Request.Body ReadAll error = %v", err) + return + } + if string(bb) != tt.want.body { + t.Errorf("Request.Body = %q, want %q", string(bb), tt.want.body) + } + } + + h := bytes.Buffer{} + err = req.Header.WriteSubset(&h, map[string]bool{}) + if err != nil { + t.Errorf("Request.Header WriteSubset error = %v", err) + return + } + if h.String() != tt.want.headers { + t.Errorf("Request.Header = %q, want %q", h.String(), tt.want.headers) + } + }) + } +} + +func Test_addQuery(t *testing.T) { + type args struct { + path string + params map[string]interface{} + } + tests := []struct { + name string + args args + want string + }{ + { + name: "string", + args: args{ + path: "", + params: map[string]interface{}{"a": "hello"}, + }, + want: "?a=hello", + }, + { + name: "append", + args: args{ + path: "path", + params: map[string]interface{}{"a": "b"}, + }, + want: "path?a=b", + }, + { + name: "append query", + args: args{ + path: "path?foo=bar", + params: map[string]interface{}{"a": "b"}, + }, + want: "path?foo=bar&a=b", + }, + { + name: "[]byte", + args: args{ + path: "", + params: map[string]interface{}{"a": []byte("hello")}, + }, + want: "?a=hello", + }, + { + name: "int", + args: args{ + path: "", + params: map[string]interface{}{"a": 123}, + }, + want: "?a=123", + }, + { + name: "nil", + args: args{ + path: "", + params: map[string]interface{}{"a": nil}, + }, + want: "?a=", + }, + { + name: "bool", + args: args{ + path: "", + params: map[string]interface{}{"a": true, "b": false}, + }, + want: "?a=true&b=false", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := addQuery(tt.args.path, tt.args.params); got != tt.want { + t.Errorf("addQuery() = %v, want %v", got, tt.want) + } + }) + } +}