Merge pull request #13182 from cli/wm/first-party-extension-suggestions

Suggest first party extensions
This commit is contained in:
William Martin 2026-04-16 18:13:59 +02:00 committed by GitHub
commit 81afcd0de0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 245 additions and 98 deletions

View file

@ -135,6 +135,7 @@ for _, tt := range tests {
- Add godoc comments to all exported functions, types, and constants
- Avoid unnecessary code comments — only comment when the *why* isn't obvious from the code
- Do not comment just to restate what the code does
- Never use em dashes (—) in code, comments, or documentation; use regular dashes (-) or rewrite the sentence instead
## Error Handling

View file

@ -14,18 +14,15 @@ import (
surveyCore "github.com/AlecAivazis/survey/v2/core"
"github.com/AlecAivazis/survey/v2/terminal"
"github.com/MakeNowJust/heredoc"
"github.com/cli/cli/v2/api"
"github.com/cli/cli/v2/internal/agents"
"github.com/cli/cli/v2/internal/build"
"github.com/cli/cli/v2/internal/config"
"github.com/cli/cli/v2/internal/config/migration"
"github.com/cli/cli/v2/internal/prompter"
"github.com/cli/cli/v2/internal/update"
"github.com/cli/cli/v2/pkg/cmd/factory"
"github.com/cli/cli/v2/pkg/cmd/root"
"github.com/cli/cli/v2/pkg/cmdutil"
"github.com/cli/cli/v2/pkg/extensions"
"github.com/cli/cli/v2/pkg/iostreams"
"github.com/cli/cli/v2/utils"
"github.com/cli/safeexec"
@ -143,18 +140,6 @@ func Main() exitCode {
return exitCode(extError.ExitCode())
}
// Check if any of the provided args match a known official extension.
// We scan all args rather than just the first because global flags
// (e.g. --repo) may precede the unknown command name.
if strings.HasPrefix(err.Error(), "unknown command ") {
for _, arg := range expandedArgs {
if ext := extensions.FindOfficialExtension(arg); ext != nil {
handleOfficialExtension(cmdFactory.IOStreams, cmdFactory.Prompter, cmdFactory.ExtensionManager, ext, err)
return exitError
}
}
}
printError(stderr, err, cmd, hasDebug)
if strings.Contains(err.Error(), "Incorrect function") {
@ -260,41 +245,3 @@ func isUnderHomebrew(ghBinary string) bool {
brewBinPrefix := filepath.Join(strings.TrimSpace(string(brewPrefixBytes)), "bin") + string(filepath.Separator)
return strings.HasPrefix(ghBinary, brewBinPrefix)
}
// handleOfficialExtension prints a suggestion for the matched official extension
// and, in interactive TTY sessions, prompts the user to install it.
func handleOfficialExtension(io *iostreams.IOStreams, p prompter.Prompter, em extensions.ExtensionManager, ext *extensions.OfficialExtension, err error) {
stderr := io.ErrOut
fmt.Fprintln(stderr, err)
if !io.CanPrompt() {
fmt.Fprint(stderr, heredoc.Docf(`
%q is also available as an official extension.
To install it, run:
gh extension install github.com/%s/%s
`, fmt.Sprintf("gh %s", ext.Name), ext.Owner, ext.Repo))
return
}
prompt := heredoc.Docf(`
%q is also available as an official extension.
Would you like to install it now?
`, fmt.Sprintf("gh %s", ext.Name))
confirmed, promptErr := p.Confirm(prompt, true)
if promptErr != nil || !confirmed {
return
}
repo := ext.Repository()
io.StartProgressIndicatorWithLabel(fmt.Sprintf("Installing %s/%s...", ext.Owner, ext.Repo))
defer io.StopProgressIndicator()
installErr := em.Install(repo, "")
io.StopProgressIndicator()
if installErr != nil {
fmt.Fprintf(stderr, "Failed to install extension: %s\n", installErr)
return
}
fmt.Fprintf(stderr, "Successfully installed %s/%s\n", ext.Owner, ext.Repo)
}

View file

@ -0,0 +1,78 @@
package root
import (
"fmt"
"github.com/MakeNowJust/heredoc"
"github.com/cli/cli/v2/internal/prompter"
"github.com/cli/cli/v2/pkg/cmdutil"
"github.com/cli/cli/v2/pkg/extensions"
"github.com/cli/cli/v2/pkg/iostreams"
"github.com/spf13/cobra"
)
// NewCmdOfficialExtension creates a hidden stub command for an official
// extension that has not yet been installed. When invoked, it suggests
// installing the extension and, in interactive sessions, offers to do so
// immediately. After a successful install, the extension is dispatched with
// the original arguments.
func NewCmdOfficialExtension(io *iostreams.IOStreams, p prompter.Prompter, em extensions.ExtensionManager, ext *extensions.OfficialExtension) *cobra.Command {
cmd := &cobra.Command{
Use: ext.Name,
Short: fmt.Sprintf("Install the official %s extension", ext.Name),
Hidden: true,
GroupID: "extension",
// Accept any args/flags the user may have passed so we don't get
// cobra validation errors before reaching RunE.
DisableFlagParsing: true,
RunE: func(cmd *cobra.Command, args []string) error {
return officialExtensionRun(io, p, em, ext, args)
},
}
cmdutil.DisableAuthCheck(cmd)
return cmd
}
func officialExtensionRun(io *iostreams.IOStreams, p prompter.Prompter, em extensions.ExtensionManager, ext *extensions.OfficialExtension, args []string) error {
stderr := io.ErrOut
if !io.CanPrompt() {
fmt.Fprint(stderr, heredoc.Docf(`
%[1]s is available as an official extension.
To install it, run:
gh extension install %[2]s/%[3]s
`, fmt.Sprintf("gh %s", ext.Name), ext.Owner, ext.Repo))
return nil
}
prompt := heredoc.Docf(`
%[1]s is available as an official extension.
Would you like to install it now?
`, fmt.Sprintf("gh %s", ext.Name))
confirmed, err := p.Confirm(prompt, true)
if err != nil {
return err
}
if !confirmed {
return nil
}
repo := ext.Repository()
io.StartProgressIndicatorWithLabel(fmt.Sprintf("Installing %s/%s...", ext.Owner, ext.Repo))
installErr := em.Install(repo, "")
io.StopProgressIndicator()
if installErr != nil {
return fmt.Errorf("failed to install extension: %w", installErr)
}
fmt.Fprintf(stderr, "Successfully installed %s/%s\n", ext.Owner, ext.Repo)
// Dispatch the newly installed extension with the original arguments.
dispatchArgs := append([]string{ext.Name}, args...)
if _, dispatchErr := em.Dispatch(dispatchArgs, io.In, io.Out, stderr); dispatchErr != nil {
return dispatchErr
}
return nil
}

View file

@ -0,0 +1,149 @@
package root
import (
"fmt"
"io"
"testing"
"github.com/cli/cli/v2/internal/ghrepo"
"github.com/cli/cli/v2/internal/prompter"
"github.com/cli/cli/v2/pkg/extensions"
"github.com/cli/cli/v2/pkg/iostreams"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestOfficialExtensionRun(t *testing.T) {
ext := &extensions.OfficialExtension{Name: "cool", Owner: "github", Repo: "gh-cool"}
tests := []struct {
name string
isTTY bool
confirmResult bool
confirmErr error
installErr error
dispatchErr error
args []string
wantErr string
wantStderr string
wantInstalled bool
wantDispatched bool
wantDispArgs []string
}{
{
name: "non-TTY prints install instructions",
isTTY: false,
wantStderr: "gh extension install github/gh-cool",
},
{
name: "TTY confirmed installs and dispatches",
isTTY: true,
confirmResult: true,
args: []string{"--help"},
wantStderr: "Successfully installed github/gh-cool",
wantInstalled: true,
wantDispatched: true,
wantDispArgs: []string{"cool", "--help"},
},
{
name: "TTY declined does not install",
isTTY: true,
confirmResult: false,
},
{
name: "TTY prompt error is propagated",
isTTY: true,
confirmErr: fmt.Errorf("prompt interrupted"),
wantErr: "prompt interrupted",
},
{
name: "TTY install error is propagated",
isTTY: true,
confirmResult: true,
installErr: fmt.Errorf("network error"),
wantErr: "network error",
wantInstalled: true,
},
{
name: "TTY dispatch error is propagated",
isTTY: true,
confirmResult: true,
dispatchErr: fmt.Errorf("dispatch failed"),
wantErr: "dispatch failed",
wantInstalled: true,
wantDispatched: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ios, _, _, stderr := iostreams.Test()
if tt.isTTY {
ios.SetStdinTTY(true)
ios.SetStdoutTTY(true)
ios.SetStderrTTY(true)
}
em := &extensions.ExtensionManagerMock{
InstallFunc: func(_ ghrepo.Interface, _ string) error {
return tt.installErr
},
DispatchFunc: func(_ []string, _ io.Reader, _, _ io.Writer) (bool, error) {
if tt.dispatchErr != nil {
return false, tt.dispatchErr
}
return true, nil
},
}
p := &prompter.PrompterMock{
ConfirmFunc: func(_ string, _ bool) (bool, error) {
return tt.confirmResult, tt.confirmErr
},
}
err := officialExtensionRun(ios, p, em, ext, tt.args)
if tt.wantErr != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.wantErr)
} else {
require.NoError(t, err)
}
if tt.wantStderr != "" {
assert.Contains(t, stderr.String(), tt.wantStderr)
}
if tt.wantInstalled {
require.NotEmpty(t, em.InstallCalls())
repo := em.InstallCalls()[0].InterfaceMoqParam
assert.Equal(t, "github", repo.RepoOwner())
assert.Equal(t, "gh-cool", repo.RepoName())
assert.Equal(t, "github.com", repo.RepoHost())
} else if tt.isTTY && !tt.confirmResult && tt.confirmErr == nil {
assert.Empty(t, em.InstallCalls())
}
if tt.wantDispatched {
require.NotEmpty(t, em.DispatchCalls())
if tt.wantDispArgs != nil {
assert.Equal(t, tt.wantDispArgs, em.DispatchCalls()[0].Args)
}
}
})
}
}
func TestNewCmdOfficialExtension_Properties(t *testing.T) {
ios, _, _, _ := iostreams.Test()
ext := &extensions.OfficialExtension{Name: "cool", Owner: "github", Repo: "gh-cool"}
em := &extensions.ExtensionManagerMock{}
p := &prompter.PrompterMock{}
cmd := NewCmdOfficialExtension(ios, p, em, ext)
assert.Equal(t, "cool", cmd.Use)
assert.True(t, cmd.Hidden)
assert.Equal(t, "extension", cmd.GroupID)
assert.True(t, cmd.DisableFlagParsing)
}

View file

@ -44,6 +44,7 @@ import (
versionCmd "github.com/cli/cli/v2/pkg/cmd/version"
workflowCmd "github.com/cli/cli/v2/pkg/cmd/workflow"
"github.com/cli/cli/v2/pkg/cmdutil"
"github.com/cli/cli/v2/pkg/extensions"
"github.com/google/shlex"
"github.com/spf13/cobra"
)
@ -229,6 +230,17 @@ func NewCmdRoot(f *cmdutil.Factory, version, buildDate string) (*cobra.Command,
}
}
// Official extension stubs — hidden commands that suggest installing
// GitHub-owned extensions when invoked. Registered after real extensions
// and aliases so that both take priority over stubs.
for i := range extensions.OfficialExtensions {
ext := &extensions.OfficialExtensions[i]
if _, _, err := cmd.Find([]string{ext.Name}); err == nil {
continue
}
cmd.AddCommand(NewCmdOfficialExtension(io, f.Prompter, em, ext))
}
cmdutil.DisableAuthCheck(cmd)
// The reference command produces paged output that displays information on every other command.

