diff --git a/pkg/cmd/api/api.go b/pkg/cmd/api/api.go index 9ec4de435..67883d2b8 100644 --- a/pkg/cmd/api/api.go +++ b/pkg/cmd/api/api.go @@ -21,12 +21,14 @@ import ( "github.com/cli/cli/pkg/cmdutil" "github.com/cli/cli/pkg/iostreams" "github.com/cli/cli/pkg/jsoncolor" + "github.com/cli/cli/utils" "github.com/spf13/cobra" ) type ApiOptions struct { IO *iostreams.IOStreams + Hostname string RequestMethod string RequestMethodPassed bool RequestPath string @@ -101,7 +103,7 @@ original query accepts an '$endCursor: String' variable and that it fetches the } } ' - + $ gh api graphql --paginate -f query=' query($endCursor: String) { viewer { @@ -128,6 +130,12 @@ original query accepts an '$endCursor: String' variable and that it fetches the opts.RequestPath = args[0] opts.RequestMethodPassed = c.Flags().Changed("method") + if c.Flags().Changed("hostname") { + if err := utils.HostnameValidator(opts.Hostname); err != nil { + return &cmdutil.FlagError{Err: fmt.Errorf("error parsing --hostname: %w", err)} + } + } + if opts.Paginate && !strings.EqualFold(opts.RequestMethod, "GET") && opts.RequestPath != "graphql" { return &cmdutil.FlagError{Err: errors.New(`the '--paginate' option is not supported for non-GET requests`)} } @@ -142,6 +150,7 @@ original query accepts an '$endCursor: String' variable and that it fetches the }, } + cmd.Flags().StringVar(&opts.Hostname, "hostname", "", "The hostname of the GitHub instance for the request") cmd.Flags().StringVarP(&opts.RequestMethod, "method", "X", "GET", "The HTTP method for the request") cmd.Flags().StringArrayVarP(&opts.MagicFields, "field", "F", nil, "Add a parameter of inferred type") cmd.Flags().StringArrayVarP(&opts.RawFields, "raw-field", "f", nil, "Add a string parameter") @@ -206,6 +215,9 @@ func apiRun(opts *ApiOptions) error { } host := ghinstance.OverridableDefault() + if opts.Hostname != "" { + host = opts.Hostname + } hasNextPage := true for hasNextPage { diff --git a/pkg/cmd/api/api_test.go b/pkg/cmd/api/api_test.go index a810fa7fc..6cd693ce1 100644 --- a/pkg/cmd/api/api_test.go +++ b/pkg/cmd/api/api_test.go @@ -31,6 +31,7 @@ func Test_NewCmdApi(t *testing.T) { name: "no flags", cli: "graphql", wants: ApiOptions{ + Hostname: "", RequestMethod: "GET", RequestMethodPassed: false, RequestPath: "graphql", @@ -48,6 +49,7 @@ func Test_NewCmdApi(t *testing.T) { name: "override method", cli: "repos/octocat/Spoon-Knife -XDELETE", wants: ApiOptions{ + Hostname: "", RequestMethod: "DELETE", RequestMethodPassed: true, RequestPath: "repos/octocat/Spoon-Knife", @@ -65,6 +67,7 @@ func Test_NewCmdApi(t *testing.T) { name: "with fields", cli: "graphql -f query=QUERY -F body=@file.txt", wants: ApiOptions{ + Hostname: "", RequestMethod: "GET", RequestMethodPassed: false, RequestPath: "graphql", @@ -82,6 +85,7 @@ func Test_NewCmdApi(t *testing.T) { name: "with headers", cli: "user -H 'accept: text/plain' -i", wants: ApiOptions{ + Hostname: "", RequestMethod: "GET", RequestMethodPassed: false, RequestPath: "user", @@ -99,6 +103,7 @@ func Test_NewCmdApi(t *testing.T) { name: "with pagination", cli: "repos/OWNER/REPO/issues --paginate", wants: ApiOptions{ + Hostname: "", RequestMethod: "GET", RequestMethodPassed: false, RequestPath: "repos/OWNER/REPO/issues", @@ -116,6 +121,7 @@ func Test_NewCmdApi(t *testing.T) { name: "with silenced output", cli: "repos/OWNER/REPO/issues --silent", wants: ApiOptions{ + Hostname: "", RequestMethod: "GET", RequestMethodPassed: false, RequestPath: "repos/OWNER/REPO/issues", @@ -138,6 +144,7 @@ func Test_NewCmdApi(t *testing.T) { name: "GraphQL pagination", cli: "-XPOST graphql --paginate", wants: ApiOptions{ + Hostname: "", RequestMethod: "POST", RequestMethodPassed: true, RequestPath: "graphql", @@ -160,6 +167,7 @@ func Test_NewCmdApi(t *testing.T) { name: "with request body from file", cli: "user --input myfile", wants: ApiOptions{ + Hostname: "", RequestMethod: "GET", RequestMethodPassed: false, RequestPath: "user", @@ -178,10 +186,29 @@ func Test_NewCmdApi(t *testing.T) { cli: "", wantsErr: true, }, + { + name: "with hostname", + cli: "graphql --hostname tom.petty", + wants: ApiOptions{ + Hostname: "tom.petty", + RequestMethod: "GET", + RequestMethodPassed: false, + RequestPath: "graphql", + RequestInputFile: "", + RawFields: []string(nil), + MagicFields: []string(nil), + RequestHeaders: []string(nil), + ShowResponseHeaders: false, + Paginate: false, + Silent: false, + }, + wantsErr: false, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cmd := NewCmdApi(f, func(o *ApiOptions) error { + assert.Equal(t, tt.wants.Hostname, o.Hostname) assert.Equal(t, tt.wants.RequestMethod, o.RequestMethod) assert.Equal(t, tt.wants.RequestMethodPassed, o.RequestMethodPassed) assert.Equal(t, tt.wants.RequestPath, o.RequestPath) diff --git a/pkg/cmd/auth/login/login.go b/pkg/cmd/auth/login/login.go index 7ded3056b..c3c310cdb 100644 --- a/pkg/cmd/auth/login/login.go +++ b/pkg/cmd/auth/login/login.go @@ -88,7 +88,7 @@ func NewCmdLogin(f *cmdutil.Factory, runF func(*LoginOptions) error) *cobra.Comm } if cmd.Flags().Changed("hostname") { - if err := hostnameValidator(opts.Hostname); err != nil { + if err := utils.HostnameValidator(opts.Hostname); err != nil { return &cmdutil.FlagError{Err: fmt.Errorf("error parsing --hostname: %w", err)} } } @@ -166,7 +166,7 @@ func loginRun(opts *LoginOptions) error { if isEnterprise { err := prompt.SurveyAskOne(&survey.Input{ Message: "GHE hostname:", - }, &hostname, survey.WithValidator(hostnameValidator)) + }, &hostname, survey.WithValidator(utils.HostnameValidator)) if err != nil { return fmt.Errorf("could not prompt: %w", err) } @@ -307,17 +307,6 @@ func loginRun(opts *LoginOptions) error { return nil } -func hostnameValidator(v interface{}) error { - val := v.(string) - if len(strings.TrimSpace(val)) < 1 { - return errors.New("a value is required") - } - if strings.ContainsRune(val, '/') || strings.ContainsRune(val, ':') { - return errors.New("invalid hostname") - } - return nil -} - func getAccessTokenTip(hostname string) string { ghHostname := hostname if ghHostname == "" { diff --git a/utils/utils.go b/utils/utils.go index 69e48c150..155a04fca 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -1,6 +1,7 @@ package utils import ( + "errors" "fmt" "io" "net/url" @@ -28,6 +29,21 @@ func OpenInBrowser(url string) error { return err } +func HostnameValidator(v interface{}) error { + hostname, valid := v.(string) + if !valid { + return errors.New("hostname is not a string") + } + + if len(strings.TrimSpace(hostname)) < 1 { + return errors.New("a value is required") + } + if strings.ContainsRune(hostname, '/') || strings.ContainsRune(hostname, ':') { + return errors.New("invalid hostname") + } + return nil +} + func Pluralize(num int, thing string) string { if num == 1 { return fmt.Sprintf("%d %s", num, thing) diff --git a/utils/utils_test.go b/utils/utils_test.go index 0891c2a39..9ed8d2afa 100644 --- a/utils/utils_test.go +++ b/utils/utils_test.go @@ -3,6 +3,8 @@ package utils import ( "testing" "time" + + "github.com/stretchr/testify/assert" ) func TestFuzzyAgo(t *testing.T) { @@ -36,3 +38,48 @@ func TestFuzzyAgo(t *testing.T) { } } } + +func TestHostnameValidator(t *testing.T) { + tests := []struct { + name string + input interface{} + wantsErr bool + }{ + { + name: "valid hostname", + input: "internal.instance", + wantsErr: false, + }, + { + name: "hostname with slashes", + input: "//internal.instance", + wantsErr: true, + }, + { + name: "empty hostname", + input: " ", + wantsErr: true, + }, + { + name: "hostname with colon", + input: "internal.instance:2205", + wantsErr: true, + }, + { + name: "non-string hostname", + input: 62, + wantsErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := HostnameValidator(tt.input) + if tt.wantsErr { + assert.Error(t, err) + return + } + assert.Equal(t, nil, err) + }) + } +}