Merge pull request #909 from cli/api-command

Add `api` command for direct API access
This commit is contained in:
Mislav Marohnić 2020-05-27 13:14:27 +02:00 committed by GitHub
commit 658d548c5e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 1033 additions and 30 deletions

View file

@ -16,14 +16,18 @@ import (
// ClientOption represents an argument to NewClient
type ClientOption = func(http.RoundTripper) http.RoundTripper
// NewClient initializes a Client
func NewClient(opts ...ClientOption) *Client {
// NewHTTPClient initializes an http.Client
func NewHTTPClient(opts ...ClientOption) *http.Client {
tr := http.DefaultTransport
for _, opt := range opts {
tr = opt(tr)
}
http := &http.Client{Transport: tr}
client := &Client{http: http}
return &http.Client{Transport: tr}
}
// NewClient initializes a Client
func NewClient(opts ...ClientOption) *Client {
client := &Client{http: NewHTTPClient(opts...)}
return client
}

View file

@ -11,6 +11,7 @@ import (
"github.com/cli/cli/command"
"github.com/cli/cli/internal/config"
"github.com/cli/cli/pkg/cmdutil"
"github.com/cli/cli/update"
"github.com/cli/cli/utils"
"github.com/mgutz/ansi"
@ -48,6 +49,10 @@ func main() {
}
func printError(out io.Writer, err error, cmd *cobra.Command, debug bool) {
if err == cmdutil.SilentError {
return
}
var dnsError *net.DNSError
if errors.As(err, &dnsError) {
fmt.Fprintf(out, "error connecting to %s\n", dnsError.Name)
@ -60,7 +65,7 @@ func printError(out io.Writer, err error, cmd *cobra.Command, debug bool) {
fmt.Fprintln(out, err)
var flagError *command.FlagError
var flagError *cmdutil.FlagError
if errors.As(err, &flagError) || strings.HasPrefix(err.Error(), "unknown command ") {
if !strings.HasSuffix(err.Error(), "\n") {
fmt.Fprintln(out)

View file

@ -7,7 +7,7 @@ import (
"net"
"testing"
"github.com/cli/cli/command"
"github.com/cli/cli/pkg/cmdutil"
"github.com/spf13/cobra"
)
@ -49,7 +49,7 @@ check your internet connection or githubstatus.com
{
name: "Cobra flag error",
args: args{
err: &command.FlagError{Err: errors.New("unknown flag --foo")},
err: &cmdutil.FlagError{Err: errors.New("unknown flag --foo")},
cmd: cmd,
debug: false,
},

View file

@ -73,14 +73,9 @@ var issueStatusCmd = &cobra.Command{
RunE: issueStatus,
}
var issueViewCmd = &cobra.Command{
Use: "view {<number> | <url>}",
Args: func(cmd *cobra.Command, args []string) error {
if len(args) < 1 {
return FlagError{errors.New("issue number or URL required as argument")}
}
return nil
},
Use: "view {<number> | <url>}",
Short: "View an issue",
Args: cobra.ExactArgs(1),
Long: `Display the title, body, and other information about an issue.
With '--web', open the issue in a web browser instead.`,

View file

@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"io"
"net/http"
"os"
"regexp"
"runtime/debug"
@ -13,6 +14,9 @@ import (
"github.com/cli/cli/context"
"github.com/cli/cli/internal/config"
"github.com/cli/cli/internal/ghrepo"
apiCmd "github.com/cli/cli/pkg/cmd/api"
"github.com/cli/cli/pkg/cmdutil"
"github.com/cli/cli/pkg/iostreams"
"github.com/cli/cli/utils"
"github.com/spf13/cobra"
@ -60,21 +64,26 @@ func init() {
if err == pflag.ErrHelp {
return err
}
return &FlagError{Err: err}
return &cmdutil.FlagError{Err: err}
})
}
// FlagError is the kind of error raised in flag processing
type FlagError struct {
Err error
}
func (fe FlagError) Error() string {
return fe.Err.Error()
}
func (fe FlagError) Unwrap() error {
return fe.Err
// TODO: iron out how a factory incorporates context
cmdFactory := &cmdutil.Factory{
IOStreams: iostreams.System(),
HttpClient: func() (*http.Client, error) {
token := os.Getenv("GITHUB_TOKEN")
if len(token) == 0 {
ctx := context.New()
var err error
token, err = ctx.AuthToken()
if err != nil {
return nil, err
}
}
return httpClient(token), nil
},
}
RootCmd.AddCommand(apiCmd.NewCmdApi(cmdFactory, nil))
}
// RootCmd is the entry point of command-line execution
@ -136,6 +145,19 @@ func contextForCommand(cmd *cobra.Command) context.Context {
return ctx
}
// for cmdutil-powered commands
func httpClient(token string) *http.Client {
var opts []api.ClientOption
if verbose := os.Getenv("DEBUG"); verbose != "" {
opts = append(opts, apiVerboseLog())
}
opts = append(opts,
api.AddHeader("Authorization", fmt.Sprintf("token %s", token)),
api.AddHeader("User-Agent", fmt.Sprintf("GitHub CLI %s", Version)),
)
return api.NewHTTPClient(opts...)
}
// overridden in tests
var apiClientForContext = func(ctx context.Context) (*api.Client, error) {
token, err := ctx.AuthToken()

2
go.mod
View file

@ -21,7 +21,7 @@ require (
github.com/shurcooL/graphql v0.0.0-20181231061246-d48a9a75455f // indirect
github.com/spf13/cobra v0.0.6
github.com/spf13/pflag v1.0.5
github.com/stretchr/testify v1.4.0 // indirect
github.com/stretchr/testify v1.5.1
golang.org/x/crypto v0.0.0-20200219234226-1ad67e1f0ef4
golang.org/x/net v0.0.0-20200219183655-46282727080f // indirect
golang.org/x/text v0.3.2

5
go.sum
View file

@ -167,13 +167,14 @@ github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.4.0/go.mod h1:PTJ7Z/lr49W6bUbkmS1V3by4uWynFiR9p7+dSq/yZzE=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1 h1:2vfRuCMp5sSVIDSqO8oNnWJq7mPa6KVP3iPIwFBuy8A=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.2.1/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U=
github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGrc=
github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU=

190
pkg/cmd/api/api.go Normal file
View file

@ -0,0 +1,190 @@
package api
import (
"fmt"
"io"
"io/ioutil"
"net/http"
"os"
"strconv"
"strings"
"github.com/cli/cli/pkg/cmdutil"
"github.com/cli/cli/pkg/iostreams"
"github.com/spf13/cobra"
)
type ApiOptions struct {
IO *iostreams.IOStreams
RequestMethod string
RequestMethodPassed bool
RequestPath string
MagicFields []string
RawFields []string
RequestHeaders []string
ShowResponseHeaders bool
HttpClient func() (*http.Client, error)
}
func NewCmdApi(f *cmdutil.Factory, runF func(*ApiOptions) error) *cobra.Command {
opts := ApiOptions{
IO: f.IOStreams,
HttpClient: f.HttpClient,
}
cmd := &cobra.Command{
Use: "api <endpoint>",
Short: "Make an authenticated GitHub API request",
Long: `Makes an authenticated HTTP request to the GitHub API and prints the response.
The <endpoint> argument should either be a path of a GitHub API v3 endpoint, or
"graphql" to access the GitHub API v4.
The default HTTP request method is "GET" normally and "POST" if any parameters
were added. Override the method with '--method'.
Pass one or more '--raw-field' values in "<key>=<value>" format to add
JSON-encoded string parameters to the POST body.
The '--field' flag behaves like '--raw-field' with magic type conversion based
on the format of the value:
- literal values "true", "false", "null", and integer numbers get converted to
appropriate JSON types;
- if the value starts with "@", the rest of the value is interpreted as a
filename to read the value from. Pass "-" to read from standard input.
`,
Args: cobra.ExactArgs(1),
RunE: func(c *cobra.Command, args []string) error {
opts.RequestPath = args[0]
opts.RequestMethodPassed = c.Flags().Changed("method")
if runF != nil {
return runF(&opts)
}
return apiRun(&opts)
},
}
cmd.Flags().StringVarP(&opts.RequestMethod, "method", "X", "GET", "The HTTP method for the request")
cmd.Flags().StringArrayVarP(&opts.MagicFields, "field", "F", nil, "Add a parameter of inferred type")
cmd.Flags().StringArrayVarP(&opts.RawFields, "raw-field", "f", nil, "Add a string parameter")
cmd.Flags().StringArrayVarP(&opts.RequestHeaders, "header", "H", nil, "Add an additional HTTP request header")
cmd.Flags().BoolVarP(&opts.ShowResponseHeaders, "include", "i", false, "Include HTTP response headers in the output")
return cmd
}
func apiRun(opts *ApiOptions) error {
params, err := parseFields(opts)
if err != nil {
return err
}
method := opts.RequestMethod
if len(params) > 0 && !opts.RequestMethodPassed {
method = "POST"
}
httpClient, err := opts.HttpClient()
if err != nil {
return err
}
resp, err := httpRequest(httpClient, method, opts.RequestPath, params, opts.RequestHeaders)
if err != nil {
return err
}
if opts.ShowResponseHeaders {
for name, vals := range resp.Header {
fmt.Fprintf(opts.IO.Out, "%s: %s\r\n", name, strings.Join(vals, ", "))
}
fmt.Fprint(opts.IO.Out, "\r\n")
}
if resp.StatusCode == 204 {
return nil
}
defer resp.Body.Close()
_, err = io.Copy(opts.IO.Out, resp.Body)
if err != nil {
return err
}
// TODO: detect GraphQL errors
if resp.StatusCode > 299 {
return cmdutil.SilentError
}
return nil
}
func parseFields(opts *ApiOptions) (map[string]interface{}, error) {
params := make(map[string]interface{})
for _, f := range opts.RawFields {
key, value, err := parseField(f)
if err != nil {
return params, err
}
params[key] = value
}
for _, f := range opts.MagicFields {
key, strValue, err := parseField(f)
if err != nil {
return params, err
}
value, err := magicFieldValue(strValue, opts.IO.In)
if err != nil {
return params, fmt.Errorf("error parsing %q value: %w", key, err)
}
params[key] = value
}
return params, nil
}
func parseField(f string) (string, string, error) {
idx := strings.IndexRune(f, '=')
if idx == -1 {
return f, "", fmt.Errorf("field %q requires a value separated by an '=' sign", f)
}
return f[0:idx], f[idx+1:], nil
}
func magicFieldValue(v string, stdin io.ReadCloser) (interface{}, error) {
if strings.HasPrefix(v, "@") {
return readUserFile(v[1:], stdin)
}
if n, err := strconv.Atoi(v); err == nil {
return n, nil
}
switch v {
case "true":
return true, nil
case "false":
return false, nil
case "null":
return nil, nil
default:
return v, nil
}
}
func readUserFile(fn string, stdin io.ReadCloser) ([]byte, error) {
var r io.ReadCloser
if fn == "-" {
r = stdin
} else {
var err error
r, err = os.Open(fn)
if err != nil {
return nil, err
}
}
defer r.Close()
return ioutil.ReadAll(r)
}

307
pkg/cmd/api/api_test.go Normal file
View file

@ -0,0 +1,307 @@
package api
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http"
"os"
"testing"
"github.com/cli/cli/pkg/cmdutil"
"github.com/cli/cli/pkg/iostreams"
"github.com/google/shlex"
"github.com/stretchr/testify/assert"
)
func Test_NewCmdApi(t *testing.T) {
f := &cmdutil.Factory{}
tests := []struct {
name string
cli string
wants ApiOptions
wantsErr bool
}{
{
name: "no flags",
cli: "graphql",
wants: ApiOptions{
RequestMethod: "GET",
RequestMethodPassed: false,
RequestPath: "graphql",
RawFields: []string(nil),
MagicFields: []string(nil),
RequestHeaders: []string(nil),
ShowResponseHeaders: false,
},
wantsErr: false,
},
{
name: "override method",
cli: "repos/octocat/Spoon-Knife -XDELETE",
wants: ApiOptions{
RequestMethod: "DELETE",
RequestMethodPassed: true,
RequestPath: "repos/octocat/Spoon-Knife",
RawFields: []string(nil),
MagicFields: []string(nil),
RequestHeaders: []string(nil),
ShowResponseHeaders: false,
},
wantsErr: false,
},
{
name: "with fields",
cli: "graphql -f query=QUERY -F body=@file.txt",
wants: ApiOptions{
RequestMethod: "GET",
RequestMethodPassed: false,
RequestPath: "graphql",
RawFields: []string{"query=QUERY"},
MagicFields: []string{"body=@file.txt"},
RequestHeaders: []string(nil),
ShowResponseHeaders: false,
},
wantsErr: false,
},
{
name: "with headers",
cli: "user -H 'accept: text/plain' -i",
wants: ApiOptions{
RequestMethod: "GET",
RequestMethodPassed: false,
RequestPath: "user",
RawFields: []string(nil),
MagicFields: []string(nil),
RequestHeaders: []string{"accept: text/plain"},
ShowResponseHeaders: true,
},
wantsErr: false,
},
{
name: "no arguments",
cli: "",
wantsErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cmd := NewCmdApi(f, func(o *ApiOptions) error {
assert.Equal(t, tt.wants.RequestMethod, o.RequestMethod)
assert.Equal(t, tt.wants.RequestMethodPassed, o.RequestMethodPassed)
assert.Equal(t, tt.wants.RequestPath, o.RequestPath)
assert.Equal(t, tt.wants.RawFields, o.RawFields)
assert.Equal(t, tt.wants.MagicFields, o.MagicFields)
assert.Equal(t, tt.wants.RequestHeaders, o.RequestHeaders)
assert.Equal(t, tt.wants.ShowResponseHeaders, o.ShowResponseHeaders)
return nil
})
argv, err := shlex.Split(tt.cli)
assert.NoError(t, err)
cmd.SetArgs(argv)
cmd.SetIn(&bytes.Buffer{})
cmd.SetOut(&bytes.Buffer{})
cmd.SetErr(&bytes.Buffer{})
_, err = cmd.ExecuteC()
if tt.wantsErr {
assert.Error(t, err)
return
}
assert.NoError(t, err)
})
}
}
func Test_apiRun(t *testing.T) {
tests := []struct {
name string
options ApiOptions
httpResponse *http.Response
err error
stdout string
stderr string
}{
{
name: "success",
httpResponse: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewBufferString(`bam!`)),
},
err: nil,
stdout: `bam!`,
stderr: ``,
},
{
name: "show response headers",
options: ApiOptions{
ShowResponseHeaders: true,
},
httpResponse: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewBufferString(`body`)),
Header: http.Header{"Content-Type": []string{"text/plain"}},
},
err: nil,
stdout: "Content-Type: text/plain\r\n\r\nbody",
stderr: ``,
},
{
name: "success 204",
httpResponse: &http.Response{
StatusCode: 204,
Body: nil,
},
err: nil,
stdout: ``,
stderr: ``,
},
{
name: "failure",
httpResponse: &http.Response{
StatusCode: 502,
Body: ioutil.NopCloser(bytes.NewBufferString(`gateway timeout`)),
},
err: cmdutil.SilentError,
stdout: `gateway timeout`,
stderr: ``,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
io, _, stdout, stderr := iostreams.Test()
tt.options.IO = io
tt.options.HttpClient = func() (*http.Client, error) {
var tr roundTripper = func(req *http.Request) (*http.Response, error) {
resp := tt.httpResponse
resp.Request = req
return resp, nil
}
return &http.Client{Transport: tr}, nil
}
err := apiRun(&tt.options)
if err != tt.err {
t.Errorf("expected error %v, got %v", tt.err, err)
}
if stdout.String() != tt.stdout {
t.Errorf("expected output %q, got %q", tt.stdout, stdout.String())
}
if stderr.String() != tt.stderr {
t.Errorf("expected error output %q, got %q", tt.stderr, stderr.String())
}
})
}
}
func Test_parseFields(t *testing.T) {
io, stdin, _, _ := iostreams.Test()
fmt.Fprint(stdin, "pasted contents")
opts := ApiOptions{
IO: io,
RawFields: []string{
"robot=Hubot",
"destroyer=false",
"helper=true",
"location=@work",
},
MagicFields: []string{
"input=@-",
"enabled=true",
"victories=123",
},
}
params, err := parseFields(&opts)
if err != nil {
t.Fatalf("parseFields error: %v", err)
}
expect := map[string]interface{}{
"robot": "Hubot",
"destroyer": "false",
"helper": "true",
"location": "@work",
"input": []byte("pasted contents"),
"enabled": true,
"victories": 123,
}
assert.Equal(t, expect, params)
}
func Test_magicFieldValue(t *testing.T) {
f, err := ioutil.TempFile("", "gh-test")
if err != nil {
t.Fatal(err)
}
fmt.Fprint(f, "file contents")
f.Close()
t.Cleanup(func() { os.Remove(f.Name()) })
type args struct {
v string
stdin io.ReadCloser
}
tests := []struct {
name string
args args
want interface{}
wantErr bool
}{
{
name: "string",
args: args{v: "hello"},
want: "hello",
wantErr: false,
},
{
name: "bool true",
args: args{v: "true"},
want: true,
wantErr: false,
},
{
name: "bool false",
args: args{v: "false"},
want: false,
wantErr: false,
},
{
name: "null",
args: args{v: "null"},
want: nil,
wantErr: false,
},
{
name: "file",
args: args{v: "@" + f.Name()},
want: []byte("file contents"),
wantErr: false,
},
{
name: "file error",
args: args{v: "@"},
want: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := magicFieldValue(tt.args.v, tt.args.stdin)
if (err != nil) != tt.wantErr {
t.Errorf("magicFieldValue() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
return
}
assert.Equal(t, tt.want, got)
})
}
}

109
pkg/cmd/api/http.go Normal file
View file

@ -0,0 +1,109 @@
package api
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
)
func httpRequest(client *http.Client, method string, p string, params interface{}, headers []string) (*http.Response, error) {
// TODO: GHE support
url := "https://api.github.com/" + p
var body io.Reader
var bodyIsJSON bool
isGraphQL := p == "graphql"
switch pp := params.(type) {
case map[string]interface{}:
if strings.EqualFold(method, "GET") {
url = addQuery(url, pp)
} else {
if isGraphQL {
pp = groupGraphQLVariables(pp)
}
b, err := json.Marshal(pp)
if err != nil {
return nil, fmt.Errorf("error serializing parameters: %w", err)
}
body = bytes.NewBuffer(b)
bodyIsJSON = true
}
case io.Reader:
body = pp
case nil:
body = nil
default:
return nil, fmt.Errorf("unrecognized parameters type: %v", params)
}
req, err := http.NewRequest(method, url, body)
if err != nil {
return nil, err
}
for _, h := range headers {
idx := strings.IndexRune(h, ':')
if idx == -1 {
return nil, fmt.Errorf("header %q requires a value separated by ':'", h)
}
req.Header.Add(h[0:idx], strings.TrimSpace(h[idx+1:]))
}
if bodyIsJSON && req.Header.Get("Content-Type") == "" {
req.Header.Set("Content-Type", "application/json; charset=utf-8")
}
return client.Do(req)
}
func groupGraphQLVariables(params map[string]interface{}) map[string]interface{} {
topLevel := make(map[string]interface{})
variables := make(map[string]interface{})
for key, val := range params {
switch key {
case "query":
topLevel[key] = val
default:
variables[key] = val
}
}
if len(variables) > 0 {
topLevel["variables"] = variables
}
return topLevel
}
func addQuery(path string, params map[string]interface{}) string {
if len(params) == 0 {
return path
}
query := url.Values{}
for key, value := range params {
switch v := value.(type) {
case string:
query.Add(key, v)
case []byte:
query.Add(key, string(v))
case nil:
query.Add(key, "")
case int:
query.Add(key, fmt.Sprintf("%d", v))
case bool:
query.Add(key, fmt.Sprintf("%v", v))
default:
panic(fmt.Sprintf("unknown type %v", v))
}
}
sep := "?"
if strings.ContainsRune(path, '?') {
sep = "&"
}
return path + sep + query.Encode()
}

306
pkg/cmd/api/http_test.go Normal file
View file

@ -0,0 +1,306 @@
package api
import (
"bytes"
"io/ioutil"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
)
func Test_groupGraphQLVariables(t *testing.T) {
tests := []struct {
name string
args map[string]interface{}
want map[string]interface{}
}{
{
name: "empty",
args: map[string]interface{}{},
want: map[string]interface{}{},
},
{
name: "query only",
args: map[string]interface{}{
"query": "QUERY",
},
want: map[string]interface{}{
"query": "QUERY",
},
},
{
name: "variables only",
args: map[string]interface{}{
"name": "hubot",
},
want: map[string]interface{}{
"variables": map[string]interface{}{
"name": "hubot",
},
},
},
{
name: "query + variables",
args: map[string]interface{}{
"query": "QUERY",
"name": "hubot",
"power": 9001,
},
want: map[string]interface{}{
"query": "QUERY",
"variables": map[string]interface{}{
"name": "hubot",
"power": 9001,
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := groupGraphQLVariables(tt.args)
assert.Equal(t, tt.want, got)
})
}
}
type roundTripper func(*http.Request) (*http.Response, error)
func (f roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
func Test_httpRequest(t *testing.T) {
var tr roundTripper = func(req *http.Request) (*http.Response, error) {
return &http.Response{Request: req}, nil
}
httpClient := http.Client{Transport: tr}
type args struct {
client *http.Client
method string
p string
params interface{}
headers []string
}
type expects struct {
method string
u string
body string
headers string
}
tests := []struct {
name string
args args
want expects
wantErr bool
}{
{
name: "simple GET",
args: args{
client: &httpClient,
method: "GET",
p: "repos/octocat/spoon-knife",
params: nil,
headers: []string{},
},
wantErr: false,
want: expects{
method: "GET",
u: "https://api.github.com/repos/octocat/spoon-knife",
body: "",
headers: "",
},
},
{
name: "GET with params",
args: args{
client: &httpClient,
method: "GET",
p: "repos/octocat/spoon-knife",
params: map[string]interface{}{
"a": "b",
},
headers: []string{},
},
wantErr: false,
want: expects{
method: "GET",
u: "https://api.github.com/repos/octocat/spoon-knife?a=b",
body: "",
headers: "",
},
},
{
name: "POST with params",
args: args{
client: &httpClient,
method: "POST",
p: "repos",
params: map[string]interface{}{
"a": "b",
},
headers: []string{},
},
wantErr: false,
want: expects{
method: "POST",
u: "https://api.github.com/repos",
body: `{"a":"b"}`,
headers: "Content-Type: application/json; charset=utf-8\r\n",
},
},
{
name: "POST GraphQL",
args: args{
client: &httpClient,
method: "POST",
p: "graphql",
params: map[string]interface{}{
"a": "b",
},
headers: []string{},
},
wantErr: false,
want: expects{
method: "POST",
u: "https://api.github.com/graphql",
body: `{"variables":{"a":"b"}}`,
headers: "Content-Type: application/json; charset=utf-8\r\n",
},
},
{
name: "POST with body and type",
args: args{
client: &httpClient,
method: "POST",
p: "repos",
params: bytes.NewBufferString("CUSTOM"),
headers: []string{
"content-type: text/plain",
"accept: application/json",
},
},
wantErr: false,
want: expects{
method: "POST",
u: "https://api.github.com/repos",
body: `CUSTOM`,
headers: "Accept: application/json\r\nContent-Type: text/plain\r\n",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := httpRequest(tt.args.client, tt.args.method, tt.args.p, tt.args.params, tt.args.headers)
if (err != nil) != tt.wantErr {
t.Errorf("httpRequest() error = %v, wantErr %v", err, tt.wantErr)
return
}
req := got.Request
if req.Method != tt.want.method {
t.Errorf("Request.Method = %q, want %q", req.Method, tt.want.method)
}
if req.URL.String() != tt.want.u {
t.Errorf("Request.URL = %q, want %q", req.URL.String(), tt.want.u)
}
if tt.want.body != "" {
bb, err := ioutil.ReadAll(req.Body)
if err != nil {
t.Errorf("Request.Body ReadAll error = %v", err)
return
}
if string(bb) != tt.want.body {
t.Errorf("Request.Body = %q, want %q", string(bb), tt.want.body)
}
}
h := bytes.Buffer{}
err = req.Header.WriteSubset(&h, map[string]bool{})
if err != nil {
t.Errorf("Request.Header WriteSubset error = %v", err)
return
}
if h.String() != tt.want.headers {
t.Errorf("Request.Header = %q, want %q", h.String(), tt.want.headers)
}
})
}
}
func Test_addQuery(t *testing.T) {
type args struct {
path string
params map[string]interface{}
}
tests := []struct {
name string
args args
want string
}{
{
name: "string",
args: args{
path: "",
params: map[string]interface{}{"a": "hello"},
},
want: "?a=hello",
},
{
name: "append",
args: args{
path: "path",
params: map[string]interface{}{"a": "b"},
},
want: "path?a=b",
},
{
name: "append query",
args: args{
path: "path?foo=bar",
params: map[string]interface{}{"a": "b"},
},
want: "path?foo=bar&a=b",
},
{
name: "[]byte",
args: args{
path: "",
params: map[string]interface{}{"a": []byte("hello")},
},
want: "?a=hello",
},
{
name: "int",
args: args{
path: "",
params: map[string]interface{}{"a": 123},
},
want: "?a=123",
},
{
name: "nil",
args: args{
path: "",
params: map[string]interface{}{"a": nil},
},
want: "?a=",
},
{
name: "bool",
args: args{
path: "",
params: map[string]interface{}{"a": true, "b": false},
},
want: "?a=true&b=false",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := addQuery(tt.args.path, tt.args.params); got != tt.want {
t.Errorf("addQuery() = %v, want %v", got, tt.want)
}
})
}
}

19
pkg/cmdutil/errors.go Normal file
View file

@ -0,0 +1,19 @@
package cmdutil
import "errors"
// FlagError is the kind of error raised in flag processing
type FlagError struct {
Err error
}
func (fe FlagError) Error() string {
return fe.Err.Error()
}
func (fe FlagError) Unwrap() error {
return fe.Err
}
// SilentError is an error that triggers exit code 1 without any error messaging
var SilentError = errors.New("SilentError")

12
pkg/cmdutil/factory.go Normal file
View file

@ -0,0 +1,12 @@
package cmdutil
import (
"net/http"
"github.com/cli/cli/pkg/iostreams"
)
type Factory struct {
IOStreams *iostreams.IOStreams
HttpClient func() (*http.Client, error)
}

View file

@ -0,0 +1,33 @@
package iostreams
import (
"bytes"
"io"
"io/ioutil"
"os"
)
type IOStreams struct {
In io.ReadCloser
Out io.Writer
ErrOut io.Writer
}
func System() *IOStreams {
return &IOStreams{
In: os.Stdin,
Out: os.Stdout,
ErrOut: os.Stderr,
}
}
func Test() (*IOStreams, *bytes.Buffer, *bytes.Buffer, *bytes.Buffer) {
in := &bytes.Buffer{}
out := &bytes.Buffer{}
errOut := &bytes.Buffer{}
return &IOStreams{
In: ioutil.NopCloser(in),
Out: out,
ErrOut: errOut,
}, in, out, errOut
}