View file

@ -12,29 +12,15 @@ type OfficialExtension struct {
Repo string
}
// Repository returns a ghrepo.Interface pinned to github.com for use with
// ExtensionManager.Install.
// Repository returns a ghrepo.Interface pinned to github.com so that GHES
// users install from github.com rather than their enterprise host.
func (e *OfficialExtension) Repository() ghrepo.Interface {
return ghrepo.NewWithHost(e.Owner, e.Repo, "github.com")
}
// officialExtensions is the hard-coded registry of GitHub-owned extensions
// that gh will suggest installing when the user invokes an unknown command
// matching one of their names.
// Install suggestions include the "github.com/" host prefix so that GHES users
// install from github.com rather than their enterprise host.
var officialExtensions = []OfficialExtension{
// OfficialExtensions is the registry of GitHub-owned extensions that gh will
// offer to install when the user invokes the corresponding command name.
var OfficialExtensions = []OfficialExtension{
{Name: "aw", Owner: "github", Repo: "gh-aw"},
{Name: "stack", Owner: "github", Repo: "gh-stack"},
}
// FindOfficialExtension returns the matching official extension for
// commandName, or nil if none matches.
func FindOfficialExtension(commandName string) *OfficialExtension {
for _, ext := range officialExtensions {
if ext.Name == commandName {
return &ext
}
}
return nil
}

View file

@ -4,34 +4,8 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestFindOfficialExtension(t *testing.T) {
tests := []struct {
name string
commandName string
wantNil bool
wantRepo string
}{
{name: "found", commandName: "stack", wantNil: false, wantRepo: "gh-stack"},
{name: "not found", commandName: "xyzzy", wantNil: true},
{name: "empty", commandName: "", wantNil: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ext := FindOfficialExtension(tt.commandName)
if tt.wantNil {
assert.Nil(t, ext)
} else {
require.NotNil(t, ext)
assert.Equal(t, tt.wantRepo, ext.Repo)
}
})
}
}
func TestOfficialExtension_Repository(t *testing.T) {
ext := &OfficialExtension{Name: "stack", Owner: "github", Repo: "gh-stack"}
repo := ext.Repository()