diff --git a/git/git.go b/git/git.go index afa9ccf0f..abff866d6 100644 --- a/git/git.go +++ b/git/git.go @@ -368,30 +368,3 @@ func getBranchShortName(output []byte) string { branch := firstLine(output) return strings.TrimPrefix(branch, "refs/heads/") } - -func IsAncestor(ancestor, commit string) (bool, error) { - cmd, err := GitCommand("merge-base", "--is-ancestor", ancestor, commit) - if err != nil { - return false, err - } - err = run.PrepareCmd(cmd).Run() - return err == nil, nil -} - -func IsDirty() (bool, error) { - cmd, err := GitCommand("status", "--untracked-files=no", "--porcelain") - if err != nil { - return false, err - } - - output, err := run.PrepareCmd(cmd).Output() - if err != nil { - return true, err - } - - if len(output) > 0 { - return true, nil - } - - return false, nil -} diff --git a/pkg/cmd/repo/sync/git.go b/pkg/cmd/repo/sync/git.go new file mode 100644 index 000000000..2d0148d67 --- /dev/null +++ b/pkg/cmd/repo/sync/git.go @@ -0,0 +1,97 @@ +package sync + +import ( + "os/exec" + + "github.com/cli/cli/git" + "github.com/cli/cli/internal/run" +) + +type gitClient interface { + Checkout([]string) error + CurrentBranch() (string, error) + Fetch([]string) error + HasLocalBranch([]string) bool + IsAncestor([]string) (bool, error) + IsDirty() (bool, error) + Merge([]string) error + Reset([]string) error + Stash([]string) error +} + +type gitExecuter struct { + gitCommand func(args ...string) (*exec.Cmd, error) +} + +func (g *gitExecuter) Checkout(args []string) error { + return git.CheckoutBranch(args[0]) +} + +func (g *gitExecuter) CurrentBranch() (string, error) { + return git.CurrentBranch() +} + +func (g *gitExecuter) Fetch(args []string) error { + args = append([]string{"fetch"}, args...) + cmd, err := g.gitCommand(args...) + if err != nil { + return err + } + return run.PrepareCmd(cmd).Run() +} + +func (g *gitExecuter) HasLocalBranch(args []string) bool { + return git.HasLocalBranch(args[0]) +} + +func (g *gitExecuter) IsAncestor(args []string) (bool, error) { + args = append([]string{"merge-base", "--is-ancestor"}, args...) + cmd, err := g.gitCommand(args...) + if err != nil { + return false, err + } + err = run.PrepareCmd(cmd).Run() + return err == nil, nil +} + +func (g *gitExecuter) IsDirty() (bool, error) { + cmd, err := g.gitCommand("status", "--untracked-files=no", "--porcelain") + if err != nil { + return false, err + } + output, err := run.PrepareCmd(cmd).Output() + if err != nil { + return true, err + } + if len(output) > 0 { + return true, nil + } + return false, nil +} + +func (g *gitExecuter) Merge(args []string) error { + args = append([]string{"merge"}, args...) + cmd, err := g.gitCommand(args...) + if err != nil { + return err + } + return run.PrepareCmd(cmd).Run() +} + +func (g *gitExecuter) Reset(args []string) error { + args = append([]string{"reset"}, args...) + cmd, err := g.gitCommand(args...) + if err != nil { + return err + } + return run.PrepareCmd(cmd).Run() +} + +func (g *gitExecuter) Stash(args []string) error { + args = append([]string{"stash"}, args...) + cmd, err := g.gitCommand(args...) + if err != nil { + return err + } + return run.PrepareCmd(cmd).Run() +} diff --git a/pkg/cmd/repo/sync/sync.go b/pkg/cmd/repo/sync/sync.go index e6d5a891a..46fbb8f4b 100644 --- a/pkg/cmd/repo/sync/sync.go +++ b/pkg/cmd/repo/sync/sync.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" "net/http" - "os/exec" "regexp" "github.com/AlecAivazis/survey/v2" @@ -13,11 +12,9 @@ import ( "github.com/cli/cli/context" "github.com/cli/cli/git" "github.com/cli/cli/internal/ghrepo" - "github.com/cli/cli/internal/run" "github.com/cli/cli/pkg/cmdutil" "github.com/cli/cli/pkg/iostreams" "github.com/cli/cli/pkg/prompt" - "github.com/cli/safeexec" "github.com/spf13/cobra" ) @@ -27,6 +24,7 @@ type SyncOptions struct { BaseRepo func() (ghrepo.Interface, error) Remotes func() (context.Remotes, error) CurrentBranch func() (string, error) + Git gitClient DestArg string SrcArg string Branch string @@ -41,6 +39,7 @@ func NewCmdSync(f *cmdutil.Factory, runF func(*SyncOptions) error) *cobra.Comman BaseRepo: f.BaseRepo, Remotes: f.Remotes, CurrentBranch: f.Branch, + Git: &gitExecuter{gitCommand: git.GitCommand}, } cmd := &cobra.Command{ @@ -202,12 +201,16 @@ func syncLocalRepo(srcRepo ghrepo.Interface, opts *SyncOptions) error { } remote := remotes[0] branch := opts.Branch + git := opts.Git - _ = executeCmds([][]string{{"git", "fetch", remote.Name, fmt.Sprintf("+refs/heads/%s", branch)}}) + err = git.Fetch([]string{remote.Name, fmt.Sprintf("+refs/heads/%s", branch)}) + if err != nil { + return err + } - hasLocalBranch := git.HasLocalBranch(branch) + hasLocalBranch := git.HasLocalBranch([]string{branch}) if hasLocalBranch { - fastForward, err := git.IsAncestor(branch, fmt.Sprintf("%s/%s", remote.Name, branch)) + fastForward, err := git.IsAncestor([]string{branch, fmt.Sprintf("%s/%s", remote.Name, branch)}) if err != nil { return err } @@ -217,38 +220,54 @@ func syncLocalRepo(srcRepo ghrepo.Interface, opts *SyncOptions) error { } } - startBranch, err := opts.CurrentBranch() - if err != nil { - return err - } - dirtyRepo, err := git.IsDirty() if err != nil { return err } + startBranch, err := git.CurrentBranch() + if err != nil { + return err + } - var cmds [][]string if dirtyRepo { - cmds = append(cmds, []string{"git", "stash", "push"}) - } - if startBranch != branch { - cmds = append(cmds, []string{"git", "checkout", branch}) - } - if hasLocalBranch { - if opts.Force { - cmds = append(cmds, []string{"git", "reset", "--hard", fmt.Sprintf("refs/remotes/%s/%s", remote, branch)}) - } else { - cmds = append(cmds, []string{"git", "merge", "--ff-only", fmt.Sprintf("refs/remotes/%s/%s", remote, branch)}) + err = git.Stash([]string{"push"}) + if err != nil { + return err } } if startBranch != branch { - cmds = append(cmds, []string{"git", "checkout", startBranch}) + err = git.Checkout([]string{branch}) + if err != nil { + return err + } + } + if hasLocalBranch { + if opts.Force { + err = git.Reset([]string{"--hard", fmt.Sprintf("refs/remotes/%s/%s", remote, branch)}) + if err != nil { + return err + } + } else { + err = git.Merge([]string{"--ff-only", fmt.Sprintf("refs/remotes/%s/%s", remote, branch)}) + if err != nil { + return err + } + } + } + if startBranch != branch { + err = git.Checkout([]string{startBranch}) + if err != nil { + return err + } } if dirtyRepo { - cmds = append(cmds, []string{"git", "stash", "pop"}) + err = git.Stash([]string{"pop"}) + if err != nil { + return err + } } - return executeCmds(cmds) + return nil } func syncRemoteRepo(client *api.Client, destRepo, srcRepo ghrepo.Interface, opts *SyncOptions) error { @@ -270,17 +289,3 @@ func syncRemoteRepo(client *api.Client, destRepo, srcRepo ghrepo.Interface, opts return err } - -func executeCmds(cmdQueue [][]string) error { - exe, err := safeexec.LookPath("git") - if err != nil { - return err - } - for _, args := range cmdQueue { - cmd := exec.Command(exe, args[1:]...) - if err := run.PrepareCmd(cmd).Run(); err != nil { - return err - } - } - return nil -} diff --git a/pkg/cmd/repo/sync/sync_test.go b/pkg/cmd/repo/sync/sync_test.go index 85d213ad8..e3092642d 100644 --- a/pkg/cmd/repo/sync/sync_test.go +++ b/pkg/cmd/repo/sync/sync_test.go @@ -17,6 +17,7 @@ func TestNewCmdSync(t *testing.T) { input string output SyncOptions wantsErr bool + errMsg string }{ { name: "no argument", @@ -69,6 +70,7 @@ func TestNewCmdSync(t *testing.T) { tty: false, input: "", wantsErr: true, + errMsg: "`--confirm` required when not running interactively", }, } for _, tt := range tests { @@ -94,6 +96,7 @@ func TestNewCmdSync(t *testing.T) { _, err = cmd.ExecuteC() if tt.wantsErr { assert.Error(t, err) + assert.Equal(t, tt.errMsg, err.Error()) return }