diff --git a/git/git.go b/git/git.go index 3cda262f0..53e7089da 100644 --- a/git/git.go +++ b/git/git.go @@ -308,8 +308,13 @@ func RunClone(cloneURL string, args []string) (target string, err error) { return } -func AddUpstreamRemote(upstreamURL, cloneDir string) error { - cloneCmd, err := GitCommand("-C", cloneDir, "remote", "add", "-f", "upstream", upstreamURL) +func AddUpstreamRemote(upstreamURL, cloneDir string, branches []string) error { + args := []string{"-C", cloneDir, "remote", "add"} + for _, branch := range branches { + args = append(args, "-t", branch) + } + args = append(args, "-f", "upstream", upstreamURL) + cloneCmd, err := GitCommand(args...) if err != nil { return err } diff --git a/git/git_test.go b/git/git_test.go index 34cc9c7cb..a432f74b5 100644 --- a/git/git_test.go +++ b/git/git_test.go @@ -3,10 +3,12 @@ package git import ( "os/exec" "reflect" + "strings" "testing" "github.com/cli/cli/internal/run" "github.com/cli/cli/test" + "github.com/stretchr/testify/assert" ) func Test_UncommittedChangeCount(t *testing.T) { @@ -170,5 +172,45 @@ func TestParseExtraCloneArgs(t *testing.T) { } }) } - +} + +func TestAddUpstreamRemote(t *testing.T) { + tests := []struct { + name string + upstreamURL string + cloneDir string + branches []string + want string + }{ + { + name: "fetch all", + upstreamURL: "URL", + cloneDir: "DIRECTORY", + branches: []string{}, + want: "git -C DIRECTORY remote add -f upstream URL", + }, + { + name: "fetch specific branches only", + upstreamURL: "URL", + cloneDir: "DIRECTORY", + branches: []string{"master", "dev"}, + want: "git -C DIRECTORY remote add -t master -t dev -f upstream URL", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cs, restore := test.InitCmdStubber() + defer restore() + + cs.Stub("") // git remote add -f + + err := AddUpstreamRemote(tt.upstreamURL, tt.cloneDir, tt.branches) + if err != nil { + t.Fatalf("error running command `git remote add -f`: %v", err) + } + + assert.Equal(t, 1, cs.Count) + assert.Equal(t, tt.want, strings.Join(cs.Calls[0].Args, " ")) + }) + } } diff --git a/pkg/cmd/repo/clone/clone.go b/pkg/cmd/repo/clone/clone.go index 73f4d7187..e6a10ead0 100644 --- a/pkg/cmd/repo/clone/clone.go +++ b/pkg/cmd/repo/clone/clone.go @@ -144,7 +144,7 @@ func cloneRun(opts *CloneOptions) error { } upstreamURL := ghrepo.FormatRemoteURL(canonicalRepo.Parent, protocol) - err = git.AddUpstreamRemote(upstreamURL, cloneDir) + err = git.AddUpstreamRemote(upstreamURL, cloneDir, []string{canonicalRepo.Parent.DefaultBranchRef.Name}) if err != nil { return err } diff --git a/pkg/cmd/repo/clone/clone_test.go b/pkg/cmd/repo/clone/clone_test.go index e7aa57b08..916c94304 100644 --- a/pkg/cmd/repo/clone/clone_test.go +++ b/pkg/cmd/repo/clone/clone_test.go @@ -218,6 +218,9 @@ func Test_RepoClone_hasParent(t *testing.T) { "name": "ORIG", "owner": { "login": "hubot" + }, + "defaultBranchRef": { + "name": "master" } } } } } @@ -237,7 +240,7 @@ func Test_RepoClone_hasParent(t *testing.T) { } assert.Equal(t, 2, cs.Count) - assert.Equal(t, "git -C REPO remote add -f upstream https://github.com/hubot/ORIG.git", strings.Join(cs.Calls[1].Args, " ")) + assert.Equal(t, "git -C REPO remote add -t master -f upstream https://github.com/hubot/ORIG.git", strings.Join(cs.Calls[1].Args, " ")) } func Test_RepoClone_withoutUsername(t *testing.T) { diff --git a/pkg/cmd/repo/fork/fork.go b/pkg/cmd/repo/fork/fork.go index 760bb6537..63cbf5129 100644 --- a/pkg/cmd/repo/fork/fork.go +++ b/pkg/cmd/repo/fork/fork.go @@ -273,7 +273,7 @@ func forkRun(opts *ForkOptions) error { } upstreamURL := ghrepo.FormatRemoteURL(repoToFork, protocol) - err = git.AddUpstreamRemote(upstreamURL, cloneDir) + err = git.AddUpstreamRemote(upstreamURL, cloneDir, []string{}) if err != nil { return err }