diff --git a/internal/ghinstance/host.go b/internal/ghinstance/host.go index 642dd0846..76639ed26 100644 --- a/internal/ghinstance/host.go +++ b/internal/ghinstance/host.go @@ -1,6 +1,7 @@ package ghinstance import ( + "errors" "fmt" "strings" ) @@ -42,6 +43,21 @@ func NormalizeHostname(h string) string { return hostname } +func HostnameValidator(v interface{}) error { + hostname, valid := v.(string) + if !valid { + return errors.New("hostname is not a string") + } + + if len(strings.TrimSpace(hostname)) < 1 { + return errors.New("a value is required") + } + if strings.ContainsRune(hostname, '/') || strings.ContainsRune(hostname, ':') { + return errors.New("invalid hostname") + } + return nil +} + func GraphQLEndpoint(hostname string) string { if IsEnterprise(hostname) { return fmt.Sprintf("https://%s/api/graphql", hostname) diff --git a/internal/ghinstance/host_test.go b/internal/ghinstance/host_test.go index 10a40a432..787569c68 100644 --- a/internal/ghinstance/host_test.go +++ b/internal/ghinstance/host_test.go @@ -2,6 +2,8 @@ package ghinstance import ( "testing" + + "github.com/stretchr/testify/assert" ) func TestOverridableDefault(t *testing.T) { @@ -97,6 +99,50 @@ func TestNormalizeHostname(t *testing.T) { } } +func TestHostnameValidator(t *testing.T) { + tests := []struct { + name string + input interface{} + wantsErr bool + }{ + { + name: "valid hostname", + input: "internal.instance", + wantsErr: false, + }, + { + name: "hostname with slashes", + input: "//internal.instance", + wantsErr: true, + }, + { + name: "empty hostname", + input: " ", + wantsErr: true, + }, + { + name: "hostname with colon", + input: "internal.instance:2205", + wantsErr: true, + }, + { + name: "non-string hostname", + input: 62, + wantsErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := HostnameValidator(tt.input) + if tt.wantsErr { + assert.Error(t, err) + return + } + assert.Equal(t, nil, err) + }) + } +} func TestGraphQLEndpoint(t *testing.T) { tests := []struct { host string diff --git a/pkg/cmd/api/api.go b/pkg/cmd/api/api.go index 9ec4de435..4d64a260b 100644 --- a/pkg/cmd/api/api.go +++ b/pkg/cmd/api/api.go @@ -27,6 +27,7 @@ import ( type ApiOptions struct { IO *iostreams.IOStreams + Hostname string RequestMethod string RequestMethodPassed bool RequestPath string @@ -101,7 +102,7 @@ original query accepts an '$endCursor: String' variable and that it fetches the } } ' - + $ gh api graphql --paginate -f query=' query($endCursor: String) { viewer { @@ -121,6 +122,8 @@ original query accepts an '$endCursor: String' variable and that it fetches the GITHUB_TOKEN: an authentication token for github.com API requests. GITHUB_ENTERPRISE_TOKEN: an authentication token for API requests to GitHub Enterprise. + + GH_HOST: make the request to a GitHub host other than github.com. `), }, Args: cobra.ExactArgs(1), @@ -128,6 +131,12 @@ original query accepts an '$endCursor: String' variable and that it fetches the opts.RequestPath = args[0] opts.RequestMethodPassed = c.Flags().Changed("method") + if c.Flags().Changed("hostname") { + if err := ghinstance.HostnameValidator(opts.Hostname); err != nil { + return &cmdutil.FlagError{Err: fmt.Errorf("error parsing --hostname: %w", err)} + } + } + if opts.Paginate && !strings.EqualFold(opts.RequestMethod, "GET") && opts.RequestPath != "graphql" { return &cmdutil.FlagError{Err: errors.New(`the '--paginate' option is not supported for non-GET requests`)} } @@ -142,6 +151,7 @@ original query accepts an '$endCursor: String' variable and that it fetches the }, } + cmd.Flags().StringVar(&opts.Hostname, "hostname", "", "The GitHub hostname for the request (default \"github.com\")") cmd.Flags().StringVarP(&opts.RequestMethod, "method", "X", "GET", "The HTTP method for the request") cmd.Flags().StringArrayVarP(&opts.MagicFields, "field", "F", nil, "Add a parameter of inferred type") cmd.Flags().StringArrayVarP(&opts.RawFields, "raw-field", "f", nil, "Add a string parameter") @@ -206,6 +216,9 @@ func apiRun(opts *ApiOptions) error { } host := ghinstance.OverridableDefault() + if opts.Hostname != "" { + host = opts.Hostname + } hasNextPage := true for hasNextPage { diff --git a/pkg/cmd/api/api_test.go b/pkg/cmd/api/api_test.go index a810fa7fc..6cd693ce1 100644 --- a/pkg/cmd/api/api_test.go +++ b/pkg/cmd/api/api_test.go @@ -31,6 +31,7 @@ func Test_NewCmdApi(t *testing.T) { name: "no flags", cli: "graphql", wants: ApiOptions{ + Hostname: "", RequestMethod: "GET", RequestMethodPassed: false, RequestPath: "graphql", @@ -48,6 +49,7 @@ func Test_NewCmdApi(t *testing.T) { name: "override method", cli: "repos/octocat/Spoon-Knife -XDELETE", wants: ApiOptions{ + Hostname: "", RequestMethod: "DELETE", RequestMethodPassed: true, RequestPath: "repos/octocat/Spoon-Knife", @@ -65,6 +67,7 @@ func Test_NewCmdApi(t *testing.T) { name: "with fields", cli: "graphql -f query=QUERY -F body=@file.txt", wants: ApiOptions{ + Hostname: "", RequestMethod: "GET", RequestMethodPassed: false, RequestPath: "graphql", @@ -82,6 +85,7 @@ func Test_NewCmdApi(t *testing.T) { name: "with headers", cli: "user -H 'accept: text/plain' -i", wants: ApiOptions{ + Hostname: "", RequestMethod: "GET", RequestMethodPassed: false, RequestPath: "user", @@ -99,6 +103,7 @@ func Test_NewCmdApi(t *testing.T) { name: "with pagination", cli: "repos/OWNER/REPO/issues --paginate", wants: ApiOptions{ + Hostname: "", RequestMethod: "GET", RequestMethodPassed: false, RequestPath: "repos/OWNER/REPO/issues", @@ -116,6 +121,7 @@ func Test_NewCmdApi(t *testing.T) { name: "with silenced output", cli: "repos/OWNER/REPO/issues --silent", wants: ApiOptions{ + Hostname: "", RequestMethod: "GET", RequestMethodPassed: false, RequestPath: "repos/OWNER/REPO/issues", @@ -138,6 +144,7 @@ func Test_NewCmdApi(t *testing.T) { name: "GraphQL pagination", cli: "-XPOST graphql --paginate", wants: ApiOptions{ + Hostname: "", RequestMethod: "POST", RequestMethodPassed: true, RequestPath: "graphql", @@ -160,6 +167,7 @@ func Test_NewCmdApi(t *testing.T) { name: "with request body from file", cli: "user --input myfile", wants: ApiOptions{ + Hostname: "", RequestMethod: "GET", RequestMethodPassed: false, RequestPath: "user", @@ -178,10 +186,29 @@ func Test_NewCmdApi(t *testing.T) { cli: "", wantsErr: true, }, + { + name: "with hostname", + cli: "graphql --hostname tom.petty", + wants: ApiOptions{ + Hostname: "tom.petty", + RequestMethod: "GET", + RequestMethodPassed: false, + RequestPath: "graphql", + RequestInputFile: "", + RawFields: []string(nil), + MagicFields: []string(nil), + RequestHeaders: []string(nil), + ShowResponseHeaders: false, + Paginate: false, + Silent: false, + }, + wantsErr: false, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cmd := NewCmdApi(f, func(o *ApiOptions) error { + assert.Equal(t, tt.wants.Hostname, o.Hostname) assert.Equal(t, tt.wants.RequestMethod, o.RequestMethod) assert.Equal(t, tt.wants.RequestMethodPassed, o.RequestMethodPassed) assert.Equal(t, tt.wants.RequestPath, o.RequestPath) diff --git a/pkg/cmd/auth/login/login.go b/pkg/cmd/auth/login/login.go index 7ded3056b..929d6b61d 100644 --- a/pkg/cmd/auth/login/login.go +++ b/pkg/cmd/auth/login/login.go @@ -88,7 +88,7 @@ func NewCmdLogin(f *cmdutil.Factory, runF func(*LoginOptions) error) *cobra.Comm } if cmd.Flags().Changed("hostname") { - if err := hostnameValidator(opts.Hostname); err != nil { + if err := ghinstance.HostnameValidator(opts.Hostname); err != nil { return &cmdutil.FlagError{Err: fmt.Errorf("error parsing --hostname: %w", err)} } } @@ -166,7 +166,7 @@ func loginRun(opts *LoginOptions) error { if isEnterprise { err := prompt.SurveyAskOne(&survey.Input{ Message: "GHE hostname:", - }, &hostname, survey.WithValidator(hostnameValidator)) + }, &hostname, survey.WithValidator(ghinstance.HostnameValidator)) if err != nil { return fmt.Errorf("could not prompt: %w", err) } @@ -307,17 +307,6 @@ func loginRun(opts *LoginOptions) error { return nil } -func hostnameValidator(v interface{}) error { - val := v.(string) - if len(strings.TrimSpace(val)) < 1 { - return errors.New("a value is required") - } - if strings.ContainsRune(val, '/') || strings.ContainsRune(val, ':') { - return errors.New("invalid hostname") - } - return nil -} - func getAccessTokenTip(hostname string) string { ghHostname := hostname if ghHostname == "" {