From 88cae9f5be67989be191cd4a1fe043ddbeafe629 Mon Sep 17 00:00:00 2001 From: Heath Stewart Date: Mon, 20 Mar 2023 07:58:30 -0700 Subject: [PATCH] Support ext install --force (#7173) Resolves #7096 --- pkg/cmd/extension/command.go | 73 ++++++++++++++++++------------- pkg/cmd/extension/command_test.go | 51 ++++++++++++++++++++- 2 files changed, 92 insertions(+), 32 deletions(-) diff --git a/pkg/cmd/extension/command.go b/pkg/cmd/extension/command.go index bfb87188e..21279c1e2 100644 --- a/pkg/cmd/extension/command.go +++ b/pkg/cmd/extension/command.go @@ -49,6 +49,33 @@ func NewCmdExtension(f *cmdutil.Factory) *cobra.Command { Aliases: []string{"extensions", "ext"}, } + upgradeFunc := func(name string, flagForce, flagDryRun bool) error { + cs := io.ColorScheme() + err := m.Upgrade(name, flagForce) + if err != nil { + if name != "" { + fmt.Fprintf(io.ErrOut, "%s Failed upgrading extension %s: %s\n", cs.FailureIcon(), name, err) + } else if errors.Is(err, noExtensionsInstalledError) { + return cmdutil.NewNoResultsError("no installed extensions found") + } else { + fmt.Fprintf(io.ErrOut, "%s Failed upgrading extensions\n", cs.FailureIcon()) + } + return cmdutil.SilentError + } + if io.IsStdoutTTY() { + successStr := "Successfully" + if flagDryRun { + successStr = "Would have" + } + extensionStr := "extension" + if name == "" { + extensionStr = "extensions" + } + fmt.Fprintf(io.Out, "%s %s upgraded %s\n", cs.SuccessIcon(), successStr, extensionStr) + } + return nil + } + extCmd.AddCommand( func() *cobra.Command { query := search.Query{ @@ -267,6 +294,7 @@ func NewCmdExtension(f *cmdutil.Factory) *cobra.Command { }, }, func() *cobra.Command { + var forceFlag bool var pinFlag string cmd := &cobra.Command{ Use: "install ", @@ -305,7 +333,12 @@ func NewCmdExtension(f *cmdutil.Factory) *cobra.Command { return err } - if err := checkValidExtension(cmd.Root(), m, repo.RepoName()); err != nil { + if ext, err := checkValidExtension(cmd.Root(), m, repo.RepoName()); err != nil { + // If an existing extension was found and --force was specified, attempt to upgrade. + if forceFlag && ext != nil { + return upgradeFunc(ext.Name(), forceFlag, false) + } + return err } @@ -333,6 +366,7 @@ func NewCmdExtension(f *cmdutil.Factory) *cobra.Command { return nil }, } + cmd.Flags().BoolVar(&forceFlag, "force", false, "force upgrade extension, or ignore if latest already installed") cmd.Flags().StringVar(&pinFlag, "pin", "", "pin extension to a release tag or commit ref") return cmd }(), @@ -363,30 +397,7 @@ func NewCmdExtension(f *cmdutil.Factory) *cobra.Command { if flagDryRun { m.EnableDryRunMode() } - cs := io.ColorScheme() - err := m.Upgrade(name, flagForce) - if err != nil { - if name != "" { - fmt.Fprintf(io.ErrOut, "%s Failed upgrading extension %s: %s\n", cs.FailureIcon(), name, err) - } else if errors.Is(err, noExtensionsInstalledError) { - return cmdutil.NewNoResultsError("no installed extensions found") - } else { - fmt.Fprintf(io.ErrOut, "%s Failed upgrading extensions\n", cs.FailureIcon()) - } - return cmdutil.SilentError - } - if io.IsStdoutTTY() { - successStr := "Successfully" - if flagDryRun { - successStr = "Would have" - } - extensionStr := "extension" - if name == "" { - extensionStr = "extensions" - } - fmt.Fprintf(io.Out, "%s %s upgraded %s\n", cs.SuccessIcon(), successStr, extensionStr) - } - return nil + return upgradeFunc(name, flagForce, flagDryRun) }, } cmd.Flags().BoolVar(&flagAll, "all", false, "Upgrade all extensions") @@ -626,25 +637,25 @@ func NewCmdExtension(f *cmdutil.Factory) *cobra.Command { return &extCmd } -func checkValidExtension(rootCmd *cobra.Command, m extensions.ExtensionManager, extName string) error { +func checkValidExtension(rootCmd *cobra.Command, m extensions.ExtensionManager, extName string) (extensions.Extension, error) { if !strings.HasPrefix(extName, "gh-") { - return errors.New("extension repository name must start with `gh-`") + return nil, errors.New("extension repository name must start with `gh-`") } commandName := strings.TrimPrefix(extName, "gh-") if c, _, err := rootCmd.Traverse([]string{commandName}); err != nil { - return err + return nil, err } else if c != rootCmd { - return fmt.Errorf("%q matches the name of a built-in command", commandName) + return nil, fmt.Errorf("%q matches the name of a built-in command", commandName) } for _, ext := range m.List() { if ext.Name() == commandName { - return fmt.Errorf("there is already an installed extension that provides the %q command", commandName) + return ext, fmt.Errorf("there is already an installed extension that provides the %q command", commandName) } } - return nil + return nil, nil } func normalizeExtensionSelector(n string) string { diff --git a/pkg/cmd/extension/command_test.go b/pkg/cmd/extension/command_test.go index 050b6d6f0..ebc4716de 100644 --- a/pkg/cmd/extension/command_test.go +++ b/pkg/cmd/extension/command_test.go @@ -797,6 +797,55 @@ func TestNewCmdExtension(t *testing.T) { wantErr: true, errMsg: "this command runs an interactive UI and needs to be run in a terminal", }, + { + name: "force install when absent", + args: []string{"install", "owner/gh-hello", "--force"}, + managerStubs: func(em *extensions.ExtensionManagerMock) func(*testing.T) { + em.ListFunc = func() []extensions.Extension { + return []extensions.Extension{} + } + em.InstallFunc = func(_ ghrepo.Interface, _ string) error { + return nil + } + return func(t *testing.T) { + listCalls := em.ListCalls() + assert.Equal(t, 1, len(listCalls)) + installCalls := em.InstallCalls() + assert.Equal(t, 1, len(installCalls)) + assert.Equal(t, "gh-hello", installCalls[0].InterfaceMoqParam.RepoName()) + } + }, + isTTY: true, + wantStdout: "✓ Installed extension owner/gh-hello\n", + }, + { + name: "force install when present", + args: []string{"install", "owner/gh-hello", "--force"}, + managerStubs: func(em *extensions.ExtensionManagerMock) func(*testing.T) { + em.ListFunc = func() []extensions.Extension { + return []extensions.Extension{ + &Extension{path: "owner/gh-hello"}, + } + } + em.InstallFunc = func(_ ghrepo.Interface, _ string) error { + return nil + } + em.UpgradeFunc = func(name string, force bool) error { + return nil + } + return func(t *testing.T) { + listCalls := em.ListCalls() + assert.Equal(t, 1, len(listCalls)) + installCalls := em.InstallCalls() + assert.Equal(t, 0, len(installCalls)) + upgradeCalls := em.UpgradeCalls() + assert.Equal(t, 1, len(upgradeCalls)) + assert.Equal(t, "hello", upgradeCalls[0].Name) + } + }, + isTTY: true, + wantStdout: "✓ Successfully upgraded extension\n", + }, } for _, tt := range tests { @@ -939,7 +988,7 @@ func Test_checkValidExtension(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := checkValidExtension(tt.args.rootCmd, tt.args.manager, tt.args.extName) + _, err := checkValidExtension(tt.args.rootCmd, tt.args.manager, tt.args.extName) if tt.wantError == "" { assert.NoError(t, err) } else {