From cc32f3358346a65c3fd2d4a1f23a96c586fef476 Mon Sep 17 00:00:00 2001 From: bagtoad <47394200+BagToad@users.noreply.github.com> Date: Mon, 21 Oct 2024 12:37:27 -0600 Subject: [PATCH] Refactor to remove code duplication --- pkg/cmd/extension/command.go | 26 +++----------------------- pkg/cmd/extension/command_test.go | 6 +++--- 2 files changed, 6 insertions(+), 26 deletions(-) diff --git a/pkg/cmd/extension/command.go b/pkg/cmd/extension/command.go index 43e9957fb..ded90cbe4 100644 --- a/pkg/cmd/extension/command.go +++ b/pkg/cmd/extension/command.go @@ -318,7 +318,7 @@ func NewCmdExtension(f *cmdutil.Factory) *cobra.Command { if err != nil { return err } - err = checkValidLocalExtension(cmd.Root(), m, wd) + _, err = checkValidExtension(cmd.Root(), m, filepath.Base(wd), "") if err != nil { return err } @@ -644,29 +644,9 @@ func NewCmdExtension(f *cmdutil.Factory) *cobra.Command { return &extCmd } -func checkValidLocalExtension(rootCmd *cobra.Command, m extensions.ExtensionManager, extDir string) error { - extName := filepath.Base(extDir) - if !strings.HasPrefix(extName, "gh-") { - return errors.New("extension directory name must start with `gh-`") - } - - commandName := strings.TrimPrefix(extName, "gh-") - if c, _, _ := rootCmd.Find([]string{commandName}); c != rootCmd && c.GroupID != "extension" { - return fmt.Errorf("%q matches the name of a built-in command or alias", commandName) - } - - for _, ext := range m.List() { - if ext.Name() == commandName { - return fmt.Errorf("there is already an installed extension that provides the %q command", commandName) - } - } - - return nil -} - func checkValidExtension(rootCmd *cobra.Command, m extensions.ExtensionManager, extName, extOwner string) (extensions.Extension, error) { if !strings.HasPrefix(extName, "gh-") { - return nil, errors.New("extension repository name must start with `gh-`") + return nil, errors.New("extension name must start with `gh-`") } commandName := strings.TrimPrefix(extName, "gh-") @@ -676,7 +656,7 @@ func checkValidExtension(rootCmd *cobra.Command, m extensions.ExtensionManager, for _, ext := range m.List() { if ext.Name() == commandName { - if ext.Owner() == extOwner { + if ext.Owner() != "" && ext.Owner() == extOwner { return ext, alreadyInstalledError } return ext, fmt.Errorf("there is already an installed extension that provides the %q command", commandName) diff --git a/pkg/cmd/extension/command_test.go b/pkg/cmd/extension/command_test.go index bddfdb93d..28364a685 100644 --- a/pkg/cmd/extension/command_test.go +++ b/pkg/cmd/extension/command_test.go @@ -1030,7 +1030,7 @@ func Test_checkValidExtension(t *testing.T) { } } -func Test_checkValidLocalExtension(t *testing.T) { +func Test_checkValidExtensionWithLocalExtension(t *testing.T) { fakeRootCmd := &cobra.Command{} fakeRootCmd.AddCommand(&cobra.Command{Use: "help"}) fakeRootCmd.AddCommand(&cobra.Command{Use: "auth"}) @@ -1077,7 +1077,7 @@ func Test_checkValidLocalExtension(t *testing.T) { manager: m, dir: "some/install/dir/hello", }, - wantError: "extension directory name must start with `gh-`", + wantError: "extension name must start with `gh-`", }, { name: "clashes with built-in command", @@ -1100,7 +1100,7 @@ func Test_checkValidLocalExtension(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := checkValidLocalExtension(tt.args.rootCmd, tt.args.manager, tt.args.dir) + _, err := checkValidExtension(tt.args.rootCmd, tt.args.manager, filepath.Base(tt.args.dir), "") if tt.wantError == "" { assert.NoError(t, err) } else {