diff --git a/pkg/cmd/attestation/auth/host.go b/pkg/cmd/attestation/auth/host.go index 1e5206813..1b5a344ea 100644 --- a/pkg/cmd/attestation/auth/host.go +++ b/pkg/cmd/attestation/auth/host.go @@ -2,32 +2,19 @@ package auth import ( "errors" - "strings" + "github.com/cli/cli/v2/internal/ghinstance" "github.com/cli/go-gh/v2/pkg/auth" ) var ErrUnsupportedHost = errors.New("An unsupported host was detected. Note that gh attestation does not currently support GHES") -const ( - github = "github.com" - localhost = "github.localhost" - // tenancyHost is the domain name of a tenancy GitHub instance - tenancyHost = "ghe.com" -) - -func isEnterprise(host string) bool { - return host != github && host != localhost && !isTenancy(host) -} - -func isTenancy(host string) bool { - return strings.HasSuffix(host, "."+tenancyHost) -} - func IsHostSupported() error { host, _ := auth.DefaultHost() - if isEnterprise(host) { + // Note that this check is slightly redundant as Tenancy should not be considered Enterprise + // but the ghinstance package has not been updated to reflect this yet. + if ghinstance.IsEnterprise(host) && !ghinstance.IsTenancy(host) { return ErrUnsupportedHost } return nil diff --git a/pkg/cmd/attestation/auth/host_test.go b/pkg/cmd/attestation/auth/host_test.go index 1192e1d9a..1d84888c4 100644 --- a/pkg/cmd/attestation/auth/host_test.go +++ b/pkg/cmd/attestation/auth/host_test.go @@ -1,10 +1,8 @@ package auth import ( - "os" "testing" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -34,23 +32,18 @@ func TestIsHostSupported(t *testing.T) { expectedErr: false, host: "some-tenant.ghe.com", }, - { - name: "Unsupported host", - expectedErr: true, - host: "my-unsupported-host.github.com", - }, } for _, tc := range testcases { - err := os.Setenv("GH_HOST", tc.host) - require.NoError(t, err) + t.Run(tc.name, func(t *testing.T) { + t.Setenv("GH_HOST", tc.host) - err = IsHostSupported() - if tc.expectedErr { - assert.Error(t, err) - assert.ErrorIs(t, err, ErrUnsupportedHost) - } else { - assert.NoError(t, err) - } + err := IsHostSupported() + if tc.expectedErr { + require.ErrorIs(t, err, ErrUnsupportedHost) + } else { + require.NoError(t, err) + } + }) } }