Rework local extensions for Windows

Replace the implementation that relied on symlinks with the one that
create regular files that act like symlinks: they contain a reference to
the local directory where to find the extension.
This commit is contained in:
Mislav Marohnić 2021-07-28 22:47:54 +02:00
parent 4b499be96b
commit 0d999ddaa1
11 changed files with 107 additions and 68 deletions

View file

@ -166,7 +166,7 @@ func mainRun() exitCode {
}
}
}
for _, ext := range cmdFactory.ExtensionManager.List() {
for _, ext := range cmdFactory.ExtensionManager.List(false) {
if strings.HasPrefix(ext.Name(), toComplete) {
results = append(results, ext.Name())
}

View file

@ -82,7 +82,7 @@ func NewCmdSet(f *cmdutil.Factory, runF func(*SetOptions) error) *cobra.Command
return true
}
for _, ext := range f.ExtensionManager.List() {
for _, ext := range f.ExtensionManager.List(false) {
if ext.Name() == split[0] {
return true
}

View file

@ -30,7 +30,7 @@ func runCommand(cfg config.Config, isTTY bool, cli string, in string) (*test.Cmd
return cfg, nil
},
ExtensionManager: &extensions.ExtensionManagerMock{
ListFunc: func() []extensions.Extension {
ListFunc: func(bool) []extensions.Extension {
return []extensions.Extension{}
},
},

View file

@ -39,7 +39,7 @@ func NewCmdExtensions(f *cmdutil.Factory) *cobra.Command {
Short: "List installed extension commands",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, args []string) error {
cmds := m.List()
cmds := m.List(true)
if len(cmds) == 0 {
return errors.New("no extensions installed")
}
@ -158,7 +158,7 @@ func checkValidExtension(rootCmd *cobra.Command, m extensions.ExtensionManager,
return fmt.Errorf("%q matches the name of a built-in command", commandName)
}
for _, ext := range m.List() {
for _, ext := range m.List(false) {
if ext.Name() == commandName {
return fmt.Errorf("there is already an installed extension that provides the %q command", commandName)
}

View file

@ -35,7 +35,7 @@ func TestNewCmdExtensions(t *testing.T) {
name: "install an extension",
args: []string{"install", "owner/gh-some-ext"},
managerStubs: func(em *extensions.ExtensionManagerMock) func(*testing.T) {
em.ListFunc = func() []extensions.Extension {
em.ListFunc = func(bool) []extensions.Extension {
return []extensions.Extension{}
}
em.InstallFunc = func(s string, out, errOut io.Writer) error {
@ -54,7 +54,7 @@ func TestNewCmdExtensions(t *testing.T) {
name: "install an extension with same name as existing extension",
args: []string{"install", "owner/gh-existing-ext"},
managerStubs: func(em *extensions.ExtensionManagerMock) func(*testing.T) {
em.ListFunc = func() []extensions.Extension {
em.ListFunc = func(bool) []extensions.Extension {
e := &Extension{path: "owner2/gh-existing-ext"}
return []extensions.Extension{e}
}
@ -150,7 +150,7 @@ func TestNewCmdExtensions(t *testing.T) {
name: "list extensions",
args: []string{"list"},
managerStubs: func(em *extensions.ExtensionManagerMock) func(*testing.T) {
em.ListFunc = func() []extensions.Extension {
em.ListFunc = func(bool) []extensions.Extension {
ex1 := &Extension{path: "cli/gh-test", url: "https://github.com/cli/gh-test", updateAvailable: false}
ex2 := &Extension{path: "cli/gh-test2", url: "https://github.com/cli/gh-test2", updateAvailable: true}
return []extensions.Extension{ex1, ex2}
@ -215,7 +215,7 @@ func Test_checkValidExtension(t *testing.T) {
rootCmd.AddCommand(&cobra.Command{Use: "auth"})
m := &extensions.ExtensionManagerMock{
ListFunc: func() []extensions.Extension {
ListFunc: func(bool) []extensions.Extension {
return []extensions.Extension{
&extensions.ExtensionMock{
NameFunc: func() string { return "screensaver" },

View file

@ -1,8 +1,6 @@
package extensions
import (
"errors"
"os"
"path/filepath"
"strings"
)
@ -10,6 +8,7 @@ import (
type Extension struct {
path string
url string
isLocal bool
updateAvailable bool
}
@ -26,20 +25,7 @@ func (e *Extension) URL() string {
}
func (e *Extension) IsLocal() bool {
dir := filepath.Dir(e.path)
fileInfo, err := os.Lstat(dir)
if err != nil {
return false
}
// Check if extension is a symlink
if fileInfo.Mode()&os.ModeSymlink != 0 {
return true
}
// Check if extension does not have a git directory
if _, err = os.Stat(filepath.Join(dir, ".git")); errors.Is(err, os.ErrNotExist) {
return true
}
return false
return e.isLocal
}
func (e *Extension) UpdateAvailable() bool {

View file

@ -44,7 +44,8 @@ func (m *Manager) Dispatch(args []string, stdin io.Reader, stdout, stderr io.Wri
extName := args[0]
forwardArgs := args[1:]
for _, e := range m.list(false) {
exts, _ := m.list(false)
for _, e := range exts {
if e.Name() == extName {
exe = e.Path()
break
@ -77,34 +78,56 @@ func (m *Manager) Dispatch(args []string, stdin io.Reader, stdout, stderr io.Wri
return true, externalCmd.Run()
}
func (m *Manager) List() []extensions.Extension {
return m.list(true)
func (m *Manager) List(includeMetadata bool) []extensions.Extension {
exts, _ := m.list(includeMetadata)
return exts
}
func (m *Manager) list(includeMetadata bool) []extensions.Extension {
func (m *Manager) list(includeMetadata bool) ([]extensions.Extension, error) {
dir := m.installDir()
entries, err := ioutil.ReadDir(dir)
if err != nil {
return nil
return nil, err
}
var results []extensions.Extension
for _, f := range entries {
if !strings.HasPrefix(f.Name(), "gh-") || !(f.IsDir() || f.Mode()&os.ModeSymlink != 0) {
if !strings.HasPrefix(f.Name(), "gh-") {
continue
}
var remoteUrl string
var updateAvailable bool
if includeMetadata {
remoteUrl = m.getRemoteUrl(f.Name())
updateAvailable = m.checkUpdateAvailable(f.Name())
updateAvailable := false
isLocal := false
exePath := filepath.Join(dir, f.Name(), f.Name())
if f.IsDir() {
if includeMetadata {
remoteUrl = m.getRemoteUrl(f.Name())
updateAvailable = m.checkUpdateAvailable(f.Name())
}
} else {
isLocal = true
if f.Mode()&os.ModeSymlink == 0 {
// if this is a regular file, its contents is the local directory of the extension
exeFile, err := os.Open(filepath.Join(dir, f.Name()))
if err != nil {
return nil, err
}
b := make([]byte, 1024)
n, err := exeFile.Read(b)
if err != nil {
return nil, err
}
exePath = filepath.Join(strings.TrimSpace(string(b[:n])), f.Name())
}
}
results = append(results, &Extension{
path: filepath.Join(dir, f.Name(), f.Name()),
path: exePath,
url: remoteUrl,
isLocal: isLocal,
updateAvailable: updateAvailable,
})
}
return results
return results, nil
}
func (m *Manager) getRemoteUrl(extension string) string {
@ -146,8 +169,19 @@ func (m *Manager) checkUpdateAvailable(extension string) bool {
func (m *Manager) InstallLocal(dir string) error {
name := filepath.Base(dir)
targetDir := filepath.Join(m.installDir(), name)
return os.Symlink(dir, targetDir)
targetLink := filepath.Join(m.installDir(), name)
if err := os.MkdirAll(filepath.Dir(targetLink), 0755); err != nil {
return err
}
// Create a regular file that contains the location of the directory where to find this extension. We
// avoid relying on symlinks because creating them on Windows requires administrator privileges.
f, err := os.OpenFile(targetLink, os.O_WRONLY|os.O_CREATE, 0644)
if err != nil {
return err
}
defer f.Close()
_, err = f.WriteString(dir)
return err
}
func (m *Manager) Install(cloneURL string, stdout, stderr io.Writer) error {
@ -173,7 +207,7 @@ func (m *Manager) Upgrade(name string, force bool, stdout, stderr io.Writer) err
return err
}
exts := m.List()
exts := m.List(false)
if len(exts) == 0 {
return errors.New("no extensions installed")
}

View file

@ -48,7 +48,7 @@ func TestManager_List(t *testing.T) {
assert.NoError(t, stubExtension(filepath.Join(tempDir, "extensions", "gh-two", "gh-two")))
m := newTestManager(tempDir)
exts := m.List()
exts := m.List(false)
assert.Equal(t, 2, len(exts))
assert.Equal(t, "hello", exts[0].Name())
assert.Equal(t, "two", exts[1].Name())
@ -94,7 +94,7 @@ func TestManager_Upgrade_AllExtensions(t *testing.T) {
tempDir := t.TempDir()
assert.NoError(t, stubExtension(filepath.Join(tempDir, "extensions", "gh-hello", "gh-hello")))
assert.NoError(t, stubExtension(filepath.Join(tempDir, "extensions", "gh-two", "gh-two")))
assert.NoError(t, stubLocalExtension(filepath.Join(tempDir, "extensions", "gh-local", "gh-local")))
assert.NoError(t, stubLocalExtension(tempDir, filepath.Join(tempDir, "extensions", "gh-local", "gh-local")))
m := newTestManager(tempDir)
@ -139,7 +139,7 @@ func TestManager_Upgrade_RemoteExtension(t *testing.T) {
func TestManager_Upgrade_LocalExtension(t *testing.T) {
tempDir := t.TempDir()
assert.NoError(t, stubLocalExtension(filepath.Join(tempDir, "extensions", "gh-local", "gh-local")))
assert.NoError(t, stubLocalExtension(tempDir, filepath.Join(tempDir, "extensions", "gh-local", "gh-local")))
m := newTestManager(tempDir)
@ -203,22 +203,6 @@ func TestManager_Install(t *testing.T) {
}
func stubExtension(path string) error {
dir := filepath.Dir(path)
gitDir := filepath.Join(dir, ".git")
if err := os.MkdirAll(dir, 0755); err != nil {
return err
}
if err := os.Mkdir(gitDir, 0755); err != nil {
return err
}
f, err := os.OpenFile(path, os.O_CREATE, 0755)
if err != nil {
return err
}
return f.Close()
}
func stubLocalExtension(path string) error {
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
return err
}
@ -228,3 +212,31 @@ func stubLocalExtension(path string) error {
}
return f.Close()
}
func stubLocalExtension(tempDir, path string) error {
extDir, err := os.MkdirTemp(tempDir, "local-ext")
if err != nil {
return err
}
extFile, err := os.OpenFile(filepath.Join(extDir, filepath.Base(path)), os.O_CREATE, 0755)
if err != nil {
return err
}
if err := extFile.Close(); err != nil {
return err
}
linkPath := filepath.Dir(path)
if err := os.MkdirAll(filepath.Dir(linkPath), 0755); err != nil {
return err
}
f, err := os.OpenFile(linkPath, os.O_WRONLY|os.O_CREATE, 0644)
if err != nil {
return err
}
_, err = f.WriteString(extDir)
if err != nil {
return err
}
return f.Close()
}

View file

@ -145,7 +145,7 @@ func rootHelpFunc(f *cmdutil.Factory, command *cobra.Command, args []string) {
}
if isRootCmd(command) {
if exts := f.ExtensionManager.List(); len(exts) > 0 {
if exts := f.ExtensionManager.List(false); len(exts) > 0 {
var names []string
for _, ext := range exts {
names = append(names, ext.Name())

View file

@ -4,7 +4,7 @@ import (
"io"
)
//go:generate moq -out extension_mock.go . Extension
//go:generate moq -rm -out extension_mock.go . Extension
type Extension interface {
Name() string
Path() string
@ -13,9 +13,9 @@ type Extension interface {
UpdateAvailable() bool
}
//go:generate moq -out manager_mock.go . ExtensionManager
//go:generate moq -rm -out manager_mock.go . ExtensionManager
type ExtensionManager interface {
List() []Extension
List(includeMetadata bool) []Extension
Install(url string, stdout, stderr io.Writer) error
InstallLocal(dir string) error
Upgrade(name string, force bool, stdout, stderr io.Writer) error

View file

@ -27,7 +27,7 @@ var _ ExtensionManager = &ExtensionManagerMock{}
// InstallLocalFunc: func(dir string) error {
// panic("mock out the InstallLocal method")
// },
// ListFunc: func() []Extension {
// ListFunc: func(includeMetadata bool) []Extension {
// panic("mock out the List method")
// },
// RemoveFunc: func(name string) error {
@ -53,7 +53,7 @@ type ExtensionManagerMock struct {
InstallLocalFunc func(dir string) error
// ListFunc mocks the List method.
ListFunc func() []Extension
ListFunc func(includeMetadata bool) []Extension
// RemoveFunc mocks the Remove method.
RemoveFunc func(name string) error
@ -90,6 +90,8 @@ type ExtensionManagerMock struct {
}
// List holds details about calls to the List method.
List []struct {
// IncludeMetadata is the includeMetadata argument value.
IncludeMetadata bool
}
// Remove holds details about calls to the Remove method.
Remove []struct {
@ -230,24 +232,29 @@ func (mock *ExtensionManagerMock) InstallLocalCalls() []struct {
}
// List calls ListFunc.
func (mock *ExtensionManagerMock) List() []Extension {
func (mock *ExtensionManagerMock) List(includeMetadata bool) []Extension {
if mock.ListFunc == nil {
panic("ExtensionManagerMock.ListFunc: method is nil but ExtensionManager.List was just called")
}
callInfo := struct {
}{}
IncludeMetadata bool
}{
IncludeMetadata: includeMetadata,
}
mock.lockList.Lock()
mock.calls.List = append(mock.calls.List, callInfo)
mock.lockList.Unlock()
return mock.ListFunc()
return mock.ListFunc(includeMetadata)
}
// ListCalls gets all the calls that were made to List.
// Check the length with:
// len(mockedExtensionManager.ListCalls())
func (mock *ExtensionManagerMock) ListCalls() []struct {
IncludeMetadata bool
} {
var calls []struct {
IncludeMetadata bool
}
mock.lockList.RLock()
calls = mock.calls.List