diff --git a/pkg/cmd/extensions/command.go b/pkg/cmd/extensions/command.go index 7e6644900..e6b0a52ee 100644 --- a/pkg/cmd/extensions/command.go +++ b/pkg/cmd/extensions/command.go @@ -96,6 +96,7 @@ func NewCmdExtensions(f *cmdutil.Factory) *cobra.Command { }, func() *cobra.Command { var flagAll bool + var flagForce bool cmd := &cobra.Command{ Use: "upgrade { | --all}", Short: "Upgrade installed extensions", @@ -116,10 +117,11 @@ func NewCmdExtensions(f *cmdutil.Factory) *cobra.Command { if len(args) > 0 { name = args[0] } - return m.Upgrade(name, io.Out, io.ErrOut) + return m.Upgrade(name, flagForce, io.Out, io.ErrOut) }, } cmd.Flags().BoolVar(&flagAll, "all", false, "Upgrade all extensions") + cmd.Flags().BoolVar(&flagForce, "force", false, "Force upgrade extension") return cmd }(), &cobra.Command{ diff --git a/pkg/cmd/extensions/command_test.go b/pkg/cmd/extensions/command_test.go index ac173f3c8..59f4486d0 100644 --- a/pkg/cmd/extensions/command_test.go +++ b/pkg/cmd/extensions/command_test.go @@ -90,7 +90,7 @@ func TestNewCmdExtensions(t *testing.T) { name: "upgrade an extension", args: []string{"upgrade", "hello"}, managerStubs: func(em *extensions.ExtensionManagerMock) func(*testing.T) { - em.UpgradeFunc = func(name string, out, errOut io.Writer) error { + em.UpgradeFunc = func(name string, force bool, out, errOut io.Writer) error { return nil } return func(t *testing.T) { @@ -104,7 +104,7 @@ func TestNewCmdExtensions(t *testing.T) { name: "upgrade all", args: []string{"upgrade", "--all"}, managerStubs: func(em *extensions.ExtensionManagerMock) func(*testing.T) { - em.UpgradeFunc = func(name string, out, errOut io.Writer) error { + em.UpgradeFunc = func(name string, force bool, out, errOut io.Writer) error { return nil } return func(t *testing.T) { diff --git a/pkg/cmd/extensions/manager.go b/pkg/cmd/extensions/manager.go index bbca3395e..4f7d48fd5 100644 --- a/pkg/cmd/extensions/manager.go +++ b/pkg/cmd/extensions/manager.go @@ -167,7 +167,7 @@ func (m *Manager) Install(cloneURL string, stdout, stderr io.Writer) error { var localExtensionUpgradeError = errors.New("local extensions can not be upgraded") -func (m *Manager) Upgrade(name string, stdout, stderr io.Writer) error { +func (m *Manager) Upgrade(name string, force bool, stdout, stderr io.Writer) error { exe, err := m.lookPath("git") if err != nil { return err @@ -195,11 +195,17 @@ func (m *Manager) Upgrade(name string, stdout, stderr io.Writer) error { continue } + var cmds []*exec.Cmd dir := filepath.Dir(f.Path()) - externalCmd := m.newCommand(exe, "-C", dir, "--git-dir="+filepath.Join(dir, ".git"), "pull", "--ff-only") - externalCmd.Stdout = stdout - externalCmd.Stderr = stderr - if e := externalCmd.Run(); e != nil { + if force { + fetchCmd := m.newCommand(exe, "-C", dir, "--git-dir="+filepath.Join(dir, ".git"), "fetch", "origin", "HEAD") + resetCmd := m.newCommand(exe, "-C", dir, "--git-dir="+filepath.Join(dir, ".git"), "reset", "--hard", "origin/HEAD") + cmds = []*exec.Cmd{fetchCmd, resetCmd} + } else { + pullCmd := m.newCommand(exe, "-C", dir, "--git-dir="+filepath.Join(dir, ".git"), "pull", "--ff-only") + cmds = []*exec.Cmd{pullCmd} + } + if e := runCmds(cmds, stdout, stderr); e != nil { err = e } someUpgraded = true @@ -221,3 +227,14 @@ func (m *Manager) Remove(name string) error { func (m *Manager) installDir() string { return filepath.Join(m.dataDir(), "extensions") } + +func runCmds(cmds []*exec.Cmd, stdout, stderr io.Writer) error { + for _, cmd := range cmds { + cmd.Stdout = stdout + cmd.Stderr = stderr + if err := cmd.Run(); err != nil { + return err + } + } + return nil +} diff --git a/pkg/cmd/extensions/manager_test.go b/pkg/cmd/extensions/manager_test.go index 693bfed89..513f19710 100644 --- a/pkg/cmd/extensions/manager_test.go +++ b/pkg/cmd/extensions/manager_test.go @@ -100,7 +100,7 @@ func TestManager_Upgrade_AllExtensions(t *testing.T) { stdout := &bytes.Buffer{} stderr := &bytes.Buffer{} - err := m.Upgrade("", stdout, stderr) + err := m.Upgrade("", false, stdout, stderr) assert.NoError(t, err) assert.Equal(t, heredoc.Docf( @@ -125,7 +125,7 @@ func TestManager_Upgrade_RemoteExtension(t *testing.T) { stdout := &bytes.Buffer{} stderr := &bytes.Buffer{} - err := m.Upgrade("remote", stdout, stderr) + err := m.Upgrade("remote", false, stdout, stderr) assert.NoError(t, err) assert.Equal(t, heredoc.Docf( ` @@ -145,12 +145,38 @@ func TestManager_Upgrade_LocalExtension(t *testing.T) { stdout := &bytes.Buffer{} stderr := &bytes.Buffer{} - err := m.Upgrade("local", stdout, stderr) + err := m.Upgrade("local", false, stdout, stderr) assert.EqualError(t, err, "local extensions can not be upgraded") assert.Equal(t, "", stdout.String()) assert.Equal(t, "", stderr.String()) } +func TestManager_Upgrade_Force(t *testing.T) { + tempDir := t.TempDir() + extensionDir := filepath.Join(tempDir, "extensions", "gh-remote") + gitDir := filepath.Join(tempDir, "extensions", "gh-remote", ".git") + + assert.NoError(t, stubExtension(filepath.Join(tempDir, "extensions", "gh-remote", "gh-remote"))) + + m := newTestManager(tempDir) + + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + err := m.Upgrade("remote", true, stdout, stderr) + assert.NoError(t, err) + assert.Equal(t, heredoc.Docf( + ` + [git -C %s --git-dir=%s fetch origin HEAD] + [git -C %s --git-dir=%s reset --hard origin/HEAD] + `, + extensionDir, + gitDir, + extensionDir, + gitDir, + ), stdout.String()) + assert.Equal(t, "", stderr.String()) +} + func TestManager_Upgrade_NoExtensions(t *testing.T) { tempDir := t.TempDir() @@ -158,7 +184,7 @@ func TestManager_Upgrade_NoExtensions(t *testing.T) { stdout := &bytes.Buffer{} stderr := &bytes.Buffer{} - err := m.Upgrade("", stdout, stderr) + err := m.Upgrade("", false, stdout, stderr) assert.EqualError(t, err, "no extensions installed") assert.Equal(t, "", stdout.String()) assert.Equal(t, "", stderr.String()) diff --git a/pkg/extensions/extension.go b/pkg/extensions/extension.go index 4007e7d28..4760e0300 100644 --- a/pkg/extensions/extension.go +++ b/pkg/extensions/extension.go @@ -18,7 +18,7 @@ type ExtensionManager interface { List() []Extension Install(url string, stdout, stderr io.Writer) error InstallLocal(dir string) error - Upgrade(name string, stdout, stderr io.Writer) error + Upgrade(name string, force bool, stdout, stderr io.Writer) error Remove(name string) error Dispatch(args []string, stdin io.Reader, stdout, stderr io.Writer) (bool, error) } diff --git a/pkg/extensions/manager_mock.go b/pkg/extensions/manager_mock.go index af71a7904..224288f4b 100644 --- a/pkg/extensions/manager_mock.go +++ b/pkg/extensions/manager_mock.go @@ -33,7 +33,7 @@ var _ ExtensionManager = &ExtensionManagerMock{} // RemoveFunc: func(name string) error { // panic("mock out the Remove method") // }, -// UpgradeFunc: func(name string, stdout io.Writer, stderr io.Writer) error { +// UpgradeFunc: func(name string, force bool, stdout io.Writer, stderr io.Writer) error { // panic("mock out the Upgrade method") // }, // } @@ -59,7 +59,7 @@ type ExtensionManagerMock struct { RemoveFunc func(name string) error // UpgradeFunc mocks the Upgrade method. - UpgradeFunc func(name string, stdout io.Writer, stderr io.Writer) error + UpgradeFunc func(name string, force bool, stdout io.Writer, stderr io.Writer) error // calls tracks calls to the methods. calls struct { @@ -100,6 +100,8 @@ type ExtensionManagerMock struct { Upgrade []struct { // Name is the name argument value. Name string + // Force is the force argument value. + Force bool // Stdout is the stdout argument value. Stdout io.Writer // Stderr is the stderr argument value. @@ -285,23 +287,25 @@ func (mock *ExtensionManagerMock) RemoveCalls() []struct { } // Upgrade calls UpgradeFunc. -func (mock *ExtensionManagerMock) Upgrade(name string, stdout io.Writer, stderr io.Writer) error { +func (mock *ExtensionManagerMock) Upgrade(name string, force bool, stdout io.Writer, stderr io.Writer) error { if mock.UpgradeFunc == nil { panic("ExtensionManagerMock.UpgradeFunc: method is nil but ExtensionManager.Upgrade was just called") } callInfo := struct { Name string + Force bool Stdout io.Writer Stderr io.Writer }{ Name: name, + Force: force, Stdout: stdout, Stderr: stderr, } mock.lockUpgrade.Lock() mock.calls.Upgrade = append(mock.calls.Upgrade, callInfo) mock.lockUpgrade.Unlock() - return mock.UpgradeFunc(name, stdout, stderr) + return mock.UpgradeFunc(name, force, stdout, stderr) } // UpgradeCalls gets all the calls that were made to Upgrade. @@ -309,11 +313,13 @@ func (mock *ExtensionManagerMock) Upgrade(name string, stdout io.Writer, stderr // len(mockedExtensionManager.UpgradeCalls()) func (mock *ExtensionManagerMock) UpgradeCalls() []struct { Name string + Force bool Stdout io.Writer Stderr io.Writer } { var calls []struct { Name string + Force bool Stdout io.Writer Stderr io.Writer }