Support triangular git workflows in pr create
- The local git remotes are scanned and resolved to GitHub repositories - The "base" repo is the first result resolved to its parent repo (if a fork) - The name of the default branch is read from the base repo - The "head" repo is the first repo that has push access
This commit is contained in:
parent
99c17c3a5f
commit
7a614ce697
5 changed files with 364 additions and 84 deletions
|
|
@ -7,6 +7,7 @@ import (
|
|||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ClientOption represents an argument to NewClient
|
||||
|
|
@ -69,9 +70,27 @@ type Client struct {
|
|||
|
||||
type graphQLResponse struct {
|
||||
Data interface{}
|
||||
Errors []struct {
|
||||
Message string
|
||||
Errors []GraphQLError
|
||||
}
|
||||
|
||||
// GraphQLError is a single error returned in a GraphQL response
|
||||
type GraphQLError struct {
|
||||
Type string
|
||||
Path []string
|
||||
Message string
|
||||
}
|
||||
|
||||
// GraphQLErrorResponse contains errors returned in a GraphQL response
|
||||
type GraphQLErrorResponse struct {
|
||||
Errors []GraphQLError
|
||||
}
|
||||
|
||||
func (gr GraphQLErrorResponse) Error() string {
|
||||
errorMessages := make([]string, 0, len(gr.Errors))
|
||||
for _, e := range gr.Errors {
|
||||
errorMessages = append(errorMessages, e.Message)
|
||||
}
|
||||
return fmt.Sprintf("graphql error: '%s'", strings.Join(errorMessages, ", "))
|
||||
}
|
||||
|
||||
// GraphQL performs a GraphQL request and parses the response
|
||||
|
|
@ -151,14 +170,9 @@ func handleResponse(resp *http.Response, data interface{}) error {
|
|||
}
|
||||
|
||||
if len(gr.Errors) > 0 {
|
||||
errorMessages := gr.Errors[0].Message
|
||||
for _, e := range gr.Errors[1:] {
|
||||
errorMessages += ", " + e.Message
|
||||
}
|
||||
return fmt.Errorf("graphql error: '%s'", errorMessages)
|
||||
return &GraphQLErrorResponse{Errors: gr.Errors}
|
||||
}
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func handleHTTPError(resp *http.Response) error {
|
||||
|
|
|
|||
|
|
@ -403,12 +403,8 @@ func PullRequestForBranch(client *Client, ghRepo Repo, branch string) (*PullRequ
|
|||
return nil, &NotFoundError{fmt.Errorf("no open pull requests found for branch %q", branch)}
|
||||
}
|
||||
|
||||
func CreatePullRequest(client *Client, ghRepo Repo, params map[string]interface{}) (*PullRequest, error) {
|
||||
repo, err := GitHubRepo(client, ghRepo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// CreatePullRequest creates a pull request in a GitHub repository
|
||||
func CreatePullRequest(client *Client, repo Repository, params map[string]interface{}) (*PullRequest, error) {
|
||||
query := `
|
||||
mutation CreatePullRequest($input: CreatePullRequestInput!) {
|
||||
createPullRequest(input: $input) {
|
||||
|
|
@ -434,7 +430,7 @@ func CreatePullRequest(client *Client, ghRepo Repo, params map[string]interface{
|
|||
}
|
||||
}{}
|
||||
|
||||
err = client.GraphQL(query, variables, &result)
|
||||
err := client.GraphQL(query, variables, &result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,15 +1,59 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Repository contains information about a GitHub repo
|
||||
type Repository struct {
|
||||
ID string
|
||||
ID string
|
||||
Name string
|
||||
Owner struct {
|
||||
Login string
|
||||
}
|
||||
|
||||
IsPrivate bool
|
||||
HasIssuesEnabled bool
|
||||
ViewerPermission string
|
||||
DefaultBranchRef struct {
|
||||
Name string
|
||||
Target struct {
|
||||
OID string
|
||||
}
|
||||
}
|
||||
|
||||
Parent *Repository
|
||||
}
|
||||
|
||||
// RepoOwner is the login name of the owner
|
||||
func (r Repository) RepoOwner() string {
|
||||
return r.Owner.Login
|
||||
}
|
||||
|
||||
// RepoName is the name of the repository
|
||||
func (r Repository) RepoName() string {
|
||||
return r.Name
|
||||
}
|
||||
|
||||
// IsFork is true when this repository has a parent repository
|
||||
func (r Repository) IsFork() bool {
|
||||
return r.Parent != nil
|
||||
}
|
||||
|
||||
// ViewerCanPush is true when the requesting user has push access
|
||||
func (r Repository) ViewerCanPush() bool {
|
||||
switch r.ViewerPermission {
|
||||
case "ADMIN", "MAINTAIN", "WRITE":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// GitHubRepo looks up the node ID of a named repository
|
||||
|
|
@ -44,3 +88,102 @@ func GitHubRepo(client *Client, ghRepo Repo) (*Repository, error) {
|
|||
|
||||
return &result.Repository, nil
|
||||
}
|
||||
|
||||
// RepoNetworkResult describes the relationship between related repositories
|
||||
type RepoNetworkResult struct {
|
||||
ViewerLogin string
|
||||
Repositories []*Repository
|
||||
}
|
||||
|
||||
// RepoNetwork inspects the relationship between multiple GitHub repositories
|
||||
func RepoNetwork(client *Client, repos []Repo) (RepoNetworkResult, error) {
|
||||
queries := []string{}
|
||||
for i, repo := range repos {
|
||||
queries = append(queries, fmt.Sprintf(`
|
||||
repo_%03d: repository(owner: %q, name: %q) {
|
||||
...repo
|
||||
parent {
|
||||
...repo
|
||||
}
|
||||
}
|
||||
`, i, repo.RepoOwner(), repo.RepoName()))
|
||||
}
|
||||
|
||||
type ViewerOrRepo struct {
|
||||
Login string
|
||||
Repository
|
||||
}
|
||||
|
||||
graphqlResult := map[string]*json.RawMessage{}
|
||||
result := RepoNetworkResult{}
|
||||
|
||||
err := client.GraphQL(fmt.Sprintf(`
|
||||
fragment repo on Repository {
|
||||
id
|
||||
name
|
||||
owner { login }
|
||||
viewerPermission
|
||||
defaultBranchRef {
|
||||
name
|
||||
target { oid }
|
||||
}
|
||||
isPrivate
|
||||
}
|
||||
query {
|
||||
viewer { login }
|
||||
%s
|
||||
}
|
||||
`, strings.Join(queries, "")), nil, &graphqlResult)
|
||||
graphqlError, isGraphQLError := err.(*GraphQLErrorResponse)
|
||||
if isGraphQLError {
|
||||
// If the only errors are that certain repositories are not found,
|
||||
// continue processing this response instead of returning an error
|
||||
tolerated := true
|
||||
for _, ge := range graphqlError.Errors {
|
||||
if ge.Type != "NOT_FOUND" {
|
||||
tolerated = false
|
||||
}
|
||||
}
|
||||
if tolerated {
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
keys := []string{}
|
||||
for key := range graphqlResult {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
// sort keys to ensure `repo_{N}` entries are processed in order
|
||||
sort.Sort(sort.StringSlice(keys))
|
||||
|
||||
for _, name := range keys {
|
||||
jsonMessage := graphqlResult[name]
|
||||
if name == "viewer" {
|
||||
viewerResult := struct {
|
||||
Login string
|
||||
}{}
|
||||
decoder := json.NewDecoder(bytes.NewReader([]byte(*jsonMessage)))
|
||||
if err := decoder.Decode(&viewerResult); err != nil {
|
||||
return result, err
|
||||
}
|
||||
result.ViewerLogin = viewerResult.Login
|
||||
} else if strings.HasPrefix(name, "repo_") {
|
||||
if jsonMessage == nil {
|
||||
result.Repositories = append(result.Repositories, nil)
|
||||
continue
|
||||
}
|
||||
repo := Repository{}
|
||||
decoder := json.NewDecoder(bytes.NewReader([]byte(*jsonMessage)))
|
||||
if err := decoder.Decode(&repo); err != nil {
|
||||
return result, err
|
||||
}
|
||||
result.Repositories = append(result.Repositories, &repo)
|
||||
} else {
|
||||
return result, fmt.Errorf("unknown GraphQL result key %q", name)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@ package command
|
|||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/github/gh-cli/api"
|
||||
"github.com/github/gh-cli/context"
|
||||
|
|
@ -15,42 +17,62 @@ import (
|
|||
|
||||
func prCreate(cmd *cobra.Command, _ []string) error {
|
||||
ctx := contextForCommand(cmd)
|
||||
|
||||
ucc, err := git.UncommittedChangeCount()
|
||||
remotes, err := ctx.Remotes()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ucc > 0 {
|
||||
|
||||
client, err := apiClientForContext(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "could not initialize API client")
|
||||
}
|
||||
|
||||
baseRepoOverride, _ := cmd.Flags().GetString("repo")
|
||||
repoContext, err := resolveRemotesToRepos(remotes, client, baseRepoOverride)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
baseRepo, err := repoContext.BaseRepo()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "could not determine the base repository")
|
||||
}
|
||||
|
||||
headBranch, err := ctx.Branch()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "could not determine the current branch")
|
||||
}
|
||||
|
||||
baseBranch, err := cmd.Flags().GetString("base")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if baseBranch == "" {
|
||||
baseBranch = baseRepo.DefaultBranchRef.Name
|
||||
}
|
||||
|
||||
headRepo, err := repoContext.HeadRepo()
|
||||
if err != nil {
|
||||
// TODO: auto-fork repository and add new git remote
|
||||
return errors.Wrap(err, "could not determine the head repository")
|
||||
}
|
||||
|
||||
if headBranch == baseBranch && isSameRepo(baseRepo, headRepo) {
|
||||
return fmt.Errorf("must be on a branch named differently than %q", baseBranch)
|
||||
}
|
||||
|
||||
fmt.Fprintf(colorableErr(cmd), "\nCreating pull request for %s into %s in %s/%s\n\n", utils.Cyan(headBranch), utils.Cyan(baseBranch), baseRepo.RepoOwner(), baseRepo.RepoName())
|
||||
|
||||
headRemote, err := repoContext.RemoteForRepo(headRepo)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "")
|
||||
}
|
||||
|
||||
if ucc, err := git.UncommittedChangeCount(); err == nil && ucc > 0 {
|
||||
fmt.Fprintf(cmd.ErrOrStderr(), "Warning: %s\n", utils.Pluralize(ucc, "uncommitted change"))
|
||||
}
|
||||
|
||||
repo, err := ctx.BaseRepo()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "could not determine GitHub repo")
|
||||
}
|
||||
|
||||
head, err := ctx.Branch()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "could not determine current branch")
|
||||
}
|
||||
|
||||
remote, err := guessRemote(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
target, err := cmd.Flags().GetString("base")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if target == "" {
|
||||
// TODO use default branch
|
||||
target = "master"
|
||||
}
|
||||
|
||||
fmt.Fprintf(colorableErr(cmd), "\nCreating pull request for %s into %s in %s/%s\n\n", utils.Cyan(head), utils.Cyan(target), repo.RepoOwner(), repo.RepoName())
|
||||
|
||||
if err = git.Push(remote, fmt.Sprintf("HEAD:%s", head)); err != nil {
|
||||
// TODO: respect existing upstream configuration of the current branch
|
||||
if err = git.Push(headRemote.Name, fmt.Sprintf("HEAD:%s", headBranch)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
@ -59,7 +81,7 @@ func prCreate(cmd *cobra.Command, _ []string) error {
|
|||
return errors.Wrap(err, "could not parse web")
|
||||
}
|
||||
if isWeb {
|
||||
openURL := fmt.Sprintf(`https://github.com/%s/%s/pull/%s`, repo.RepoOwner(), repo.RepoName(), head)
|
||||
openURL := fmt.Sprintf(`https://github.com/%s/%s/pull/%s`, baseRepo.RepoOwner(), baseRepo.RepoName(), headBranch)
|
||||
fmt.Fprintf(cmd.ErrOrStderr(), "Opening %s in your browser.\n", openURL)
|
||||
return utils.OpenInBrowser(openURL)
|
||||
}
|
||||
|
|
@ -104,20 +126,6 @@ func prCreate(cmd *cobra.Command, _ []string) error {
|
|||
}
|
||||
}
|
||||
|
||||
base, err := cmd.Flags().GetString("base")
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "could not parse base")
|
||||
}
|
||||
if base == "" {
|
||||
// TODO: use default branch for the repo
|
||||
base = "master"
|
||||
}
|
||||
|
||||
client, err := apiClientForContext(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "could not initialize api client")
|
||||
}
|
||||
|
||||
isDraft, err := cmd.Flags().GetBool("draft")
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "could not parse draft")
|
||||
|
|
@ -128,10 +136,11 @@ func prCreate(cmd *cobra.Command, _ []string) error {
|
|||
"title": title,
|
||||
"body": body,
|
||||
"draft": isDraft,
|
||||
"baseRefName": base,
|
||||
"headRefName": head,
|
||||
"baseRefName": baseBranch,
|
||||
"headRefName": headBranch,
|
||||
}
|
||||
|
||||
repo := api.Repository{}
|
||||
pr, err := api.CreatePullRequest(client, repo, params)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to create pull request")
|
||||
|
|
@ -141,10 +150,10 @@ func prCreate(cmd *cobra.Command, _ []string) error {
|
|||
} else if action == PreviewAction {
|
||||
openURL := fmt.Sprintf(
|
||||
"https://github.com/%s/%s/compare/%s...%s?expand=1&title=%s&body=%s",
|
||||
repo.RepoOwner(),
|
||||
repo.RepoName(),
|
||||
target,
|
||||
head,
|
||||
baseRepo.RepoOwner(),
|
||||
baseRepo.RepoName(),
|
||||
baseBranch,
|
||||
headBranch,
|
||||
url.QueryEscape(title),
|
||||
url.QueryEscape(body),
|
||||
)
|
||||
|
|
@ -163,22 +172,6 @@ func prCreate(cmd *cobra.Command, _ []string) error {
|
|||
|
||||
}
|
||||
|
||||
func guessRemote(ctx context.Context) (string, error) {
|
||||
remotes, err := ctx.Remotes()
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "could not read git remotes")
|
||||
}
|
||||
|
||||
// TODO: consolidate logic with fsContext.BaseRepo
|
||||
// TODO: check if the GH repo that the remote points to is writeable
|
||||
remote, err := remotes.FindByName("upstream", "github", "origin", "*")
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "could not determine suitable remote")
|
||||
}
|
||||
|
||||
return remote.Name, nil
|
||||
}
|
||||
|
||||
var prCreateCmd = &cobra.Command{
|
||||
Use: "create",
|
||||
Short: "Create a pull request",
|
||||
|
|
@ -196,3 +189,117 @@ func init() {
|
|||
"The branch into which you want your code merged")
|
||||
prCreateCmd.Flags().BoolP("web", "w", false, "Open the web browser to create a pull request")
|
||||
}
|
||||
|
||||
const maxRemotesForLookup = 5
|
||||
|
||||
func resolveRemotesToRepos(remotes context.Remotes, client *api.Client, base string) (resolvedRemotes, error) {
|
||||
sort.Stable(remotes)
|
||||
lenRemotesForLookup := len(remotes)
|
||||
if lenRemotesForLookup > maxRemotesForLookup {
|
||||
lenRemotesForLookup = maxRemotesForLookup
|
||||
}
|
||||
|
||||
hasBaseOverride := base != ""
|
||||
baseOverride := repoFromFullName(base)
|
||||
foundBaseOverride := false
|
||||
repos := []api.Repo{}
|
||||
for _, r := range remotes[:lenRemotesForLookup] {
|
||||
repos = append(repos, r)
|
||||
if isSameRepo(r, baseOverride) {
|
||||
foundBaseOverride = true
|
||||
}
|
||||
}
|
||||
if hasBaseOverride && !foundBaseOverride {
|
||||
repos = append(repos, baseOverride)
|
||||
}
|
||||
|
||||
result := resolvedRemotes{remotes: remotes}
|
||||
if hasBaseOverride {
|
||||
result.baseOverride = baseOverride
|
||||
}
|
||||
networkResult, err := api.RepoNetwork(client, repos)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
result.network = networkResult
|
||||
return result, nil
|
||||
}
|
||||
|
||||
type resolvedRemotes struct {
|
||||
baseOverride api.Repo
|
||||
remotes context.Remotes
|
||||
network api.RepoNetworkResult
|
||||
}
|
||||
|
||||
// BaseRepo is the first found repository in the "upstream", "github", "origin"
|
||||
// git remote order, resolved to the parent repo if the git remote points to a fork
|
||||
func (r resolvedRemotes) BaseRepo() (*api.Repository, error) {
|
||||
if r.baseOverride != nil {
|
||||
for _, repo := range r.network.Repositories {
|
||||
if repo != nil && isSameRepo(repo, r.baseOverride) {
|
||||
return repo, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("failed looking up information about the '%s/%s' repository",
|
||||
r.baseOverride.RepoOwner(), r.baseOverride.RepoName())
|
||||
}
|
||||
|
||||
for _, repo := range r.network.Repositories {
|
||||
if repo == nil {
|
||||
continue
|
||||
}
|
||||
if repo.IsFork() {
|
||||
return repo.Parent, nil
|
||||
}
|
||||
return repo, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("not found")
|
||||
}
|
||||
|
||||
// HeadRepo is the first found repository that has push access
|
||||
func (r resolvedRemotes) HeadRepo() (*api.Repository, error) {
|
||||
for _, repo := range r.network.Repositories {
|
||||
if repo != nil && repo.ViewerCanPush() {
|
||||
return repo.Parent, nil
|
||||
}
|
||||
}
|
||||
return nil, errors.New("none of the repositories have push access")
|
||||
}
|
||||
|
||||
// RemoteForRepo finds the git remote that points to a repository
|
||||
func (r resolvedRemotes) RemoteForRepo(repo api.Repo) (*context.Remote, error) {
|
||||
for i, remote := range r.remotes {
|
||||
if isSameRepo(remote, repo) ||
|
||||
// FIXME: express better that this is because of repo renames
|
||||
(r.network.Repositories[i] != nil && isSameRepo(r.network.Repositories[i], repo)) {
|
||||
return remote, nil
|
||||
}
|
||||
}
|
||||
return nil, errors.New("not found")
|
||||
}
|
||||
|
||||
type ghRepo struct {
|
||||
owner string
|
||||
name string
|
||||
}
|
||||
|
||||
func repoFromFullName(nwo string) (r ghRepo) {
|
||||
parts := strings.SplitN(nwo, "/", 2)
|
||||
if len(parts) == 2 {
|
||||
r.owner, r.name = parts[0], parts[1]
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (r ghRepo) RepoOwner() string {
|
||||
return r.owner
|
||||
}
|
||||
func (r ghRepo) RepoName() string {
|
||||
return r.name
|
||||
}
|
||||
|
||||
func isSameRepo(a, b api.Repo) bool {
|
||||
return strings.EqualFold(a.RepoOwner(), b.RepoOwner()) &&
|
||||
strings.EqualFold(a.RepoName(), b.RepoName())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -35,6 +35,26 @@ func (r Remotes) FindByRepo(owner, name string) (*Remote, error) {
|
|||
return nil, fmt.Errorf("no matching remote found")
|
||||
}
|
||||
|
||||
func remoteNameSortScore(name string) int {
|
||||
switch strings.ToLower(name) {
|
||||
case "upstream":
|
||||
return 3
|
||||
case "github":
|
||||
return 2
|
||||
case "origin":
|
||||
return 1
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
// https://golang.org/pkg/sort/#Interface
|
||||
func (r Remotes) Len() int { return len(r) }
|
||||
func (r Remotes) Swap(i, j int) { r[i], r[j] = r[j], r[i] }
|
||||
func (r Remotes) Less(i, j int) bool {
|
||||
return remoteNameSortScore(r[i].Name) > remoteNameSortScore(r[j].Name)
|
||||
}
|
||||
|
||||
// Remote represents a git remote mapped to a GitHub repository
|
||||
type Remote struct {
|
||||
*git.Remote
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue