Add repo sync command

This commit is contained in:
Sam Coe 2021-06-08 11:00:40 -04:00
parent 5984cf2a82
commit 86e16cc7c4
No known key found for this signature in database
GPG key ID: 8E322C20F811D086
5 changed files with 374 additions and 9 deletions

View file

@ -368,3 +368,30 @@ func getBranchShortName(output []byte) string {
branch := firstLine(output)
return strings.TrimPrefix(branch, "refs/heads/")
}
func IsAncestor(ancestor, commit string) (bool, error) {
cmd, err := GitCommand("merge-base", "--is-ancestor", ancestor, commit)
if err != nil {
return false, err
}
err = run.PrepareCmd(cmd).Run()
return err == nil, nil
}
func IsDirty() (bool, error) {
cmd, err := GitCommand("status", "--untracked-files=no", "--porcelain")
if err != nil {
return false, err
}
output, err := run.PrepareCmd(cmd).Output()
if err != nil {
return true, err
}
if len(output) > 0 {
return true, nil
}
return false, nil
}

View file

@ -8,6 +8,7 @@ import (
repoForkCmd "github.com/cli/cli/pkg/cmd/repo/fork"
gardenCmd "github.com/cli/cli/pkg/cmd/repo/garden"
repoListCmd "github.com/cli/cli/pkg/cmd/repo/list"
repoSyncCmd "github.com/cli/cli/pkg/cmd/repo/sync"
repoViewCmd "github.com/cli/cli/pkg/cmd/repo/view"
"github.com/cli/cli/pkg/cmdutil"
"github.com/spf13/cobra"
@ -38,6 +39,7 @@ func NewCmdRepo(f *cmdutil.Factory) *cobra.Command {
cmd.AddCommand(repoCloneCmd.NewCmdClone(f, nil))
cmd.AddCommand(repoCreateCmd.NewCmdCreate(f, nil))
cmd.AddCommand(repoListCmd.NewCmdList(f, nil))
cmd.AddCommand(repoSyncCmd.NewCmdSync(f, nil))
cmd.AddCommand(creditsCmd.NewCmdRepoCredits(f, nil))
cmd.AddCommand(gardenCmd.NewCmdGarden(f, nil))

42
pkg/cmd/repo/sync/http.go Normal file
View file

@ -0,0 +1,42 @@
package sync
import (
"bytes"
"encoding/json"
"fmt"
"github.com/cli/cli/api"
"github.com/cli/cli/internal/ghrepo"
)
type commit struct {
Ref string `json:"ref"`
NodeID string `json:"node_id"`
URL string `json:"url"`
Object struct {
Type string `json:"type"`
SHA string `json:"sha"`
URL string `json:"url"`
} `json:"object"`
}
func latestCommit(client *api.Client, repo ghrepo.Interface, branch string) (commit, error) {
var response commit
path := fmt.Sprintf("repos/%s/%s/git/refs/heads/%s", repo.RepoOwner(), repo.RepoName(), branch)
err := client.REST(repo.RepoHost(), "GET", path, nil, &response)
return response, err
}
func syncFork(client *api.Client, repo ghrepo.Interface, branch, SHA string, force bool) error {
path := fmt.Sprintf("repos/%s/%s/git/refs/heads/%s", repo.RepoOwner(), repo.RepoName(), branch)
body := map[string]interface{}{
"sha": SHA,
"force": force,
}
requestByte, err := json.Marshal(body)
if err != nil {
return err
}
requestBody := bytes.NewReader(requestByte)
return client.REST(repo.RepoHost(), "PATCH", path, requestBody, nil)
}

286
pkg/cmd/repo/sync/sync.go Normal file
View file

@ -0,0 +1,286 @@
package sync
import (
"errors"
"fmt"
"net/http"
"os/exec"
"regexp"
"github.com/AlecAivazis/survey/v2"
"github.com/MakeNowJust/heredoc"
"github.com/cli/cli/api"
"github.com/cli/cli/context"
"github.com/cli/cli/git"
"github.com/cli/cli/internal/ghrepo"
"github.com/cli/cli/internal/run"
"github.com/cli/cli/pkg/cmdutil"
"github.com/cli/cli/pkg/iostreams"
"github.com/cli/cli/pkg/prompt"
"github.com/cli/safeexec"
"github.com/spf13/cobra"
)
type SyncOptions struct {
HttpClient func() (*http.Client, error)
IO *iostreams.IOStreams
BaseRepo func() (ghrepo.Interface, error)
Remotes func() (context.Remotes, error)
CurrentBranch func() (string, error)
DestArg string
SrcArg string
Branch string
Force bool
SkipConfirm bool
}
func NewCmdSync(f *cmdutil.Factory, runF func(*SyncOptions) error) *cobra.Command {
opts := SyncOptions{
HttpClient: f.HttpClient,
IO: f.IOStreams,
BaseRepo: f.BaseRepo,
Remotes: f.Remotes,
CurrentBranch: f.Branch,
}
cmd := &cobra.Command{
Use: "sync [<destination-repository>]",
Short: "Sync a repository",
Long: heredoc.Doc(`
Sync destination repository from source repository.
Without an argument, the local repository is selected as the destination repository.
By default the source repository is the parent of the destination repository.
The source repository can be overridden with the --source flag.
`),
Example: heredoc.Doc(`
# Sync local repository from remote parent
$ gh repo sync
# Sync local repository from remote parent on non-default branch
$ gh repo sync --branch v1
# Sync remote fork from remote parent
$ gh repo sync owner/cli-fork
# Sync remote repo from another remote repo
$ gh repo sync owner/repo --source owner2/repo2
`),
Args: cobra.MaximumNArgs(1),
RunE: func(c *cobra.Command, args []string) error {
if len(args) > 0 {
opts.DestArg = args[0]
}
if !opts.IO.CanPrompt() && !opts.SkipConfirm {
return &cmdutil.FlagError{Err: errors.New("`--confirm` required when not running interactively")}
}
if runF != nil {
return runF(&opts)
}
return syncRun(&opts)
},
}
cmd.Flags().StringVarP(&opts.SrcArg, "source", "s", "", "Source repository")
cmd.Flags().StringVarP(&opts.Branch, "branch", "b", "", "Branch to sync")
cmd.Flags().BoolVarP(&opts.Force, "force", "", false, "Discard destination repository changes")
cmd.Flags().BoolVarP(&opts.SkipConfirm, "confirm", "y", false, "Skip the confirmation prompt")
return cmd
}
func syncRun(opts *SyncOptions) error {
httpClient, err := opts.HttpClient()
if err != nil {
return err
}
apiClient := api.NewClientFromHTTP(httpClient)
var local bool
var destRepo, srcRepo ghrepo.Interface
if opts.DestArg == "" {
local = true
destRepo, err = opts.BaseRepo()
if err != nil {
return err
}
} else {
destRepo, err = ghrepo.FromFullName(opts.DestArg)
if err != nil {
return err
}
}
if opts.SrcArg == "" {
if local {
srcRepo = destRepo
} else {
opts.IO.StartProgressIndicator()
srcRepo, err = api.RepoParent(apiClient, destRepo)
opts.IO.StopProgressIndicator()
if err != nil {
return err
}
if srcRepo == nil {
return fmt.Errorf("can't determine source repo for %s because repo is not fork", ghrepo.FullName(destRepo))
}
}
} else {
srcRepo, err = ghrepo.FromFullName(opts.SrcArg)
if err != nil {
return err
}
}
if !local && destRepo.RepoHost() != srcRepo.RepoHost() {
return fmt.Errorf("can't sync repos from different hosts")
}
if opts.Branch == "" {
opts.IO.StartProgressIndicator()
opts.Branch, err = api.RepoDefaultBranch(apiClient, srcRepo)
opts.IO.StopProgressIndicator()
if err != nil {
return err
}
}
srcStr := fmt.Sprintf("%s:%s", ghrepo.FullName(srcRepo), opts.Branch)
destStr := fmt.Sprintf("%s:%s", ghrepo.FullName(destRepo), opts.Branch)
if local {
destStr = fmt.Sprintf(".:%s", opts.Branch)
}
cs := opts.IO.ColorScheme()
if !opts.SkipConfirm && opts.IO.CanPrompt() {
if opts.Force {
fmt.Fprintf(opts.IO.ErrOut, "%s Using --force will cause diverging commits on %s to be discarded\n", cs.WarningIcon(), destStr)
}
var confirmed bool
confirmQuestion := &survey.Confirm{
Message: fmt.Sprintf("Sync %s from %s?", destStr, srcStr),
Default: false,
}
err := prompt.SurveyAskOne(confirmQuestion, &confirmed)
if err != nil {
return err
}
if !confirmed {
return cmdutil.CancelError
}
}
opts.IO.StartProgressIndicator()
if local {
err = syncLocalRepo(srcRepo, opts)
} else {
err = syncRemoteRepo(apiClient, destRepo, srcRepo, opts)
}
opts.IO.StopProgressIndicator()
if err != nil {
return err
}
if opts.IO.IsStdoutTTY() {
success := cs.Bold(fmt.Sprintf("Synced %s from %s\n", destStr, srcStr))
fmt.Fprintf(opts.IO.Out, "%s %s", cs.SuccessIconWithColor(cs.GreenBold), success)
}
return nil
}
func syncLocalRepo(srcRepo ghrepo.Interface, opts *SyncOptions) error {
// Remotes precedence by name
// 1. upstream
// 2. github
// 3. origin
// 4. other
remotes, err := opts.Remotes()
if err != nil {
return err
}
remote := remotes[0]
branch := opts.Branch
_ = executeCmds([][]string{{"git", "fetch", remote.Name, fmt.Sprintf("+refs/heads/%s", branch)}})
hasLocalBranch := git.HasLocalBranch(branch)
if hasLocalBranch {
fastForward, err := git.IsAncestor(branch, fmt.Sprintf("%s/%s", remote.Name, branch))
if err != nil {
return err
}
if !fastForward && !opts.Force {
return fmt.Errorf("can't sync .:%s because there are diverging commits, try using `--force`", branch)
}
}
startBranch, err := opts.CurrentBranch()
if err != nil {
return err
}
dirtyRepo, err := git.IsDirty()
if err != nil {
return err
}
var cmds [][]string
if dirtyRepo {
cmds = append(cmds, []string{"git", "stash", "push"})
}
if startBranch != branch {
cmds = append(cmds, []string{"git", "checkout", branch})
}
if hasLocalBranch {
if opts.Force {
cmds = append(cmds, []string{"git", "reset", "--hard", fmt.Sprintf("refs/remotes/%s/%s", remote, branch)})
} else {
cmds = append(cmds, []string{"git", "merge", "--ff-only", fmt.Sprintf("refs/remotes/%s/%s", remote, branch)})
}
}
if startBranch != branch {
cmds = append(cmds, []string{"git", "checkout", startBranch})
}
if dirtyRepo {
cmds = append(cmds, []string{"git", "stash", "pop"})
}
return executeCmds(cmds)
}
func syncRemoteRepo(client *api.Client, destRepo, srcRepo ghrepo.Interface, opts *SyncOptions) error {
commit, err := latestCommit(client, srcRepo, opts.Branch)
if err != nil {
return err
}
// This is not a great way to detect the error returned by the API
// Unfortunately API returns 422 for multiple reasons
notFastForwardErrorMessage := regexp.MustCompile(`^Update is not a fast forward$`)
err = syncFork(client, destRepo, opts.Branch, commit.Object.SHA, opts.Force)
var httpErr api.HTTPError
if err != nil && errors.As(err, &httpErr) && notFastForwardErrorMessage.MatchString(httpErr.Message) {
return fmt.Errorf("can't sync %s:%s because there are diverging commits, try using `--force`",
ghrepo.FullName(destRepo),
opts.Branch)
}
return err
}
func executeCmds(cmdQueue [][]string) error {
exe, err := safeexec.LookPath("git")
if err != nil {
return err
}
for _, args := range cmdQueue {
cmd := exec.Command(exe, args[1:]...)
if err := run.PrepareCmd(cmd).Run(); err != nil {
return err
}
}
return nil
}

View file

@ -9,15 +9,16 @@ import (
)
var (
magenta = ansi.ColorFunc("magenta")
cyan = ansi.ColorFunc("cyan")
red = ansi.ColorFunc("red")
yellow = ansi.ColorFunc("yellow")
blue = ansi.ColorFunc("blue")
green = ansi.ColorFunc("green")
gray = ansi.ColorFunc("black+h")
bold = ansi.ColorFunc("default+b")
cyanBold = ansi.ColorFunc("cyan+b")
magenta = ansi.ColorFunc("magenta")
cyan = ansi.ColorFunc("cyan")
red = ansi.ColorFunc("red")
yellow = ansi.ColorFunc("yellow")
blue = ansi.ColorFunc("blue")
green = ansi.ColorFunc("green")
gray = ansi.ColorFunc("black+h")
bold = ansi.ColorFunc("default+b")
cyanBold = ansi.ColorFunc("cyan+b")
greenBold = ansi.ColorFunc("green+b")
gray256 = func(t string) string {
return fmt.Sprintf("\x1b[%d;5;%dm%s\x1b[m", 38, 242, t)
@ -96,6 +97,13 @@ func (c *ColorScheme) Green(t string) string {
return green(t)
}
func (c *ColorScheme) GreenBold(t string) string {
if !c.enabled {
return t
}
return greenBold(t)
}
func (c *ColorScheme) Greenf(t string, args ...interface{}) string {
return c.Green(fmt.Sprintf(t, args...))
}