Support ext install --force (#7173)

Resolves #7096
This commit is contained in:
Heath Stewart 2023-03-20 07:58:30 -07:00 committed by GitHub
parent 0b9b1f710f
commit 88cae9f5be
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 92 additions and 32 deletions

View file

@ -49,6 +49,33 @@ func NewCmdExtension(f *cmdutil.Factory) *cobra.Command {
Aliases: []string{"extensions", "ext"},
}
upgradeFunc := func(name string, flagForce, flagDryRun bool) error {
cs := io.ColorScheme()
err := m.Upgrade(name, flagForce)
if err != nil {
if name != "" {
fmt.Fprintf(io.ErrOut, "%s Failed upgrading extension %s: %s\n", cs.FailureIcon(), name, err)
} else if errors.Is(err, noExtensionsInstalledError) {
return cmdutil.NewNoResultsError("no installed extensions found")
} else {
fmt.Fprintf(io.ErrOut, "%s Failed upgrading extensions\n", cs.FailureIcon())
}
return cmdutil.SilentError
}
if io.IsStdoutTTY() {
successStr := "Successfully"
if flagDryRun {
successStr = "Would have"
}
extensionStr := "extension"
if name == "" {
extensionStr = "extensions"
}
fmt.Fprintf(io.Out, "%s %s upgraded %s\n", cs.SuccessIcon(), successStr, extensionStr)
}
return nil
}
extCmd.AddCommand(
func() *cobra.Command {
query := search.Query{
@ -267,6 +294,7 @@ func NewCmdExtension(f *cmdutil.Factory) *cobra.Command {
},
},
func() *cobra.Command {
var forceFlag bool
var pinFlag string
cmd := &cobra.Command{
Use: "install <repository>",
@ -305,7 +333,12 @@ func NewCmdExtension(f *cmdutil.Factory) *cobra.Command {
return err
}
if err := checkValidExtension(cmd.Root(), m, repo.RepoName()); err != nil {
if ext, err := checkValidExtension(cmd.Root(), m, repo.RepoName()); err != nil {
// If an existing extension was found and --force was specified, attempt to upgrade.
if forceFlag && ext != nil {
return upgradeFunc(ext.Name(), forceFlag, false)
}
return err
}
@ -333,6 +366,7 @@ func NewCmdExtension(f *cmdutil.Factory) *cobra.Command {
return nil
},
}
cmd.Flags().BoolVar(&forceFlag, "force", false, "force upgrade extension, or ignore if latest already installed")
cmd.Flags().StringVar(&pinFlag, "pin", "", "pin extension to a release tag or commit ref")
return cmd
}(),
@ -363,30 +397,7 @@ func NewCmdExtension(f *cmdutil.Factory) *cobra.Command {
if flagDryRun {
m.EnableDryRunMode()
}
cs := io.ColorScheme()
err := m.Upgrade(name, flagForce)
if err != nil {
if name != "" {
fmt.Fprintf(io.ErrOut, "%s Failed upgrading extension %s: %s\n", cs.FailureIcon(), name, err)
} else if errors.Is(err, noExtensionsInstalledError) {
return cmdutil.NewNoResultsError("no installed extensions found")
} else {
fmt.Fprintf(io.ErrOut, "%s Failed upgrading extensions\n", cs.FailureIcon())
}
return cmdutil.SilentError
}
if io.IsStdoutTTY() {
successStr := "Successfully"
if flagDryRun {
successStr = "Would have"
}
extensionStr := "extension"
if name == "" {
extensionStr = "extensions"
}
fmt.Fprintf(io.Out, "%s %s upgraded %s\n", cs.SuccessIcon(), successStr, extensionStr)
}
return nil
return upgradeFunc(name, flagForce, flagDryRun)
},
}
cmd.Flags().BoolVar(&flagAll, "all", false, "Upgrade all extensions")
@ -626,25 +637,25 @@ func NewCmdExtension(f *cmdutil.Factory) *cobra.Command {
return &extCmd
}
func checkValidExtension(rootCmd *cobra.Command, m extensions.ExtensionManager, extName string) error {
func checkValidExtension(rootCmd *cobra.Command, m extensions.ExtensionManager, extName string) (extensions.Extension, error) {
if !strings.HasPrefix(extName, "gh-") {
return errors.New("extension repository name must start with `gh-`")
return nil, errors.New("extension repository name must start with `gh-`")
}
commandName := strings.TrimPrefix(extName, "gh-")
if c, _, err := rootCmd.Traverse([]string{commandName}); err != nil {
return err
return nil, err
} else if c != rootCmd {
return fmt.Errorf("%q matches the name of a built-in command", commandName)
return nil, fmt.Errorf("%q matches the name of a built-in command", commandName)
}
for _, ext := range m.List() {
if ext.Name() == commandName {
return fmt.Errorf("there is already an installed extension that provides the %q command", commandName)
return ext, fmt.Errorf("there is already an installed extension that provides the %q command", commandName)
}
}
return nil
return nil, nil
}
func normalizeExtensionSelector(n string) string {

View file

@ -797,6 +797,55 @@ func TestNewCmdExtension(t *testing.T) {
wantErr: true,
errMsg: "this command runs an interactive UI and needs to be run in a terminal",
},
{
name: "force install when absent",
args: []string{"install", "owner/gh-hello", "--force"},
managerStubs: func(em *extensions.ExtensionManagerMock) func(*testing.T) {
em.ListFunc = func() []extensions.Extension {
return []extensions.Extension{}
}
em.InstallFunc = func(_ ghrepo.Interface, _ string) error {
return nil
}
return func(t *testing.T) {
listCalls := em.ListCalls()
assert.Equal(t, 1, len(listCalls))
installCalls := em.InstallCalls()
assert.Equal(t, 1, len(installCalls))
assert.Equal(t, "gh-hello", installCalls[0].InterfaceMoqParam.RepoName())
}
},
isTTY: true,
wantStdout: "✓ Installed extension owner/gh-hello\n",
},
{
name: "force install when present",
args: []string{"install", "owner/gh-hello", "--force"},
managerStubs: func(em *extensions.ExtensionManagerMock) func(*testing.T) {
em.ListFunc = func() []extensions.Extension {
return []extensions.Extension{
&Extension{path: "owner/gh-hello"},
}
}
em.InstallFunc = func(_ ghrepo.Interface, _ string) error {
return nil
}
em.UpgradeFunc = func(name string, force bool) error {
return nil
}
return func(t *testing.T) {
listCalls := em.ListCalls()
assert.Equal(t, 1, len(listCalls))
installCalls := em.InstallCalls()
assert.Equal(t, 0, len(installCalls))
upgradeCalls := em.UpgradeCalls()
assert.Equal(t, 1, len(upgradeCalls))
assert.Equal(t, "hello", upgradeCalls[0].Name)
}
},
isTTY: true,
wantStdout: "✓ Successfully upgraded extension\n",
},
}
for _, tt := range tests {
@ -939,7 +988,7 @@ func Test_checkValidExtension(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := checkValidExtension(tt.args.rootCmd, tt.args.manager, tt.args.extName)
_, err := checkValidExtension(tt.args.rootCmd, tt.args.manager, tt.args.extName)
if tt.wantError == "" {
assert.NoError(t, err)
} else {