Refactor to remove code duplication

This commit is contained in:
bagtoad 2024-10-21 12:37:27 -06:00
parent 6923fb5cc8
commit cc32f33583
2 changed files with 6 additions and 26 deletions

View file

@ -318,7 +318,7 @@ func NewCmdExtension(f *cmdutil.Factory) *cobra.Command {
if err != nil {
return err
}
err = checkValidLocalExtension(cmd.Root(), m, wd)
_, err = checkValidExtension(cmd.Root(), m, filepath.Base(wd), "")
if err != nil {
return err
}
@ -644,29 +644,9 @@ func NewCmdExtension(f *cmdutil.Factory) *cobra.Command {
return &extCmd
}
func checkValidLocalExtension(rootCmd *cobra.Command, m extensions.ExtensionManager, extDir string) error {
extName := filepath.Base(extDir)
if !strings.HasPrefix(extName, "gh-") {
return errors.New("extension directory name must start with `gh-`")
}
commandName := strings.TrimPrefix(extName, "gh-")
if c, _, _ := rootCmd.Find([]string{commandName}); c != rootCmd && c.GroupID != "extension" {
return fmt.Errorf("%q matches the name of a built-in command or alias", commandName)
}
for _, ext := range m.List() {
if ext.Name() == commandName {
return fmt.Errorf("there is already an installed extension that provides the %q command", commandName)
}
}
return nil
}
func checkValidExtension(rootCmd *cobra.Command, m extensions.ExtensionManager, extName, extOwner string) (extensions.Extension, error) {
if !strings.HasPrefix(extName, "gh-") {
return nil, errors.New("extension repository name must start with `gh-`")
return nil, errors.New("extension name must start with `gh-`")
}
commandName := strings.TrimPrefix(extName, "gh-")
@ -676,7 +656,7 @@ func checkValidExtension(rootCmd *cobra.Command, m extensions.ExtensionManager,
for _, ext := range m.List() {
if ext.Name() == commandName {
if ext.Owner() == extOwner {
if ext.Owner() != "" && ext.Owner() == extOwner {
return ext, alreadyInstalledError
}
return ext, fmt.Errorf("there is already an installed extension that provides the %q command", commandName)

View file

@ -1030,7 +1030,7 @@ func Test_checkValidExtension(t *testing.T) {
}
}
func Test_checkValidLocalExtension(t *testing.T) {
func Test_checkValidExtensionWithLocalExtension(t *testing.T) {
fakeRootCmd := &cobra.Command{}
fakeRootCmd.AddCommand(&cobra.Command{Use: "help"})
fakeRootCmd.AddCommand(&cobra.Command{Use: "auth"})
@ -1077,7 +1077,7 @@ func Test_checkValidLocalExtension(t *testing.T) {
manager: m,
dir: "some/install/dir/hello",
},
wantError: "extension directory name must start with `gh-`",
wantError: "extension name must start with `gh-`",
},
{
name: "clashes with built-in command",
@ -1100,7 +1100,7 @@ func Test_checkValidLocalExtension(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := checkValidLocalExtension(tt.args.rootCmd, tt.args.manager, tt.args.dir)
_, err := checkValidExtension(tt.args.rootCmd, tt.args.manager, filepath.Base(tt.args.dir), "")
if tt.wantError == "" {
assert.NoError(t, err)
} else {