Merge pull request #238 from github/pr-create-just-works-TM

Support triangular git workflows in `pr create`
This commit is contained in:
Mislav Marohnić 2020-01-23 14:11:17 +01:00 committed by GitHub
commit e81a29c076
9 changed files with 635 additions and 99 deletions

View file

@ -84,9 +84,27 @@ type Client struct {
type graphQLResponse struct {
Data interface{}
Errors []struct {
Message string
Errors []GraphQLError
}
// GraphQLError is a single error returned in a GraphQL response
type GraphQLError struct {
Type string
Path []string
Message string
}
// GraphQLErrorResponse contains errors returned in a GraphQL response
type GraphQLErrorResponse struct {
Errors []GraphQLError
}
func (gr GraphQLErrorResponse) Error() string {
errorMessages := make([]string, 0, len(gr.Errors))
for _, e := range gr.Errors {
errorMessages = append(errorMessages, e.Message)
}
return fmt.Sprintf("graphql error: '%s'", strings.Join(errorMessages, ", "))
}
// GraphQL performs a GraphQL request and parses the response
@ -166,14 +184,9 @@ func handleResponse(resp *http.Response, data interface{}) error {
}
if len(gr.Errors) > 0 {
errorMessages := gr.Errors[0].Message
for _, e := range gr.Errors[1:] {
errorMessages += ", " + e.Message
}
return fmt.Errorf("graphql error: '%s'", errorMessages)
return &GraphQLErrorResponse{Errors: gr.Errors}
}
return nil
}
func handleHTTPError(resp *http.Response) error {

View file

@ -2,7 +2,6 @@ package api
import (
"bytes"
"fmt"
"io/ioutil"
"reflect"
"testing"
@ -47,5 +46,7 @@ func TestGraphQLError(t *testing.T) {
response := struct{}{}
http.StubResponse(200, bytes.NewBufferString(`{"errors":[{"message":"OH NO"}]}`))
err := client.GraphQL("", nil, &response)
eq(t, err, fmt.Errorf("graphql error: 'OH NO'"))
if err == nil || err.Error() != "graphql error: 'OH NO'" {
t.Fatalf("got %q", err.Error())
}
}

View file

@ -403,12 +403,8 @@ func PullRequestForBranch(client *Client, ghRepo Repo, branch string) (*PullRequ
return nil, &NotFoundError{fmt.Errorf("no open pull requests found for branch %q", branch)}
}
func CreatePullRequest(client *Client, ghRepo Repo, params map[string]interface{}) (*PullRequest, error) {
repo, err := GitHubRepo(client, ghRepo)
if err != nil {
return nil, err
}
// CreatePullRequest creates a pull request in a GitHub repository
func CreatePullRequest(client *Client, repo *Repository, params map[string]interface{}) (*PullRequest, error) {
query := `
mutation CreatePullRequest($input: CreatePullRequestInput!) {
createPullRequest(input: $input) {
@ -434,7 +430,7 @@ func CreatePullRequest(client *Client, ghRepo Repo, params map[string]interface{
}
}{}
err = client.GraphQL(query, variables, &result)
err := client.GraphQL(query, variables, &result)
if err != nil {
return nil, err
}

View file

@ -1,15 +1,62 @@
package api
import (
"bytes"
"encoding/json"
"fmt"
"sort"
"strings"
"github.com/pkg/errors"
)
// Repository contains information about a GitHub repo
type Repository struct {
ID string
ID string
Name string
Owner RepositoryOwner
IsPrivate bool
HasIssuesEnabled bool
ViewerPermission string
DefaultBranchRef struct {
Name string
Target struct {
OID string
}
}
Parent *Repository
}
// RepositoryOwner is the owner of a GitHub repository
type RepositoryOwner struct {
Login string
}
// RepoOwner is the login name of the owner
func (r Repository) RepoOwner() string {
return r.Owner.Login
}
// RepoName is the name of the repository
func (r Repository) RepoName() string {
return r.Name
}
// IsFork is true when this repository has a parent repository
func (r Repository) IsFork() bool {
return r.Parent != nil
}
// ViewerCanPush is true when the requesting user has push access
func (r Repository) ViewerCanPush() bool {
switch r.ViewerPermission {
case "ADMIN", "MAINTAIN", "WRITE":
return true
default:
return false
}
}
// GitHubRepo looks up the node ID of a named repository
@ -44,3 +91,131 @@ func GitHubRepo(client *Client, ghRepo Repo) (*Repository, error) {
return &result.Repository, nil
}
// RepoNetworkResult describes the relationship between related repositories
type RepoNetworkResult struct {
ViewerLogin string
Repositories []*Repository
}
// RepoNetwork inspects the relationship between multiple GitHub repositories
func RepoNetwork(client *Client, repos []Repo) (RepoNetworkResult, error) {
queries := []string{}
for i, repo := range repos {
queries = append(queries, fmt.Sprintf(`
repo_%03d: repository(owner: %q, name: %q) {
...repo
parent {
...repo
}
}
`, i, repo.RepoOwner(), repo.RepoName()))
}
// Since the query is constructed dynamically, we can't parse a response
// format using a static struct. Instead, hold the raw JSON data until we
// decide how to parse it manually.
graphqlResult := map[string]*json.RawMessage{}
result := RepoNetworkResult{}
err := client.GraphQL(fmt.Sprintf(`
fragment repo on Repository {
id
name
owner { login }
viewerPermission
defaultBranchRef {
name
target { oid }
}
isPrivate
}
query {
viewer { login }
%s
}
`, strings.Join(queries, "")), nil, &graphqlResult)
graphqlError, isGraphQLError := err.(*GraphQLErrorResponse)
if isGraphQLError {
// If the only errors are that certain repositories are not found,
// continue processing this response instead of returning an error
tolerated := true
for _, ge := range graphqlError.Errors {
if ge.Type != "NOT_FOUND" {
tolerated = false
}
}
if tolerated {
err = nil
}
}
if err != nil {
return result, err
}
keys := []string{}
for key := range graphqlResult {
keys = append(keys, key)
}
// sort keys to ensure `repo_{N}` entries are processed in order
sort.Sort(sort.StringSlice(keys))
// Iterate over keys of GraphQL response data and, based on its name,
// dynamically allocate the target struct an individual message gets decoded to.
for _, name := range keys {
jsonMessage := graphqlResult[name]
if name == "viewer" {
viewerResult := struct {
Login string
}{}
decoder := json.NewDecoder(bytes.NewReader([]byte(*jsonMessage)))
if err := decoder.Decode(&viewerResult); err != nil {
return result, err
}
result.ViewerLogin = viewerResult.Login
} else if strings.HasPrefix(name, "repo_") {
if jsonMessage == nil {
result.Repositories = append(result.Repositories, nil)
continue
}
repo := Repository{}
decoder := json.NewDecoder(bytes.NewReader([]byte(*jsonMessage)))
if err := decoder.Decode(&repo); err != nil {
return result, err
}
result.Repositories = append(result.Repositories, &repo)
} else {
return result, fmt.Errorf("unknown GraphQL result key %q", name)
}
}
return result, nil
}
// repositoryV3 is the repository result from GitHub API v3
type repositoryV3 struct {
NodeID string
Name string
Owner struct {
Login string
}
}
// ForkRepo forks the repository on GitHub and returns the new repository
func ForkRepo(client *Client, repo Repo) (*Repository, error) {
path := fmt.Sprintf("repos/%s/%s/forks", repo.RepoOwner(), repo.RepoName())
body := bytes.NewBufferString(`{}`)
result := repositoryV3{}
err := client.REST("POST", path, body, &result)
if err != nil {
return nil, err
}
return &Repository{
ID: result.NodeID,
Name: result.Name,
Owner: RepositoryOwner{
Login: result.Owner.Login,
},
ViewerPermission: "WRITE",
}, nil
}

View file

@ -3,6 +3,9 @@ package command
import (
"fmt"
"net/url"
"sort"
"strings"
"time"
"github.com/github/gh-cli/api"
"github.com/github/gh-cli/context"
@ -15,43 +18,95 @@ import (
func prCreate(cmd *cobra.Command, _ []string) error {
ctx := contextForCommand(cmd)
ucc, err := git.UncommittedChangeCount()
remotes, err := ctx.Remotes()
if err != nil {
return err
}
if ucc > 0 {
client, err := apiClientForContext(ctx)
if err != nil {
return errors.Wrap(err, "could not initialize API client")
}
baseRepoOverride, _ := cmd.Flags().GetString("repo")
repoContext, err := resolveRemotesToRepos(remotes, client, baseRepoOverride)
if err != nil {
return err
}
baseRepo, err := repoContext.BaseRepo()
if err != nil {
return errors.Wrap(err, "could not determine the base repository")
}
headBranch, err := ctx.Branch()
if err != nil {
return errors.Wrap(err, "could not determine the current branch")
}
baseBranch, err := cmd.Flags().GetString("base")
if err != nil {
return err
}
if baseBranch == "" {
baseBranch = baseRepo.DefaultBranchRef.Name
}
didForkRepo := false
var headRemote *context.Remote
headRepo, err := repoContext.HeadRepo()
if err != nil {
if baseRepo.IsPrivate {
return fmt.Errorf("cannot write to private repository '%s/%s'", baseRepo.RepoOwner(), baseRepo.RepoName())
}
headRepo, err = api.ForkRepo(client, baseRepo)
if err != nil {
return fmt.Errorf("error forking repo: %w", err)
}
didForkRepo = true
// TODO: support non-HTTPS git remote URLs
baseRepoURL := fmt.Sprintf("https://github.com/%s/%s.git", baseRepo.RepoOwner(), baseRepo.RepoName())
headRepoURL := fmt.Sprintf("https://github.com/%s/%s.git", headRepo.RepoOwner(), headRepo.RepoName())
// TODO: figure out what to name the new git remote
gitRemote, err := git.AddRemote("fork", baseRepoURL, headRepoURL)
if err != nil {
return fmt.Errorf("error adding remote: %w", err)
}
headRemote = &context.Remote{
Remote: gitRemote,
Owner: headRepo.RepoOwner(),
Repo: headRepo.RepoName(),
}
}
if headBranch == baseBranch && isSameRepo(baseRepo, headRepo) {
return fmt.Errorf("must be on a branch named differently than %q", baseBranch)
}
if headRemote == nil {
headRemote, err = repoContext.RemoteForRepo(headRepo)
if err != nil {
return errors.Wrap(err, "git remote not found for head repository")
}
}
if ucc, err := git.UncommittedChangeCount(); err == nil && ucc > 0 {
fmt.Fprintf(cmd.ErrOrStderr(), "Warning: %s\n", utils.Pluralize(ucc, "uncommitted change"))
}
repo, err := ctx.BaseRepo()
if err != nil {
return errors.Wrap(err, "could not determine GitHub repo")
}
head, err := ctx.Branch()
if err != nil {
return errors.Wrap(err, "could not determine current branch")
}
remote, err := guessRemote(ctx)
if err != nil {
return err
}
target, err := cmd.Flags().GetString("base")
if err != nil {
return err
}
if target == "" {
// TODO use default branch
target = "master"
}
fmt.Fprintf(colorableErr(cmd), "\nCreating pull request for %s into %s in %s/%s\n\n", utils.Cyan(head), utils.Cyan(target), repo.RepoOwner(), repo.RepoName())
if err = git.Push(remote, fmt.Sprintf("HEAD:%s", head)); err != nil {
return err
pushTries := 0
maxPushTries := 3
for {
// TODO: respect existing upstream configuration of the current branch
if err := git.Push(headRemote.Name, fmt.Sprintf("HEAD:%s", headBranch)); err != nil {
if didForkRepo && pushTries < maxPushTries {
pushTries++
// first wait 2 seconds after forking, then 4s, then 6s
time.Sleep(time.Duration(2*pushTries) * time.Second)
continue
}
return err
}
break
}
isWeb, err := cmd.Flags().GetBool("web")
@ -59,11 +114,21 @@ func prCreate(cmd *cobra.Command, _ []string) error {
return errors.Wrap(err, "could not parse web")
}
if isWeb {
openURL := fmt.Sprintf(`https://github.com/%s/%s/pull/%s`, repo.RepoOwner(), repo.RepoName(), head)
openURL := fmt.Sprintf(`https://github.com/%s/%s/pull/%s`, headRepo.RepoOwner(), headRepo.RepoName(), headBranch)
fmt.Fprintf(cmd.ErrOrStderr(), "Opening %s in your browser.\n", openURL)
return utils.OpenInBrowser(openURL)
}
headBranchLabel := headBranch
if !isSameRepo(baseRepo, headRepo) {
headBranchLabel = fmt.Sprintf("%s:%s", headRepo.RepoOwner(), headBranch)
}
fmt.Fprintf(colorableErr(cmd), "\nCreating pull request for %s into %s in %s/%s\n\n",
utils.Cyan(headBranchLabel),
utils.Cyan(baseBranch),
baseRepo.RepoOwner(),
baseRepo.RepoName())
title, err := cmd.Flags().GetString("title")
if err != nil {
return errors.Wrap(err, "could not parse title")
@ -104,20 +169,6 @@ func prCreate(cmd *cobra.Command, _ []string) error {
}
}
base, err := cmd.Flags().GetString("base")
if err != nil {
return errors.Wrap(err, "could not parse base")
}
if base == "" {
// TODO: use default branch for the repo
base = "master"
}
client, err := apiClientForContext(ctx)
if err != nil {
return errors.Wrap(err, "could not initialize api client")
}
isDraft, err := cmd.Flags().GetBool("draft")
if err != nil {
return errors.Wrap(err, "could not parse draft")
@ -128,11 +179,11 @@ func prCreate(cmd *cobra.Command, _ []string) error {
"title": title,
"body": body,
"draft": isDraft,
"baseRefName": base,
"headRefName": head,
"baseRefName": baseBranch,
"headRefName": headBranch,
}
pr, err := api.CreatePullRequest(client, repo, params)
pr, err := api.CreatePullRequest(client, baseRepo, params)
if err != nil {
return errors.Wrap(err, "failed to create pull request")
}
@ -141,10 +192,10 @@ func prCreate(cmd *cobra.Command, _ []string) error {
} else if action == PreviewAction {
openURL := fmt.Sprintf(
"https://github.com/%s/%s/compare/%s...%s?expand=1&title=%s&body=%s",
repo.RepoOwner(),
repo.RepoName(),
target,
head,
baseRepo.RepoOwner(),
baseRepo.RepoName(),
baseBranch,
headBranchLabel,
url.QueryEscape(title),
url.QueryEscape(body),
)
@ -163,22 +214,6 @@ func prCreate(cmd *cobra.Command, _ []string) error {
}
func guessRemote(ctx context.Context) (string, error) {
remotes, err := ctx.Remotes()
if err != nil {
return "", errors.Wrap(err, "could not read git remotes")
}
// TODO: consolidate logic with fsContext.BaseRepo
// TODO: check if the GH repo that the remote points to is writeable
remote, err := remotes.FindByName("upstream", "github", "origin", "*")
if err != nil {
return "", errors.Wrap(err, "could not determine suitable remote")
}
return remote.Name, nil
}
var prCreateCmd = &cobra.Command{
Use: "create",
Short: "Create a pull request",
@ -196,3 +231,122 @@ func init() {
"The branch into which you want your code merged")
prCreateCmd.Flags().BoolP("web", "w", false, "Open the web browser to create a pull request")
}
// cap the number of git remotes looked up, since the user might have an
// unusally large number of git remotes
const maxRemotesForLookup = 5
func resolveRemotesToRepos(remotes context.Remotes, client *api.Client, base string) (resolvedRemotes, error) {
sort.Stable(remotes)
lenRemotesForLookup := len(remotes)
if lenRemotesForLookup > maxRemotesForLookup {
lenRemotesForLookup = maxRemotesForLookup
}
hasBaseOverride := base != ""
baseOverride := repoFromFullName(base)
foundBaseOverride := false
repos := []api.Repo{}
for _, r := range remotes[:lenRemotesForLookup] {
repos = append(repos, r)
if isSameRepo(r, baseOverride) {
foundBaseOverride = true
}
}
if hasBaseOverride && !foundBaseOverride {
// additionally, look up the explicitly specified base repo if it's not
// already covered by git remotes
repos = append(repos, baseOverride)
}
result := resolvedRemotes{remotes: remotes}
if hasBaseOverride {
result.baseOverride = baseOverride
}
networkResult, err := api.RepoNetwork(client, repos)
if err != nil {
return result, err
}
result.network = networkResult
return result, nil
}
type resolvedRemotes struct {
baseOverride api.Repo
remotes context.Remotes
network api.RepoNetworkResult
}
// BaseRepo is the first found repository in the "upstream", "github", "origin"
// git remote order, resolved to the parent repo if the git remote points to a fork
func (r resolvedRemotes) BaseRepo() (*api.Repository, error) {
if r.baseOverride != nil {
for _, repo := range r.network.Repositories {
if repo != nil && isSameRepo(repo, r.baseOverride) {
return repo, nil
}
}
return nil, fmt.Errorf("failed looking up information about the '%s/%s' repository",
r.baseOverride.RepoOwner(), r.baseOverride.RepoName())
}
for _, repo := range r.network.Repositories {
if repo == nil {
continue
}
if repo.IsFork() {
return repo.Parent, nil
}
return repo, nil
}
return nil, errors.New("not found")
}
// HeadRepo is the first found repository that has push access
func (r resolvedRemotes) HeadRepo() (*api.Repository, error) {
for _, repo := range r.network.Repositories {
if repo != nil && repo.ViewerCanPush() {
return repo, nil
}
}
return nil, errors.New("none of the repositories have push access")
}
// RemoteForRepo finds the git remote that points to a repository
func (r resolvedRemotes) RemoteForRepo(repo api.Repo) (*context.Remote, error) {
for i, remote := range r.remotes {
if isSameRepo(remote, repo) ||
// additionally, look up the resolved repository name in case this
// git remote points to this repository via a redirect
(r.network.Repositories[i] != nil && isSameRepo(r.network.Repositories[i], repo)) {
return remote, nil
}
}
return nil, errors.New("not found")
}
type ghRepo struct {
owner string
name string
}
func repoFromFullName(nwo string) (r ghRepo) {
parts := strings.SplitN(nwo, "/", 2)
if len(parts) == 2 {
r.owner, r.name = parts[0], parts[1]
}
return
}
func (r ghRepo) RepoOwner() string {
return r.owner
}
func (r ghRepo) RepoName() string {
return r.name
}
func isSameRepo(a, b api.Repo) bool {
return strings.EqualFold(a.RepoOwner(), b.RepoOwner()) &&
strings.EqualFold(a.RepoName(), b.RepoName())
}

View file

@ -10,6 +10,7 @@ import (
"strings"
"testing"
"github.com/github/gh-cli/api"
"github.com/github/gh-cli/context"
"github.com/github/gh-cli/git"
"github.com/github/gh-cli/test"
@ -52,8 +53,15 @@ func TestPRCreate(t *testing.T) {
http := initFakeHTTP()
http.StubResponse(200, bytes.NewBufferString(`
{ "data": { "repository": {
"id": "REPOID"
{ "data": { "repo_000": {
"id": "REPOID",
"name": "REPO",
"owner": {"login": "OWNER"},
"defaultBranchRef": {
"name": "master",
"target": {"oid": "deadbeef"}
},
"viewerPermission": "WRITE"
} } }
`))
http.StubResponse(200, bytes.NewBufferString(`
@ -103,7 +111,20 @@ func TestPRCreate_web(t *testing.T) {
initContext = func() context.Context {
return ctx
}
initFakeHTTP()
http := initFakeHTTP()
http.StubResponse(200, bytes.NewBufferString(`
{ "data": { "repo_000": {
"id": "REPOID",
"name": "REPO",
"owner": {"login": "OWNER"},
"defaultBranchRef": {
"name": "master",
"target": {"oid": "deadbeef"}
},
"viewerPermission": "WRITE"
} } }
`))
ranCommands := [][]string{}
restoreCmd := utils.SetPrepareCmd(func(cmd *exec.Cmd) utils.Runnable {
@ -116,11 +137,7 @@ func TestPRCreate_web(t *testing.T) {
eq(t, err, nil)
eq(t, output.String(), "")
eq(t, output.Stderr(), `
Creating pull request for feature into master in OWNER/REPO
Opening https://github.com/OWNER/REPO/pull/feature in your browser.
`)
eq(t, output.Stderr(), "Opening https://github.com/OWNER/REPO/pull/feature in your browser.\n")
eq(t, len(ranCommands), 3)
eq(t, strings.Join(ranCommands[1], " "), "git push --set-upstream origin HEAD:feature")
@ -139,8 +156,15 @@ func TestPRCreate_ReportsUncommittedChanges(t *testing.T) {
http := initFakeHTTP()
http.StubResponse(200, bytes.NewBufferString(`
{ "data": { "repository": {
"id": "REPOID"
{ "data": { "repo_000": {
"id": "REPOID",
"name": "REPO",
"owner": {"login": "OWNER"},
"defaultBranchRef": {
"name": "master",
"target": {"oid": "deadbeef"}
},
"viewerPermission": "WRITE"
} } }
`))
http.StubResponse(200, bytes.NewBufferString(`
@ -165,3 +189,122 @@ Creating pull request for feature into master in OWNER/REPO
`)
}
func Test_resolvedRemotes_clonedFork(t *testing.T) {
resolved := resolvedRemotes{
baseOverride: nil,
remotes: context.Remotes{
&context.Remote{
Remote: &git.Remote{Name: "origin"},
Owner: "OWNER",
Repo: "REPO",
},
},
network: api.RepoNetworkResult{
Repositories: []*api.Repository{
&api.Repository{
Name: "REPO",
Owner: api.RepositoryOwner{Login: "OWNER"},
ViewerPermission: "ADMIN",
Parent: &api.Repository{
Name: "REPO",
Owner: api.RepositoryOwner{Login: "PARENTOWNER"},
ViewerPermission: "READ",
},
},
},
},
}
baseRepo, err := resolved.BaseRepo()
if err != nil {
t.Fatalf("got %v", err)
}
if baseRepo.RepoOwner() != "PARENTOWNER" {
t.Errorf("got owner %q", baseRepo.RepoOwner())
}
baseRemote, err := resolved.RemoteForRepo(baseRepo)
if baseRemote != nil || err == nil {
t.Error("did not expect any remote for base")
}
headRepo, err := resolved.HeadRepo()
if err != nil {
t.Fatalf("got %v", err)
}
if headRepo.RepoOwner() != "OWNER" {
t.Errorf("got owner %q", headRepo.RepoOwner())
}
headRemote, err := resolved.RemoteForRepo(headRepo)
if err != nil {
t.Fatalf("got %v", err)
}
if headRemote.Name != "origin" {
t.Errorf("got remote %q", headRemote.Name)
}
}
func Test_resolvedRemotes_triangularSetup(t *testing.T) {
resolved := resolvedRemotes{
baseOverride: nil,
remotes: context.Remotes{
&context.Remote{
Remote: &git.Remote{Name: "origin"},
Owner: "OWNER",
Repo: "REPO",
},
&context.Remote{
Remote: &git.Remote{Name: "fork"},
Owner: "MYSELF",
Repo: "REPO",
},
},
network: api.RepoNetworkResult{
Repositories: []*api.Repository{
&api.Repository{
Name: "NEWNAME",
Owner: api.RepositoryOwner{Login: "NEWOWNER"},
ViewerPermission: "READ",
},
&api.Repository{
Name: "REPO",
Owner: api.RepositoryOwner{Login: "MYSELF"},
ViewerPermission: "ADMIN",
},
},
},
}
baseRepo, err := resolved.BaseRepo()
if err != nil {
t.Fatalf("got %v", err)
}
if baseRepo.RepoOwner() != "NEWOWNER" {
t.Errorf("got owner %q", baseRepo.RepoOwner())
}
if baseRepo.RepoName() != "NEWNAME" {
t.Errorf("got name %q", baseRepo.RepoName())
}
baseRemote, err := resolved.RemoteForRepo(baseRepo)
if err != nil {
t.Fatalf("got %v", err)
}
if baseRemote.Name != "origin" {
t.Errorf("got remote %q", baseRemote.Name)
}
headRepo, err := resolved.HeadRepo()
if err != nil {
t.Fatalf("got %v", err)
}
if headRepo.RepoOwner() != "MYSELF" {
t.Errorf("got owner %q", headRepo.RepoOwner())
}
headRemote, err := resolved.RemoteForRepo(headRepo)
if err != nil {
t.Fatalf("got %v", err)
}
if headRemote.Name != "fork" {
t.Errorf("got remote %q", headRemote.Name)
}
}

View file

@ -93,7 +93,6 @@ func contextForCommand(cmd *cobra.Command) context.Context {
ctx := initContext()
if repo, err := cmd.Flags().GetString("repo"); err == nil && repo != "" {
ctx.SetBaseRepo(repo)
ctx.SetBranch("master")
}
return ctx
}

View file

@ -35,6 +35,26 @@ func (r Remotes) FindByRepo(owner, name string) (*Remote, error) {
return nil, fmt.Errorf("no matching remote found")
}
func remoteNameSortScore(name string) int {
switch strings.ToLower(name) {
case "upstream":
return 3
case "github":
return 2
case "origin":
return 1
default:
return 0
}
}
// https://golang.org/pkg/sort/#Interface
func (r Remotes) Len() int { return len(r) }
func (r Remotes) Swap(i, j int) { r[i], r[j] = r[j], r[i] }
func (r Remotes) Less(i, j int) bool {
return remoteNameSortScore(r[i].Name) > remoteNameSortScore(r[j].Name)
}
// Remote represents a git remote mapped to a GitHub repository
type Remote struct {
*git.Remote

View file

@ -2,8 +2,11 @@ package git
import (
"net/url"
"os/exec"
"regexp"
"strings"
"github.com/github/gh-cli/utils"
)
var remoteRE = regexp.MustCompile(`(.+)\s+(.+)\s+\((push|fetch)\)`)
@ -67,3 +70,35 @@ func parseRemotes(gitRemotes []string) (remotes RemoteSet) {
}
return
}
// AddRemote adds a new git remote. The initURL is the remote URL with which the
// automatic fetch is made and finalURL, if non-blank, is set as the remote URL
// after the fetch.
func AddRemote(name, initURL, finalURL string) (*Remote, error) {
addCmd := exec.Command("git", "remote", "add", "-f", name, initURL)
err := utils.PrepareCmd(addCmd).Run()
if err != nil {
return nil, err
}
if finalURL == "" {
finalURL = initURL
} else {
setCmd := exec.Command("git", "remote", "set-url", name, finalURL)
err := utils.PrepareCmd(setCmd).Run()
if err != nil {
return nil, err
}
}
finalURLParsed, err := url.Parse(initURL)
if err != nil {
return nil, err
}
return &Remote{
Name: name,
FetchURL: finalURLParsed,
PushURL: finalURLParsed,
}, nil
}