diff --git a/pkg/cmd/repo/create/create.go b/pkg/cmd/repo/create/create.go index 2aa7834ea..4ab6025f9 100644 --- a/pkg/cmd/repo/create/create.go +++ b/pkg/cmd/repo/create/create.go @@ -132,6 +132,7 @@ func createRun(opts *CreateOptions) error { isNameAnArg := false isDescEmpty := opts.Description == "" isVisibilityPassed := false + inLocalRepo := projectDirErr == nil if opts.Name != "" { isNameAnArg = true @@ -201,6 +202,7 @@ func createRun(opts *CreateOptions) error { repoToCreate = ghrepo.New("", opts.Name) } + var templateRepoMainBranch string // Find template repo ID if opts.Template != "" { httpClient, err := opts.HttpClient() @@ -230,6 +232,7 @@ func createRun(opts *CreateOptions) error { } opts.Template = repo.ID + templateRepoMainBranch = repo.DefaultBranchRef.Name } input := repoCreateInput{ @@ -250,7 +253,7 @@ func createRun(opts *CreateOptions) error { createLocalDirectory := opts.ConfirmSubmit if !opts.ConfirmSubmit { - opts.ConfirmSubmit, err = confirmSubmission(input.Name, input.OwnerID, projectDirErr) + opts.ConfirmSubmit, err = confirmSubmission(input.Name, input.OwnerID, inLocalRepo) if err != nil { return err } @@ -284,7 +287,7 @@ func createRun(opts *CreateOptions) error { } remoteURL := ghrepo.FormatRemoteURL(repo, protocol) - if projectDirErr == nil { + if inLocalRepo { _, err = git.AddRemote("origin", remoteURL) if err != nil { return err @@ -295,7 +298,7 @@ func createRun(opts *CreateOptions) error { } else { if opts.IO.CanPrompt() { if !createLocalDirectory { - err := prompt.Confirm(fmt.Sprintf("Create a local project directory for %s?", ghrepo.FullName(repo)), &createLocalDirectory) + err := prompt.Confirm(fmt.Sprintf(`Create a local project directory for "%s"?`, ghrepo.FullName(repo)), &createLocalDirectory) if err != nil { return err } @@ -303,32 +306,18 @@ func createRun(opts *CreateOptions) error { } if createLocalDirectory { path := repo.Name - - gitInit, err := git.GitCommand("init", path) - if err != nil { - return err + checkoutBranch := "" + if opts.Template != "" { + // NOTE: we cannot read `defaultBranchRef` from the newly created repository as it will + // be null at this time. Instead, we assume that the main branch name of the new + // repository will be the same as that of the template repository. + checkoutBranch = templateRepoMainBranch } - isTTY := opts.IO.IsStdoutTTY() - if isTTY { - gitInit.Stdout = stdout - } - gitInit.Stderr = stderr - err = run.PrepareCmd(gitInit).Run() - if err != nil { - return err - } - gitRemoteAdd, err := git.GitCommand("-C", path, "remote", "add", "origin", remoteURL) - if err != nil { - return err - } - gitRemoteAdd.Stdout = stdout - gitRemoteAdd.Stderr = stderr - err = run.PrepareCmd(gitRemoteAdd).Run() - if err != nil { + if err := localInit(opts.IO, remoteURL, path, checkoutBranch); err != nil { return err } if isTTY { - fmt.Fprintf(stderr, "%s Initialized repository in './%s/'\n", cs.SuccessIcon(), path) + fmt.Fprintf(stderr, "%s Initialized repository in \"%s\"\n", cs.SuccessIcon(), path) } } } @@ -339,6 +328,56 @@ func createRun(opts *CreateOptions) error { return nil } +func localInit(io *iostreams.IOStreams, remoteURL, path, checkoutBranch string) error { + gitInit, err := git.GitCommand("init", path) + if err != nil { + return err + } + isTTY := io.IsStdoutTTY() + if isTTY { + gitInit.Stdout = io.Out + } + gitInit.Stderr = io.ErrOut + err = run.PrepareCmd(gitInit).Run() + if err != nil { + return err + } + + gitRemoteAdd, err := git.GitCommand("-C", path, "remote", "add", "origin", remoteURL) + if err != nil { + return err + } + gitRemoteAdd.Stdout = io.Out + gitRemoteAdd.Stderr = io.ErrOut + err = run.PrepareCmd(gitRemoteAdd).Run() + if err != nil { + return err + } + + if checkoutBranch == "" { + return nil + } + + gitFetch, err := git.GitCommand("-C", path, "fetch", "origin", fmt.Sprintf("+refs/heads/%[1]s:refs/remotes/origin/%[1]s", checkoutBranch)) + if err != nil { + return err + } + gitFetch.Stdout = io.Out + gitFetch.Stderr = io.ErrOut + err = run.PrepareCmd(gitFetch).Run() + if err != nil { + return err + } + + gitCheckout, err := git.GitCommand("-C", path, "checkout", checkoutBranch) + if err != nil { + return err + } + gitCheckout.Stdout = io.Out + gitCheckout.Stderr = io.ErrOut + return run.PrepareCmd(gitCheckout).Run() +} + func interactiveRepoCreate(isDescEmpty bool, isVisibilityPassed bool, repoName string) (string, string, string, error) { qs := []*survey.Question{} @@ -388,16 +427,18 @@ func interactiveRepoCreate(isDescEmpty bool, isVisibilityPassed bool, repoName s return answers.RepoName, answers.RepoDescription, strings.ToUpper(answers.RepoVisibility), nil } -func confirmSubmission(repoName string, repoOwner string, projectDirErr error) (bool, error) { +func confirmSubmission(repoName string, repoOwner string, inLocalRepo bool) (bool, error) { qs := []*survey.Question{} promptString := "" - if projectDirErr == nil { - promptString = "This will add remote origin to your current directory. Continue? " - } else if repoOwner != "" { - promptString = fmt.Sprintf("This will create '%s/%s' in your current directory. Continue? ", repoOwner, repoName) + if inLocalRepo { + promptString = `This will add an "origin" git remote to your local repository. Continue?` } else { - promptString = fmt.Sprintf("This will create '%s' in your current directory. Continue? ", repoName) + targetRepo := repoName + if repoOwner != "" { + targetRepo = fmt.Sprintf("%s/%s", repoOwner, repoName) + } + promptString = fmt.Sprintf(`This will create the "%s" repository on GitHub. Continue?`, targetRepo) } confirmSubmitQuestion := &survey.Question{ diff --git a/pkg/cmd/repo/create/create_test.go b/pkg/cmd/repo/create/create_test.go index 9914807ea..9aa975724 100644 --- a/pkg/cmd/repo/create/create_test.go +++ b/pkg/cmd/repo/create/create_test.go @@ -7,6 +7,7 @@ import ( "net/http" "testing" + "github.com/MakeNowJust/heredoc" "github.com/cli/cli/internal/config" "github.com/cli/cli/internal/run" "github.com/cli/cli/pkg/cmdutil" @@ -345,6 +346,7 @@ func TestRepoCreate_orgWithTeam(t *testing.T) { func TestRepoCreate_template(t *testing.T) { reg := &httpmock.Registry{} + defer reg.Verify(t) reg.Register( httpmock.GraphQL(`mutation CloneTemplateRepository\b`), httpmock.StringResponse(` @@ -370,32 +372,26 @@ func TestRepoCreate_template(t *testing.T) { cs, cmdTeardown := run.Stub() defer cmdTeardown(t) - cs.Register(`git remote add -f origin https://github\.com/OWNER/REPO\.git`, 0, "") - cs.Register(`git rev-parse --show-toplevel`, 0, "") + cs.Register(`git rev-parse --show-toplevel`, 1, "") + cs.Register(`git init REPO`, 0, "") + cs.Register(`git -C REPO remote add`, 0, "") + cs.Register(`git -C REPO fetch origin \+refs/heads/main:refs/remotes/origin/main`, 0, "") + cs.Register(`git -C REPO checkout main`, 0, "") - as, surveyTearDown := prompt.InitAskStubber() + _, surveyTearDown := prompt.InitAskStubber() defer surveyTearDown() - as.Stub([]*prompt.QuestionStub{ - { - Name: "repoVisibility", - Value: "PRIVATE", - }, - }) - as.Stub([]*prompt.QuestionStub{ - { - Name: "confirmSubmit", - Value: true, - }, - }) - - output, err := runCommand(httpClient, "REPO --template='OWNER/REPO'", true) + output, err := runCommand(httpClient, "REPO -y --private --template='OWNER/REPO'", true) if err != nil { t.Errorf("error running command `repo create`: %v", err) + return } assert.Equal(t, "", output.String()) - assert.Equal(t, "āœ“ Created repository OWNER/REPO on GitHub\nāœ“ Added remote https://github.com/OWNER/REPO.git\n", output.Stderr()) + assert.Equal(t, heredoc.Doc(` + āœ“ Created repository OWNER/REPO on GitHub + āœ“ Initialized repository in "REPO" + `), output.Stderr()) var reqBody struct { Query string @@ -404,10 +400,6 @@ func TestRepoCreate_template(t *testing.T) { } } - if len(reg.Requests) != 3 { - t.Fatalf("expected 3 HTTP requests, got %d", len(reg.Requests)) - } - bodyBytes, _ := ioutil.ReadAll(reg.Requests[2].Body) _ = json.Unmarshal(bodyBytes, &reqBody) if repoName := reqBody.Variables.Input["name"].(string); repoName != "REPO" {