diff --git a/internal/ghcmd/cmd.go b/internal/ghcmd/cmd.go index 8690078c6..3bf0f5f4a 100644 --- a/internal/ghcmd/cmd.go +++ b/internal/ghcmd/cmd.go @@ -14,15 +14,18 @@ import ( surveyCore "github.com/AlecAivazis/survey/v2/core" "github.com/AlecAivazis/survey/v2/terminal" + "github.com/MakeNowJust/heredoc" "github.com/cli/cli/v2/api" "github.com/cli/cli/v2/internal/agents" "github.com/cli/cli/v2/internal/build" "github.com/cli/cli/v2/internal/config" "github.com/cli/cli/v2/internal/config/migration" + "github.com/cli/cli/v2/internal/prompter" "github.com/cli/cli/v2/internal/update" "github.com/cli/cli/v2/pkg/cmd/factory" "github.com/cli/cli/v2/pkg/cmd/root" "github.com/cli/cli/v2/pkg/cmdutil" + "github.com/cli/cli/v2/pkg/extensions" "github.com/cli/cli/v2/pkg/iostreams" "github.com/cli/cli/v2/utils" "github.com/cli/safeexec" @@ -140,6 +143,18 @@ func Main() exitCode { return exitCode(extError.ExitCode()) } + // Check if any of the provided args match a known official extension. + // We scan all args rather than just the first because global flags + // (e.g. --repo) may precede the unknown command name. + if strings.HasPrefix(err.Error(), "unknown command ") { + for _, arg := range expandedArgs { + if ext := extensions.FindOfficialExtension(arg); ext != nil { + handleOfficialExtension(cmdFactory.IOStreams, cmdFactory.Prompter, cmdFactory.ExtensionManager, ext, err) + return exitError + } + } + } + printError(stderr, err, cmd, hasDebug) if strings.Contains(err.Error(), "Incorrect function") { @@ -245,3 +260,41 @@ func isUnderHomebrew(ghBinary string) bool { brewBinPrefix := filepath.Join(strings.TrimSpace(string(brewPrefixBytes)), "bin") + string(filepath.Separator) return strings.HasPrefix(ghBinary, brewBinPrefix) } + +// handleOfficialExtension prints a suggestion for the matched official extension +// and, in interactive TTY sessions, prompts the user to install it. +func handleOfficialExtension(io *iostreams.IOStreams, p prompter.Prompter, em extensions.ExtensionManager, ext *extensions.OfficialExtension, err error) { + stderr := io.ErrOut + + fmt.Fprintln(stderr, err) + + if !io.CanPrompt() { + fmt.Fprint(stderr, heredoc.Docf(` + %q is also available as an official extension. + To install it, run: + gh extension install github.com/%s/%s + `, fmt.Sprintf("gh %s", ext.Name), ext.Owner, ext.Repo)) + return + } + + prompt := heredoc.Docf(` + %q is also available as an official extension. + Would you like to install it now? + `, fmt.Sprintf("gh %s", ext.Name)) + confirmed, promptErr := p.Confirm(prompt, true) + if promptErr != nil || !confirmed { + return + } + + repo := ext.Repository() + io.StartProgressIndicatorWithLabel(fmt.Sprintf("Installing %s/%s...", ext.Owner, ext.Repo)) + defer io.StopProgressIndicator() + installErr := em.Install(repo, "") + io.StopProgressIndicator() + if installErr != nil { + fmt.Fprintf(stderr, "Failed to install extension: %s\n", installErr) + return + } + + fmt.Fprintf(stderr, "Successfully installed %s/%s\n", ext.Owner, ext.Repo) +} diff --git a/pkg/extensions/official.go b/pkg/extensions/official.go new file mode 100644 index 000000000..a1e6996db --- /dev/null +++ b/pkg/extensions/official.go @@ -0,0 +1,40 @@ +package extensions + +import ( + "github.com/cli/cli/v2/internal/ghrepo" +) + +// OfficialExtension describes a GitHub-owned CLI extension that can be +// suggested to users when they invoke an unknown command. +type OfficialExtension struct { + Name string + Owner string + Repo string +} + +// Repository returns a ghrepo.Interface pinned to github.com for use with +// ExtensionManager.Install. +func (e *OfficialExtension) Repository() ghrepo.Interface { + return ghrepo.NewWithHost(e.Owner, e.Repo, "github.com") +} + +// officialExtensions is the hard-coded registry of GitHub-owned extensions +// that gh will suggest installing when the user invokes an unknown command +// matching one of their names. +// Install suggestions include the "github.com/" host prefix so that GHES users +// install from github.com rather than their enterprise host. +var officialExtensions = []OfficialExtension{ + {Name: "aw", Owner: "github", Repo: "gh-aw"}, + {Name: "stack", Owner: "github", Repo: "gh-stack"}, +} + +// FindOfficialExtension returns the matching official extension for +// commandName, or nil if none matches. +func FindOfficialExtension(commandName string) *OfficialExtension { + for _, ext := range officialExtensions { + if ext.Name == commandName { + return &ext + } + } + return nil +} diff --git a/pkg/extensions/official_test.go b/pkg/extensions/official_test.go new file mode 100644 index 000000000..0a0b5ec52 --- /dev/null +++ b/pkg/extensions/official_test.go @@ -0,0 +1,41 @@ +package extensions + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFindOfficialExtension(t *testing.T) { + tests := []struct { + name string + commandName string + wantNil bool + wantRepo string + }{ + {name: "found", commandName: "stack", wantNil: false, wantRepo: "gh-stack"}, + {name: "not found", commandName: "xyzzy", wantNil: true}, + {name: "empty", commandName: "", wantNil: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ext := FindOfficialExtension(tt.commandName) + if tt.wantNil { + assert.Nil(t, ext) + } else { + require.NotNil(t, ext) + assert.Equal(t, tt.wantRepo, ext.Repo) + } + }) + } +} + +func TestOfficialExtension_Repository(t *testing.T) { + ext := &OfficialExtension{Name: "stack", Owner: "github", Repo: "gh-stack"} + repo := ext.Repository() + assert.Equal(t, "github", repo.RepoOwner()) + assert.Equal(t, "gh-stack", repo.RepoName()) + assert.Equal(t, "github.com", repo.RepoHost()) +}