Use go-gh/auth package for IsEnterprise, IsTenancy, and NormalizeHostname

This commit is contained in:
Tyler McGoffin 2024-10-11 15:06:11 -07:00
parent 44fdb3320d
commit 81591a09b8
18 changed files with 45 additions and 212 deletions

View file

@ -10,8 +10,8 @@ import (
"regexp"
"strings"
"github.com/cli/cli/v2/internal/ghinstance"
ghAPI "github.com/cli/go-gh/v2/pkg/api"
ghauth "github.com/cli/go-gh/v2/pkg/auth"
)
const (
@ -249,7 +249,7 @@ func generateScopesSuggestion(statusCode int, endpointNeedsScopes, tokenHasScope
return fmt.Sprintf(
"This API operation needs the %[1]q scope. To request it, run: gh auth refresh -h %[2]s -s %[1]s",
s,
ghinstance.NormalizeHostname(hostname),
ghauth.NormalizeHostname(hostname),
)
}

View file

@ -7,9 +7,9 @@ import (
"strings"
"time"
"github.com/cli/cli/v2/internal/ghinstance"
"github.com/cli/cli/v2/utils"
ghAPI "github.com/cli/go-gh/v2/pkg/api"
ghauth "github.com/cli/go-gh/v2/pkg/auth"
)
type tokenGetter interface {
@ -98,7 +98,7 @@ func AddAuthTokenHeader(rt http.RoundTripper, cfg tokenGetter) http.RoundTripper
// Only set header if an initial request or redirect request to the same host as the initial request.
// If the host has changed during a redirect do not add the authentication token header.
if !redirectHostnameChange {
hostname := ghinstance.NormalizeHostname(getHost(req))
hostname := ghauth.NormalizeHostname(getHost(req))
if token, _ := cfg.ActiveToken(hostname); token != "" {
req.Header.Set(authorization, fmt.Sprintf("token %s", token))
}

View file

@ -16,6 +16,8 @@ import (
"github.com/cli/cli/v2/utils"
"github.com/cli/oauth"
"github.com/henvic/httpretty"
ghauth "github.com/cli/go-gh/v2/pkg/auth"
)
var (
@ -105,7 +107,7 @@ func AuthFlow(oauthHost string, IO *iostreams.IOStreams, notice string, addition
func getCallbackURI(oauthHost string) string {
callbackURI := "http://127.0.0.1/callback"
if ghinstance.IsEnterprise(oauthHost) {
if ghauth.IsEnterprise(oauthHost) {
// the OAuth app on Enterprise hosts is still registered with a legacy callback URL
// see https://github.com/cli/cli/pull/222, https://github.com/cli/cli/pull/650
callbackURI = "http://localhost/"

View file

@ -10,7 +10,7 @@ import (
"github.com/cli/cli/v2/internal/gh"
"github.com/cli/cli/v2/internal/keyring"
o "github.com/cli/cli/v2/pkg/option"
ghAuth "github.com/cli/go-gh/v2/pkg/auth"
ghauth "github.com/cli/go-gh/v2/pkg/auth"
ghConfig "github.com/cli/go-gh/v2/pkg/config"
)
@ -206,7 +206,7 @@ func (c *AuthConfig) ActiveToken(hostname string) (string, string) {
if c.tokenOverride != nil {
return c.tokenOverride(hostname)
}
token, source := ghAuth.TokenFromEnvOrConfig(hostname)
token, source := ghauth.TokenFromEnvOrConfig(hostname)
if token == "" {
var err error
token, err = c.TokenFromKeyring(hostname)
@ -240,7 +240,7 @@ func (c *AuthConfig) HasEnvToken() bool {
// It has to use a hostname that is not going to be found in the hosts so that it
// can guarantee that tokens will only be returned from a set env var.
// Discussed here, but maybe worth revisiting: https://github.com/cli/cli/pull/7169#discussion_r1136979033
token, _ := ghAuth.TokenFromEnvOrConfig(hostname)
token, _ := ghauth.TokenFromEnvOrConfig(hostname)
return token != ""
}
@ -282,7 +282,7 @@ func (c *AuthConfig) Hosts() []string {
if c.hostsOverride != nil {
return c.hostsOverride()
}
return ghAuth.KnownHosts()
return ghauth.KnownHosts()
}
// SetHosts will override any hosts resolution and return the given
@ -297,7 +297,7 @@ func (c *AuthConfig) DefaultHost() (string, string) {
if c.defaultHostOverride != nil {
return c.defaultHostOverride()
}
return ghAuth.DefaultHost()
return ghauth.DefaultHost()
}
// SetDefaultHost will override any host resolution and return the given

View file

@ -4,8 +4,9 @@ import (
"net/http"
"github.com/cli/cli/v2/api"
"github.com/cli/cli/v2/internal/ghinstance"
"golang.org/x/sync/errgroup"
ghauth "github.com/cli/go-gh/v2/pkg/auth"
)
type Detector interface {
@ -62,7 +63,7 @@ func NewDetector(httpClient *http.Client, host string) Detector {
}
func (d *detector) IssueFeatures() (IssueFeatures, error) {
if !ghinstance.IsEnterprise(d.host) {
if !ghauth.IsEnterprise(d.host) {
return allIssueFeatures, nil
}
@ -163,7 +164,7 @@ func (d *detector) PullRequestFeatures() (PullRequestFeatures, error) {
}
func (d *detector) RepositoryFeatures() (RepositoryFeatures, error) {
if !ghinstance.IsEnterprise(d.host) {
if !ghauth.IsEnterprise(d.host) {
return allRepositoryFeatures, nil
}

View file

@ -4,6 +4,8 @@ import (
"errors"
"fmt"
"strings"
ghauth "github.com/cli/go-gh/v2/pkg/auth"
)
// DefaultHostname is the domain name of the default GitHub instance.
@ -20,22 +22,10 @@ func Default() string {
return defaultHostname
}
// IsEnterprise reports whether a non-normalized host name looks like a GHE instance.
func IsEnterprise(h string) bool {
normalizedHostName := NormalizeHostname(h)
return normalizedHostName != defaultHostname && normalizedHostName != localhost
}
// IsTenancy reports whether a non-normalized host name looks like a tenancy instance.
func IsTenancy(h string) bool {
normalizedHostName := NormalizeHostname(h)
return strings.HasSuffix(normalizedHostName, "."+tenancyHost)
}
// TenantName extracts the tenant name from tenancy host name and
// reports whether it found the tenant name.
func TenantName(h string) (string, bool) {
normalizedHostName := NormalizeHostname(h)
normalizedHostName := ghauth.NormalizeHostname(h)
return cutSuffix(normalizedHostName, "."+tenancyHost)
}
@ -43,22 +33,6 @@ func isGarage(h string) bool {
return strings.EqualFold(h, "garage.github.com")
}
// NormalizeHostname returns the canonical host name of a GitHub instance.
func NormalizeHostname(h string) string {
hostname := strings.ToLower(h)
if strings.HasSuffix(hostname, "."+defaultHostname) {
return defaultHostname
}
if strings.HasSuffix(hostname, "."+localhost) {
return localhost
}
if before, found := cutSuffix(hostname, "."+tenancyHost); found {
idx := strings.LastIndex(before, ".")
return fmt.Sprintf("%s.%s", before[idx+1:], tenancyHost)
}
return hostname
}
func HostnameValidator(hostname string) error {
if len(strings.TrimSpace(hostname)) < 1 {
return errors.New("a value is required")
@ -77,10 +51,10 @@ func GraphQLEndpoint(hostname string) string {
// conditional can be removed as the flow will fall through to the bottom.
// However, we can't do that until we've investigated all places in which
// Tenancy is currently treated as Enterprise.
if IsTenancy(hostname) {
if ghauth.IsTenancy(hostname) {
return fmt.Sprintf("https://api.%s/graphql", hostname)
}
if IsEnterprise(hostname) {
if ghauth.IsEnterprise(hostname) {
return fmt.Sprintf("https://%s/api/graphql", hostname)
}
if strings.EqualFold(hostname, localhost) {
@ -97,10 +71,10 @@ func RESTPrefix(hostname string) string {
// conditional can be removed as the flow will fall through to the bottom.
// However, we can't do that until we've investigated all places in which
// Tenancy is currently treated as Enterprise.
if IsTenancy(hostname) {
if ghauth.IsTenancy(hostname) {
return fmt.Sprintf("https://api.%s/", hostname)
}
if IsEnterprise(hostname) {
if ghauth.IsEnterprise(hostname) {
return fmt.Sprintf("https://%s/api/v3/", hostname)
}
if strings.EqualFold(hostname, localhost) {
@ -121,7 +95,7 @@ func GistHost(hostname string) string {
if isGarage(hostname) {
return fmt.Sprintf("%s/gist/", hostname)
}
if IsEnterprise(hostname) {
if ghauth.IsEnterprise(hostname) {
return fmt.Sprintf("%s/gist/", hostname)
}
if strings.EqualFold(hostname, localhost) {

View file

@ -6,88 +6,6 @@ import (
"github.com/stretchr/testify/assert"
)
func TestIsEnterprise(t *testing.T) {
tests := []struct {
host string
want bool
}{
{
host: "github.com",
want: false,
},
{
host: "api.github.com",
want: false,
},
{
host: "github.localhost",
want: false,
},
{
host: "api.github.localhost",
want: false,
},
{
host: "garage.github.com",
want: false,
},
{
host: "ghe.io",
want: true,
},
{
host: "example.com",
want: true,
},
}
for _, tt := range tests {
t.Run(tt.host, func(t *testing.T) {
if got := IsEnterprise(tt.host); got != tt.want {
t.Errorf("IsEnterprise() = %v, want %v", got, tt.want)
}
})
}
}
func TestIsTenancy(t *testing.T) {
tests := []struct {
host string
want bool
}{
{
host: "github.com",
want: false,
},
{
host: "github.localhost",
want: false,
},
{
host: "garage.github.com",
want: false,
},
{
host: "ghe.com",
want: false,
},
{
host: "tenant.ghe.com",
want: true,
},
{
host: "api.tenant.ghe.com",
want: true,
},
}
for _, tt := range tests {
t.Run(tt.host, func(t *testing.T) {
if got := IsTenancy(tt.host); got != tt.want {
t.Errorf("IsTenancy() = %v, want %v", got, tt.want)
}
})
}
}
func TestTenantName(t *testing.T) {
tests := []struct {
host string
@ -130,69 +48,6 @@ func TestTenantName(t *testing.T) {
}
}
func TestNormalizeHostname(t *testing.T) {
tests := []struct {
host string
want string
}{
{
host: "GitHub.com",
want: "github.com",
},
{
host: "api.github.com",
want: "github.com",
},
{
host: "ssh.github.com",
want: "github.com",
},
{
host: "upload.github.com",
want: "github.com",
},
{
host: "GitHub.localhost",
want: "github.localhost",
},
{
host: "api.github.localhost",
want: "github.localhost",
},
{
host: "garage.github.com",
want: "github.com",
},
{
host: "GHE.IO",
want: "ghe.io",
},
{
host: "git.my.org",
want: "git.my.org",
},
{
host: "ghe.com",
want: "ghe.com",
},
{
host: "tenant.ghe.com",
want: "tenant.ghe.com",
},
{
host: "api.tenant.ghe.com",
want: "tenant.ghe.com",
},
}
for _, tt := range tests {
t.Run(tt.host, func(t *testing.T) {
if got := NormalizeHostname(tt.host); got != tt.want {
t.Errorf("NormalizeHostname() = %v, want %v", got, tt.want)
}
})
}
}
func TestHostnameValidator(t *testing.T) {
tests := []struct {
name string

View file

@ -6,7 +6,7 @@ import (
"strings"
"github.com/cli/cli/v2/internal/ghinstance"
ghAuth "github.com/cli/go-gh/v2/pkg/auth"
ghauth "github.com/cli/go-gh/v2/pkg/auth"
"github.com/cli/go-gh/v2/pkg/repository"
)
@ -37,7 +37,7 @@ func FullName(r Interface) string {
}
func defaultHost() string {
host, _ := ghAuth.DefaultHost()
host, _ := ghauth.DefaultHost()
return host
}

View file

@ -3,7 +3,7 @@ package auth
import (
"errors"
"github.com/cli/cli/v2/internal/ghinstance"
ghauth "github.com/cli/go-gh/v2/pkg/auth"
)
var ErrUnsupportedHost = errors.New("An unsupported host was detected. Note that gh attestation does not currently support GHES")
@ -11,7 +11,7 @@ var ErrUnsupportedHost = errors.New("An unsupported host was detected. Note that
func IsHostSupported(host string) error {
// Note that this check is slightly redundant as Tenancy should not be considered Enterprise
// but the ghinstance package has not been updated to reflect this yet.
if ghinstance.IsEnterprise(host) && !ghinstance.IsTenancy(host) {
if ghauth.IsEnterprise(host) && !ghauth.IsTenancy(host) {
return ErrUnsupportedHost
}
return nil

View file

@ -85,7 +85,7 @@ func NewInspectCmd(f *cmdutil.Factory, runF func(*Options) error) *cobra.Command
Logger: opts.Logger,
}
// Prepare for tenancy if detected
if ghinstance.IsTenancy(opts.Hostname) {
if ghauth.IsTenancy(opts.Hostname) {
hc, err := f.HttpClient()
if err != nil {
return err

View file

@ -6,7 +6,6 @@ import (
"fmt"
"os"
"github.com/cli/cli/v2/internal/ghinstance"
"github.com/cli/cli/v2/pkg/cmd/attestation/api"
"github.com/cli/cli/v2/pkg/cmd/attestation/auth"
"github.com/cli/cli/v2/pkg/cmd/attestation/io"
@ -68,7 +67,7 @@ func NewTrustedRootCmd(f *cmdutil.Factory, runF func(*Options) error) *cobra.Com
return err
}
if ghinstance.IsTenancy(opts.Hostname) {
if ghauth.IsTenancy(opts.Hostname) {
c, err := f.Config()
if err != nil {
return err

View file

@ -138,7 +138,7 @@ func NewVerifyCmd(f *cmdutil.Factory, runF func(*Options) error) *cobra.Command
}
// Prepare for tenancy if detected
if ghinstance.IsTenancy(opts.Hostname) {
if ghauth.IsTenancy(opts.Hostname) {
td, err := opts.APIClient.GetTrustDomain()
if err != nil {
return fmt.Errorf("error getting trust domain, make sure you are authenticated against the host: %w", err)

View file

@ -15,7 +15,7 @@ import (
"github.com/cli/cli/v2/pkg/cmd/auth/shared/gitcredentials"
"github.com/cli/cli/v2/pkg/cmdutil"
"github.com/cli/cli/v2/pkg/iostreams"
ghAuth "github.com/cli/go-gh/v2/pkg/auth"
ghauth "github.com/cli/go-gh/v2/pkg/auth"
"github.com/spf13/cobra"
)
@ -123,7 +123,7 @@ func NewCmdLogin(f *cmdutil.Factory, runF func(*LoginOptions) error) *cobra.Comm
}
if opts.Hostname == "" && (!opts.Interactive || opts.Web) {
opts.Hostname, _ = ghAuth.DefaultHost()
opts.Hostname, _ = ghauth.DefaultHost()
}
opts.MainExecutable = f.Executable()

View file

@ -8,11 +8,12 @@ import (
"github.com/MakeNowJust/heredoc"
"github.com/cli/cli/v2/git"
"github.com/cli/cli/v2/internal/gh"
"github.com/cli/cli/v2/internal/ghinstance"
"github.com/cli/cli/v2/pkg/cmdutil"
"github.com/cli/cli/v2/pkg/iostreams"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
ghauth "github.com/cli/go-gh/v2/pkg/auth"
)
type CloneOptions struct {
@ -93,7 +94,7 @@ func cloneRun(opts *CloneOptions) error {
}
func formatRemoteURL(hostname string, gistID string, protocol string) string {
if ghinstance.IsEnterprise(hostname) {
if ghauth.IsEnterprise(hostname) || ghauth.IsTenancy(hostname) {
if protocol == "ssh" {
return fmt.Sprintf("git@%s:gist/%s.git", hostname, gistID)
}

View file

@ -6,9 +6,10 @@ import (
"strings"
"github.com/cli/cli/v2/api"
"github.com/cli/cli/v2/internal/ghinstance"
"github.com/cli/cli/v2/internal/ghrepo"
"github.com/cli/cli/v2/pkg/set"
ghauth "github.com/cli/go-gh/v2/pkg/auth"
)
type requestOptions struct {
@ -186,7 +187,7 @@ func pullRequestStatus(httpClient *http.Client, repo ghrepo.Interface, options r
}
func getCurrentUsername(username string, hostname string, apiClient *api.Client) (string, error) {
if username == "@me" && ghinstance.IsEnterprise(hostname) {
if username == "@me" && ghauth.IsEnterprise(hostname) {
var err error
username, err = api.CurrentLoginName(apiClient, hostname)
if err != nil {

View file

@ -10,7 +10,7 @@ import (
"github.com/cli/cli/v2/internal/ghrepo"
"github.com/cli/cli/v2/pkg/cmdutil"
"github.com/cli/cli/v2/pkg/iostreams"
ghAuth "github.com/cli/go-gh/v2/pkg/auth"
ghauth "github.com/cli/go-gh/v2/pkg/auth"
"github.com/spf13/cobra"
)
@ -85,7 +85,7 @@ func deleteRun(opts *DeleteOptions) error {
} else {
repoSelector := opts.RepoArg
if !strings.Contains(repoSelector, "/") {
defaultHost, _ := ghAuth.DefaultHost()
defaultHost, _ := ghauth.DefaultHost()
currentUser, err := api.CurrentLoginName(apiClient, defaultHost)
if err != nil {
return err

View file

@ -15,7 +15,7 @@ import (
"github.com/cli/cli/v2/pkg/cmd/ruleset/shared"
"github.com/cli/cli/v2/pkg/cmdutil"
"github.com/cli/cli/v2/pkg/iostreams"
ghAuth "github.com/cli/go-gh/v2/pkg/auth"
ghauth "github.com/cli/go-gh/v2/pkg/auth"
"github.com/spf13/cobra"
)
@ -109,7 +109,7 @@ func listRun(opts *ListOptions) error {
}
}
hostname, _ := ghAuth.DefaultHost()
hostname, _ := ghauth.DefaultHost()
if opts.WebMode {
var rulesetURL string

View file

@ -15,7 +15,7 @@ import (
"github.com/cli/cli/v2/pkg/cmd/ruleset/shared"
"github.com/cli/cli/v2/pkg/cmdutil"
"github.com/cli/cli/v2/pkg/iostreams"
ghAuth "github.com/cli/go-gh/v2/pkg/auth"
ghauth "github.com/cli/go-gh/v2/pkg/auth"
"github.com/spf13/cobra"
)
@ -124,7 +124,7 @@ func viewRun(opts *ViewOptions) error {
}
}
hostname, _ := ghAuth.DefaultHost()
hostname, _ := ghauth.DefaultHost()
cs := opts.IO.ColorScheme()
if opts.InteractiveMode {