diff --git a/pkg/cmd/attestation/auth/host.go b/pkg/cmd/attestation/auth/host.go index 998dcb7f5..1b5a344ea 100644 --- a/pkg/cmd/attestation/auth/host.go +++ b/pkg/cmd/attestation/auth/host.go @@ -3,14 +3,18 @@ package auth import ( "errors" + "github.com/cli/cli/v2/internal/ghinstance" "github.com/cli/go-gh/v2/pkg/auth" ) -var ErrUnsupportedHost = errors.New("The GH_HOST environment variable is set to a custom GitHub host. gh attestation does not currently support custom GitHub Enterprise hosts") +var ErrUnsupportedHost = errors.New("An unsupported host was detected. Note that gh attestation does not currently support GHES") func IsHostSupported() error { host, _ := auth.DefaultHost() - if host != "github.com" { + + // Note that this check is slightly redundant as Tenancy should not be considered Enterprise + // but the ghinstance package has not been updated to reflect this yet. + if ghinstance.IsEnterprise(host) && !ghinstance.IsTenancy(host) { return ErrUnsupportedHost } return nil diff --git a/pkg/cmd/attestation/auth/host_test.go b/pkg/cmd/attestation/auth/host_test.go new file mode 100644 index 000000000..1d84888c4 --- /dev/null +++ b/pkg/cmd/attestation/auth/host_test.go @@ -0,0 +1,49 @@ +package auth + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestIsHostSupported(t *testing.T) { + testcases := []struct { + name string + expectedErr bool + host string + }{ + { + name: "Default github.com host", + expectedErr: false, + host: "github.com", + }, + { + name: "Localhost", + expectedErr: false, + host: "github.localhost", + }, + { + name: "No host set", + expectedErr: false, + host: "", + }, + { + name: "GHE tenant host", + expectedErr: false, + host: "some-tenant.ghe.com", + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + t.Setenv("GH_HOST", tc.host) + + err := IsHostSupported() + if tc.expectedErr { + require.ErrorIs(t, err, ErrUnsupportedHost) + } else { + require.NoError(t, err) + } + }) + } +}