From 0b80c30789840c136f8fd56423522cfea658b92a Mon Sep 17 00:00:00 2001 From: Sam Coe Date: Mon, 28 Jun 2021 17:00:06 -0700 Subject: [PATCH] Fix remote resolving for source repo --- pkg/cmd/repo/sync/sync.go | 43 ++++++++++++++++++++-------------- pkg/cmd/repo/sync/sync_test.go | 22 ++++++++++++----- 2 files changed, 41 insertions(+), 24 deletions(-) diff --git a/pkg/cmd/repo/sync/sync.go b/pkg/cmd/repo/sync/sync.go index 0073433e1..23c6d1f11 100644 --- a/pkg/cmd/repo/sync/sync.go +++ b/pkg/cmd/repo/sync/sync.go @@ -97,6 +97,14 @@ func syncLocalRepo(opts *SyncOptions) error { var err error var srcRepo ghrepo.Interface + dirtyRepo, err := opts.Git.IsDirty() + if err != nil { + return err + } + if dirtyRepo { + return fmt.Errorf("can't sync because there are local changes, please commit or stash them") + } + if opts.SrcArg != "" { srcRepo, err = ghrepo.FromFullName(opts.SrcArg) } else { @@ -106,12 +114,21 @@ func syncLocalRepo(opts *SyncOptions) error { return err } - dirtyRepo, err := opts.Git.IsDirty() + // Find remote that matches the srcRepo + var remote string + remotes, err := opts.Remotes() if err != nil { return err } - if dirtyRepo { - return fmt.Errorf("can't sync because there are local changes, please commit or stash them") + for _, r := range remotes { + if r.RepoName() == srcRepo.RepoName() && + r.RepoOwner() == srcRepo.RepoOwner() && + r.RepoHost() == srcRepo.RepoHost() { + remote = r.Name + } + } + if remote == "" { + return fmt.Errorf("can't find corresponding remote for %s", ghrepo.FullName(srcRepo)) } if opts.Branch == "" { @@ -129,7 +146,7 @@ func syncLocalRepo(opts *SyncOptions) error { } opts.IO.StartProgressIndicator() - err = executeLocalRepoSync(srcRepo, opts) + err = executeLocalRepoSync(srcRepo, remote, opts) opts.IO.StopProgressIndicator() if err != nil { if errors.Is(err, divergingError) { @@ -216,28 +233,18 @@ func syncRemoteRepo(opts *SyncOptions) error { var divergingError = errors.New("diverging changes") -func executeLocalRepoSync(srcRepo ghrepo.Interface, opts *SyncOptions) error { - // Remotes precedence by name - // 1. upstream - // 2. github - // 3. origin - // 4. other - remotes, err := opts.Remotes() - if err != nil { - return err - } - remote := remotes[0] - branch := opts.Branch +func executeLocalRepoSync(srcRepo ghrepo.Interface, remote string, opts *SyncOptions) error { git := opts.Git + branch := opts.Branch - err = git.Fetch([]string{remote.Name, fmt.Sprintf("+refs/heads/%s", branch)}) + err := git.Fetch([]string{remote, fmt.Sprintf("+refs/heads/%s", branch)}) if err != nil { return err } hasLocalBranch := git.HasLocalBranch([]string{branch}) if hasLocalBranch { - fastForward, err := git.IsAncestor([]string{branch, fmt.Sprintf("%s/%s", remote.Name, branch)}) + fastForward, err := git.IsAncestor([]string{branch, fmt.Sprintf("%s/%s", remote, branch)}) if err != nil { return err } diff --git a/pkg/cmd/repo/sync/sync_test.go b/pkg/cmd/repo/sync/sync_test.go index d4e953597..237d7cadd 100644 --- a/pkg/cmd/repo/sync/sync_test.go +++ b/pkg/cmd/repo/sync/sync_test.go @@ -163,11 +163,11 @@ func Test_SyncRun(t *testing.T) { }, gitStubs: func(mgc *mockGitClient) { mgc.On("IsDirty").Return(false, nil).Once() - mgc.On("Fetch", []string{"origin", "+refs/heads/trunk"}).Return(nil).Once() + mgc.On("Fetch", []string{"upstream", "+refs/heads/trunk"}).Return(nil).Once() mgc.On("HasLocalBranch", []string{"trunk"}).Return(true).Once() - mgc.On("IsAncestor", []string{"trunk", "origin/trunk"}).Return(true, nil).Once() + mgc.On("IsAncestor", []string{"trunk", "upstream/trunk"}).Return(true, nil).Once() mgc.On("CurrentBranch").Return("trunk", nil).Once() - mgc.On("Merge", []string{"--ff-only", "refs/remotes/origin/trunk"}).Return(nil).Once() + mgc.On("Merge", []string{"--ff-only", "refs/remotes/upstream/trunk"}).Return(nil).Once() }, wantStdout: "✓ Synced .:trunk from OWNER2/REPO2:trunk\n", }, @@ -423,14 +423,24 @@ func Test_SyncRun(t *testing.T) { io.SetStdoutTTY(tt.tty) tt.opts.IO = io + repo1, _ := ghrepo.FromFullName("OWNER/REPO") + repo2, _ := ghrepo.FromFullName("OWNER2/REPO2") tt.opts.BaseRepo = func() (ghrepo.Interface, error) { - repo, _ := ghrepo.FromFullName("OWNER/REPO") - return repo, nil + return repo1, nil } tt.opts.Remotes = func() (context.Remotes, error) { if tt.remotes == nil { - return []*context.Remote{{Remote: &git.Remote{Name: "origin"}}}, nil + return []*context.Remote{ + { + Remote: &git.Remote{Name: "origin"}, + Repo: repo1, + }, + { + Remote: &git.Remote{Name: "upstream"}, + Repo: repo2, + }, + }, nil } return tt.remotes, nil }