diff --git a/api/client.go b/api/client.go index 88ea5b954..fe1fcf842 100644 --- a/api/client.go +++ b/api/client.go @@ -16,14 +16,18 @@ import ( // ClientOption represents an argument to NewClient type ClientOption = func(http.RoundTripper) http.RoundTripper -// NewClient initializes a Client -func NewClient(opts ...ClientOption) *Client { +// NewHTTPClient initializes an http.Client +func NewHTTPClient(opts ...ClientOption) *http.Client { tr := http.DefaultTransport for _, opt := range opts { tr = opt(tr) } - http := &http.Client{Transport: tr} - client := &Client{http: http} + return &http.Client{Transport: tr} +} + +// NewClient initializes a Client +func NewClient(opts ...ClientOption) *Client { + client := &Client{http: NewHTTPClient(opts...)} return client } diff --git a/cmd/gh/main.go b/cmd/gh/main.go index 9ddf79315..98916523e 100644 --- a/cmd/gh/main.go +++ b/cmd/gh/main.go @@ -11,6 +11,7 @@ import ( "github.com/cli/cli/command" "github.com/cli/cli/internal/config" + "github.com/cli/cli/pkg/cmdutil" "github.com/cli/cli/update" "github.com/cli/cli/utils" "github.com/mgutz/ansi" @@ -48,6 +49,10 @@ func main() { } func printError(out io.Writer, err error, cmd *cobra.Command, debug bool) { + if err == cmdutil.SilentError { + return + } + var dnsError *net.DNSError if errors.As(err, &dnsError) { fmt.Fprintf(out, "error connecting to %s\n", dnsError.Name) @@ -60,7 +65,7 @@ func printError(out io.Writer, err error, cmd *cobra.Command, debug bool) { fmt.Fprintln(out, err) - var flagError *command.FlagError + var flagError *cmdutil.FlagError if errors.As(err, &flagError) || strings.HasPrefix(err.Error(), "unknown command ") { if !strings.HasSuffix(err.Error(), "\n") { fmt.Fprintln(out) diff --git a/cmd/gh/main_test.go b/cmd/gh/main_test.go index 3e0a02690..9036391fd 100644 --- a/cmd/gh/main_test.go +++ b/cmd/gh/main_test.go @@ -7,7 +7,7 @@ import ( "net" "testing" - "github.com/cli/cli/command" + "github.com/cli/cli/pkg/cmdutil" "github.com/spf13/cobra" ) @@ -49,7 +49,7 @@ check your internet connection or githubstatus.com { name: "Cobra flag error", args: args{ - err: &command.FlagError{Err: errors.New("unknown flag --foo")}, + err: &cmdutil.FlagError{Err: errors.New("unknown flag --foo")}, cmd: cmd, debug: false, }, diff --git a/command/issue.go b/command/issue.go index 46648d914..d467986a1 100644 --- a/command/issue.go +++ b/command/issue.go @@ -73,14 +73,9 @@ var issueStatusCmd = &cobra.Command{ RunE: issueStatus, } var issueViewCmd = &cobra.Command{ - Use: "view { | }", - Args: func(cmd *cobra.Command, args []string) error { - if len(args) < 1 { - return FlagError{errors.New("issue number or URL required as argument")} - } - return nil - }, + Use: "view { | }", Short: "View an issue", + Args: cobra.ExactArgs(1), Long: `Display the title, body, and other information about an issue. With '--web', open the issue in a web browser instead.`, diff --git a/command/root.go b/command/root.go index cabf9a7c0..63ef4c4f9 100644 --- a/command/root.go +++ b/command/root.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "io" + "net/http" "os" "regexp" "runtime/debug" @@ -13,6 +14,9 @@ import ( "github.com/cli/cli/context" "github.com/cli/cli/internal/config" "github.com/cli/cli/internal/ghrepo" + apiCmd "github.com/cli/cli/pkg/cmd/api" + "github.com/cli/cli/pkg/cmdutil" + "github.com/cli/cli/pkg/iostreams" "github.com/cli/cli/utils" "github.com/spf13/cobra" @@ -60,21 +64,26 @@ func init() { if err == pflag.ErrHelp { return err } - return &FlagError{Err: err} + return &cmdutil.FlagError{Err: err} }) -} -// FlagError is the kind of error raised in flag processing -type FlagError struct { - Err error -} - -func (fe FlagError) Error() string { - return fe.Err.Error() -} - -func (fe FlagError) Unwrap() error { - return fe.Err + // TODO: iron out how a factory incorporates context + cmdFactory := &cmdutil.Factory{ + IOStreams: iostreams.System(), + HttpClient: func() (*http.Client, error) { + token := os.Getenv("GITHUB_TOKEN") + if len(token) == 0 { + ctx := context.New() + var err error + token, err = ctx.AuthToken() + if err != nil { + return nil, err + } + } + return httpClient(token), nil + }, + } + RootCmd.AddCommand(apiCmd.NewCmdApi(cmdFactory, nil)) } // RootCmd is the entry point of command-line execution @@ -136,6 +145,19 @@ func contextForCommand(cmd *cobra.Command) context.Context { return ctx } +// for cmdutil-powered commands +func httpClient(token string) *http.Client { + var opts []api.ClientOption + if verbose := os.Getenv("DEBUG"); verbose != "" { + opts = append(opts, apiVerboseLog()) + } + opts = append(opts, + api.AddHeader("Authorization", fmt.Sprintf("token %s", token)), + api.AddHeader("User-Agent", fmt.Sprintf("GitHub CLI %s", Version)), + ) + return api.NewHTTPClient(opts...) +} + // overridden in tests var apiClientForContext = func(ctx context.Context) (*api.Client, error) { token, err := ctx.AuthToken() diff --git a/go.mod b/go.mod index 193320caa..407c8ab8a 100644 --- a/go.mod +++ b/go.mod @@ -21,7 +21,7 @@ require ( github.com/shurcooL/graphql v0.0.0-20181231061246-d48a9a75455f // indirect github.com/spf13/cobra v0.0.6 github.com/spf13/pflag v1.0.5 - github.com/stretchr/testify v1.4.0 // indirect + github.com/stretchr/testify v1.5.1 golang.org/x/crypto v0.0.0-20200219234226-1ad67e1f0ef4 golang.org/x/net v0.0.0-20200219183655-46282727080f // indirect golang.org/x/text v0.3.2 diff --git a/go.sum b/go.sum index 807e68f62..09c195c5d 100644 --- a/go.sum +++ b/go.sum @@ -167,13 +167,14 @@ github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.4.0/go.mod h1:PTJ7Z/lr49W6bUbkmS1V3by4uWynFiR9p7+dSq/yZzE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1 h1:2vfRuCMp5sSVIDSqO8oNnWJq7mPa6KVP3iPIwFBuy8A= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.1/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGrc= github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= diff --git a/pkg/cmd/api/api.go b/pkg/cmd/api/api.go new file mode 100644 index 000000000..4158ff8c1 --- /dev/null +++ b/pkg/cmd/api/api.go @@ -0,0 +1,190 @@ +package api + +import ( + "fmt" + "io" + "io/ioutil" + "net/http" + "os" + "strconv" + "strings" + + "github.com/cli/cli/pkg/cmdutil" + "github.com/cli/cli/pkg/iostreams" + "github.com/spf13/cobra" +) + +type ApiOptions struct { + IO *iostreams.IOStreams + + RequestMethod string + RequestMethodPassed bool + RequestPath string + MagicFields []string + RawFields []string + RequestHeaders []string + ShowResponseHeaders bool + + HttpClient func() (*http.Client, error) +} + +func NewCmdApi(f *cmdutil.Factory, runF func(*ApiOptions) error) *cobra.Command { + opts := ApiOptions{ + IO: f.IOStreams, + HttpClient: f.HttpClient, + } + + cmd := &cobra.Command{ + Use: "api ", + Short: "Make an authenticated GitHub API request", + Long: `Makes an authenticated HTTP request to the GitHub API and prints the response. + +The argument should either be a path of a GitHub API v3 endpoint, or +"graphql" to access the GitHub API v4. + +The default HTTP request method is "GET" normally and "POST" if any parameters +were added. Override the method with '--method'. + +Pass one or more '--raw-field' values in "=" format to add +JSON-encoded string parameters to the POST body. + +The '--field' flag behaves like '--raw-field' with magic type conversion based +on the format of the value: + +- literal values "true", "false", "null", and integer numbers get converted to + appropriate JSON types; +- if the value starts with "@", the rest of the value is interpreted as a + filename to read the value from. Pass "-" to read from standard input. +`, + Args: cobra.ExactArgs(1), + RunE: func(c *cobra.Command, args []string) error { + opts.RequestPath = args[0] + opts.RequestMethodPassed = c.Flags().Changed("method") + + if runF != nil { + return runF(&opts) + } + return apiRun(&opts) + }, + } + + cmd.Flags().StringVarP(&opts.RequestMethod, "method", "X", "GET", "The HTTP method for the request") + cmd.Flags().StringArrayVarP(&opts.MagicFields, "field", "F", nil, "Add a parameter of inferred type") + cmd.Flags().StringArrayVarP(&opts.RawFields, "raw-field", "f", nil, "Add a string parameter") + cmd.Flags().StringArrayVarP(&opts.RequestHeaders, "header", "H", nil, "Add an additional HTTP request header") + cmd.Flags().BoolVarP(&opts.ShowResponseHeaders, "include", "i", false, "Include HTTP response headers in the output") + return cmd +} + +func apiRun(opts *ApiOptions) error { + params, err := parseFields(opts) + if err != nil { + return err + } + + method := opts.RequestMethod + if len(params) > 0 && !opts.RequestMethodPassed { + method = "POST" + } + + httpClient, err := opts.HttpClient() + if err != nil { + return err + } + + resp, err := httpRequest(httpClient, method, opts.RequestPath, params, opts.RequestHeaders) + if err != nil { + return err + } + + if opts.ShowResponseHeaders { + for name, vals := range resp.Header { + fmt.Fprintf(opts.IO.Out, "%s: %s\r\n", name, strings.Join(vals, ", ")) + } + fmt.Fprint(opts.IO.Out, "\r\n") + } + + if resp.StatusCode == 204 { + return nil + } + defer resp.Body.Close() + + _, err = io.Copy(opts.IO.Out, resp.Body) + if err != nil { + return err + } + + // TODO: detect GraphQL errors + if resp.StatusCode > 299 { + return cmdutil.SilentError + } + + return nil +} + +func parseFields(opts *ApiOptions) (map[string]interface{}, error) { + params := make(map[string]interface{}) + for _, f := range opts.RawFields { + key, value, err := parseField(f) + if err != nil { + return params, err + } + params[key] = value + } + for _, f := range opts.MagicFields { + key, strValue, err := parseField(f) + if err != nil { + return params, err + } + value, err := magicFieldValue(strValue, opts.IO.In) + if err != nil { + return params, fmt.Errorf("error parsing %q value: %w", key, err) + } + params[key] = value + } + return params, nil +} + +func parseField(f string) (string, string, error) { + idx := strings.IndexRune(f, '=') + if idx == -1 { + return f, "", fmt.Errorf("field %q requires a value separated by an '=' sign", f) + } + return f[0:idx], f[idx+1:], nil +} + +func magicFieldValue(v string, stdin io.ReadCloser) (interface{}, error) { + if strings.HasPrefix(v, "@") { + return readUserFile(v[1:], stdin) + } + + if n, err := strconv.Atoi(v); err == nil { + return n, nil + } + + switch v { + case "true": + return true, nil + case "false": + return false, nil + case "null": + return nil, nil + default: + return v, nil + } +} + +func readUserFile(fn string, stdin io.ReadCloser) ([]byte, error) { + var r io.ReadCloser + if fn == "-" { + r = stdin + } else { + var err error + r, err = os.Open(fn) + if err != nil { + return nil, err + } + } + defer r.Close() + return ioutil.ReadAll(r) +} diff --git a/pkg/cmd/api/api_test.go b/pkg/cmd/api/api_test.go new file mode 100644 index 000000000..8149f5906 --- /dev/null +++ b/pkg/cmd/api/api_test.go @@ -0,0 +1,307 @@ +package api + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "net/http" + "os" + "testing" + + "github.com/cli/cli/pkg/cmdutil" + "github.com/cli/cli/pkg/iostreams" + "github.com/google/shlex" + "github.com/stretchr/testify/assert" +) + +func Test_NewCmdApi(t *testing.T) { + f := &cmdutil.Factory{} + + tests := []struct { + name string + cli string + wants ApiOptions + wantsErr bool + }{ + { + name: "no flags", + cli: "graphql", + wants: ApiOptions{ + RequestMethod: "GET", + RequestMethodPassed: false, + RequestPath: "graphql", + RawFields: []string(nil), + MagicFields: []string(nil), + RequestHeaders: []string(nil), + ShowResponseHeaders: false, + }, + wantsErr: false, + }, + { + name: "override method", + cli: "repos/octocat/Spoon-Knife -XDELETE", + wants: ApiOptions{ + RequestMethod: "DELETE", + RequestMethodPassed: true, + RequestPath: "repos/octocat/Spoon-Knife", + RawFields: []string(nil), + MagicFields: []string(nil), + RequestHeaders: []string(nil), + ShowResponseHeaders: false, + }, + wantsErr: false, + }, + { + name: "with fields", + cli: "graphql -f query=QUERY -F body=@file.txt", + wants: ApiOptions{ + RequestMethod: "GET", + RequestMethodPassed: false, + RequestPath: "graphql", + RawFields: []string{"query=QUERY"}, + MagicFields: []string{"body=@file.txt"}, + RequestHeaders: []string(nil), + ShowResponseHeaders: false, + }, + wantsErr: false, + }, + { + name: "with headers", + cli: "user -H 'accept: text/plain' -i", + wants: ApiOptions{ + RequestMethod: "GET", + RequestMethodPassed: false, + RequestPath: "user", + RawFields: []string(nil), + MagicFields: []string(nil), + RequestHeaders: []string{"accept: text/plain"}, + ShowResponseHeaders: true, + }, + wantsErr: false, + }, + { + name: "no arguments", + cli: "", + wantsErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := NewCmdApi(f, func(o *ApiOptions) error { + assert.Equal(t, tt.wants.RequestMethod, o.RequestMethod) + assert.Equal(t, tt.wants.RequestMethodPassed, o.RequestMethodPassed) + assert.Equal(t, tt.wants.RequestPath, o.RequestPath) + assert.Equal(t, tt.wants.RawFields, o.RawFields) + assert.Equal(t, tt.wants.MagicFields, o.MagicFields) + assert.Equal(t, tt.wants.RequestHeaders, o.RequestHeaders) + assert.Equal(t, tt.wants.ShowResponseHeaders, o.ShowResponseHeaders) + return nil + }) + + argv, err := shlex.Split(tt.cli) + assert.NoError(t, err) + cmd.SetArgs(argv) + cmd.SetIn(&bytes.Buffer{}) + cmd.SetOut(&bytes.Buffer{}) + cmd.SetErr(&bytes.Buffer{}) + _, err = cmd.ExecuteC() + if tt.wantsErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + }) + } +} + +func Test_apiRun(t *testing.T) { + tests := []struct { + name string + options ApiOptions + httpResponse *http.Response + err error + stdout string + stderr string + }{ + { + name: "success", + httpResponse: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(bytes.NewBufferString(`bam!`)), + }, + err: nil, + stdout: `bam!`, + stderr: ``, + }, + { + name: "show response headers", + options: ApiOptions{ + ShowResponseHeaders: true, + }, + httpResponse: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(bytes.NewBufferString(`body`)), + Header: http.Header{"Content-Type": []string{"text/plain"}}, + }, + err: nil, + stdout: "Content-Type: text/plain\r\n\r\nbody", + stderr: ``, + }, + { + name: "success 204", + httpResponse: &http.Response{ + StatusCode: 204, + Body: nil, + }, + err: nil, + stdout: ``, + stderr: ``, + }, + { + name: "failure", + httpResponse: &http.Response{ + StatusCode: 502, + Body: ioutil.NopCloser(bytes.NewBufferString(`gateway timeout`)), + }, + err: cmdutil.SilentError, + stdout: `gateway timeout`, + stderr: ``, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + io, _, stdout, stderr := iostreams.Test() + + tt.options.IO = io + tt.options.HttpClient = func() (*http.Client, error) { + var tr roundTripper = func(req *http.Request) (*http.Response, error) { + resp := tt.httpResponse + resp.Request = req + return resp, nil + } + return &http.Client{Transport: tr}, nil + } + + err := apiRun(&tt.options) + if err != tt.err { + t.Errorf("expected error %v, got %v", tt.err, err) + } + + if stdout.String() != tt.stdout { + t.Errorf("expected output %q, got %q", tt.stdout, stdout.String()) + } + if stderr.String() != tt.stderr { + t.Errorf("expected error output %q, got %q", tt.stderr, stderr.String()) + } + }) + } +} + +func Test_parseFields(t *testing.T) { + io, stdin, _, _ := iostreams.Test() + fmt.Fprint(stdin, "pasted contents") + + opts := ApiOptions{ + IO: io, + RawFields: []string{ + "robot=Hubot", + "destroyer=false", + "helper=true", + "location=@work", + }, + MagicFields: []string{ + "input=@-", + "enabled=true", + "victories=123", + }, + } + + params, err := parseFields(&opts) + if err != nil { + t.Fatalf("parseFields error: %v", err) + } + + expect := map[string]interface{}{ + "robot": "Hubot", + "destroyer": "false", + "helper": "true", + "location": "@work", + "input": []byte("pasted contents"), + "enabled": true, + "victories": 123, + } + assert.Equal(t, expect, params) +} + +func Test_magicFieldValue(t *testing.T) { + f, err := ioutil.TempFile("", "gh-test") + if err != nil { + t.Fatal(err) + } + fmt.Fprint(f, "file contents") + f.Close() + t.Cleanup(func() { os.Remove(f.Name()) }) + + type args struct { + v string + stdin io.ReadCloser + } + tests := []struct { + name string + args args + want interface{} + wantErr bool + }{ + { + name: "string", + args: args{v: "hello"}, + want: "hello", + wantErr: false, + }, + { + name: "bool true", + args: args{v: "true"}, + want: true, + wantErr: false, + }, + { + name: "bool false", + args: args{v: "false"}, + want: false, + wantErr: false, + }, + { + name: "null", + args: args{v: "null"}, + want: nil, + wantErr: false, + }, + { + name: "file", + args: args{v: "@" + f.Name()}, + want: []byte("file contents"), + wantErr: false, + }, + { + name: "file error", + args: args{v: "@"}, + want: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := magicFieldValue(tt.args.v, tt.args.stdin) + if (err != nil) != tt.wantErr { + t.Errorf("magicFieldValue() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr { + return + } + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/pkg/cmd/api/http.go b/pkg/cmd/api/http.go new file mode 100644 index 000000000..5f9ebc3ca --- /dev/null +++ b/pkg/cmd/api/http.go @@ -0,0 +1,109 @@ +package api + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" +) + +func httpRequest(client *http.Client, method string, p string, params interface{}, headers []string) (*http.Response, error) { + // TODO: GHE support + url := "https://api.github.com/" + p + var body io.Reader + var bodyIsJSON bool + isGraphQL := p == "graphql" + + switch pp := params.(type) { + case map[string]interface{}: + if strings.EqualFold(method, "GET") { + url = addQuery(url, pp) + } else { + if isGraphQL { + pp = groupGraphQLVariables(pp) + } + b, err := json.Marshal(pp) + if err != nil { + return nil, fmt.Errorf("error serializing parameters: %w", err) + } + body = bytes.NewBuffer(b) + bodyIsJSON = true + } + case io.Reader: + body = pp + case nil: + body = nil + default: + return nil, fmt.Errorf("unrecognized parameters type: %v", params) + } + + req, err := http.NewRequest(method, url, body) + if err != nil { + return nil, err + } + + for _, h := range headers { + idx := strings.IndexRune(h, ':') + if idx == -1 { + return nil, fmt.Errorf("header %q requires a value separated by ':'", h) + } + req.Header.Add(h[0:idx], strings.TrimSpace(h[idx+1:])) + } + if bodyIsJSON && req.Header.Get("Content-Type") == "" { + req.Header.Set("Content-Type", "application/json; charset=utf-8") + } + + return client.Do(req) +} + +func groupGraphQLVariables(params map[string]interface{}) map[string]interface{} { + topLevel := make(map[string]interface{}) + variables := make(map[string]interface{}) + + for key, val := range params { + switch key { + case "query": + topLevel[key] = val + default: + variables[key] = val + } + } + + if len(variables) > 0 { + topLevel["variables"] = variables + } + return topLevel +} + +func addQuery(path string, params map[string]interface{}) string { + if len(params) == 0 { + return path + } + + query := url.Values{} + for key, value := range params { + switch v := value.(type) { + case string: + query.Add(key, v) + case []byte: + query.Add(key, string(v)) + case nil: + query.Add(key, "") + case int: + query.Add(key, fmt.Sprintf("%d", v)) + case bool: + query.Add(key, fmt.Sprintf("%v", v)) + default: + panic(fmt.Sprintf("unknown type %v", v)) + } + } + + sep := "?" + if strings.ContainsRune(path, '?') { + sep = "&" + } + return path + sep + query.Encode() +} diff --git a/pkg/cmd/api/http_test.go b/pkg/cmd/api/http_test.go new file mode 100644 index 000000000..e2ba8680b --- /dev/null +++ b/pkg/cmd/api/http_test.go @@ -0,0 +1,306 @@ +package api + +import ( + "bytes" + "io/ioutil" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_groupGraphQLVariables(t *testing.T) { + tests := []struct { + name string + args map[string]interface{} + want map[string]interface{} + }{ + { + name: "empty", + args: map[string]interface{}{}, + want: map[string]interface{}{}, + }, + { + name: "query only", + args: map[string]interface{}{ + "query": "QUERY", + }, + want: map[string]interface{}{ + "query": "QUERY", + }, + }, + { + name: "variables only", + args: map[string]interface{}{ + "name": "hubot", + }, + want: map[string]interface{}{ + "variables": map[string]interface{}{ + "name": "hubot", + }, + }, + }, + { + name: "query + variables", + args: map[string]interface{}{ + "query": "QUERY", + "name": "hubot", + "power": 9001, + }, + want: map[string]interface{}{ + "query": "QUERY", + "variables": map[string]interface{}{ + "name": "hubot", + "power": 9001, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := groupGraphQLVariables(tt.args) + assert.Equal(t, tt.want, got) + }) + } +} + +type roundTripper func(*http.Request) (*http.Response, error) + +func (f roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +func Test_httpRequest(t *testing.T) { + var tr roundTripper = func(req *http.Request) (*http.Response, error) { + return &http.Response{Request: req}, nil + } + httpClient := http.Client{Transport: tr} + + type args struct { + client *http.Client + method string + p string + params interface{} + headers []string + } + type expects struct { + method string + u string + body string + headers string + } + tests := []struct { + name string + args args + want expects + wantErr bool + }{ + { + name: "simple GET", + args: args{ + client: &httpClient, + method: "GET", + p: "repos/octocat/spoon-knife", + params: nil, + headers: []string{}, + }, + wantErr: false, + want: expects{ + method: "GET", + u: "https://api.github.com/repos/octocat/spoon-knife", + body: "", + headers: "", + }, + }, + { + name: "GET with params", + args: args{ + client: &httpClient, + method: "GET", + p: "repos/octocat/spoon-knife", + params: map[string]interface{}{ + "a": "b", + }, + headers: []string{}, + }, + wantErr: false, + want: expects{ + method: "GET", + u: "https://api.github.com/repos/octocat/spoon-knife?a=b", + body: "", + headers: "", + }, + }, + { + name: "POST with params", + args: args{ + client: &httpClient, + method: "POST", + p: "repos", + params: map[string]interface{}{ + "a": "b", + }, + headers: []string{}, + }, + wantErr: false, + want: expects{ + method: "POST", + u: "https://api.github.com/repos", + body: `{"a":"b"}`, + headers: "Content-Type: application/json; charset=utf-8\r\n", + }, + }, + { + name: "POST GraphQL", + args: args{ + client: &httpClient, + method: "POST", + p: "graphql", + params: map[string]interface{}{ + "a": "b", + }, + headers: []string{}, + }, + wantErr: false, + want: expects{ + method: "POST", + u: "https://api.github.com/graphql", + body: `{"variables":{"a":"b"}}`, + headers: "Content-Type: application/json; charset=utf-8\r\n", + }, + }, + { + name: "POST with body and type", + args: args{ + client: &httpClient, + method: "POST", + p: "repos", + params: bytes.NewBufferString("CUSTOM"), + headers: []string{ + "content-type: text/plain", + "accept: application/json", + }, + }, + wantErr: false, + want: expects{ + method: "POST", + u: "https://api.github.com/repos", + body: `CUSTOM`, + headers: "Accept: application/json\r\nContent-Type: text/plain\r\n", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := httpRequest(tt.args.client, tt.args.method, tt.args.p, tt.args.params, tt.args.headers) + if (err != nil) != tt.wantErr { + t.Errorf("httpRequest() error = %v, wantErr %v", err, tt.wantErr) + return + } + req := got.Request + if req.Method != tt.want.method { + t.Errorf("Request.Method = %q, want %q", req.Method, tt.want.method) + } + if req.URL.String() != tt.want.u { + t.Errorf("Request.URL = %q, want %q", req.URL.String(), tt.want.u) + } + + if tt.want.body != "" { + bb, err := ioutil.ReadAll(req.Body) + if err != nil { + t.Errorf("Request.Body ReadAll error = %v", err) + return + } + if string(bb) != tt.want.body { + t.Errorf("Request.Body = %q, want %q", string(bb), tt.want.body) + } + } + + h := bytes.Buffer{} + err = req.Header.WriteSubset(&h, map[string]bool{}) + if err != nil { + t.Errorf("Request.Header WriteSubset error = %v", err) + return + } + if h.String() != tt.want.headers { + t.Errorf("Request.Header = %q, want %q", h.String(), tt.want.headers) + } + }) + } +} + +func Test_addQuery(t *testing.T) { + type args struct { + path string + params map[string]interface{} + } + tests := []struct { + name string + args args + want string + }{ + { + name: "string", + args: args{ + path: "", + params: map[string]interface{}{"a": "hello"}, + }, + want: "?a=hello", + }, + { + name: "append", + args: args{ + path: "path", + params: map[string]interface{}{"a": "b"}, + }, + want: "path?a=b", + }, + { + name: "append query", + args: args{ + path: "path?foo=bar", + params: map[string]interface{}{"a": "b"}, + }, + want: "path?foo=bar&a=b", + }, + { + name: "[]byte", + args: args{ + path: "", + params: map[string]interface{}{"a": []byte("hello")}, + }, + want: "?a=hello", + }, + { + name: "int", + args: args{ + path: "", + params: map[string]interface{}{"a": 123}, + }, + want: "?a=123", + }, + { + name: "nil", + args: args{ + path: "", + params: map[string]interface{}{"a": nil}, + }, + want: "?a=", + }, + { + name: "bool", + args: args{ + path: "", + params: map[string]interface{}{"a": true, "b": false}, + }, + want: "?a=true&b=false", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := addQuery(tt.args.path, tt.args.params); got != tt.want { + t.Errorf("addQuery() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/cmdutil/errors.go b/pkg/cmdutil/errors.go new file mode 100644 index 000000000..77ca58340 --- /dev/null +++ b/pkg/cmdutil/errors.go @@ -0,0 +1,19 @@ +package cmdutil + +import "errors" + +// FlagError is the kind of error raised in flag processing +type FlagError struct { + Err error +} + +func (fe FlagError) Error() string { + return fe.Err.Error() +} + +func (fe FlagError) Unwrap() error { + return fe.Err +} + +// SilentError is an error that triggers exit code 1 without any error messaging +var SilentError = errors.New("SilentError") diff --git a/pkg/cmdutil/factory.go b/pkg/cmdutil/factory.go new file mode 100644 index 000000000..578b29561 --- /dev/null +++ b/pkg/cmdutil/factory.go @@ -0,0 +1,12 @@ +package cmdutil + +import ( + "net/http" + + "github.com/cli/cli/pkg/iostreams" +) + +type Factory struct { + IOStreams *iostreams.IOStreams + HttpClient func() (*http.Client, error) +} diff --git a/pkg/iostreams/iostreams.go b/pkg/iostreams/iostreams.go new file mode 100644 index 000000000..028b11264 --- /dev/null +++ b/pkg/iostreams/iostreams.go @@ -0,0 +1,33 @@ +package iostreams + +import ( + "bytes" + "io" + "io/ioutil" + "os" +) + +type IOStreams struct { + In io.ReadCloser + Out io.Writer + ErrOut io.Writer +} + +func System() *IOStreams { + return &IOStreams{ + In: os.Stdin, + Out: os.Stdout, + ErrOut: os.Stderr, + } +} + +func Test() (*IOStreams, *bytes.Buffer, *bytes.Buffer, *bytes.Buffer) { + in := &bytes.Buffer{} + out := &bytes.Buffer{} + errOut := &bytes.Buffer{} + return &IOStreams{ + In: ioutil.NopCloser(in), + Out: out, + ErrOut: errOut, + }, in, out, errOut +}