diff --git a/git/git.go b/git/git.go index abff866d6..afa9ccf0f 100644 --- a/git/git.go +++ b/git/git.go @@ -368,3 +368,30 @@ func getBranchShortName(output []byte) string { branch := firstLine(output) return strings.TrimPrefix(branch, "refs/heads/") } + +func IsAncestor(ancestor, commit string) (bool, error) { + cmd, err := GitCommand("merge-base", "--is-ancestor", ancestor, commit) + if err != nil { + return false, err + } + err = run.PrepareCmd(cmd).Run() + return err == nil, nil +} + +func IsDirty() (bool, error) { + cmd, err := GitCommand("status", "--untracked-files=no", "--porcelain") + if err != nil { + return false, err + } + + output, err := run.PrepareCmd(cmd).Output() + if err != nil { + return true, err + } + + if len(output) > 0 { + return true, nil + } + + return false, nil +} diff --git a/pkg/cmd/repo/repo.go b/pkg/cmd/repo/repo.go index 4fee2b19c..26e856ffd 100644 --- a/pkg/cmd/repo/repo.go +++ b/pkg/cmd/repo/repo.go @@ -8,6 +8,7 @@ import ( repoForkCmd "github.com/cli/cli/pkg/cmd/repo/fork" gardenCmd "github.com/cli/cli/pkg/cmd/repo/garden" repoListCmd "github.com/cli/cli/pkg/cmd/repo/list" + repoSyncCmd "github.com/cli/cli/pkg/cmd/repo/sync" repoViewCmd "github.com/cli/cli/pkg/cmd/repo/view" "github.com/cli/cli/pkg/cmdutil" "github.com/spf13/cobra" @@ -38,6 +39,7 @@ func NewCmdRepo(f *cmdutil.Factory) *cobra.Command { cmd.AddCommand(repoCloneCmd.NewCmdClone(f, nil)) cmd.AddCommand(repoCreateCmd.NewCmdCreate(f, nil)) cmd.AddCommand(repoListCmd.NewCmdList(f, nil)) + cmd.AddCommand(repoSyncCmd.NewCmdSync(f, nil)) cmd.AddCommand(creditsCmd.NewCmdRepoCredits(f, nil)) cmd.AddCommand(gardenCmd.NewCmdGarden(f, nil)) diff --git a/pkg/cmd/repo/sync/http.go b/pkg/cmd/repo/sync/http.go new file mode 100644 index 000000000..70c3e9a7e --- /dev/null +++ b/pkg/cmd/repo/sync/http.go @@ -0,0 +1,42 @@ +package sync + +import ( + "bytes" + "encoding/json" + "fmt" + + "github.com/cli/cli/api" + "github.com/cli/cli/internal/ghrepo" +) + +type commit struct { + Ref string `json:"ref"` + NodeID string `json:"node_id"` + URL string `json:"url"` + Object struct { + Type string `json:"type"` + SHA string `json:"sha"` + URL string `json:"url"` + } `json:"object"` +} + +func latestCommit(client *api.Client, repo ghrepo.Interface, branch string) (commit, error) { + var response commit + path := fmt.Sprintf("repos/%s/%s/git/refs/heads/%s", repo.RepoOwner(), repo.RepoName(), branch) + err := client.REST(repo.RepoHost(), "GET", path, nil, &response) + return response, err +} + +func syncFork(client *api.Client, repo ghrepo.Interface, branch, SHA string, force bool) error { + path := fmt.Sprintf("repos/%s/%s/git/refs/heads/%s", repo.RepoOwner(), repo.RepoName(), branch) + body := map[string]interface{}{ + "sha": SHA, + "force": force, + } + requestByte, err := json.Marshal(body) + if err != nil { + return err + } + requestBody := bytes.NewReader(requestByte) + return client.REST(repo.RepoHost(), "PATCH", path, requestBody, nil) +} diff --git a/pkg/cmd/repo/sync/sync.go b/pkg/cmd/repo/sync/sync.go new file mode 100644 index 000000000..e6d5a891a --- /dev/null +++ b/pkg/cmd/repo/sync/sync.go @@ -0,0 +1,286 @@ +package sync + +import ( + "errors" + "fmt" + "net/http" + "os/exec" + "regexp" + + "github.com/AlecAivazis/survey/v2" + "github.com/MakeNowJust/heredoc" + "github.com/cli/cli/api" + "github.com/cli/cli/context" + "github.com/cli/cli/git" + "github.com/cli/cli/internal/ghrepo" + "github.com/cli/cli/internal/run" + "github.com/cli/cli/pkg/cmdutil" + "github.com/cli/cli/pkg/iostreams" + "github.com/cli/cli/pkg/prompt" + "github.com/cli/safeexec" + "github.com/spf13/cobra" +) + +type SyncOptions struct { + HttpClient func() (*http.Client, error) + IO *iostreams.IOStreams + BaseRepo func() (ghrepo.Interface, error) + Remotes func() (context.Remotes, error) + CurrentBranch func() (string, error) + DestArg string + SrcArg string + Branch string + Force bool + SkipConfirm bool +} + +func NewCmdSync(f *cmdutil.Factory, runF func(*SyncOptions) error) *cobra.Command { + opts := SyncOptions{ + HttpClient: f.HttpClient, + IO: f.IOStreams, + BaseRepo: f.BaseRepo, + Remotes: f.Remotes, + CurrentBranch: f.Branch, + } + + cmd := &cobra.Command{ + Use: "sync []", + Short: "Sync a repository", + Long: heredoc.Doc(` + Sync destination repository from source repository. + + Without an argument, the local repository is selected as the destination repository. + By default the source repository is the parent of the destination repository. + The source repository can be overridden with the --source flag. + `), + Example: heredoc.Doc(` + # Sync local repository from remote parent + $ gh repo sync + + # Sync local repository from remote parent on non-default branch + $ gh repo sync --branch v1 + + # Sync remote fork from remote parent + $ gh repo sync owner/cli-fork + + # Sync remote repo from another remote repo + $ gh repo sync owner/repo --source owner2/repo2 + `), + Args: cobra.MaximumNArgs(1), + RunE: func(c *cobra.Command, args []string) error { + if len(args) > 0 { + opts.DestArg = args[0] + } + if !opts.IO.CanPrompt() && !opts.SkipConfirm { + return &cmdutil.FlagError{Err: errors.New("`--confirm` required when not running interactively")} + } + if runF != nil { + return runF(&opts) + } + return syncRun(&opts) + }, + } + + cmd.Flags().StringVarP(&opts.SrcArg, "source", "s", "", "Source repository") + cmd.Flags().StringVarP(&opts.Branch, "branch", "b", "", "Branch to sync") + cmd.Flags().BoolVarP(&opts.Force, "force", "", false, "Discard destination repository changes") + cmd.Flags().BoolVarP(&opts.SkipConfirm, "confirm", "y", false, "Skip the confirmation prompt") + return cmd +} + +func syncRun(opts *SyncOptions) error { + httpClient, err := opts.HttpClient() + if err != nil { + return err + } + apiClient := api.NewClientFromHTTP(httpClient) + + var local bool + var destRepo, srcRepo ghrepo.Interface + + if opts.DestArg == "" { + local = true + destRepo, err = opts.BaseRepo() + if err != nil { + return err + } + } else { + destRepo, err = ghrepo.FromFullName(opts.DestArg) + if err != nil { + return err + } + } + + if opts.SrcArg == "" { + if local { + srcRepo = destRepo + } else { + opts.IO.StartProgressIndicator() + srcRepo, err = api.RepoParent(apiClient, destRepo) + opts.IO.StopProgressIndicator() + if err != nil { + return err + } + if srcRepo == nil { + return fmt.Errorf("can't determine source repo for %s because repo is not fork", ghrepo.FullName(destRepo)) + } + } + } else { + srcRepo, err = ghrepo.FromFullName(opts.SrcArg) + if err != nil { + return err + } + } + + if !local && destRepo.RepoHost() != srcRepo.RepoHost() { + return fmt.Errorf("can't sync repos from different hosts") + } + + if opts.Branch == "" { + opts.IO.StartProgressIndicator() + opts.Branch, err = api.RepoDefaultBranch(apiClient, srcRepo) + opts.IO.StopProgressIndicator() + if err != nil { + return err + } + } + + srcStr := fmt.Sprintf("%s:%s", ghrepo.FullName(srcRepo), opts.Branch) + destStr := fmt.Sprintf("%s:%s", ghrepo.FullName(destRepo), opts.Branch) + if local { + destStr = fmt.Sprintf(".:%s", opts.Branch) + } + cs := opts.IO.ColorScheme() + if !opts.SkipConfirm && opts.IO.CanPrompt() { + if opts.Force { + fmt.Fprintf(opts.IO.ErrOut, "%s Using --force will cause diverging commits on %s to be discarded\n", cs.WarningIcon(), destStr) + } + var confirmed bool + confirmQuestion := &survey.Confirm{ + Message: fmt.Sprintf("Sync %s from %s?", destStr, srcStr), + Default: false, + } + err := prompt.SurveyAskOne(confirmQuestion, &confirmed) + if err != nil { + return err + } + + if !confirmed { + return cmdutil.CancelError + } + } + + opts.IO.StartProgressIndicator() + if local { + err = syncLocalRepo(srcRepo, opts) + } else { + err = syncRemoteRepo(apiClient, destRepo, srcRepo, opts) + } + opts.IO.StopProgressIndicator() + + if err != nil { + return err + } + + if opts.IO.IsStdoutTTY() { + success := cs.Bold(fmt.Sprintf("Synced %s from %s\n", destStr, srcStr)) + fmt.Fprintf(opts.IO.Out, "%s %s", cs.SuccessIconWithColor(cs.GreenBold), success) + } + + return nil +} + +func syncLocalRepo(srcRepo ghrepo.Interface, opts *SyncOptions) error { + // Remotes precedence by name + // 1. upstream + // 2. github + // 3. origin + // 4. other + remotes, err := opts.Remotes() + if err != nil { + return err + } + remote := remotes[0] + branch := opts.Branch + + _ = executeCmds([][]string{{"git", "fetch", remote.Name, fmt.Sprintf("+refs/heads/%s", branch)}}) + + hasLocalBranch := git.HasLocalBranch(branch) + if hasLocalBranch { + fastForward, err := git.IsAncestor(branch, fmt.Sprintf("%s/%s", remote.Name, branch)) + if err != nil { + return err + } + + if !fastForward && !opts.Force { + return fmt.Errorf("can't sync .:%s because there are diverging commits, try using `--force`", branch) + } + } + + startBranch, err := opts.CurrentBranch() + if err != nil { + return err + } + + dirtyRepo, err := git.IsDirty() + if err != nil { + return err + } + + var cmds [][]string + if dirtyRepo { + cmds = append(cmds, []string{"git", "stash", "push"}) + } + if startBranch != branch { + cmds = append(cmds, []string{"git", "checkout", branch}) + } + if hasLocalBranch { + if opts.Force { + cmds = append(cmds, []string{"git", "reset", "--hard", fmt.Sprintf("refs/remotes/%s/%s", remote, branch)}) + } else { + cmds = append(cmds, []string{"git", "merge", "--ff-only", fmt.Sprintf("refs/remotes/%s/%s", remote, branch)}) + } + } + if startBranch != branch { + cmds = append(cmds, []string{"git", "checkout", startBranch}) + } + if dirtyRepo { + cmds = append(cmds, []string{"git", "stash", "pop"}) + } + + return executeCmds(cmds) +} + +func syncRemoteRepo(client *api.Client, destRepo, srcRepo ghrepo.Interface, opts *SyncOptions) error { + commit, err := latestCommit(client, srcRepo, opts.Branch) + if err != nil { + return err + } + + // This is not a great way to detect the error returned by the API + // Unfortunately API returns 422 for multiple reasons + notFastForwardErrorMessage := regexp.MustCompile(`^Update is not a fast forward$`) + err = syncFork(client, destRepo, opts.Branch, commit.Object.SHA, opts.Force) + var httpErr api.HTTPError + if err != nil && errors.As(err, &httpErr) && notFastForwardErrorMessage.MatchString(httpErr.Message) { + return fmt.Errorf("can't sync %s:%s because there are diverging commits, try using `--force`", + ghrepo.FullName(destRepo), + opts.Branch) + } + + return err +} + +func executeCmds(cmdQueue [][]string) error { + exe, err := safeexec.LookPath("git") + if err != nil { + return err + } + for _, args := range cmdQueue { + cmd := exec.Command(exe, args[1:]...) + if err := run.PrepareCmd(cmd).Run(); err != nil { + return err + } + } + return nil +} diff --git a/pkg/iostreams/color.go b/pkg/iostreams/color.go index 2dedbdbd5..4a4633205 100644 --- a/pkg/iostreams/color.go +++ b/pkg/iostreams/color.go @@ -9,15 +9,16 @@ import ( ) var ( - magenta = ansi.ColorFunc("magenta") - cyan = ansi.ColorFunc("cyan") - red = ansi.ColorFunc("red") - yellow = ansi.ColorFunc("yellow") - blue = ansi.ColorFunc("blue") - green = ansi.ColorFunc("green") - gray = ansi.ColorFunc("black+h") - bold = ansi.ColorFunc("default+b") - cyanBold = ansi.ColorFunc("cyan+b") + magenta = ansi.ColorFunc("magenta") + cyan = ansi.ColorFunc("cyan") + red = ansi.ColorFunc("red") + yellow = ansi.ColorFunc("yellow") + blue = ansi.ColorFunc("blue") + green = ansi.ColorFunc("green") + gray = ansi.ColorFunc("black+h") + bold = ansi.ColorFunc("default+b") + cyanBold = ansi.ColorFunc("cyan+b") + greenBold = ansi.ColorFunc("green+b") gray256 = func(t string) string { return fmt.Sprintf("\x1b[%d;5;%dm%s\x1b[m", 38, 242, t) @@ -96,6 +97,13 @@ func (c *ColorScheme) Green(t string) string { return green(t) } +func (c *ColorScheme) GreenBold(t string) string { + if !c.enabled { + return t + } + return greenBold(t) +} + func (c *ColorScheme) Greenf(t string, args ...interface{}) string { return c.Green(fmt.Sprintf(t, args...)) }