From fac497108b75158ba09966fca81e500476a1a0f7 Mon Sep 17 00:00:00 2001 From: Sam Coe Date: Mon, 12 Jun 2023 09:23:56 +0900 Subject: [PATCH] Set upstream remote to track all branches after initial fetch (#7542) --- git/client.go | 76 +++++++++++++++++++++---------- git/client_test.go | 78 ++++++++++++++++++++++++++++++++ pkg/cmd/extension/git.go | 13 ++---- pkg/cmd/repo/clone/clone.go | 11 ++++- pkg/cmd/repo/clone/clone_test.go | 2 + pkg/cmd/repo/create/create.go | 15 +----- 6 files changed, 145 insertions(+), 50 deletions(-) diff --git a/git/client.go b/git/client.go index a7d59899e..a029c8ba6 100644 --- a/git/client.go +++ b/git/client.go @@ -36,6 +36,19 @@ type Client struct { mu sync.Mutex } +func (c *Client) Copy() *Client { + return &Client{ + GhPath: c.GhPath, + RepoDir: c.RepoDir, + GitPath: c.GitPath, + Stderr: c.Stderr, + Stdin: c.Stdin, + Stdout: c.Stdout, + + commandContext: c.commandContext, + } +} + func (c *Client) Command(ctx context.Context, args ...string) (*Command, error) { if c.RepoDir != "" { args = append([]string{"-C", c.RepoDir}, args...) @@ -408,6 +421,44 @@ func (c *Client) revParse(ctx context.Context, args ...string) ([]byte, error) { return cmd.Output() } +func (c *Client) IsLocalGitRepo(ctx context.Context) (bool, error) { + _, err := c.GitDir(ctx) + if err != nil { + var execError errWithExitCode + if errors.As(err, &execError) && execError.ExitCode() == 128 { + return false, nil + } + return false, err + } + return true, nil +} + +func (c *Client) UnsetRemoteResolution(ctx context.Context, name string) error { + args := []string{"config", "--unset", fmt.Sprintf("remote.%s.gh-resolved", name)} + cmd, err := c.Command(ctx, args...) + if err != nil { + return err + } + _, err = cmd.Output() + if err != nil { + return err + } + return nil +} + +func (c *Client) SetRemoteBranches(ctx context.Context, remote string, refspec string) error { + args := []string{"remote", "set-branches", remote, refspec} + cmd, err := c.Command(ctx, args...) + if err != nil { + return err + } + _, err = cmd.Output() + if err != nil { + return err + } + return nil +} + // Below are commands that make network calls and need authentication credentials supplied from gh. func (c *Client) Fetch(ctx context.Context, remote string, refspec string, mods ...CommandModifier) error { @@ -513,31 +564,6 @@ func (c *Client) AddRemote(ctx context.Context, name, urlStr string, trackingBra return remote, nil } -func (c *Client) IsLocalGitRepo(ctx context.Context) (bool, error) { - _, err := c.GitDir(ctx) - if err != nil { - var execError errWithExitCode - if errors.As(err, &execError) && execError.ExitCode() == 128 { - return false, nil - } - return false, err - } - return true, nil -} - -func (c *Client) UnsetRemoteResolution(ctx context.Context, name string) error { - args := []string{"config", "--unset", fmt.Sprintf("remote.%s.gh-resolved", name)} - cmd, err := c.Command(ctx, args...) - if err != nil { - return err - } - _, err = cmd.Output() - if err != nil { - return err - } - return nil -} - func resolveGitPath() (string, error) { path, err := safeexec.LookPath("git") if err != nil { diff --git a/git/client_test.go b/git/client_test.go index 6bbeee27e..5e96bb1b7 100644 --- a/git/client_test.go +++ b/git/client_test.go @@ -834,6 +834,84 @@ func TestClientPathFromRoot(t *testing.T) { } } +func TestClientUnsetRemoteResolution(t *testing.T) { + tests := []struct { + name string + cmdExitStatus int + cmdStdout string + cmdStderr string + wantCmdArgs string + wantErrorMsg string + }{ + { + name: "unset remote resolution", + wantCmdArgs: `path/to/git config --unset remote.origin.gh-resolved`, + }, + { + name: "git error", + cmdExitStatus: 1, + cmdStderr: "git error message", + wantCmdArgs: `path/to/git config --unset remote.origin.gh-resolved`, + wantErrorMsg: "failed to run git: git error message", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd, cmdCtx := createCommandContext(t, tt.cmdExitStatus, tt.cmdStdout, tt.cmdStderr) + client := Client{ + GitPath: "path/to/git", + commandContext: cmdCtx, + } + err := client.UnsetRemoteResolution(context.Background(), "origin") + assert.Equal(t, tt.wantCmdArgs, strings.Join(cmd.Args[3:], " ")) + if tt.wantErrorMsg == "" { + assert.NoError(t, err) + } else { + assert.EqualError(t, err, tt.wantErrorMsg) + } + }) + } +} + +func TestClientSetRemoteBranches(t *testing.T) { + tests := []struct { + name string + cmdExitStatus int + cmdStdout string + cmdStderr string + wantCmdArgs string + wantErrorMsg string + }{ + { + name: "set remote branches", + wantCmdArgs: `path/to/git remote set-branches origin trunk`, + }, + { + name: "git error", + cmdExitStatus: 1, + cmdStderr: "git error message", + wantCmdArgs: `path/to/git remote set-branches origin trunk`, + wantErrorMsg: "failed to run git: git error message", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd, cmdCtx := createCommandContext(t, tt.cmdExitStatus, tt.cmdStdout, tt.cmdStderr) + client := Client{ + GitPath: "path/to/git", + commandContext: cmdCtx, + } + err := client.SetRemoteBranches(context.Background(), "origin", "trunk") + assert.Equal(t, tt.wantCmdArgs, strings.Join(cmd.Args[3:], " ")) + if tt.wantErrorMsg == "" { + assert.NoError(t, err) + } else { + assert.EqualError(t, err, tt.wantErrorMsg) + } + }) + } +} + func TestClientFetch(t *testing.T) { tests := []struct { name string diff --git a/pkg/cmd/extension/git.go b/pkg/cmd/extension/git.go index e85ff7371..58ef0ca12 100644 --- a/pkg/cmd/extension/git.go +++ b/pkg/cmd/extension/git.go @@ -46,16 +46,9 @@ func (g *gitExecuter) Fetch(remote string, refspec string) error { } func (g *gitExecuter) ForRepo(repoDir string) gitClient { - return &gitExecuter{ - client: &git.Client{ - GhPath: g.client.GhPath, - RepoDir: repoDir, - GitPath: g.client.GitPath, - Stderr: g.client.Stderr, - Stdin: g.client.Stdin, - Stdout: g.client.Stdout, - }, - } + gc := g.client.Copy() + gc.RepoDir = repoDir + return &gitExecuter{client: gc} } func (g *gitExecuter) Pull(remote, branch string) error { diff --git a/pkg/cmd/repo/clone/clone.go b/pkg/cmd/repo/clone/clone.go index fc39013c0..8d04ac630 100644 --- a/pkg/cmd/repo/clone/clone.go +++ b/pkg/cmd/repo/clone/clone.go @@ -175,12 +175,19 @@ func cloneRun(opts *CloneOptions) error { upstreamName = canonicalRepo.Parent.RepoOwner() } - _, err = gitClient.AddRemote(ctx, upstreamName, upstreamURL, []string{canonicalRepo.Parent.DefaultBranchRef.Name}, git.WithRepoDir(cloneDir)) + gc := gitClient.Copy() + gc.RepoDir = cloneDir + + _, err = gc.AddRemote(ctx, upstreamName, upstreamURL, []string{canonicalRepo.Parent.DefaultBranchRef.Name}) if err != nil { return err } - if err := gitClient.Fetch(ctx, upstreamName, "", git.WithRepoDir(cloneDir)); err != nil { + if err := gc.Fetch(ctx, upstreamName, ""); err != nil { + return err + } + + if err := gc.SetRemoteBranches(ctx, upstreamName, `*`); err != nil { return err } } diff --git a/pkg/cmd/repo/clone/clone_test.go b/pkg/cmd/repo/clone/clone_test.go index 98fc3736c..0ea3a65af 100644 --- a/pkg/cmd/repo/clone/clone_test.go +++ b/pkg/cmd/repo/clone/clone_test.go @@ -247,6 +247,7 @@ func Test_RepoClone_hasParent(t *testing.T) { cs.Register(`git clone https://github.com/OWNER/REPO.git`, 0, "") cs.Register(`git -C REPO remote add -t trunk upstream https://github.com/hubot/ORIG.git`, 0, "") cs.Register(`git -C REPO fetch upstream`, 0, "") + cs.Register(`git -C REPO remote set-branches upstream *`, 0, "") _, err := runCloneCommand(httpClient, "OWNER/REPO") if err != nil { @@ -284,6 +285,7 @@ func Test_RepoClone_hasParent_upstreamRemoteName(t *testing.T) { cs.Register(`git clone https://github.com/OWNER/REPO.git`, 0, "") cs.Register(`git -C REPO remote add -t trunk test https://github.com/hubot/ORIG.git`, 0, "") cs.Register(`git -C REPO fetch test`, 0, "") + cs.Register(`git -C REPO remote set-branches test *`, 0, "") _, err := runCloneCommand(httpClient, "OWNER/REPO --upstream-remote-name test") if err != nil { diff --git a/pkg/cmd/repo/create/create.go b/pkg/cmd/repo/create/create.go index fe2269e7e..1bd37769b 100644 --- a/pkg/cmd/repo/create/create.go +++ b/pkg/cmd/repo/create/create.go @@ -662,8 +662,8 @@ func localInit(gitClient *git.Client, remoteURL, path string) error { return err } - // Clone the client so we do not modify the original client's RepoDir. - gc := cloneGitClient(gitClient) + // Copy the client so we do not modify the original client's RepoDir. + gc := gitClient.Copy() gc.RepoDir = path gitRemoteAdd, err := gc.Command(ctx, "remote", "add", "origin", remoteURL) @@ -794,14 +794,3 @@ func splitNameAndOwner(name string) (string, string, error) { } return repo.RepoName(), repo.RepoOwner(), nil } - -func cloneGitClient(c *git.Client) *git.Client { - return &git.Client{ - GhPath: c.GhPath, - RepoDir: c.RepoDir, - GitPath: c.GitPath, - Stderr: c.Stderr, - Stdin: c.Stdin, - Stdout: c.Stdout, - } -}