diff --git a/pkg/cmd/repo/sync/git.go b/pkg/cmd/repo/sync/git.go index a63552377..54da42719 100644 --- a/pkg/cmd/repo/sync/git.go +++ b/pkg/cmd/repo/sync/git.go @@ -1,10 +1,14 @@ package sync import ( + "fmt" + "strings" + "github.com/cli/cli/git" ) type gitClient interface { + BranchRemote(string) (string, error) Checkout([]string) error CurrentBranch() (string, error) Fetch([]string) error @@ -17,8 +21,27 @@ type gitClient interface { type gitExecuter struct{} +func (g *gitExecuter) BranchRemote(branch string) (string, error) { + args := append([]string{"rev-parse", "--symbolic-full-name", "--abbrev-ref", fmt.Sprintf("%s@{u}", branch)}) + cmd, err := git.GitCommand(args...) + if err != nil { + return "", err + } + out, err := cmd.Output() + if err != nil { + return "", err + } + parts := strings.SplitN(string(out), "/", 2) + return parts[0], nil +} + func (g *gitExecuter) Checkout(args []string) error { - return git.CheckoutBranch(args[0]) + args = append([]string{"checkout"}, args...) + cmd, err := git.GitCommand(args...) + if err != nil { + return err + } + return cmd.Run() } func (g *gitExecuter) CurrentBranch() (string, error) { diff --git a/pkg/cmd/repo/sync/mocks.go b/pkg/cmd/repo/sync/mocks.go index 3ef9b1dbc..fc66a2b30 100644 --- a/pkg/cmd/repo/sync/mocks.go +++ b/pkg/cmd/repo/sync/mocks.go @@ -8,6 +8,11 @@ type mockGitClient struct { mock.Mock } +func (g *mockGitClient) BranchRemote(a string) (string, error) { + args := g.Called(a) + return args.String(0), args.Error(1) +} + func (g *mockGitClient) Checkout(a []string) error { args := g.Called(a) return args.Error(0) diff --git a/pkg/cmd/repo/sync/sync.go b/pkg/cmd/repo/sync/sync.go index 23c6d1f11..9178e2a69 100644 --- a/pkg/cmd/repo/sync/sync.go +++ b/pkg/cmd/repo/sync/sync.go @@ -152,6 +152,9 @@ func syncLocalRepo(opts *SyncOptions) error { if errors.Is(err, divergingError) { return fmt.Errorf("can't sync because there are diverging changes, you can use `--force` to overwrite the changes") } + if errors.Is(err, mismatchRemotesError) { + return fmt.Errorf("can't sync because %s is not tracking %s", opts.Branch, ghrepo.FullName(srcRepo)) + } return err } @@ -232,6 +235,7 @@ func syncRemoteRepo(opts *SyncOptions) error { } var divergingError = errors.New("diverging changes") +var mismatchRemotesError = errors.New("branch remote does not match specified source") func executeLocalRepoSync(srcRepo ghrepo.Interface, remote string, opts *SyncOptions) error { git := opts.Git @@ -244,6 +248,11 @@ func executeLocalRepoSync(srcRepo ghrepo.Interface, remote string, opts *SyncOpt hasLocalBranch := git.HasLocalBranch([]string{branch}) if hasLocalBranch { + branchRemote, err := git.BranchRemote(branch) + if branchRemote != remote { + return mismatchRemotesError + } + fastForward, err := git.IsAncestor([]string{branch, fmt.Sprintf("%s/%s", remote, branch)}) if err != nil { return err @@ -259,9 +268,16 @@ func executeLocalRepoSync(srcRepo ghrepo.Interface, remote string, opts *SyncOpt return err } if startBranch != branch { - err = git.Checkout([]string{branch}) - if err != nil { - return err + if hasLocalBranch { + err = git.Checkout([]string{branch}) + if err != nil { + return err + } + } else { + err = git.Checkout([]string{"--track", fmt.Sprintf("%s/%s", remote, branch)}) + if err != nil { + return err + } } } if hasLocalBranch { diff --git a/pkg/cmd/repo/sync/sync_test.go b/pkg/cmd/repo/sync/sync_test.go index 237d7cadd..d2feb4f85 100644 --- a/pkg/cmd/repo/sync/sync_test.go +++ b/pkg/cmd/repo/sync/sync_test.go @@ -125,6 +125,7 @@ func Test_SyncRun(t *testing.T) { mgc.On("IsDirty").Return(false, nil).Once() mgc.On("Fetch", []string{"origin", "+refs/heads/trunk"}).Return(nil).Once() mgc.On("HasLocalBranch", []string{"trunk"}).Return(true).Once() + mgc.On("BranchRemote", "trunk").Return("origin", nil).Once() mgc.On("IsAncestor", []string{"trunk", "origin/trunk"}).Return(true, nil).Once() mgc.On("CurrentBranch").Return("trunk", nil).Once() mgc.On("Merge", []string{"--ff-only", "refs/remotes/origin/trunk"}).Return(nil).Once() @@ -144,6 +145,7 @@ func Test_SyncRun(t *testing.T) { mgc.On("IsDirty").Return(false, nil).Once() mgc.On("Fetch", []string{"origin", "+refs/heads/trunk"}).Return(nil).Once() mgc.On("HasLocalBranch", []string{"trunk"}).Return(true).Once() + mgc.On("BranchRemote", "trunk").Return("origin", nil).Once() mgc.On("IsAncestor", []string{"trunk", "origin/trunk"}).Return(true, nil).Once() mgc.On("CurrentBranch").Return("trunk", nil).Once() mgc.On("Merge", []string{"--ff-only", "refs/remotes/origin/trunk"}).Return(nil).Once() @@ -165,6 +167,7 @@ func Test_SyncRun(t *testing.T) { mgc.On("IsDirty").Return(false, nil).Once() mgc.On("Fetch", []string{"upstream", "+refs/heads/trunk"}).Return(nil).Once() mgc.On("HasLocalBranch", []string{"trunk"}).Return(true).Once() + mgc.On("BranchRemote", "trunk").Return("upstream", nil).Once() mgc.On("IsAncestor", []string{"trunk", "upstream/trunk"}).Return(true, nil).Once() mgc.On("CurrentBranch").Return("trunk", nil).Once() mgc.On("Merge", []string{"--ff-only", "refs/remotes/upstream/trunk"}).Return(nil).Once() @@ -181,6 +184,7 @@ func Test_SyncRun(t *testing.T) { mgc.On("IsDirty").Return(false, nil).Once() mgc.On("Fetch", []string{"origin", "+refs/heads/test"}).Return(nil).Once() mgc.On("HasLocalBranch", []string{"test"}).Return(true).Once() + mgc.On("BranchRemote", "test").Return("origin", nil).Once() mgc.On("IsAncestor", []string{"test", "origin/test"}).Return(true, nil).Once() mgc.On("CurrentBranch").Return("test", nil).Once() mgc.On("Merge", []string{"--ff-only", "refs/remotes/origin/test"}).Return(nil).Once() @@ -202,6 +206,7 @@ func Test_SyncRun(t *testing.T) { mgc.On("IsDirty").Return(false, nil).Once() mgc.On("Fetch", []string{"origin", "+refs/heads/trunk"}).Return(nil).Once() mgc.On("HasLocalBranch", []string{"trunk"}).Return(true).Once() + mgc.On("BranchRemote", "trunk").Return("origin", nil).Once() mgc.On("IsAncestor", []string{"trunk", "origin/trunk"}).Return(false, nil).Once() mgc.On("CurrentBranch").Return("trunk", nil).Once() mgc.On("Reset", []string{"--hard", "refs/remotes/origin/trunk"}).Return(nil).Once() @@ -221,11 +226,30 @@ func Test_SyncRun(t *testing.T) { mgc.On("IsDirty").Return(false, nil).Once() mgc.On("Fetch", []string{"origin", "+refs/heads/trunk"}).Return(nil).Once() mgc.On("HasLocalBranch", []string{"trunk"}).Return(true).Once() + mgc.On("BranchRemote", "trunk").Return("origin", nil).Once() mgc.On("IsAncestor", []string{"trunk", "origin/trunk"}).Return(false, nil).Once() }, wantErr: true, errMsg: "can't sync because there are diverging changes, you can use `--force` to overwrite the changes", }, + { + name: "sync local repo with parent and mismatching branch remotes", + tty: true, + opts: &SyncOptions{}, + httpStubs: func(reg *httpmock.Registry) { + reg.Register( + httpmock.GraphQL(`query RepositoryInfo\b`), + httpmock.StringResponse(`{"data":{"repository":{"defaultBranchRef":{"name": "trunk"}}}}`)) + }, + gitStubs: func(mgc *mockGitClient) { + mgc.On("IsDirty").Return(false, nil).Once() + mgc.On("Fetch", []string{"origin", "+refs/heads/trunk"}).Return(nil).Once() + mgc.On("HasLocalBranch", []string{"trunk"}).Return(true).Once() + mgc.On("BranchRemote", "trunk").Return("upstream", nil).Once() + }, + wantErr: true, + errMsg: "can't sync because trunk is not tracking OWNER/REPO", + }, { name: "sync local repo with parent and local changes", tty: true, @@ -249,6 +273,7 @@ func Test_SyncRun(t *testing.T) { mgc.On("IsDirty").Return(false, nil).Once() mgc.On("Fetch", []string{"origin", "+refs/heads/trunk"}).Return(nil).Once() mgc.On("HasLocalBranch", []string{"trunk"}).Return(true).Once() + mgc.On("BranchRemote", "trunk").Return("origin", nil).Once() mgc.On("IsAncestor", []string{"trunk", "origin/trunk"}).Return(true, nil).Once() mgc.On("CurrentBranch").Return("test", nil).Once() mgc.On("Checkout", []string{"trunk"}).Return(nil).Once()